-
Notifications
You must be signed in to change notification settings - Fork 23
[No Merge] Update AITER subcommit and refactor internal AITER/CK FA API usage #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
f02c48e
0b0ad93
c198cbd
a52bb32
77f0a05
1637266
568e9b5
2cb6d82
cf4aa9e
e25cea8
2122479
4817e72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| from pathlib import Path | ||
| from typing import List, Tuple | ||
| import subprocess | ||
| import sys | ||
|
|
||
| import setuptools | ||
| from setuptools.command.egg_info import egg_info | ||
|
|
@@ -34,6 +35,8 @@ | |
|
|
||
|
|
||
| from setuptools.command.build_ext import build_ext as BuildExtension | ||
| from setuptools.command.editable_wheel import editable_wheel | ||
| from setuptools.command.build import SubCommand | ||
|
|
||
| os.environ["NVTE_PROJECT_BUILDING"] = "1" | ||
|
|
||
|
|
@@ -56,6 +59,25 @@ def run(self): | |
| if not rocm_build(): | ||
| archs = cuda_archs() | ||
|
|
||
| # A custom develop command only used for ROCm builds | ||
| class EditableWheel(editable_wheel, SubCommand): | ||
| def run(self): | ||
| super().run() | ||
| if ( | ||
| int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and | ||
| int(os.getenv("NVTE_FUSED_ATTN", "1")) | ||
| ): | ||
| # Ensure that the AITER ASM kernels are properly available at runtime | ||
| # by creating a symlink to them. | ||
| project_dir = Path(__file__).parent | ||
| asm_src_dir = project_dir / '3rdparty' / 'aiter' / 'hsa' | ||
| # Must be synced with | ||
| # TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp | ||
| asm_target_dir = project_dir / 'transformer_engine' / 'lib' / 'aiter' | ||
| if asm_src_dir.is_dir() and not asm_target_dir.is_dir(): | ||
| print(f"Setting up symlink for AITER ASM kernels: {asm_target_dir} -> {asm_src_dir}") | ||
| asm_target_dir.symlink_to(asm_src_dir) | ||
|
|
||
| class TimedBdist(bdist_wheel): | ||
| """Helper class to measure build time""" | ||
|
|
||
|
|
@@ -88,6 +110,14 @@ def setup_common_extension() -> CMakeExtension: | |
| cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF") | ||
| elif os.getenv("NVTE_FUSED_ATTN_CK") or os.getenv("NVTE_FUSED_ATTN"): | ||
| cmake_flags.append("-DUSE_FUSED_ATTN_CK=ON") | ||
| try: | ||
| subprocess.run( | ||
| sys.executable + " tools/check_aiter_mha_args_usage.py --mode both", | ||
| shell=True, check=True | ||
| ) | ||
| except subprocess.CalledProcessError: | ||
| print("Error checking AITER mha_args usage.") | ||
| sys.exit(1) | ||
|
Comment on lines
+113
to
+120
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explicitly checks the AITER API usage |
||
|
|
||
| if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None: | ||
| os.environ["NVTE_ENABLE_ROCSHMEM"] = '1' | ||
|
|
@@ -177,14 +207,14 @@ def setup_requirements() -> Tuple[List[str], List[str]]: | |
| with open("README.rst", encoding="utf-8") as f: | ||
| long_description = f.read() | ||
|
|
||
| cmdclass = {"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} | ||
| # Settings for building top level empty package for dependency management. | ||
| if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): | ||
| assert bool( | ||
| int(os.getenv("NVTE_RELEASE_BUILD", "0")) | ||
| ), "NVTE_RELEASE_BUILD env must be set for metapackage build." | ||
| te_cuda_vers = "rocm" if rocm_build() else "cu12" | ||
| ext_modules = [] | ||
| cmdclass = {} | ||
| package_data = {} | ||
| include_package_data = False | ||
| install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],) | ||
|
|
@@ -195,7 +225,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: | |
| else: | ||
| install_requires, test_requires = setup_requirements() | ||
| ext_modules = [setup_common_extension()] | ||
| cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} | ||
| cmdclass["editable_wheel"] = EditableWheel | ||
| package_data = { | ||
| "": ["VERSION.txt"], | ||
| "transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"], | ||
|
|
@@ -241,7 +271,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: | |
| long_description=long_description, | ||
| long_description_content_type="text/x-rst", | ||
| ext_modules=ext_modules, | ||
| cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, | ||
| cmdclass=cmdclass, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no change when |
||
| python_requires=">=3.8", | ||
| classifiers=["Programming Language :: Python :: 3"], | ||
| install_requires=install_requires, | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This helper script scans semi-hard-coded files wrt TE source-code in order to directly compare AITER's internal API and our attempt at utilizing it. This script is run during setup through |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| import argparse | ||
| import re | ||
| from pathlib import Path | ||
| from typing import List, Set | ||
| import sys | ||
|
|
||
| def parse_with_skip_comments(buffer, line, regex, outputs): | ||
| # skip comments | ||
| stripped = line.strip() | ||
| if not stripped or stripped.startswith("//"): | ||
| return | ||
| line_no_comment = re.sub(r"//.*", "", line) | ||
| buffer[0] += " " + line_no_comment.strip() | ||
| if ";" not in line_no_comment: | ||
| return | ||
| match = regex.search(buffer[0]) | ||
| if match: | ||
| outputs.append(match.group(1)) | ||
| buffer[0] = "" | ||
|
|
||
|
|
||
| def extract_fields_from_header(text: str, struct_name: str) -> List[str]: | ||
| struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") | ||
| struct_end_re = re.compile(r"^\s*};\s*$") | ||
|
|
||
| struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") | ||
| lines = text.splitlines() | ||
| in_struct = False | ||
| fields: List[str] = [] | ||
| buffer = [""] | ||
| for line in lines: | ||
| if not in_struct: | ||
| if struct_start_re.search(line): | ||
| in_struct = True | ||
| continue | ||
| if struct_end_re.search(line): | ||
| break | ||
| parse_with_skip_comments(buffer, line, struct_field_re, fields) | ||
| return fields | ||
|
|
||
|
|
||
| def extract_usage_from_source(text: str, var_name: str) -> Set[str]: | ||
| assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=") | ||
| assignments = [] | ||
| lines = text.splitlines() | ||
| buffer = [""] | ||
| for line in lines: | ||
| parse_with_skip_comments(buffer, line, assign_re, assignments) | ||
| return set(assignments) | ||
|
|
||
|
|
||
| def main() -> int: | ||
| parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") | ||
| parser.add_argument("--mode", choices=["fwd", "bwd", "both"], required=True, help="Mode: fwd, bwd, or both") | ||
| args = parser.parse_args() | ||
| modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] | ||
| mismatch = 0 | ||
| for mode in modes: | ||
| header_path = Path(f"3rdparty/aiter/csrc/include/mha_{mode}.h") | ||
| source_path = Path(f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp") | ||
| header_text = header_path.read_text(encoding="utf-8") | ||
| source_text = source_path.read_text(encoding="utf-8") | ||
|
|
||
| header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args") | ||
| header_set = set(header_fields) | ||
| used_fields = extract_usage_from_source(source_text, f"fmha_args") | ||
|
|
||
| missing_in_usage = sorted(header_set - used_fields) | ||
| unknown_in_header = sorted(used_fields - header_set) | ||
| mismatch += len(missing_in_usage) + len(unknown_in_header) | ||
|
|
||
| print(f"\nAnalyzing mha_{mode}_args\n") | ||
| print(f"mha_{mode}_args fields in header:", len(header_set)) | ||
| print(f"mha_{mode}_args fields referenced in source:", len(used_fields)) | ||
|
|
||
| if missing_in_usage: | ||
| print("\nFields present in header but not referenced in source:") | ||
| for name in missing_in_usage: | ||
| print(f" - {name}") | ||
| else: | ||
| print("\nAll header fields are referenced in source.") | ||
|
|
||
| if unknown_in_header: | ||
| print("\nFields referenced in source but not in header:") | ||
| for name in unknown_in_header: | ||
| print(f" - {name}") | ||
| else: | ||
| print("\nNo unknown fields referenced in source.") | ||
|
|
||
| if mismatch: | ||
| print(f"\nTotal mismatched fields: {mismatch}") | ||
| return 1 | ||
| return 0 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(main()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary to not break AITER ASM on editable installs -- we ran into this issue before.