From f4729f7007f2e475012bbff547a4f395aa9e2d4b Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Thu, 5 Feb 2026 20:00:21 -0500 Subject: [PATCH] Support release wheels with local version --- build_tools/build_ext.py | 16 +++++++++++++++- build_tools/te_version.py | 16 +++++++++++++--- transformer_engine/common/CMakeLists.txt | 4 +++- transformer_engine/jax/setup.py | 6 +++--- transformer_engine/pytorch/setup.py | 7 ++++--- 5 files changed, 38 insertions(+), 11 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 8bcfc5a69..26ab03a7e 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -18,7 +18,9 @@ from typing import List, Optional, Type import setuptools +from setuptools.command.sdist import sdist as _sdist +from .te_version import te_version, is_local_version_used, version_file from .utils import ( rocm_build, rocm_path, @@ -224,3 +226,15 @@ def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None: super().build_extensions() return _CMakeBuildExtension + + +class SdistWithLocalVersion(_sdist): + """ + Override sdist to modify the *staged* copy of VERSION.txt. + """ + def make_release_tree(self, base_dir, files): + # First let setuptools stage the files into base_dir + super().make_release_tree(base_dir, files) + + if is_local_version_used(): + version_file(base_dir).write_text(te_version() + "\n", encoding="utf-8") diff --git a/build_tools/te_version.py b/build_tools/te_version.py index 0aee63f64..1af09c9a6 100644 --- a/build_tools/te_version.py +++ b/build_tools/te_version.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,6 +10,16 @@ import subprocess +def is_local_version_used() -> bool: + return not bool(int(os.getenv("NVTE_NO_LOCAL_VERSION", "0"))) and ( + not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) + or bool(int(os.getenv("NVTE_USE_LOCAL_VERSION", "0")))) + + +def version_file(base: str | Path) -> Path: + return Path(base).resolve() / "build_tools" / "VERSION.txt" + + def te_version() -> str: """Transformer Engine version string @@ -18,9 +30,7 @@ def te_version() -> str: root_path = Path(__file__).resolve().parent with open(root_path / "VERSION.txt", "r") as f: version = f.readline().strip() - if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool( - int(os.getenv("NVTE_RELEASE_BUILD", "0")) - ): + if is_local_version_used(): try: output = subprocess.run( ["git", "rev-parse", "--short", "HEAD"], diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..312df074d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -473,9 +473,11 @@ if (USE_ROCM) file(READ "${ROCM_PATH}/.info/version" ROCM_VER) string(STRIP "${ROCM_VER}" ROCM_VER) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER}") + get_git_commit("${TE}" TE_COMMIT_ID) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" "ROCM_VERSION: ${ROCM_VER}\n" "GPU_TARGETS: ${CMAKE_HIP_ARCHITECTURES}\n" + "COMMIT_ID: ${TE_COMMIT_ID}\n" ) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" DESTINATION "transformer_engine/") endif() diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index b58d2df7f..7a4b7b06c 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -45,7 +45,7 @@ shutil.copytree(build_tools_dir, build_tools_copy) -from build_tools.build_ext import get_build_ext +from build_tools.build_ext import get_build_ext, SdistWithLocalVersion from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools, clear_hipify_tools_copy) from build_tools.te_version import te_version @@ -103,7 +103,7 @@ version=te_version(), description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, + cmdclass={"build_ext": CMakeBuildExtension, "sdist": SdistWithLocalVersion}, install_requires=install_requirements(), tests_require=test_requirements(), ) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index e86873b12..b999535ac 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -46,7 +46,7 @@ shutil.copytree(build_tools_dir, build_tools_copy) -from build_tools.build_ext import get_build_ext +from build_tools.build_ext import get_build_ext, SdistWithLocalVersion from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools, clear_hipify_tools_copy ) from build_tools.te_version import te_version @@ -156,7 +156,8 @@ def run(self): version=te_version(), description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand, + "sdist": SdistWithLocalVersion}, install_requires=install_requirements(), tests_require=test_requirements(), )