From fc185200c446e567d3097fb4272beb8bb372e048 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 9 Jun 2025 11:54:10 -0700 Subject: [PATCH 01/39] Use public API instead of removed private function in `te_llama.py` (#1856) Use public API instead of removed private function * replaced use of _load_state_dict_into_model with model.load_state_dict because the private function _load_state_dict_into_model was removed in https://github.com/huggingface/transformers/pull/36335 Signed-off-by: Jan Bielak --- docs/examples/te_llama/te_llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index b6ec290b03..8297ac6d2e 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -19,7 +19,7 @@ LlamaRMSNorm, LlamaConfig, ) -from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model +from transformers.modeling_utils import _add_variant, load_state_dict from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils.hub import get_checkpoint_shard_files @@ -148,8 +148,8 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k state_dict = load_state_dict(shard_file) # replace_params copies parameters relevant only to TransformerEngine replace_params(state_dict, vanilla_model.state_dict(), config) - # _load_state_dict_into_model copies parameters other than those in TransformerEngine - _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") + # load_state_dict copies parameters other than those in TransformerEngine + vanilla_model.load_state_dict(state_dict, strict=False) # Force mem release. Taken from huggingface code del state_dict From ddcda1ffd2e1d49e60aa84fadc08628697e70c4a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 9 Jun 2025 14:45:25 -0700 Subject: [PATCH 02/39] Manage dependencies and add missing `einops` req (#1859) * Manage deps and add einops Signed-off-by: Kirthi Shankar Sivamani * Update build.yml Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/build.yml | 2 +- build_tools/jax.py | 11 ++++++++++- build_tools/pytorch.py | 16 ++++++++++++++++ setup.py | 18 ++++++++---------- transformer_engine/jax/setup.py | 6 +++--- transformer_engine/pytorch/setup.py | 6 +++--- 6 files changed, 41 insertions(+), 18 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 58cb6f6b1e..78ef679f9d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -87,7 +87,7 @@ jobs: with: submodules: recursive - name: 'Build' - run: pip install --no-build-isolation . -v + run: pip install --no-build-isolation . -v --no-deps env: NVTE_FRAMEWORK: all MAX_JOBS: 1 diff --git a/build_tools/jax.py b/build_tools/jax.py index db4dba1c60..2b9ad1a30f 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -4,7 +4,6 @@ """JAX related extensions.""" import os -import shutil from pathlib import Path import setuptools @@ -13,6 +12,16 @@ from typing import List +def install_requirements() -> List[str]: + """Install dependencies for TE/JAX extensions.""" + return ["jax[cuda12]", "flax>=0.7.1"] + + +def test_requirements() -> List[str]: + """Test dependencies for TE/JAX extensions.""" + return ["numpy"] + + def xla_path() -> str: """XLA root path lookup. Throws FileNotFoundError if XLA source is not found.""" diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e4716cca10..a40db3bac4 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -9,6 +9,22 @@ import setuptools from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled +from typing import List + + +def install_requirements() -> List[str]: + """Install dependencies for TE/JAX extensions.""" + reqs = ["torch>=2.1", "einops"] + reqs.append( + "nvdlfw-inspect @" + " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" + ) + return reqs + + +def test_requirements() -> List[str]: + """Test dependencies for TE/JAX extensions.""" + return ["numpy", "torchvision", "transformers"] def setup_pytorch_extension( diff --git a/setup.py b/setup.py index ddddc75d18..f3599e71ac 100644 --- a/setup.py +++ b/setup.py @@ -120,19 +120,17 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: + from build_tools.pytorch import install_requirements, test_requirements + setup_reqs.extend(["torch>=2.1"]) - install_reqs.extend(["torch>=2.1"]) - install_reqs.append( - "nvdlfw-inspect @" - " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" - ) - # Blackwell is not supported as of Triton 3.2.0, need custom internal build - # install_reqs.append("triton") - test_reqs.extend(["numpy", "torchvision", "transformers"]) + install_reqs.extend(install_requirements()) + test_reqs.extend(test_requirements()) if "jax" in frameworks: + from build_tools.jax import install_requirements, test_requirements + setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"]) - install_reqs.extend(["jax", "flax>=0.7.1"]) - test_reqs.extend(["numpy"]) + install_reqs.extend(install_requirements()) + test_reqs.extend(test_requirements()) return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 2b64543ecf..c428374f80 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -46,7 +46,7 @@ from build_tools.build_ext import get_build_ext from build_tools.utils import copy_common_headers, install_and_import, cuda_toolkit_include_path from build_tools.te_version import te_version -from build_tools.jax import setup_jax_extension +from build_tools.jax import setup_jax_extension, install_requirements, test_requirements install_and_import("pybind11") from pybind11.setup_helpers import build_ext as BuildExtension @@ -116,8 +116,8 @@ ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, setup_requires=setup_requires, - install_requires=["jax", "flax>=0.7.1"], - tests_require=["numpy"], + install_requires=install_requirements(), + tests_require=test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index ee088bc908..0e0af06abf 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -31,7 +31,7 @@ from build_tools.build_ext import get_build_ext from build_tools.utils import copy_common_headers, cuda_toolkit_include_path from build_tools.te_version import te_version -from build_tools.pytorch import setup_pytorch_extension +from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -70,8 +70,8 @@ ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, setup_requires=setup_requires, - install_requires=["torch>=2.1"], - tests_require=["numpy", "torchvision"], + install_requires=install_requirements(), + tests_require=test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) From 031c6cf6cdc79d3bf7fe3b16ea6a293e734e1767 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 9 Jun 2025 17:57:06 -0700 Subject: [PATCH 03/39] Python 3.12+ support (#1862) Signed-off-by: Kirthi Shankar Sivamani --- setup.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index f3599e71ac..f0ad7e9011 100644 --- a/setup.py +++ b/setup.py @@ -201,14 +201,8 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8, <3.13", - classifiers=[ - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], + python_requires=">=3.8", + classifiers=["Programming Language :: Python :: 3"], setup_requires=setup_requires, install_requires=install_requires, license_files=("LICENSE",), From faee0e8bb046bfe9a481158e7ac9796d10e8640f Mon Sep 17 00:00:00 2001 From: yuzhongw-nvidia Date: Wed, 11 Jun 2025 04:43:05 +0800 Subject: [PATCH 04/39] Support Context Parallel for Multi Latent Attention (MLA) (#1729) * Support MLA (qk_dim != v_dim) for AttnFuncWithCPAndKVP2P Signed-off-by: Yuzhong Wang * add UT for MLA CP Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine the code Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine the code Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Yuzhong Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> --- .../fused_attn/test_fused_attn_with_cp.py | 14 + .../dot_product_attention/context_parallel.py | 604 +++++++++++++----- .../attention/dot_product_attention/utils.py | 5 - 3 files changed, 465 insertions(+), 158 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index b17c85327c..4ecc54b530 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -107,6 +107,18 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_2_4": ModelConfig( 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) ), # GQA + "cp_3_0": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64 + ), # MLA + "cp_3_1": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64 + ), # MLA + "cp_3_2": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64 + ), # MLA + "cp_3_3": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64 + ), # MLA } @@ -159,6 +171,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ) if dtype != "fp8" and fp8_mha: pytest.skip("Only fp8 works with fp8_mha=True!") + if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: + pytest.skip("MLA CP currently only support KV P2P!") subprocess.run( get_bash_arguments( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f9a5d02496..9f4822784e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -461,6 +461,7 @@ def forward( ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + enable_mla = k.shape[-1] != v.shape[-1] if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -498,7 +499,10 @@ def forward( cu_seqlens_q_half, cu_seqlens_kv_half = None, None if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") - qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + if enable_mla: + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + else: + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None if use_fused_attention: batch_dim = qkv_format.index("b") @@ -676,9 +680,16 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if qkv_format in ["bshd", "sbhd"]: + if enable_mla: + # If MLA, the shape of k and v does not match, so we flatten them + # and split them after receiving them. + k_shape = k.shape + k_numel = k.numel() + v_shape = v.shape + p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) + elif qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) - else: + else: # qkv_format == "thd" p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] @@ -707,6 +718,10 @@ def forward( else: # KV exchange is in BF16/FP16, cast received KV in each step kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data + if enable_mla: + # If MLA, k and v are flattened, so split them after receiving. + k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) + v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) if causal: if i == 0: if pad_between_seqs: @@ -725,17 +740,27 @@ def forward( if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[2:]) + v_part = v_part.view(-1, *v_part.shape[2:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) elif qkv_format == "thd": q_inputs[i % 2] = q if use_fused_attention: @@ -750,16 +775,19 @@ def forward( ).contiguous() q_part = q_inputs[i % 2] - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( @@ -810,6 +838,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) + # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -858,36 +887,60 @@ def forward( if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn] + k_part = k_part[:, 0, ...] + v_part = v_part[:, 0, ...] + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0] + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk//2, b, np, hn] + k_part = k_part[0] + v_part = v_part[0] + else: + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0] elif qkv_format == "thd": q_inputs[i % 2] = q - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + if enable_mla: + # [t, np, hn] -> [t/2, np, hn] + k_part = tex.thd_read_half_tensor( + k_part, cu_seqlens_kv_padded, 0 + ) + v_part = tex.thd_read_half_tensor( + v_part, cu_seqlens_kv_padded, 0 + ) + else: + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 + ) if use_fused_attention: - kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() + if enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() q_part = q_inputs[i % 2] - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( @@ -948,6 +1001,7 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -996,17 +1050,27 @@ def forward( if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_inputs[i % 2] = q[:, 1, ...] - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) elif qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_inputs[i % 2] = q[1] - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[2:]) + v_part = v_part.view(-1, *v_part.shape[2:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) elif qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] q_inputs[i % 2] = tex.thd_read_half_tensor( @@ -1025,16 +1089,17 @@ def forward( ).contiguous() q_part = q_inputs[i % 2] - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( @@ -1095,6 +1160,7 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], ( @@ -1152,16 +1218,17 @@ def forward( ).contiguous() q_part = q - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( @@ -1211,6 +1278,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) + # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q, ( @@ -1257,7 +1325,15 @@ def forward( if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) + if enable_mla: + out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( + v_shape + ) + else: + # MHA or GQA + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( + q.shape + ) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -1295,7 +1371,10 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - out = out.view(q.shape) + if enable_mla: + out = out.view(v_shape) + else: + out = out.view(q.shape) else: flash_attn_fwd_out_correction( out.view(*out_per_step[i].shape), @@ -1417,6 +1496,12 @@ def forward( ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.enable_mla = enable_mla + if enable_mla: + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape + ctx.qkv_dtype = qkv_dtype ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_CP_quantizer = dQKV_CP_quantizer @@ -1466,7 +1551,10 @@ def backward(ctx, dout): seq_dim = None if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] + if ctx.enable_mla: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + else: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] else: qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format @@ -1595,8 +1683,13 @@ def backward(ctx, dout): ) dout = dout.dequantize(dtype=dout_dtype) - out = out.view(*q.shape) - dout = dout.view(*q.shape) + if ctx.enable_mla: + out = out.view(*ctx.v_shape) + dout = dout.view(*ctx.v_shape) + else: + # MHA or GQA + out = out.view(*q.shape) + dout = dout.view(*q.shape) send_recv_reqs = [] flash_attn_bwd = None @@ -1672,6 +1765,9 @@ def backward(ctx, dout): kv = p2p_comm_buffers[i % 2][0] q_, kv_, out_, dout_ = None, None, None, None dq_, dk_, dv_ = None, None, None + if ctx.enable_mla: + k_part = kv[: ctx.k_numel].view(*ctx.k_shape) + v_part = kv[ctx.k_numel :].view(*ctx.v_shape) # In reversed order of fwd if causal: if i == (cp_size - 1): @@ -1680,13 +1776,23 @@ def backward(ctx, dout): q_, out_, dout_ = [ x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] ] - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[-3:]) + v_part = v_part.view(-1, *v_part.shape[-3:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) elif ctx.qkv_format == "thd": q_, kv_, out_, dout_ = q, kv, out, dout if ctx.use_fused_attention: @@ -1701,8 +1807,13 @@ def backward(ctx, dout): if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] q_part = q_ - k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) out_part = out_ dout_part = dout_ @@ -1784,6 +1895,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = 0 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, @@ -1801,19 +1913,38 @@ def backward(ctx, dout): q_, out_, dout_ = [ x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] ] - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0] + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part[:, 0] + v_part = v_part[:, 0] + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0] elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0] + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part[0] + v_part = v_part[0] + else: + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0] elif ctx.qkv_format == "thd": q_, out_, dout_ = q, out, dout - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) + if ctx.enable_mla: + # [t, np, hn] -> [t/2, np, hn] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + else: + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) if ctx.use_fused_attention: - kv_ = kv_.contiguous() + if ctx.enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + kv_ = kv_.contiguous() if ctx.fp8: aux_ctx_tensors = [ softmax_lse, @@ -1825,8 +1956,13 @@ def backward(ctx, dout): if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] q_part = q_ - k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) out_part = out_ dout_part = dout_ @@ -1910,6 +2046,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, @@ -1925,13 +2062,23 @@ def backward(ctx, dout): if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] q_, out_, dout_ = q[1], out[1], dout[1] - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[-3:]) + v_part = v_part.view(-1, *v_part.shape[-3:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) elif ctx.qkv_format == "thd": # [t, np, hn] -> [t/2, np, hn] q_, out_, dout_ = [ @@ -1953,8 +2100,13 @@ def backward(ctx, dout): aux_ctx_tensors += [attn_biases[cp_size - i - 1]] q_part = q_ - k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) out_part = out_ dout_part = dout_ @@ -2038,6 +2190,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, @@ -2058,8 +2211,9 @@ def backward(ctx, dout): if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] q_part = q - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + if not ctx.enable_mla: + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] out_part = out dout_part = dout @@ -2133,6 +2287,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout, q, @@ -2225,15 +2380,18 @@ def backward(ctx, dout): else: dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: - if ctx.qkv_format in ["bshd", "sbhd"]: + if ctx.enable_mla: + dkv_ = None + elif ctx.qkv_format in ["bshd", "sbhd"]: dkv_ = combine_tensors([dk_, dv_], -2) elif ctx.qkv_format == "thd": dkv_ = torch.cat( (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 ) # pylint: disable=used-before-assignment - if ctx.qkv_format in ["bshd", "sbhd"]: + if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]: # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + # dkv is a buffer, so we do not need to transpose it, but only need to reshape it. dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) dkv_ = dkv_.movedim(-3, 0) if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): @@ -2241,91 +2399,225 @@ def backward(ctx, dout): # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] dkv_ = dkv_.view(*dkv.shape) - if ctx.fp8: - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - dkv[:, :, 1, ...].fill_(0) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - dkv[:, 1, ...].fill_(0) - else: - dkv.copy_(dkv_) - elif causal: - if i == (cp_size - 1): - if rank == 0: + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] or + # [2, sk//2, b, np, hn] + dk = dkv[: ctx.k_numel].view(*ctx.k_shape) + dv = dkv[ctx.k_numel :].view(*ctx.v_shape) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + dk_ = dk_.view(*ctx.k_shape) + dv_ = dv_.view(*ctx.v_shape) + + if ctx.fp8: + # enable_mla and fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) - dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + dk[:, 0, ...].copy_(dk_) + dk[:, 1, ...].fill_(0) + dv[:, 0, ...].copy_(dv_) + dv[:, 1, ...].fill_(0) elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_[:, 0, ...]) - dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy") - else: - dkv.add_(dkv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): + dk[0].copy_(dk_) + dk[1].fill_(0) + dv[0].copy_(dv_) + dv[1].fill_(0) + else: + dk.copy_(dk_) + dv.copy_(dv_) + elif causal: + # enable_mla and not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_[:, 0, ...]) + dk[:, 1, ...].copy_(dk_[:, 1, ...]) + dv[:, 0, ...].add_(dv_[:, 0, ...]) + dv[:, 1, ...].copy_(dv_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_[0, ...]) + dk[1, ...].copy_(dk_[1, ...]) + dv[0, ...].add_(dv_[0, ...]) + dv[1, ...].copy_(dv_[1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "add", "copy" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "add", "copy" + ) + else: + dk.add_(dk_) + dv.add_(dv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dv[:, 0, ...].copy_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].copy_(dk_) + dv[0, ...].copy_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "copy", "none" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "copy", "none" + ) + else: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_) + dv[:, 0, ...].add_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_) + dv[0, ...].add_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "add", "none" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "add", "none" + ) + elif i > 0: + dk.add_(dk_) + dv.add_(dv_) + else: # i == 0 + dk.copy_(dk_) + dv.copy_(dv_) + else: + # enable_mla and not fp8 and not causal + if i == 0: + dk.copy_(dk_) + dv.copy_(dv_) + else: # i > 0 + dk.add_(dk_) + dv.add_(dv_) + else: + if ctx.fp8: + # fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if ctx.qkv_format == "bshd": dkv[:, :, 0, ...].copy_(dkv_) + dkv[:, :, 1, ...].fill_(0) elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none") + dkv[:, 1, ...].fill_(0) else: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none") - elif i > 0: - dkv.add_(dkv_) - else: - dkv.copy_(dkv_) - else: - if i == 0: - dkv.copy_(dkv_) + dkv.copy_(dkv_) + elif causal: + # not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) + dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_[:, 0, ...]) + dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "add", "copy" + ) + else: + dkv.add_(dkv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "copy", "none" + ) + else: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "add", "none" + ) + elif i > 0: + dkv.add_(dkv_) + else: # i == 0 + dkv.copy_(dkv_) else: - dkv.add_(dkv_) + # not fp8 and not causal + if i == 0: + dkv.copy_(dkv_) + else: # i > 0 + dkv.add_(dkv_) if ctx.fp8 and ctx.use_fused_attention: amax_cp_bwd = amax_per_step.amax(dim=1) ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) - if ctx.qkv_format in ["bshd", "sbhd"]: - # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or - # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] - dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( dq_fp8, fake_dtype=torch.float32, internal=True ) - dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dkv_fp8, fake_dtype=torch.float32, internal=True - ) - dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] - dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + + if ctx.enable_mla: + # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] + dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape) + dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape) + dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dk_fp8, fake_dtype=torch.float32, internal=True + ) + dv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]] + dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]] + else: + if ctx.qkv_format in ["bshd", "sbhd"]: + # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or + # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] + dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) + dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dkv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] + dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] if causal: if ctx.qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:]) + dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:]) + else: + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) elif ctx.qkv_format == "sbhd": # [2, sq//2, b, np, hn] -> [sq, b, np, hn] dq = dq.view(-1, *dq.shape[-3:]) - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + else: + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) if ctx.qkv_format == "thd" and not ctx.use_fused_attention: dq[cu_seqlens_q_padded[-1] :].fill_(0) - dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) + if ctx.enable_mla: + dk[cu_seqlens_kv_padded[-1] :].fill_(0) + dv[cu_seqlens_kv_padded[-1] :].fill_(0) + else: + dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: assert torch.uint8 not in [dq.dtype, dkv.dtype] - dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] - dk, dv = dkv[0], dkv[1] + if ctx.enable_mla: + dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]] + else: + dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] + if not ctx.enable_mla: + dk, dv = dkv[0], dkv[1] if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) @@ -3584,6 +3876,12 @@ def attn_forward_func_with_cp( "all_gather", ], "The context parallel running configs cannot support sliding window attetnion!" + enable_mla = k.shape[-1] != v.shape[-1] + assert not enable_mla or cp_comm_type in [ + "p2p", + "a2a+p2p", + ], "The context parallel running configs cannot support MLA!" + args = [ is_training, q, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ba90fcfeb7..ec93e8c5c8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -608,11 +608,6 @@ def get_attention_backend( " bias for THD format" ) use_fused_attention = False - elif head_dim_qk != head_dim_v: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with MLA" - ) - use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends From aedd7e1012d4a0ad704ee9e70e4e3f91c9ba87ea Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 10 Jun 2025 13:59:28 -0700 Subject: [PATCH 05/39] pyproject.toml (#1852) * Initial basic setup Signed-off-by: Kirthi Shankar Sivamani * rm setup reqs Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * buil-isolation support Signed-off-by: Kirthi Shankar Sivamani * rm not needed funcs Signed-off-by: Kirthi Shankar Sivamani * Fix workflows Signed-off-by: Kirthi Shankar Sivamani * fix wheel Signed-off-by: Kirthi Shankar Sivamani * Fix invalid wheel Signed-off-by: Kirthi Shankar Sivamani * Fix JAX build in baremetal env Signed-off-by: Kirthi Shankar Sivamani * Update install inst in readme Signed-off-by: Kirthi Shankar Sivamani * Update build.yml Signed-off-by: Kirthi Shankar Sivamani * docstring fix Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/build.yml | 10 +++--- README.rst | 4 +-- build_tools/jax.py | 17 ++-------- build_tools/utils.py | 7 ---- build_tools/wheel_utils/build_wheels.sh | 15 +++++---- pyproject.toml | 10 ++++++ setup.py | 39 +++-------------------- transformer_engine/jax/pyproject.toml | 10 ++++++ transformer_engine/jax/setup.py | 18 +---------- transformer_engine/pytorch/pyproject.toml | 10 ++++++ transformer_engine/pytorch/setup.py | 17 +--------- 11 files changed, 56 insertions(+), 101 deletions(-) create mode 100755 pyproject.toml create mode 100755 transformer_engine/jax/pyproject.toml create mode 100755 transformer_engine/pytorch/pyproject.toml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 78ef679f9d..80703adf11 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,8 +18,8 @@ jobs: - name: 'Dependencies' run: | apt-get update - apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 - pip install cmake==3.21.0 + apt-get install -y git python3.9 pip cudnn9-cuda-12 + pip install cmake==3.21.0 pybind11[global] ninja - name: 'Checkout' uses: actions/checkout@v3 with: @@ -42,8 +42,8 @@ jobs: - name: 'Dependencies' run: | apt-get update - apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 - pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops + apt-get install -y git python3.9 pip cudnn9-cuda-12 + pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops - name: 'Checkout' uses: actions/checkout@v3 with: @@ -62,6 +62,8 @@ jobs: image: ghcr.io/nvidia/jax:jax options: --user root steps: + - name: 'Dependencies' + run: pip install pybind11[global] - name: 'Checkout' uses: actions/checkout@v3 with: diff --git a/README.rst b/README.rst index c318554460..cfd5c687e4 100644 --- a/README.rst +++ b/README.rst @@ -216,13 +216,13 @@ Alternatively, install directly from the GitHub repository: .. code-block:: bash - pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable When installing from GitHub, you can explicitly specify frameworks using the environment variable: .. code-block:: bash - NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + NVTE_FRAMEWORK=pytorch,jax pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable conda Installation ^^^^^^^^^^^^^^^^^^ diff --git a/build_tools/jax.py b/build_tools/jax.py index 2b9ad1a30f..4fe4b78ba5 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -14,7 +14,7 @@ def install_requirements() -> List[str]: """Install dependencies for TE/JAX extensions.""" - return ["jax[cuda12]", "flax>=0.7.1"] + return ["jax", "flax>=0.7.1"] def test_requirements() -> List[str]: @@ -75,20 +75,9 @@ def setup_jax_extension( # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension - class Pybind11CPPExtension(Pybind11Extension): - """Modified Pybind11Extension to allow custom CXX flags.""" - - def _add_cflags(self, flags: List[str]) -> None: - if isinstance(self.extra_compile_args, dict): - cxx_flags = self.extra_compile_args.pop("cxx", []) - cxx_flags += flags - self.extra_compile_args["cxx"] = cxx_flags - else: - self.extra_compile_args[:0] = flags - - return Pybind11CPPExtension( + return Pybind11Extension( "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], - extra_compile_args={"cxx": cxx_flags}, + extra_compile_args=cxx_flags, ) diff --git a/build_tools/utils.py b/build_tools/utils.py index 3c8554dc07..0dc5e36898 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -354,10 +354,3 @@ def copy_common_headers( new_path = dst_dir / path.relative_to(src_dir) new_path.parent.mkdir(exist_ok=True, parents=True) shutil.copy(path, new_path) - - -def install_and_import(package): - """Install a package via pip (if not already installed) and import into globals.""" - main_package = package.split("[")[0] - subprocess.check_call([sys.executable, "-m", "pip", "install", package]) - globals()[main_package] = importlib.import_module(main_package) diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 9acb22aee6..bf4f9d2bc2 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -20,6 +20,9 @@ cd /TransformerEngine git checkout $TARGET_BRANCH git submodule update --init --recursive +# Install deps +/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja + if $BUILD_METAPACKAGE ; then cd /TransformerEngine NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt @@ -31,15 +34,15 @@ if $BUILD_COMMON ; then WHL_BASE="transformer_engine-${VERSION}" # Create the wheel. - /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt # Repack the wheel for cuda specific package, i.e. cu12. - /opt/python/cp38-cp38/bin/wheel unpack dist/* + /opt/python/cp310-cp310/bin/wheel unpack dist/* # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" - /opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE} + /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) @@ -51,14 +54,14 @@ fi if $BUILD_PYTORCH ; then cd /TransformerEngine/transformer_engine/pytorch - /opt/python/cp38-cp38/bin/pip install torch - /opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt + /opt/python/cp310-cp310/bin/pip install torch + /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt cp dist/* /wheelhouse/ fi if $BUILD_JAX ; then cd /TransformerEngine/transformer_engine/jax - /opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib + /opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi diff --git a/pyproject.toml b/pyproject.toml new file mode 100755 index 0000000000..ef112d2798 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[build-system] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"] + +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + diff --git a/setup.py b/setup.py index f0ad7e9011..8cdedde844 100644 --- a/setup.py +++ b/setup.py @@ -16,13 +16,8 @@ from build_tools.te_version import te_version from build_tools.utils import ( cuda_archs, - found_cmake, - found_ninja, - found_pybind11, get_frameworks, - install_and_import, remove_dups, - cuda_toolkit_include_path, ) frameworks = get_frameworks() @@ -36,7 +31,6 @@ if "pytorch" in frameworks: from torch.utils.cpp_extension import BuildExtension elif "jax" in frameworks: - install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension @@ -82,26 +76,13 @@ def setup_common_extension() -> CMakeExtension: ) -def setup_requirements() -> Tuple[List[str], List[str], List[str]]: +def setup_requirements() -> Tuple[List[str], List[str]]: """Setup Python dependencies - Returns dependencies for build, runtime, and testing. + Returns dependencies for runtime and testing. """ # Common requirements - setup_reqs: List[str] = [] - if cuda_toolkit_include_path() is None: - setup_reqs.extend( - [ - "nvidia-cuda-runtime-cu12", - "nvidia-cublas-cu12", - "nvidia-cudnn-cu12", - "nvidia-cuda-cccl-cu12", - "nvidia-cuda-nvcc-cu12", - "nvidia-nvtx-cu12", - "nvidia-cuda-nvrtc-cu12", - ] - ) install_reqs: List[str] = [ "pydantic", "importlib-metadata>=1.0", @@ -109,30 +90,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ] test_reqs: List[str] = ["pytest>=8.2.1"] - # Requirements that may be installed outside of Python - if not found_cmake(): - setup_reqs.append("cmake>=3.21") - if not found_ninja(): - setup_reqs.append("ninja") - if not found_pybind11(): - setup_reqs.append("pybind11") - # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: from build_tools.pytorch import install_requirements, test_requirements - setup_reqs.extend(["torch>=2.1"]) install_reqs.extend(install_requirements()) test_reqs.extend(test_requirements()) if "jax" in frameworks: from build_tools.jax import install_requirements, test_requirements - setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"]) install_reqs.extend(install_requirements()) test_reqs.extend(test_requirements()) - return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] + return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] if __name__ == "__main__": @@ -149,14 +120,13 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ext_modules = [] package_data = {} include_package_data = False - setup_requires = [] install_requires = ([f"transformer_engine_cu12=={__version__}"],) extras_require = { "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } else: - setup_requires, install_requires, test_requires = setup_requirements() + install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] package_data = {"": ["VERSION.txt"]} include_package_data = True @@ -203,7 +173,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8", classifiers=["Programming Language :: Python :: 3"], - setup_requires=setup_requires, install_requires=install_requires, license_files=("LICENSE",), include_package_data=include_package_data, diff --git a/transformer_engine/jax/pyproject.toml b/transformer_engine/jax/pyproject.toml new file mode 100755 index 0000000000..ff0e356ed9 --- /dev/null +++ b/transformer_engine/jax/pyproject.toml @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[build-system] +requires = ["setuptools>=61.0", "pybind11[global]", "pip", "jax[cuda12]", "flax>=0.7.1"] + +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index c428374f80..ca83cf465e 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -44,11 +44,10 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers, install_and_import, cuda_toolkit_include_path +from build_tools.utils import copy_common_headers from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements -install_and_import("pybind11") from pybind11.setup_helpers import build_ext as BuildExtension os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -94,20 +93,6 @@ ) ] - setup_requires = ["jax[cuda12]", "flax>=0.7.1"] - if cuda_toolkit_include_path() is None: - setup_requires.extend( - [ - "nvidia-cuda-runtime-cu12", - "nvidia-cublas-cu12", - "nvidia-cudnn-cu12", - "nvidia-cuda-cccl-cu12", - "nvidia-cuda-nvcc-cu12", - "nvidia-nvtx-cu12", - "nvidia-cuda-nvrtc-cu12", - ] - ) - # Configure package setuptools.setup( name="transformer_engine_jax", @@ -115,7 +100,6 @@ description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - setup_requires=setup_requires, install_requires=install_requirements(), tests_require=test_requirements(), ) diff --git a/transformer_engine/pytorch/pyproject.toml b/transformer_engine/pytorch/pyproject.toml new file mode 100755 index 0000000000..e5a4549db2 --- /dev/null +++ b/transformer_engine/pytorch/pyproject.toml @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[build-system] +requires = ["setuptools>=61.0", "pip", "torch>=2.1"] + +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 0e0af06abf..ae1b5780bb 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -29,7 +29,7 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers, cuda_toolkit_include_path +from build_tools.utils import copy_common_headers from build_tools.te_version import te_version from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements @@ -48,20 +48,6 @@ ) ] - setup_requires = ["torch>=2.1"] - if cuda_toolkit_include_path() is None: - setup_requires.extend( - [ - "nvidia-cuda-runtime-cu12", - "nvidia-cublas-cu12", - "nvidia-cudnn-cu12", - "nvidia-cuda-cccl-cu12", - "nvidia-cuda-nvcc-cu12", - "nvidia-nvtx-cu12", - "nvidia-cuda-nvrtc-cu12", - ] - ) - # Configure package setuptools.setup( name="transformer_engine_torch", @@ -69,7 +55,6 @@ description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - setup_requires=setup_requires, install_requires=install_requirements(), tests_require=test_requirements(), ) From 0efc7daf59a8700a6d3fecd14bcad98e7e040281 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 11 Jun 2025 20:46:08 -0700 Subject: [PATCH 06/39] [PyTorch] Fix backward compatibility for checkpoint loading (#1868) Fix for loading old ckpt formats Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5dd44da5ed..acbd871c7e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -820,6 +820,11 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: def set_extra_state(self, state: torch.Tensor) -> None: """Load previous state.""" + + # Maintain backwards compatibility with older checkpoints. + if state is None: + return + # Load state if isinstance(state, torch.Tensor): # No FP8 is indicated by an empty tensor we don't need to unpickle. From c293d3a88995277fc974587651d98542fa1c871e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:06:22 +0200 Subject: [PATCH 07/39] [PyTorch] Fix typo in GrouppedLinear (#1867) typo fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/grouped_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1bede9d933..5fe351578e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -241,8 +241,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = ctx.main_grads - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO - for i in ctx.num_gemms: + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + for i in range(ctx.num_gemms): w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w.main_grad = main_grads[i] weights[i] = w From 5d01ef2113d154846883623c947224e51bf75ac6 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 12 Jun 2025 10:47:32 -0400 Subject: [PATCH 08/39] [JAX] GroupedDense v.2 without dynamic shape (#1721) * Implemented GroupedDense and TestGroupedDense for BF16, FP16, and FP8 * Fix GroupedGemmFFI cuBLAS workspace alignment bug Signed-off-by: Hua Huang Signed-off-by: Phuong Nguyen --- tests/jax/test_custom_call_compute.py | 352 ++++++-------- .../common/gemm/cublaslt_gemm.cu | 3 + transformer_engine/jax/cpp_extensions/gemm.py | 459 ++++++++++++------ .../jax/cpp_extensions/quantization.py | 23 +- .../jax/csrc/extensions/gemm.cpp | 308 +++++++----- transformer_engine/jax/dense.py | 301 ++++++++---- .../jax/quantize/dequantizer.py | 12 +- 7 files changed, 881 insertions(+), 577 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 25a463aeaa..9ff0c11757 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -40,10 +40,11 @@ ScalingMode, QuantizerFactory, QuantizeLayout, + noop_quantizer_set, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation -from transformer_engine.jax.dense import dense +from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense GEMM_CASES = [ @@ -1204,24 +1205,6 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) -# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm() -def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer): - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - lhs_q = lhs_quantizer.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = rhs_quantizer.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return lhs_q, rhs_q - - # E5M2 * E5M2 is not supported fwd_bwd_dtypes = [ [jnp.float8_e4m3fn, jnp.float8_e4m3fn], @@ -1229,219 +1212,194 @@ def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer [jnp.float8_e5m2, jnp.float8_e4m3fn], ] -""" -@pytest_parametrize_wrapper( - "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] -) +GROUPED_DENSE_INPUT_SHAPES = [ + # (n_groups, m, n, k), the actual m will be multiplied by 32 + (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 + (8, 64, 32, 128), + (8, 64, 128, 256), +] + + +@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) class TestGroupedDense: - def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list): - ref_out_list = [] - for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): - dim_nums = (contracting_dims, ((), ())) - ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums)) - return ref_out_list - - def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): + def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): + lhs_contract_dim, _ = contracting_dims + assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 + if bias is None: + bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) + else: + assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) + remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() + lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) + rhs = jnp.split(rhs, rhs.shape[0], axis=0) + bias = jnp.split(bias, bias.shape[0], axis=0) + ref_out = [] + dim_num = (contracting_dims, ((), ())) + for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): + out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0) + ref_out.append(jnp.squeeze(out_i)) + return ref_out + + def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, len(shape_list) * 2) - - lhs_list, rhs_list, contracting_dims_list = [], [], [] - for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)): - lhs = jax.random.uniform( - subkeys[2 * i], - (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), - dtype=dtype, - ) - rhs = jax.random.uniform( - subkeys[2 * i + 1], - (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), - dtype=dtype, - ) - lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) - rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) - contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + subkeys = jax.random.split(key, 4) + n_groups, m, n, k = input_shape + + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + assert group_sizes.sum() == m + + # *32 to make sure that input shape works for MXFP8 + group_sizes = group_sizes * 32 + m = m * 32 - lhs_list.append(lhs) - rhs_list.append(rhs) - contracting_dims_list.append(contracting_dims) + lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) + rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) + bias_shape = (n_groups, n) - return lhs_list, rhs_list, contracting_dims_list + lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) + rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype) + bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None + + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,) + contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + + return lhs, rhs, group_sizes, contracting_dims, bias + + def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): + assert out.dtype == ref_list[0].dtype + out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + for i in range(len(ref_list)): + assert_allclose(out_list[i], ref_list[i], dtype=dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) - def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list): - lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( - dtype, shape_list, layout_list + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp16(self, dtype, input_shape, layout): + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + dtype, input_shape, layout ) - ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) - primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list) - for i in range(len(shape_list)): - assert_allclose(primitive_out[i], ref_out[i], dtype=dtype) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) - def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list): + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + pytest.skip("MXFP8 is not supported in grouped_gemm yet") + fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False + scaling_mode=scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=False, + n_groups=input_shape[0], ) + # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype + # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype + quantizer_set.kernel.q_dtype = bwd_dtype + for quantizer in quantizer_set.kernel.quantizers: + quantizer.q_dtype = bwd_dtype + out_dtype = jnp.bfloat16 - lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( - out_dtype, shape_list, layout_list + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + out_dtype, input_shape, layout + ) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + prim_out = tex.grouped_gemm( + lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set ) - q_lhs_list = [] - q_rhs_list = [] - for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): - # quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to - # test the case where lhs and rhs have different q_dtypes - q_lhs, q_rhs = _quantize_gemm_pair( - lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad - ) - q_lhs_list.append(q_lhs) - q_rhs_list.append(q_rhs) - - ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) - primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list) allclose_dtype = jnp.float8_e4m3fn - if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: + if jnp.float8_e5m2 in fwd_bwd_dtype: allclose_dtype = jnp.float8_e5m2 - for i in range(len(shape_list)): - assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype) - @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - def test_grouped_dense_grad_fp16(self, dtype, shape_list): - group_size = len(shape_list) - layout_list = ["NN" for _ in range(group_size)] + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) - x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( - dtype, shape_list, layout_list + def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): + out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) + # Note: we use jnp.sum instead of jnp.mean to make the gradient larger + # and prevent them from being clamp to zero + out_sum_list = [jnp.sum(out) for out in out_list] + return jnp.sum(jnp.asarray(out_sum_list)) + + def _primitive_sum_grouped_dense( + self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set + ): + out = grouped_dense( + x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set ) - bias_list = [] - key = jax.random.PRNGKey(1) - for shape in shape_list: - n = shape[1] - bias = jax.random.uniform(key, n, dtype=dtype) - bias_list.append(bias) - - def ref_func(x_list, kernel_list, bias_list, contracting_dims_list): - out_list = [] - for i in range(len(x_list)): - out_list.append( - dense( - x_list[i], - kernel_list[i], - bias_list[i], - contracting_dims=contracting_dims_list[i], - ) - ) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + return jnp.sum(jnp.asarray(out)) - def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list): - out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list) - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) + def test_grouped_dense_grad_fp16(self, dtype, input_shape): + x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, + input_shape, + with_bias=True, + ) - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) + value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) - ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( - x_list, kernel_list, bias_list, contracting_dims_list + ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( + x, kernel, bias, group_sizes, contracting_dims ) - primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( - value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list) + prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( + x, kernel, bias, group_sizes, contracting_dims ) - assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype) - for i in range(group_size): - assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype) - assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype) - assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype) + assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) + assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) + assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) + @pytest.mark.parametrize( + "fwd_bwd_dtype", + [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], + ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list): - group_size = len(shape_list) - layout_list = ["NN" for _ in range(group_size)] - fwd_dtype, bwd_dtype = fwd_bwd_dtype - if fwd_dtype == jnp.float8_e5m2: - pytest.skip("We never use E5M2 for fwd_dtype in training") - - # Question: should we use different quantizers for different groups? - ref_quantizer_set_list = [] - quantizer_set_list = [] - for _ in range(group_size): - ref_quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True - ) - ref_quantizer_set_list.append(ref_quantizer_set) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True - ) - quantizer_set_list.append(quantizer_set) + def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + pytest.skip("MXFP8 is not supported in grouped_dense yet") - out_dtype = jnp.bfloat16 - x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( - out_dtype, shape_list, layout_list + fwd_dtype, bwd_dtype = fwd_bwd_dtype + dtype = jnp.bfloat16 + x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, + input_shape, + with_bias=True, ) - bias_list = [] - key = jax.random.PRNGKey(1) - for shape in shape_list: - n = shape[1] - bias = jax.random.uniform(key, n, dtype=out_dtype) - bias_list.append(bias) - - def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): - out_list = [] - for i in range(len(x_list)): - out_list.append( - dense( - x_list[i], - kernel_list[i], - bias_list[i], - contracting_dims=contracting_dims_list[i], - quantizer_set=quantizer_set_list[i], - ) - ) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) - - def primitive_func( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ): - out_list = grouped_dense( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ) - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) - - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) - ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( - x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list + quantizer_set = QuantizerFactory.create_set( + scaling_mode=scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=True, + n_groups=group_sizes.size, ) - primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( - value_n_grad_primitive_func( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ) + value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + + ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( + x, + kernel, + bias, + group_sizes, + contracting_dims, + ) + prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( + x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set ) - allclose_dtype = jnp.float8_e4m3fn - if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: - allclose_dtype = jnp.float8_e5m2 - assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype) - for i in range(group_size): - assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype) - assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype) - assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype) -""" + assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) + assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) + assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 4080ae1668..fa8785dcc7 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -525,6 +525,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); + const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -533,6 +534,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); + NVTE_CHECK(workspace_alignment % 256 == 0, + "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c38a04f85a..cc02ec3404 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -6,22 +6,28 @@ from typing import Tuple, Sequence, Union, Dict from functools import partial, reduce import operator +import math import jax import jax.numpy as jnp from transformer_engine_jax import get_device_compute_capability from .base import BasePrimitive, register_primitive +from .quantization import grouped_quantize from ..quantize import ( ScaledTensor, + GroupedScaledTensor1x, ScalingMode, Quantizer, + GroupedQuantizer, QuantizeConfig, + QuantizerSet, + QuantizeLayout, noop_quantizer_set, ) -__all__ = ["gemm"] +__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] num_cublas_streams = 4 @@ -34,6 +40,11 @@ def get_cublas_workspace_size_bytes() -> None: return 4_194_304 +def is_gemm_with_all_layouts_supported() -> False: + """Return True if using blackwell, False otherwise.""" + return get_device_compute_capability(0) >= 100 + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -41,73 +52,139 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = () + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @staticmethod - def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): + def abstract( + lhs_data_aval, + lhs_scale_inv_aval, + rhs_data_aval, + rhs_scale_inv_aval, + bias_aval, + group_sizes_aval, + group_offset_aval, + *, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): """ + Grouped GEMM operation. + Args: - *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: - args[ 0 : num_gemms] are the lhs tensors, - args[ num_gemms : 2*num_gemms] are the rhs tensors, - args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, - args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, - args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. - num_gemms: Number of GEMM operations to perform. - scaling_mode: Scaling mode for the GEMM operations. - out_dtype: Data type of the output tensors. - has_bias: Boolean indicating if bias tensors are provided. + lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array + rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array + bias: Bias matrix of shape (G, N) + group_sizes: 1D array containing the sizes of each group + group_offset: 1D array containing offsets for each group (not yet implemented) + M: Number of rows in the output matrix + N: Number of columns in the output matrix + K: Number of columns in the left-hand side matrix + lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed + rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed + scaling_mode: Scaling mode for the GEMM operations + out_dtype: Data type of the output tensors + has_bias: Boolean indicating if bias tensors are provided + is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation + where both lhs and rhs are 2D matrices and output is (G, M, N) Returns: - A tuple of ShapedArray objects of size num_gemms+1: - ret[0 : num_gemms]: GEMM output tensors, - ret[num_gemms]:workspace tensor. + A jnp.ndarray containing the result of the grouped GEMM operation """ - del scaling_mode - expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms - assert ( - len(args) == expected_num_args - ), f"Expected {expected_num_args} input arguments, but got {len(args)}" - A_list = args[0:num_gemms] - B_list = args[num_gemms : 2 * num_gemms] - # A and B have shapes [1, m, k] and [1, n, k] - out_list_aval = tuple( - jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) - for A, B in zip(A_list, B_list) - ) + del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval + del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias + # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return (*out_list_aval, workspace_aval) + out_shape = (M, N) + if is_grouped_dense_wgrad: + out_shape = (group_sizes_aval.size, M, N) + out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + return (out_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) - return out_aval + return (out_aval,) @staticmethod - def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): + def lowering( + ctx, + *args, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, *args, - num_gemms=num_gemms, - scaling_mode=int(scaling_mode), + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) @staticmethod - def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): + def impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): assert GroupedGemmPrimitive.inner_primitive is not None - out = GroupedGemmPrimitive.inner_primitive.bind( - *args, - num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + (out, _) = GroupedGemmPrimitive.inner_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - return out[:-1] # out is [out_list, wkspace], only return out_list + return (out,) register_primitive(GroupedGemmPrimitive) @@ -285,7 +362,7 @@ def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """General matrix multiplication with optional quantization. @@ -310,130 +387,190 @@ def gemm( return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) -""" -def swizzled_scale(scales): - # Swizzle the scale tensor for FP8 GEMM - assert scales.ndim == 2 - rows, cols = scales.shape - scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) - scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) - scales = scales.reshape(rows, cols) - return scales +def grouped_gemm( + lhs: Union[jnp.ndarray, GroupedScaledTensor1x], + rhs: Union[jnp.ndarray, GroupedScaledTensor1x], + group_sizes: jnp.ndarray, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), + bias: jnp.ndarray = None, + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, + preferred_element_type: jnp.dtype = None, + group_offset: jnp.array = None, + quantizer_set: QuantizerSet = noop_quantizer_set, +) -> jnp.ndarray: + """ + Grouped GEMM operation. + + Args: + lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x + rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x + group_sizes: 1D array containing the sizes of each group + contracting_dims: Tuple of two sequences representing the contracting dimensions + bias: Bias tensor of shape (G, N) + precision: JAX precision for the GEMM operation + preferred_element_type: Preferred data type for the output tensor + group_offset: 1D array containing offsets for each group (not yet implemented) + quantizer_set: Set of quantizers for FP8 quantization of the input and output + Returns: + A jnp.ndarray containing the result of the grouped GEMM operation -def grouped_gemm( - lhs_list: List[Union[jnp.ndarray, ScaledTensor]], - rhs_list: List[Union[jnp.ndarray, ScaledTensor]], - contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], - bias_list: List[jnp.ndarray] = None, -) -> List[jnp.ndarray]: - # Grouped GEMM for multiple pairs of tensors. - assert ( - len(lhs_list) == len(rhs_list) == len(contracting_dims_list) - ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - - num_gemms = len(lhs_list) - lhs_list_ = [] - rhs_list_ = [] - lhs_sinv_list_ = [] - rhs_sinv_list_ = [] - bias_list_ = [] - for i in range(num_gemms): - lhs = lhs_list[i] - rhs = rhs_list[i] - contracting_dims = contracting_dims_list[i] - dim_nums = (contracting_dims, ((), ())) - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - scaling_mode = lhs.scaling_mode - lhs_shape = lhs.data.shape - rhs_shape = rhs.data.shape - out_dtype = lhs.dq_dtype - # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode.is_tensor_scaling(): - assert not ( - lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 - ), "FP8 GEMM does not support E5M2 * E5M2" - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - if lhs.data_layout == "T": - lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim - if rhs.data_layout == "T": - rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim - dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) + Note: + Tested shapes: + lhs: [M, K] or [K, N] + rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] + """ + # TODO(Phuong): implement the group_offset + group_offset = group_offset or jnp.zeros((1,), jnp.int32) + + # TODO(Phuong): implement the precision + del precision + + if isinstance(lhs, jnp.ndarray): + assert isinstance(rhs, jnp.ndarray) + out_dtype = lhs.dtype + lhs_shape = lhs.shape + rhs_shape = rhs.shape + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + elif isinstance(lhs, GroupedScaledTensor1x): + assert isinstance(rhs, GroupedScaledTensor1x) + out_dtype = lhs.dq_dtype + lhs_shape = lhs.original_shape + rhs_shape = rhs.original_shape + lhs_data = lhs.data + rhs_data = rhs.data + lhs_scale_inv = lhs.scale_inv + rhs_scale_inv = rhs.scale_inv + assert lhs.scaling_mode == rhs.scaling_mode + scaling_mode = lhs.scaling_mode + else: + raise TypeError("Unsupported lhs type object!") + + out_dtype = preferred_element_type or out_dtype + + lhs_contract_dim, rhs_contract_dim = contracting_dims + + lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 + lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) + + # rhs_shape [G, K, N] + rhs_is_trans = rhs_contract_dim[0] != 1 + rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) + + is_grouped_dense_wgrad = False + if len(rhs_shape) == 2: + rhs_is_trans = rhs_contract_dim[0] != 0 + is_grouped_dense_wgrad = True + + # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? + if ( + is_grouped_dense_wgrad + and not isinstance(lhs, ScaledTensor) + and not isinstance(rhs, ScaledTensor) + ): + lhs_is_trans = True + rhs_is_trans = False + lhs_flatten_axis = 1 + rhs_flatten_axis = 1 + + if ( + not isinstance(lhs, ScaledTensor) + and not isinstance(rhs, ScaledTensor) + and quantizer_set != noop_quantizer_set + ): + assert isinstance(quantizer_set.x, GroupedQuantizer) + assert type(quantizer_set.x) is type(quantizer_set.kernel) + scaling_mode = quantizer_set.x.scaling_mode + if ( + # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later + # scaling_mode.is_tensor_scaling() + # and is_gemm_with_all_layouts_supported() + scaling_mode.is_1d_block_scaling() + ): + lhs_is_rowwise = rhs_is_rowwise = True else: - # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NO_SCALING - lhs_shape = lhs.shape - rhs_shape = rhs.shape - out_dtype = lhs.dtype - - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums - lhs_dn = (lhs_contract, lhs_batch) - rhs_dn = (rhs_contract, rhs_batch) - - lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) - rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - - # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy - if scaling_mode == ScalingMode.NO_SCALING: - lhs_3d = _shape_normalization(lhs, lhs_dn) - rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode.is_tensor_scaling(): - lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") - rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn) - rhs_3d = _shape_normalization(rhs.data, rhs_dn) - lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) - rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) - # swizzled_scale requires a matrix - lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) - rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) + lhs_is_rowwise = not lhs_is_trans + rhs_is_rowwise = lhs_is_trans + quantizer_set.x.q_layout = ( + QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE + ) + quantizer_set.kernel.q_layout = ( + QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE + ) + lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + rhs_q = grouped_quantize( + rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis + ) + lhs_data = lhs_q.data + rhs_data = rhs_q.data + lhs_scale_inv = lhs_q.scale_inv + rhs_scale_inv = rhs_q.scale_inv + + assert not ( + lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 + ), "FP8 GEMM does not support E5M2 * E5M2" + + # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs + # thus additional transpose is required + # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later + if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported(): + lhs_is_trans = False + rhs_is_trans = True + if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): + lhs_layout_is_T = lhs.data_layout == "T" + rhs_layout_is_T = rhs.data_layout == "T" else: - raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - - # Note: already_transposed doesn't matter for the output shape - # x.shape = [B, D1, D2] - # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] - # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] - # x.shape = [D1, D2] - # contracting_dims = (1, ) --> output.shape = [1, D1, D2] - # contracting_dims = (0, ) --> output.shape = [1, D2, D1] - bm = lhs_remain_shape[0] - bn = rhs_remain_shape[0] - kl = lhs_3d.shape[-1] - kr = rhs_3d.shape[-1] - assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" - if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): - print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print(f"m = {bm}, n = {bn}, k = {kl}; ") - print("cuBLAS requires the problem shapes being multiples of 16") - assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) - - lhs_list_.append(lhs_3d) - rhs_list_.append(rhs_3d) - if scaling_mode == ScalingMode.NO_SCALING: - lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode.is_tensor_scaling(): - lhs_sinv_list_.append(lhs.scale_inv) - rhs_sinv_list_.append(rhs.scale_inv) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_sinv_list_.append(lhs_scale_inv) - rhs_sinv_list_.append(rhs_scale_inv) - if bias_list is not None: - bias_list_.append(bias_list[i]) - - out_list = GroupedGemmPrimitive.outer_primitive.bind( - *lhs_list_, - *rhs_list_, - *lhs_sinv_list_, - *rhs_sinv_list_, - *bias_list_, - num_gemms=num_gemms, - scaling_mode=scaling_mode, + lhs_layout_is_T = lhs_q.data_layout == "T" + rhs_layout_is_T = rhs_q.data_layout == "T" + lhs_ndim = len(lhs_shape) + rhs_ndim = len(rhs_shape) + if lhs_layout_is_T: + lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) + if rhs_layout_is_T: + rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) + lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T) + rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T) + + # Calling GroupedGEMM Custom Call + K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) + K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) + assert K_lhs == K_rhs + M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) + N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G + + if is_grouped_dense_wgrad: + N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) + else: + assert group_sizes.size == rhs_shape[0] + + assert group_offset.size == 1 + + has_bias = bias is not None + assert not has_bias or bias.shape == (group_sizes.size, N) + bias = jnp.empty((), jnp.float32) if bias is None else bias + + # TODO(Phuong): support MXFP8_1D_SCALING + assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported" + + (out,) = GroupedGemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K_lhs, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, - has_bias=1 if bias_list is not None else 0, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - - return out_list -""" + return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7ed0db0298..07d8f81df0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -47,7 +47,7 @@ from jax.extend import ffi # pylint: disable=ungrouped-imports -__all__ = ["quantize", "quantize_dbias", "grouped_quantize"] +__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] class BaseDBiasQuantizePrimitive(BasePrimitive): @@ -1032,3 +1032,24 @@ def grouped_quantize( group_axis=group_axis, ) return out + + +def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray: + """ + Compute the grouped bias gradient. + + Args: + grad: jnp.ndarray of shape (M, N) + group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M + + Returns: + dbias: jnp.ndarray of shape (num_groups, N) + """ + assert grad.ndim == 2, "Input grad must be a 2D tensor." + assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor." + + segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes) + grad_fp32 = grad.astype(jnp.float32) + dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0]) + dbias = dbias_fp32.astype(grad.dtype) + return dbias diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 0825bd2f73..d9d519fa00 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -13,43 +13,127 @@ #include "transformer_engine/multi_stream.h" #include "xla/ffi/api/c_api.h" +#define MXFP8_BLOCK_SIZE 32 + namespace transformer_engine { namespace jax { -Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, - Variadic_Result_Type output_list, int64_t num_gemms, - JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { +Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, + Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, + bool is_grouped_dense_wgrad) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major with size [m, k], - // B: row-major with size [n, k], needs transpose, + // A: row-major [m, k] for N - [k, m] for T + // B: row-major [k, n] for N - [n, k] for T // on exiting this function, JAX expect: // C: row-major with size [m, n]. // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m], needs transpose, - // B: column-major with size [k, n]. + // A: column-major with size [k, m] for T - [m, k] for N + // B: column-major with size [n, k] for T - [k, n] for N + // // If we call cuBLAS GEMM for A * B, the output will be: // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - if (num_gemms <= 0) { - return ffi_with_cuda_error_check(); + int num_streams = nvte_get_num_compute_streams(); + + // Inputs + auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); + auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); + auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); + auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); + auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); + auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); + auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + + NVTE_CHECK(group_sizes.dimensions().size() == 1); + size_t num_gemms = group_sizes.dimensions()[0]; + + // Outputs + auto out_ptr = reinterpret_cast(output->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + auto workspace_total_size = product(workspace->dimensions()); + + auto lhs_sinv_size = product(lhs_sinv.dimensions()); + auto rhs_sinv_size = product(rhs_sinv.dimensions()); + auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams; + auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; + auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + + size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); + size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); + size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); + size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); + size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); + size_t out_dtype_bytes = te_dtype_bytes(out_dtype); + + NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); + NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, + "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); + + size_t expected_lhs_size = m * k; + size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t actual_lhs_size = product(lhs_data.dimensions()); + size_t actual_rhs_size = product(rhs_data.dimensions()); + size_t actual_out_size = product(output->dimensions()); + NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", + expected_lhs_size, ", got ", actual_lhs_size); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, + "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, + " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, + " * ", n, " = ", expected_out_size, ", got ", actual_out_size); + } else { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, + " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, + "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, + " = ", expected_out_size, ", got ", actual_out_size); } - size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; - size_t expected_output_size = num_gemms + 1; - size_t actual_input_size = input_list.size(); - size_t actual_output_size = output_list.size(); - NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", - expected_input_size, actual_input_size); - NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", - expected_output_size, actual_output_size); - - bool trans_lhs = true; - bool trans_rhs = false; + + size_t dim_list_bytes = sizeof(int32_t) * num_gemms; + std::vector dim_list_host(num_gemms); + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; bool use_split_accumulator = false; + auto bias_shape = std::vector{has_bias ? n : 0}; + const int arch = cuda::sm_arch(); + + // It is weird that TE/Common GEMM only use colwise for MXFP8 + const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); + const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + + if (arch < 100 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; @@ -67,96 +151,83 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, std::vector out_list; std::vector workspace_list; - int lhs_list_offset = 0; - int rhs_list_offset = num_gemms; - int lhs_sinv_list_offset = 2 * num_gemms; - int rhs_sinv_list_offset = 3 * num_gemms; - int bias_list_offset = 4 * num_gemms; - int out_list_offset = 0; - for (int i = 0; i < num_gemms; i++) { - Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); - Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); - Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); - Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); - Result_Type out_i = output_list.get(out_list_offset + i).value(); - - DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); - DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); - DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); - - void *lhs_ptr = lhs_i.untyped_data(); - void *rhs_ptr = rhs_i.untyped_data(); - void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); - void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); - void *out_ptr = out_i->untyped_data(); - - // Placeholder for bias since it can be empty - DType bias_dtype = DType::kFloat32; - void *bias_ptr = nullptr; - - auto lhs_shape_ = lhs_i.dimensions(); - auto rhs_shape_ = rhs_i.dimensions(); - - // lhs and rhs has shape [1, m, k] and [1, n, k] - size_t m = lhs_shape_[1]; - size_t n = rhs_shape_[1]; - size_t k = lhs_shape_[2]; - - auto lhs_shape = std::vector{m, k}; - auto rhs_shape = std::vector{n, k}; - auto out_shape = std::vector{n, m}; - auto lhs_sinv_shape = std::vector{1, 1}; - auto rhs_sinv_shape = std::vector{1, 1}; - - if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || - scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { - float *amax_dptr = nullptr; - float *scale_dptr = nullptr; - auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, - reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, - reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i_)); - rhs_wrapper_list.push_back(std::move(rhs_i_)); - } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Note: the scale_inv array should have been swizzled in Python before lowering - auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); - auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); - for (int i = 0; i < 2; i++) { - lhs_sinv_shape[i] = lhs_sinv_shape_[i]; - rhs_sinv_shape[i] = rhs_sinv_shape_[i]; - } - - NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); - TensorWrapper lhs_i_(nvte_scaling_mode); - TensorWrapper rhs_i_(nvte_scaling_mode); - lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); - rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); - lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); - rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); - - lhs_wrapper_list.push_back(std::move(lhs_i_)); - rhs_wrapper_list.push_back(std::move(rhs_i_)); - } else { - NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); + for (size_t i = 0; i < num_gemms; i++) { + // Matrix data shapes + size_t m_i = dim_list_host[i]; + auto lhs_shape = std::vector{m_i, k}; + auto rhs_shape = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; + auto out_shape = std::vector{m_i, n}; + if (is_grouped_dense_wgrad) { + size_t k_i = dim_list_host[i]; + lhs_shape[0] = lhs_is_trans ? k_i : m; + lhs_shape[1] = lhs_is_trans ? m : k_i; + rhs_shape[0] = rhs_is_trans ? n : k_i; + rhs_shape[1] = rhs_is_trans ? k_i : n; + out_shape[0] = m; + out_shape[1] = n; } - auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); - void *pre_gelu_ptr = nullptr; - auto bias_shape = std::vector{0}; - auto pre_gelu_shape = std::vector{0}; - if (has_bias) { - auto bias_i_get = input_list.get(bias_list_offset + i); - Buffer_Type bias_i = bias_i_get.value(); - bias_ptr = bias_i.untyped_data(); - bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); - bias_shape[0] = n; + // Set matrix data pointers + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); + void *lhs_vptr = static_cast(lhs_ptr); + void *rhs_vptr = static_cast(rhs_ptr); + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); + else + rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); + else + lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape); + + // Scale_inv shapes + auto lhs_sinv_size = std::vector{1}; + auto rhs_sinv_size = std::vector{1}; + if (is_mxfp8_scaling) { + NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", + MXFP8_BLOCK_SIZE, k); + size_t scale_k = k / MXFP8_BLOCK_SIZE; + lhs_sinv_size[0] = m_i * scale_k; + rhs_sinv_size[0] = n * scale_k; + // Need to add swizzle here } + + // Set scale_inv pointers + void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); + void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); + if (is_fp8_gemm) { + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); + else + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + else + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + } else { + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Unsupported scaling mode: ", static_cast(scaling_mode)); + } + auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); + auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); + + // Update pointer for the next GEMM pair + lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; + rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; + out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; + if (is_fp8_gemm) { + lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; + rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; + } + if (has_bias) bias_ptr += n * bias_dtype_bytes; - out_wrapper_list.push_back(std::move(out_i_)); + // Move objects to the lists to keep them alive + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); + out_wrapper_list.push_back(std::move(out_i)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); @@ -167,11 +238,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, out_list.push_back(out_wrapper_list.back().data()); } - auto workspace_get = output_list.get(num_gemms); - Result_Type workspace = workspace_get.value(); - uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); - auto num_streams = nvte_get_num_compute_streams(); - size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -182,7 +248,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, } nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, + pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, workspace_list.data(), accumulate, use_split_accumulator, num_math_sm, stream); @@ -192,11 +258,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .RemainingArgs() // input list - .RemainingRets() // output list - .Attr("num_gemms") + .Arg() // lhs_data + .Arg() // lhs_sinv + .Arg() // rhs_data + .Arg() // rhs_sinv + .Arg() // bias + .Arg() // group_sizes + .Arg() // group_offset + .Ret() // output + .Ret() // workspace + .Attr("M") + .Attr("N") + .Attr("K") + .Attr("lhs_is_trans") + .Attr("rhs_is_trans") .Attr("scaling_mode") - .Attr("has_bias"), + .Attr("has_bias") + .Attr("is_grouped_dense_wgrad"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 55d60e4189..bba101c722 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -153,28 +153,28 @@ def _dense_bwd_rule( # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_constracting_dim = tuple( + g_contracting_dim = tuple( range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) # k_non_contracting_dims - k_constracting_dim = tuple( + k_contracting_dim = tuple( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad = tex.gemm( casted_grad.get_rowwise_tensor(), rowwise_casted_kernel, - (g_constracting_dim, k_constracting_dim), + (g_contracting_dim, k_contracting_dim), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims - g_constracting_dim = x_constracting_dim = tuple( + g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad = tex.gemm( - colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) + colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) @@ -184,135 +184,240 @@ def _dense_bwd_rule( _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) -""" def grouped_dense( - x_list, - kernel_list, - bias_list, - contracting_dims_list, - quantizer_set_list=None, + x: jnp.ndarray, + kernel: jnp.ndarray, + group_sizes: jnp.ndarray, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), + bias: jnp.ndarray = None, + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, + preferred_element_type: jnp.dtype = None, + group_offset: jnp.array = None, + quantizer_set: QuantizerSet = noop_quantizer_set, ): - # Perform grouped_dense layer transformation with optional quantization. + """ + Perform grouped dense (linear) layer transformation with optional quantization. - output_list = _grouped_dense( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + Args: + x: Input tensor of shape (M, K) + kernel: Weight matrix of shape (G, K, N) + group_sizes: 1D array of shape (G,) specifying the size of each group + contracting_dims: Tuple of sequences specifying which dimensions to contract + (currently only supports ((1,), (1,))) + bias: Bias tensor of shape (G, N) + precision: JAX precision for the GEMM operation + preferred_element_type: Preferred data type for the output tensor + group_offset: 1D array containing offsets for each group (not yet implemented) + quantizer_set: Set of quantizers for FP8 quantization of the input and output + + Returns: + A jnp.ndarray containing the result of the grouped linear operation + """ + output = _grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ) - return output_list + return output -@partial(jax.custom_vjp, nondiff_argnums=(3,)) -def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): - output_list, _ = _grouped_dense_fwd_rule( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list +@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) +def _grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, +): + output, _ = _grouped_dense_fwd_rule( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ) - return output_list + return output def _grouped_dense_fwd_rule( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ): - use_bias = bias_list is not None - output_list = [] - x_rowwise_list = [] - x_colwise_list = [] - kernel_colwise_list = [] - kernel_rowwise_list = [] - x_shape_list = [] - kernel_shape_list = [] - if quantizer_set_list is None: - x_rowwise_list = x_list - x_colwise_list = x_list - kernel_colwise_list = kernel_list - kernel_rowwise_list = kernel_list - x_shape_list = [x.shape for x in x_list] - kernel_shape_list = [kernel.shape for kernel in kernel_list] + use_bias = bias is not None + is_noop_quantizer_set = quantizer_set == noop_quantizer_set + + if is_noop_quantizer_set: + grouped_gemm_x = x + grouped_gemm_kernel = kernel + ctx_x = x + ctx_kernel = kernel + flatten_axis_k = None else: - for i in range(len(x_list)): # pylint: disable=consider-using-enumerate - q_x = tex.quantize(x_list[i], quantizer_set_list[i].x) - q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel) - x_rowwise_list.append(q_x.get_rowwise_tensor()) - x_colwise_list.append(q_x.get_colwise_tensor()) - kernel_colwise_list.append(q_kernel.get_colwise_tensor()) - kernel_rowwise_list.append(q_kernel.get_rowwise_tensor()) - x_shape_list.append(x_rowwise_list[-1].data.shape) - kernel_shape_list.append(kernel_rowwise_list[-1].data.shape) - - output_list = tex.grouped_gemm( - x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list + x_contracting_dims, k_contracting_dims = contracting_dims + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis + + assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" + assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" + # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( + "grouped_dense for FP8 can only handle x_contracting_dims=(1,) " + "and k_contracting_dims=(1,) for now, " + f"got {x_contracting_dims=} and {k_contracting_dims=}" + ) + k_contracting_dims = (0,) + + casted_x = tex.grouped_quantize( + x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x + ) + casted_kernel = tex.grouped_quantize( + kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k + ) + contracting_dims = (x_contracting_dims, k_contracting_dims) + + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have + # rowwise_casted_x.original_shape == (M, K) + # colwise_casted_kernel.original_shape == (G, N, K) + grouped_gemm_x = casted_x.get_rowwise_tensor() + grouped_gemm_kernel = casted_kernel.get_colwise_tensor() + # TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()? + ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None + ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None + + output = tex.grouped_gemm( + grouped_gemm_x, + grouped_gemm_kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, ) ctx = ( - x_colwise_list, - kernel_rowwise_list, - x_shape_list, - kernel_shape_list, + group_sizes, + ctx_x, + ctx_kernel, + x.shape, + kernel.shape, use_bias, - quantizer_set_list, + is_noop_quantizer_set, + quantizer_set, + flatten_axis_k, ) - return output_list, ctx + return output, ctx -def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list): +def _grouped_dense_bwd_rule( + contracting_dims, precision, preferred_element_type, group_offset, ctx, grad +): + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims + ( - colwise_x_list, - rowwise_kernel_list, - x_shape_list, - kernel_shape_list, + group_sizes, + ctx_x, + ctx_kernel, + x_shape, + kernel_shape, use_bias, - quantizer_set_list, + is_noop_quantizer_set, + quantizer_set, + flatten_axis_k, ) = ctx - group_size = len(grad_list) - dbias_list = [] - grad_rowwise_list = [] - grad_colwise_list = [] - dgrad_contracting_dims_list = [] - wgrad_contracting_dims_list = [] - for i in range(group_size): - grad = grad_list[i] - x_shape = x_shape_list[i] - kernel_shape = kernel_shape_list[i] - fwd_contracting_dims = contracting_dims_list[i] - - if quantizer_set_list is None: - casted_grad = grad - dbias = tex.quantization._jax_dbias(grad) - grad_rowwise_list.append(grad) - grad_colwise_list.append(grad) - else: - quantizer_set = quantizer_set_list[i] - casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad - ) - grad_rowwise_list.append(casted_grad.get_rowwise_tensor()) - grad_colwise_list.append(casted_grad.get_colwise_tensor()) - dbias_list.append(dbias) - - # GEMM NT - fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims + if is_noop_quantizer_set: + # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) + # g_contracting_dim = (1, ) + # k_contracting_dim = (2, ) g_contracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) k_contracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims + dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_contracting_dims_list.append(dgrad_contracting_dims) + dgrad_grad = grad + dgrad_kernel_T = ctx_kernel - # GEMM TN + # g_contracting_dim = (0, ) + # x_contracting_dim = (0, ) g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_contracting_dims_list.append(wgrad_contracting_dims) + wgrad_x_T = ctx_x + wgrad_grad = grad + else: + casted_grad = tex.grouped_quantize( + grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k + ) + + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use + # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the + # extra transpose for FP8 in grouped_gemm + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + g_contracting_dim = (1,) + k_contracting_dim = (2,) + dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) + dgrad_grad = casted_grad.get_rowwise_tensor() + dgrad_kernel_T = ctx_kernel + + # We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work + # after the extra transpose for FP8 in grouped_gemm + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + g_contracting_dim = (0,) + x_contracting_dim = (1,) + wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) + wgrad_x_T = ctx_x + wgrad_grad = casted_grad.get_colwise_tensor() + + dgrad = tex.grouped_gemm( + dgrad_grad, + dgrad_kernel_T, + group_sizes, + dgrad_contracting_dims, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) - dgrad_list = tex.grouped_gemm( - grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list + wgrad = tex.grouped_gemm( + wgrad_x_T, + wgrad_grad, + group_sizes, + wgrad_contracting_dims, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) - wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list) - return dgrad_list, wgrad_list, dbias_list, quantizer_set_list + group_sizes_grad = None + dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + + return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) -""" diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 45ec4fd1fa..06a2562fb1 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -127,14 +127,16 @@ def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatte def dequantize(scaled_tensor): """Dequantize a tensor using block scaling. - This function dequantizes a tensor that was quantized using block scaling - by applying the inverse scaling factor to each block of data. - Args: - scaled_tensor: The quantized tensor to dequantize + data: The quantized tensor data + scale_inv: The inverse scaling factors + dq_dtype: The data type for dequantized values + scaling_mode: The scaling mode used for quantization + is_colwise: Whether the scaling is column-wise + flatten_axis: The axis along which the tensor could be flattened to 2D Returns: - The dequantized tensor in the specified data type + The dequantized tensor """ return BlockScaleDequantizer._dequantize_func( scaled_tensor.data, From 4d4f1edb6aeb4a17206276f3421126bc58ae67e7 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Thu, 12 Jun 2025 10:18:58 -0700 Subject: [PATCH 09/39] Cpu reload double buffer (#1695) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added double buffering support initial commit Signed-off-by: Selvaraj Anandaraj * Fixed bugs Signed-off-by: Selvaraj Anandaraj * Make only one double buffer creation Signed-off-by: Selvaraj Anandaraj * Fixed bug Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed flag setting Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Merge conflict Signed-off-by: Selvaraj Anandaraj * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Signed-off-by: Pawel Gadzinski Co-authored-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Co-authored-by: Pawel Gadzinski Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_cpu_offloading.py | 5 ++ transformer_engine/pytorch/cpu_offload.py | 64 +++++++++++++++++++++-- transformer_engine/pytorch/utils.py | 8 +++ 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index ab4b7634b8..87494f3c21 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -97,6 +97,8 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload max_mem_used = torch.cuda.memory_allocated() / (1024**2) torch.cuda.synchronize() + tensor.sum().backward() + return max_mem_used @@ -115,6 +117,9 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: the difference being the size of the FP8 cache that is not offloaded to the CPU. We also expect this memory consumption to be smaller than in scenario (1). """ + import gc + + gc.collect() model_cls = model_types[model_key] models_list = [model_cls() for _ in range(NUM_LAYERS)] diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index b0cca99f84..4262c2103d 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -253,12 +253,20 @@ def offload(src_tensor, pin_memory=True): return state @staticmethod - def reload(state, non_blocking=None): + def reload(state, non_blocking=None, copy_buffer=None): """Reload.""" dev, cpu_backup = state if non_blocking is None: non_blocking = cpu_backup.is_pinned() - return cpu_backup.to(dev, non_blocking=non_blocking) + + if copy_buffer is None: + return cpu_backup.to(dev, non_blocking=non_blocking) + + assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!" + + copy_buffer.copy_(cpu_backup, non_blocking=non_blocking) + + return copy_buffer def tensor_push(self, tensor: torch.Tensor, **kwargs): """Tensor push.""" @@ -300,6 +308,7 @@ def __init__( num_offload_group, # must be <= actual number of groups (number of commits) num_model_group, tensor_need_offloading_checker=(lambda t: True), + double_buffering=False, debug=False, ) -> None: super().__init__( @@ -320,6 +329,11 @@ def __init__( # Core data structure that decides the window for offloading self.layer_window_map = {} + # Data structures fo double buffered reloading + self.double_buffering = double_buffering + self.reload_double_buffer = [[], []] + self.double_buffer_created = False + # Logic to make offloading load balance across computation # for optimal CPU/GPU interconnect usage constant = 0 @@ -413,8 +427,10 @@ def tensor_pop(self, tensor_tag, **kwargs): self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) tensor = self.fp8_tensor_object_map.pop(tensor_tag) - self.tensor_tag_to_buf.pop(tensor_tag, None) + if self.double_buffering: + tensor.do_not_clear = True + self.tensor_tag_to_buf.pop(tensor_tag, None) # the tensor should have been copied back in on_group_commit_backward() # which invokes bulk_reload_group. assert not isinstance(tensor, tuple) @@ -466,6 +482,20 @@ def synchronize_on_group_commit_forward(self, current_group): # the first compute completion if current_group == 0: self.d2h_stream.wait_stream(torch.cuda.current_stream()) + + if not self.double_buffer_created: + # Creating the first copy of double buffer for tensors that are offloaded + for tensor_tag, buf in self.tensor_tag_to_buf.items(): + if isinstance(buf, list): + for b in buf: + self.reload_double_buffer[0].append( + torch.empty_like(b) if self.double_buffering else None + ) + else: + self.reload_double_buffer[0].append( + torch.empty_like(buf) if self.double_buffering else None + ) + self.bulk_offload_group(current_group) # Window map data structure helps us synchronize based on number @@ -495,6 +525,15 @@ def synchronize_on_group_commit_forward(self, current_group): # Increment the offload group count to keep track self.offloaded_group_count += 1 + if not self.double_buffer_created: + # Creating second copy of double buffer for tensors that are offloaded + if current_group == (self.num_layers - 1): + for buf in self.reload_double_buffer[0]: + self.reload_double_buffer[1].append( + torch.empty_like(buf) if self.double_buffering else None + ) + self.double_buffer_created = True + def on_group_commit_forward(self): """This function will cause host device synchronization""" # handle synchronization events @@ -506,21 +545,32 @@ def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" assert group_to_reload < self.num_offload_group + buffer_idx = 0 + double_buffer_idx = group_to_reload % 2 + with torch.cuda.stream(self.h2d_stream): # move back tensors for tensor_label, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_to_reload: if isinstance(state, tuple): - recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) + recovered_tensor = SynchronizedGroupOffloadHandler.reload( + state, True, self.reload_double_buffer[double_buffer_idx][buffer_idx] + ) + buffer_idx = buffer_idx + 1 self.tensor_tag_to_state[tensor_label] = recovered_tensor elif isinstance(state, list): tensor_list = [] for state_tuple in state: if isinstance(state_tuple, tuple): tensor_list.append( - SynchronizedGroupOffloadHandler.reload(state_tuple) + SynchronizedGroupOffloadHandler.reload( + state_tuple, + True, + self.reload_double_buffer[double_buffer_idx][buffer_idx], + ) ) + buffer_idx = buffer_idx + 1 else: tensor_list.append(state_tuple) @@ -574,6 +624,7 @@ def get_cpu_offload_context( model_layers: int = 1, offload_activations: bool = True, offload_weights: bool = False, + double_buffering: bool = False, ): """ This function returns the CPU Offload context and the synchronizer function that needs to be @@ -602,6 +653,8 @@ def get_cpu_offload_context( When set to `True`, offloads the activations for the TE layer. offload_weights: bool, default = `True` When set to `True`, offloads the weights for the TE layer. + double_buffering: bool, default = `False` + When set to `True`, uses double buffering for offloading. """ @@ -633,6 +686,7 @@ def tensor_need_offloading_checker_activations(tensor): num_offload_group=num_layers, num_model_group=model_layers, tensor_need_offloading_checker=tensor_need_offloading_checker, + double_buffering=double_buffering, ) def group_prefetch_offload_commit_async(tensor): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 3abebdf1e4..e66477476f 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -37,8 +37,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: Must be used carefully. """ + for t in tensors: if t is not None: + # Workaround for double buffering in cpu offload + if hasattr(t, "do_not_clear"): + continue + if hasattr(t, "get_data_tensors"): + if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()): + continue + if hasattr(t, "clear"): t.clear() else: From c3b7c2aee161ccfebb9b37eea9b0afdd10e972c1 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 12 Jun 2025 13:30:56 -0400 Subject: [PATCH 10/39] Revert "[JAX] GroupedDense v.2 without dynamic shape" (#1874) Revert "[JAX] GroupedDense v.2 without dynamic shape (#1721)" This reverts commit 5d01ef2113d154846883623c947224e51bf75ac6. Signed-off-by: Phuong Nguyen --- tests/jax/test_custom_call_compute.py | 352 ++++++++------ .../common/gemm/cublaslt_gemm.cu | 3 - transformer_engine/jax/cpp_extensions/gemm.py | 459 ++++++------------ .../jax/cpp_extensions/quantization.py | 23 +- .../jax/csrc/extensions/gemm.cpp | 308 +++++------- transformer_engine/jax/dense.py | 301 ++++-------- .../jax/quantize/dequantizer.py | 12 +- 7 files changed, 577 insertions(+), 881 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9ff0c11757..25a463aeaa 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -40,11 +40,10 @@ ScalingMode, QuantizerFactory, QuantizeLayout, - noop_quantizer_set, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation -from transformer_engine.jax.dense import dense, grouped_dense +from transformer_engine.jax.dense import dense from transformer_engine.jax.layernorm_dense import layernorm_dense GEMM_CASES = [ @@ -1205,6 +1204,24 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) +# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm() +def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer): + ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims + lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 + rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 + lhs_q = lhs_quantizer.quantize( + lhs, + is_rowwise=lhs_is_rowwise, + is_colwise=not lhs_is_rowwise, + ) + rhs_q = rhs_quantizer.quantize( + rhs, + is_rowwise=rhs_is_rowwise, + is_colwise=not rhs_is_rowwise, + ) + return lhs_q, rhs_q + + # E5M2 * E5M2 is not supported fwd_bwd_dtypes = [ [jnp.float8_e4m3fn, jnp.float8_e4m3fn], @@ -1212,194 +1229,219 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): [jnp.float8_e5m2, jnp.float8_e4m3fn], ] -GROUPED_DENSE_INPUT_SHAPES = [ - # (n_groups, m, n, k), the actual m will be multiplied by 32 - (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 - (8, 64, 32, 128), - (8, 64, 128, 256), -] - - -@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) +""" +@pytest_parametrize_wrapper( + "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] +) class TestGroupedDense: - def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): - lhs_contract_dim, _ = contracting_dims - assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 - if bias is None: - bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) - else: - assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) - remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() - lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) - rhs = jnp.split(rhs, rhs.shape[0], axis=0) - bias = jnp.split(bias, bias.shape[0], axis=0) - ref_out = [] - dim_num = (contracting_dims, ((), ())) - for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): - out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0) - ref_out.append(jnp.squeeze(out_i)) - return ref_out - - def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): + def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list): + ref_out_list = [] + for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): + dim_nums = (contracting_dims, ((), ())) + ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums)) + return ref_out_list + + def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 4) - n_groups, m, n, k = input_shape - - group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) - group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) - group_sizes = jnp.diff(group_sizes) - assert group_sizes.sum() == m - - # *32 to make sure that input shape works for MXFP8 - group_sizes = group_sizes * 32 - m = m * 32 - - lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) - rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) - bias_shape = (n_groups, n) - - lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) - rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype) - bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None - - lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) - rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,) - contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + subkeys = jax.random.split(key, len(shape_list) * 2) + + lhs_list, rhs_list, contracting_dims_list = [], [], [] + for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)): + lhs = jax.random.uniform( + subkeys[2 * i], + (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), + dtype=dtype, + ) + rhs = jax.random.uniform( + subkeys[2 * i + 1], + (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), + dtype=dtype, + ) + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) + contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) - return lhs, rhs, group_sizes, contracting_dims, bias + lhs_list.append(lhs) + rhs_list.append(rhs) + contracting_dims_list.append(contracting_dims) - def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): - assert out.dtype == ref_list[0].dtype - out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) - for i in range(len(ref_list)): - assert_allclose(out_list[i], ref_list[i], dtype=dtype) + return lhs_list, rhs_list, contracting_dims_list @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - @pytest_parametrize_wrapper("layout", ["NN"]) - def test_grouped_gemm_fp16(self, dtype, input_shape, layout): - lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - dtype, input_shape, layout + @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) + def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list): + lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( + dtype, shape_list, layout_list ) - ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) - self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) + ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) + primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list) + for i in range(len(shape_list)): + assert_allclose(primitive_out[i], ref_out[i], dtype=dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("layout", ["NN"]) - def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - pytest.skip("MXFP8 is not supported in grouped_gemm yet") - + @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) + def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list): fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=fwd_dtype, - bwd_dtype=bwd_dtype, - is_2x2x=False, - n_groups=input_shape[0], + scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False ) - # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype - # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype - quantizer_set.kernel.q_dtype = bwd_dtype - for quantizer in quantizer_set.kernel.quantizers: - quantizer.q_dtype = bwd_dtype - out_dtype = jnp.bfloat16 - lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - out_dtype, input_shape, layout - ) - ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - prim_out = tex.grouped_gemm( - lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( + out_dtype, shape_list, layout_list ) + q_lhs_list = [] + q_rhs_list = [] + for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): + # quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to + # test the case where lhs and rhs have different q_dtypes + q_lhs, q_rhs = _quantize_gemm_pair( + lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad + ) + q_lhs_list.append(q_lhs) + q_rhs_list.append(q_rhs) + + ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) + primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list) allclose_dtype = jnp.float8_e4m3fn - if jnp.float8_e5m2 in fwd_bwd_dtype: + if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: allclose_dtype = jnp.float8_e5m2 + for i in range(len(shape_list)): + assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype) - self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) - - def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): - out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) + def test_grouped_dense_grad_fp16(self, dtype, shape_list): + group_size = len(shape_list) + layout_list = ["NN" for _ in range(group_size)] - def _primitive_sum_grouped_dense( - self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set - ): - out = grouped_dense( - x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set + x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( + dtype, shape_list, layout_list ) - return jnp.sum(jnp.asarray(out)) + bias_list = [] + key = jax.random.PRNGKey(1) + for shape in shape_list: + n = shape[1] + bias = jax.random.uniform(key, n, dtype=dtype) + bias_list.append(bias) + + def ref_func(x_list, kernel_list, bias_list, contracting_dims_list): + out_list = [] + for i in range(len(x_list)): + out_list.append( + dense( + x_list[i], + kernel_list[i], + bias_list[i], + contracting_dims=contracting_dims_list[i], + ) + ) + # Note: we use jnp.sum instead of jnp.mean to make the gradient larger + # and prevent them from being clamp to zero + out_sum_list = [jnp.sum(out) for out in out_list] + return jnp.sum(jnp.asarray(out_sum_list)) - @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - def test_grouped_dense_grad_fp16(self, dtype, input_shape): - x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( - dtype, - input_shape, - with_bias=True, - ) + def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list): + out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list) + out_sum_list = [jnp.sum(out) for out in out_list] + return jnp.sum(jnp.asarray(out_sum_list)) - value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) - value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) + value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) - ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( - x, kernel, bias, group_sizes, contracting_dims + ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( + x_list, kernel_list, bias_list, contracting_dims_list ) - prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( - x, kernel, bias, group_sizes, contracting_dims + primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( + value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list) ) - assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) - assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) - assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) - assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype) + for i in range(group_size): + assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype) + assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype) + assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize( - "fwd_bwd_dtype", - [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], - ) + @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - pytest.skip("MXFP8 is not supported in grouped_dense yet") - + def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list): + group_size = len(shape_list) + layout_list = ["NN" for _ in range(group_size)] fwd_dtype, bwd_dtype = fwd_bwd_dtype - dtype = jnp.bfloat16 - x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( - dtype, - input_shape, - with_bias=True, - ) + if fwd_dtype == jnp.float8_e5m2: + pytest.skip("We never use E5M2 for fwd_dtype in training") + + # Question: should we use different quantizers for different groups? + ref_quantizer_set_list = [] + quantizer_set_list = [] + for _ in range(group_size): + ref_quantizer_set = QuantizerFactory.create_set( + scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True + ) + ref_quantizer_set_list.append(ref_quantizer_set) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True + ) + quantizer_set_list.append(quantizer_set) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=fwd_dtype, - bwd_dtype=bwd_dtype, - is_2x2x=True, - n_groups=group_sizes.size, + out_dtype = jnp.bfloat16 + x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( + out_dtype, shape_list, layout_list ) - value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) - value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) - - ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( - x, - kernel, - bias, - group_sizes, - contracting_dims, + bias_list = [] + key = jax.random.PRNGKey(1) + for shape in shape_list: + n = shape[1] + bias = jax.random.uniform(key, n, dtype=out_dtype) + bias_list.append(bias) + + def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): + out_list = [] + for i in range(len(x_list)): + out_list.append( + dense( + x_list[i], + kernel_list[i], + bias_list[i], + contracting_dims=contracting_dims_list[i], + quantizer_set=quantizer_set_list[i], + ) + ) + # Note: we use jnp.sum instead of jnp.mean to make the gradient larger + # and prevent them from being clamp to zero + out_sum_list = [jnp.sum(out) for out in out_list] + return jnp.sum(jnp.asarray(out_sum_list)) + + def primitive_func( + x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + ): + out_list = grouped_dense( + x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + ) + out_sum_list = [jnp.sum(out) for out in out_list] + return jnp.sum(jnp.asarray(out_sum_list)) + + value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) + value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) + + ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( + x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list ) - prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( - x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set + primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( + value_n_grad_primitive_func( + x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + ) ) - assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) - assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) - assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) - assert_allclose(prim_dbias, ref_dbias, dtype=dtype) + allclose_dtype = jnp.float8_e4m3fn + if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: + allclose_dtype = jnp.float8_e5m2 + assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype) + for i in range(group_size): + assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype) + assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype) + assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype) +""" diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index fa8785dcc7..4080ae1668 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -525,7 +525,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); - const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -534,8 +533,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); - NVTE_CHECK(workspace_alignment % 256 == 0, - "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index cc02ec3404..c38a04f85a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -6,28 +6,22 @@ from typing import Tuple, Sequence, Union, Dict from functools import partial, reduce import operator -import math import jax import jax.numpy as jnp from transformer_engine_jax import get_device_compute_capability from .base import BasePrimitive, register_primitive -from .quantization import grouped_quantize from ..quantize import ( ScaledTensor, - GroupedScaledTensor1x, ScalingMode, Quantizer, - GroupedQuantizer, QuantizeConfig, - QuantizerSet, - QuantizeLayout, noop_quantizer_set, ) -__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] +__all__ = ["gemm"] num_cublas_streams = 4 @@ -40,11 +34,6 @@ def get_cublas_workspace_size_bytes() -> None: return 4_194_304 -def is_gemm_with_all_layouts_supported() -> False: - """Return True if using blackwell, False otherwise.""" - return get_device_compute_capability(0) >= 100 - - class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -52,139 +41,73 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = () inner_primitive = None outer_primitive = None @staticmethod - def abstract( - lhs_data_aval, - lhs_scale_inv_aval, - rhs_data_aval, - rhs_scale_inv_aval, - bias_aval, - group_sizes_aval, - group_offset_aval, - *, - M, - N, - K, - lhs_is_trans, - rhs_is_trans, - scaling_mode, - out_dtype, - has_bias, - is_grouped_dense_wgrad, - ): + def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): """ - Grouped GEMM operation. - Args: - lhs_data: Left-hand side input matrix data, 1D flattened array - lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, 1D flattened array - rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array - bias: Bias matrix of shape (G, N) - group_sizes: 1D array containing the sizes of each group - group_offset: 1D array containing offsets for each group (not yet implemented) - M: Number of rows in the output matrix - N: Number of columns in the output matrix - K: Number of columns in the left-hand side matrix - lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed - rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed - scaling_mode: Scaling mode for the GEMM operations - out_dtype: Data type of the output tensors - has_bias: Boolean indicating if bias tensors are provided - is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation - where both lhs and rhs are 2D matrices and output is (G, M, N) + *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: + args[ 0 : num_gemms] are the lhs tensors, + args[ num_gemms : 2*num_gemms] are the rhs tensors, + args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, + args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, + args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. + num_gemms: Number of GEMM operations to perform. + scaling_mode: Scaling mode for the GEMM operations. + out_dtype: Data type of the output tensors. + has_bias: Boolean indicating if bias tensors are provided. Returns: - A jnp.ndarray containing the result of the grouped GEMM operation + A tuple of ShapedArray objects of size num_gemms+1: + ret[0 : num_gemms]: GEMM output tensors, + ret[num_gemms]:workspace tensor. """ - del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval - del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias - # TODO(Phuong): move some shape checks from Cpp to here + del scaling_mode + expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms + assert ( + len(args) == expected_num_args + ), f"Expected {expected_num_args} input arguments, but got {len(args)}" + A_list = args[0:num_gemms] + B_list = args[num_gemms : 2 * num_gemms] + # A and B have shapes [1, m, k] and [1, n, k] + out_list_aval = tuple( + jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) + for A, B in zip(A_list, B_list) + ) workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams - workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - out_shape = (M, N) - if is_grouped_dense_wgrad: - out_shape = (group_sizes_aval.size, M, N) - out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) - return (out_aval, workspace_aval) + return (*out_list_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) - return (out_aval,) + return out_aval @staticmethod - def lowering( - ctx, - *args, - M, - N, - K, - lhs_is_trans, - rhs_is_trans, - scaling_mode, - out_dtype, - has_bias, - is_grouped_dense_wgrad, - ): + def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, *args, - M=M, - N=N, - K=K, - lhs_is_trans=lhs_is_trans, - rhs_is_trans=rhs_is_trans, - scaling_mode=scaling_mode.value, + num_gemms=num_gemms, + scaling_mode=int(scaling_mode), has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) @staticmethod - def impl( - lhs_data, - lhs_scale_inv, - rhs_data, - rhs_scale_inv, - bias, - group_sizes, - group_offset, - M, - N, - K, - lhs_is_trans, - rhs_is_trans, - scaling_mode, - out_dtype, - has_bias, - is_grouped_dense_wgrad, - ): + def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): assert GroupedGemmPrimitive.inner_primitive is not None - (out, _) = GroupedGemmPrimitive.inner_primitive.bind( - lhs_data, - lhs_scale_inv, - rhs_data, - rhs_scale_inv, - bias, - group_sizes, - group_offset, - M=M, - N=N, - K=K, - lhs_is_trans=lhs_is_trans, - rhs_is_trans=rhs_is_trans, - scaling_mode=scaling_mode, + out = GroupedGemmPrimitive.inner_primitive.bind( + *args, + num_gemms=num_gemms, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - return (out,) + return out[:-1] # out is [out_list, wkspace], only return out_list register_primitive(GroupedGemmPrimitive) @@ -362,7 +285,7 @@ def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: QuantizerSet = noop_quantizer_set, + quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, ) -> jnp.ndarray: """General matrix multiplication with optional quantization. @@ -387,190 +310,130 @@ def gemm( return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) -def grouped_gemm( - lhs: Union[jnp.ndarray, GroupedScaledTensor1x], - rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - group_sizes: jnp.ndarray, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), - bias: jnp.ndarray = None, - precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, - preferred_element_type: jnp.dtype = None, - group_offset: jnp.array = None, - quantizer_set: QuantizerSet = noop_quantizer_set, -) -> jnp.ndarray: - """ - Grouped GEMM operation. - - Args: - lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - group_sizes: 1D array containing the sizes of each group - contracting_dims: Tuple of two sequences representing the contracting dimensions - bias: Bias tensor of shape (G, N) - precision: JAX precision for the GEMM operation - preferred_element_type: Preferred data type for the output tensor - group_offset: 1D array containing offsets for each group (not yet implemented) - quantizer_set: Set of quantizers for FP8 quantization of the input and output - - Returns: - A jnp.ndarray containing the result of the grouped GEMM operation - - Note: - Tested shapes: - lhs: [M, K] or [K, N] - rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] - """ - # TODO(Phuong): implement the group_offset - group_offset = group_offset or jnp.zeros((1,), jnp.int32) - - # TODO(Phuong): implement the precision - del precision - - if isinstance(lhs, jnp.ndarray): - assert isinstance(rhs, jnp.ndarray) - out_dtype = lhs.dtype - lhs_shape = lhs.shape - rhs_shape = rhs.shape - lhs_data = lhs - rhs_data = rhs - lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) - scaling_mode = ScalingMode.NO_SCALING - elif isinstance(lhs, GroupedScaledTensor1x): - assert isinstance(rhs, GroupedScaledTensor1x) - out_dtype = lhs.dq_dtype - lhs_shape = lhs.original_shape - rhs_shape = rhs.original_shape - lhs_data = lhs.data - rhs_data = rhs.data - lhs_scale_inv = lhs.scale_inv - rhs_scale_inv = rhs.scale_inv - assert lhs.scaling_mode == rhs.scaling_mode - scaling_mode = lhs.scaling_mode - else: - raise TypeError("Unsupported lhs type object!") - - out_dtype = preferred_element_type or out_dtype - - lhs_contract_dim, rhs_contract_dim = contracting_dims - - lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 - lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) - - # rhs_shape [G, K, N] - rhs_is_trans = rhs_contract_dim[0] != 1 - rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) +""" +def swizzled_scale(scales): + # Swizzle the scale tensor for FP8 GEMM + assert scales.ndim == 2 + rows, cols = scales.shape + scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) + scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) + scales = scales.reshape(rows, cols) + return scales - is_grouped_dense_wgrad = False - if len(rhs_shape) == 2: - rhs_is_trans = rhs_contract_dim[0] != 0 - is_grouped_dense_wgrad = True - # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? - if ( - is_grouped_dense_wgrad - and not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - ): - lhs_is_trans = True - rhs_is_trans = False - lhs_flatten_axis = 1 - rhs_flatten_axis = 1 - - if ( - not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - and quantizer_set != noop_quantizer_set - ): - assert isinstance(quantizer_set.x, GroupedQuantizer) - assert type(quantizer_set.x) is type(quantizer_set.kernel) - scaling_mode = quantizer_set.x.scaling_mode - if ( - # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later - # scaling_mode.is_tensor_scaling() - # and is_gemm_with_all_layouts_supported() - scaling_mode.is_1d_block_scaling() - ): - lhs_is_rowwise = rhs_is_rowwise = True - else: - lhs_is_rowwise = not lhs_is_trans - rhs_is_rowwise = lhs_is_trans - quantizer_set.x.q_layout = ( - QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE - ) - quantizer_set.kernel.q_layout = ( - QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE - ) - lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) - rhs_q = grouped_quantize( - rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis - ) - lhs_data = lhs_q.data - rhs_data = rhs_q.data - lhs_scale_inv = lhs_q.scale_inv - rhs_scale_inv = rhs_q.scale_inv - - assert not ( - lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 - ), "FP8 GEMM does not support E5M2 * E5M2" - - # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs - # thus additional transpose is required - # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later - if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported(): - lhs_is_trans = False - rhs_is_trans = True +def grouped_gemm( + lhs_list: List[Union[jnp.ndarray, ScaledTensor]], + rhs_list: List[Union[jnp.ndarray, ScaledTensor]], + contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], + bias_list: List[jnp.ndarray] = None, +) -> List[jnp.ndarray]: + # Grouped GEMM for multiple pairs of tensors. + assert ( + len(lhs_list) == len(rhs_list) == len(contracting_dims_list) + ), "lhs_list, rhs_list, contracting_dims_list must have the same length" + + num_gemms = len(lhs_list) + lhs_list_ = [] + rhs_list_ = [] + lhs_sinv_list_ = [] + rhs_sinv_list_ = [] + bias_list_ = [] + for i in range(num_gemms): + lhs = lhs_list[i] + rhs = rhs_list[i] + contracting_dims = contracting_dims_list[i] + dim_nums = (contracting_dims, ((), ())) if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - lhs_layout_is_T = lhs.data_layout == "T" - rhs_layout_is_T = rhs.data_layout == "T" + scaling_mode = lhs.scaling_mode + lhs_shape = lhs.data.shape + rhs_shape = rhs.data.shape + out_dtype = lhs.dq_dtype + # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout + if lhs.scaling_mode.is_tensor_scaling(): + assert not ( + lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 + ), "FP8 GEMM does not support E5M2 * E5M2" + ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims + if lhs.data_layout == "T": + lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim + if rhs.data_layout == "T": + rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim + dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: - lhs_layout_is_T = lhs_q.data_layout == "T" - rhs_layout_is_T = rhs_q.data_layout == "T" - lhs_ndim = len(lhs_shape) - rhs_ndim = len(rhs_shape) - if lhs_layout_is_T: - lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) - if rhs_layout_is_T: - rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T) - rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T) - - # Calling GroupedGEMM Custom Call - K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) - K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - assert K_lhs == K_rhs - M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G - - if is_grouped_dense_wgrad: - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) - else: - assert group_sizes.size == rhs_shape[0] - - assert group_offset.size == 1 - - has_bias = bias is not None - assert not has_bias or bias.shape == (group_sizes.size, N) - bias = jnp.empty((), jnp.float32) if bias is None else bias - - # TODO(Phuong): support MXFP8_1D_SCALING - assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported" - - (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data, - lhs_scale_inv, - rhs_data, - rhs_scale_inv, - bias, - group_sizes, - group_offset, - M=M, - N=N, - K=K_lhs, - lhs_is_trans=lhs_is_trans, - rhs_is_trans=rhs_is_trans, - scaling_mode=scaling_mode.value, + # For jnp.ndarray, only consider contracting_dims, data_layout is always NN + scaling_mode = ScalingMode.NO_SCALING + lhs_shape = lhs.shape + rhs_shape = rhs.shape + out_dtype = lhs.dtype + + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums + lhs_dn = (lhs_contract, lhs_batch) + rhs_dn = (rhs_contract, rhs_batch) + + lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) + rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) + + # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy + if scaling_mode == ScalingMode.NO_SCALING: + lhs_3d = _shape_normalization(lhs, lhs_dn) + rhs_3d = _shape_normalization(rhs, rhs_dn) + elif scaling_mode.is_tensor_scaling(): + lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") + rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") + elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: + lhs_3d = _shape_normalization(lhs.data, lhs_dn) + rhs_3d = _shape_normalization(rhs.data, rhs_dn) + lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) + rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) + # swizzled_scale requires a matrix + lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) + rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) + else: + raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") + + # Note: already_transposed doesn't matter for the output shape + # x.shape = [B, D1, D2] + # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] + # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] + # x.shape = [D1, D2] + # contracting_dims = (1, ) --> output.shape = [1, D1, D2] + # contracting_dims = (0, ) --> output.shape = [1, D2, D1] + bm = lhs_remain_shape[0] + bn = rhs_remain_shape[0] + kl = lhs_3d.shape[-1] + kr = rhs_3d.shape[-1] + assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" + if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): + print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") + print(f"m = {bm}, n = {bn}, k = {kl}; ") + print("cuBLAS requires the problem shapes being multiples of 16") + assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) + + lhs_list_.append(lhs_3d) + rhs_list_.append(rhs_3d) + if scaling_mode == ScalingMode.NO_SCALING: + lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) + if scaling_mode.is_tensor_scaling(): + lhs_sinv_list_.append(lhs.scale_inv) + rhs_sinv_list_.append(rhs.scale_inv) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + lhs_sinv_list_.append(lhs_scale_inv) + rhs_sinv_list_.append(rhs_scale_inv) + if bias_list is not None: + bias_list_.append(bias_list[i]) + + out_list = GroupedGemmPrimitive.outer_primitive.bind( + *lhs_list_, + *rhs_list_, + *lhs_sinv_list_, + *rhs_sinv_list_, + *bias_list_, + num_gemms=num_gemms, + scaling_mode=scaling_mode, out_dtype=out_dtype, - has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, + has_bias=1 if bias_list is not None else 0, ) - return out + + return out_list +""" diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 07d8f81df0..7ed0db0298 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -47,7 +47,7 @@ from jax.extend import ffi # pylint: disable=ungrouped-imports -__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] +__all__ = ["quantize", "quantize_dbias", "grouped_quantize"] class BaseDBiasQuantizePrimitive(BasePrimitive): @@ -1032,24 +1032,3 @@ def grouped_quantize( group_axis=group_axis, ) return out - - -def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray: - """ - Compute the grouped bias gradient. - - Args: - grad: jnp.ndarray of shape (M, N) - group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M - - Returns: - dbias: jnp.ndarray of shape (num_groups, N) - """ - assert grad.ndim == 2, "Input grad must be a 2D tensor." - assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor." - - segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes) - grad_fp32 = grad.astype(jnp.float32) - dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0]) - dbias = dbias_fp32.astype(grad.dtype) - return dbias diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d9d519fa00..0825bd2f73 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -13,127 +13,43 @@ #include "transformer_engine/multi_stream.h" #include "xla/ffi/api/c_api.h" -#define MXFP8_BLOCK_SIZE 32 - namespace transformer_engine { namespace jax { -Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, - Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, - bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad) { +Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, + Variadic_Result_Type output_list, int64_t num_gemms, + JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major [m, k] for N - [k, m] for T - // B: row-major [k, n] for N - [n, k] for T + // A: row-major with size [m, k], + // B: row-major with size [n, k], needs transpose, // on exiting this function, JAX expect: // C: row-major with size [m, n]. // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m] for T - [m, k] for N - // B: column-major with size [n, k] for T - [k, n] for N - // + // A: column-major with size [k, m], needs transpose, + // B: column-major with size [k, n]. // If we call cuBLAS GEMM for A * B, the output will be: // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - int num_streams = nvte_get_num_compute_streams(); - - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); - auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - - // Outputs - auto out_ptr = reinterpret_cast(output->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); - auto workspace_total_size = product(workspace->dimensions()); - - auto lhs_sinv_size = product(lhs_sinv.dimensions()); - auto rhs_sinv_size = product(rhs_sinv.dimensions()); - auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams; - auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); - size_t actual_lhs_size = product(lhs_data.dimensions()); - size_t actual_rhs_size = product(rhs_data.dimensions()); - size_t actual_out_size = product(output->dimensions()); - NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", - expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, - "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, - " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, - " * ", n, " = ", expected_out_size, ", got ", actual_out_size); - } else { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, - " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, - "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, - " = ", expected_out_size, ", got ", actual_out_size); + if (num_gemms <= 0) { + return ffi_with_cuda_error_check(); } - - size_t dim_list_bytes = sizeof(int32_t) * num_gemms; - std::vector dim_list_host(num_gemms); - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); - } - + size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; + size_t expected_output_size = num_gemms + 1; + size_t actual_input_size = input_list.size(); + size_t actual_output_size = output_list.size(); + NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", + expected_input_size, actual_input_size); + NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", + expected_output_size, actual_output_size); + + bool trans_lhs = true; + bool trans_rhs = false; auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; bool use_split_accumulator = false; - auto bias_shape = std::vector{has_bias ? n : 0}; - const int arch = cuda::sm_arch(); - - // It is weird that TE/Common GEMM only use colwise for MXFP8 - const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); - const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; - const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; - const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; - - if (arch < 100 && is_fp8_gemm) { - NVTE_CHECK(!lhs_is_trans && rhs_is_trans, - "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", - "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); - } // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; @@ -151,83 +67,96 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type std::vector out_list; std::vector workspace_list; - for (size_t i = 0; i < num_gemms; i++) { - // Matrix data shapes - size_t m_i = dim_list_host[i]; - auto lhs_shape = std::vector{m_i, k}; - auto rhs_shape = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; - auto out_shape = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { - size_t k_i = dim_list_host[i]; - lhs_shape[0] = lhs_is_trans ? k_i : m; - lhs_shape[1] = lhs_is_trans ? m : k_i; - rhs_shape[0] = rhs_is_trans ? n : k_i; - rhs_shape[1] = rhs_is_trans ? k_i : n; - out_shape[0] = m; - out_shape[1] = n; - } - - // Set matrix data pointers - auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); - void *lhs_vptr = static_cast(lhs_ptr); - void *rhs_vptr = static_cast(rhs_ptr); - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); - else - rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); - else - lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape); - - // Scale_inv shapes - auto lhs_sinv_size = std::vector{1}; - auto rhs_sinv_size = std::vector{1}; - if (is_mxfp8_scaling) { - NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", - MXFP8_BLOCK_SIZE, k); - size_t scale_k = k / MXFP8_BLOCK_SIZE; - lhs_sinv_size[0] = m_i * scale_k; - rhs_sinv_size[0] = n * scale_k; - // Need to add swizzle here - } - - // Set scale_inv pointers - void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); - void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); - if (is_fp8_gemm) { - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); - else - rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); - else - lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + int lhs_list_offset = 0; + int rhs_list_offset = num_gemms; + int lhs_sinv_list_offset = 2 * num_gemms; + int rhs_sinv_list_offset = 3 * num_gemms; + int bias_list_offset = 4 * num_gemms; + int out_list_offset = 0; + for (int i = 0; i < num_gemms; i++) { + Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); + Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); + Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); + Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); + Result_Type out_i = output_list.get(out_list_offset + i).value(); + + DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); + DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); + DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); + + void *lhs_ptr = lhs_i.untyped_data(); + void *rhs_ptr = rhs_i.untyped_data(); + void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); + void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); + void *out_ptr = out_i->untyped_data(); + + // Placeholder for bias since it can be empty + DType bias_dtype = DType::kFloat32; + void *bias_ptr = nullptr; + + auto lhs_shape_ = lhs_i.dimensions(); + auto rhs_shape_ = rhs_i.dimensions(); + + // lhs and rhs has shape [1, m, k] and [1, n, k] + size_t m = lhs_shape_[1]; + size_t n = rhs_shape_[1]; + size_t k = lhs_shape_[2]; + + auto lhs_shape = std::vector{m, k}; + auto rhs_shape = std::vector{n, k}; + auto out_shape = std::vector{n, m}; + auto lhs_sinv_shape = std::vector{1, 1}; + auto rhs_sinv_shape = std::vector{1, 1}; + + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { + float *amax_dptr = nullptr; + float *scale_dptr = nullptr; + auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, + reinterpret_cast(lhs_sinv_ptr)); + auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, + reinterpret_cast(rhs_sinv_ptr)); + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); + } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Note: the scale_inv array should have been swizzled in Python before lowering + auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); + auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); + for (int i = 0; i < 2; i++) { + lhs_sinv_shape[i] = lhs_sinv_shape_[i]; + rhs_sinv_shape[i] = rhs_sinv_shape_[i]; + } + + NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); + TensorWrapper lhs_i_(nvte_scaling_mode); + TensorWrapper rhs_i_(nvte_scaling_mode); + lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); + rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); + lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); + rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); + + lhs_wrapper_list.push_back(std::move(lhs_i_)); + rhs_wrapper_list.push_back(std::move(rhs_i_)); } else { - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Unsupported scaling mode: ", static_cast(scaling_mode)); + NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); } - auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); - - // Update pointer for the next GEMM pair - lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; - rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; - out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; - if (is_fp8_gemm) { - lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; + auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); + void *pre_gelu_ptr = nullptr; + auto bias_shape = std::vector{0}; + auto pre_gelu_shape = std::vector{0}; + if (has_bias) { + auto bias_i_get = input_list.get(bias_list_offset + i); + Buffer_Type bias_i = bias_i_get.value(); + bias_ptr = bias_i.untyped_data(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); + bias_shape[0] = n; } - if (has_bias) bias_ptr += n * bias_dtype_bytes; + auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); - // Move objects to the lists to keep them alive - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - out_wrapper_list.push_back(std::move(out_i)); + out_wrapper_list.push_back(std::move(out_i_)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); @@ -238,6 +167,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type out_list.push_back(out_wrapper_list.back().data()); } + auto workspace_get = output_list.get(num_gemms); + Result_Type workspace = workspace_get.value(); + uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); + auto num_streams = nvte_get_num_compute_streams(); + size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -248,7 +182,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type } nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, + pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, workspace_list.data(), accumulate, use_split_accumulator, num_math_sm, stream); @@ -258,23 +192,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data - .Arg() // lhs_sinv - .Arg() // rhs_data - .Arg() // rhs_sinv - .Arg() // bias - .Arg() // group_sizes - .Arg() // group_offset - .Ret() // output - .Ret() // workspace - .Attr("M") - .Attr("N") - .Attr("K") - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") + .RemainingArgs() // input list + .RemainingRets() // output list + .Attr("num_gemms") .Attr("scaling_mode") - .Attr("has_bias") - .Attr("is_grouped_dense_wgrad"), + .Attr("has_bias"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index bba101c722..55d60e4189 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -153,28 +153,28 @@ def _dense_bwd_rule( # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_contracting_dim = tuple( + g_constracting_dim = tuple( range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) # k_non_contracting_dims - k_contracting_dim = tuple( + k_constracting_dim = tuple( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad = tex.gemm( casted_grad.get_rowwise_tensor(), rowwise_casted_kernel, - (g_contracting_dim, k_contracting_dim), + (g_constracting_dim, k_constracting_dim), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims - g_contracting_dim = x_contracting_dim = tuple( + g_constracting_dim = x_constracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad = tex.gemm( - colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) + colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) @@ -184,240 +184,135 @@ def _dense_bwd_rule( _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) +""" def grouped_dense( - x: jnp.ndarray, - kernel: jnp.ndarray, - group_sizes: jnp.ndarray, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), - bias: jnp.ndarray = None, - precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, - preferred_element_type: jnp.dtype = None, - group_offset: jnp.array = None, - quantizer_set: QuantizerSet = noop_quantizer_set, + x_list, + kernel_list, + bias_list, + contracting_dims_list, + quantizer_set_list=None, ): - """ - Perform grouped dense (linear) layer transformation with optional quantization. + # Perform grouped_dense layer transformation with optional quantization. - Args: - x: Input tensor of shape (M, K) - kernel: Weight matrix of shape (G, K, N) - group_sizes: 1D array of shape (G,) specifying the size of each group - contracting_dims: Tuple of sequences specifying which dimensions to contract - (currently only supports ((1,), (1,))) - bias: Bias tensor of shape (G, N) - precision: JAX precision for the GEMM operation - preferred_element_type: Preferred data type for the output tensor - group_offset: 1D array containing offsets for each group (not yet implemented) - quantizer_set: Set of quantizers for FP8 quantization of the input and output - - Returns: - A jnp.ndarray containing the result of the grouped linear operation - """ - output = _grouped_dense( - x, - kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, - quantizer_set, + output_list = _grouped_dense( + x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list ) - return output + return output_list -@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) -def _grouped_dense( - x, - kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, - quantizer_set, -): - output, _ = _grouped_dense_fwd_rule( - x, - kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, - quantizer_set, +@partial(jax.custom_vjp, nondiff_argnums=(3,)) +def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): + output_list, _ = _grouped_dense_fwd_rule( + x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list ) - return output + return output_list def _grouped_dense_fwd_rule( - x, - kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, - quantizer_set, + x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list ): - use_bias = bias is not None - is_noop_quantizer_set = quantizer_set == noop_quantizer_set - - if is_noop_quantizer_set: - grouped_gemm_x = x - grouped_gemm_kernel = kernel - ctx_x = x - ctx_kernel = kernel - flatten_axis_k = None + use_bias = bias_list is not None + output_list = [] + x_rowwise_list = [] + x_colwise_list = [] + kernel_colwise_list = [] + kernel_rowwise_list = [] + x_shape_list = [] + kernel_shape_list = [] + if quantizer_set_list is None: + x_rowwise_list = x_list + x_colwise_list = x_list + kernel_colwise_list = kernel_list + kernel_rowwise_list = kernel_list + x_shape_list = [x.shape for x in x_list] + kernel_shape_list = [kernel.shape for kernel in kernel_list] else: - x_contracting_dims, k_contracting_dims = contracting_dims - flatten_axis_x = -len(x_contracting_dims) - flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis - - assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" - assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" - # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( - "grouped_dense for FP8 can only handle x_contracting_dims=(1,) " - "and k_contracting_dims=(1,) for now, " - f"got {x_contracting_dims=} and {k_contracting_dims=}" - ) - k_contracting_dims = (0,) - - casted_x = tex.grouped_quantize( - x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x - ) - casted_kernel = tex.grouped_quantize( - kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k - ) - contracting_dims = (x_contracting_dims, k_contracting_dims) - - # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have - # rowwise_casted_x.original_shape == (M, K) - # colwise_casted_kernel.original_shape == (G, N, K) - grouped_gemm_x = casted_x.get_rowwise_tensor() - grouped_gemm_kernel = casted_kernel.get_colwise_tensor() - # TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()? - ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None - ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None - - output = tex.grouped_gemm( - grouped_gemm_x, - grouped_gemm_kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, + for i in range(len(x_list)): # pylint: disable=consider-using-enumerate + q_x = tex.quantize(x_list[i], quantizer_set_list[i].x) + q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel) + x_rowwise_list.append(q_x.get_rowwise_tensor()) + x_colwise_list.append(q_x.get_colwise_tensor()) + kernel_colwise_list.append(q_kernel.get_colwise_tensor()) + kernel_rowwise_list.append(q_kernel.get_rowwise_tensor()) + x_shape_list.append(x_rowwise_list[-1].data.shape) + kernel_shape_list.append(kernel_rowwise_list[-1].data.shape) + + output_list = tex.grouped_gemm( + x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list ) ctx = ( - group_sizes, - ctx_x, - ctx_kernel, - x.shape, - kernel.shape, + x_colwise_list, + kernel_rowwise_list, + x_shape_list, + kernel_shape_list, use_bias, - is_noop_quantizer_set, - quantizer_set, - flatten_axis_k, + quantizer_set_list, ) - return output, ctx + return output_list, ctx -def _grouped_dense_bwd_rule( - contracting_dims, precision, preferred_element_type, group_offset, ctx, grad -): - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims - +def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list): ( - group_sizes, - ctx_x, - ctx_kernel, - x_shape, - kernel_shape, + colwise_x_list, + rowwise_kernel_list, + x_shape_list, + kernel_shape_list, use_bias, - is_noop_quantizer_set, - quantizer_set, - flatten_axis_k, + quantizer_set_list, ) = ctx - if is_noop_quantizer_set: - # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) - # g_contracting_dim = (1, ) - # k_contracting_dim = (2, ) + group_size = len(grad_list) + dbias_list = [] + grad_rowwise_list = [] + grad_colwise_list = [] + dgrad_contracting_dims_list = [] + wgrad_contracting_dims_list = [] + for i in range(group_size): + grad = grad_list[i] + x_shape = x_shape_list[i] + kernel_shape = kernel_shape_list[i] + fwd_contracting_dims = contracting_dims_list[i] + + if quantizer_set_list is None: + casted_grad = grad + dbias = tex.quantization._jax_dbias(grad) + grad_rowwise_list.append(grad) + grad_colwise_list.append(grad) + else: + quantizer_set = quantizer_set_list[i] + casted_grad, dbias = tex.quantize_dbias( + grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad + ) + grad_rowwise_list.append(casted_grad.get_rowwise_tensor()) + grad_colwise_list.append(casted_grad.get_colwise_tensor()) + dbias_list.append(dbias) + + # GEMM NT + fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims g_contracting_dim = tuple( - range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) k_contracting_dim = tuple( - dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims + dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_grad = grad - dgrad_kernel_T = ctx_kernel + dgrad_contracting_dims_list.append(dgrad_contracting_dims) - # g_contracting_dim = (0, ) - # x_contracting_dim = (0, ) + # GEMM TN g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_x_T = ctx_x - wgrad_grad = grad - else: - casted_grad = tex.grouped_quantize( - grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k - ) - - # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use - # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the - # extra transpose for FP8 in grouped_gemm - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - g_contracting_dim = (1,) - k_contracting_dim = (2,) - dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_grad = casted_grad.get_rowwise_tensor() - dgrad_kernel_T = ctx_kernel - - # We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work - # after the extra transpose for FP8 in grouped_gemm - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - g_contracting_dim = (0,) - x_contracting_dim = (1,) - wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_x_T = ctx_x - wgrad_grad = casted_grad.get_colwise_tensor() - - dgrad = tex.grouped_gemm( - dgrad_grad, - dgrad_kernel_T, - group_sizes, - dgrad_contracting_dims, - precision=precision, - preferred_element_type=preferred_element_type, - group_offset=group_offset, - ) + wgrad_contracting_dims_list.append(wgrad_contracting_dims) - wgrad = tex.grouped_gemm( - wgrad_x_T, - wgrad_grad, - group_sizes, - wgrad_contracting_dims, - precision=precision, - preferred_element_type=preferred_element_type, - group_offset=group_offset, + dgrad_list = tex.grouped_gemm( + grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list ) + wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list) - group_sizes_grad = None - dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None - - return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set + return dgrad_list, wgrad_list, dbias_list, quantizer_set_list _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) +""" diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 06a2562fb1..45ec4fd1fa 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -127,16 +127,14 @@ def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatte def dequantize(scaled_tensor): """Dequantize a tensor using block scaling. + This function dequantizes a tensor that was quantized using block scaling + by applying the inverse scaling factor to each block of data. + Args: - data: The quantized tensor data - scale_inv: The inverse scaling factors - dq_dtype: The data type for dequantized values - scaling_mode: The scaling mode used for quantization - is_colwise: Whether the scaling is column-wise - flatten_axis: The axis along which the tensor could be flattened to 2D + scaled_tensor: The quantized tensor to dequantize Returns: - The dequantized tensor + The dequantized tensor in the specified data type """ return BlockScaleDequantizer._dequantize_func( scaled_tensor.data, From c9d7f3f272c654cd46a571beb5fe97a7dd7602cf Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 12 Jun 2025 13:39:53 -0400 Subject: [PATCH 11/39] [JAX] GroupedDense v.2 without dynamic shape (#1875) * Implemented GroupedDense and TestGroupedDense for BF16, FP16, and FP8 * Fix GroupedGemmFFI cuBLAS workspace alignment bug Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: Hua Huang --- tests/jax/test_custom_call_compute.py | 352 ++++++-------- .../common/gemm/cublaslt_gemm.cu | 3 + transformer_engine/jax/cpp_extensions/gemm.py | 459 ++++++++++++------ .../jax/cpp_extensions/quantization.py | 23 +- .../jax/csrc/extensions/gemm.cpp | 308 +++++++----- transformer_engine/jax/dense.py | 301 ++++++++---- .../jax/quantize/dequantizer.py | 12 +- 7 files changed, 881 insertions(+), 577 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 25a463aeaa..9ff0c11757 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -40,10 +40,11 @@ ScalingMode, QuantizerFactory, QuantizeLayout, + noop_quantizer_set, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation -from transformer_engine.jax.dense import dense +from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense GEMM_CASES = [ @@ -1204,24 +1205,6 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) -# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm() -def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer): - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - lhs_q = lhs_quantizer.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = rhs_quantizer.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return lhs_q, rhs_q - - # E5M2 * E5M2 is not supported fwd_bwd_dtypes = [ [jnp.float8_e4m3fn, jnp.float8_e4m3fn], @@ -1229,219 +1212,194 @@ def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer [jnp.float8_e5m2, jnp.float8_e4m3fn], ] -""" -@pytest_parametrize_wrapper( - "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] -) +GROUPED_DENSE_INPUT_SHAPES = [ + # (n_groups, m, n, k), the actual m will be multiplied by 32 + (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 + (8, 64, 32, 128), + (8, 64, 128, 256), +] + + +@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES) class TestGroupedDense: - def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list): - ref_out_list = [] - for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): - dim_nums = (contracting_dims, ((), ())) - ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums)) - return ref_out_list - - def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): + def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): + lhs_contract_dim, _ = contracting_dims + assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 + if bias is None: + bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) + else: + assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2]) + remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() + lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis) + rhs = jnp.split(rhs, rhs.shape[0], axis=0) + bias = jnp.split(bias, bias.shape[0], axis=0) + ref_out = [] + dim_num = (contracting_dims, ((), ())) + for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): + out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0) + ref_out.append(jnp.squeeze(out_i)) + return ref_out + + def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, len(shape_list) * 2) - - lhs_list, rhs_list, contracting_dims_list = [], [], [] - for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)): - lhs = jax.random.uniform( - subkeys[2 * i], - (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), - dtype=dtype, - ) - rhs = jax.random.uniform( - subkeys[2 * i + 1], - (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), - dtype=dtype, - ) - lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) - rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) - contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + subkeys = jax.random.split(key, 4) + n_groups, m, n, k = input_shape + + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + assert group_sizes.sum() == m + + # *32 to make sure that input shape works for MXFP8 + group_sizes = group_sizes * 32 + m = m * 32 - lhs_list.append(lhs) - rhs_list.append(rhs) - contracting_dims_list.append(contracting_dims) + lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) + rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) + bias_shape = (n_groups, n) - return lhs_list, rhs_list, contracting_dims_list + lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) + rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype) + bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None + + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,) + contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + + return lhs, rhs, group_sizes, contracting_dims, bias + + def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): + assert out.dtype == ref_list[0].dtype + out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + for i in range(len(ref_list)): + assert_allclose(out_list[i], ref_list[i], dtype=dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) - def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list): - lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( - dtype, shape_list, layout_list + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp16(self, dtype, input_shape, layout): + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + dtype, input_shape, layout ) - ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) - primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list) - for i in range(len(shape_list)): - assert_allclose(primitive_out[i], ref_out[i], dtype=dtype) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) - def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list): + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + pytest.skip("MXFP8 is not supported in grouped_gemm yet") + fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False + scaling_mode=scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=False, + n_groups=input_shape[0], ) + # quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype + # We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype + quantizer_set.kernel.q_dtype = bwd_dtype + for quantizer in quantizer_set.kernel.quantizers: + quantizer.q_dtype = bwd_dtype + out_dtype = jnp.bfloat16 - lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( - out_dtype, shape_list, layout_list + lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( + out_dtype, input_shape, layout + ) + ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + prim_out = tex.grouped_gemm( + lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set ) - q_lhs_list = [] - q_rhs_list = [] - for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): - # quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to - # test the case where lhs and rhs have different q_dtypes - q_lhs, q_rhs = _quantize_gemm_pair( - lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad - ) - q_lhs_list.append(q_lhs) - q_rhs_list.append(q_rhs) - - ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) - primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list) allclose_dtype = jnp.float8_e4m3fn - if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: + if jnp.float8_e5m2 in fwd_bwd_dtype: allclose_dtype = jnp.float8_e5m2 - for i in range(len(shape_list)): - assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype) - @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) - def test_grouped_dense_grad_fp16(self, dtype, shape_list): - group_size = len(shape_list) - layout_list = ["NN" for _ in range(group_size)] + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) - x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( - dtype, shape_list, layout_list + def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): + out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) + # Note: we use jnp.sum instead of jnp.mean to make the gradient larger + # and prevent them from being clamp to zero + out_sum_list = [jnp.sum(out) for out in out_list] + return jnp.sum(jnp.asarray(out_sum_list)) + + def _primitive_sum_grouped_dense( + self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set + ): + out = grouped_dense( + x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set ) - bias_list = [] - key = jax.random.PRNGKey(1) - for shape in shape_list: - n = shape[1] - bias = jax.random.uniform(key, n, dtype=dtype) - bias_list.append(bias) - - def ref_func(x_list, kernel_list, bias_list, contracting_dims_list): - out_list = [] - for i in range(len(x_list)): - out_list.append( - dense( - x_list[i], - kernel_list[i], - bias_list[i], - contracting_dims=contracting_dims_list[i], - ) - ) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + return jnp.sum(jnp.asarray(out)) - def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list): - out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list) - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) + @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) + def test_grouped_dense_grad_fp16(self, dtype, input_shape): + x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, + input_shape, + with_bias=True, + ) - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) + value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) - ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( - x_list, kernel_list, bias_list, contracting_dims_list + ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( + x, kernel, bias, group_sizes, contracting_dims ) - primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( - value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list) + prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( + x, kernel, bias, group_sizes, contracting_dims ) - assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype) - for i in range(group_size): - assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype) - assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype) - assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype) + assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) + assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) + assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) + @pytest.mark.parametrize( + "fwd_bwd_dtype", + [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], + ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list): - group_size = len(shape_list) - layout_list = ["NN" for _ in range(group_size)] - fwd_dtype, bwd_dtype = fwd_bwd_dtype - if fwd_dtype == jnp.float8_e5m2: - pytest.skip("We never use E5M2 for fwd_dtype in training") - - # Question: should we use different quantizers for different groups? - ref_quantizer_set_list = [] - quantizer_set_list = [] - for _ in range(group_size): - ref_quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True - ) - ref_quantizer_set_list.append(ref_quantizer_set) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True - ) - quantizer_set_list.append(quantizer_set) + def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + pytest.skip("MXFP8 is not supported in grouped_dense yet") - out_dtype = jnp.bfloat16 - x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( - out_dtype, shape_list, layout_list + fwd_dtype, bwd_dtype = fwd_bwd_dtype + dtype = jnp.bfloat16 + x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, + input_shape, + with_bias=True, ) - bias_list = [] - key = jax.random.PRNGKey(1) - for shape in shape_list: - n = shape[1] - bias = jax.random.uniform(key, n, dtype=out_dtype) - bias_list.append(bias) - - def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): - out_list = [] - for i in range(len(x_list)): - out_list.append( - dense( - x_list[i], - kernel_list[i], - bias_list[i], - contracting_dims=contracting_dims_list[i], - quantizer_set=quantizer_set_list[i], - ) - ) - # Note: we use jnp.sum instead of jnp.mean to make the gradient larger - # and prevent them from being clamp to zero - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) - - def primitive_func( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ): - out_list = grouped_dense( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ) - out_sum_list = [jnp.sum(out) for out in out_list] - return jnp.sum(jnp.asarray(out_sum_list)) - - value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) - ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( - x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list + quantizer_set = QuantizerFactory.create_set( + scaling_mode=scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=True, + n_groups=group_sizes.size, ) - primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( - value_n_grad_primitive_func( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list - ) + value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + + ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( + x, + kernel, + bias, + group_sizes, + contracting_dims, + ) + prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( + x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set ) - allclose_dtype = jnp.float8_e4m3fn - if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: - allclose_dtype = jnp.float8_e5m2 - assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype) - for i in range(group_size): - assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype) - assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype) - assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype) -""" + assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) + assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) + assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) + assert_allclose(prim_dbias, ref_dbias, dtype=dtype) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 4080ae1668..fa8785dcc7 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -525,6 +525,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); + const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -533,6 +534,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); + NVTE_CHECK(workspace_alignment % 256 == 0, + "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c38a04f85a..cc02ec3404 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -6,22 +6,28 @@ from typing import Tuple, Sequence, Union, Dict from functools import partial, reduce import operator +import math import jax import jax.numpy as jnp from transformer_engine_jax import get_device_compute_capability from .base import BasePrimitive, register_primitive +from .quantization import grouped_quantize from ..quantize import ( ScaledTensor, + GroupedScaledTensor1x, ScalingMode, Quantizer, + GroupedQuantizer, QuantizeConfig, + QuantizerSet, + QuantizeLayout, noop_quantizer_set, ) -__all__ = ["gemm"] +__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] num_cublas_streams = 4 @@ -34,6 +40,11 @@ def get_cublas_workspace_size_bytes() -> None: return 4_194_304 +def is_gemm_with_all_layouts_supported() -> False: + """Return True if using blackwell, False otherwise.""" + return get_device_compute_capability(0) >= 100 + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -41,73 +52,139 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = () + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @staticmethod - def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): + def abstract( + lhs_data_aval, + lhs_scale_inv_aval, + rhs_data_aval, + rhs_scale_inv_aval, + bias_aval, + group_sizes_aval, + group_offset_aval, + *, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): """ + Grouped GEMM operation. + Args: - *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: - args[ 0 : num_gemms] are the lhs tensors, - args[ num_gemms : 2*num_gemms] are the rhs tensors, - args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, - args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, - args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. - num_gemms: Number of GEMM operations to perform. - scaling_mode: Scaling mode for the GEMM operations. - out_dtype: Data type of the output tensors. - has_bias: Boolean indicating if bias tensors are provided. + lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array + rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array + bias: Bias matrix of shape (G, N) + group_sizes: 1D array containing the sizes of each group + group_offset: 1D array containing offsets for each group (not yet implemented) + M: Number of rows in the output matrix + N: Number of columns in the output matrix + K: Number of columns in the left-hand side matrix + lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed + rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed + scaling_mode: Scaling mode for the GEMM operations + out_dtype: Data type of the output tensors + has_bias: Boolean indicating if bias tensors are provided + is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation + where both lhs and rhs are 2D matrices and output is (G, M, N) Returns: - A tuple of ShapedArray objects of size num_gemms+1: - ret[0 : num_gemms]: GEMM output tensors, - ret[num_gemms]:workspace tensor. + A jnp.ndarray containing the result of the grouped GEMM operation """ - del scaling_mode - expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms - assert ( - len(args) == expected_num_args - ), f"Expected {expected_num_args} input arguments, but got {len(args)}" - A_list = args[0:num_gemms] - B_list = args[num_gemms : 2 * num_gemms] - # A and B have shapes [1, m, k] and [1, n, k] - out_list_aval = tuple( - jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype) - for A, B in zip(A_list, B_list) - ) + del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval + del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias + # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams + workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return (*out_list_aval, workspace_aval) + out_shape = (M, N) + if is_grouped_dense_wgrad: + out_shape = (group_sizes_aval.size, M, N) + out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + return (out_aval, workspace_aval) @staticmethod def outer_abstract(*args, **kwargs): (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) - return out_aval + return (out_aval,) @staticmethod - def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): + def lowering( + ctx, + *args, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( ctx, *args, - num_gemms=num_gemms, - scaling_mode=int(scaling_mode), + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) @staticmethod - def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): + def impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M, + N, + K, + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + is_grouped_dense_wgrad, + ): assert GroupedGemmPrimitive.inner_primitive is not None - out = GroupedGemmPrimitive.inner_primitive.bind( - *args, - num_gemms=num_gemms, - scaling_mode=scaling_mode.value, + (out, _) = GroupedGemmPrimitive.inner_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - return out[:-1] # out is [out_list, wkspace], only return out_list + return (out,) register_primitive(GroupedGemmPrimitive) @@ -285,7 +362,7 @@ def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """General matrix multiplication with optional quantization. @@ -310,130 +387,190 @@ def gemm( return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) -""" -def swizzled_scale(scales): - # Swizzle the scale tensor for FP8 GEMM - assert scales.ndim == 2 - rows, cols = scales.shape - scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) - scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) - scales = scales.reshape(rows, cols) - return scales +def grouped_gemm( + lhs: Union[jnp.ndarray, GroupedScaledTensor1x], + rhs: Union[jnp.ndarray, GroupedScaledTensor1x], + group_sizes: jnp.ndarray, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), + bias: jnp.ndarray = None, + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, + preferred_element_type: jnp.dtype = None, + group_offset: jnp.array = None, + quantizer_set: QuantizerSet = noop_quantizer_set, +) -> jnp.ndarray: + """ + Grouped GEMM operation. + + Args: + lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x + rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x + group_sizes: 1D array containing the sizes of each group + contracting_dims: Tuple of two sequences representing the contracting dimensions + bias: Bias tensor of shape (G, N) + precision: JAX precision for the GEMM operation + preferred_element_type: Preferred data type for the output tensor + group_offset: 1D array containing offsets for each group (not yet implemented) + quantizer_set: Set of quantizers for FP8 quantization of the input and output + Returns: + A jnp.ndarray containing the result of the grouped GEMM operation -def grouped_gemm( - lhs_list: List[Union[jnp.ndarray, ScaledTensor]], - rhs_list: List[Union[jnp.ndarray, ScaledTensor]], - contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], - bias_list: List[jnp.ndarray] = None, -) -> List[jnp.ndarray]: - # Grouped GEMM for multiple pairs of tensors. - assert ( - len(lhs_list) == len(rhs_list) == len(contracting_dims_list) - ), "lhs_list, rhs_list, contracting_dims_list must have the same length" - - num_gemms = len(lhs_list) - lhs_list_ = [] - rhs_list_ = [] - lhs_sinv_list_ = [] - rhs_sinv_list_ = [] - bias_list_ = [] - for i in range(num_gemms): - lhs = lhs_list[i] - rhs = rhs_list[i] - contracting_dims = contracting_dims_list[i] - dim_nums = (contracting_dims, ((), ())) - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - scaling_mode = lhs.scaling_mode - lhs_shape = lhs.data.shape - rhs_shape = rhs.data.shape - out_dtype = lhs.dq_dtype - # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode.is_tensor_scaling(): - assert not ( - lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 - ), "FP8 GEMM does not support E5M2 * E5M2" - ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - if lhs.data_layout == "T": - lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim - if rhs.data_layout == "T": - rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim - dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) + Note: + Tested shapes: + lhs: [M, K] or [K, N] + rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] + """ + # TODO(Phuong): implement the group_offset + group_offset = group_offset or jnp.zeros((1,), jnp.int32) + + # TODO(Phuong): implement the precision + del precision + + if isinstance(lhs, jnp.ndarray): + assert isinstance(rhs, jnp.ndarray) + out_dtype = lhs.dtype + lhs_shape = lhs.shape + rhs_shape = rhs.shape + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + elif isinstance(lhs, GroupedScaledTensor1x): + assert isinstance(rhs, GroupedScaledTensor1x) + out_dtype = lhs.dq_dtype + lhs_shape = lhs.original_shape + rhs_shape = rhs.original_shape + lhs_data = lhs.data + rhs_data = rhs.data + lhs_scale_inv = lhs.scale_inv + rhs_scale_inv = rhs.scale_inv + assert lhs.scaling_mode == rhs.scaling_mode + scaling_mode = lhs.scaling_mode + else: + raise TypeError("Unsupported lhs type object!") + + out_dtype = preferred_element_type or out_dtype + + lhs_contract_dim, rhs_contract_dim = contracting_dims + + lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 + lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) + + # rhs_shape [G, K, N] + rhs_is_trans = rhs_contract_dim[0] != 1 + rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) + + is_grouped_dense_wgrad = False + if len(rhs_shape) == 2: + rhs_is_trans = rhs_contract_dim[0] != 0 + is_grouped_dense_wgrad = True + + # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? + if ( + is_grouped_dense_wgrad + and not isinstance(lhs, ScaledTensor) + and not isinstance(rhs, ScaledTensor) + ): + lhs_is_trans = True + rhs_is_trans = False + lhs_flatten_axis = 1 + rhs_flatten_axis = 1 + + if ( + not isinstance(lhs, ScaledTensor) + and not isinstance(rhs, ScaledTensor) + and quantizer_set != noop_quantizer_set + ): + assert isinstance(quantizer_set.x, GroupedQuantizer) + assert type(quantizer_set.x) is type(quantizer_set.kernel) + scaling_mode = quantizer_set.x.scaling_mode + if ( + # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later + # scaling_mode.is_tensor_scaling() + # and is_gemm_with_all_layouts_supported() + scaling_mode.is_1d_block_scaling() + ): + lhs_is_rowwise = rhs_is_rowwise = True else: - # For jnp.ndarray, only consider contracting_dims, data_layout is always NN - scaling_mode = ScalingMode.NO_SCALING - lhs_shape = lhs.shape - rhs_shape = rhs.shape - out_dtype = lhs.dtype - - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums - lhs_dn = (lhs_contract, lhs_batch) - rhs_dn = (rhs_contract, rhs_batch) - - lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) - rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) - - # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy - if scaling_mode == ScalingMode.NO_SCALING: - lhs_3d = _shape_normalization(lhs, lhs_dn) - rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode.is_tensor_scaling(): - lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") - rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn) - rhs_3d = _shape_normalization(rhs.data, rhs_dn) - lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) - rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) - # swizzled_scale requires a matrix - lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) - rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) + lhs_is_rowwise = not lhs_is_trans + rhs_is_rowwise = lhs_is_trans + quantizer_set.x.q_layout = ( + QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE + ) + quantizer_set.kernel.q_layout = ( + QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE + ) + lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + rhs_q = grouped_quantize( + rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis + ) + lhs_data = lhs_q.data + rhs_data = rhs_q.data + lhs_scale_inv = lhs_q.scale_inv + rhs_scale_inv = rhs_q.scale_inv + + assert not ( + lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 + ), "FP8 GEMM does not support E5M2 * E5M2" + + # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs + # thus additional transpose is required + # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later + if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported(): + lhs_is_trans = False + rhs_is_trans = True + if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): + lhs_layout_is_T = lhs.data_layout == "T" + rhs_layout_is_T = rhs.data_layout == "T" else: - raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") - - # Note: already_transposed doesn't matter for the output shape - # x.shape = [B, D1, D2] - # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] - # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] - # x.shape = [D1, D2] - # contracting_dims = (1, ) --> output.shape = [1, D1, D2] - # contracting_dims = (0, ) --> output.shape = [1, D2, D1] - bm = lhs_remain_shape[0] - bn = rhs_remain_shape[0] - kl = lhs_3d.shape[-1] - kr = rhs_3d.shape[-1] - assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" - if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): - print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") - print(f"m = {bm}, n = {bn}, k = {kl}; ") - print("cuBLAS requires the problem shapes being multiples of 16") - assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) - - lhs_list_.append(lhs_3d) - rhs_list_.append(rhs_3d) - if scaling_mode == ScalingMode.NO_SCALING: - lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode.is_tensor_scaling(): - lhs_sinv_list_.append(lhs.scale_inv) - rhs_sinv_list_.append(rhs.scale_inv) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_sinv_list_.append(lhs_scale_inv) - rhs_sinv_list_.append(rhs_scale_inv) - if bias_list is not None: - bias_list_.append(bias_list[i]) - - out_list = GroupedGemmPrimitive.outer_primitive.bind( - *lhs_list_, - *rhs_list_, - *lhs_sinv_list_, - *rhs_sinv_list_, - *bias_list_, - num_gemms=num_gemms, - scaling_mode=scaling_mode, + lhs_layout_is_T = lhs_q.data_layout == "T" + rhs_layout_is_T = rhs_q.data_layout == "T" + lhs_ndim = len(lhs_shape) + rhs_ndim = len(rhs_shape) + if lhs_layout_is_T: + lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) + if rhs_layout_is_T: + rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) + lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T) + rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T) + + # Calling GroupedGEMM Custom Call + K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) + K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) + assert K_lhs == K_rhs + M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) + N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G + + if is_grouped_dense_wgrad: + N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) + else: + assert group_sizes.size == rhs_shape[0] + + assert group_offset.size == 1 + + has_bias = bias is not None + assert not has_bias or bias.shape == (group_sizes.size, N) + bias = jnp.empty((), jnp.float32) if bias is None else bias + + # TODO(Phuong): support MXFP8_1D_SCALING + assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported" + + (out,) = GroupedGemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + group_sizes, + group_offset, + M=M, + N=N, + K=K_lhs, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode.value, out_dtype=out_dtype, - has_bias=1 if bias_list is not None else 0, + has_bias=has_bias, + is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) - - return out_list -""" + return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7ed0db0298..07d8f81df0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -47,7 +47,7 @@ from jax.extend import ffi # pylint: disable=ungrouped-imports -__all__ = ["quantize", "quantize_dbias", "grouped_quantize"] +__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] class BaseDBiasQuantizePrimitive(BasePrimitive): @@ -1032,3 +1032,24 @@ def grouped_quantize( group_axis=group_axis, ) return out + + +def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray: + """ + Compute the grouped bias gradient. + + Args: + grad: jnp.ndarray of shape (M, N) + group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M + + Returns: + dbias: jnp.ndarray of shape (num_groups, N) + """ + assert grad.ndim == 2, "Input grad must be a 2D tensor." + assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor." + + segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes) + grad_fp32 = grad.astype(jnp.float32) + dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0]) + dbias = dbias_fp32.astype(grad.dtype) + return dbias diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 0825bd2f73..d9d519fa00 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -13,43 +13,127 @@ #include "transformer_engine/multi_stream.h" #include "xla/ffi/api/c_api.h" +#define MXFP8_BLOCK_SIZE 32 + namespace transformer_engine { namespace jax { -Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, - Variadic_Result_Type output_list, int64_t num_gemms, - JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { +Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, + Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, + bool is_grouped_dense_wgrad) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major with size [m, k], - // B: row-major with size [n, k], needs transpose, + // A: row-major [m, k] for N - [k, m] for T + // B: row-major [k, n] for N - [n, k] for T // on exiting this function, JAX expect: // C: row-major with size [m, n]. // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m], needs transpose, - // B: column-major with size [k, n]. + // A: column-major with size [k, m] for T - [m, k] for N + // B: column-major with size [n, k] for T - [k, n] for N + // // If we call cuBLAS GEMM for A * B, the output will be: // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - if (num_gemms <= 0) { - return ffi_with_cuda_error_check(); + int num_streams = nvte_get_num_compute_streams(); + + // Inputs + auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); + auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); + auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); + auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); + auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); + auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); + auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + + NVTE_CHECK(group_sizes.dimensions().size() == 1); + size_t num_gemms = group_sizes.dimensions()[0]; + + // Outputs + auto out_ptr = reinterpret_cast(output->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + auto workspace_total_size = product(workspace->dimensions()); + + auto lhs_sinv_size = product(lhs_sinv.dimensions()); + auto rhs_sinv_size = product(rhs_sinv.dimensions()); + auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams; + auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; + auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + + size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); + size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); + size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); + size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); + size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); + size_t out_dtype_bytes = te_dtype_bytes(out_dtype); + + NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); + NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, + "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); + + size_t expected_lhs_size = m * k; + size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t actual_lhs_size = product(lhs_data.dimensions()); + size_t actual_rhs_size = product(rhs_data.dimensions()); + size_t actual_out_size = product(output->dimensions()); + NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", + expected_lhs_size, ", got ", actual_lhs_size); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, + "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, + " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, + " * ", n, " = ", expected_out_size, ", got ", actual_out_size); + } else { + NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, + " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); + NVTE_CHECK(expected_out_size == actual_out_size, + "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, + " = ", expected_out_size, ", got ", actual_out_size); } - size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms; - size_t expected_output_size = num_gemms + 1; - size_t actual_input_size = input_list.size(); - size_t actual_output_size = output_list.size(); - NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu", - expected_input_size, actual_input_size); - NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu", - expected_output_size, actual_output_size); - - bool trans_lhs = true; - bool trans_rhs = false; + + size_t dim_list_bytes = sizeof(int32_t) * num_gemms; + std::vector dim_list_host(num_gemms); + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_grouped_dense_wgrad) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; bool use_split_accumulator = false; + auto bias_shape = std::vector{has_bias ? n : 0}; + const int arch = cuda::sm_arch(); + + // It is weird that TE/Common GEMM only use colwise for MXFP8 + const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); + const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + + if (arch < 100 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; @@ -67,96 +151,83 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, std::vector out_list; std::vector workspace_list; - int lhs_list_offset = 0; - int rhs_list_offset = num_gemms; - int lhs_sinv_list_offset = 2 * num_gemms; - int rhs_sinv_list_offset = 3 * num_gemms; - int bias_list_offset = 4 * num_gemms; - int out_list_offset = 0; - for (int i = 0; i < num_gemms; i++) { - Buffer_Type lhs_i = input_list.get(lhs_list_offset + i).value(); - Buffer_Type rhs_i = input_list.get(rhs_list_offset + i).value(); - Buffer_Type lhs_sinv_i = input_list.get(lhs_sinv_list_offset + i).value(); - Buffer_Type rhs_sinv_i = input_list.get(rhs_sinv_list_offset + i).value(); - Result_Type out_i = output_list.get(out_list_offset + i).value(); - - DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); - DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type()); - DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type()); - - void *lhs_ptr = lhs_i.untyped_data(); - void *rhs_ptr = rhs_i.untyped_data(); - void *lhs_sinv_ptr = lhs_sinv_i.untyped_data(); - void *rhs_sinv_ptr = rhs_sinv_i.untyped_data(); - void *out_ptr = out_i->untyped_data(); - - // Placeholder for bias since it can be empty - DType bias_dtype = DType::kFloat32; - void *bias_ptr = nullptr; - - auto lhs_shape_ = lhs_i.dimensions(); - auto rhs_shape_ = rhs_i.dimensions(); - - // lhs and rhs has shape [1, m, k] and [1, n, k] - size_t m = lhs_shape_[1]; - size_t n = rhs_shape_[1]; - size_t k = lhs_shape_[2]; - - auto lhs_shape = std::vector{m, k}; - auto rhs_shape = std::vector{n, k}; - auto out_shape = std::vector{n, m}; - auto lhs_sinv_shape = std::vector{1, 1}; - auto rhs_sinv_shape = std::vector{1, 1}; - - if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || - scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { - float *amax_dptr = nullptr; - float *scale_dptr = nullptr; - auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr, - reinterpret_cast(lhs_sinv_ptr)); - auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr, - reinterpret_cast(rhs_sinv_ptr)); - lhs_wrapper_list.push_back(std::move(lhs_i_)); - rhs_wrapper_list.push_back(std::move(rhs_i_)); - } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Note: the scale_inv array should have been swizzled in Python before lowering - auto lhs_sinv_shape_ = lhs_sinv_i.dimensions(); - auto rhs_sinv_shape_ = rhs_sinv_i.dimensions(); - for (int i = 0; i < 2; i++) { - lhs_sinv_shape[i] = lhs_sinv_shape_[i]; - rhs_sinv_shape[i] = rhs_sinv_shape_[i]; - } - - NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); - TensorWrapper lhs_i_(nvte_scaling_mode); - TensorWrapper rhs_i_(nvte_scaling_mode); - lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); - rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); - lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); - rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); - - lhs_wrapper_list.push_back(std::move(lhs_i_)); - rhs_wrapper_list.push_back(std::move(rhs_i_)); - } else { - NVTE_ERROR("Unsupported scaling mode: ", static_cast(scaling_mode)); + for (size_t i = 0; i < num_gemms; i++) { + // Matrix data shapes + size_t m_i = dim_list_host[i]; + auto lhs_shape = std::vector{m_i, k}; + auto rhs_shape = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; + auto out_shape = std::vector{m_i, n}; + if (is_grouped_dense_wgrad) { + size_t k_i = dim_list_host[i]; + lhs_shape[0] = lhs_is_trans ? k_i : m; + lhs_shape[1] = lhs_is_trans ? m : k_i; + rhs_shape[0] = rhs_is_trans ? n : k_i; + rhs_shape[1] = rhs_is_trans ? k_i : n; + out_shape[0] = m; + out_shape[1] = n; } - auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); - void *pre_gelu_ptr = nullptr; - auto bias_shape = std::vector{0}; - auto pre_gelu_shape = std::vector{0}; - if (has_bias) { - auto bias_i_get = input_list.get(bias_list_offset + i); - Buffer_Type bias_i = bias_i_get.value(); - bias_ptr = bias_i.untyped_data(); - bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); - bias_shape[0] = n; + // Set matrix data pointers + auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); + void *lhs_vptr = static_cast(lhs_ptr); + void *rhs_vptr = static_cast(rhs_ptr); + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); + else + rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); + else + lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape); + + // Scale_inv shapes + auto lhs_sinv_size = std::vector{1}; + auto rhs_sinv_size = std::vector{1}; + if (is_mxfp8_scaling) { + NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", + MXFP8_BLOCK_SIZE, k); + size_t scale_k = k / MXFP8_BLOCK_SIZE; + lhs_sinv_size[0] = m_i * scale_k; + rhs_sinv_size[0] = n * scale_k; + // Need to add swizzle here } + + // Set scale_inv pointers + void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); + void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); + if (is_fp8_gemm) { + if (rhs_use_colwise) // MatA to enter cuBLAS + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); + else + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); + if (lhs_use_colwise) // MatB to enter cuBLAS + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + else + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + } else { + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Unsupported scaling mode: ", static_cast(scaling_mode)); + } + auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); + auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); + + // Update pointer for the next GEMM pair + lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; + rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; + out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; + if (is_fp8_gemm) { + lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; + rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; + } + if (has_bias) bias_ptr += n * bias_dtype_bytes; - out_wrapper_list.push_back(std::move(out_i_)); + // Move objects to the lists to keep them alive + lhs_wrapper_list.push_back(std::move(lhs_i)); + rhs_wrapper_list.push_back(std::move(rhs_i)); + out_wrapper_list.push_back(std::move(out_i)); bias_wrapper_list.push_back(std::move(bias_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); @@ -167,11 +238,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, out_list.push_back(out_wrapper_list.back().data()); } - auto workspace_get = output_list.get(num_gemms); - Result_Type workspace = workspace_get.value(); - uint8_t *workspace_ptr = reinterpret_cast(workspace->untyped_data()); - auto num_streams = nvte_get_num_compute_streams(); - size_t workspace_size = workspace->dimensions()[0] / num_streams; auto workspace_shape = std::vector{workspace_size}; for (int i = 0; i < num_streams; i++) { auto workspace_i = @@ -182,7 +248,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, } nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, + pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, workspace_list.data(), accumulate, use_split_accumulator, num_math_sm, stream); @@ -192,11 +258,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .RemainingArgs() // input list - .RemainingRets() // output list - .Attr("num_gemms") + .Arg() // lhs_data + .Arg() // lhs_sinv + .Arg() // rhs_data + .Arg() // rhs_sinv + .Arg() // bias + .Arg() // group_sizes + .Arg() // group_offset + .Ret() // output + .Ret() // workspace + .Attr("M") + .Attr("N") + .Attr("K") + .Attr("lhs_is_trans") + .Attr("rhs_is_trans") .Attr("scaling_mode") - .Attr("has_bias"), + .Attr("has_bias") + .Attr("is_grouped_dense_wgrad"), FFI_CudaGraph_Traits); } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 55d60e4189..bba101c722 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -153,28 +153,28 @@ def _dense_bwd_rule( # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_constracting_dim = tuple( + g_contracting_dim = tuple( range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) # k_non_contracting_dims - k_constracting_dim = tuple( + k_contracting_dim = tuple( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad = tex.gemm( casted_grad.get_rowwise_tensor(), rowwise_casted_kernel, - (g_constracting_dim, k_constracting_dim), + (g_contracting_dim, k_contracting_dim), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims - g_constracting_dim = x_constracting_dim = tuple( + g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad = tex.gemm( - colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) + colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) @@ -184,135 +184,240 @@ def _dense_bwd_rule( _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) -""" def grouped_dense( - x_list, - kernel_list, - bias_list, - contracting_dims_list, - quantizer_set_list=None, + x: jnp.ndarray, + kernel: jnp.ndarray, + group_sizes: jnp.ndarray, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), + bias: jnp.ndarray = None, + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, + preferred_element_type: jnp.dtype = None, + group_offset: jnp.array = None, + quantizer_set: QuantizerSet = noop_quantizer_set, ): - # Perform grouped_dense layer transformation with optional quantization. + """ + Perform grouped dense (linear) layer transformation with optional quantization. - output_list = _grouped_dense( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + Args: + x: Input tensor of shape (M, K) + kernel: Weight matrix of shape (G, K, N) + group_sizes: 1D array of shape (G,) specifying the size of each group + contracting_dims: Tuple of sequences specifying which dimensions to contract + (currently only supports ((1,), (1,))) + bias: Bias tensor of shape (G, N) + precision: JAX precision for the GEMM operation + preferred_element_type: Preferred data type for the output tensor + group_offset: 1D array containing offsets for each group (not yet implemented) + quantizer_set: Set of quantizers for FP8 quantization of the input and output + + Returns: + A jnp.ndarray containing the result of the grouped linear operation + """ + output = _grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ) - return output_list + return output -@partial(jax.custom_vjp, nondiff_argnums=(3,)) -def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): - output_list, _ = _grouped_dense_fwd_rule( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list +@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) +def _grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, +): + output, _ = _grouped_dense_fwd_rule( + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ) - return output_list + return output def _grouped_dense_fwd_rule( - x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list + x, + kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, + quantizer_set, ): - use_bias = bias_list is not None - output_list = [] - x_rowwise_list = [] - x_colwise_list = [] - kernel_colwise_list = [] - kernel_rowwise_list = [] - x_shape_list = [] - kernel_shape_list = [] - if quantizer_set_list is None: - x_rowwise_list = x_list - x_colwise_list = x_list - kernel_colwise_list = kernel_list - kernel_rowwise_list = kernel_list - x_shape_list = [x.shape for x in x_list] - kernel_shape_list = [kernel.shape for kernel in kernel_list] + use_bias = bias is not None + is_noop_quantizer_set = quantizer_set == noop_quantizer_set + + if is_noop_quantizer_set: + grouped_gemm_x = x + grouped_gemm_kernel = kernel + ctx_x = x + ctx_kernel = kernel + flatten_axis_k = None else: - for i in range(len(x_list)): # pylint: disable=consider-using-enumerate - q_x = tex.quantize(x_list[i], quantizer_set_list[i].x) - q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel) - x_rowwise_list.append(q_x.get_rowwise_tensor()) - x_colwise_list.append(q_x.get_colwise_tensor()) - kernel_colwise_list.append(q_kernel.get_colwise_tensor()) - kernel_rowwise_list.append(q_kernel.get_rowwise_tensor()) - x_shape_list.append(x_rowwise_list[-1].data.shape) - kernel_shape_list.append(kernel_rowwise_list[-1].data.shape) - - output_list = tex.grouped_gemm( - x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list + x_contracting_dims, k_contracting_dims = contracting_dims + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis + + assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" + assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" + # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( + "grouped_dense for FP8 can only handle x_contracting_dims=(1,) " + "and k_contracting_dims=(1,) for now, " + f"got {x_contracting_dims=} and {k_contracting_dims=}" + ) + k_contracting_dims = (0,) + + casted_x = tex.grouped_quantize( + x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x + ) + casted_kernel = tex.grouped_quantize( + kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k + ) + contracting_dims = (x_contracting_dims, k_contracting_dims) + + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have + # rowwise_casted_x.original_shape == (M, K) + # colwise_casted_kernel.original_shape == (G, N, K) + grouped_gemm_x = casted_x.get_rowwise_tensor() + grouped_gemm_kernel = casted_kernel.get_colwise_tensor() + # TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()? + ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None + ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None + + output = tex.grouped_gemm( + grouped_gemm_x, + grouped_gemm_kernel, + group_sizes, + contracting_dims, + bias, + precision, + preferred_element_type, + group_offset, ) ctx = ( - x_colwise_list, - kernel_rowwise_list, - x_shape_list, - kernel_shape_list, + group_sizes, + ctx_x, + ctx_kernel, + x.shape, + kernel.shape, use_bias, - quantizer_set_list, + is_noop_quantizer_set, + quantizer_set, + flatten_axis_k, ) - return output_list, ctx + return output, ctx -def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list): +def _grouped_dense_bwd_rule( + contracting_dims, precision, preferred_element_type, group_offset, ctx, grad +): + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims + ( - colwise_x_list, - rowwise_kernel_list, - x_shape_list, - kernel_shape_list, + group_sizes, + ctx_x, + ctx_kernel, + x_shape, + kernel_shape, use_bias, - quantizer_set_list, + is_noop_quantizer_set, + quantizer_set, + flatten_axis_k, ) = ctx - group_size = len(grad_list) - dbias_list = [] - grad_rowwise_list = [] - grad_colwise_list = [] - dgrad_contracting_dims_list = [] - wgrad_contracting_dims_list = [] - for i in range(group_size): - grad = grad_list[i] - x_shape = x_shape_list[i] - kernel_shape = kernel_shape_list[i] - fwd_contracting_dims = contracting_dims_list[i] - - if quantizer_set_list is None: - casted_grad = grad - dbias = tex.quantization._jax_dbias(grad) - grad_rowwise_list.append(grad) - grad_colwise_list.append(grad) - else: - quantizer_set = quantizer_set_list[i] - casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad - ) - grad_rowwise_list.append(casted_grad.get_rowwise_tensor()) - grad_colwise_list.append(casted_grad.get_colwise_tensor()) - dbias_list.append(dbias) - - # GEMM NT - fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims + if is_noop_quantizer_set: + # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) + # g_contracting_dim = (1, ) + # k_contracting_dim = (2, ) g_contracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) ) k_contracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims + dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_contracting_dims_list.append(dgrad_contracting_dims) + dgrad_grad = grad + dgrad_kernel_T = ctx_kernel - # GEMM TN + # g_contracting_dim = (0, ) + # x_contracting_dim = (0, ) g_contracting_dim = x_contracting_dim = tuple( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_contracting_dims_list.append(wgrad_contracting_dims) + wgrad_x_T = ctx_x + wgrad_grad = grad + else: + casted_grad = tex.grouped_quantize( + grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k + ) + + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use + # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the + # extra transpose for FP8 in grouped_gemm + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + g_contracting_dim = (1,) + k_contracting_dim = (2,) + dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) + dgrad_grad = casted_grad.get_rowwise_tensor() + dgrad_kernel_T = ctx_kernel + + # We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work + # after the extra transpose for FP8 in grouped_gemm + # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? + g_contracting_dim = (0,) + x_contracting_dim = (1,) + wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) + wgrad_x_T = ctx_x + wgrad_grad = casted_grad.get_colwise_tensor() + + dgrad = tex.grouped_gemm( + dgrad_grad, + dgrad_kernel_T, + group_sizes, + dgrad_contracting_dims, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) - dgrad_list = tex.grouped_gemm( - grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list + wgrad = tex.grouped_gemm( + wgrad_x_T, + wgrad_grad, + group_sizes, + wgrad_contracting_dims, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) - wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list) - return dgrad_list, wgrad_list, dbias_list, quantizer_set_list + group_sizes_grad = None + dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + + return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) -""" diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 45ec4fd1fa..06a2562fb1 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -127,14 +127,16 @@ def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatte def dequantize(scaled_tensor): """Dequantize a tensor using block scaling. - This function dequantizes a tensor that was quantized using block scaling - by applying the inverse scaling factor to each block of data. - Args: - scaled_tensor: The quantized tensor to dequantize + data: The quantized tensor data + scale_inv: The inverse scaling factors + dq_dtype: The data type for dequantized values + scaling_mode: The scaling mode used for quantization + is_colwise: Whether the scaling is column-wise + flatten_axis: The axis along which the tensor could be flattened to 2D Returns: - The dequantized tensor in the specified data type + The dequantized tensor """ return BlockScaleDequantizer._dequantize_func( scaled_tensor.data, From 40a30a5f8de9d97105a31bad7f23ed40abdf739b Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 12 Jun 2025 19:56:29 +0200 Subject: [PATCH 12/39] [PyTorch] Support L2Normalization basic op -> use for qk_norm (#1864) * Support L2Norm basic op Signed-off-by: Evgeny * Add L2Norm module wrapper Signed-off-by: Evgeny * Expose qk_norm to MHA nd transformer laayer Signed-off-by: Evgeny * Move tests into separate file Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix pass Signed-off-by: Evgeny * Add license Signed-off-by: Evgeny * Remove module Signed-off-by: Evgeny * Resollve comments Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Evgeny Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 52 ++++ tests/pytorch/test_jit.py | 59 +++++ tests/pytorch/test_qk_norm.py | 242 ++++++++++++++++++ .../pytorch/attention/multi_head_attention.py | 38 +++ transformer_engine/pytorch/jit.py | 91 +++++++ .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/l2normalization.py | 128 +++++++++ transformer_engine/pytorch/transformer.py | 16 ++ 8 files changed, 627 insertions(+) create mode 100644 tests/pytorch/test_qk_norm.py create mode 100644 transformer_engine/pytorch/ops/basic/l2normalization.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index a0c0ee5faa..b1706db612 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1273,6 +1273,58 @@ def test_rmsnorm( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("in_shape", ((32,), (6, 16, 64), (32, 64))) + @pytest.mark.parametrize("dtype", _dtypes) + def test_l2normalization( + self, + *, + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + eps: float = 1e-6, + ) -> None: + """L2 Normalization""" + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + # L2 norm: x / ||x||_2 = x / sqrt(sum(x^2) + eps) + l2_norm_squared = x_ref.pow(2).sum(dim=-1, keepdim=True) + rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) + y_ref = x_ref * rsqrt_norm + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.L2Normalization( + eps=eps, + ) + y_test = op(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + + torch.testing.assert_close(y_test, y_ref, **tols) + # L2Norm backward pass requires slightly looser atol for bfloat16 + if dtype == torch.bfloat16: + tols["atol"] = 2e-3 + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("fp8", (False, True)) diff --git a/tests/pytorch/test_jit.py b/tests/pytorch/test_jit.py index a697cc0483..e670070bc0 100644 --- a/tests/pytorch/test_jit.py +++ b/tests/pytorch/test_jit.py @@ -63,3 +63,62 @@ def test_lazy_compile(): from transformer_engine.pytorch.jit import dgelu_fused_ dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10)) + + +def test_l2normalization_fused(): + """Smoke test for L2Normalization fusion functions.""" + from transformer_engine.pytorch.jit import ( + l2normalization_fused, + l2normalization_fwd_fused, + l2normalization_backward_fused, + ) + + # Basic smoke test like other JIT functions + x = torch.randn(10, 128, device="cuda", dtype=torch.float32) + eps = 1e-6 + + # Test inference version + output_inf = l2normalization_fused(x, eps) + + # Test training version with backward + x_train = torch.randn(10, 128, device="cuda", dtype=torch.float32, requires_grad=True) + output_train, rsqrt_norm = l2normalization_fwd_fused(x_train, eps) + grad_output = torch.randn_like(output_train) + grad_input = l2normalization_backward_fused(grad_output, x_train, rsqrt_norm, eps) + + +def test_l2normalization_fused_correctness(): + """Simple verification that L2Normalization fusion matches reference implementation.""" + from transformer_engine.pytorch.jit import ( + l2normalization_fwd_fused, + l2normalization_backward_fused, + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + x = torch.randn(16, 64, device=device, dtype=torch.float32, requires_grad=True) + eps = 1e-6 + + # Test fused forward + output_fused, rsqrt_norm = l2normalization_fwd_fused(x, eps) + + # Reference implementation + x_ref = x.clone().detach().requires_grad_(True) + x_squared = x_ref.pow(2) + l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) + rsqrt_norm_ref = torch.rsqrt(l2_norm_squared + eps) + output_ref = x_ref * rsqrt_norm_ref + + # Check forward pass matches + torch.testing.assert_close(output_fused, output_ref, atol=1e-6, rtol=1e-5) + torch.testing.assert_close(rsqrt_norm, rsqrt_norm_ref, atol=1e-6, rtol=1e-5) + + # Test fused backward + grad_output = torch.randn_like(output_fused) + grad_input_fused = l2normalization_backward_fused(grad_output, x, rsqrt_norm, eps) + + # Reference backward + output_ref.backward(grad_output) + grad_input_ref = x_ref.grad + + # Check backward pass matches + torch.testing.assert_close(grad_input_fused, grad_input_ref, atol=1e-5, rtol=1e-4) diff --git a/tests/pytorch/test_qk_norm.py b/tests/pytorch/test_qk_norm.py new file mode 100644 index 0000000000..6f4e62f81a --- /dev/null +++ b/tests/pytorch/test_qk_norm.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from transformer_engine.pytorch import MultiheadAttention + +import pytest +import torch + + +@pytest.mark.parametrize("use_qk_norm", [False, True]) +@pytest.mark.parametrize("attention_type", ["self", "cross"]) +@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5]) +def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None: + """Test QK normalization functionality, module structure, and numerical behavior.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 128 + + # Create MultiheadAttention module + mha = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_type=attention_type, + use_qk_norm=use_qk_norm, + qk_norm_eps=qk_norm_eps, + bias=False, + device="cuda", + ).cuda() + + # Check module structure based on use_qk_norm parameter + if use_qk_norm: + assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True" + assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module" + assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module" + # Check that the module is L2Norm type + from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization + + assert isinstance( + mha.qk_norm, L2Normalization + ), "qk_norm should be an L2Normalization module" + else: + assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False" + + # Create input tensors + batch_size = 2 # Use a fixed batch size for testing + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + if attention_type == "cross": + encoder_output = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + else: + encoder_output = None + + # Test forward pass + with torch.no_grad(): + if attention_type == "cross": + output = mha(hidden_states, encoder_output=encoder_output) + else: + output = mha(hidden_states) + + # Check output shape and numerical properties + assert output.shape == ( + seq_len, + batch_size, + hidden_size, + ), f"Output shape mismatch: {output.shape}" + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + + # Test with RoPE (if self-attention) + if attention_type == "self": + head_dim = hidden_size // num_attention_heads + rotary_dim = head_dim // 2 + rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32) + + with torch.no_grad(): + output_with_rope = mha(hidden_states, rotary_pos_emb=rotary_pos_emb) + + assert output_with_rope.shape == ( + seq_len, + batch_size, + hidden_size, + ), "Output shape with RoPE mismatch" + assert not torch.isnan(output_with_rope).any(), "RoPE output contains NaN" + assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf" + + +def test_qk_norm_output_difference() -> None: + """Test that QK normalization actually changes the output compared to no normalization.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 128 + batch_size = 2 + + # Use same random seed to ensure identical weight initialization + current_rng_state = torch.get_rng_state() + current_cuda_rng_state = torch.cuda.get_rng_state() + + # Reset to a known seed for reproducible initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create model with QK normalization + mha_with_norm = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_qk_norm=True, + bias=False, + device="cuda", + ).cuda() + + # Reset to same seed for identical initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create identical model without QK normalization + mha_no_norm = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_qk_norm=False, + bias=False, + device="cuda", + ).cuda() + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Compare outputs with identical weights but different QK norm settings + with torch.no_grad(): + output_with_norm = mha_with_norm(hidden_states) + output_no_norm = mha_no_norm(hidden_states) + + # Outputs should be different when QK normalization is enabled + assert not torch.allclose( + output_with_norm, output_no_norm, atol=1e-6 + ), "QK normalization should change the output, but outputs are identical" + + +def test_qk_norm_with_fused_qkv() -> None: + """Test QK normalization works with fused QKV parameters.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 64 + + mha = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + fuse_qkv_params=True, + use_qk_norm=True, + bias=False, + device="cuda", + ).cuda() + + # Create input and test forward pass + batch_size = 2 # Use a fixed batch size for testing + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + with torch.no_grad(): + output = mha(hidden_states) + + assert output.shape == ( + seq_len, + batch_size, + hidden_size, + ), f"Output shape mismatch: {output.shape}" + + +def test_qk_norm_transformer_layer_output_difference() -> None: + """Test that QK normalization actually changes TransformerLayer output compared to no normalization.""" + from transformer_engine.pytorch import TransformerLayer + + hidden_size = 256 + ffn_hidden_size = 1024 + num_attention_heads = 8 + seq_len = 128 + batch_size = 2 + + # Use same random seed to ensure identical weight initialization + current_rng_state = torch.get_rng_state() + current_cuda_rng_state = torch.cuda.get_rng_state() + + # Reset to a known seed for reproducible initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create TransformerLayer with QK normalization + transformer_with_norm = TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + use_qk_norm=True, + bias=False, + device="cuda", + ).cuda() + + # Reset to same seed for identical initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create identical TransformerLayer without QK normalization + transformer_no_norm = TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + use_qk_norm=False, + bias=False, + device="cuda", + ).cuda() + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Compare outputs with identical weights but different QK norm settings + with torch.no_grad(): + output_with_norm = transformer_with_norm(hidden_states) + output_no_norm = transformer_no_norm(hidden_states) + + # Outputs should be different when QK normalization is enabled + assert not torch.allclose( + output_with_norm, output_no_norm, atol=1e-6 + ), "QK normalization should change the TransformerLayer output, but outputs are identical" + + # Check that outputs have expected shapes and properties + assert output_with_norm.shape == ( + seq_len, + batch_size, + hidden_size, + ), f"Output shape mismatch: {output_with_norm.shape}" + assert not torch.isnan(output_with_norm).any(), "Output with QK norm contains NaN" + assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf" + assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN" + assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf" diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 256eec742a..142044240b 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -12,6 +12,7 @@ from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear +from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization from transformer_engine.pytorch.utils import ( SplitAlongDim, divide, @@ -174,6 +175,22 @@ class MultiheadAttention(torch.nn.Module): parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument `fuse_wgrad_accumulation`. + use_qk_norm: bool, default = 'False' + if set to `True`, L2 normalization is applied to query and key tensors + after RoPE (if applicable) but before attention computation. + This follows the Llama4 approach for QK normalization to improve + training stability and model performance. + qk_norm_eps: float, default = 1e-6 + epsilon value for L2 normalization of query and key tensors. + Only used when `use_qk_norm` is True. + seq_length: Optional[int], default = `None` + sequence length of input samples. Needed for JIT Warmup, a technique where jit + fused functions are warmed up before training to ensure same kernels are used for + forward propagation and activation recompute phase. + micro_batch_size: Optional[int], default = `None` + batch size per training step. Needed for JIT Warmup, a technique where jit + fused functions are warmed up before training to ensure same kernels are + used for forward propagation and activation recompute phase. """ def __init__( @@ -214,6 +231,10 @@ def __init__( device: Union[torch.device, str] = "cuda", qkv_format: str = "sbhd", name: str = None, + use_qk_norm: bool = False, + qk_norm_eps: float = 1e-6, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, ) -> None: super().__init__() @@ -267,6 +288,7 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name + self.use_qk_norm = use_qk_norm common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -278,6 +300,14 @@ def __init__( "device": device, } + # Initialize L2 normalization modules for query and key if enabled + if self.use_qk_norm: + self.qk_norm = L2Normalization( + eps=qk_norm_eps, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + ) + qkv_parallel_mode = "column" if set_parallel_mode else None if self.attention_type == "self": @@ -812,6 +842,14 @@ def forward( interleaved=self.rotary_pos_interleaved, ) + # =========================== + # Apply L2 normalization to query and key tensors + # =========================== + + if self.use_qk_norm: + query_layer = self.qk_norm(query_layer) + key_layer = self.qk_norm(key_layer) + # =========================== # Core attention computation # =========================== diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 8410b1551e..3902d0a48b 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -121,6 +121,35 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: return dgelu +@jit_fuser +def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor: + """L2 normalization fused - inference version""" + x_squared = x.pow(2) + l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) + rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) + return x * rsqrt_norm + + +@jit_fuser +def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: + """L2 normalization fused - training version that returns intermediate values""" + x_squared = x.pow(2) + l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) + rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) + y = x * rsqrt_norm + return y, rsqrt_norm + + +@jit_fuser +def l2normalization_backward_fused_( + grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float +) -> torch.Tensor: + """L2 normalization backward fused""" + x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True) + x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps + return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared) + + def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: """Disable native AMP for bias_gelu_fused_""" with gpu_autocast_ctx(enabled=False): @@ -139,6 +168,26 @@ def bgrad_dgelu_fused( return None, dgelu_fused_(grad_output, inp) +def l2normalization_fused(x: torch.Tensor, eps: float) -> torch.Tensor: + """Disable native AMP for l2normalization_fused_ - inference version""" + with gpu_autocast_ctx(enabled=False): + return l2normalization_fused_(x, eps) + + +def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: + """Disable native AMP for l2normalization_fwd_fused_ - training version""" + with gpu_autocast_ctx(enabled=False): + return l2normalization_fwd_fused_(x, eps) + + +def l2normalization_backward_fused( + grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float +) -> torch.Tensor: + """Disable native AMP for l2normalization_backward_fused_""" + with gpu_autocast_ctx(enabled=False): + return l2normalization_backward_fused_(grad_output, x, rsqrt_norm, eps) + + def bias_dropout_add( x: torch.Tensor, bias: torch.Tensor, @@ -264,3 +313,45 @@ def warmup_jit_bias_gelu_all_dtypes( """Call `warmup_jit_bias_gelu` for all training dtypes""" for dtype in [torch.float32, torch.bfloat16, torch.float16]: warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size) + + +def warmup_jit_l2normalization( + hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int +) -> None: + """Compile L2Normalization JIT function before the main training steps""" + + # Save cuda RNG state to ensure warmup does not affect reproducibility. + rng_state = torch.cuda.get_rng_state() + + inp = torch.rand( + (seq_length * micro_batch_size, hidden_size), + dtype=dtype, + device="cuda", + ) + eps = 1e-6 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad in [False, True]: + inp.requires_grad = input_grad + for _ in range(5): + if input_grad: + # Test training version that returns intermediate values + output, rsqrt_norm = l2normalization_fwd_fused_(inp, eps) + # Test backward pass as well + grad_out = torch.rand_like(output) + _ = l2normalization_backward_fused_(grad_out, inp, rsqrt_norm, eps) + else: + # Test inference version + output = l2normalization_fused_(inp, eps) + del inp, output + + torch.cuda.empty_cache() + torch.cuda.set_rng_state(rng_state) + + +def warmup_jit_l2normalization_all_dtypes( + hidden_size: int, seq_length: int, micro_batch_size: int +) -> None: + """Call `warmup_jit_l2normalization` for all training dtypes""" + for dtype in [torch.float32, torch.bfloat16, torch.float16]: + warmup_jit_l2normalization(hidden_size, dtype, seq_length, micro_batch_size) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index ae635c956a..c69e3df027 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -11,6 +11,7 @@ from .basic_linear import BasicLinear from .bias import Bias from .identity import Identity +from .l2normalization import L2Normalization from .layer_norm import LayerNorm from .make_extra_output import MakeExtraOutput from .quantize import Quantize diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py new file mode 100644 index 0000000000..c7f0b7999a --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusable operation for L2 Normalization.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from ...tensor import QuantizedTensor +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from ...jit import ( + l2normalization_fused, + l2normalization_fwd_fused, + l2normalization_backward_fused, + set_jit_fusion_options, + warmup_jit_l2normalization_all_dtypes, +) + + +class L2Normalization(BasicOperation): + r"""L2 Normalization + + Applies L2 normalization over the last dimension of input tensors. + This is a parameter-free normalization that scales each vector to unit L2 norm. + + .. math:: + y = \frac{x}{\sqrt{\sum_{i} x_i^2 + \varepsilon}} + + This operation is used e.g. for query-key normalization in attention mechanisms. + + Parameters + ---------- + eps : float, default = 1e-6 + A value added to the denominator for numerical stability + seq_length: int, default = None + sequence length of input samples. Needed for JIT Warmup, a technique where jit fused + functions are warmed up before training to ensure same kernels are used for forward + propagation and activation recompute phase. + micro_batch_size: int, default = None + batch size per training step. Needed for JIT Warmup, a technique where jit + fused functions are warmed up before training to ensure same kernels are + used for forward propagation and activation recompute phase. + + """ + + def __init__( + self, + *, + eps: float = 1e-6, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + ) -> None: + super().__init__() + self.eps: float = eps + + # JIT warmup for L2Normalization fused operations + if seq_length and micro_batch_size: + if torch.cuda.is_available(): + set_jit_fusion_options() + # For L2Normalization, we don't know the hidden size until forward pass, + # but we can warm up with common sizes. For QK normalization, this will be + # the attention head dimension (hidden_size_per_attention_head), not the full + # model hidden dimension. Common head dimensions are 32, 64, 80, 96, 128, 256. + common_hidden_sizes = [32, 64, 80, 96, 128, 256] + for hidden_size in common_hidden_sizes: + warmup_jit_l2normalization_all_dtypes(hidden_size, seq_length, micro_batch_size) + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + # Use input directly - torch.compile can handle multi-dimensional tensors + x = input_ + + if isinstance(x, QuantizedTensor): + x = x.dequantize() + + # Check if backward pass is needed + requires_grad = ctx.requires_grad + + # Compute L2 normalization using fused implementation + # L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps) + if requires_grad: + # Training: use version that returns both output and intermediate values + y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps) + else: + # Inference: use lightweight version that only returns output + y = l2normalization_fused(x, self.eps) + rsqrt_norm = None # Not needed for inference + + # Save state for backward pass + if requires_grad: + ctx.save_for_backward(x, rsqrt_norm) + ctx.has_prev_op = prev_op is not None + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + x, rsqrt_norm = ctx.saved_tensors + + dy = grad_output + + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + + # Compute L2 norm backward pass using fused implementation + dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps) + + # Clear saved tensors if possible + if ctx.has_prev_op: + clear_tensor_data(x) + clear_tensor_data(rsqrt_norm) + + # No parameters, so empty tuple for param grads + return dx, () diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 63874e2f55..3135b01688 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -235,6 +235,14 @@ class TransformerLayer(torch.nn.Module): parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument `fuse_wgrad_accumulation`. + use_qk_norm: bool, default = 'False' + if set to `True`, L2 normalization is applied to query and key tensors + after RoPE (if applicable) but before attention computation. + This follows the Llama4 approach for QK normalization to improve + training stability and model performance. + qk_norm_eps: float, default = 1e-6 + epsilon value for L2 normalization of query and key tensors. + Only used when `use_qk_norm` is True. """ def __init__( @@ -284,6 +292,8 @@ def __init__( device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", name: str = None, + use_qk_norm: bool = False, + qk_norm_eps: float = 1e-6, ) -> None: super().__init__() @@ -373,6 +383,8 @@ def __init__( "ub_overlap_rs": ub_overlap_rs, "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad, "qkv_format": self.attn_input_format, + "seq_length": seq_length, + "micro_batch_size": micro_batch_size, } self.self_attention = MultiheadAttention( @@ -384,6 +396,8 @@ def __init__( return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, + use_qk_norm=use_qk_norm, + qk_norm_eps=qk_norm_eps, name=name + ".self_attention" if name is not None else None, ) @@ -398,6 +412,8 @@ def __init__( return_bias=True, normalization=normalization, device=device, + use_qk_norm=use_qk_norm, + qk_norm_eps=qk_norm_eps, name=name + ".inter_attention" if name is not None else None, ) From 227961e6099a519e7c005ebae36cc210e1858813 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Thu, 12 Jun 2025 14:34:50 -0700 Subject: [PATCH 13/39] [JAX] Distinguish the reasons why fp8 / mxfp8 is not supported in unit test (#1873) Distinguish the reasons why fp8 is not supported and mxfp8 is not supported Signed-off-by: Hua Huang --- tests/jax/test_custom_call_compute.py | 38 +++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9ff0c11757..b81f3fb9bf 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -57,8 +57,8 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] -is_fp8_supported, reason = helper.is_fp8_available() -is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) +is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available() +is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] """ Find supported scaling modes""" @@ -209,7 +209,7 @@ def test_act_grad(self, shape, activation_type): assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @@ -240,7 +240,7 @@ def test_act_grad_with_tensor_scaling_fp8( assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @@ -270,7 +270,7 @@ def test_act_forward_with_tensor_scaling_fp8( assert_bitwise_scaled_tensors(te_output, jax_output) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)]) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @@ -391,7 +391,7 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper( @@ -506,7 +506,7 @@ def _test_norm_forward( if norm_type == "layernorm": assert_allclose(mu, ref_mu, dtype=inp_dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper( @@ -542,7 +542,7 @@ def test_norm_forward_with_tensor_scaling_fp8( q_layout=q_layout, ) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) def test_norm_forward_with_block_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype @@ -591,7 +591,7 @@ def test_norm_forward_with_block_scaling_fp8( } -@pytest.mark.skipif(not is_fp8_supported, reason=reason) +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @@ -638,7 +638,7 @@ def test_quantize_bitwise( assert_bitwise_scaled_tensors(te_output, jax_output) -@pytest.mark.skipif(not is_fp8_supported, reason=reason) +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @@ -692,7 +692,7 @@ def test_grouped_qdq( @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) class TestFusedQuantize: - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @@ -793,7 +793,7 @@ def test_quantize_dact_dbias_no_quantization( q_layout=QuantizeLayout.ROWWISE, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @@ -817,7 +817,7 @@ def test_quantize_dact_dbias_tensor_scaling( q_layout=q_layout, ) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper( "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] @@ -886,7 +886,7 @@ def test_gemm_bf16(self, m, n, k, data_layout): assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @@ -928,7 +928,7 @@ def ref_func(x, w, data_layout): assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @@ -992,7 +992,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @@ -1077,7 +1077,7 @@ def ref_func(x, w, gamma, beta): if beta is not None: assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @@ -1284,7 +1284,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) @@ -1360,7 +1360,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize( "fwd_bwd_dtype", [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], From ecaf3e21503d885f002fa06261f0131b9fc7b523 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 12 Jun 2025 19:01:50 -0400 Subject: [PATCH 14/39] Fixes for JIT-able grouped_gemm (#1872) * fixes for jittable grouped_quantize * fixes for jittable grouped_gemm * fix contracting_dim for wgrad gemm * exclude jitted grouped_gemm from the unit test as it does not work cudaGraph --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_custom_call_compute.py | 29 ++++++- transformer_engine/jax/cpp_extensions/gemm.py | 11 ++- .../jax/cpp_extensions/quantization.py | 12 +-- transformer_engine/jax/csrc/extensions.h | 1 + .../jax/csrc/extensions/gemm.cpp | 15 ++-- .../jax/csrc/extensions/pybind.cpp | 1 + transformer_engine/jax/dense.py | 2 +- transformer_engine/jax/quantize/tensor.py | 76 ++++++++++--------- 8 files changed, 90 insertions(+), 57 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index b81f3fb9bf..f689bce6a5 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp -import numpy as np import pytest from jax import jit, value_and_grad from functools import reduce @@ -13,7 +12,6 @@ from utils import ( assert_allclose, - assert_tree_like_allclose, pytest_parametrize_wrapper, ) from transformer_engine.jax.layernorm import layernorm @@ -682,6 +680,10 @@ def test_grouped_qdq( n_groups=n_groups, ) + # grouped_quantize does not work with cudaGraph yet, so the jitting will breaks + # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to + # disable cudaGraph, then use the following jitted function + scaled_tensor = tex.grouped_quantize( x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer ) @@ -1281,6 +1283,16 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): dtype, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + + # grouped_gemm does not work with cudaGraph yet, so the jitting will breaks + # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to + # disable cudaGraph, then use the following jitted function + + # jitting grouped_gemm + # prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + # lhs, rhs, group_sizes, contracting_dims, + # ) + prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) @@ -1312,6 +1324,12 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout out_dtype, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + + # jitting grouped_gemm + # prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))( + # lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + # ) + prim_out = tex.grouped_gemm( lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set ) @@ -1346,6 +1364,9 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): ) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + # jitting the grouped_dense + # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), + # static_argnums=(4,)) value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( @@ -1386,6 +1407,10 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): n_groups=group_sizes.size, ) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) + + # jitting the grouped_dense + # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), + # static_argnums=(4,)) value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index cc02ec3404..d3c23015c1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -9,7 +9,7 @@ import math import jax import jax.numpy as jnp -from transformer_engine_jax import get_device_compute_capability +from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -30,7 +30,7 @@ __all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] -num_cublas_streams = 4 +num_cublas_streams = get_num_compute_streams() def get_cublas_workspace_size_bytes() -> None: @@ -103,10 +103,15 @@ def abstract( """ del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias + del lhs_scale_inv_aval, rhs_scale_inv_aval # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams - workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size + # JAX buffer pointers are 128-aligned + # 255 is added to the workspace size to ensure workspace ptr is 256-aligned + workspace_size += 255 workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + # TODO(phuong): We should make separate tmp buffers for swizzled scales to avoid unaligned-by-256 workspace ptr issue + out_shape = (M, N) if is_grouped_dense_wgrad: out_shape = (group_sizes_aval.size, M, N) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 07d8f81df0..11b3cdc2a3 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -839,9 +839,7 @@ def outer_abstract(*args, **kwargs): scale_inv, colwise_scale_inv, updated_amax, - _dbias, - _wkspace, - ) = DBiasQuantizePrimitive.abstract(*args, **kwargs) + ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax @staticmethod @@ -975,7 +973,9 @@ def grouped_quantize( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) - segment_ids = jnp.repeat(jnp.arange(n_groups), group_sizes) + segment_ids = jnp.repeat( + jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] + ) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype) @@ -1048,7 +1048,9 @@ def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray: assert grad.ndim == 2, "Input grad must be a 2D tensor." assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor." - segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes) + segment_ids = jnp.repeat( + jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0] + ) grad_fp32 = grad.astype(jnp.float32) dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0]) dbias = dbias_fp32.astype(grad.dtype) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3b9bd1944a..aa257abe95 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -30,6 +30,7 @@ #include "extensions/misc.h" #include "extensions/utils.h" #include "transformer_engine/activation.h" +#include "transformer_engine/multi_stream.h" // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d9d519fa00..d57d4682ca 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -10,7 +10,6 @@ #include "../extensions.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" -#include "transformer_engine/multi_stream.h" #include "xla/ffi/api/c_api.h" #define MXFP8_BLOCK_SIZE 32 @@ -58,14 +57,12 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // Outputs auto out_ptr = reinterpret_cast(output->untyped_data()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); - auto workspace_total_size = product(workspace->dimensions()); - - auto lhs_sinv_size = product(lhs_sinv.dimensions()); - auto rhs_sinv_size = product(rhs_sinv.dimensions()); - auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams; - auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned + auto workspace_ptr = + reinterpret_cast((reinterpret_cast(workspace->untyped_data()) + 255) & + ~static_cast(255)); + auto workspace_total_size = product(workspace->dimensions()) - 255; + auto workspace_size = workspace_total_size / num_streams; size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 785bf1a198..2d7801cc20 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -69,6 +69,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("get_num_compute_streams", &nvte_get_num_compute_streams); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes); m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes); diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index bba101c722..8834f4f73c 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -389,7 +389,7 @@ def _grouped_dense_bwd_rule( # after the extra transpose for FP8 in grouped_gemm # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? g_contracting_dim = (0,) - x_contracting_dim = (1,) + x_contracting_dim = (0,) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_x_T = ctx_x wgrad_grad = casted_grad.get_colwise_tensor() diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 0a7829fc4d..02b1a1a99e 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -131,22 +131,16 @@ def __post_init__(self): Ensures the scale_inv shape matches the expected shape based on the scaling mode and quantization direction. Pads the scale_inv if necessary. """ - flatten_axis = ( - len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis - ) + assert self.flatten_axis > 0 assert ( - 0 < flatten_axis < len(self.data.shape) - ), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}" - - if self.data_layout == "T": - flatten_axis = self.data.ndim - flatten_axis - self.flatten_axis = flatten_axis + 0 < self.flatten_axis < len(self.data.shape) + ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis + self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis ) expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis + self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis ) if self.scale_inv.shape != expected_scale_shape: assert self.scale_inv.shape == expected_unpadded_scale_shape, ( @@ -291,6 +285,7 @@ def __init__( original_shape, group_axis=0, ): + self.flatten_axis = flatten_axis self.group_sizes = group_sizes self.original_shape = original_shape self.group_axis = group_axis @@ -301,44 +296,25 @@ def __init__( def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.data.ndim == 1, "Only support flattened data" + assert self.group_axis >= 0 + assert self.flatten_axis > 0 data_ndim = len(self.original_shape) - flatten_axis = data_ndim + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis assert ( - 0 < flatten_axis < data_ndim - ), f"flatten_axis {flatten_axis} is out of bounds for data.ndim = {data_ndim}" + 0 < self.flatten_axis < data_ndim + ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}" - group_axis = ( - len(self.original_shape) + self.group_axis if self.group_axis < 0 else self.group_axis - ) assert ( - 0 <= group_axis < data_ndim - ), f"group_axis {group_axis} is out of bounds for shape {self.original_shape}" - - if self.data_layout == "T": - if self.original_shape[0] == self.group_sizes.size: - self.original_shape = ( - self.original_shape[0], - *self.original_shape[flatten_axis:], - *self.original_shape[1:flatten_axis], - ) - flatten_axis = len(self.original_shape) - flatten_axis + 1 - else: - self.original_shape = ( - *self.original_shape[flatten_axis:], - *self.original_shape[:flatten_axis], - ) - self.group_axis = flatten_axis - flatten_axis = len(self.original_shape) - flatten_axis + 0 <= self.group_axis < data_ndim + ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}" - self.flatten_axis = flatten_axis expected_scale_shape = self.scaling_mode.get_grouped_scale_shape( self.original_shape, self.group_sizes.size, self.group_axis, self.is_colwise, is_padded=True, - flatten_axis=flatten_axis, + flatten_axis=self.flatten_axis, ) assert self.scale_inv.shape == expected_scale_shape, ( @@ -479,10 +455,31 @@ def create_1x( A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided """ dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) + if group_sizes is not None: + flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" + + # Handling attrs of transposed tensors + group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis + if data_layout == "T": + if original_shape[0] == group_sizes.size: + original_shape = ( + original_shape[0], + *original_shape[flatten_axis:], + *original_shape[1:flatten_axis], + ) + flatten_axis = len(original_shape) - flatten_axis + 1 + else: + original_shape = ( + *original_shape[flatten_axis:], + *original_shape[:flatten_axis], + ) + group_axis = flatten_axis + flatten_axis = len(original_shape) - flatten_axis + return GroupedScaledTensor1x( data=data, scale_inv=scale_inv, @@ -497,6 +494,11 @@ def create_1x( group_axis=group_axis, ) + # Handling attrs of transposed tensors + flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + if data_layout == "T": + flatten_axis = data.ndim - flatten_axis + return ScaledTensor1x( data, scale_inv, From d90ced7c0348174b066832ddc0f00c8db8405d09 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Fri, 13 Jun 2025 12:49:24 +1200 Subject: [PATCH 15/39] Add support for overlapping wgrad NCCL AG with dgrad GEMM (#1849) * Add support for overlapping wgrad NCCL AG with dgrad GEMM Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * Remove unused wait on memcpy API from UB Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * Add better commenting to MXFP8 overlap Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --------- Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> Co-authored-by: dastokes --- setup.py | 5 +++ .../distributed/run_gemm_with_overlap.py | 4 +- .../distributed/run_layer_with_overlap.py | 1 + transformer_engine/pytorch/csrc/extensions.h | 4 ++ .../csrc/extensions/comm_gemm_overlap.cpp | 8 ++++ .../pytorch/csrc/extensions/pybind.cpp | 6 ++- .../pytorch/module/layernorm_linear.py | 41 +++++++++++-------- .../pytorch/module/layernorm_mlp.py | 40 ++++++++++-------- transformer_engine/pytorch/module/linear.py | 23 +++++++---- .../ops/fused/userbuffers_backward_linear.py | 28 +++++++++---- 10 files changed, 109 insertions(+), 51 deletions(-) diff --git a/setup.py b/setup.py index 8cdedde844..0b1b523277 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,11 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + # Add custom CMake arguments from environment variable + nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") + if nvte_cmake_extra_args: + cmake_flags.extend(nvte_cmake_extra_args.split()) + # Project directory root root_path = Path(__file__).resolve().parent diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index ce936f9644..6d9e2f1526 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -273,7 +273,9 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") assert dist.is_nccl_available() dist.init_process_group(**dist_init_kwargs) - tp_group = dist.new_group(backend="nccl") + tp_group = dist.new_group( + backend="nccl", pg_options=dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) + ) tp_rank = dist.get_rank(tp_group) tp_size = dist.get_world_size(tp_group) dist_print( diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 48ace31c33..8638c1bcea 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -323,6 +323,7 @@ def _train(opts): new_group_kwargs = { "backend": "nccl", "ranks": tp_rank_list, + "pg_options": dist.ProcessGroupNCCL.Options(is_high_priority_stream=True), } else: opts.tp = WORLD_SIZE diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 361c24b22c..d4218a08be 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -430,6 +430,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve at::Tensor get_buffer(bool local_chunk = false, std::optional> shape = std::nullopt); + at::Stream get_communication_stream(); + }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { @@ -449,6 +451,8 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm at::Tensor get_buffer(bool local_chunk = false, std::optional> shape = std::nullopt); + at::Stream get_communication_stream(); + }; // CommOverlapP2P #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 61eab1c654..0e7bca25b1 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -216,6 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional, transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( @@ -402,5 +403,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, - py::arg("shape") = std::nullopt); + py::arg("shape") = std::nullopt) + .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); } diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3f3c70e027..dba98b0150 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -743,6 +743,31 @@ def backward( wgrad = None if ctx.requires_wgrad: + # Prepare grad output tensor + # Note: Synchronize tensor-parallel communication and + # make sure required data is available + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + # UB does not support overlapping grad output + # all-gather with wgrad GEMM. Also, we can't + # convert row-scaled MXFP8 to column-scaled, so we + # can't reuse the grad output that was gathered + # for the dgrad GEMM. We work around by explicitly + # overlapping the NCCL operation with the dgrad GEMM. + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + + # Get the communication stream from the dgrad GEMM and set it as the current torch stream + dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() + with torch.cuda.stream(dgrad_comm_stream): + # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather + # This ensures that we don't start until all communication for the dgrad GEMM is complete + grad_output, mxfp8_grad_output_work = gather_along_first_dim( + grad_outputs[0], + ctx.tp_group, + async_op=True, + quantizer=ctx.grad_output_quantizer, + ) + # Synchronize with the main stream + mxfp8_grad_output_work.wait() # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -757,22 +782,6 @@ def backward( ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - # Prepare grad output tensor - # Note: Synchronize tensor-parallel communication and - # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output - # all-gather with wgrad GEMM. Also, we can't - # convert row-scaled MXFP8 to column-scaled, so we - # can't reuse the grad output that was gathered - # for the dgrad GEMM. We work around with blocking - # all-gather for column-scaled MXFP8 data. - ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output, _ = gather_along_first_dim( - grad_outputs[0], - ctx.tp_group, - quantizer=ctx.grad_output_quantizer, - ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(columnwise_usage=True) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d4298b91f8..f1f0475907 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -832,6 +832,30 @@ def backward( fc2_wgrad = None if ctx.fc2_weight_requires_grad: + # Prepare grad output tensor + # Note: Synchronize tensor-parallel communication and + # make sure required data is available + if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): + # UB does not support overlapping grad output + # all-gather with wgrad GEMM. Also, we can't + # convert row-scaled MXFP8 to column-scaled, so we + # can't reuse the grad output that was gathered + # for the dgrad GEMM. We work around by explicitly + # overlapping the NCCL operation with the dgrad GEMM. + ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + # Get the communication stream from the dgrad GEMM and set it as the current torch stream + dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream() + with torch.cuda.stream(dgrad_comm_stream): + # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather + # This ensures that we don't start until all communication for the dgrad GEMM is complete + grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim( + grad_outputs[0], + ctx.tp_group, + async_op=True, + quantizer=ctx.fc2_grad_output_quantizer, + ) + # Synchronize with the main stream + mxfp8_fc2_grad_output_work.wait() # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -843,22 +867,6 @@ def backward( ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - # Prepare grad output tensor - # Note: Synchronize tensor-parallel communication and - # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output - # all-gather with wgrad GEMM. Also, we can't - # convert row-scaled MXFP8 to column-scaled, so we - # can't reuse the grad output that was gathered - # for the dgrad GEMM. We work around with blocking - # all-gather for column-scaled MXFP8 data. - ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output, _ = gather_along_first_dim( - grad_outputs[0], - ctx.tp_group, - quantizer=ctx.fc2_grad_output_quantizer, - ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(columnwise_usage=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dd1a214979..8a7c0ce2d1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -689,14 +689,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered - # for the dgrad GEMM. We work around with blocking - # all-gather for column-scaled MXFP8 data. + # for the dgrad GEMM. We work around by explicitly + # overlapping the NCCL operation with the dgrad GEMM. ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output, _ = gather_along_first_dim( - grad_output_arg, - ctx.tp_group, - quantizer=ctx.grad_output_quantizer, - ) + # Get the communication stream from the dgrad GEMM and set it as the current torch stream + dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() + with torch.cuda.stream(dgrad_comm_stream): + # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather + # This ensures that we don't start until all communication for the dgrad GEMM is complete + grad_output, grad_output_work = gather_along_first_dim( + grad_output_arg, + ctx.tp_group, + async_op=True, + quantizer=ctx.grad_output_quantizer, + ) + # Synchronize with the main stream + grad_output_work.wait() + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(columnwise_usage=True) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 3b04e991d9..67779d709d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -407,16 +407,26 @@ def _functional_backward( # Initialize grad output if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer): # UB does not support overlapping grad output - # all-gather with wgrad GEMM. Also, MXFP8 does not - # allow reusing the grad output that was gathered for - # the dgrad GEMM. We work around with blocking - # all-gather for column-scaled MXFP8 data. + # all-gather with wgrad GEMM. Also, we can't + # convert row-scaled MXFP8 to column-scaled, so we + # can't reuse the grad output that was gathered + # for the dgrad GEMM. We work around by explicitly + # overlapping the NCCL operation with the dgrad GEMM. grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - dy, _ = gather_along_first_dim( - grad_output, - tensor_parallel_group, - quantizer=grad_output_quantizer, - ) + # Get the communication stream from the dgrad GEMM and set it as the current torch stream + dgrad_comm_stream = ub_comm_dgrad.get_communication_stream() + with torch.cuda.stream(dgrad_comm_stream): + # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather + # This ensures that we don't start until all communication for the dgrad GEMM is complete + dy, dy_work = gather_along_first_dim( + dy_local, + tensor_parallel_group, + async_op=True, + quantizer=grad_output_quantizer, + ) + # Synchronize with the main stream + dy_work.wait() + if tensor_parallel_mode == "column": dy = dy_local if dy is None: From 8d4bdbc2cd85063e9bcf750bdb8eba776956e2d3 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Thu, 12 Jun 2025 17:49:39 -0700 Subject: [PATCH 16/39] Optimize `/ops/fuser.py` by moving computation from `forward` to `__init__` (#1870) * Flatten basic op params during fuser init Signed-off-by: Jan Bielak (cherry picked from commit 949abe97070721b1da5117903067608250f5fb61) * Add caching for is_non_tn_fp8_gemm_supported Signed-off-by: Jan Bielak (cherry picked from commit fd830ae24ffbd2d0727010b1a8a119ca72f61ce5) * Pass fuser to _OperationFuserAutogradFunction.forward and moving computation to __init__ Signed-off-by: Jan Bielak (cherry picked from commit fd808991993958b670726896254b82fcb967fa07) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Pass basic_op_kwargs and is_grad_enabled as parameters rather than in fuser Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/fuser.py | 81 ++++++++++--------------- transformer_engine/pytorch/utils.py | 1 + 2 files changed, 32 insertions(+), 50 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 8ff0242229..cf61d15eff 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): def forward( func_ctx: Optional[torch.autograd.function.FunctionCtx], input_: torch.Tensor, - forward_ops: list[tuple[FusibleOperation, list[int]]], - backward_ops: list[tuple[FusibleOperation, list[int]]], - basic_ops: list[BasicOperation], + fuser: OperationFuser, basic_op_kwargs: list[dict[str, Any]], is_grad_enabled: bool, - num_params: int, - num_extra_inputs: int, *params_and_extra_inputs: torch.nn.Parameter, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -78,20 +74,12 @@ def forward( Context for PyTorch autograd function input_: torch.Tensor Input to first operation in pipeline - forward_ops: list of tuple - Forward pass operations and the indices of the - corresponding basic operations. The order should match - basic_ops. - backward_ops: list of tuple - Backward pass operations and the indices of the - corresponding basic operations. The order should be the - reverse of basic_ops. - basic_ops: list of BasicOperation - Basic operations + fuser: OperationFuser + Container for the pipeline of operations to run basic_op_kwargs: list of dict Keyword arguments to BasicOperation - num_params: int - Number of parameter tensors to include in autograd graph. + is_grad_enabled: bool + Should context be saved for backward *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -106,26 +94,20 @@ def forward( """ # Operation autograd contexts - basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] + basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)] # Unflatten list of parameters and extra tensor inputs - if len(params_and_extra_inputs) != num_params + num_extra_inputs: - raise ValueError( - f"Expected {num_params + num_extra_inputs} extra tensor arguments " - f"({num_params} parameters, {num_extra_inputs} extra inputs), " - f"but got {len(params_and_extra_inputs)}" - ) - _, extra_inputs = _split_tuple(params_and_extra_inputs, num_params) + extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] basic_op_extra_inputs = [] - for op in basic_ops: + for op in fuser._basic_ops: xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) basic_op_extra_inputs.append(xs) # Apply forward ops x = input_ requires_grad = is_grad_enabled and x.requires_grad - extra_outputs = [None for _ in range(len(basic_ops))] - for op, basic_op_idxs in forward_ops: + extra_outputs = [None] * fuser._num_basic_ops + for op, basic_op_idxs in fuser._forward_ops: # Check if backward op is required if is_grad_enabled: @@ -143,9 +125,10 @@ def forward( # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] - prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] + prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] next_ops = [ - basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs + fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None + for idx in basic_op_idxs ] x, fused_op_extra_outputs = op.fuser_forward( [basic_op_ctxs[idx] for idx in basic_op_idxs], @@ -165,7 +148,7 @@ def forward( extra_outputs_flat = [] for idx, ys in enumerate(extra_outputs): ys = list(ys) - num_extra_outputs = basic_ops[idx].num_extra_outputs + num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs if len(ys) != num_extra_outputs: raise RuntimeError( f"Expected op {idx} to generate " @@ -189,11 +172,11 @@ def forward( func_ctx.save_for_backward(*to_save) # Other context - func_ctx.backward_ops = backward_ops - func_ctx.basic_ops = basic_ops + func_ctx.backward_ops = fuser._backward_ops + func_ctx.basic_ops = fuser._basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops] - func_ctx.num_extra_inputs = num_extra_inputs + func_ctx.basic_op_num_params = fuser._num_list_basic_op_params + func_ctx.num_extra_inputs = fuser._num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() @@ -293,13 +276,9 @@ def backward( return ( dx, # input_ - None, # forward_ops - None, # backward_ops - None, # basic_ops + None, # fuser None, # basic_op_kwargs None, # is_grad_enabled - None, # num_params - None, # num_extra_inputs *grad_params_flat, *grad_extra_inputs_flat, ) @@ -346,6 +325,10 @@ def __init__( if fuse_ops: self.fuse_ops() + # Flatten list of parameters + self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()] + self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops] + @classmethod def _fuse_forward_ops( cls, @@ -378,6 +361,11 @@ def __call__( *extra_inputs: torch.Tensor, basic_op_kwargs: Optional[list[dict[str, Any]]] = None, ) -> torch.Tensor | tuple[torch.Tensor, ...]: + # Verify extra input count + if len(extra_inputs) != self._num_extra_inputs: + raise ValueError( + f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}" + ) # Initialization before forward pass for op in self._basic_ops: @@ -385,10 +373,7 @@ def __call__( # Canonicalize op kwargs if basic_op_kwargs is None: - basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] - - # Flatten list of parameters - params = [param for op in self._basic_ops for param in op.parameters()] + basic_op_kwargs = [{}] * self._num_basic_ops # Fuser forward pass is_grad_enabled = torch.is_grad_enabled() @@ -400,14 +385,10 @@ def __call__( args = [None] args += ( input, - self._forward_ops, - self._backward_ops, - self._basic_ops, + self, basic_op_kwargs, is_grad_enabled, - len(params), - self._num_extra_inputs, - *params, + *self._basic_op_params, *extra_inputs, ) return forward_func(*args) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index e66477476f..4adc8cc467 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -448,6 +448,7 @@ def is_bf16_compatible() -> None: return torch.cuda.get_device_capability()[0] >= 8 +@functools.lru_cache(maxsize=None) def is_non_tn_fp8_gemm_supported() -> bool: """Checks whether the device supports non-TN layouts for FP8 GEMMs. From 655512c1230c2173cba5072f33254b5e8a03b1c4 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 12 Jun 2025 19:04:29 -0700 Subject: [PATCH 17/39] [PyTorch] Inference mode disables initializing quantized weights with column-wise usage (#1847) * Do not initialize quantized weights with column-wise usage in inference mode Signed-off-by: Tim Moon * Fix bug in test Signed-off-by: Tim Moon * Use no-grad mode instead of inference mode in tests Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_sanity.py | 82 ++++++++++++++++++- transformer_engine/pytorch/module/base.py | 32 +++++--- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 8 +- .../pytorch/tensor/mxfp8_tensor.py | 1 + 5 files changed, 108 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 2ca133e77b..a7ff2b2a91 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -10,6 +10,7 @@ import pytest import os +import transformer_engine.pytorch from transformer_engine.pytorch.fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -38,9 +39,11 @@ from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8Quantizer, Float8CurrentScalingQuantizer, + Float8Quantizer, + Float8Tensor, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint from test_numerics import reset_rng_states, dtype_tols @@ -1338,3 +1341,80 @@ def backward(ctx, grad_output): # Assert that gradients are the same torch.testing.assert_close(grad_checkpoint, grad_standard) + + +@pytest.mark.parametrize( + "module_name", + ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"), +) +@pytest.mark.parametrize( + "quantization", + (None, "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"), +) +def test_inference_mode( + module_name: str, + quantization: Optional[str], +) -> None: + """Test heuristics for initializing quantized weights""" + + # Tensor dimensions + sequence_length = 32 + hidden_size = 32 + + # Skip invalid configurations + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + # Construct quantization recipe + with_quantization = quantization not in (None, "None") + quantization_recipe = None + if quantization == "fp8_delayed_scaling": + quantization_recipe = recipe.DelayedScaling() + elif quantization == "fp8_current_scaling": + quantization_recipe = recipe.Float8CurrentScaling() + elif quantization == "mxfp8": + quantization_recipe = recipe.MXFP8BlockScaling() + + # Construct module + module = None + with torch.no_grad(): + with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe): + if module_name == "Linear": + module = Linear(hidden_size, hidden_size) + elif module_name == "LayerNormLinear": + module = LayerNormLinear(hidden_size, hidden_size) + elif module_name == "LayerNormMLP": + module = LayerNormMLP(hidden_size, hidden_size) + elif module_name == "GroupedLinear": + module = GroupedLinear(1, hidden_size, hidden_size) + elif module_name == "ops.Linear": + module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size) + + def check_weights(): + """Helper function to check that weight parameters have expected data""" + for param in module.parameters(): + if isinstance(param, Float8Tensor): + assert param._data is not None, "Missing FP8 data" + assert ( + param._transpose is None and param._transpose_invalid + ), "FP8 transpose is not expected for inference" + if isinstance(param, MXFP8Tensor): + assert param._rowwise_data is not None, "Missing row-wise MXFP8 data" + assert ( + param._columnwise_data is None + ), "Column-wise MXFP8 data is not expected for inference" + + # Check that modules have expected weights after initialization + check_weights() + + # Check that modules have expected weights after forward pass + with torch.inference_mode(): + x = torch.zeros(sequence_length, hidden_size, device="cuda") + kwargs = {} + if module_name == "GroupedLinear": + kwargs["m_splits"] = [sequence_length] + with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe): + y = module(x, **kwargs) + check_weights() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index acbd871c7e..3d06a47313 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1183,18 +1183,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) - # If primary weights are in fp8, wrap the parameter as FP8Tensor + # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None if self.primary_weights_in_fp8 and fp8_meta_index is not None: + + # Keep high-precision values on CPU if needed if self.preserve_high_precision_init_val: high_precision_init_val = param.detach().cpu() + # Configure quantizer quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] - assert ( - quantizer is not None - ) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. + if quantizer is None: + raise RuntimeError("Weight quantizer has not been initialized") + quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False + + # Quantize parameter param = quantizer(param) # Redo parameter wrap in case we broke it above @@ -1202,6 +1207,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. param = torch.nn.Parameter(param) + + # Keep high-precision values on CPU if needed if high_precision_init_val is not None: # - Master weights are initialized from model weights, if we use fp8 primary @@ -1245,7 +1252,7 @@ def get_weight_workspace( fsdp_group: Optional[dist_group_type] = None, workspace_dtype: Optional[torch.dtype] = None, ) -> QuantizedTensor: - """Get FP8 workspace buffer and maybe update its values + """Get workspace buffer for weights and maybe update its values The workspace buffer may be cached for future function calls. @@ -1271,13 +1278,16 @@ def get_weight_workspace( for debug quantization, this is dtype of the tensor. """ - # FP8 primary weights + # Handle case where weights are already quantized + # Note: Make sure weights have required usages, but do not + # destroy unnecessary usages since they may be used later. if isinstance(tensor, QuantizedTensor): - if update_workspace and quantizer is not None: - tensor.update_usage( - rowwise_usage=quantizer.rowwise_usage, - columnwise_usage=quantizer.columnwise_usage, - ) + update_rowwise_usage = True if quantizer.rowwise_usage else None + update_columnwise_usage = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise_usage, + columnwise_usage=update_columnwise_usage, + ) return tensor # Try getting workspace from cache diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index dba98b0150..b99952ad2a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -271,7 +271,7 @@ def forward( # Configure quantizer if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=True) + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f1f0475907..375db477b0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -325,8 +325,8 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, @@ -1762,9 +1762,9 @@ def forward( fc2_bias = self.fc2_bias if self.use_bias else None if not self.fp8: if isinstance(fc1_weight, Float8Tensor): - fc1_weight = fc1_weight.from_float8() + fc1_weight = fc1_weight.dequantize() if isinstance(fc2_weight, Float8Tensor): - fc2_weight = fc2_weight.from_float8() + fc2_weight = fc2_weight.dequantize() # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index c930cdbff5..e20927c24a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -384,6 +384,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Quantize to FP8 assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.internal = False self.data = self._quantizer.quantize(tensor) if self.requires_grad != tensor.requires_grad: self.requires_grad_(requires_grad=tensor.requires_grad) From e963e4a95bd9057eb190b4d06f73fa76753ecec0 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 12 Jun 2025 19:05:14 -0700 Subject: [PATCH 18/39] [PyTorch] Add support for FP8 current scaling in operation-based API (#1858) * Add FP8 current scaling to te.Sequential tests Signed-off-by: Tim Moon * Helper function for test/ref tensors does not produce quantized tensor by default Signed-off-by: Tim Moon * Add FP8 current scaling to distributed te.Sequential tests Signed-off-by: Tim Moon * Add FP8 current scaling to Userbuffers te.Sequential tests Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Debug MXFP8 tests Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/distributed/test_fusible_ops.py | 185 +++++---- .../test_fusible_ops_with_userbuffers.py | 74 ++-- tests/pytorch/test_fusible_ops.py | 350 +++++++++--------- tests/pytorch/utils.py | 22 ++ transformer_engine/pytorch/distributed.py | 2 +- .../pytorch/ops/basic/basic_linear.py | 34 +- .../ops/fused/userbuffers_forward_linear.py | 12 +- 7 files changed, 368 insertions(+), 311 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 472d20c508..6f025817df 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -22,19 +22,28 @@ import transformer_engine.pytorch as te from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import dtype_tols, make_recipe + # Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: - quantization_list.append("fp8") + quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") @@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None: @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -76,78 +86,55 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ + + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device), + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + + # Make sure reference and test tensors match each other ref.copy_(test) + ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: - """Estimated numerical error for a datatype - - Based on tolerances for torch.testing.assert_close. - - """ - - # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat8E4M3: - return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == tex.DType.kFloat8E5M2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 - dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - }[dtype] - - # PyTorch dtypes - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float64: - return dict(rtol=1e-7, atol=1e-7) - raise ValueError(f"Unsupported dtype ({dtype})") - - -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - def _test_all_reduce( *, - local_size: int = 17, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -156,22 +143,25 @@ def _test_all_reduce( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, local_size] - out_shape = [local_size] + in_shape = [world_size, local_size, local_size] + out_shape = [local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -199,10 +189,10 @@ def _test_all_reduce( def _test_all_gather( *, - local_size: int = 13, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -211,26 +201,29 @@ def _test_all_gather( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, local_size] - out_shape = [world_size, world_size * local_size] + in_shape = [world_size, local_size, local_size] + out_shape = [world_size, world_size * local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation - y_ref = x_ref.tile((world_size, 1)).reshape(out_shape) + y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape) y_ref.backward(dy_ref) # Convert to distributed tensors @@ -257,10 +250,10 @@ def _test_all_gather( def _test_reduce_scatter( *, - local_size: int = 11, + local_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8: bool = False, + quantization: Optional[str] = None, ) -> None: # Distributed process group @@ -269,22 +262,25 @@ def _test_reduce_scatter( world_size = torch.distributed.get_world_size(process_group) # Tensor dimensions - in_shape = [world_size, world_size * local_size] - out_shape = [world_size, local_size] + in_shape = [world_size, world_size * local_size, local_size] + out_shape = [world_size, local_size, local_size] # Random data reset_rng() + with_quantization = quantization is not None x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) # Plain PyTorch implementation @@ -324,7 +320,11 @@ def _test_basic_linear( tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + + # Skip invalid configurations quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return # Distributed process group process_group = world_group() @@ -348,30 +348,23 @@ def _test_basic_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -468,7 +461,11 @@ def _test_linear( tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + + # Skip invalid configurations quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return # Distributed process group process_group = world_group() @@ -492,21 +489,16 @@ def _test_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -520,13 +512,11 @@ def _test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -773,9 +763,10 @@ def run_parallel_tests() -> None: if rank == 0: print(f"Running _test_all_reduce") _test_all_reduce() - if rank == 0: - print(f"Running _test_all_gather") - _test_all_gather() + for quantization in quantization_list: + if rank == 0: + print(f"Running _test_all_gather with quantization={quantization}") + _test_all_gather(quantization=quantization) if rank == 0: print(f"Running _test_reduce_scatter") _test_reduce_scatter() diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 42070ea0f4..68083a0e03 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -26,21 +26,25 @@ UserbuffersBackwardLinear, UserbuffersForwardLinear, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.utils import is_bf16_compatible # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import dtype_tols, str_to_dtype +from utils import dtype_tols, make_recipe, str_to_dtype # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: - quantization_list.append("fp8") + quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") @@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None: @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -131,47 +136,49 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ - # Random data + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) - # Make copy of tensor + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device), + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() - # Make sure reference and test tensors represent exact same values + # Make sure reference and test tensors match each other ref.copy_(test) - # Return reference and test tensors ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - def _test_linear( *, model_config: ModelConfig, @@ -201,21 +208,16 @@ def _test_linear( reset_rng() x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(w_test, QuantizedTensor): - w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -229,13 +231,11 @@ def _test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index b1706db612..f78fa581b5 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,6 +7,8 @@ from collections.abc import Iterable import io import math +import pathlib +import sys from typing import Optional import pytest @@ -24,10 +26,20 @@ ForwardLinearBiasAdd, ) from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8CurrentScalingQuantizer, + Float8Quantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent)) +from utils import dtype_tols, make_recipe + # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -40,6 +52,13 @@ # Supported devices _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] +# Supported quantization recipes +_quantization_list: list[Optional[str]] = [None] +if fp8_available: + _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) +if mxfp8_available: + _quantization_list.append("mxfp8") + def maybe_skip_quantization( quantization: Optional[str], @@ -47,13 +66,14 @@ def maybe_skip_quantization( dims: Optional[Iterable[int] | int] = None, device: Optional[torch.device | str] = None, ) -> None: + """Skip test case if a quantization scheme is not supported""" # Don't skip if there is no quantization if quantization is None: return # Check if quantization scheme is supported - if quantization == "fp8" and not fp8_available: + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) @@ -61,7 +81,7 @@ def maybe_skip_quantization( if dims is not None: if not isinstance(dims, Iterable): dims = (dims,) - if quantization == "fp8": + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("FP8 GEMMs require dims that are divisible by 16") elif quantization == "mxfp8": @@ -73,47 +93,15 @@ def maybe_skip_quantization( pytest.skip("Quantization is only supported on CUDA devices") -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: - """Estimated numerical error for a datatype - - Based on tolerances for torch.testing.assert_close. - - """ - - # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat8E4M3: - return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == tex.DType.kFloat8E5M2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 - dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - }[dtype] - - # PyTorch dtypes - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float64: - return dict(rtol=1e-7, atol=1e-7) - raise ValueError(f"Unsupported dtype ({dtype})") - - @torch.no_grad() def make_reference_and_test_tensors( shape: int | Iterable[int], + quantization: Optional[str] = None, ref_dtype: torch.dtype = torch.float64, ref_device: torch.device = "cpu", test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", - test_is_fp8: bool = False, + test_is_quantized: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -122,39 +110,49 @@ def make_reference_and_test_tensors( operations in high precision. The test tensor is intended for use in Transformer Engine operations. + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + """ + + # Random reference tensor ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Construct test tensor from reference tensor test = ref.to(device=test_device, dtype=test_dtype) - if test_is_fp8: + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), fp8_dtype=tex.DType.kFloat8E4M3, ) test = quantizer(test) - elif test.data_ptr() == ref.data_ptr(): - test = test.clone() + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + + # Make sure reference and test tensors match each other ref.copy_(test) + ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test -def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: - """Make recipe for quantization scheme""" - if name is None: - return None - if name == "fp8": - return transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - if name == "mxfp8": - return transformer_engine.common.recipe.MXFP8BlockScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - raise ValueError(f"Unsupported quantization scheme ({name})") - - class TestSequential: """Tests for sequential container""" @@ -364,7 +362,7 @@ def test_fp8_scale_update( @pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_dtype_cast( self, *, @@ -377,8 +375,9 @@ def test_dtype_cast( """Check dtype cast functions""" # Skip invalid configurations - maybe_skip_quantization(quantization, device=device) + in_shape = (size, size) with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data dtype = torch.float32 @@ -388,9 +387,9 @@ def test_dtype_cast( dtype = torch.bfloat16 w_ref, w_test = make_reference_and_test_tensors( (size, size), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=with_quantization, ) # Construct operation @@ -412,11 +411,11 @@ def test_dtype_cast( assert isinstance(op.weight, QuantizedTensor) == with_quantization assert op.weight.dtype == final_dtype w_test = op.weight.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) + torch.testing.assert_close(w_test, w_ref, **dtype_tols(dtype)) # Check forward and backward pass x = torch.zeros( - (size, size), + in_shape, dtype=init_dtype, device=device, requires_grad=True, @@ -429,7 +428,7 @@ def test_dtype_cast( @pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_pyt_autocast( self, *, @@ -444,8 +443,9 @@ def test_pyt_autocast( device = torch.device(device) # Skip invalid configurations + in_shape = (size, size) quantized_compute = quantization is not None - maybe_skip_quantization(quantization) + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Construct operation recipe = make_recipe(quantization) @@ -454,7 +454,7 @@ def test_pyt_autocast( # Check forward and backward pass x = torch.zeros( - (size, size), + in_shape, dtype=model_dtype, device=device, requires_grad=True, @@ -492,33 +492,34 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_identity( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -554,7 +555,7 @@ def test_identity( ), ) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling")) def test_reshape( self, *, @@ -562,31 +563,32 @@ def test_reshape( dtype: torch.dtype, device: torch.device = "cuda", memory_format: torch.memory_format = torch.contiguous_format, - fp8: bool, + quantization: Optional[str], ) -> None: in_shape, out_shape = shapes # Skip invalid configurations if memory_format == torch.channels_last and len(in_shape) != 4: pytest.skip("torch.channels_last only supports 4D tensors") - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, device=device) + with_quantization = quantization is not None # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) x_test = x_test.contiguous(memory_format=memory_format) x_test = x_test.detach().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( x_ref.reshape(out_shape).size(), + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -615,10 +617,10 @@ def test_reshape( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("size", (1, 7, 32)) - @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (4, 3, 8, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", _devices) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_bias( self, *, @@ -626,24 +628,23 @@ def test_bias( in_shape: Iterable[int], dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: # Make input and bias shapes consistent in_shape = list(in_shape)[:-1] + [size] # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) b_ref, b_test = make_reference_and_test_tensors( size, @@ -652,8 +653,10 @@ def test_bias( ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -678,7 +681,7 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True)) def test_quantize( @@ -694,25 +697,26 @@ def test_quantize( """Quantize""" # Skip invalid configurations - maybe_skip_quantization(quantization) + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device) + if quantization == "mxfp8": + maybe_skip_quantization(quantization, dims=in_shape) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - requires_grad=False, - test_is_fp8=True, + requires_grad=True, ) - x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, - test_is_fp8=True, ) - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = x_ref @@ -721,13 +725,14 @@ def test_quantize( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) recipe = make_recipe(quantization) - with te.fp8_autocast(fp8_recipe=recipe): + with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) # Check tensor types - assert isinstance(y_test, QuantizedTensor) == cast_forward - assert isinstance(x_test.grad, QuantizedTensor) == cast_backward + if with_quantization: + assert isinstance(y_test, QuantizedTensor) == cast_forward + assert isinstance(x_test.grad, QuantizedTensor) == cast_backward # Check values tols = dict(rtol=0, atol=0) @@ -762,10 +767,25 @@ def _test_basic_linear( # Skip invalid configurations maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) - if quantization == "fp8" and quantized_output and not quantized_compute: - pytest.skip("FP8 output is only supported with FP8 GEMMs") - if quantization == "fp8" and quantized_grad_input and not quantized_compute: - pytest.skip("FP8 grad input is only supported with FP8 GEMMs") + quantization_needed = any( + ( + quantized_compute, + quantized_input, + quantized_weight, + quantized_output, + quantized_grad_output, + quantized_grad_input, + ) + ) + if quantization is None and quantization_needed: + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not quantization_needed: + pytest.skip("Quantization scheme is not used") + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): + if quantized_output and not quantized_compute: + pytest.skip("FP8 output is only supported with FP8 GEMMs") + if quantized_grad_input and not quantized_compute: + pytest.skip("FP8 grad input is only supported with FP8 GEMMs") if quantization == "mxfp8" and quantized_output: pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") if quantization == "mxfp8" and quantized_grad_input: @@ -774,28 +794,25 @@ def _test_basic_linear( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_input), + test_is_quantized=quantized_input, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_grad_output), + test_is_quantized=quantized_grad_output, requires_grad=False, ) - if isinstance(dy_test, QuantizedTensor): - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -858,7 +875,7 @@ def _test_basic_linear( @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) def test_basic_linear( self, @@ -880,7 +897,7 @@ def test_basic_linear( ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_input", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @@ -899,6 +916,8 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" + if quantization is None: + pytest.skip("Skipping case without quantization") self._test_basic_linear( dtype=torch.bfloat16, quantization=quantization, @@ -911,7 +930,8 @@ def test_basic_linear_quantized( ) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) @@ -924,6 +944,7 @@ def test_linear( dtype: torch.dtype = torch.float32, device: torch.device = "cuda", quantization: Optional[str], + quantized_compute: bool, quantized_weight: bool, input_requires_grad: bool, weight_requires_grad: bool, @@ -936,26 +957,25 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - with torch.no_grad(): - if isinstance(x_test, QuantizedTensor): - x_test = x_test.dequantize() - x_test.requires_grad_(requires_grad=input_requires_grad) w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -966,6 +986,7 @@ def test_linear( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1022,7 +1043,7 @@ def test_linear( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_layer_norm( self, *, @@ -1192,7 +1213,7 @@ def test_layer_norm_autocast( @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_rmsnorm( self, *, @@ -1327,14 +1348,14 @@ def test_l2normalization( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_add_in_place( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: """Add two tensors @@ -1343,28 +1364,30 @@ def test_add_in_place( """ # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x1_ref, x1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) x2_ref, x2_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -1381,7 +1404,7 @@ def test_add_in_place( # Check results tols = dtype_tols(dtype) - if fp8: + if with_quantization: tols = dtype_tols(x1_test._fp8_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1392,14 +1415,14 @@ def test_add_in_place( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("quantization", _quantization_list) def test_make_extra_output( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype, device: torch.device, - fp8: bool, + quantization: Optional[str], ) -> None: """Output tensor twice @@ -1408,28 +1431,31 @@ def test_make_extra_output( """ # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_quantized=with_quantization, ) dy1_ref, dy1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, + test_is_quantized=with_quantization, requires_grad=False, ) @@ -1455,7 +1481,7 @@ def test_make_extra_output( @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("cache_quantized_input", (False, True)) def test_activation( self, @@ -1478,26 +1504,21 @@ def test_activation( quantized_compute = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device) if cache_quantized_input: - maybe_skip_quantization("fp8", device=device) + maybe_skip_quantization("fp8_current_scaling", device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization="fp8_current_scaling" if cache_quantized_input else None, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, requires_grad=False, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() - dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref: torch.Tensor @@ -1540,8 +1561,6 @@ def test_activation( tols = dtype_tols(dtype) if quantized_compute or cache_quantized_input: tols = dtype_tols(tex.DType.kFloat8E4M3) - if activation == "relu" and not cache_quantized_input: - tols = {"atol": 0, "rtol": 0} # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1550,7 +1569,7 @@ def test_activation( torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True)) def test_swiglu( @@ -1628,7 +1647,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_forward_linear_bias_activation( self, @@ -1660,18 +1679,15 @@ def test_forward_linear_bias_activation( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if quantized_compute: - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1682,6 +1698,7 @@ def test_forward_linear_bias_activation( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1738,7 +1755,7 @@ def test_forward_linear_bias_activation( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_forward_linear_bias_add( self, *, @@ -1767,18 +1784,15 @@ def test_forward_linear_bias_add( # Random data x1_ref, x1_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x1_test, QuantizedTensor): - with torch.no_grad(): - x1_test = x1_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1794,6 +1808,7 @@ def test_forward_linear_bias_add( ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1852,7 +1867,7 @@ def test_forward_linear_bias_add( torch.testing.assert_close(db_test, b_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) def test_backward_linear_add( self, *, @@ -1880,27 +1895,26 @@ def test_backward_linear_add( # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=quantized_compute, ) - if isinstance(x_test, QuantizedTensor): - with torch.no_grad(): - x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, - test_is_fp8=(quantized_compute or quantized_weight), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, ) dy2_ref, dy2_test = make_reference_and_test_tensors( out_shape, + quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -1964,7 +1978,7 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( self, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 450c24da33..f4a8ce69c6 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -7,6 +7,7 @@ import torch import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine_torch as tex @@ -83,3 +84,24 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: if dtype == torch.float8_e5m2: return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 raise ValueError(f"Unsupported dtype ({dtype})") + + +def make_recipe(name: Optional[str]) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name in ("fp8", "fp8_delayed_scaling"): + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "fp8_current_scaling": + return transformer_engine.common.recipe.Float8CurrentScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "fp8_block_scaling": + return transformer_engine.common.recipe.Float8BlockScaling() + raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 55246a3d1d..868fc3a27a 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -947,7 +947,7 @@ def _all_gather_fp8( out = quantizer.make_empty(out_shape, dtype=dtype, device=device) elif isinstance(inp, Float8Tensor): out = inp.make_like(inp, shape=out_shape) - out._data = torch.empty_like( + out._data = torch.empty( out_shape, dtype=torch.uint8, device=inp.device, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 0e786ca96f..4bd94b3ad2 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -22,7 +22,7 @@ from ...fp8 import FP8GlobalStateManager from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer, QuantizedTensor -from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase @@ -324,12 +324,38 @@ def pre_forward(self, *args, **kwargs) -> None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + # Recipe-specific configuration + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + if any( + not isinstance(q, Float8CurrentScalingQuantizer) + for q in (input_quantizer, weight_quantizer, grad_output_quantizer) + ): + raise RuntimeError( + "FP8 current-scaling recipe is enabled, " + f"but input quantizer is {input_quantizer.__class__.__name__}, " + f"weight quantizer is {weight_quantizer.__class__.__name__}, " + f"grad output quantizer is {grad_output_quantizer.__class__.__name__}" + ) + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + if self.sequence_parallel and self.tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + if self.sequence_parallel and self.tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization # recipe changed - if isinstance(weight_quantizer, Float8Quantizer) and isinstance( - weight, Float8TensorBase - ): + if isinstance( + weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ) and isinstance(weight, Float8TensorBase): weight._quantizer = weight_quantizer @staticmethod diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index c35d029403..0078f7ae65 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -21,7 +21,7 @@ _2X_ACC_FPROP, ) from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer -from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...utils import canonicalize_device, canonicalize_dtype from ..basic import BasicLinear, Bias, ReduceScatter @@ -208,7 +208,9 @@ def _functional_forward( if input_quantizer is not None: if not isinstance(x_local, QuantizedTensorBase): input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - if isinstance(input_quantizer, Float8Quantizer): + if isinstance( + input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): input_quantizer.set_usage(columnwise=False) x_local = input_quantizer(x_local) input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -327,8 +329,10 @@ def fuser_forward( grad_input_quantizer = None if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() - if not recipe.delayed() and not recipe.mxfp8(): - raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe") + if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): + raise RuntimeError( + f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" + ) input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) grad_output_quantizer = linear_op.get_quantizer("backward", 0) From 7b94bd99b24af9f898b6995b86ecc976df18d47b Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Fri, 13 Jun 2025 05:53:18 +0200 Subject: [PATCH 19/39] [common] Added support of FP4 data type (#1779) * Added support of FP4 data type Signed-off-by: Oleg Goncharov * Refactoring to BitsNum in progress Signed-off-by: Oleg Goncharov * Fixed compilation errors. All C++ tests passed Signed-off-by: Oleg Goncharov * Fixed a typo Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added FP4 guard to TMA tensor descriptor data type Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed errors in JAX C++ extensions Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed dummy NVFP4 C++ test file Signed-off-by: Oleg Goncharov * Make pytorch changes Signed-off-by: Kirthi Shankar Sivamani * Refactored the code per the review notes. Fixed JAX build error. Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed unnecessary static casts Signed-off-by: Oleg Goncharov * Typo fix Signed-off-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> * Pass correct num bits to create_2D_tensor_map; fixes CI Signed-off-by: Kirthi Shankar Sivamani * inline funcs Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Oleg Goncharov Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/cpp/operator/test_normalization.h | 3 +- tests/cpp/test_common.cu | 61 +++++++------- tests/cpp/test_common.h | 77 ++++++++++++++---- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 39 +++++---- transformer_engine/common/common.cu | 49 ++++++++---- transformer_engine/common/common.h | 79 ++++++++++++++++++- .../common/fused_attn/context_parallel.cu | 4 +- .../fused_attn_f16_arbitrary_seqlen.cu | 20 ++--- .../common/fused_attn/fused_attn_fp8.cu | 16 ++-- .../transformer_engine/transformer_engine.h | 66 +++++++++++++--- .../common/normalization/common.cpp | 5 +- .../common/transformer_engine.cpp | 40 ++++++++-- .../common/transpose/cast_transpose_fusion.cu | 11 +-- .../common/transpose/multi_cast_transpose.cu | 4 +- .../quantize_transpose_square_blockwise.cu | 2 +- .../common/transpose/transpose_fusion.cu | 11 +-- .../common/util/cast_gated_kernels.cuh | 27 ++++--- .../common/util/cast_kernels.cuh | 16 ++-- .../common/util/dequantize_kernels.cuh | 4 +- transformer_engine/common/util/padding.cu | 4 +- transformer_engine/common/utils.cuh | 4 + transformer_engine/pytorch/csrc/common.h | 12 +-- .../pytorch/csrc/extensions/attention.cpp | 6 +- 23 files changed, 391 insertions(+), 169 deletions(-) diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index 368ffa66c9..f8dfb9f6eb 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -67,7 +67,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const // Remove the use_cudnn check here when it is supported by both backends. const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; - if constexpr (std::is_same_v || std::is_same_v){ + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v){ compute_t g = static_cast(gamma); if (zero_centered_gamma) { g += static_cast(1.f); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4c78ebedb5..0f64d7c01b 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) { return true; } -size_t typeToSize(DType type) { +size_t typeToNumBits(DType type) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, { return TypeInfo::size; @@ -62,7 +62,8 @@ const std::string &typeName(DType type) { {DType::kBFloat16, "bfloat16"}, {DType::kFloat8E4M3, "float8e4m3"}, {DType::kFloat8E5M2, "float8e5m2"}, - {DType::kFloat8E8M0, "float8e8m0"}}; + {DType::kFloat8E8M0, "float8e8m0"}, + {DType::kFloat4E2M1, "float4e2m1"}}; return name_map.at(type); } @@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){ struct scale_inv_meta { std::vector shape; DType type; - size_t type_size; + size_t type_size_bits; + size_t bytes() const noexcept { + return (product(shape) * type_size_bits) / 8; + } }; +size_t bytes(const NVTEShape& shape, const DType type) { + return (product(shape) * typeToNumBits(type)) / 8; +} + NVTEShape convertShape(const std::vector& s) { return nvte_make_shape(s.data(), s.size()); } @@ -122,7 +130,7 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret; ret.shape = {1}; ret.type = DType::kFloat32; - ret.type_size = sizeof(float); + ret.type_size_bits = typeToNumBits(DType::kFloat32); return {ret, ret}; } if (scaling_mode == NVTE_MXFP8_1D_SCALING) { @@ -152,8 +160,8 @@ std::pair get_scales(const NVTEShape& shape, } ret_rowwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0; - ret_rowwise.type_size = sizeof(uint8_t); - ret_colwise.type_size = sizeof(uint8_t); + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); return {ret_rowwise, ret_colwise}; } @@ -179,8 +187,8 @@ std::pair get_scales(const NVTEShape& shape, } ret_rowwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32; - ret_rowwise.type_size = sizeof(float); - ret_colwise.type_size = sizeof(float); + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32); return {ret_rowwise, ret_colwise}; } @@ -205,8 +213,8 @@ std::pair get_scales(const NVTEShape& shape, } ret_rowwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32; - ret_rowwise.type_size = sizeof(float); - ret_colwise.type_size = sizeof(float); + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32); return {ret_rowwise, ret_colwise}; } @@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name, gen_.seed(seed); rowwise_ = rowwise; columnwise_ = columnwise; - size_t s = typeToSize(type); - size_t total_size = product(shape) * s; + size_t total_size = bytes(shape, type); void *dptr_rowwise = nullptr; void *dptr_columnwise = nullptr; cpu_data_rowwise_ = nullptr; @@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name, } else { auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); - auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; - auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto rowwise_scale_size = rowwise_scale_meta.bytes(); + auto columnwise_scale_size = colwise_scale_meta.bytes(); auto scale_shape = rowwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape; if (rowwise) { @@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name, void Tensor::to_cpu() const { const NVTEShape s = tensor_.shape(); - const size_t size = product(s) * typeToSize(tensor_.dtype()); + const size_t size = bytes(s, tensor_.dtype()); if (rowwise_) { cudaMemcpy(cpu_data_rowwise_.get(), tensor_.get_rowwise_data().data_ptr, @@ -360,14 +367,14 @@ void Tensor::to_cpu() const { auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { - auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + auto scale_size = rowwise_scale_meta.bytes(); cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), tensor_.get_rowwise_scale_inv().data_ptr, scale_size, cudaMemcpyDeviceToHost); } if (columnwise_) { - auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto scale_size = colwise_scale_meta.bytes(); cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), tensor_.get_columnwise_scale_inv().data_ptr, scale_size, @@ -378,34 +385,32 @@ void Tensor::to_cpu() const { void Tensor::from_cpu() const { const NVTEShape s = tensor_.shape(); - const size_t size = product(s) * typeToSize(tensor_.dtype()); + const size_t size = bytes(s, tensor_.dtype()); if (rowwise_) { - cudaMemcpy(tensor_.get_rowwise_data().data_ptr, - cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice); + cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size, + cudaMemcpyHostToDevice); } if (columnwise_) { - cudaMemcpy(tensor_.get_columnwise_data().data_ptr, - cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); + cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, + cudaMemcpyHostToDevice); } if (isFp8Type(dtype())) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if (tensor_.amax() != nullptr){ - cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { - auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + auto scale_size = rowwise_scale_meta.bytes(); cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, rowwise_scale_inv_cpu_data_.get(), scale_size, cudaMemcpyHostToDevice); } if (columnwise_) { - auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto scale_size = colwise_scale_meta.bytes(); cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, columnwise_scale_inv_cpu_data_.get(), scale_size, cudaMemcpyHostToDevice); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 82ff66bdcf..3597c94d85 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -10,10 +10,15 @@ #include #include #include +#include +#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) #include #include #include +#if FP4_TYPE_SUPPORTED +#include +#endif #include #include @@ -55,19 +60,32 @@ using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; using fp8e8m0 = uint8_t; +#if FP4_TYPE_SUPPORTED +using fp4e2m1 = __nv_fp4_e2m1; +#endif template -struct TypeInfo{ - using types = std::tuple; +struct BitsNumber; + +#if FP4_TYPE_SUPPORTED +template <> +struct BitsNumber { + static constexpr size_t num_bits = 4; +}; +#endif + +template +struct BitsNumber { + static constexpr size_t num_bits = 8 * sizeof(T); +}; + +template +struct TypeInfo { +#if FP4_TYPE_SUPPORTED + using types = std::tuple; +#else + using types = std::tuple; +#endif template struct Helper { @@ -94,7 +112,7 @@ struct TypeInfo{ } constexpr static DType dtype = getType(); - constexpr static size_t size = sizeof(T); + constexpr static size_t size = BitsNumber::num_bits;; }; class Tensor { @@ -416,9 +434,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } inline float srelu(const float x) { return x > 0 ? x * x : 0; } inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } -size_t typeToSize(DType type); +size_t typeToNumBits(DType type); size_t product(const NVTEShape &shape); size_t product(const std::vector &shape); +size_t bytes(const NVTEShape& shape, const DType type); size_t first_dimension(const std::vector &shape); size_t last_dimension(const std::vector &shape); @@ -464,6 +483,16 @@ constexpr int32_t blackwellComputeCapability = 100; } // namespace test +#if FP4_TYPE_SUPPORTED +#define SWITCH_FP4_TYPE_HANDLE(type, ...) \ + case DType::kFloat4E2M1: { \ + using type = fp4e2m1; \ + { __VA_ARGS__ } \ + } break; +#else +#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing +#endif + #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -515,8 +544,16 @@ constexpr int32_t blackwellComputeCapability = 100; {__VA_ARGS__} \ } \ break; \ + case DType::kFloat8E8M0: \ + { \ + using type = fp8e8m0; \ + {__VA_ARGS__} \ + } \ + break; \ + SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ - NVTE_ERROR("Invalid type."); \ + printf("dtype: %d\n", static_cast(dtype)); \ + NVTE_ERROR("Invalid type MARKED TEST."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ @@ -535,7 +572,15 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Invalid type MARKED TEST 2."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ + default: \ + NVTE_ERROR("Invalid type MARKED TEST 3."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ @@ -560,5 +605,5 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Invalid type MARKED TEST 4."); \ } diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 120fad14ab..40595ea988 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -196,7 +196,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz if (param_type == NVTETensorParam::kNVTERowwiseData || param_type == NVTETensorParam::kNVTEColumnwiseData) { // Offset data pointer - param_dptr += chunk_offset * typeToSize(param_dtype); + param_dptr += get_buffer_size_bytes(chunk_offset, param_dtype); param_shape = chunk_shape; if (param_type == NVTETensorParam::kNVTEColumnwiseData && @@ -217,7 +217,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz } else { chunk_scale_height /= 32; } - param_dptr += (chunk_offset / 32) * typeToSize(param_dtype); + param_dptr += get_buffer_size_bytes(chunk_offset / 32, param_dtype); param_shape = {chunk_scale_height, chunk_scale_width}; } @@ -236,7 +236,7 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); // Update chunk with offset data pointers from the communication buffer - auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size()); + auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + chunk_offset * _ubuf.element_size(); if (chunk.dptr() != nullptr) { chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); } @@ -269,7 +269,7 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType "or 2 (multi-atomic)."); NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); - size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); @@ -306,7 +306,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0)); // Communication: AG and RS - int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size + int comm_elements = _ubuf.bytes() / 2; // UBUF uses 2Byte element size if (comm_type == CommOverlapType::AG) { allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); @@ -606,7 +606,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); - size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); int buffer_chunk_bytes = buffer_bytes / tp_size; _num_ubuf_chunks = tp_size; if (_is_reduce_scatter) { @@ -704,7 +704,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( assert(pre_gelu_out.numel() == 0); // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int comm_bytes = _ubufs[0].bytes(); // Create an GEMM output buffer with N+1 chunks in a contiguous memory void *D_buffer_ptr; @@ -762,21 +762,20 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); - NVTE_CHECK_CUDA( - cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), - _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send[0])); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), + _ubufs[_self_chunk_id].bytes(), cudaMemcpyDeviceToDevice, + _stream_send[0])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); } // Copy the first GEMM output chunk to the end chunk position of D_buffer char *src_ptr = reinterpret_cast(D_buffer.dptr()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes, + NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + D.bytes(), src_ptr, D_chunk_bytes, cudaMemcpyDeviceToDevice, stream_main)); // Return the last N rows of D_buffer - NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.bytes(), cudaMemcpyDeviceToDevice, stream_main)); // Clean up buffer allocation @@ -806,7 +805,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const size_t n_chunk = _ubufs[0].size(0); // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int comm_bytes = _ubufs[0].bytes(); const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); @@ -882,8 +881,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send[0])); + _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, + _stream_send[0])); } } } else { @@ -935,8 +934,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send[0])); + _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, + _stream_send[0])); } } } @@ -966,7 +965,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( _ub_comm->cga_size = _cga_size; // Get communication and GEMM input chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int comm_bytes = _ubufs[0].bytes(); // Reset counters int *counter_ptr = reinterpret_cast(_counter.dptr()); @@ -1033,7 +1032,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); size_t n_chunk = _ubufs[0].size(0); - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int comm_bytes = _ubufs[0].bytes(); // Get input and workspace data pointers size_t input_chunk_size = n_chunk * k; diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index cdaf60f778..192c915a84 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -116,13 +116,20 @@ void checkCuDriverContext(CUstream stream) { } CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { - static const std::unordered_map dtypeMapping = { - {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, - {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, - {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, - {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, - {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, - {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; + static const std::unordered_map dtypeMapping = []() { + std::unordered_map typeMapping = { + {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, + {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, + {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, + {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; +#if FP4_TYPE_SUPPORTED + typeMapping.insert( + {DType::kFloat4E2M1, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B}); +#endif + return typeMapping; + }(); return dtypeMapping.at(dtype); } @@ -130,7 +137,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_size) { + const uint32_t offset_elems, const size_t type_num_bits) { // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { @@ -142,7 +149,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, uint64_t size[rank] = {globalX, globalY}; // The stride is the number of bytes to traverse from the first element of one row to the next - uint64_t stride[rank - 1] = {stride_elems * type_size}; + uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / 8}; // The boxSize is the size of the shared memory buffer that is used as the // source/destination of a TMA transfer @@ -152,15 +159,15 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, uint32_t elemStride[rank] = {1, 1}; const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype); - void *dataPtr = - reinterpret_cast(reinterpret_cast(tensor.dptr) + offset_elems * type_size); + void *dataPtr = reinterpret_cast(reinterpret_cast(tensor.dptr) + + (offset_elems * type_num_bits) / 8); NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), "Tensor data pointer must be 16B aligned"); - const int TMA_needed_size = TMA_gmem_alignment / type_size; - NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size, - "-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX); + const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; + NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, + "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); // Create the tensor descriptor. NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( @@ -206,4 +213,18 @@ std::vector> convert_tensor_array(NVTETensor **nvte_tensor return ret; } +size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype) { + return (elements_num * typeToNumBits(buffer_dtype)) / 8; +} + +size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last, + const DType buffer_dtype) { + if (buffer_dtype == DType::kFloat4E2M1) { + NVTE_CHECK(dim_last % 2 == 0, + "Last dimension of a tensor with FP4 type of data must be an even number!"); + } + const size_t elements_num = dim_first * dim_last; + return get_buffer_size_bytes(elements_num, buffer_dtype); +} + } // namespace transformer_engine diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 21f43e31cc..22b448a001 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -8,9 +8,15 @@ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #include +#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) + #include #include #include +#if FP4_TYPE_SUPPORTED +#include +#endif + #include #include @@ -183,6 +189,7 @@ struct Tensor { } break; case NVTE_MXFP8_1D_SCALING: + case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; } else { @@ -268,6 +275,13 @@ constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); } +template +constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T1 &N, const T2 &M) { + static_assert(std::is_integral::value && std::is_integral::value, + "Integral type required."); + return DIVUP(static_cast(N), static_cast(M)) * M; +} + using byte = uint8_t; using int16 = int16_t; using int32 = int32_t; @@ -280,6 +294,9 @@ using fp8e5m2 = __nv_fp8_e5m2; #if CUDA_VERSION >= 12080 using fp8e8m0 = __nv_fp8_e8m0; #endif +#if FP4_TYPE_SUPPORTED +using fp4e2m1 = __nv_fp4_e2m1; +#endif using e8m0_t = uint8_t; namespace detail { @@ -303,11 +320,21 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) #if CUDA_VERSION >= 12080 TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) #endif +#if FP4_TYPE_SUPPORTED +TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1) +#endif #undef TRANSFORMER_ENGINE_TYPE_NAME template struct TypeExtrema; +#if FP4_TYPE_SUPPORTED +template <> +struct TypeExtrema { + static constexpr float max = 6.0f; +}; +#endif + template <> struct TypeExtrema { static constexpr float max = 448.0f; @@ -337,9 +364,28 @@ struct TypeExtrema { } // namespace detail +template +struct BitsNumber; + +#if FP4_TYPE_SUPPORTED +template <> +struct BitsNumber { + static constexpr size_t num_bits = 4; +}; +#endif + +template +struct BitsNumber { + static constexpr size_t num_bits = 8 * sizeof(T); +}; + template struct TypeInfo { +#if FP4_TYPE_SUPPORTED + using types = std::tuple; +#else using types = std::tuple; +#endif template struct Helper { @@ -364,11 +410,21 @@ struct TypeInfo { } constexpr static DType dtype = getType(); - constexpr static size_t size = sizeof(T); + constexpr static size_t size = BitsNumber::num_bits; constexpr static float max_finite_value = detail::TypeExtrema::max; constexpr static const char *name = detail::type_name(); }; +#if FP4_TYPE_SUPPORTED +#define SWITCH_FP4_TYPE_HANDLE(type, ...) \ + case DType::kFloat4E2M1: { \ + using type = fp4e2m1; \ + { __VA_ARGS__ } \ + } break; +#else +#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing +#endif + #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -412,6 +468,7 @@ struct TypeInfo { using type = byte; \ { __VA_ARGS__ } \ } break; \ + SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ NVTE_ERROR("Invalid type."); \ } @@ -523,6 +580,9 @@ struct TypeInfo { case DType::kFloat8E4M3: { \ NVTE_ERROR("FP8 type not instantiated for input."); \ } break; \ + case DType::kFloat4E2M1: { \ + NVTE_ERROR("FP4 type not instantiated for input."); \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } @@ -593,6 +653,14 @@ struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; +template +struct is_fp4 : std::false_type {}; + +#if FP4_TYPE_SUPPORTED +template <> +struct is_fp4 : std::true_type {}; +#endif + // [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; @@ -611,13 +679,16 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) { } size_t typeToSize(const DType type); +size_t typeToNumBits(const DType type); + +size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype); +size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last, + const DType buffer_dtype); void CheckNoopTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); -bool is_fp8_dtype(const DType t); - /*! \brief Update a tensor's FP8 scale-inverse * * The FP8 scale-inverse (dequantization scaling factor) is updated @@ -636,7 +707,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype); void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_size); + const uint32_t offset_elems, const size_t type_num_bits); bool is_supported_by_CC_100(); diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index e340242c63..15708d2d59 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor int batch = cu_seqlens_shape[0] - 1; int num_heads = tensor_shape[seq_dim + 1]; int dim_per_head = tensor_shape[seq_dim + 2]; - int hidden_size_in_bytes = num_heads * dim_per_head * typeToSize(tensor.dtype()); + int hidden_size_in_bytes = (num_heads * dim_per_head * typeToNumBits(tensor.dtype())) / 8; // For 128-bits load/store NVTE_CHECK(hidden_size_in_bytes % 16 == 0); @@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step, NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head); size_t hidden_size = num_heads * dim_per_head; - NVTE_CHECK((hidden_size * typeToSize(grad.dtype())) % 16 == 0); + NVTE_CHECK(((hidden_size * typeToNumBits(grad.dtype())) / 8) % 16 == 0); constexpr unsigned int block = 256; unsigned int grid_x; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 8a827fde4b..0932b2cf85 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t num_bytes_per_ragged_offset = - alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); + alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8); size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); @@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t num_bytes_per_ragged_offset = - alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); + alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8); size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); @@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void *devPtrQ = static_cast(devPtrQKV); void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); @@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void *devPtrQ = devPtrQKV; void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); @@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void *devPtrK = devPtrKV; void *devPtrV = static_cast(static_cast(devPtrKV) + stride); @@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void *devPtrK = devPtrKV; void *devPtrV = static_cast(static_cast(devPtrKV) + stride); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 431c401cf1..3e38a5066e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void* devPtrQ = static_cast(devPtrQKV); void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); @@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void* devPtrQ = devPtrQKV; void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); @@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void* devPtrK = devPtrKV; void* devPtrV = static_cast(static_cast(devPtrKV) + stride); @@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = typeToSize(QKV_type) * head_dim; + stride = (typeToNumBits(QKV_type) * head_dim) / 8; } void* devPtrK = devPtrKV; void* devPtrV = static_cast(static_cast(devPtrKV) + stride); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 37ba1b4770..dab4fcfe75 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -22,17 +22,18 @@ extern "C" { * \brief TE datatype. */ enum NVTEDType { - kNVTEByte = 0, /*!< Byte */ - kNVTEInt16 = 1, /*!< 16-bit integer */ - kNVTEInt32 = 2, /*!< 32-bit integer */ - kNVTEInt64 = 3, /*!< 64-bit integer */ - kNVTEFloat32 = 4, /*!< 32-bit float */ - kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */ - kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */ - kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */ - kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */ - kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */ - kNVTENumTypes /*!< Number of supported types */ + kNVTEByte = 0, /*!< Byte */ + kNVTEInt16 = 1, /*!< 16-bit integer */ + kNVTEInt32 = 2, /*!< 32-bit integer */ + kNVTEInt64 = 3, /*!< 64-bit integer */ + kNVTEFloat32 = 4, /*!< 32-bit float */ + kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */ + kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */ + kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */ + kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */ + kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */ + kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */ + kNVTENumTypes /*!< Number of supported types */ }; /*! \struct NVTEShape @@ -87,6 +88,10 @@ enum NVTEScalingMode { */ NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_2D = 3, + /*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD), + and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD). + */ + NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4, NVTE_INVALID_SCALING = 100 }; @@ -177,6 +182,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor); */ size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim); +/*! \brief Get the byte size for the tensor. + * + * \param[in] tensor Tensor. + * + * \return Byte size of the tensor. + */ +size_t nvte_tensor_size_bytes(const NVTETensor tensor); + /*! \brief Get a tensor's total number of elements. * * \param[in] tensor Tensor. @@ -193,6 +206,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor); */ size_t nvte_tensor_element_size(const NVTETensor tensor); +/*! \brief Get the bit size for the tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return Bit size of the tensor's data type. + */ +size_t nvte_tensor_element_size_bits(const NVTETensor tensor); + /*! \brief Get a tensor's data type. * * \param[in] tensor Tensor. @@ -390,6 +411,7 @@ enum class DType { kFloat8E4M3 = 7, kFloat8E5M2 = 8, kFloat8E8M0 = 9, + kFloat4E2M1 = 10, kNumTypes }; @@ -398,7 +420,16 @@ enum class DType { * Return true if TE datatype is FP8 * \param[in] DType TE Datatype of interest */ -bool is_fp8_dtype(const DType t); +inline bool is_fp8_dtype(const DType t) { + return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; +} + +/*! \brief Check if TE datatype is FP4 + * + * Return true if TE datatype is FP4 + * \param[in] DType TE Datatype of interest + */ +inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } /*! \struct TensorWrapper * \brief C++ wrapper for the NVTETensor class. @@ -627,6 +658,15 @@ class TensorWrapper { return nvte_tensor_element_size(tensor_); } + /*! \brief Get the tensor's element size in bits. + * + * \return Element size in bits. + */ + size_t element_size_bits() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_element_size_bits(tensor_); + } + /*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr * data even if the TensorWrapper has a non-zero shape and valid dtype. * @@ -634,7 +674,7 @@ class TensorWrapper { */ size_t bytes() const noexcept { if (tensor_ == nullptr || this->dptr() == nullptr) return 0; - return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_); + return nvte_tensor_size_bytes(tensor_); } /*! \brief Get the data type of this TensorWrapper. diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index ae89c7773c..9df81a917f 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -212,8 +212,11 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor } const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype; + NVTE_CHECK(gamma_dtype == DType::kFloat32 || gamma_dtype == DType::kFloat16 || + gamma_dtype == DType::kBFloat16, + "Gamma of type FP4 is not supported"); - _scalar_dptr = std::make_unique(typeToSize(gamma_dtype)); + _scalar_dptr = std::make_unique(typeToNumBits(gamma_dtype) / 8); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( gamma_dtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 5f58e2a99a..6c395837fb 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -18,12 +18,15 @@ namespace transformer_engine { -size_t typeToSize(const DType type) { +size_t typeToNumBits(const DType type) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, return TypeInfo::size;); // NOLINT(*) } -bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; } +size_t typeToSize(const DType type) { + NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type."); + return typeToNumBits(type) / 8; +} std::string to_string(const DType type) { switch (type) { @@ -41,6 +44,8 @@ std::string to_string(const DType type) { return "Float8E5M2"; case DType::kFloat8E8M0: return "Float8E8M0"; + case DType::kFloat4E2M1: + return "Float4E2M1"; case DType::kInt32: return "Int32"; case DType::kInt64: @@ -56,6 +61,8 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; + case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: + return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"; case NVTE_INVALID_SCALING: return "NVTE_INVALID_SCALING"; } @@ -85,10 +92,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { t.columnwise_scale_inv.shape, ")"); } } else { - if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || + t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; size_t expected_x, expected_y, alignment; + const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; + const size_t block_size_colwise = 32; if (t.has_data()) { alignment = block_alignment[0]; @@ -96,7 +106,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { DIVUP(DIVUP(t.flat_first_dim(), static_cast(1)), alignment) * alignment; alignment = block_alignment[1]; expected_y = - DIVUP(DIVUP(t.flat_last_dim(), static_cast(32)), alignment) * alignment; + DIVUP(DIVUP(t.flat_last_dim(), static_cast(block_size_rowwise)), alignment) * + alignment; const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected, ", got ", @@ -105,7 +116,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { if (t.has_columnwise_data()) { alignment = block_alignment[1]; expected_x = - DIVUP(DIVUP(t.flat_first_dim(), static_cast(32)), alignment) * alignment; + DIVUP(DIVUP(t.flat_first_dim(), static_cast(block_size_colwise)), alignment) * + alignment; alignment = block_alignment[0]; expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; const auto &expected = std::vector{expected_x, expected_y}; @@ -384,10 +396,24 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { return numel; } +size_t nvte_tensor_element_size_bits(const NVTETensor tensor) { + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return 8 * sizeof(float); + return transformer_engine::typeToNumBits(t->dtype()); +} + size_t nvte_tensor_element_size(const NVTETensor tensor) { auto *t = transformer_engine::convertNVTETensor(tensor); if (t == nullptr) return sizeof(float); - return transformer_engine::typeToSize(t->dtype()); + NVTE_CHECK(!is_fp4_dtype(t->dtype()), + "For FP4 type please use the nvte_tensor_element_size_bits."); + return nvte_tensor_element_size_bits(tensor) / 8; +} + +size_t nvte_tensor_size_bytes(const NVTETensor tensor) { + auto *t = transformer_engine::convertNVTETensor(tensor); + if (t == nullptr) return 0; + return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8; } void *nvte_tensor_data(const NVTETensor tensor) { @@ -514,7 +540,7 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { const auto &t = *transformer_engine::convertNVTETensorCheck(tensor); // Zero out tensor data if allocated if (t.data.dptr != nullptr) { - size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor); + const size_t size_in_bytes = nvte_tensor_size_bytes(tensor); cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); } // Set amax to 0 if allocated diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index f81fbd3213..ca48a055a7 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -192,17 +192,18 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, / workspace->data.dtype = DType::kFloat32; } else { // Check that workspace matches expected size - const size_t workspace_size = + const size_t workspace_size = get_buffer_size_bytes( std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, - std::multiplies()) * - typeToSize(workspace->data.dtype); - const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + std::multiplies()), + workspace->data.dtype); + const size_t required_size = + get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", num_rows_partial_dbias, ",", row_length, "), found ())"); NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), "; found dims=", workspace->data.shape, - ", dtype=", typeToSize(workspace->data.dtype), ")"); + ", dtype=", typeToNumBits(workspace->data.dtype), " bits)"); } } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index a46e34688a..2be365465b 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -237,8 +237,8 @@ void multi_cast_transpose(const std::vector input_list, std::vectordata.dtype = DType::kFloat32; } else { // Check that workspace matches expected size - const size_t workspace_size = + const size_t workspace_size = get_buffer_size_bytes( std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, - std::multiplies()) * - typeToSize(workspace->data.dtype); - const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + std::multiplies()), + workspace->data.dtype); + const size_t required_size = + get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", num_rows_partial_dbias, ",", row_length, "), found ())"); NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), "; found dims=", workspace->data.shape, - ", dtype=", typeToSize(workspace->data.dtype), ")"); + ", dtype=", typeToNumBits(workspace->data.dtype), " bits)"); } } diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index d85291666e..e2d9ecc519 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -754,19 +754,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if constexpr (IS_DGATED) { create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, sizeof(IType)); + cols, 0, typeToNumBits(gated_input.dtype())); } const uint32_t tensor_stride_elems = output_cols; create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType)); + SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType)); + SHMEM_DIM_X, tensor_stride_elems, cols, + typeToNumBits(output->dtype())); const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_size_aligned_in = @@ -849,31 +850,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out if constexpr (IS_DGATED) { create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, sizeof(IType)); + SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); } const uint32_t tensor_stride_elems = output_cols; create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, + typeToNumBits(gated_input.dtype())); create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, + typeToNumBits(gated_input.dtype())); if (USE_ROWWISE_SCALING) { create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, - sizeof(OType)); + typeToNumBits(output->dtype())); create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, - sizeof(OType)); + typeToNumBits(output->dtype())); } if (USE_COLWISE_SCALING) { create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - 0, sizeof(OType)); + 0, typeToNumBits(output->dtype())); create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - cols, sizeof(OType)); + cols, typeToNumBits(output->dtype())); } const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 14695b095b..610cbf41fa 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -895,15 +895,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T alignas(64) CUtensorMap tensor_map_output{}; create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); if constexpr (IS_DACT) { create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); } create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, sizeof(OType)); + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); cast_fp8_2D_kernel <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, @@ -991,24 +991,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, alignas(64) CUtensorMap tensor_map_output_colwise{}; create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); if constexpr (IS_DACT) { create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - sizeof(IType)); + typeToNumBits(input.dtype())); } if (use_rowwise_scaling) { create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - sizeof(OType)); + typeToNumBits(output->dtype())); } if (use_colwise_scaling) { create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - sizeof(OType)); + typeToNumBits(output->dtype())); } cast_mxfp8_2D_kernelflat_last_dim(); constexpr int TMA_bytes = 16; - const int alignment_requirement = TMA_bytes / typeToSize(t->dtype()); + const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); return cols % alignment_requirement == 0; } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 967a0df3aa..e716065abd 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -319,9 +319,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s alignas(64) CUtensorMap tensor_map_output{}; create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, sizeof(IType)); + SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, sizeof(OType)); + SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype())); dequantize_mxfp8_kernel <<>>(tensor_map_input, tensor_map_output, scales_ptr, diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 51bb006021..df11ddd3f6 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -155,8 +155,8 @@ void multi_padding(const std::vector input_list, std::vector o // Input matrices are divided into tiles // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles - const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); - const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type); + const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type); // Add tensors to kernel argument struct MultiPaddingArgs kernel_args; diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 227b3aaa48..e6a54108ed 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -11,6 +11,10 @@ #include #include +#if CUDA_VERSION >= 12080 +#include +#endif + #if !defined(__CUDACC_RTC__) #include #else diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index f5698855ee..1dcb4e4e45 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -221,21 +221,23 @@ std::vector getTensorShape(at::Tensor t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); -inline size_t typeToSize(transformer_engine::DType t) { +inline size_t typeToNumBits(transformer_engine::DType t) { switch (t) { case transformer_engine::DType::kInt64: - return 8; + return 64; case transformer_engine::DType::kInt32: case transformer_engine::DType::kFloat32: - return 4; + return 32; case transformer_engine::DType::kInt16: case transformer_engine::DType::kFloat16: case transformer_engine::DType::kBFloat16: - return 2; + return 16; case transformer_engine::DType::kByte: case transformer_engine::DType::kFloat8E4M3: case transformer_engine::DType::kFloat8E5M2: - return 1; + return 8; + case transformer_engine::DType::kFloat4E2M1: + return 4; default: NVTE_ERROR("Invalid type"); } diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 9d6a99e6a9..55a1ec169a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); - size_t element_size = transformer_engine::pytorch::typeToSize(self.dtype()); + size_t element_size_bits = transformer_engine::pytorch::typeToNumBits(self.dtype()); int32_t start_row = start_index.data_ptr()[0]; void *base_ptr = static_cast(self.get_rowwise_data().data_ptr) + - static_cast(start_row) * fcd_size * element_size; + static_cast(start_row) * fcd_size * element_size_bits / 8; size_t num_rows_to_zero = max_tokens - start_row; - size_t total_bytes = num_rows_to_zero * fcd_size * element_size; + size_t total_bytes = num_rows_to_zero * fcd_size * element_size_bits / 8; NVTE_SCOPED_GIL_RELEASE( { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); From 71c76b6ba2055c42b0121e6d3f0f34eedd5f7988 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Sat, 14 Jun 2025 01:45:54 +0800 Subject: [PATCH 20/39] Add support for head_dim > 128 (#1797) * add support for head dim > 128 Signed-off-by: Charlene Yang * remove debugging Signed-off-by: Charlene Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * raise tols slightly to tolerate 1/2048 mismatches Signed-off-by: Charlene Yang * fix is_training for test_te_layer Signed-off-by: Charlene Yang * add bprop support for blackwell Signed-off-by: Charlene Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor tweak for format Signed-off-by: Charlene Yang * fix backend selection results Signed-off-by: Charlene Yang * bump sm100 to sm100+ Signed-off-by: Charlene Yang * add sq=1 test for MLA Signed-off-by: Charlene Yang * enable sq=1 for bprop Signed-off-by: Charlene Yang * minor tweak in comments Signed-off-by: Charlene Yang * fix head_dim logic and remove pytest skip Signed-off-by: Charlene Yang * add FE fix for d>128 Signed-off-by: Charlene Yang * update FE again to take in small fixes Signed-off-by: Charlene Yang * add cuDNN version info in L0 tests Signed-off-by: Charlene Yang * increase tols for Unfused + large dim Signed-off-by: Charlene Yang * Revert "add cuDNN version info in L0 tests" This reverts commit 3e1b426ca5319a2c0540b9e73bba7047d0e583e5. Signed-off-by: Charlene Yang * fix tols for Unfused Signed-off-by: Charlene Yang --------- Signed-off-by: Charlene Yang Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- tests/jax/test_distributed_fused_attn.py | 3 + tests/jax/test_fused_attn.py | 1 + tests/pytorch/fused_attn/test_fused_attn.py | 74 ++++++++++++++++--- tests/pytorch/fused_attn/test_kv_cache.py | 28 +++++-- .../common/fused_attn/fused_attn.cpp | 65 +++++++++------- .../include/transformer_engine/fused_attn.h | 9 ++- transformer_engine/jax/attention.py | 2 + .../jax/cpp_extensions/attention.py | 3 + transformer_engine/jax/csrc/extensions.h | 2 +- .../jax/csrc/extensions/attention.cpp | 20 ++--- transformer_engine/jax/flax/transformer.py | 2 + .../attention/dot_product_attention/utils.py | 1 + transformer_engine/pytorch/csrc/extensions.h | 12 ++- .../pytorch/csrc/extensions/attention.cpp | 14 ++-- 15 files changed, 166 insertions(+), 72 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 724f0ec8ce..f937055efc 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 724f0ec8ce06027feada51f2d948cd3313e63720 +Subproject commit f937055efc6d414d11f4c6577e3977fe74f35fb6 diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index adef8dd627..afb3a1df0c 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -68,6 +68,7 @@ def impl_test_self_attn( batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( + is_training, dtype, dtype, QKVLayout.BS3HD, @@ -214,6 +215,7 @@ def test_cross_attn( batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( + is_training, dtype, dtype, QKVLayout.BSHD_BS2HD, @@ -346,6 +348,7 @@ def impl_test_context_parallel_attn( def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( + is_training, dtype, dtype, qkv_layout, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 745f1cc633..2332bbc0de 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -347,6 +347,7 @@ def _check_configs(self): ) self.backend = FusedAttnHelper( + self.is_training, self.dtype, self.dtype, self.qkv_layout, diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index b82665911a..a05e64fca3 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -222,13 +222,19 @@ def test(): model_configs_base = { - # test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend - "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 - "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0 - "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1 - "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 - "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference - "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + # test: b, h, hg, d, sq, skv, p, mask, bias + "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), + "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), + "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), } @@ -270,14 +276,28 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + + is_training = True available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, window_size=config.window_size, pad_between_seqs=pad_between_seqs, + is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: + is_training = False + available_backends, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes @@ -296,7 +316,6 @@ def test_dot_product_attention( if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") - is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 # UnfusedDotProductAttention backend if unfused_attn_supported: unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( @@ -360,6 +379,7 @@ def test_dot_product_attention( is_training, ) + logging.info(f"[test_dot_product_attention]: is_training = {is_training}") if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) @@ -399,18 +419,27 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mla_1_1": ModelConfig( 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 ), # cross, 0 + "mla_1_2": ModelConfig( + 4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # cross, 0 "mla_2_0": ModelConfig( 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 ), # self , 1 "mla_2_1": ModelConfig( 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 ), # cross, 1 + "mla_2_2": ModelConfig( + 1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128 + ), # cross, 1 "mla_3_0": ModelConfig( 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 ), # inference "mla_3_1": ModelConfig( 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 ), # inference + "mla_3_2": ModelConfig( + 8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # inference } @@ -1024,6 +1053,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type=config.attn_type, ).to(dtype=dtype, device="cuda") + if not is_training: + block = block.eval() # Run a forward and backward pass if backend in ["FlashAttention", "UnfusedDotProductAttention"]: @@ -1136,14 +1167,29 @@ def test_transformer_layer( workspace_opt = True # Test backend availability + is_training = True available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=( qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") ), + is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: + is_training = False + available_backends, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=( + qkv_format.replace("hd", "h3d") + if fused_qkv_params + else qkv_format.replace("hd", "3hd") + ), + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: @@ -1163,6 +1209,7 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) # FusedAttention backend @@ -1176,6 +1223,7 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) # FlashAttention backend @@ -1189,8 +1237,10 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) + logging.info(f"[test_transformer_layer]: is_training = {is_training}") if unfused_attn_supported and fused_attn_supported: logging.info("[test_transformer_layer]: unfused attn vs fused attn") torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) @@ -1257,6 +1307,7 @@ def _run_transformer_layer( workspace_opt: bool, fused_qkv_params: bool, RoPE: bool, + is_training: bool, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run TransformerLayer module with one forward pass and one backward pass""" @@ -1410,6 +1461,8 @@ def _run_transformer_layer( bias=True, attn_input_format=qkv_format, ).to(dtype=dtype, device="cuda") + if not is_training: + block = block.eval() # Create ALiBi slopes alibi_slopes = None @@ -1432,8 +1485,9 @@ def _run_transformer_layer( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, ) - loss = out.sum() - loss.backward() + if is_training: + loss = out.sum() + loss.backward() return out, inp.grad diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/fused_attn/test_kv_cache.py index eb3838ff12..9673094597 100644 --- a/tests/pytorch/fused_attn/test_kv_cache.py +++ b/tests/pytorch/fused_attn/test_kv_cache.py @@ -52,7 +52,7 @@ 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 ), "infer_1": ModelConfig( - 2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 + 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 ), } @@ -370,12 +370,24 @@ def generate_args( ] -def get_tols(module, backend, dtype): +def get_tols(config, module, backend, dtype): if module == "TransformerLayer": - tols = { - torch.half: (5e-3, 5e-3), - torch.bfloat16: (3.5e-2, 3.5e-2), - } + if config.head_dim_qk <= 128: + tols = { + torch.half: (5e-3, 5e-3), + torch.bfloat16: (3.5e-2, 3.5e-2), + } + else: + if backend == "UnfusedAttention": + tols = { + torch.half: (1.6e-2, 1.6e-2), + torch.bfloat16: (1.2e-1, 1e-1), + } + else: + tols = { + torch.half: (1e-2, 1e-2), + torch.bfloat16: (8e-2, 7e-2), + } if module == "DotProductAttention": tols = { torch.half: (1e-3, 1e-3), @@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g incremental_output = incremental_output[0] # compare results - atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) + atol, rtol = get_tols( + config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn + ) for i, seq in enumerate(sim.t_seq_ids): token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index a6784bacbb..b512133efd 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if ( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - // special conditions for blackwell - // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 - !(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) && // architecture - ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || - (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && + ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) || + (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) || + (cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) && // sequence length ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || (cudnn_runtime_version >= 90000)) && @@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) && // head dimension - ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // d=256 only supported for forward - (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && - head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && + // multiples of 8 + (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 && + // <= 128 + ((head_dim_qk <= 128 && head_dim_v <= 128) || + // 9.1: <= 256 + Hopper + fprop + // 9.5: <= 256 + Hopper + bprop + (head_dim_qk <= 256 && head_dim_v <= 256 && + ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || + (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || + // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 + (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || + // 9.10: any head_dim + any arch + fprop + paged + // 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 + // 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} + (!is_training && cudnn_runtime_version >= 91000 && + (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || + (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || + // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged + (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && + cudnn_runtime_version >= 91100))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && @@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -505,7 +520,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { @@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -731,8 +746,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -954,8 +969,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index ebe8341cca..44f5791490 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * + * \param[in] is_training Whether the model is in training mode. * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V. @@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right); + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 2c57d284de..d24e853e1c 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str): def is_fused_attn_kernel_available( + is_training, q_dtype, kv_dtype, qkv_layout, @@ -296,6 +297,7 @@ def is_fused_attn_kernel_available( def make_helper(attn_mask_type): return tex.FusedAttnHelper( + is_training, q_dtype, kv_dtype, qkv_layout, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6b617355a3..e8907eb127 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -103,6 +103,7 @@ class FusedAttnHelper: Helper for the fused attention backend """ + is_training: bool q_dtype: jnp.dtype kv_dtype: jnp.dtype qkv_layout: QKVLayout @@ -123,6 +124,7 @@ def is_fused_attn_kernel_available(self): def get_fused_attn_backend(self): """Get the fused attention kernel backend""" return transformer_engine_jax.get_fused_attn_backend( + self.is_training, jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), self.qkv_layout.value, @@ -276,6 +278,7 @@ def abstract( # backend determines the softmax buffer shape/dtype backend = FusedAttnHelper( + config.is_training, q_dtype, k_dtype, config.qkv_layout, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index aa257abe95..47399bc791 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -96,7 +96,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); -NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, +NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_num_heads, size_t kv_num_heads, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 7235d3f232..d3bb845642 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -11,7 +11,7 @@ namespace transformer_engine { namespace jax { -NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, +NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, @@ -19,9 +19,9 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, size_t head_dim, int64_t window_size_left, int64_t window_size_right) { auto backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, + bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); return backend; } @@ -263,9 +263,9 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -518,9 +518,9 @@ static void FusedAttnBackwardImpl( NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 1ac22a6d2f..e9a3047742 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -596,6 +596,8 @@ def __call__( seqlen_kv = key.shape[sequence_dim] has_fused_attn_kernel = is_fused_attn_kernel_available( + # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. + not deterministic, self.dtype, self.dtype, qkv_layout, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ec93e8c5c8..d98dde0159 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -761,6 +761,7 @@ def get_attention_backend( q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) kv_type = q_type fused_attention_backend = tex.get_fused_attn_backend( + is_training, q_type, kv_type, QKVLayout[qkv_layout], diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d4218a08be..72f6f27596 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -35,13 +35,11 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float p_dropout, - size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right); +NVTE_Fused_Attn_Backend get_fused_attn_backend( + bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 55a1ec169a..71a8062b1a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -57,14 +57,14 @@ namespace transformer_engine::pytorch { // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( - const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim_qk, head_dim_v, window_size_left, window_size_right); + is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, + bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; } From 1ddfa0c6a468be6cfe02f2cb051110c94ba17a61 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Fri, 13 Jun 2025 15:14:06 -0700 Subject: [PATCH 21/39] [JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v (#1851) * Add support for Fused Attn MLA head_dim_qk != head_dim_v Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval() Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v Modify DotProductAttention call() to extract head dims separately for qk and v Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API Add test case for head_dim_qk != head_dim_v (failing) Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix context dims in general DPA in test_fused_attn Signed-off-by: Kshitij Janardan Lakhani * Fix dim for output tensor by replacing with v head dim rather than q head dim Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type Signed-off-by: Kshitij Janardan Lakhani * Modify the fused attn jax unit test case for head dim qk != head dim v Signed-off-by: Kshitij Janardan Lakhani * Use new FusedAttnRunner function signature for separate hidden dim for qk and v in Fused Attn distributed tests Code clean up Signed-off-by: Kshitij Janardan Lakhani * Fix usage of is_fused_attn signature in distributed tests Signed-off-by: Kshitij Janardan Lakhani * Remove unnecessary assert Signed-off-by: Kshitij Janardan Lakhani --------- Signed-off-by: Kshitij Janardan Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_distributed_fused_attn.py | 6 + tests/jax/test_fused_attn.py | 75 +++++++++--- transformer_engine/jax/attention.py | 10 +- .../jax/cpp_extensions/attention.py | 115 +++++++++++++----- transformer_engine/jax/csrc/extensions.h | 12 +- .../jax/csrc/extensions/attention.cpp | 99 +++++++-------- transformer_engine/jax/flax/transformer.py | 13 +- 7 files changed, 220 insertions(+), 110 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index afb3a1df0c..e88108155e 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -80,6 +80,7 @@ def impl_test_self_attn( seqlen, seqlen, hidden, + hidden, None, # no window ): pytest.skip("No FusedAttn backend found") @@ -99,6 +100,7 @@ def impl_test_self_attn( num_head, num_head, hidden, + hidden, attn_bias_type, attn_mask_type, dropout_prob, @@ -227,6 +229,7 @@ def test_cross_attn( seqlen, seqlen, hidden, + hidden, None, # no window ): pytest.skip("No FusedAttn backend found") @@ -239,6 +242,7 @@ def test_cross_attn( num_head, num_head, hidden, + hidden, attn_bias_type, attn_mask_type, dropout_prob, @@ -329,6 +333,7 @@ def impl_test_context_parallel_attn( num_head, num_kv_heads, hidden, + hidden, attn_bias_type, attn_mask_type, dropout_prob, @@ -360,6 +365,7 @@ def check_has_backend_for_mask(mask_type): seqlen, seqlen, hidden, + hidden, None, ) # no SWA for CP diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 2332bbc0de..f9e5c8ad2e 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -106,7 +106,8 @@ def general_dot_product_attention( softmax_out = softmax_out * multiplier context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value) - context = jnp.reshape(context, query.shape) + context_shape = query.shape[:-1] + (value.shape[-1],) + context = jnp.reshape(context, context_shape) return context @@ -294,7 +295,8 @@ class FusedAttnRunner: max_seqlen_kv: int num_heads_q: int num_heads_kv: int - head_dim: int + head_dim_qk: int + head_dim_v: int attn_bias_type: AttnBiasType attn_mask_type: AttnMaskType dropout_prob: float @@ -346,6 +348,14 @@ def _check_configs(self): "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) + # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors + # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats + if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): + pytest.skip( + "For head_dim_qk != head_dim_v, it is necessary that the QKV layout " + "is either BSHD_BSHD_BSHD or THD_THD_THD" + ) + self.backend = FusedAttnHelper( self.is_training, self.dtype, @@ -358,7 +368,8 @@ def _check_configs(self): self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv, - self.head_dim, + self.head_dim_qk, + self.head_dim_v, (-1, -1) if self.window_size is None else self.window_size, ).get_fused_attn_backend() if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: @@ -391,13 +402,9 @@ def _setup_inputs(self): key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) - q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) - k_shape = v_shape = ( - self.batch_size, - self.max_seqlen_kv, - self.num_heads_kv, - self.head_dim, - ) + q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk) + k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk) + v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v) if self.attn_bias_type == AttnBiasType.NO_BIAS: bias_shape = None @@ -616,7 +623,7 @@ def generate_random_segment_ids( raise ValueError(f"Unknown {self.seq_desc_format=}") self.dropout_rng = dropout_key if self.dropout_prob > 0 else None - self.scaling_factor = 1.0 / sqrt(self.head_dim) + self.scaling_factor = 1.0 / sqrt(self.head_dim_qk) # Setup distributed sharding specs # Setup shardings for distributed tests @@ -935,9 +942,31 @@ def check_dqkv(primitive, reference, pad, idx): ], ) @pytest.mark.parametrize( - "b, s_q, s_kv, h_q, h_kv, d, dtype", + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", [ - pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), + pytest.param( + 2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF" + ), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + id="2-2048-1024-12-12-64-64-BF16-CROSS", + ), + pytest.param( + 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" + ), + pytest.param( + 4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF" + ), + pytest.param( + 4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF" + ), pytest.param( 2, 2048, @@ -945,11 +974,13 @@ def check_dqkv(primitive, reference, pad, idx): 12, 12, 64, + 32, jnp.bfloat16, - id="2-2048-1024-12-12-64-BF16-CROSS", + id="2-2048-1024-12-12-64-32-BF16-CROSS", + ), + pytest.param( + 2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA" ), - pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), - pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), ], ) @pytest.mark.parametrize( @@ -1003,7 +1034,8 @@ def _test_forward( s_kv, h_q, h_kv, - d, + d_qk, + d_v, attn_bias_type, attn_mask_type, dropout_prob, @@ -1028,7 +1060,8 @@ def _test_forward( s_kv, h_q, h_kv, - d, + d_qk, + d_v, attn_bias_type, attn_mask_type, dropout_prob, @@ -1055,7 +1088,8 @@ def test_backward( s_kv, h_q, h_kv, - d, + d_qk, + d_v, attn_bias_type, attn_mask_type, dropout_prob, @@ -1077,7 +1111,8 @@ def test_backward( s_kv, h_q, h_kv, - d, + d_qk, + d_v, attn_bias_type, attn_mask_type, dropout_prob, diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index d24e853e1c..fe4109cee8 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -188,7 +188,7 @@ class ReorderStrategy(Enum): - DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the - mulitple of 2 * cp_size. + multiple of 2 * cp_size. Examples: - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15]; - After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3] @@ -288,7 +288,8 @@ def is_fused_attn_kernel_available( kv_num_heads, q_max_seqlen, kv_max_seqlen, - head_dim, + head_dim_qk, + head_dim_v, window_size: Optional[Tuple[int, int]] = None, ): """ @@ -308,7 +309,8 @@ def make_helper(attn_mask_type): kv_num_heads, q_max_seqlen, kv_max_seqlen, - head_dim, + head_dim_qk, + head_dim_v, (-1, -1) if window_size is None else window_size, ) @@ -491,7 +493,7 @@ def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type): @jax.tree_util.register_pytree_node_class class SequenceDescriptor: - """A class to descibe the sequences with flexible initialization. + """A class to describe the sequences with flexible initialization. - SequenceDescriptor.from_seqlens For non-THD (non-packed) cases, where each batch has only 1 sequence. - SequenceDescriptor.from_seqlens_and_offsets diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e8907eb127..089ef75f1c 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -114,7 +114,8 @@ class FusedAttnHelper: kv_num_heads: int q_max_seqlen: int kv_max_seqlen: int - head_dim: int + head_dim_qk: int + head_dim_v: int window_size: Tuple[int, int] def is_fused_attn_kernel_available(self): @@ -135,7 +136,8 @@ def get_fused_attn_backend(self): self.kv_num_heads, self.q_max_seqlen, self.kv_max_seqlen, - self.head_dim, + self.head_dim_qk, + self.head_dim_v, self.window_size[0], self.window_size[1], ) @@ -155,23 +157,49 @@ def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): kv_batch_shape = q_batch_shape kv_max_seqlen = q_max_seqlen num_gqa_groups = attn_heads - kv_head_dim = q_head_dim + v_head_dim = q_head_dim assert nqkv == 3 elif qkv_layout.is_kvpacked(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape + *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape + assert q_batch_shape == kv_batch_shape + assert q_head_dim == v_head_dim assert nkv == 2 elif qkv_layout.is_separate(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape - assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}" + *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape + *v_batch_shape, v_max_seqlen, v_num_gqa_groups, v_head_dim = v_aval.shape + assert ( + q_head_dim == k_head_dim + ), f"Mismatched q_head_dim: {q_head_dim} and k_head_dim: {k_head_dim}" + assert ( + k_max_seqlen == v_max_seqlen + ), f"Mismatched k_max_seqlen: {k_max_seqlen} and v_max_seqlen: {v_max_seqlen}" + kv_max_seqlen = k_max_seqlen + assert q_batch_shape == k_batch_shape == v_batch_shape, ( + f"Mismatched qkv batch size for q_batch_shape: {q_batch_shape}, k_batch_shape:" + f" {k_batch_shape} and v_batch_shape: {v_batch_shape}" + ) + assert k_num_gqa_groups == v_num_gqa_groups, ( + f"Mismatched k_num_gqa_groups: {k_num_gqa_groups} and v_num_gqa_groups:" + f" {v_num_gqa_groups}" + ) + num_gqa_groups = k_num_gqa_groups else: raise ValueError(f"Unexpected {qkv_layout=}") - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert q_aval.dtype == k_aval.dtype == v_aval.dtype - - return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim) + assert q_aval.dtype == k_aval.dtype == v_aval.dtype, ( + f"Mismatched data types for q_aval: {q_aval.dtype}, k_aval: {k_aval.dtype}, v_aval:" + f" {v_aval.dtype}" + ) + return ( + q_batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + q_head_dim, + v_head_dim, + ) @dataclass(frozen=True) @@ -269,11 +297,17 @@ def abstract( f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}" ) - batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - ) + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + q_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim) + output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim) out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) # backend determines the softmax buffer shape/dtype @@ -289,7 +323,8 @@ def abstract( num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, + q_head_dim, + v_head_dim, config.window_size, ).get_fused_attn_backend() @@ -340,7 +375,8 @@ def abstract( attn_heads, num_gqa_groups, bias_heads, - head_dim, + q_head_dim, + v_head_dim, config.scaling_factor, config.dropout_probability, config.attn_bias_type.value, @@ -392,9 +428,15 @@ def lowering( """ q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - ) + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + q_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) input_batch = reduce(operator.mul, batch_shape) @@ -433,7 +475,8 @@ def lowering( attn_heads=attn_heads, num_gqa_groups=num_gqa_groups, bias_heads=bias_heads, - head_dim=head_dim, + qk_head_dim=q_head_dim, + v_head_dim=v_head_dim, max_segments_per_seq=config.max_segments_per_seq, scaling_factor=float(config.scaling_factor), dropout_probability=float(config.dropout_probability), @@ -711,9 +754,15 @@ def abstract( assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype - batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - ) + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + qk_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 @@ -732,7 +781,8 @@ def abstract( attn_heads, num_gqa_groups, bias_heads, - head_dim, + qk_head_dim, + v_head_dim, config.scaling_factor, config.dropout_probability, config.attn_bias_type.value, @@ -791,9 +841,15 @@ def lowering( """ q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( - FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - ) + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + qk_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) input_batch = reduce(operator.mul, batch_shape) @@ -835,7 +891,8 @@ def lowering( attn_heads=attn_heads, num_gqa_groups=num_gqa_groups, bias_heads=bias_heads, - head_dim=head_dim, + qk_head_dim=qk_head_dim, + v_head_dim=v_head_dim, max_segments_per_seq=config.max_segments_per_seq, scaling_factor=float(config.scaling_factor), dropout_probability=float(config.dropout_probability), diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 47399bc791..0789478348 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -101,20 +101,20 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy NVTE_Mask_Type mask_type, float dropout_probability, size_t q_num_heads, size_t kv_num_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t head_dim, int64_t window_size_left, - int64_t window_size_right); + size_t qk_head_dim, size_t v_head_dim, + int64_t window_size_left, int64_t window_size_right); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index d3bb845642..40089dc2d6 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy NVTE_Mask_Type mask_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t head_dim, int64_t window_size_left, - int64_t window_size_right) { + size_t qk_head_dim, size_t v_head_dim, + int64_t window_size_left, int64_t window_size_right) { auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, - kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); return backend; } @@ -117,24 +117,24 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); // For kv_packed - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); // For separate q, k, v - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_shape = k_shape; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; @@ -237,17 +237,17 @@ static void FusedAttnForwardImpl( void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, - size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, - float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, - bool deterministic, int64_t window_size_left, int64_t window_size_right) { + size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); if (is_ragged) { - auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; + auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim; cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); // Memset to 0xF0 for filling large negative numbers @@ -257,7 +257,7 @@ static void FusedAttnForwardImpl( /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 - auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; auto o_tensor = TensorWrapper(output, o_shape, dtype); /* Prepare RNG state */ @@ -265,7 +265,7 @@ static void FusedAttnForwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -278,7 +278,7 @@ static void FusedAttnForwardImpl( /* Call the underlying NVTE API */ auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), @@ -287,8 +287,9 @@ static void FusedAttnForwardImpl( qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto kv_shape = + std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; auto q_tensor = TensorWrapper(q, q_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype); nvte_fused_attn_fwd_kvpacked( @@ -299,9 +300,9 @@ static void FusedAttnForwardImpl( is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; auto q_tensor = TensorWrapper(q, q_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); @@ -327,7 +328,8 @@ static void FusedAttnForwardImpl( size_t attn_heads = get_attr_value(attrs, "attn_heads"); \ size_t num_gqa_groups = get_attr_value(attrs, "num_gqa_groups"); \ size_t bias_heads = get_attr_value(attrs, "bias_heads"); \ - size_t head_dim = get_attr_value(attrs, "head_dim"); \ + size_t qk_head_dim = get_attr_value(attrs, "qk_head_dim"); \ + size_t v_head_dim = get_attr_value(attrs, "v_head_dim"); \ size_t max_segments_per_seq = get_attr_value(attrs, "max_segments_per_seq"); \ auto window_size_left = get_attr_value(attrs, "window_size_left"); \ auto window_size_right = get_attr_value(attrs, "window_size_right"); \ @@ -362,9 +364,9 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, - head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, - mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, - window_size_right); + qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, + dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training, + deterministic, window_size_left, window_size_right); return ffi_with_cuda_error_check(); } @@ -391,33 +393,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); // For kv_packed - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); // For separate q, k, v - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_shape = k_shape; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); @@ -498,15 +500,15 @@ static void FusedAttnBackwardImpl( void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; auto output_tensor = TensorWrapper(output, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); @@ -520,14 +522,14 @@ static void FusedAttnBackwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); /* Call the underly NVTE API */ if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); if (is_ragged) { @@ -543,8 +545,9 @@ static void FusedAttnBackwardImpl( bias_type, mask_type, window_size_left, window_size_right, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto kv_shape = + std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; auto q_tensor = TensorWrapper(q, q_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype); @@ -564,9 +567,9 @@ static void FusedAttnBackwardImpl( dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; auto q_tensor = TensorWrapper(q, q_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); @@ -614,9 +617,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(), dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(), workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, - scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, - is_training, deterministic, window_size_left, window_size_right); + attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, + wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, + wkspace_dtype, is_training, deterministic, window_size_left, window_size_right); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index e9a3047742..97ab519b9c 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -594,6 +594,12 @@ def __call__( seqlen_kv = seqlen_q else: seqlen_kv = key.shape[sequence_dim] + if qkv_layout.is_separate(): + head_dim_qk = query.shape[-1] + head_dim_v = value.shape[-1] + else: + head_dim_qk = self.head_dim + head_dim_v = self.head_dim has_fused_attn_kernel = is_fused_attn_kernel_available( # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. @@ -608,7 +614,8 @@ def __call__( self.num_gqa_groups, seqlen_q, seqlen_kv, - self.head_dim, + head_dim_qk, + head_dim_v, self.window_size, ) @@ -621,7 +628,7 @@ def __call__( "Please try to update the cuDNN and TE to the latest version.\n" f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" - f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n" + f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" ) dropout_rng = None @@ -629,7 +636,7 @@ def __call__( dropout_rng = self.make_rng(self.dropout_rng_name) if self.scale_factor is None: - scale_factor = 1.0 / sqrt(self.head_dim) + scale_factor = 1.0 / sqrt(head_dim_qk) else: scale_factor = self.scale_factor del self.scale_factor From 980c4342406bcd3d9d8b15538b67b9d13468d2e5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 13 Jun 2025 16:07:06 -0700 Subject: [PATCH 22/39] Changed VERSION to 2.5.0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index ece4f82d95..437459cd94 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.5.0.dev0 +2.5.0 From efe19c3c3b2ec6b951cbe8c2816041eee2dc498c Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Mon, 16 Jun 2025 08:33:37 -0700 Subject: [PATCH 23/39] [JAX] Grouped GEMM & Dense support MXFP8 and handle empty matrices (#1871) * Support MXFP8 and handle empty matrices Signed-off-by: Hua Huang --------- Signed-off-by: Hua Huang --- tests/jax/test_custom_call_compute.py | 9 +- .../include/transformer_engine/multi_stream.h | 22 ++ .../common/util/multi_stream.cpp | 8 + transformer_engine/jax/cpp_extensions/gemm.py | 17 +- .../jax/csrc/extensions/gemm.cpp | 200 ++++++++++++++---- transformer_engine/jax/dense.py | 10 +- 6 files changed, 204 insertions(+), 62 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index f689bce6a5..54ceecdab6 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1250,6 +1250,9 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) group_sizes = jnp.diff(group_sizes) + # Make one empty input lhs to test empty GEMM handling + group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) + group_sizes = group_sizes.at[1].set(0) assert group_sizes.sum() == m # *32 to make sure that input shape works for MXFP8 @@ -1301,9 +1304,6 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - pytest.skip("MXFP8 is not supported in grouped_gemm yet") - fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, @@ -1388,9 +1388,6 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - pytest.skip("MXFP8 is not supported in grouped_dense yet") - fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( diff --git a/transformer_engine/common/include/transformer_engine/multi_stream.h b/transformer_engine/common/include/transformer_engine/multi_stream.h index 6e0506100a..e406a07867 100644 --- a/transformer_engine/common/include/transformer_engine/multi_stream.h +++ b/transformer_engine/common/include/transformer_engine/multi_stream.h @@ -11,6 +11,8 @@ #ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H #define TRANSFORMER_ENGINE_MULTI_STREAM_H +#include "cuda_runtime.h" + #ifdef __cplusplus extern "C" { #endif @@ -18,6 +20,26 @@ extern "C" { /*! \brief Number of CUDA streams to use in multi-stream operations */ int nvte_get_num_compute_streams(); +/*! \brief Get a CUDA stream for compute operations. + * + * \param[in] idx Index of the stream to retrieve.Add commentMore actions + * \return A cudaStream_t. + * + * This function returns a CUDA stream that can be used for compute operations. + * The index should be in the range [0, nvte_get_num_compute_streams() - 1]. + */ +cudaStream_t nvte_get_compute_stream(const int idx); + +/*! \brief Get a CUDA event for compute operations. + * + * \param[in] idx Index of the event to retrieve. + * \return A cudaEvent_t. + * + * This function returns a CUDA event that can be used to synchronize compute operations. + * The index should be in the range [0, nvte_get_num_compute_streams() - 1]. + */ +cudaEvent_t nvte_get_compute_stream_event(const int idx); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/util/multi_stream.cpp b/transformer_engine/common/util/multi_stream.cpp index ffce1f4c31..70d7376afa 100644 --- a/transformer_engine/common/util/multi_stream.cpp +++ b/transformer_engine/common/util/multi_stream.cpp @@ -58,4 +58,12 @@ int get_num_compute_streams() { int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); } +cudaStream_t nvte_get_compute_stream(const int idx) { + return transformer_engine::detail::get_compute_stream(idx); +} + +cudaEvent_t nvte_get_compute_stream_event(const int idx) { + return transformer_engine::detail::get_compute_stream_event(idx); +} + #endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d3c23015c1..94c05f5aa8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -103,14 +103,15 @@ def abstract( """ del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias - del lhs_scale_inv_aval, rhs_scale_inv_aval # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams - # JAX buffer pointers are 128-aligned - # 255 is added to the workspace size to ensure workspace ptr is 256-aligned - workspace_size += 255 + # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not + # necessarily 256 bytes aligned, we add some padding to ensure alignment. + # We also pad scale_inv swizzle buffers size for 256 bytes alignment. + workspace_size += 256 + workspace_size += lhs_scale_inv_aval.size + 256 + workspace_size += rhs_scale_inv_aval.size + 256 workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - # TODO(phuong): We should make separate tmp buffers for swizzled scales to avoid unaligned-by-256 workspace ptr issue out_shape = (M, N) if is_grouped_dense_wgrad: @@ -495,7 +496,8 @@ def grouped_gemm( # and is_gemm_with_all_layouts_supported() scaling_mode.is_1d_block_scaling() ): - lhs_is_rowwise = rhs_is_rowwise = True + lhs_is_rowwise = True + rhs_is_rowwise = False else: lhs_is_rowwise = not lhs_is_trans rhs_is_rowwise = lhs_is_trans @@ -557,9 +559,6 @@ def grouped_gemm( assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias - # TODO(Phuong): support MXFP8_1D_SCALING - assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported" - (out,) = GroupedGemmPrimitive.outer_primitive.bind( lhs_data, lhs_scale_inv, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d57d4682ca..c03f7f7751 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -10,6 +10,8 @@ #include "../extensions.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" +#include "transformer_engine/multi_stream.h" +#include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" #define MXFP8_BLOCK_SIZE 32 @@ -17,6 +19,12 @@ namespace transformer_engine { namespace jax { +static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { + // Move the pointer to the next 256B aligned address + return reinterpret_cast((reinterpret_cast(ptr) + 255) & + ~static_cast(255)); +} + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, @@ -58,11 +66,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto out_ptr = reinterpret_cast(output->untyped_data()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned - auto workspace_ptr = - reinterpret_cast((reinterpret_cast(workspace->untyped_data()) + 255) & - ~static_cast(255)); - auto workspace_total_size = product(workspace->dimensions()) - 255; - auto workspace_size = workspace_total_size / num_streams; + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + auto workspace_total_size = product(workspace->dimensions()); + + auto lhs_sinv_size = product(lhs_sinv.dimensions()); + auto rhs_sinv_size = product(rhs_sinv.dimensions()); + auto workspace_size = + (workspace_total_size - lhs_sinv_size - rhs_sinv_size - 3 * 256) / num_streams; + auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams; + swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); + auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; + swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); @@ -122,6 +137,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); + const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; @@ -135,6 +152,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; std::vector rhs_wrapper_list; + std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling + std::vector rhs_swizzle_wrapper_list; std::vector bias_wrapper_list; std::vector pre_gelu_wrapper_list; std::vector out_wrapper_list; @@ -143,66 +162,119 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM std::vector lhs_list; std::vector rhs_list; + std::vector lhs_swizzle_list; + std::vector rhs_swizzle_list; std::vector bias_list; std::vector pre_gelu_list; std::vector out_list; std::vector workspace_list; + size_t lhs_sinv_total_size = 0; + size_t rhs_sinv_total_size = 0; + + std::vector zero_out_dptr_list; + std::vector zero_out_size_list; + for (size_t i = 0; i < num_gemms; i++) { // Matrix data shapes size_t m_i = dim_list_host[i]; - auto lhs_shape = std::vector{m_i, k}; - auto rhs_shape = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; - auto out_shape = std::vector{m_i, n}; + auto lhs_shape_i = std::vector{m_i, k}; + auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; + auto out_shape_i = std::vector{m_i, n}; if (is_grouped_dense_wgrad) { size_t k_i = dim_list_host[i]; - lhs_shape[0] = lhs_is_trans ? k_i : m; - lhs_shape[1] = lhs_is_trans ? m : k_i; - rhs_shape[0] = rhs_is_trans ? n : k_i; - rhs_shape[1] = rhs_is_trans ? k_i : n; - out_shape[0] = m; - out_shape[1] = n; + lhs_shape_i[0] = lhs_is_trans ? k_i : m; + lhs_shape_i[1] = lhs_is_trans ? m : k_i; + rhs_shape_i[0] = rhs_is_trans ? n : k_i; + rhs_shape_i[1] = rhs_is_trans ? k_i : n; + out_shape_i[0] = m; + out_shape_i[1] = n; + } + + size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1]; + size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1]; + size_t out_size = out_shape_i[0] * out_shape_i[1]; + bool is_empty_gemm = lhs_size == 0 || rhs_size == 0; + if (is_empty_gemm && out_size > 0) { + zero_out_dptr_list.push_back(out_ptr); + zero_out_size_list.push_back(out_size * out_dtype_bytes); } // Set matrix data pointers auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape, out_dtype); + auto out_i = TensorWrapper(static_cast(out_ptr), out_shape_i, out_dtype); void *lhs_vptr = static_cast(lhs_ptr); void *rhs_vptr = static_cast(rhs_ptr); if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); + rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); else - rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); + rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); + lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); else - lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape); - - // Scale_inv shapes - auto lhs_sinv_size = std::vector{1}; - auto rhs_sinv_size = std::vector{1}; - if (is_mxfp8_scaling) { - NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", - MXFP8_BLOCK_SIZE, k); - size_t scale_k = k / MXFP8_BLOCK_SIZE; - lhs_sinv_size[0] = m_i * scale_k; - rhs_sinv_size[0] = n * scale_k; - // Need to add swizzle here - } + lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - // Set scale_inv pointers + // Set scale_inv shapes and pointers void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); - if (is_fp8_gemm) { + size_t lhs_sinv_size_i = 0; + size_t rhs_sinv_size_i = 0; + if (is_tensor_scaling) { + auto tensor_scaling_sinv_shape = std::vector{1}; + // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers + if (!is_empty_gemm) { + lhs_sinv_size_i = 1; + rhs_sinv_size_i = 1; + } if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); + rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); else - rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); + rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); else - lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); + lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); + } else if (is_mxfp8_scaling) { + auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); + void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); + + // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i + // point to swizzled scale_inv data (store on workspace, only used for GEMM). + // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers + auto lhs_sinv_shape_i = + get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); + auto rhs_sinv_shape_i = + get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); + lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; + rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; + if (lhs_use_colwise) { + lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + } else { + lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); + lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); + } + if (rhs_use_colwise) { + rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + } else { + rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); + rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); + } + + if (!is_empty_gemm) { + lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); + rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); + lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); + rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); + } } else { NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -212,16 +284,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Update pointer for the next GEMM pair - lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; - rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; - out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; + lhs_ptr += lhs_size * lhs_dtype_bytes; + rhs_ptr += rhs_size * rhs_dtype_bytes; + out_ptr += out_size * out_dtype_bytes; if (is_fp8_gemm) { - lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; + lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; + rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; + lhs_sinv_total_size += lhs_sinv_size_i; + rhs_sinv_total_size += rhs_sinv_size_i; + if (is_mxfp8_scaling) { + swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; + swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; + } } if (has_bias) bias_ptr += n * bias_dtype_bytes; // Move objects to the lists to keep them alive + if (is_empty_gemm) continue; lhs_wrapper_list.push_back(std::move(lhs_i)); rhs_wrapper_list.push_back(std::move(rhs_i)); out_wrapper_list.push_back(std::move(out_i)); @@ -244,10 +323,41 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type workspace_ptr += workspace_size; } + if (is_fp8_gemm) { + NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", + lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); + NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", + rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size); + } + + size_t num_non_empty_gemms = lhs_list.size(); + + if (is_mxfp8_scaling) { + for (int i = 0; i < num_non_empty_gemms; i++) { + // The i-th GEMM will use the (i % num_streams)-th stream to compute, + // use the same stream to swizzle the scaling factors to make sure that + // the swizzling is done before the GEMM computation starts. + int stream_id = i % num_streams; + cudaStream_t stream_i = nvte_get_compute_stream(stream_id); + nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); + nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); + } + } + + // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM + size_t num_zero_outs = zero_out_dptr_list.size(); + for (int i = 0; i < num_zero_outs; i++) { + int stream_id = i % num_streams; + cudaStream_t stream_i = nvte_get_compute_stream(stream_id); + void *dptr = zero_out_dptr_list[i]; + size_t count = zero_out_size_list[i]; + NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); + } + nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, - workspace_list.data(), accumulate, use_split_accumulator, - num_math_sm, stream); + pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, + lhs_is_trans, grad, workspace_list.data(), accumulate, + use_split_accumulator, num_math_sm, stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8834f4f73c..a318bfef68 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -287,7 +287,13 @@ def _grouped_dense_fwd_rule( "and k_contracting_dims=(1,) for now, " f"got {x_contracting_dims=} and {k_contracting_dims=}" ) - k_contracting_dims = (0,) + scaling_mode = quantizer_set.x.scaling_mode + if scaling_mode.is_tensor_scaling(): + k_contracting_dims = (0,) + elif scaling_mode.is_1d_block_scaling(): + k_contracting_dims = (1,) + else: + raise ValueError(f"Unsupported scaling mode {scaling_mode.value} for grouped_dense") casted_x = tex.grouped_quantize( x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x @@ -385,7 +391,7 @@ def _grouped_dense_bwd_rule( dgrad_grad = casted_grad.get_rowwise_tensor() dgrad_kernel_T = ctx_kernel - # We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work + # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work # after the extra transpose for FP8 in grouped_gemm # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? g_contracting_dim = (0,) From 4a16c2ddd654c6221cbc9059fb3335236c57125a Mon Sep 17 00:00:00 2001 From: Li Tao Date: Tue, 17 Jun 2025 01:03:56 +0800 Subject: [PATCH 24/39] [Pytorch] Bugfix in te fusion ce implementation (#1879) * Fix an issue when mcore uses te fusion ce implementation Signed-off-by: lit * simplify unit test code Signed-off-by: lit * Update tests/pytorch/test_parallel_cross_entropy.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: lit Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_parallel_cross_entropy.py | 18 +++++++++++------- .../pytorch/triton/cross_entropy.py | 16 ++++++++++++++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index fdb9b7f0b9..dd6c6a3b0f 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -61,22 +61,26 @@ def one_iteration_test( test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - if reduce_loss: - test_loss.backward() ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) + + # Handle backward pass based on the test scenario if reduce_loss: + test_loss.backward() ref_loss.backward() + else: + test_loss.sum().backward() + ref_loss.sum().backward() test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - torch.testing.assert_close(test_loss, ref_loss, check_dtype=False) if ignore_idx: print(test_loss, ref_loss) - if reduce_loss: - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) + + # Compare gradients when backward pass was called + torch.testing.assert_close( + torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad + ) self.input_test = None self.input_ref = None diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index a8001d2b63..45ff9f9c53 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -97,6 +97,7 @@ def cross_entropy_kernel( ignore_idx, n_cols, n_non_ignore, + reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -176,7 +177,13 @@ def cross_entropy_kernel( if label_smoothing > 0: # scale X beforehand to avoid overflow scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + # Scale gradients based on reduction mode + # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore + # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here + if reduce_loss: + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + else: + X_block = tl.exp(X_block - m) / d - eps tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written @@ -204,7 +211,11 @@ def cross_entropy_kernel( if y >= vocab_start_idx: if y < vocab_end_idx: X_y = tl.load(X_ptr + y - vocab_start_idx) - X_y += -(1 - label_smoothing) / (n_non_ignore) + # Apply the same conditional scaling logic for the target token + if reduce_loss: + X_y += -(1 - label_smoothing) / (n_non_ignore) + else: + X_y += -(1 - label_smoothing) tl.store(X_ptr + y - vocab_start_idx, X_y) tl.store(loss_ptr, loss) @@ -318,6 +329,7 @@ def cross_entropy_forward( ignore_idx=ignore_idx, n_cols=V, n_non_ignore=n_rows, + reduce_loss=reduce_loss, label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, From b894f69bffa93db24eb7819c14442189b89140f4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 17 Jun 2025 08:58:59 -0400 Subject: [PATCH 25/39] [JAX] Fixes for L0_jax_distributed_unittest (#1884) * include previously accidentally excluded tests * Execute run_test_multiprocessing_encoder with nested bash + exit code for inner bash shell * Adapt run_test_multiprocessing to handle segfault Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../run_test_multiprocessing_encoder.sh | 21 +++++++++---------- qa/L0_jax_distributed_unittest/test.sh | 10 ++++----- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 9003fd1edf..a21d5ecb57 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -6,13 +6,13 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} # Define the test cases to run TEST_CASES=( -# "test_te_bf16" +"test_te_bf16" "test_te_delayed_scaling_fp8" -# "test_te_current_scaling_fp8" -# "test_te_mxfp8" -# "test_te_bf16_shardy" +"test_te_current_scaling_fp8" +"test_te_mxfp8" +"test_te_bf16_shardy" "test_te_delayed_scaling_fp8_shardy" -# "test_te_current_scaling_fp8_shardy" +"test_te_current_scaling_fp8_shardy" ) echo @@ -40,21 +40,20 @@ for TEST_CASE in "${TEST_CASES[@]}"; do wait tail -n +7 "${TEST_CASE}_gpu_0.log" - tail -n +7 "${TEST_CASE}_gpu_0.log" # Check and print the log content accordingly - if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then - HAS_FAILURE=1 - echo "... $TEST_CASE FAILED" - elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then + if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE SKIPPED" elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE PASSED" else - echo "Invalid ${TEST_CASE}_gpu_0.log" + HAS_FAILURE=1 + echo "... $TEST_CASE FAILED" fi # Remove the log file after processing it + wait rm ${TEST_CASE}_gpu_*.log done +wait exit $HAS_FAILURE diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index b3b1684799..d9c46347fd 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -24,11 +24,11 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -# python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" -# wait -# python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" -# wait -. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" +wait +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" +wait +TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" From 82bff478b2bf8076cd3d4d87b29323237803c9a2 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 18 Jun 2025 07:47:18 -0400 Subject: [PATCH 26/39] [JAX] TensorUsage + FP8 GEMM with all layouts handling on BW (#1844) * TensorUsage + FP8 GEMM with all layouts handling on BW Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- tests/jax/test_custom_call_compute.py | 12 +-- transformer_engine/jax/cpp_extensions/gemm.py | 59 +++++----- transformer_engine/jax/dense.py | 41 ++++--- transformer_engine/jax/layernorm_dense.py | 21 ++-- transformer_engine/jax/layernorm_mlp.py | 49 +++++---- transformer_engine/jax/quantize/__init__.py | 1 + .../jax/quantize/device_utils.py | 34 ++++++ transformer_engine/jax/quantize/helper.py | 15 +-- transformer_engine/jax/quantize/quantizer.py | 24 +++-- .../jax/quantize/scaling_modes.py | 101 +++++++++++++++++- transformer_engine/jax/quantize/tensor.py | 83 ++++++-------- 11 files changed, 283 insertions(+), 157 deletions(-) create mode 100644 transformer_engine/jax/quantize/device_utils.py diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 54ceecdab6..349916cafe 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): else: assert_allclose(a.dequantize(), b, dtype=a.data.dtype) elif isinstance(a, ScaledTensor2x): - assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b) - assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b) + assert_dequantized_scaled_tensor(a.rowwise_tensor, b) + assert_dequantized_scaled_tensor(a.colwise_tensor, b) else: pytest.fail("a must be a ScaledTensor object") @@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor( dq_a_i = dq_a_i.reshape(b_i.shape) assert_allclose(dq_a_i, b_i, dtype=a.data.dtype) elif isinstance(a, ScaledTensor2x): - assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x) - assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x) - assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b) - assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b) + assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x) + assert isinstance(a.colwise_tensor, GroupedScaledTensor1x) + assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b) + assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b) else: pytest.fail("a must be a GroupedScaledTensor object") diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 94c05f5aa8..a6c58edb4a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,10 +24,11 @@ QuantizerSet, QuantizeLayout, noop_quantizer_set, + is_fp8_gemm_with_all_layouts_supported, ) -__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] +__all__ = ["gemm", "grouped_gemm"] num_cublas_streams = get_num_compute_streams() @@ -40,11 +41,6 @@ def get_cublas_workspace_size_bytes() -> None: return 4_194_304 -def is_gemm_with_all_layouts_supported() -> False: - """Return True if using blackwell, False otherwise.""" - return get_device_compute_capability(0) >= 100 - - class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -338,10 +334,15 @@ def _jax_gemm_fp8_impl(lhs, rhs): if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): if quantizer_set != noop_quantizer_set: assert type(quantizer_set.x) is type(quantizer_set.kernel) - (((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - # Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm) + if ( + quantizer_set.x.scaling_mode.is_tensor_scaling() + and is_fp8_gemm_with_all_layouts_supported() + ): + lhs_is_rowwise = rhs_is_rowwise = True + else: + (((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums + lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 + rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 lhs_q = quantizer_set.x.quantize( lhs, is_rowwise=lhs_is_rowwise, @@ -491,16 +492,13 @@ def grouped_gemm( assert type(quantizer_set.x) is type(quantizer_set.kernel) scaling_mode = quantizer_set.x.scaling_mode if ( - # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later - # scaling_mode.is_tensor_scaling() - # and is_gemm_with_all_layouts_supported() - scaling_mode.is_1d_block_scaling() + quantizer_set.x.scaling_mode.is_tensor_scaling() + and is_fp8_gemm_with_all_layouts_supported() ): - lhs_is_rowwise = True - rhs_is_rowwise = False + lhs_is_rowwise = rhs_is_rowwise = True else: lhs_is_rowwise = not lhs_is_trans - rhs_is_rowwise = lhs_is_trans + rhs_is_rowwise = rhs_is_trans quantizer_set.x.q_layout = ( QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE ) @@ -515,6 +513,8 @@ def grouped_gemm( rhs_data = rhs_q.data lhs_scale_inv = lhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv + lhs_shape = lhs_q.original_shape + rhs_shape = rhs_q.original_shape assert not ( lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 @@ -522,24 +522,35 @@ def grouped_gemm( # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs # thus additional transpose is required - # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later - if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported(): - lhs_is_trans = False - rhs_is_trans = True + if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported(): if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): lhs_layout_is_T = lhs.data_layout == "T" rhs_layout_is_T = rhs.data_layout == "T" else: lhs_layout_is_T = lhs_q.data_layout == "T" rhs_layout_is_T = rhs_q.data_layout == "T" + # we can't apply _shape_normalization on the grouped input + # thus we need to ensure that lhs is in N and rhs is in T + assert ( + lhs_is_trans == lhs_layout_is_T + ), "lhs input must be transposed before calling grouped_gemm" + assert ( + not rhs_is_trans == rhs_layout_is_T + ), "rhs input must be transposed before calling grouped_gemm" + lhs_is_trans = False + rhs_is_trans = True lhs_ndim = len(lhs_shape) rhs_ndim = len(rhs_shape) if lhs_layout_is_T: lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: - rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T) - rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T) + # For rhs [G, K, N], need to exclude the G dim from contract_dim + if group_sizes.size == rhs_shape[0]: + rhs_contract_dim = tuple( + (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim + ) + else: + rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) # Calling GroupedGEMM Custom Call K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index a318bfef68..57170e85be 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -19,6 +19,7 @@ QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, + TensorUsage, ) @@ -105,8 +106,8 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, # GEMM NN output = tex.gemm( - casted_x.get_rowwise_tensor(), - casted_kernel.get_colwise_tensor(), + casted_x.get_tensor(usage=TensorUsage.LHS), + casted_kernel.get_tensor(usage=TensorUsage.RHS), (x_contracting_dims, k_contracting_dims), ) @@ -116,8 +117,8 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, - casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, + casted_x.get_tensor(usage=TensorUsage.LHS_TRANS), + casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS), x.shape, kernel.shape, use_bias, @@ -138,8 +139,8 @@ def _dense_bwd_rule( fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims ( - colwise_casted_x, - rowwise_casted_kernel, + casted_x_lhs, + casted_kernel_rhs, x_shape, kernel_shape, use_bias, @@ -161,8 +162,8 @@ def _dense_bwd_rule( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) dgrad = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel, + casted_grad.get_tensor(usage=TensorUsage.LHS), + casted_kernel_rhs, (g_contracting_dim, k_contracting_dim), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) @@ -174,7 +175,9 @@ def _dense_bwd_rule( ) wgrad = tex.gemm( - colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) + casted_x_lhs, + casted_grad.get_tensor(usage=TensorUsage.RHS), + (x_contracting_dim, g_contracting_dim), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) @@ -287,13 +290,6 @@ def _grouped_dense_fwd_rule( "and k_contracting_dims=(1,) for now, " f"got {x_contracting_dims=} and {k_contracting_dims=}" ) - scaling_mode = quantizer_set.x.scaling_mode - if scaling_mode.is_tensor_scaling(): - k_contracting_dims = (0,) - elif scaling_mode.is_1d_block_scaling(): - k_contracting_dims = (1,) - else: - raise ValueError(f"Unsupported scaling mode {scaling_mode.value} for grouped_dense") casted_x = tex.grouped_quantize( x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x @@ -306,11 +302,10 @@ def _grouped_dense_fwd_rule( # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have # rowwise_casted_x.original_shape == (M, K) # colwise_casted_kernel.original_shape == (G, N, K) - grouped_gemm_x = casted_x.get_rowwise_tensor() - grouped_gemm_kernel = casted_kernel.get_colwise_tensor() - # TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()? - ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None - ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None + grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) + grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) + ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) + ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) output = tex.grouped_gemm( grouped_gemm_x, @@ -388,7 +383,7 @@ def _grouped_dense_bwd_rule( g_contracting_dim = (1,) k_contracting_dim = (2,) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_grad = casted_grad.get_rowwise_tensor() + dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS) dgrad_kernel_T = ctx_kernel # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work @@ -398,7 +393,7 @@ def _grouped_dense_bwd_rule( x_contracting_dim = (0,) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_x_T = ctx_x - wgrad_grad = casted_grad.get_colwise_tensor() + wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) dgrad = tex.grouped_gemm( dgrad_grad, diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 727ff78c2d..ea66e78302 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -21,6 +21,7 @@ QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, + TensorUsage, ) @@ -198,8 +199,8 @@ def _layernorm_dense_fwd_rule( # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out...) output = tex.gemm( - casted_ln_out.get_rowwise_tensor(), - casted_kernel.get_colwise_tensor(), + casted_ln_out.get_tensor(TensorUsage.LHS), + casted_kernel.get_tensor(TensorUsage.RHS), (x_contracting_dims, k_contracting_dims), ) @@ -209,8 +210,8 @@ def _layernorm_dense_fwd_rule( output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, - casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, + casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), + casted_kernel.get_tensor(TensorUsage.RHS_TRANS), x.shape, kernel.shape, mu, @@ -250,8 +251,8 @@ def _layernorm_dense_bwd_rule( Tuple of gradients for all input parameters """ ( - colwise_casted_ln_out, - rowwise_casted_kernel, + casted_ln_out, + casted_kernel, x_shape, kernel_shape, mu, @@ -281,8 +282,8 @@ def _layernorm_dense_bwd_rule( # NT GEMM dgrad = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel, + casted_grad.get_tensor(TensorUsage.LHS), + casted_kernel, (g_constracting_dim, k_constracting_dim), ) @@ -294,8 +295,8 @@ def _layernorm_dense_bwd_rule( # TN GEMM wgrad = tex.gemm( - colwise_casted_ln_out, - casted_grad.get_colwise_tensor(), + casted_ln_out, + casted_grad.get_tensor(TensorUsage.RHS), (x_constracting_dim, g_constracting_dim), ) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index e04b930233..18563fd255 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -22,7 +22,12 @@ from . import cpp_extensions as tex from .layernorm import canonicalize_norm_type -from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set +from .quantize import ( + with_sharding_constraint_by_logical_axes, + QuantizerSet, + noop_quantizer_set, + TensorUsage, +) from .sharding import get_non_contracting_logical_axes @@ -270,8 +275,8 @@ def _layernorm_mlp_fwd_rule( # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out) dot_1_output = tex.gemm( - casted_ln_out.get_rowwise_tensor(), - casted_kernel_1.get_colwise_tensor(), + casted_ln_out.get_tensor(TensorUsage.LHS), + casted_kernel_1.get_tensor(TensorUsage.RHS), (x_contracting_dims, k_contracting_dims), ) @@ -299,8 +304,8 @@ def _layernorm_mlp_fwd_rule( # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) dot_2_output = tex.gemm( - casted_act_out.get_rowwise_tensor(), - casted_kernel_2.get_colwise_tensor(), + casted_act_out.get_tensor(TensorUsage.LHS), + casted_kernel_2.get_tensor(TensorUsage.RHS), (x_contracting_dims, k_contracting_dims), ) @@ -317,11 +322,11 @@ def _layernorm_mlp_fwd_rule( rsigma, gamma, beta, - casted_ln_out.get_colwise_tensor(), - casted_kernel_1.get_rowwise_tensor(), + casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), + casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS), dot_1_output, - casted_act_out.get_colwise_tensor(), - casted_kernel_2.get_rowwise_tensor(), + casted_act_out.get_tensor(TensorUsage.LHS_TRANS), + casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS), x_contracting_dims, k_contracting_dims, kernel_1.shape, @@ -369,11 +374,11 @@ def _layernorm_mlp_bwd_rule( rsigma, gamma, beta, - colwise_casted_ln_out, - rowwise_casted_kernel_1, + casted_ln_out, + casted_kernel_1, dot_1_output, - colwise_casted_act_out, - rowwise_casted_kernel_2, + casted_act_out, + casted_kernel_2, x_contracting_dims_in_fwd, k_contracting_dims_in_fwd, kernel_1_shape, @@ -404,8 +409,8 @@ def _layernorm_mlp_bwd_rule( # NT GEMM # (batch..., hidden_out) x (hidden_in, hidden_out) dgrad_2 = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel_2, + casted_grad.get_tensor(TensorUsage.LHS), + casted_kernel_2, (g_contracting_dims_2, k_contracting_dims_2), ) @@ -418,8 +423,8 @@ def _layernorm_mlp_bwd_rule( # TN GEMM # (hidden, batch...,) x (hidden, batch...) wgrad_2 = tex.gemm( - colwise_casted_act_out, - casted_grad.get_colwise_tensor(), + casted_act_out, + casted_grad.get_tensor(TensorUsage.RHS), (x_contracting_dims, g_contracting_dims), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -433,7 +438,7 @@ def _layernorm_mlp_bwd_rule( ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim + dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim g_contracting_dims_1 = tuple( range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) ) @@ -444,8 +449,8 @@ def _layernorm_mlp_bwd_rule( # NT GEMM dgrad_1 = tex.gemm( - casted_dact_out.get_rowwise_tensor(), - rowwise_casted_kernel_1, + casted_dact_out.get_tensor(TensorUsage.LHS), + casted_kernel_1, (g_contracting_dims_1, k_contracting_dims_1), ) @@ -454,8 +459,8 @@ def _layernorm_mlp_bwd_rule( # TN GEMM # (hidden, batch...) x (hidden, batch...) wgrad_1 = tex.gemm( - colwise_casted_ln_out, - casted_dact_out.get_colwise_tensor(), + casted_ln_out, + casted_dact_out.get_tensor(TensorUsage.RHS), (x_contracting_dims, g_contracting_dims), ) diff --git a/transformer_engine/jax/quantize/__init__.py b/transformer_engine/jax/quantize/__init__.py index aa36df7a2f..11f692917f 100644 --- a/transformer_engine/jax/quantize/__init__.py +++ b/transformer_engine/jax/quantize/__init__.py @@ -15,3 +15,4 @@ from .scaling_modes import * from .metadata import * from .helper import * +from .device_utils import * diff --git a/transformer_engine/jax/quantize/device_utils.py b/transformer_engine/jax/quantize/device_utils.py new file mode 100644 index 0000000000..9f5d2f4587 --- /dev/null +++ b/transformer_engine/jax/quantize/device_utils.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Device utility functions for JAX quantization. + +This module provides utility functions for checking device capabilities and compatibility +for quantization operations in JAX. +""" + +import functools + +import transformer_engine_jax + +__all__ = [ + "get_device_compute_capability", + "is_fp8_gemm_with_all_layouts_supported", +] + + +@functools.lru_cache(maxsize=None) +def get_device_compute_capability(gpu_id: int = 0) -> int: + """ + Get the compute capability of the device. + """ + return transformer_engine_jax.get_device_compute_capability(gpu_id) + + +@functools.lru_cache(maxsize=None) +def is_fp8_gemm_with_all_layouts_supported() -> bool: + """Return True if using Blackwell architecture, False otherwise.""" + compute_capability = get_device_compute_capability() + return 100 <= compute_capability < 120 diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 13abf8bc06..c0617eafbb 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -15,17 +15,13 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from transformer_engine_jax import DType -from transformer_engine_jax import get_cublasLt_version -from transformer_engine_jax import ( - get_cuda_version, - get_device_compute_capability, -) +from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version from transformer_engine.common import recipe from transformer_engine.jax.sharding import global_shard_guard, MeshResource from .scaling_modes import ScalingMode from .. import cpp_extensions as tex +from .device_utils import get_device_compute_capability __all__ = [ "QuantizeConfig", @@ -203,7 +199,7 @@ class QuantizeConfig: FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients - IF_QUANTIZE_2X: Whether 2x quantization is enabled + INFERENCE_MODE: Whether to enable optimization for inference SCALING_MODE: Scaling mode AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_COMPUTE_ALGO: Algorithm for AMAX computation @@ -218,7 +214,7 @@ class QuantizeConfig: FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False - IF_QUANTIZE_2X: bool = False + INFERENCE_MODE: bool = False SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling @@ -246,7 +242,6 @@ def initialize(cls, fp8_recipe: recipe.Recipe) -> None: cls.FP8_FORMAT = fp8_recipe.fp8_format cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) - cls.IF_QUANTIZE_2X = True @classmethod def finalize(cls) -> None: @@ -260,7 +255,7 @@ def finalize(cls) -> None: cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_WGRAD = False cls.SCALING_MODE = ScalingMode.NO_SCALING - cls.IF_QUANTIZE_2X = False + cls.INFERENCE_MODE = False # DelayedScaling cls.AMAX_HISTORY_LEN = 1024 cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index aaac66e65c..881f3a74bb 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -23,6 +23,7 @@ QuantizeConfig, AmaxComputeAlgo, ) +from .device_utils import is_fp8_gemm_with_all_layouts_supported __all__ = [ "QuantizeLayout", @@ -607,9 +608,10 @@ def tree_flatten(self): def __post_init__(self): if self.quantizers[0] is None: - self.quantizers = QuantizerFactory.create( + quantizers = QuantizerFactory.create( self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout ) + self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers self.data_layout = self.quantizers[0].data_layout def _create_grouped_tensor_from_tensor_list( @@ -841,9 +843,11 @@ def _create_set( if is_2x2x: q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE else: - q_layout_x = QuantizeLayout.ROWWISE - q_layout_kernel = QuantizeLayout.COLWISE - q_layout_dgrad = None + q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE + if scaling_mode.is_1d_block_scaling(): + q_layout_kernel = QuantizeLayout.COLWISE + if QuantizeConfig.INFERENCE_MODE: + q_layout_dgrad = None if "quantize_meta_set" in kwargs: quantize_meta_set = kwargs.get("quantize_meta_set") @@ -898,7 +902,15 @@ def create_set( scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE - is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X + if is_2x2x is None: + if scaling_mode.is_1d_block_scaling(): + is_2x2x = True + elif scaling_mode.is_tensor_scaling(): + is_2x2x = not is_fp8_gemm_with_all_layouts_supported() + else: # NO_SCALING ignores is_2x2x for now + is_2x2x = False + is_inference_mode = QuantizeConfig.INFERENCE_MODE + assert not is_inference_mode, "Inference mode is not supported yet!" q_set = [] for _ in range(n_quantizer_sets): @@ -911,4 +923,4 @@ def create_set( return q_set[0] if len(q_set) == 1 else tuple(q_set) -noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) +noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING, is_2x2x=False) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index c26802c39c..f45a05a399 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from enum import Enum from typing import Tuple, Dict -from functools import reduce +from functools import reduce, lru_cache import operator import numpy as np @@ -21,10 +21,44 @@ from jax.tree_util import register_pytree_node_class import jax.numpy as jnp -from transformer_engine_jax import JAXX_Scaling_Mode +from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout +from .device_utils import is_fp8_gemm_with_all_layouts_supported -__all__ = ["QuantizeShardyRules", "ScalingMode"] +__all__ = [ + "QuantizeShardyRules", + "ScalingMode", + "TensorUsage", +] + + +class TensorUsage(Enum): + """Enum indicating tensor usage in GEMM operations. + + Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form. + The tensor usage can be: + - LHS: A is in the normal form + - LHS_TRANS: A is in the transposed form + - RHS: B is in the normal form + - RHS_TRANS: B is in the transposed form + + The tensor usage is used in the ScaledTensor.get_tensor() method. + """ + + # LHS: Left-hand side, RHS: Right-hand side + # LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed + LHS = 0 + LHS_TRANS = 1 + RHS = 2 + RHS_TRANS = 3 + + def __eq__(self, other): + if not isinstance(other, TensorUsage): + return False + return self.value == other.value + + def __hash__(self): + return hash(self.value) def DIVUP(a, b): @@ -104,6 +138,18 @@ def get_grouped_scale_shape( The shape for scale tensors """ + @lru_cache(maxsize=4) + @abstractmethod + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + @abstractmethod def get_shardy_sharding_rules( self, input_rank, unique_var, flatten_axis @@ -157,6 +203,23 @@ def get_scale_shape( return (0,) return (1,) + @lru_cache(maxsize=4) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + if is_fp8_gemm_with_all_layouts_supported(): + return QuantizeLayout.ROWWISE + + if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS): + return QuantizeLayout.ROWWISE + return QuantizeLayout.COLWISE + def get_grouped_scale_shape( self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: @@ -321,6 +384,27 @@ def get_scale_shape( return (*first_dim_scale_shape, *last_dim_scale_shape) + @lru_cache(maxsize=4) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + # If we need to support 1x1x for inference in the future + # if QuantizeConfig.INFERENCE_MODE: + # assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!") + # if usage == TensorUsage.LHS: + # return QuantizeLayout.ROWWISE + # return QuantizeLayout.COLWISE + + if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS): + return QuantizeLayout.ROWWISE + return QuantizeLayout.COLWISE + def get_grouped_scale_shape( self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: @@ -506,6 +590,17 @@ def get_scale_shape( """ return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + return self._get_impl().get_quantize_layout(usage) + def get_shardy_sharding_rules( self, input_rank, unique_var, flatten_axis=-1 ) -> Tuple[Tuple[str]]: diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 02b1a1a99e..633be237f9 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -17,13 +17,14 @@ from transformer_engine_jax import QuantizeLayout -from .scaling_modes import ScalingMode +from .scaling_modes import ScalingMode, TensorUsage from .dequantizer import ScalingModeToDequantizerMap from ..sharding import ( with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, ) __all__ = [ + "TensorUsage", "ScaledTensor", "ScaledTensor1x", "ScaledTensor2x", @@ -64,25 +65,15 @@ def dequantize(self): """ @abstractmethod - def get_rowwise_tensor(self): - """Returns the row-wise component of the tensor. + def get_tensor(self, usage: TensorUsage): + """Returns the appropriate tensor based on the tensor usage and the scaling mode. + If the tensor usage is not valid for the scaling mode, an error is raised. - Returns: - The row-wise tensor component - - Raises: - ValueError: If called on a tensor that doesn't support row-wise access - """ - - @abstractmethod - def get_colwise_tensor(self): - """Returns the column-wise component of the tensor. + Args: + usage: The usage of the tensor Returns: - The column-wise tensor component - - Raises: - ValueError: If called on a tensor that doesn't support column-wise access + The tensor based on the usage """ @abstractmethod @@ -181,33 +172,19 @@ def dequantize(self): """ return self._dq_func(self) - def get_rowwise_tensor(self): - """Returns the tensor if it's row-wise quantized. - - Returns: - The row-wise tensor + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout = self.scaling_mode.get_quantize_layout(usage) + colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise + rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise - Raises: - ValueError: If called on a column-wise quantized tensor - """ - if not self.is_colwise: + if colwise_usage_valid or rowwise_usage_valid: return self - raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!") - - def get_colwise_tensor(self): - """Returns the tensor if it's column-wise quantized. - - Returns: - The column-wise tensor - - Raises: - ValueError: If called on a row-wise quantized tensor - """ - if self.is_colwise: - return self - - raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!") + raise ValueError( + f"Calling get_tensor() with usage {usage} is not valid for this tensor as" + f" self.is_colwise={self.is_colwise}!" + ) def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): """Applies sharding constraints to a tensor based on logical axis names. @@ -378,21 +355,21 @@ def dequantize(self): """ return self.rowwise_tensor.dequantize() - def get_rowwise_tensor(self): - """Returns the row-wise quantized component. + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage) + q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage) - Returns: - The row-wise tensor component - """ - return self.rowwise_tensor + if q_layout_rowwise == QuantizeLayout.ROWWISE: + return self.rowwise_tensor - def get_colwise_tensor(self): - """Returns the column-wise quantized component. + if q_layout_colwise == QuantizeLayout.COLWISE: + return self.colwise_tensor - Returns: - The column-wise tensor component - """ - return self.colwise_tensor + raise ValueError( + f"Calling get_tensor() with usage {usage} is not valid for this tensor as" + f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!" + ) def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): """Applies sharding constraints to a tensor based on logical axis names. From 9192fb62c666102ba3c79edda2ef86e6d915166d Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 18 Jun 2025 17:20:50 -0700 Subject: [PATCH 27/39] [PyTorch] Use FP16 tols for distributed tests with TF32 compute (#1831) * Use FP16 tols for tests with TF32 Signed-off-by: Tim Moon * Use uniform init instead of constant init Signed-off-by: Tim Moon * Revert constant init test, but reduce value Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/distributed/run_numerics.py | 10 +++------- tests/pytorch/distributed/test_numerics.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 3c3c807a90..1e34b06632 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -47,11 +47,6 @@ ) -# Disable TF32 -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - - # Quantization recipe setup def quantization_recipe() -> Recipe: if QUANTIZATION == "fp8": @@ -166,7 +161,7 @@ def backward(ctx, grad_output): def _constant(tensor): - return nn.init.constant_(tensor, 0.5) + return nn.init.constant_(tensor, 0.05) def dist_print(msg, src=None, end="\n", error=False): @@ -189,7 +184,8 @@ def _get_tolerances(dtype): if dtype == torch.bfloat16: return {"rtol": 1.6e-2, "atol": 1e-5} if dtype == torch.float32: - return {"rtol": 1.2e-4, "atol": 1e-4} + # TF32 has same mantissa bits as FP16 + return {"rtol": 1e-3, "atol": 1e-5} raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 632f50e90a..1ff5aff997 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -56,7 +56,7 @@ def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "fp8_cs" and not fp8_available: - pytest.skip(fp8_available) + pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: From 1e038827350ff7514d860f8b7879438357906a62 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 24 Jun 2025 17:11:57 -0700 Subject: [PATCH 28/39] Fix cppunittest test.sh for editable installs (#1869) * Fix cppunittest test.sh for editable installs Signed-off-by: Jeremy Berchtold * Update tests/cpp/CMakeLists.txt Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Jeremy Berchtold Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- qa/L0_cppunittest/test.sh | 2 +- tests/cpp/CMakeLists.txt | 8 +++++--- tests/cpp/operator/CMakeLists.txt | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index df8a48b662..cd46b0b63c 100755 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -6,7 +6,7 @@ set -e # Find TE : ${TE_PATH:=/opt/transformerengine} -TE_LIB_PATH=`pip3 show transformer-engine | grep Location | cut -d ' ' -f 2` +TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH # Set parallelization parameters diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index afc80cba43..eb2825ba41 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -26,11 +26,13 @@ enable_testing() include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) if(NOT DEFINED TE_LIB_PATH) - execute_process(COMMAND bash -c "pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" - OUTPUT_VARIABLE TE_LIB_PATH) + execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'" + OUTPUT_VARIABLE TE_LIB_FILE + OUTPUT_STRIP_TRAILING_WHITESPACE) + get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/transformer_engine" ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0b0e615495..b680389a35 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -22,6 +22,7 @@ add_executable(test_operator test_act.cu test_normalization.cu test_normalization_mxfp8.cu + test_memset.cu test_multi_cast_transpose.cu test_multi_padding.cu test_causal_softmax.cu From 6f6951e0d67d21743f52e5142c3a40bc5e4aa5f5 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu <42691305+zhongbozhu@users.noreply.github.com> Date: Wed, 25 Jun 2025 22:02:40 -0700 Subject: [PATCH 29/39] [PyTorch][MoE] Reduce CPU Overhead By Fuse Torch Empty Calls (#1793) * finish python ref impl for bulk alloc Signed-off-by: zhongboz * c++ bulk alloc worked, still draft version Signed-off-by: zhongboz * clean up Signed-off-by: zhongboz * resolve rebase conflict Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add license Signed-off-by: zhongboz * use shared_ptr to auto manage reference count Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * attempt to fix misc training error Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * attempt to handle case where experts get zero token Signed-off-by: zhongboz * updated with fused C++ function calls Signed-off-by: zhongboz * clean up Signed-off-by: zhongboz * experiment with reducing py object construction time Signed-off-by: zhongboz * fix seg fault bug in inference mode Signed-off-by: zhongboz * fix lint Signed-off-by: zhongboz * fuse torch split into bulk alloc Signed-off-by: zhongboz * clean up Signed-off-by: zhongboz * rebase to latest main Signed-off-by: zhongboz * fix unit test failure Signed-off-by: zhongboz * fix lint error Signed-off-by: zhongboz * refactor create_tensor to use get_scale_shape Signed-off-by: zhongboz * refactor quantize to call quantize_cpp Signed-off-by: zhongboz * Implement separate functions for multi-tensor quantize and split + multi-tensor quantize Signed-off-by: Tim Moon * Update grouped linear module with fused split+quantize func Signed-off-by: Tim Moon * Move multi-tensor quantize func to cast.cpp Signed-off-by: Tim Moon * Do not expose quantizer helper function externally Signed-off-by: Tim Moon * Fix linter warnings Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert cuDNN frontend commit Signed-off-by: Tim Moon * fix corner cases with zero tokens Signed-off-by: zhongboz * add comments Signed-off-by: zhongboz --------- Signed-off-by: zhongboz Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon --- benchmarks/linear/benchmark_grouped_linear.py | 241 +++++++++ transformer_engine/pytorch/csrc/common.h | 2 + transformer_engine/pytorch/csrc/extensions.h | 13 +- .../pytorch/csrc/extensions/cast.cpp | 482 +++++++++++++++--- .../pytorch/csrc/extensions/pybind.cpp | 13 +- .../pytorch/csrc/extensions/transpose.cpp | 77 +-- transformer_engine/pytorch/csrc/quantizer.cpp | 132 +++-- .../pytorch/module/grouped_linear.py | 105 ++-- .../_internal/float8_blockwise_tensor_base.py | 4 +- .../pytorch/tensor/float8_blockwise_tensor.py | 32 ++ 10 files changed, 864 insertions(+), 237 deletions(-) create mode 100644 benchmarks/linear/benchmark_grouped_linear.py diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py new file mode 100644 index 0000000000..f4af193669 --- /dev/null +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -0,0 +1,241 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import torch +import torch.utils.benchmark as benchmark +import pandas as pd +import pathlib + +from transformer_engine.pytorch.module import GroupedLinear +from transformer_engine.common.recipe import Float8BlockScaling +from transformer_engine.pytorch.fp8 import fp8_autocast +from contextlib import nullcontext + +RECIPES = { + "bf16": None, + "fp8_sub_channel": Float8BlockScaling(), +} + + +def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None): + assert mode in ["fwd_only", "fwd_bwd"] + fp8_context = ( + fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext() + ) + # print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}") + + if mode == "fwd_only": + with torch.no_grad(), fp8_context: + for i in range(run_num_steps): + y_q = layer.forward( + x, + m_splits, + is_first_microbatch=(i == 0), + ) + return y_q + else: + # reset gradients + layer.zero_grad() + x.grad = None + + with fp8_context: + for i in range(run_num_steps): + label = f"step_{i}" + torch.cuda.nvtx.range_push(label) + y_q = layer.forward( + x, + m_splits, + is_first_microbatch=(i == 0), + ) + y_q.backward(gradient) + torch.cuda.nvtx.range_pop() + + grads_q = [] + grads_q.append(x.grad) + # remaining derivatives are in respect to model parameters + for p in layer.parameters(): + if p.requires_grad: + grads_q.append(p.grad) + + return y_q, grads_q + + +def benchmark_linear( + x, + ws, + m_splits, + bias, + recipe_name, + mode, + num_gemms=4, +): + params_dtype = torch.bfloat16 + recipe = RECIPES[recipe_name] + + in_features = x.shape[1] + out_features = ws[0].shape[0] + gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device) + + layer = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + ) + + layer = layer.to("cuda") + with torch.no_grad(): + for i in range(num_gemms): + weight_i = getattr(layer, f"weight{i}") + weight_i.copy_(ws[i]) + if bias is not None: + bias_i = getattr(layer, f"bias{i}") + bias_i.copy_(bias) + + num_microbatches = 32 + + label = f"{recipe_name}_{'grouped'}" + torch.cuda.nvtx.range_push(label) + timing = benchmark.Timer( + stmt=( + "run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches," + " recipe)" + ), + globals={ + "run_linear_multiple_steps": run_linear_multiple_steps, + "layer": layer, + "x": x, + "m_splits": m_splits, + "mode": mode, + "gradient": gradient, + "num_microbatches": num_microbatches, + "recipe": recipe, + }, + num_threads=1, + ).blocked_autorange(min_run_time=5) + print(f"{recipe_name}: {timing} \n") + timing_ms = timing.median * 1000 / num_microbatches + + return timing_ms + + +def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): + data = [] + assert not use_bias, "Bias is not supported for GroupedLinear benchmark" + + print(f"========== Benchmarking {recipe_name} ==========") + for m, k, n in mkns: + device = "cuda" + x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) + ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)] + assert m % num_gemms == 0 + m_splits = [m // num_gemms] * num_gemms + # Bias is not supported for GroupedLinear benchmark + bias = None + + # Run the benchmark + print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") + + grouped_fwd_bwd_timing_ms = benchmark_linear( + x, + ws, + m_splits, + bias, + recipe_name, + mode="fwd_bwd", + num_gemms=num_gemms, + ) + + # Append the results + data.append( + [ + m, + k, + n, + recipe_name, + num_gemms, + grouped_fwd_bwd_timing_ms, + ] + ) + + df = pd.DataFrame( + data=data, + columns=[ + "m", + "k", + "n", + "recipe", + "num_gemms", + "grouped_fwd_bwd_time_ms", + ], + ) + + print(df, "\n") + return df + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") + parser.add_argument( + "--output_dir", + type=str, + default="benchmark_output/", + help="output path for report", + ) + args = parser.parse_args() + + use_bias = False + # Set the MKN values to benchmark + mkns = [] + for m in [1024]: + # for m in [4096, 8192, 16384]: + # for n in [1024, 2048, 4096, 8192, 16384]: + for n in [3072]: + for k in [4096]: + mkns.append((m, k, n)) + + # recipe_list = [ + # "bf16", "fp8_sub_channel", + # ] + recipe_list = [ + "fp8_sub_channel", + ] + + # num_gemms_list = [16, 32] + num_gemms_list = [4] + + if args.profile: + # nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_1_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile + # nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_32_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile + # nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_8_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile + # nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_2_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile + mkns = [(4096, 4096, 4096)] + recipe_list = ["fp8_sub_channel"] + # recipe_list = ["bf16"] + num_gemms_list = [8] + torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + + # Initialize a dataframe to store the results + df_linears = pd.DataFrame() + + # Run the fp8 benchmarks + for num_gemms in num_gemms_list: + print(f"========== Benchmarking with num_gemms={num_gemms} ==========") + for recipe_name in recipe_list: + df = run_benchmark_linear( + mkns, + recipe_name, + use_bias, + num_gemms=num_gemms, + ) + df_linears = pd.concat([df_linears, df]) + + print(df_linears) + + if args.profile: + torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 1dcb4e4e45..d8c08651f2 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -197,6 +197,8 @@ class Float8BlockQuantizer : public Quantizer { std::pair create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data = std::nullopt) const override; + + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; class MXFP8Quantizer : public Quantizer { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 72f6f27596..4af7576c5f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -108,10 +108,6 @@ std::optional> te_general_grouped_gemm( * Transpose **************************************************************************************************/ -std::vector fused_multi_quantize(std::vector input_list, - std::optional> output_list, - std::vector quantizer_list, DType otype); - at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output = std::nullopt); @@ -182,10 +178,17 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w **************************************************************************************************/ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, - std::optional noop); + std::optional noop_flag); py::object dequantize(const py::handle &input, DType otype); +std::vector multi_tensor_quantize(const std::vector &tensor_list, + std::vector quantizer_list); + +std::vector split_quantize(const at::Tensor &tensor, + const std::vector &split_sections, + std::vector quantizer_list); + /*************************************************************************************************** * Bias gradient fusions **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 93fae74b63..4be2a8880e 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -6,60 +6,51 @@ #include "transformer_engine/cast.h" +#include +#include +#include +#include +#include +#include + #include "../extensions.h" #include "common.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" -namespace transformer_engine::pytorch { - -py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output, - std::optional noop) { - init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = tensor.contiguous(); +namespace transformer_engine { +namespace pytorch { - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - auto fake_tensor_type = tensor.scalar_type(); - if (!detail::IsFloatingPointType(fake_tensor_type)) { - fake_tensor_type = at::kFloat; - } +namespace { - TensorWrapper te_output; - py::object out; - if (output.is_none()) { - DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); - std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type); - } else { - out = output; - te_output = makeTransformerEngineTensor(output, quantizer); - } +std::vector get_tensor_shape(const TensorWrapper &tensor) { + const auto &shape = tensor.shape(); + return std::vector(shape.data, shape.data + shape.ndim); +} - TensorWrapper te_noop; - if (noop.has_value()) { - te_noop = makeTransformerEngineTensor(*noop); - } else { - te_noop = TensorWrapper(); +void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py, + std::unique_ptr &quantizer_cpp, TensorWrapper &output, + TensorWrapper &noop_flag) { + // Check tensor dims + NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output), + "Input tensor (shape=", get_tensor_shape(input), + ") and output tensor (shape=", get_tensor_shape(output), ") do not match"); + if (input.numel() == 0) { + return; } - if (te_output.numel() == 0) return out; - + // Recipe-specific configuration QuantizationConfigWrapper quant_config; - quant_config.set_noop_tensor(te_noop.data()); - - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); + quant_config.set_noop_tensor(noop_flag.data()); + if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { + auto my_quantizer_cs = static_cast(quantizer_cpp.get()); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); }); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; std::vector tensors = {amax_tensor_torch}; // allreduce amax tensor c10d::AllreduceOptions allreduce_opts; @@ -72,37 +63,70 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); + nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - auto my_quantizer_bw = static_cast(my_quantizer.get()); + // set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel + output.set_amax(nullptr, DType::kFloat32, output.defaultShape); + } else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) { + auto my_quantizer_bw = static_cast(quantizer_cpp.get()); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); if (my_quantizer_bw->all_gather_usage) { quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); } } + + // Perform quantization NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); + nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); +} - return out; +} // namespace + +py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, + std::optional noop_flag) { + // Convert quantizer to C++ object + auto quantizer_cpp = convert_quantizer(quantizer); + + // Convert input tensor to C++ object + auto input_contiguous = tensor.contiguous(); + const auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + // Initialize output tensor + TensorWrapper output_cpp; + py::object output_py; + if (output.is_none()) { + const auto shape = get_tensor_shape(input_cpp); + const auto fake_dtype = input_cpp.dtype(); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + } else { + output_py = output; + output_cpp = makeTransformerEngineTensor(output_py, quantizer); + } + + // Initialize no-op flag + TensorWrapper noop_flag_cpp; + if (noop_flag.has_value()) { + noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); + } + + // Perform quantization + quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp); + + return output_py; } -py::object dequantize(const py::handle& input, transformer_engine::DType otype) { +py::object dequantize(const py::handle &input, transformer_engine::DType otype) { init_extension(); const auto none = py::none(); - const auto& input_tensor = makeTransformerEngineTensor(input, none); + const auto &input_tensor = makeTransformerEngineTensor(input, none); NoneQuantizer q(none); - const auto& shape = convertShape(input_tensor.shape()); + const auto &shape = convertShape(input_tensor.shape()); auto [out_tensor, out] = q.create_tensor(shape, otype); @@ -113,9 +137,348 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype) return out; } +namespace { + +void multi_tensor_quantize_impl(const std::vector &input_list, + std::vector &quantizer_py_list, + std::vector> &quantizer_cpp_list, + std::vector &output_list) { + // Check number of tensors + const size_t num_tensors = input_list.size(); + NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors, + " Python quantizers, but got ", quantizer_py_list.size()); + NVTE_CHECK(quantizer_cpp_list.size() == num_tensors, "Expected ", num_tensors, + " C++ quantizers, but got ", quantizer_cpp_list.size()); + NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors, + " output tensors, but got ", output_list.size()); + + // Choose implementation + // Note: Currently only have fused kernel for FP8 delayed scaling + bool with_fused_kernel = true; + for (size_t i = 0; i < num_tensors; i++) { + if (!detail::IsFloat8Quantizers(quantizer_py_list[i].ptr())) { + with_fused_kernel = false; + break; + } + if (nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) { + with_fused_kernel = false; + break; + } + } + + // Launch TE kernel + if (with_fused_kernel) { + // Fused kernel for multi-tensor quantize + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + for (size_t i = 0; i < num_tensors; ++i) { + nvte_tensor_input_list.push_back(input_list[i].data()); + nvte_tensor_output_list.push_back(output_list[i].data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); + }); + } else { + // Quantize kernels individually + TensorWrapper dummy_noop_flag; + for (size_t i = 0; i < num_tensors; ++i) { + quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i], + dummy_noop_flag); + } + } +} + +} // namespace + +std::vector multi_tensor_quantize(const std::vector &tensor_list, + std::vector quantizer_list) { + // Check number of tensors + const size_t num_tensors = tensor_list.size(); + NVTE_CHECK(quantizer_list.size() == num_tensors, "Expected ", num_tensors, + " quantizers, but got ", quantizer_list.size()); + + // Convert quantizers to C++ objects + std::vector> quantizer_cpp_list; + for (size_t i = 0; i < num_tensors; i++) { + quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); + } + + // Initialize input and output tensors + std::vector input_cpp_list; + std::vector output_cpp_list; + std::vector output_py_list; + for (size_t i = 0; i < num_tensors; ++i) { + // Convert input tensor to C++ object + const auto &input_py = tensor_list[i]; + NVTE_CHECK(input_py.is_contiguous(), "Input tensor ", i, " is not contiguous"); + input_cpp_list.emplace_back(makeTransformerEngineTensor(input_py)); + const auto &input_cpp = input_cpp_list.back(); + const auto input_shape = input_cpp.shape(); + const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); + + // Construct output tensor + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + output_cpp_list.emplace_back(std::move(output_cpp)); + output_py_list.emplace_back(std::move(output_py)); + } + + // Perform multi-tensor quantization + multi_tensor_quantize_impl(input_cpp_list, quantizer_list, quantizer_cpp_list, output_cpp_list); + + return output_py_list; +} + +namespace { + +std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( + std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { + init_extension(); + std::tuple, std::vector> retval; + auto &tensor_py_list = std::get<0>(retval); + auto &tensor_cpp_list = std::get<1>(retval); + + // Number of tensors + const size_t num_tensors = shape_list.size(); + if (num_tensors == 0) { + return retval; + } + + // Quantization parameters + const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; + const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); + const auto is_2D_scaled = scaling_mode == NVTE_BLOCK_SCALING_2D; + const auto fp8_dtype = quantizer_cpp_list[0]->dtype; + constexpr size_t fp8_elem_size = 1; + constexpr size_t scale_elem_size = 4; + + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + size_t offset, at::ScalarType dtype) -> at::Tensor { + std::vector shape_int64(shape.begin(), shape.end()); + // in the case where full buffer is empty because local rank receives no tokens for all the experts + // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob + // but in the case where some experts receive tokens, some not, we want to leverage from_blob + // as much as possible to avoid CPU overhead + if (buffer->data_ptr() == nullptr) { + return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); + } + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); + }; + + // Allocate row-wise data + std::vector rowwise_data_list, rowwise_scale_list; + std::vector> rowwise_data_shapes, rowwise_scale_shapes; + if (rowwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_shapes.emplace_back(shape_list[i]); + rowwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_list.emplace_back( + make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); + rowwise_scale_list.emplace_back( + make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + } + } + + // Allocate column-wise data + std::vector columnwise_data_list, columnwise_scale_list; + std::vector> columnwise_data_shapes, columnwise_scale_shapes; + if (columnwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + columnwise_data_shapes.emplace_back(); + auto &shape = columnwise_data_shapes.back(); + shape.push_back(shape_list[i].back()); + for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { + shape.push_back(shape_list[i][j]); + } + columnwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + columnwise_data_list.emplace_back( + make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); + columnwise_scale_list.emplace_back( + make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + } + } + + // Construct FP8 block-wise tensors + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + for (size_t i = 0; i < num_tensors; ++i) { + // Create tensor objects with proper reference counting + py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); + py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); + py::object columnwise_data = + (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); + py::object columnwise_scale = + (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); + + // Construct Python tensor + tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( + rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, + quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); + + // Construct C++ tensor + tensor_cpp_list.emplace_back(makeTransformerEngineTensor( + rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp8_dtype, nullptr, + nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode)); + } + + return retval; +} + +} // namespace + +std::vector split_quantize(const at::Tensor &tensor, + const std::vector &split_sections, + std::vector quantizer_list) { + init_extension(); + + // Check number of tensors + const size_t num_splits = split_sections.size(); + NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ", + quantizer_list.size()); + if (num_splits == 0) { + return {}; + } + + // Input tensor properties + auto input_py = tensor.contiguous(); + uint8_t *input_dptr = reinterpret_cast(input_py.data_ptr()); + auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); + std::vector input_shape; + size_t input_size = 1; + for (const auto &d : input_py.sizes()) { + input_shape.push_back(d); + input_size *= d; + } + NVTE_CHECK(input_shape.size() > 0, "Input tensor has 0 dims"); + + // Split input tensor along dim 0 + std::vector input_list; + std::vector> split_shapes; + size_t dim0_offset = 0; + const size_t dim0_stride = + input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0]; + for (size_t i = 0; i < num_splits; ++i) { + NVTE_CHECK(split_sections[i] >= 0, "Attempted to split tensor with shape=", input_shape, + " along dim 0 with split_sections=", split_sections); + NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0], + "Attempted to split tensor with shape=", input_shape, + " along dim 0 with split_sections=", split_sections); + split_shapes.push_back(input_shape); + auto &split_shape = split_shapes.back(); + split_shape[0] = split_sections[i]; + void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); + input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); + dim0_offset += split_sections[i]; + } + + // Convert quantizers to C++ objects + std::vector> quantizer_cpp_list; + for (size_t i = 0; i < num_splits; i++) { + quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); + } + + // For FP8 block-scaling, we construct output tensors with bulk allocations + bool use_fused_bulk_alloc = true; + for (size_t i = 0; i < quantizer_list.size(); i++) { + if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr())) { + use_fused_bulk_alloc = false; + break; + } + } + + // Allocate output tensors + std::vector output_cpp_list; + std::vector output_py_list; + if (!use_fused_bulk_alloc) { + // Allocate output tensors individually + for (size_t i = 0; i < num_splits; ++i) { + auto [output_cpp, output_py] = + quantizer_cpp_list[i]->create_tensor(split_shapes[i], input_dtype); + output_cpp_list.emplace_back(std::move(output_cpp)); + output_py_list.emplace_back(std::move(output_py)); + } + } else { + // FP8 block-scaling: construct output tensors with bulk allocations + std::vector blockwise_quantizers; + for (auto &quantizer : quantizer_cpp_list) { + blockwise_quantizers.push_back(static_cast(quantizer.get())); + } + std::tie(output_py_list, output_cpp_list) = + bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers); + } + + // Perform multi-tensor quantization + multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); + + return output_py_list; +} + template -std::vector dbias_dact(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dact(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { init_extension(); auto my_quantizer = convert_quantizer(quantizer); @@ -125,7 +488,7 @@ std::vector dbias_dact(const at::Tensor& grad_output, const at::Tens auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); auto act_input_tensor = makeTransformerEngineTensor(act_input); - const auto& shape = convertShape(grad_tensor.shape()); + const auto &shape = convertShape(grad_tensor.shape()); auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); auto dbias_tensor = makeTransformerEngineTensor(grad_bias); @@ -149,29 +512,30 @@ std::vector dbias_dact(const at::Tensor& grad_output, const at::Tens return {py::cast(grad_bias), dact}; } -std::vector dbias_dgelu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -std::vector dbias_dsilu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -std::vector dbias_drelu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -std::vector dbias_dqgelu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -std::vector dbias_dsrelu(const at::Tensor& grad_output, const at::Tensor& act_input, +std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer) { return dbias_dact(grad_output, act_input, quantizer); } -} // namespace transformer_engine::pytorch +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 63c3b434d3..8f06883807 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -12,7 +12,9 @@ #include #include -#include +#include +#include +#include #include "../common.h" #include "../extensions.h" @@ -199,10 +201,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); - m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize, - "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), - py::arg("quantizer_list"), py::arg("otype")); - + m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize, + "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); + m.def("split_quantize", &transformer_engine::pytorch::split_quantize, + "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), + py::arg("quantizer_list")); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 637dc7a94c..d2f7107fe5 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -4,80 +4,16 @@ * See LICENSE for license information. ************************************************************************/ +#include + #include +#include #include "../extensions.h" #include "pybind.h" -namespace transformer_engine::pytorch { - -std::vector fused_multi_quantize(std::vector input_list, - std::optional> output_list, - std::vector quantizer_list, DType otype) { - init_extension(); - std::vector nvte_tensor_input_list; - std::vector nvte_tensor_output_list; - std::vector py_output_objects_list; - std::vector tensor_wrappers; - if (output_list.has_value()) { - py_output_objects_list = output_list.value(); - } - - // Choose implementation - // Note: Currently only have fused kernel for FP8 cast-transpose - bool with_fused_kernel = true; - - // create TE tensors from input - for (size_t i = 0; i < input_list.size(); i++) { - auto input_tensor = makeTransformerEngineTensor(input_list[i]); - const NVTEShape input_shape = input_tensor.shape(); - - TensorWrapper output_tensor; - - if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) { - with_fused_kernel = false; - } - if (output_list == std::nullopt) { - std::unique_ptr quantizer = convert_quantizer(quantizer_list[i]); - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - py::object o; - std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype); - py_output_objects_list.push_back(o); - } else { - output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]); - } - if (input_tensor.numel() == 0) continue; - - nvte_tensor_output_list.emplace_back(output_tensor.data()); - nvte_tensor_input_list.emplace_back(input_tensor.data()); - tensor_wrappers.emplace_back(std::move(input_tensor)); - tensor_wrappers.emplace_back(std::move(output_tensor)); - } - - // Check tensor lists - NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), - "Number of input and output tensors must match"); - - for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { - if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) { - with_fused_kernel = false; - break; - } - } - - // Launch TE kernel - if (with_fused_kernel) { - NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), - nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); - }); - } else { - for (size_t i = 0; i < py_output_objects_list.size(); i++) { - quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt); - } - } - return py_output_objects_list; -} +namespace transformer_engine { +namespace pytorch { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { init_extension(); @@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data) const { using namespace pybind11::literals; std::vector torch_shape; - size_t numel = 1; for (auto s : shape) { torch_shape.emplace_back(static_cast(s)); - numel *= s; } TensorWrapper tensor(this->get_scaling_mode()); @@ -296,10 +294,6 @@ std::pair Float8BlockQuantizer::create_tensor( opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back(); - size_t m_dim = numel / k_dim; - constexpr size_t kBlockLen = 128; - Float8BlockScaleTensorFormat data_format = (all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT : Float8BlockScaleTensorFormat::GEMM_READY); @@ -310,30 +304,9 @@ std::pair Float8BlockQuantizer::create_tensor( } else { data_rowwise = at::empty(torch_shape, opts); } - size_t sinv0 = 0; - size_t sinv1 = 0; - if (block_scaling_dim == 2) { - // 2D scaling is always GEMM_READY for now - NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, - "2D scaling is always GEMM_READY for now."); - sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); - } else if (block_scaling_dim == 1) { - // 1D scaling can be GEMM_READY or COMPACT - bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; - // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY - sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; - sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); - // if the rowwise format is compact, the scaling factor is not be transposed - if (rowwise_compact) { - std::swap(sinv0, sinv1); - } - } else { - NVTE_ERROR( - "Unsupported block_scaling_dim in create_tensor rowwise. " - "Expected 1 or 2. Got ", - block_scaling_dim); - } + auto scale_shape = get_scale_shape(shape, false); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; scale_inv_rowwise = at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); @@ -364,27 +337,9 @@ std::pair Float8BlockQuantizer::create_tensor( columnwise_shape = shape; } } - size_t sinv0 = 0; - size_t sinv1 = 0; - if (block_scaling_dim == 2) { - // 2D scaling is always GEMM_READY for now - NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, - "2D scaling is always GEMM_READY for now."); - sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); - } else if (block_scaling_dim == 1) { - bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; - sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; - sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4); - // GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS - // for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1] - // so no need to swap sinv0 and sinv1 here - } else { - NVTE_ERROR( - "Unsupported block_scaling_dim in create_tensor columnwise. " - "Expected 1 or 2. Got ", - block_scaling_dim); - } + auto scale_shape = get_scale_shape(shape, true); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; data_colwise = at::empty(torch_columnwise_shape, opts); scale_inv_colwise = at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); @@ -418,6 +373,81 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, + bool columnwise) const { + size_t numel = 1; + for (auto s : shape) { + numel *= s; + } + + size_t k_dim = shape.size() == 0 ? 1u : shape.back(); + size_t m_dim = numel / k_dim; + constexpr size_t kBlockLen = 128; + + Float8BlockScaleTensorFormat data_format = + (all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT + : Float8BlockScaleTensorFormat::GEMM_READY); + + std::vector scale_shape; + + bool rowwise_usage = !columnwise; + + if (rowwise_usage) { + // rowwise scaling factor shape + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + // 2D scaling is always GEMM_READY for now + NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, + "2D scaling is always GEMM_READY for now."); + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + // 1D scaling can be GEMM_READY or COMPACT + bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; + // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); + // if the rowwise format is compact, the scaling factor is not be transposed + if (rowwise_compact) { + std::swap(sinv0, sinv1); + } + } else { + NVTE_CHECK(false, + "Unsupported block_scaling_dim in create_tensor rowwise." + "Expected 1 or 2. Got ", + block_scaling_dim); + } + scale_shape = {sinv0, sinv1}; + } else { + // columnwise scaling factor shape + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + // 2D scaling is always GEMM_READY for now + NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, + "2D scaling is always GEMM_READY for now."); + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + // 1D scaling can be GEMM_READY or COMPACT + bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4); + // GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS + // for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1] + // so no need to swap sinv0 and sinv1 here + } else { + NVTE_CHECK(false, + "Unsupported block_scaling_dim in create_tensor columnwise." + "Expected 1 or 2. Got ", + block_scaling_dim); + } + scale_shape = {sinv0, sinv1}; + } + return scale_shape; +} + MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); } diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 5fe351578e..4b5148b771 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -24,7 +24,6 @@ from ..utils import ( divide, cast_if_needed, - assert_dim_for_fp8_exec, clear_tensor_data, init_method_constant, requires_grad, @@ -38,7 +37,7 @@ from ..cpp_extensions import ( general_grouped_gemm, ) -from ..constants import GemmParallelModes, dist_group_type, TE_DType +from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..cpu_offload import is_cpu_offload_enabled @@ -87,20 +86,9 @@ def forward( weights = weights_and_biases[:num_gemms] biases = weights_and_biases[num_gemms:] device = inp.device - - # Make sure input dimensions are compatible - in_features = weights[0].shape[-1] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmats = torch.split(inp.view(-1, in_features), m_splits) - if fp8: - assert_dim_for_fp8_exec(*inputmats, *weights) - - # Cast input to expected dtype - inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] - inputmats = [] - weight_requires_grad = weights[0].requires_grad + # Configure quantizers if input_quantizers[0] is not None: for input_quantizer in input_quantizers: input_quantizer.set_usage( @@ -120,17 +108,25 @@ def forward( for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) - fprop_gemm_use_split_accumulator = _2X_ACC_FPROP - if fp8: - recipe = FP8GlobalStateManager.get_fp8_recipe() - if hasattr(recipe, "fp8_gemm_fprop"): - fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator - inputmats = tex.fused_multi_quantize( - inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] + # Initialize input tensors + in_features = weights[0].size(-1) + if inp.size(-1) != in_features: + raise ValueError( + f"Input tensor (shape={tuple(inp.size())}) is not compatible with " + f"weight tensor (shape={tuple(weights[0].size())})" ) - weights_fp8 = [] - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + inp_view = inp.reshape(-1, in_features) + inputmats: list + if fp8: + inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) + else: + inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) + + # Initialize weights + weights_fp8: list + if fp8: # FP8 cast to workspace buffer + weights_fp8 = [] update_workspace = is_first_microbatch is None or is_first_microbatch for i in range(num_gemms): weight_fp8 = module.get_weight_workspace( @@ -143,18 +139,29 @@ def forward( weights_fp8.append(weight_fp8) else: - inputmats = inputmats_no_fp8 - bias_dtype = activation_dtype weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] + # Initialize biases + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases + # Initialize output tensor out = torch.empty( [sum(m_splits), weights_fp8[0].size(0)], dtype=activation_dtype, device=device, ) + # Choose whether to use split accumulator + use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + + # Perform GEMM _ = general_grouped_gemm( weights_fp8, inputmats, @@ -165,7 +172,7 @@ def forward( m_splits=m_splits, bias=biases, use_bias=use_bias, - use_split_accumulator=fprop_gemm_use_split_accumulator, + use_split_accumulator=use_split_accumulator, ) if fp8_calibration: @@ -247,36 +254,44 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], w.main_grad = main_grads[i] weights[i] = w - # preprocess grad_output - - grad_output = grad_output.contiguous() - grad_output_mats = torch.split( - grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits - ) + # Preprocess grad output + grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms if ctx.fp8: if ctx.use_bias: - # unfuse bgrad for now until cast_transpose + dgrad calculation is ready - # for Float8BlockQuantizer. - if ctx.fp8_recipe.float8_block_scaling(): - for i in range(ctx.num_gemms): - grad_biases[i] = grad_output_mats[i].sum(dim=0) - grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i]) - else: + grad_output_mats = torch.split(grad_output_view, ctx.m_splits) + recipe = ctx.fp8_recipe + if recipe.delayed() or recipe.float8_current_scaling() or recipe.mxfp8(): + # Fused bias grad + quantize kernel for i in range(ctx.num_gemms): grad_biases[i], grad_output[i] = tex.bgrad_quantize( - grad_output_mats[i], ctx.grad_output_quantizers[i] + grad_output_mats[i], + ctx.grad_output_quantizers[i], ) + else: + # Unfused bias grad and multi-tensor quantize + for i in range(ctx.num_gemms): + grad_biases[i] = grad_output_mats[i].sum(dim=0) + grad_output = tex.split_quantize( + grad_output_view, + ctx.m_splits, + ctx.grad_output_quantizers, + ) else: - grad_output = tex.fused_multi_quantize( - grad_output_mats, - None, + # Multi-tensor quantize + grad_output = tex.split_quantize( + grad_output_view, + ctx.m_splits, ctx.grad_output_quantizers, - TE_DType[ctx.activation_dtype], ) else: - grad_output = grad_output_mats + # Only split grad output. Grad bias is fused with + # wgrad GEMM. + grad_output = torch.split( + cast_if_needed(grad_output_view, ctx.activation_dtype), + ctx.m_splits, + ) if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index a5e15925e2..3635494ccc 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -42,7 +42,6 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): def __new__( cls, - *args, rowwise_data: Optional[torch.Tensor], rowwise_scale_inv: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor], @@ -50,7 +49,8 @@ def __new__( fp8_dtype: TE_DType, quantizer: Quantizer, is_2D_scaled: bool, - data_format: Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY, + data_format: Float8BlockScaleTensorFormat, + *args, **kwargs, ): instance = super().__new__(cls, *args, **kwargs) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 738bc3906f..bac7159491 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -10,6 +10,7 @@ import torch import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType +from transformer_engine_torch import Float8BlockScaleTensorFormat from transformer_engine.common.recipe import Float8BlockScaling, Recipe from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase @@ -294,6 +295,37 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): holds configuration about quantization and dequantization modes. """ + # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args, + # which significantly reduces the Pybind11 overhead when calling the constructor from C++. + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + fp8_dtype: TE_DType, + quantizer: Quantizer, + is_2D_scaled: bool, + data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY, + **kwargs, + ): + instance = super().__new__( + cls, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + fp8_dtype, + quantizer, + is_2D_scaled, + data_format, + *args, + **kwargs, + ) + + return instance + def __repr__(self, *, tensor_contents=None): return ( f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," From 7b9d9a53952a5252f0fc38e9756e9bee88d64fcd Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Thu, 26 Jun 2025 15:37:21 +0800 Subject: [PATCH 30/39] [PyTorch|common] Optimize unpadding kernel for FP8 (#1866) * [PyTorch|common] Implement unpadding kernel for FP8 1. Add multi-tensor unpadding kernel 2. Replace split+cat with unpadding kernel in Fp8Padding and Fp8Unpadding 3. Add unpadding with padding unit tests Signed-off-by: xiaoxi-wangfj <690912414@qq.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add license Signed-off-by: Xin Yao * Update padding.cu Signed-off-by: Xin Yao --------- Signed-off-by: xiaoxi-wangfj <690912414@qq.com> Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_multi_unpadding.cu | 186 ++++++++++++++++++ .../include/transformer_engine/padding.h | 27 +++ transformer_engine/common/util/padding.cu | 163 +++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 3 + .../pytorch/csrc/extensions/padding.cpp | 73 +++++++ .../pytorch/csrc/extensions/pybind.cpp | 2 + .../pytorch/module/fp8_padding.py | 17 +- .../pytorch/module/fp8_unpadding.py | 11 +- 9 files changed, 471 insertions(+), 12 deletions(-) create mode 100644 tests/cpp/operator/test_multi_unpadding.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index b680389a35..ff889c2812 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable(test_operator test_memset.cu test_multi_cast_transpose.cu test_multi_padding.cu + test_multi_unpadding.cu test_causal_softmax.cu test_swizzle.cu ../test_common.cu) diff --git a/tests/cpp/operator/test_multi_unpadding.cu b/tests/cpp/operator/test_multi_unpadding.cu new file mode 100644 index 0000000000..ca685b9628 --- /dev/null +++ b/tests/cpp/operator/test_multi_unpadding.cu @@ -0,0 +1,186 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_unpadding_ref(const std::vector>& input_list, + std::vector>& output_list, + const std::vector& height_list, + const std::vector& width_list, + const std::vector& padded_height_list) { + using compute_t = float; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = input_list[tensor_id]; + auto& output = output_list[tensor_id]; + const size_t height = height_list[tensor_id]; + const size_t width = width_list[tensor_id]; + const size_t padded_height = padded_height_list[tensor_id]; + + // Only copy the valid (unpadded) portion + for (size_t i = 0; i < height; ++i) { + for (size_t j = 0; j < width; ++j) { + const compute_t x = static_cast(input[i * width + j]); + const OutputType y = static_cast(x); + output[i * width + j] = y; + } + } + } +} + +template +void performUnpaddingTest() { + using namespace test; + + const DType itype = TypeInfo::dtype; + const DType otype = TypeInfo::dtype; + const std::vector> tensor_dims = {{1,1}, + {1,768}, + {768,1}, + {768,768}, + {43,43}, + {43,256}, + {256,43}, + {256,256}}; + const size_t num_tensors = tensor_dims.size(); + constexpr int align = 16; + + // Buffers for Transformer Engine implementation + std::vector padded_input_list, unpadded_output_list; + + // Buffers for reference implementation + std::vector> ref_padded_input_list; + std::vector> ref_unpadded_output_list; + std::vector ref_height_list(num_tensors), ref_width_list(num_tensors); + std::vector ref_padded_height_list(num_tensors); + + // Initialize buffers + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + const size_t original_height = tensor_dims[tensor_id].first; + const size_t width = tensor_dims[tensor_id].second; + const size_t padded_height = (original_height + align - 1) / align * align; + + // Input is padded tensor (padded_height x width) + padded_input_list.emplace_back( + Tensor("padded_input_" + std::to_string(tensor_id), + std::vector{padded_height, width}, itype)); + + // Output is unpadded tensor (original_height x width) + unpadded_output_list.emplace_back( + Tensor("unpadded_output_" + std::to_string(tensor_id), + std::vector{original_height, width}, otype)); + + auto& padded_input = padded_input_list.back(); + auto& unpadded_output = unpadded_output_list.back(); + + // Fill padded input with random data (including padding area) + fillUniform(&padded_input); + setRandomScale(&unpadded_output); + + // Initialize reference buffers + ref_padded_input_list.emplace_back(padded_height * width); + ref_unpadded_output_list.emplace_back(original_height * width); + + // Copy data to reference buffers + std::copy(padded_input.rowwise_cpu_dptr(), + padded_input.rowwise_cpu_dptr() + padded_height * width, + ref_padded_input_list.back().begin()); + + ref_height_list[tensor_id] = original_height; + ref_width_list[tensor_id] = width; + ref_padded_height_list[tensor_id] = padded_height; + } + + // Transformer Engine implementation + auto make_nvte_vector = [](std::vector& tensor_list) + -> std::vector { + std::vector nvte_tensor_list; + for (auto& tensor : tensor_list) { + nvte_tensor_list.emplace_back(tensor.data()); + } + return nvte_tensor_list; + }; + + // Convert height_list to int for the API + std::vector original_height_list_int(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + original_height_list_int[i] = static_cast(ref_height_list[i]); + } + + // Call unpadding API + nvte_multi_unpadding(num_tensors, + make_nvte_vector(padded_input_list).data(), + make_nvte_vector(unpadded_output_list).data(), + original_height_list_int.data(), + 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Reference implementation + compute_unpadding_ref(ref_padded_input_list, + ref_unpadded_output_list, + ref_height_list, + ref_width_list, + ref_padded_height_list); + + // Check correctness + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + auto [atol, rtol] = getTolerances(otype); + compareResults("unpadded_output", + unpadded_output_list[tensor_id], + ref_unpadded_output_list[tensor_id].data(), + true, + atol, rtol); + } +} + +} // namespace + +class MultiUnpaddingTestSuite + : public ::testing::TestWithParam {}; + +TEST_P(MultiUnpaddingTestSuite, TestMultiUnpadding) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = GetParam(); + const DType output_type = input_type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performUnpaddingTest(); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiUnpaddingTestSuite, + ::testing::ValuesIn(test::all_fp_types), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(info.param); + return name; + }); diff --git a/transformer_engine/common/include/transformer_engine/padding.h b/transformer_engine/common/include/transformer_engine/padding.h index 4258463b1b..0783fc2b21 100644 --- a/transformer_engine/common/include/transformer_engine/padding.h +++ b/transformer_engine/common/include/transformer_engine/padding.h @@ -44,6 +44,33 @@ extern "C" { void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, const int* padded_num_rows_list, cudaStream_t stream); +/*! \brief Unpadding multiple tensors (reverse operation of padding). + * + * NOTE: Unpadding mode only removes bottom rows. + * + * For example, 4x3 matrix unpad to 3x3 matrix. + * + * source + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * | 0 | 0 | 0 | + * + * destination + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D padded input tensors. + * \param[in,out] output_list List of unpadded tensors. Dimensions + * match original unpadded tensors. + * \param[in] unpadded_num_rows_list List of unpadded num rows corresponding to input tensors. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* unpadded_num_rows_list, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index df11ddd3f6..a1899d5b10 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -126,6 +126,83 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP } } +template +__global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(MultiPaddingArgs args) { + using Vec = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr int bdimx = THREADS_PER_WARP; + constexpr int bdimy = n_warps_per_tile; + const int tid = threadIdx.x; + const int tidx = tid % bdimx; + const int tidy = tid / bdimx; + const int bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + constexpr int tile_dim_m = THREADS_PER_WARP * nvec; + constexpr int tile_dim_n = THREADS_PER_WARP * nvec; + + // Number of nvec x nvec subtiles for each thread to + // load/store + constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + + // Find tensor corresponding to block + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const Type* input = reinterpret_cast(args.input_list[tensor_id]); + Type* output = reinterpret_cast(args.output_list[tensor_id]); + const int num_rows = args.num_rows_list[tensor_id]; + const int row_length = args.row_length_list[tensor_id]; + + // Find position of tile within tensor + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int tile_id = bid - args.block_range[tensor_id]; + const int tile_id_m = tile_id / num_tiles_n; + const int tile_id_n = tile_id % num_tiles_n; + const int tile_row = tile_id_m * tile_dim_m; + const int tile_col = tile_id_n * tile_dim_n; + + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + Type local_zero = static_cast(0.f); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + Vec local_input; + Vec local_output; + local_input.clear(); + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[row * row_length + col + j2]; + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec; ++j2) { + local_output.data.elt[j2] = local_input.data.elt[j2]; + } + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_output.data.elt[j2]; + } + } + } + } + } +} + } // namespace void multi_padding(const std::vector input_list, std::vector output_list, @@ -202,6 +279,78 @@ void multi_padding(const std::vector input_list, std::vector o } } +void multi_unpadding(const std::vector input_list, std::vector output_list, + const std::vector unpadded_num_rows_list, cudaStream_t stream) { + // Check that number of tensors is valid + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); + if (input_list.empty()) { + return; + } + + // Check that tensor properties are valid + DType type = input_list[0]->data.dtype; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = *input_list[tensor_id]; + const auto& output = *output_list[tensor_id]; + CheckInputTensor(input, "multi_unpadding_input_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_unpadding_output_" + std::to_string(tensor_id)); + + NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match."); + NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match."); + + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); + NVTE_CHECK(output.data.shape[0] == unpadded_num_rows_list[tensor_id], + "output tensor shape does not match padded input shape."); + } + + // Input matrices are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + + // Add tensors to kernel argument struct + MultiPaddingArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + // Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_unpadding_kernel + <<>>(kernel_args);); // NOLINT(*) + kernel_args.num_tensors = 0; + } + + // Calculate number of thread blocks needed for tensor + const int num_rows = unpadded_num_rows_list[tensor_id]; + const int row_length = input_list[tensor_id]->data.shape[1]; + const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m; + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int num_tiles = num_tiles_m * num_tiles_n; + + // Add tensor to kernel argument struct + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); + kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.num_rows_list[pos] = num_rows; + kernel_args.row_length_list[pos] = row_length; + kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; + kernel_args.num_tensors++; + } + + // Launch kernel + if (kernel_args.num_tensors > 0) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_unpadding_kernel + <<>>(kernel_args);); // NOLINT(*) + } +} + } // namespace transformer_engine void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, @@ -217,3 +366,17 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe } multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); } + +void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* unpadded_num_rows_list, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_unpadding); + using namespace transformer_engine; + std::vector input_list_, output_list_; + std::vector unpadded_num_rows_list_; + for (size_t i = 0; i < num_tensors; ++i) { + input_list_.push_back(convertNVTETensorCheck(input_list[i])); + output_list_.push_back(convertNVTETensorCheck(output_list[i])); + unpadded_num_rows_list_.push_back(unpadded_num_rows_list[i]); + } + multi_unpadding(input_list_, output_list_, unpadded_num_rows_list_, stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4af7576c5f..835124be41 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -368,6 +368,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector unpadded_input_row_list); /*************************************************************************************************** * NVSHMEM APIs **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index f3c1b58cf2..d4b64a485c 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -81,4 +81,77 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, }); } +void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector unpadded_input_row_list) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input_row_list.size() == unpadded_input_row_list.size(), + "Number of input row list and padded row list must match."); + NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); + NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); + + const auto num_tensors = input_row_list.size(); + // Extract properties from PyTorch tensors + std::vector input_dptr_list, output_dptr_list; + std::vector> input_shape_list, output_shape_list; + std::vector input_type_list; + void* d_input_ptr = reinterpret_cast(input.data_ptr()); + void* d_output_ptr = reinterpret_cast(output.data_ptr()); + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + input_dptr_list.push_back(d_input_ptr); + output_dptr_list.push_back(d_output_ptr); + + // Move the input pointer to the next split. + char* input_char_ptr = reinterpret_cast(d_input_ptr); + const size_t input_dptr_offset = + input_row_list[tensor_id] * input.size(1) * input.element_size(); + input_char_ptr += input_dptr_offset; + d_input_ptr = reinterpret_cast(input_char_ptr); + + input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); + + // Move the output pointer to the next split. + char* output_char_ptr = reinterpret_cast(d_output_ptr); + const size_t output_dptr_offset = + unpadded_input_row_list[tensor_id] * output.size(1) * output.element_size(); + output_char_ptr += output_dptr_offset; + d_output_ptr = reinterpret_cast(output_char_ptr); + + output_shape_list.push_back( + {unpadded_input_row_list[tensor_id], static_cast(output.size(1))}); + } + + // Construct TE tensors + std::vector nvte_input_list, nvte_output_list; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype) -> NVTETensor { + tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); + return tensor_wrappers.back().data(); + }; + + std::vector unpadded_num_rows_list; + for (size_t i = 0; i < input_dptr_list.size(); ++i) { + if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue; + nvte_input_list.emplace_back( + make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i])); + nvte_output_list.emplace_back( + make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i])); + unpadded_num_rows_list.emplace_back(unpadded_input_row_list[i]); + } + + // Check tensor lists + NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(), + "Number of input and output tensors must match"); + NVTE_CHECK(unpadded_num_rows_list.size() == nvte_input_list.size() && + "Number of input and padded row list must match"); + + // Launch TE kernel + nvte_multi_unpadding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), + unpadded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8f06883807..83f5291177 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -232,6 +232,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("out_dtype"), py::call_guard()); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); + m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, + "Fused Multi-tensor unpadding", py::call_guard()); // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 9748408338..c2ec7b07b5 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -53,15 +53,16 @@ def backward(ctx, grad_output: torch.Tensor): if ctx.requires_dgrad: grad_output = grad_output.contiguous() - grad_output_mats = torch.split( - grad_output.view(-1, grad_output.shape[-1]), ctx.padded_m_splits + in_features = grad_output.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(ctx.m_splits) + grad_input = torch.empty( + [total_row, in_features], dtype=grad_output.dtype, device=grad_output.device ) - grad_input = torch.cat( - [ - grad_output_mat[: ctx.m_splits[i]] - for i, grad_output_mat in enumerate(grad_output_mats) - ], - dim=0, + + tex.fused_multi_row_unpadding( + grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits ) return (grad_input, None, None, None) diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 7e1fbcb2a3..4b4fbf25e9 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -29,10 +29,13 @@ def forward( is_grad_enabled: bool, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - inputmats = torch.split(inp.view(-1, inp.shape[-1]), padded_m_splits) - out_ret = torch.cat( - [grad_output_mat[: m_splits[i]] for i, grad_output_mat in enumerate(inputmats)], dim=0 - ) + in_features = inp.shape[-1] + + # Allocate cast and transpose output tensor + total_row = sum(m_splits) + out_ret = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device) + + tex.fused_multi_row_unpadding(inp.view(-1, in_features), out_ret, padded_m_splits, m_splits) if is_grad_enabled: ctx.m_splits = m_splits From c42614d0df2ccbe0ab6602779d560767f91b805b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 26 Jun 2025 10:44:04 +0200 Subject: [PATCH 31/39] [PyTorch Debug] Fix the issue with PP (#1894) * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/debug/pytorch/debug_quantization.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 4d61757e1d..2b859800ae 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -62,6 +62,12 @@ def __init__( self.tp_group = tp_group # used in inspect_tensor calls self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count + # .internal = True is slightly faster, but results + # in errors when caching the weights. + # Setting .internal = False is safer. + if parent_quantizer is not None: + parent_quantizer.internal = False + self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, From 968eb0d7f2f2583e1142d2308b2c23aeb345e7d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 26 Jun 2025 10:47:40 +0200 Subject: [PATCH 32/39] [PyTorch Debug] Fixed the empty tensor bug in statistics computation (#1843) * fixed the bug Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * test change Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/debug/test_distributed.py | 4 +-- tests/pytorch/debug/test_numerics.py | 30 +++++++++++++++++++ .../debug/features/utils/stats_buffer.py | 7 +++++ .../debug/features/utils/stats_computation.py | 4 ++- 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/debug/test_distributed.py b/tests/pytorch/debug/test_distributed.py index 7c072a0541..7333354ee3 100644 --- a/tests/pytorch/debug/test_distributed.py +++ b/tests/pytorch/debug/test_distributed.py @@ -34,6 +34,6 @@ def test_debug_distributed(feature_dirs): test_path = TEST_ROOT / "run_distributed.py" test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"] - result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + result = subprocess.run(test_cmd, env=os.environ, check=False, text=True) if result.returncode != 0: - raise AssertionError(result.stderr.decode()) + raise AssertionError(f"torchrun exited with {result.returncode}") diff --git a/tests/pytorch/debug/test_numerics.py b/tests/pytorch/debug/test_numerics.py index 55c3ab9b7e..6a89149c7a 100644 --- a/tests/pytorch/debug/test_numerics.py +++ b/tests/pytorch/debug/test_numerics.py @@ -262,6 +262,18 @@ def _get_tensors(): return x, weight +LOGGING_CONFIG = """logging_config: + enabled: True + layers: + layer_types: [linear] + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation, gradient, weight, output, wgrad, dgrad] + stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] +""" + + DISABLE_FP8_CONFIG = Template( """disable_fp8_config: enabled: True @@ -275,6 +287,24 @@ def _get_tensors(): ) +@create_config_file +def run_logging_zero_numel_tensor(feature_dirs, **kwargs): + kwargs["config_file"].write(LOGGING_CONFIG) + kwargs["config_file"].flush() + + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs) + + x, weight = _get_tensors() + x1 = x[:0, :] + model = _init_model(weight) + _ = _run_forward_backward(x1, model) + _ = _run_forward_backward(x, model) + + +def test_logging_zero_numel_tensor(feature_dirs): + run_logging_zero_numel_tensor(feature_dirs) + + @pytest.mark.parametrize("fprop_fp8", all_boolean) @pytest.mark.parametrize("dgrad_fp8", all_boolean) @pytest.mark.parametrize("wgrad_fp8", all_boolean) diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 2313484054..4be465f8e8 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -85,6 +85,13 @@ def feed(self, tensor, iteration): if self.modified[0] and not self.reduce_within_microbatch: return + if ( + tensor.numel() == 0 + if hasattr(tensor, "numel") + else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors()) + ): + return + # save stats for tensor to tmp buffer for stat_name in self.stats_to_compute: fn, _ = STATS[stat_name] diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index d111e48903..ed32de1ae2 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -17,6 +17,8 @@ def _compute_dynamic_range_top(tensor): """Computes the log2 of the amax of the tensor""" tensor_abs = tensor.abs() tensor_abs = tensor_abs[tensor_abs != 0] + if tensor_abs.numel() == 0: + return torch.inf amax = tensor_abs.max().float() if not amax.all(): amax = torch.tensor(1, device=tensor.device).to(torch.float) @@ -125,7 +127,7 @@ def _get(buffers, stat_name): lambda buffers: min(_get(buffers, "dynamic_range_bottom")), ), "underflows_num": ( - lambda x: (x._data == 0).sum(), + lambda x: (x.get_data_tensors()[0] == 0).sum(), lambda buffers: sum(_get(buffers, "underflows_num")), ), "std": ( From 866953e09dcb1f74c41ae39d49f4fee178410b05 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 26 Jun 2025 12:56:10 -0700 Subject: [PATCH 33/39] [JAX] Use keyword args for jit in_shardings and out_shardings (#1898) Use keyword args for jit in_shardings and out_shardings Signed-off-by: Jeremy Berchtold --- examples/jax/encoder/test_model_parallel_encoder.py | 12 +++++++++--- examples/jax/encoder/test_multigpu_encoder.py | 12 +++++++++--- examples/jax/encoder/test_multiprocessing_encoder.py | 12 +++++++++--- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index b2bd18205f..1f45d10faf 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -307,7 +307,9 @@ def train_and_evaluate(args): key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) # Check if params are sufficiently sharded after initialization @@ -344,11 +346,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index b6f4db1084..12148b0e29 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -288,7 +288,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -312,11 +314,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index c7606c3ab0..580824cefa 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -412,7 +412,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -432,11 +434,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) From 8382eed6cccb1eb0602c96afc1cfbc707468257f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 26 Jun 2025 15:00:45 -0700 Subject: [PATCH 34/39] [PyTorch] Skip KV cache for sm89 and cuDNN < 9.12 (#1895) * skip kv cache for sm89, cudnn < 9.12 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix test_numerics Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 4 ++-- .../pytorch/attention/dot_product_attention/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 560b7ed7f9..ab3ca4c314 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2322,9 +2322,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, if ( backend == "FusedAttention" and get_device_compute_capability() == (8, 9) - and get_cudnn_version() < (9, 11, 0) + and get_cudnn_version() < (9, 12, 0) ): - pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11") + pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12") os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index d98dde0159..18a5e9a665 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -433,8 +433,8 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version < (9, 11, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.11") + if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0): + logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") From f05f12c974b37c5bd8dfca3d2d294be53b66abfa Mon Sep 17 00:00:00 2001 From: yuzhongw-nvidia Date: Sun, 29 Jun 2025 00:14:38 +0800 Subject: [PATCH 35/39] Fix MLA CP Bugs (#1896) * fix: (1) UT ignores MLA; (2) bshd format runtime error. Ban fp8 mla attn + cp due to correctness problem Signed-off-by: Yuzhong Wang * only disable FP8 CP for MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Yuzhong Wang Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../fused_attn/run_fused_attn_with_cp.py | 35 ++++++++++++++----- .../fused_attn/test_fused_attn_with_cp.py | 2 ++ .../dot_product_attention/context_parallel.py | 8 ++--- .../attention/dot_product_attention/utils.py | 6 ++++ 4 files changed, 38 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index ad3bc32079..f1db30d992 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -89,7 +89,7 @@ def run_dpa_with_cp( # instantiate core attn module core_attn = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, qkv_format=qkv_format, @@ -106,16 +106,22 @@ def run_dpa_with_cp( config.num_heads, config.head_dim_qk, ) - kv_input_shape = ( + k_input_shape = ( config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim_qk, ) + v_input_shape = ( + config.batch_size, + config.max_seqlen_kv, + config.num_gqa_groups, + config.head_dim_v, + ) attn_output_shape = ( config.batch_size, config.max_seqlen_q, - config.num_heads * config.head_dim_qk, + config.num_heads * config.head_dim_v, ) cu_seqlens_q = None cu_seqlens_kv = None @@ -128,16 +134,22 @@ def run_dpa_with_cp( config.num_heads, config.head_dim_qk, ) - kv_input_shape = ( + k_input_shape = ( config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim_qk, ) + v_input_shape = ( + config.max_seqlen_kv, + config.batch_size, + config.num_gqa_groups, + config.head_dim_v, + ) attn_output_shape = ( config.max_seqlen_q, config.batch_size, - config.num_heads * config.head_dim_qk, + config.num_heads * config.head_dim_v, ) cu_seqlens_q = None cu_seqlens_kv = None @@ -149,14 +161,19 @@ def run_dpa_with_cp( config.num_heads, config.head_dim_qk, ) - kv_input_shape = ( + k_input_shape = ( config.batch_size * config.max_seqlen_q, config.num_gqa_groups, config.head_dim_qk, ) + v_input_shape = ( + config.batch_size * config.max_seqlen_q, + config.num_gqa_groups, + config.head_dim_v, + ) attn_output_shape = ( config.batch_size * config.max_seqlen_q, - config.num_heads * config.head_dim_qk, + config.num_heads * config.head_dim_v, ) seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2) @@ -177,8 +194,8 @@ def run_dpa_with_cp( assert False, f"{qkv_format} is an unsupported qkv_format!" q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() - k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() - v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() + k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() + v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() dout_quantizer = Float8Quantizer( fp8_dtype=tex.DType.kFloat8E5M2, diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 4ecc54b530..458070c9b0 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -173,6 +173,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("Only fp8 works with fp8_mha=True!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") + if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: + pytest.skip("MLA CP currently does not support FP8 attention!") subprocess.run( get_bash_arguments( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 9f4822784e..c6f4647c04 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2559,8 +2559,8 @@ def backward(ctx, dout): if ctx.enable_mla: # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] - dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape) - dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape) + dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) + dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( dk_fp8, fake_dtype=torch.float32, internal=True ) @@ -2586,8 +2586,8 @@ def backward(ctx, dout): dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) if ctx.enable_mla: # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:]) - dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:]) + dk = dk.view(dk.shape[0], -1, *dk.shape[-2:]) + dv = dv.view(dv.shape[0], -1, *dv.shape[-2:]) else: # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 18a5e9a665..0e23e3a8ce 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -608,6 +608,12 @@ def get_attention_backend( " bias for THD format" ) use_fused_attention = False + elif fp8 and head_dim_qk != head_dim_v: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " MLA attention" + ) + use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends From 94ac69f9872c2a71d962592bae50168caa27201d Mon Sep 17 00:00:00 2001 From: Shiqing Fan Date: Wed, 12 Mar 2025 01:43:14 -0700 Subject: [PATCH 36/39] Extended tensor parallelism support: support ETP+TP comm overlap related patterns; - ag->fc2_wgrad - ag->fc1_wgrad - fc1_dgrad->rs - ag->proj_wgrad - ag->qkv_wgrad - qkv_dgrad->rs --- .../distributed/run_gemm_with_overlap.py | 1 + .../comm_gemm_overlap/comm_gemm_overlap.cpp | 69 ++++++++++++++----- .../transformer_engine/comm_gemm_overlap.h | 7 +- .../pytorch/cpp_extensions/gemm.py | 7 ++ transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/gemm.cpp | 6 +- .../pytorch/csrc/extensions/pybind.cpp | 1 + transformer_engine/pytorch/module/base.py | 32 ++++++++- 8 files changed, 100 insertions(+), 25 deletions(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 6d9e2f1526..28b299e5a2 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -51,6 +51,7 @@ def _mapped_argtype(opt, typemap): def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.") parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") + parser.add_argument("--local-rank", type=int, help="Input batch size.") parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.") parser.add_argument( "-n", "--num-heads", type=int, default=16, help="Number of attention heads." diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 40595ea988..392a3219a1 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -793,20 +793,28 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main) { + bool use_split_accumulator, bool ag_on_B, + TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions between TN and NN input layouts + // Get GEMM dimensions between TN, NN and NT input layouts const size_t m = (transa) ? A.size(0) : A.size(1); const size_t k = (transa) ? A.size(1) : A.size(0); - const size_t n_chunk = _ubufs[0].size(0); + const size_t n = (transb) ? B.size(1) : B.size(0); + + // For TN or NN layout, we chunk on the n dimension. + const size_t n_chunk = (transb) ? n : (n / _tp_size); + // For NT layer, we chunk on the k dimension. + const size_t k_chunk = (transb) ? (k / _tp_size) : k; // Get communication and GEMM output chunk sizes const int comm_bytes = _ubufs[0].bytes(); const bool do_gelu = pre_gelu_out.numel() > 0; + size_t input_a_chunk_size = m * k_chunk; + size_t input_b_chunk_size = n_chunk * k_chunk; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); @@ -820,10 +828,12 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, // Chunk dims std::vector input_b_chunk_shape = - (transb ? std::vector{k, 2 * n_chunk} : std::vector{2 * n_chunk, k}); - std::vector output_chunk_shape = {2 * n_chunk, m}; - size_t input_b_chunk_size = 2 * n_chunk * k; - size_t output_chunk_size = 2 * n_chunk * m; + (transb ? std::vector{2 * k_chunk, n} : std::vector{2 * n_chunk, k}); + // (transb ? std::vector{k, 2 * n_chunk} : std::vector{2 * n_chunk, k}); + std::vector output_chunk_shape = {(transb ? 1 : 2) * n_chunk, m}; + input_a_chunk_size *= transb ? 2 : 1; + input_b_chunk_size *= 2; + output_chunk_size *= transb ? 1 : 2; // Initial 1X input chunk exchange between neighboring peers int send_chunk_id = _tp_id; @@ -851,17 +861,25 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, recv_offset = comm_bytes * recv_chunk_id; // GEMM + auto input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id : 0, + transb ? std::vector{k_chunk * 2, m} : std::vector{m, k}); auto input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); auto output_chunk = - get_tensor_chunk(D, output_chunk_size * send_chunk_id / 2, output_chunk_shape); + get_tensor_chunk(D, transb ? 0 : output_chunk_size * send_chunk_id / 2, output_chunk_shape); auto aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2, - {n_chunk * 2, k}) + {2 * n_chunk, k}) : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); auto workspace_chunk = get_tensor_chunk( workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + // For NT input layer, overwrite the output buffer which perhpaps not been zeroed yet with the + // first chunk's partail output, and then accmulate the partial sum's result for all + // the following chunked gemms. + if (transa == false && transb == true && accumulate == false) { + accumulate = (i == 0) ? false : true; + } nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, @@ -888,10 +906,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, } else { // Chunk dims std::vector input_b_chunk_shape = - (transb ? std::vector{k, n_chunk} : std::vector{n_chunk, k}); + (transb ? std::vector{k_chunk, n} : std::vector{n_chunk, k}); std::vector output_chunk_shape = {n_chunk, m}; - size_t input_b_chunk_size = n_chunk * k; - size_t output_chunk_size = n_chunk * m; for (int i = 0; i < _tp_size; i++) { // Set the userbuffer id. Buffer under send is the input for the current @@ -904,10 +920,22 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, int recv_offset = comm_bytes * recv_chunk_id; // GEMM - auto input_b_chunk = - get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); + TensorWrapper input_a_chunk, input_b_chunk; + if (ag_on_B) { // AllGather is performed on input B tensor (default case). + // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP. + input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id : 0, + transb ? std::vector{k_chunk, m} : std::vector{m, k}); + input_b_chunk = + get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); + } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad. + assert(trana == false && transb == true); + input_a_chunk = get_buffer_chunk_like(A, input_a_chunk_size * send_chunk_id, + transb ? std::vector{k_chunk, m} : std::vector{m, k}); + input_b_chunk = + get_tensor_chunk(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); + } auto output_chunk = - get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape); + get_tensor_chunk(D, transb ? 0 : output_chunk_size * send_chunk_id, output_chunk_shape); auto aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) @@ -915,7 +943,14 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, auto workspace_chunk = get_tensor_chunk( workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); - nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + // For NT input layer, overwrite the output buffer which perhpaps not been zeroed yet with the + // first chunk's partail output, and then accmulate the partial sum's result for all + // the following chunked gemms. + if (transa == false && transb == true && accumulate == false) { + accumulate = (i == 0) ? false : true; + } + + nvte_cublas_gemm(input_a_chunk.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 293c57526d..97fdcc703b 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -129,7 +129,8 @@ class CommOverlapCore { virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + bool accumulate, bool use_split_accumulator, + bool ag_on_B, TensorWrapper &B_copy, cudaStream_t stream_main) { NVTE_ERROR("Operation is not implemented."); } @@ -176,7 +177,7 @@ class CommOverlapBase : public CommOverlapCore { void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, + bool use_split_accumulator, bool ag_on_B, TensorWrapper &B_copy, cudaStream_t stream_main) override { NVTE_ERROR("Operation not supported."); } @@ -257,7 +258,7 @@ class CommOverlapP2PBase : public CommOverlapCore { void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, + bool use_split_accumulator, bool ag_on_B, TensorWrapper &B_copy, cudaStream_t stream_main) override; /* diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 9f3921d36b..1a7ba9afab 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -39,14 +39,20 @@ def general_gemm( ub_type: tex.CommOverlapType = None, extra_output: Optional[torch.Tensor] = None, bulk_overlap: bool = False, + ag_on_B: bool = True, ) -> Iterable[Optional[torch.Tensor]]: """GEMM supporting fp8 inputs.""" + # assert A.dim() == 2 and B.dim() == 2, f"TE requires 2D input tensors!" + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" # assert quantization_params is None, "FP8 output not supported yet" + if layout == "NT": + assert gelu == False, "When layout='NT', gelu should be false." + if ub_type is not None: assert ub is not None, ( f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" @@ -102,6 +108,7 @@ def general_gemm( workspace.shape[0], accumulate, use_split_accumulator, + ag_on_B, # ag_on_B ) kwargs = { "comm_overlap": ub, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 835124be41..74cc579ae8 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -84,7 +84,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans py::handle quantizer, std::optional out_dtype, MaybeTensor bias, DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, + bool use_split_accumulator, bool ag_on_B, CommOverlapCore *comm_overlap = nullptr, std::optional comm_type = std::nullopt, MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 99bb4e69fd..465c82c7c7 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -90,7 +90,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans py::handle quantizer, std::optional out_dtype, MaybeTensor bias, DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, CommOverlapCore* comm_overlap, + bool use_split_accumulator, bool ag_on_B, CommOverlapCore* comm_overlap, std::optional comm_type, MaybeTensor extra_output, bool bulk_overlap) { // Input tensors @@ -214,8 +214,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, extra_output_tensor, - main_stream); + accumulate, use_split_accumulator, ag_on_B, + extra_output_tensor, main_stream); }); } } else { diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 83f5291177..546e20d5a7 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -110,6 +110,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"), py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), + py::arg("ag_on_B"), py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3d06a47313..f5651c34aa 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -255,6 +255,7 @@ def initialize_ub( ] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] + layers_all_gather_wgrad_overlap = ["qkv_wgrad", "fc1_wgrad", "fc2_wgrad", "proj_wgrad"] # Default overlap methods for layers methods = { "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], @@ -387,14 +388,41 @@ def add_ub( for name in dgrad_reduce_scatter_overlap: if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": wgrad_name = name.replace("dgrad", "wgrad") - assert wgrad_name not in ub_cfgs + # assert wgrad_name not in ub_cfgs layers_reduce_scatter_overlap.remove(wgrad_name) layers_all_gather_overlap.remove(name) layers_reduce_scatter_overlap.append(name) methods["bulk"].remove(name) + methods["bulk"].remove(wgrad_name) new_method = ub_cfgs[name]["method"] methods[new_method].append(name) + # Loop over user configs and disable dgrad and wgrad bulk overlaps for + # every layer that has a all-gather wgrad overlap. + for name in layers_all_gather_wgrad_overlap: + if name in ub_cfgs and ub_cfgs[name]["method"] != "bulk": + dgrad_name = name.replace("wgrad", "dgrad") + if name in {"fc1_wgrad", "qkv_wgrad"}: + if name in layers_reduce_scatter_overlap: + layers_reduce_scatter_overlap.remove(name) + layers_reduce_scatter_overlap.append(dgrad_name) + if dgrad_name in layers_all_gather_overlap: + layers_all_gather_overlap.remove(dgrad_name) + layers_all_gather_overlap.append(name) + + name in methods["bulk"] and methods["bulk"].remove(name) + dgrad_name in methods["bulk"] and methods["bulk"].remove(dgrad_name) + + new_method = ub_cfgs[name]["method"] + if name not in methods[new_method]: + methods[new_method].append(name) + else: + assert name in {"fc2_wgrad", "proj_wgrad"} + # Replace default {fc2, proj}_dgrad AG overlap in configs with {fc2, proj}_wgrad. + methods["ring_exchange"].remove(dgrad_name) + methods["ring_exchange"].append(name) + + for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: ub_cfg = get_default_config(name) if ub_cfgs is not None and name in ub_cfgs: @@ -403,6 +431,8 @@ def add_ub( ) ub_cfg.update(ub_cfgs[name]) ub_cfg["fp8_buf"] = fp8_buf + if torch.distributed.get_rank() == 0: + print(f"Registered ub_name={name}, ub_config={ub_cfg}") add_ub(name, **ub_cfg) From c58598f61ede7a97fa305b8c8d3d1cfd5e75966b Mon Sep 17 00:00:00 2001 From: Shiqing Fan Date: Wed, 30 Jul 2025 19:14:01 -0700 Subject: [PATCH 37/39] fix: fix aggregate=True for AG->Wgrad (layout=NT). --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 392a3219a1..bf2a766d48 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -861,7 +861,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, recv_offset = comm_bytes * recv_chunk_id; // GEMM - auto input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id : 0, + auto input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0, transb ? std::vector{k_chunk * 2, m} : std::vector{m, k}); auto input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); @@ -880,7 +880,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, if (transa == false && transb == true && accumulate == false) { accumulate = (i == 0) ? false : true; } - nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + nvte_cublas_gemm(input_a_chunk.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); From e23a81c7384e2e9252902b6c82fc5d1069d9103c Mon Sep 17 00:00:00 2001 From: Shiqing Fan Date: Wed, 30 Jul 2025 19:52:18 -0700 Subject: [PATCH 38/39] fix: fix aggregate=True with allgather on A tensor for AG->Wgrad (layout=NT) --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index bf2a766d48..15b8188256 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -861,10 +861,21 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, recv_offset = comm_bytes * recv_chunk_id; // GEMM - auto input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0, - transb ? std::vector{k_chunk * 2, m} : std::vector{m, k}); - auto input_b_chunk = - get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); + TensorWrapper input_a_chunk, input_b_chunk; + if (ag_on_B) { // AllGather is performed on input B tensor (default case). + // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP. + input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0, + transb ? std::vector{k_chunk * 2, m} : std::vector{m, k}); + input_b_chunk = + get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); + } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad. + assert(transa == false && transb == true); + input_a_chunk = get_buffer_chunk_like( + A, input_a_chunk_size * send_chunk_id / 2, std::vector{k_chunk * 2, m} + ); + input_b_chunk = + get_tensor_chunk(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); + } auto output_chunk = get_tensor_chunk(D, transb ? 0 : output_chunk_size * send_chunk_id / 2, output_chunk_shape); auto aux_chunk = (do_gelu) @@ -928,7 +939,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad. - assert(trana == false && transb == true); + assert(transa == false && transb == true); input_a_chunk = get_buffer_chunk_like(A, input_a_chunk_size * send_chunk_id, transb ? std::vector{k_chunk, m} : std::vector{m, k}); input_b_chunk = From 91077423afed60fd43367b8b4524348c34c86c64 Mon Sep 17 00:00:00 2001 From: Shiqing Fan Date: Mon, 11 Aug 2025 03:02:16 -0700 Subject: [PATCH 39/39] fix ag_gemm shape for A tensor. --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 15b8188256..4fc21345b5 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -865,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, if (ag_on_B) { // AllGather is performed on input B tensor (default case). // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP. input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0, - transb ? std::vector{k_chunk * 2, m} : std::vector{m, k}); + transb ? std::vector{k_chunk * 2, m} : shape_to_vector(A.shape())); input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape); } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad. @@ -935,7 +935,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, if (ag_on_B) { // AllGather is performed on input B tensor (default case). // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP. input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id : 0, - transb ? std::vector{k_chunk, m} : std::vector{m, k}); + transb ? std::vector{k_chunk, m} : shape_to_vector(A.shape())); input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape); } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.