Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ artifacts/
**/times.csv
transformer_engine/build_info.txt
transformer_engine/common/util/hip_nvml.*
transformer_engine/lib/aiter
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 1130 files
36 changes: 33 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path
from typing import List, Tuple
import subprocess
import sys

import setuptools
from setuptools.command.egg_info import egg_info
Expand All @@ -34,6 +35,8 @@


from setuptools.command.build_ext import build_ext as BuildExtension
from setuptools.command.editable_wheel import editable_wheel
from setuptools.command.build import SubCommand

os.environ["NVTE_PROJECT_BUILDING"] = "1"

Expand All @@ -56,6 +59,25 @@ def run(self):
if not rocm_build():
archs = cuda_archs()

# A custom develop command only used for ROCm builds
class EditableWheel(editable_wheel, SubCommand):
def run(self):
super().run()
if (
int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and
int(os.getenv("NVTE_FUSED_ATTN", "1"))
):
# Ensure that the AITER ASM kernels are properly available at runtime
# by creating a symlink to them.
project_dir = Path(__file__).parent
asm_src_dir = project_dir / '3rdparty' / 'aiter' / 'hsa'
# Must be synced with
# TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
asm_target_dir = project_dir / 'transformer_engine' / 'lib' / 'aiter'
if asm_src_dir.is_dir() and not asm_target_dir.is_dir():
print(f"Setting up symlink for AITER ASM kernels: {asm_target_dir} -> {asm_src_dir}")
asm_target_dir.symlink_to(asm_src_dir)

Comment on lines +62 to +80
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary to not break AITER ASM on editable installs -- we ran into this issue before.

class TimedBdist(bdist_wheel):
"""Helper class to measure build time"""

Expand Down Expand Up @@ -88,6 +110,14 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF")
elif os.getenv("NVTE_FUSED_ATTN_CK") or os.getenv("NVTE_FUSED_ATTN"):
cmake_flags.append("-DUSE_FUSED_ATTN_CK=ON")
try:
subprocess.run(
sys.executable + " tools/check_aiter_mha_args_usage.py --mode both",
shell=True, check=True
)
except subprocess.CalledProcessError:
print("Error checking AITER mha_args usage.")
sys.exit(1)
Comment on lines +113 to +120
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly checks the AITER API usage


if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None:
os.environ["NVTE_ENABLE_ROCSHMEM"] = '1'
Expand Down Expand Up @@ -177,14 +207,14 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
with open("README.rst", encoding="utf-8") as f:
long_description = f.read()

cmdclass = {"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
# Settings for building top level empty package for dependency management.
if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
assert bool(
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
te_cuda_vers = "rocm" if rocm_build() else "cu12"
ext_modules = []
cmdclass = {}
package_data = {}
include_package_data = False
install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],)
Expand All @@ -195,7 +225,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
else:
install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
cmdclass["editable_wheel"] = EditableWheel
package_data = {
"": ["VERSION.txt"],
"transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"],
Expand Down Expand Up @@ -241,7 +271,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
long_description=long_description,
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
cmdclass=cmdclass,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no change when NVTE_BUILD_METAPACKAGE=1 -- this only affects the NVTE_BUILD_METAPACKAGE=0 case.

python_requires=">=3.8",
classifiers=["Programming Language :: Python :: 3"],
install_requires=install_requires,
Expand Down
97 changes: 97 additions & 0 deletions tools/check_aiter_mha_args_usage.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper script scans semi-hard-coded files wrt TE source-code in order to directly compare AITER's internal API and our attempt at utilizing it. This script is run during setup through setup.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
import re
from pathlib import Path
from typing import List, Set
import sys

def parse_with_skip_comments(buffer, line, regex, outputs):
# skip comments
stripped = line.strip()
if not stripped or stripped.startswith("//"):
return
line_no_comment = re.sub(r"//.*", "", line)
buffer[0] += " " + line_no_comment.strip()
if ";" not in line_no_comment:
return
match = regex.search(buffer[0])
if match:
outputs.append(match.group(1))
buffer[0] = ""


def extract_fields_from_header(text: str, struct_name: str) -> List[str]:
struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$")
struct_end_re = re.compile(r"^\s*};\s*$")

struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b")
lines = text.splitlines()
in_struct = False
fields: List[str] = []
buffer = [""]
for line in lines:
if not in_struct:
if struct_start_re.search(line):
in_struct = True
continue
if struct_end_re.search(line):
break
parse_with_skip_comments(buffer, line, struct_field_re, fields)
return fields


def extract_usage_from_source(text: str, var_name: str) -> Set[str]:
assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=")
assignments = []
lines = text.splitlines()
buffer = [""]
for line in lines:
parse_with_skip_comments(buffer, line, assign_re, assignments)
return set(assignments)


def main() -> int:
parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition")
parser.add_argument("--mode", choices=["fwd", "bwd", "both"], required=True, help="Mode: fwd, bwd, or both")
args = parser.parse_args()
modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode]
mismatch = 0
for mode in modes:
header_path = Path(f"3rdparty/aiter/csrc/include/mha_{mode}.h")
source_path = Path(f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp")
header_text = header_path.read_text(encoding="utf-8")
source_text = source_path.read_text(encoding="utf-8")

header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args")
header_set = set(header_fields)
used_fields = extract_usage_from_source(source_text, f"fmha_args")

missing_in_usage = sorted(header_set - used_fields)
unknown_in_header = sorted(used_fields - header_set)
mismatch += len(missing_in_usage) + len(unknown_in_header)

print(f"\nAnalyzing mha_{mode}_args\n")
print(f"mha_{mode}_args fields in header:", len(header_set))
print(f"mha_{mode}_args fields referenced in source:", len(used_fields))

if missing_in_usage:
print("\nFields present in header but not referenced in source:")
for name in missing_in_usage:
print(f" - {name}")
else:
print("\nAll header fields are referenced in source.")

if unknown_in_header:
print("\nFields referenced in source but not in header:")
for name in unknown_in_header:
print(f" - {name}")
else:
print("\nNo unknown fields referenced in source.")

if mismatch:
print(f"\nTotal mismatched fields: {mismatch}")
return 1
return 0


if __name__ == "__main__":
sys.exit(main())
Loading
Loading