Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion build_tools/build_ext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -18,7 +18,9 @@
from typing import List, Optional, Type

import setuptools
from setuptools.command.sdist import sdist as _sdist

from .te_version import te_version, is_local_version_used, version_file
from .utils import (
rocm_build,
rocm_path,
Expand Down Expand Up @@ -224,3 +226,15 @@ def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
super().build_extensions()

return _CMakeBuildExtension


class SdistWithLocalVersion(_sdist):
"""
Override sdist to modify the *staged* copy of VERSION.txt.
"""
def make_release_tree(self, base_dir, files):
# First let setuptools stage the files into base_dir
super().make_release_tree(base_dir, files)

if is_local_version_used():
version_file(base_dir).write_text(te_version() + "\n", encoding="utf-8")
16 changes: 13 additions & 3 deletions build_tools/te_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -8,6 +10,16 @@
import subprocess


def is_local_version_used() -> bool:
return not bool(int(os.getenv("NVTE_NO_LOCAL_VERSION", "0"))) and (
not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))
or bool(int(os.getenv("NVTE_USE_LOCAL_VERSION", "0"))))


def version_file(base: str | Path) -> Path:
return Path(base).resolve() / "build_tools" / "VERSION.txt"


def te_version() -> str:
"""Transformer Engine version string

Expand All @@ -18,9 +30,7 @@ def te_version() -> str:
root_path = Path(__file__).resolve().parent
with open(root_path / "VERSION.txt", "r") as f:
version = f.readline().strip()
if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool(
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
):
if is_local_version_used():
try:
output = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -473,9 +473,11 @@ if (USE_ROCM)
file(READ "${ROCM_PATH}/.info/version" ROCM_VER)
string(STRIP "${ROCM_VER}" ROCM_VER)
string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER}")
get_git_commit("${TE}" TE_COMMIT_ID)
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt"
"ROCM_VERSION: ${ROCM_VER}\n"
"GPU_TARGETS: ${CMAKE_HIP_ARCHITECTURES}\n"
"COMMIT_ID: ${TE_COMMIT_ID}\n"
)
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" DESTINATION "transformer_engine/")
endif()
6 changes: 3 additions & 3 deletions transformer_engine/jax/setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -45,7 +45,7 @@
shutil.copytree(build_tools_dir, build_tools_copy)


from build_tools.build_ext import get_build_ext
from build_tools.build_ext import get_build_ext, SdistWithLocalVersion
from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools,
clear_hipify_tools_copy)
from build_tools.te_version import te_version
Expand Down Expand Up @@ -103,7 +103,7 @@
version=te_version(),
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
cmdclass={"build_ext": CMakeBuildExtension, "sdist": SdistWithLocalVersion},
install_requires=install_requirements(),
tests_require=test_requirements(),
)
Expand Down
7 changes: 4 additions & 3 deletions transformer_engine/pytorch/setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -46,7 +46,7 @@
shutil.copytree(build_tools_dir, build_tools_copy)


from build_tools.build_ext import get_build_ext
from build_tools.build_ext import get_build_ext, SdistWithLocalVersion
from build_tools.utils import (
rocm_build, copy_common_headers, copy_hipify_tools, clear_hipify_tools_copy )
from build_tools.te_version import te_version
Expand Down Expand Up @@ -156,7 +156,8 @@ def run(self):
version=te_version(),
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand},
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand,
"sdist": SdistWithLocalVersion},
install_requires=install_requirements(),
tests_require=test_requirements(),
)
Expand Down
Loading