diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 8bcfc5a69..5d96a9287 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -25,7 +25,6 @@ cmake_bin, debug_build_enabled, found_ninja, - get_frameworks, nvcc_path, get_max_jobs_for_parallel_build, ) @@ -158,8 +157,11 @@ def run(self) -> None: def build_extensions(self): # For core lib + JAX install, fix build_ext from pybind11.setup_helpers # to handle CUDA files correctly. + # Upstream uses get_frameworks() here which is incorrectly works when install from + # release (sdist) wheel on a system with both frameworks installed. ext_names = [ext.name for ext in self.extensions] - if "transformer_engine_pytorch" not in ext_names: + if ("transformer_engine_torch" not in ext_names and + "transformer_engine_rocm_torch" not in ext_names): # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when # extra_compile_args is a dict. for ext in self.extensions: diff --git a/build_tools/utils.py b/build_tools/utils.py index e3c5b6be8..2f9b2e031 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -208,6 +208,7 @@ def rocm_build() -> bool: # If neither ROCm nor CUDA is detected, raise an error raise FileNotFoundError("Could not detect ROCm or CUDA platform") + @functools.lru_cache(maxsize=None) def rocm_path() -> Tuple[str, str]: """ROCm root path and HIPCC binary path as a tuple""" @@ -227,6 +228,18 @@ def rocm_path() -> Tuple[str, str]: return rocm_home, hipcc_bin +def rocm_version() -> Tuple[int, ...]: + """ROCm version as a (major, minor) tuple. + Try to get ROCm version by parsing .info/version. + """ + rocm_home, _ = rocm_path() + try: + with open(rocm_home / ".info" / "version", "r") as f: + rocm_version= f.read().strip().split('.')[:2] + return tuple(int(v) for v in rocm_version) + except FileNotFoundError: + raise RuntimeError("Could not determine ROCm version.") + def cuda_toolkit_include_path() -> Tuple[str, str]: """Returns root path for cuda toolkit includes. @@ -495,10 +508,13 @@ def uninstall_te_wheel_packages(): "pip", "uninstall", "-y", - "transformer_engine_rocm", # te_cuda_vers for ROCm build + "transformer_engine", "transformer_engine_cu12", "transformer_engine_torch", "transformer_engine_jax", + "transformer_engine_rocm", + "transformer_engine_rocm_jax", + "transformer_engine_rocm_torch", ] ) diff --git a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 index dc0cd112b..318a0696f 100644 --- a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 +++ b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. @@ -45,4 +45,4 @@ COPY build_wheels.sh / WORKDIR /TransformerEngine/ RUN git clone https://github.com/ROCm/TransformerEngine.git /TransformerEngine -CMD ["/bin/bash", "/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] +CMD ["/bin/bash", "/build_wheels.sh", "manylinux_2_28_x86_64", "false", "true", "true", "true"] diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 4a6653479..076273c0b 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -14,11 +14,13 @@ BUILD_JAX=${5:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} -mkdir -p /wheelhouse/logs + +WHEEL_ROOT=${WHEEL_ROOT:-/wheelhouse} +mkdir -p $WHEEL_ROOT/logs # Generate wheels for common library. -git config --global --add safe.directory /TransformerEngine -cd /TransformerEngine +TE_ROOT=${TE_ROOT:-/TransformerEngine} +cd $TE_ROOT #If there is default Python installation, use it PYTHON=`which python || true` @@ -29,90 +31,102 @@ else fi ROCM_BUILD=`${PYBINDIR}python -c "import build_tools.utils as u; print(int(u.rocm_build()))"` - -if [ "$LOCAL_TREE_BUILD" != "1" ]; then - if [ "$ROCM_BUILD" = "1" ]; then - git pull - fi - git checkout $TARGET_BRANCH - git submodule update --init --recursive +if [ "$ROCM_BUILD" = "1" ]; then + ROCM_BUILD=true else - git submodule status --recursive | cut -d' ' -f3 | xargs -l -P1 -I_SUB_ git config --global --add safe.directory /TransformerEngine/_SUB_ + ROCM_BUILD=false fi -if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install setuptools wheel +if [ "$LOCAL_TREE_BUILD" != "1" ]; then + git config --global --add safe.directory $TE_ROOT + if [ "$NO_REPO_UPDATE" = "1" ]; then + git submodule status --recursive | cut -d' ' -f3 | xargs -l -P1 -I_SUB_ git config --global --add safe.directory $TE_ROOT/_SUB_ + else + if [ $ROCM_BUILD ]; then + git pull + fi + git checkout $TARGET_BRANCH + git submodule update --init --recursive + fi fi # Install deps -if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install pybind11[global] ninja +if [ $ROCM_BUILD ]; then + ${PYBINDIR}pip install setuptools wheel pybind11[global] ninja else ${PYBINDIR}pip install cmake pybind11[global] ninja fi if $BUILD_METAPACKAGE ; then - cd /TransformerEngine - if [ "$ROCM_BUILD" != "1" ]; then + cd $TE_ROOT + if [ ! $ROCM_BUILD ]; then PYBINDIR=/opt/python/cp310-cp310/bin/ fi - NVTE_BUILD_METAPACKAGE=1 ${PYBINDIR}python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt - mv dist/* /wheelhouse/ + NVTE_BUILD_METAPACKAGE=1 ${PYBINDIR}python setup.py bdist_wheel 2>&1 | tee $WHEEL_ROOT/logs/metapackage.txt + mv dist/* $WHEEL_ROOT/ fi -if $BUILD_COMMON ; then +if $BUILD_COMMON -a $ROCM_BUILD; then + VERSION=`cat build_tools/VERSION.txt` + WHL_BASE="transformer_engine_rocm-${VERSION}" + #dataclasses, psutil are needed for AITER + ${PYBINDIR}pip install dataclasses psutil + #hipify expects python in PATH, also ninja may be installed to python bindir + test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true + + # Create the wheel. + ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee $WHEEL_ROOT/logs/common.txt + + # Rename the wheel to make it python version agnostic. + whl_name=$(basename dist/*) + IFS='-' read -ra whl_parts <<< "$whl_name" + whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}" + mv dist/*.whl $WHEEL_ROOT/"$whl_name_target" + +elif $BUILD_COMMON; then VERSION=`cat build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" - if [ "$ROCM_BUILD" = "1" ]; then - TE_CUDA_VERS="rocm" - #dataclasses, psutil are needed for AITER - ${PYBINDIR}pip install dataclasses psutil - #hipify expects python in PATH, also ninja may be installed to python bindir - test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true - else - TE_CUDA_VERS="cu12" - PYBINDIR=/opt/python/cp38-cp38/bin/ - fi + PYBINDIR=/opt/python/cp38-cp38/bin/ # Create the wheel. - ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee $WHEEL_ROOT/logs/common.txt # Repack the wheel for cuda specific package, i.e. cu12. ${PYBINDIR}wheel unpack dist/* # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" ${PYBINDIR}wheel pack ${WHL_BASE} # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" - whl_name_target="${whl_parts[0]}_${TE_CUDA_VERS}-${whl_parts[1]}-py3-none-${whl_parts[4]}" + whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" rm -rf $WHL_BASE dist - mv *.whl /wheelhouse/"$whl_name_target" + mv *.whl $WHEEL_ROOT/"$whl_name_target" fi if $BUILD_PYTORCH ; then - cd /TransformerEngine/transformer_engine/pytorch - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 + cd $TE_ROOT/transformer_engine/pytorch + if [ $ROCM_BUILD ]; then + ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/cpu else PYBINDIR=/opt/python/cp38-cp38/bin/ ${PYBINDIR}pip install torch fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt - cp dist/* /wheelhouse/ + ${PYBINDIR}python setup.py sdist 2>&1 | tee $WHEEL_ROOT/logs/torch.txt + cp dist/* $WHEEL_ROOT/ fi if $BUILD_JAX ; then - cd /TransformerEngine/transformer_engine/jax - if [ "$ROCM_BUILD" = "1" ]; then + cd $TE_ROOT/transformer_engine/jax + if [ $ROCM_BUILD ]; then ${PYBINDIR}pip install jax else PYBINDIR=/opt/python/cp310-cp310/bin/ ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt - cp dist/* /wheelhouse/ + ${PYBINDIR}python setup.py sdist 2>&1 | tee $WHEEL_ROOT/logs/jax.txt + cp dist/* $WHEEL_ROOT/ fi diff --git a/setup.py b/setup.py index 1ae476311..bdde17b38 100644 --- a/setup.py +++ b/setup.py @@ -182,12 +182,11 @@ def setup_requirements() -> Tuple[List[str], List[str]]: assert bool( int(os.getenv("NVTE_RELEASE_BUILD", "0")) ), "NVTE_RELEASE_BUILD env must be set for metapackage build." - te_cuda_vers = "rocm" if rocm_build() else "cu12" ext_modules = [] cmdclass = {} package_data = {} include_package_data = False - install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],) + install_requires = ([f"transformer_engine_cu12=={__version__}"],) extras_require = { "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], @@ -222,9 +221,21 @@ def setup_requirements() -> Tuple[List[str], List[str]]: ) ) + PACKAGE_NAME="transformer_engine" + if rocm_build(): + if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): + install_requires = ([f"transformer_engine_rocm=={__version__}"],) + else: + PACKAGE_NAME="transformer_engine_rocm" + #On ROCm make add extra to core package too so it can be installed w/o metapackage + extras_require = { + "pytorch": [f"transformer_engine_rocm_torch=={__version__}"], + "jax": [f"transformer_engine_rocm_jax=={__version__}"], + } # Configure package setuptools.setup( - name="transformer_engine", + name=PACKAGE_NAME, version=__version__, packages=setuptools.find_packages( include=[ @@ -239,7 +250,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8", + python_requires=">=3.9", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, license_files=("LICENSE",), diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 050abc8f7..da8c33749 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -83,4 +83,9 @@ category=RuntimeWarning, ) -__version__ = str(metadata.version("transformer_engine")) +try: + __version__ = str(metadata.version("transformer_engine")) +except metadata.PackageNotFoundError: + if not transformer_engine.common.te_rocm_build: + raise + __version__ = str(metadata.version("transformer_engine_rocm")) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..312df074d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -473,9 +473,11 @@ if (USE_ROCM) file(READ "${ROCM_PATH}/.info/version" ROCM_VER) string(STRIP "${ROCM_VER}" ROCM_VER) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER}") + get_git_commit("${TE}" TE_COMMIT_ID) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" "ROCM_VERSION: ${ROCM_VER}\n" "GPU_TARGETS: ${CMAKE_HIP_ARCHITECTURES}\n" + "COMMIT_ID: ${TE_COMMIT_ID}\n" ) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" DESTINATION "transformer_engine/") endif() diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 26672bafd..02497fcb5 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -133,7 +133,7 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" - te_cuda_vers = "rocm" if te_rocm_build else "cu12" + te_core_tag = "rocm" if te_rocm_build else "cu12" # If the framework extension pip package is installed, it means that TE is installed via # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework @@ -143,24 +143,24 @@ def load_framework_extension(framework: str) -> None: "transformer_engine" ), "Could not find `transformer-engine`." assert _is_pip_package_installed( - f"transformer_engine_{te_cuda_vers}" - ), f"Could not find `transformer-engine-{te_cuda_vers}`." + f"transformer_engine_{te_core_tag}" + ), f"Could not find `transformer-engine-{te_core_tag}`." assert ( version(module_name) == version("transformer-engine") - == version(f"transformer-engine-{te_cuda_vers}") + == version(f"transformer-engine-{te_core_tag}") ), ( "TransformerEngine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-{te_cuda_vers}" - f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using " + f" v{version('transformer-engine')}, and transformer-engine-{te_core_tag}" + f" v{version(f'transformer-engine-{te_core_tag}')}. Install transformer-engine using " f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" ) # If the core package is installed via PyPI, log if # the framework extension is not found from PyPI. # Note: Should we error? This is a rare use case. - if _is_pip_package_installed(f"transformer-engine-{te_cuda_vers}"): + if _is_pip_package_installed(f"transformer-engine-{te_core_tag}"): if not _is_pip_package_installed(module_name): _logger.info( "Could not find package %s. Install transformer-engine using " diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index b58d2df7f..28a0ec029 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -99,7 +99,7 @@ # Configure package setuptools.setup( - name="transformer_engine_jax", + name="transformer_engine_rocm_jax" if rocm_build() else "transformer_engine_jax", version=te_version(), description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index e86873b12..da6906476 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -56,6 +56,9 @@ test_requirements, ) +if rocm_build(): + PACKAGE_NAME = "transformer_engine_rocm_torch" + os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) @@ -112,6 +115,10 @@ class CachedWheelsCommand(_bdist_wheel): """ def run(self): + if rocm_build(): + print("ROCm build detected, building from source...") + return super().run() + if FORCE_BUILD: super().run()