-
Notifications
You must be signed in to change notification settings - Fork 23
[No Merge] Update AITER subcommit and refactor internal AITER/CK FA API usage #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Micky774
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments to hopefully help reviewers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This helper script scans semi-hard-coded files wrt TE source-code in order to directly compare AITER's internal API and our attempt at utilizing it. This script is run during setup through setup.py
| int ck_to_aiter_mask_type(mask_enum mask_type, ck_tile::index_t left, ck_tile::index_t right){ | ||
| if( | ||
| mask_type == mask_enum::no_mask || | ||
| mask_type == mask_enum::window_generic | ||
| ) return 0; | ||
| if(left == -1 && right == 0){ | ||
| return mask_type == mask_enum::mask_top_left ? 1 : 2; | ||
| } | ||
| return 3; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is based on their op_tests/cpp/mha/* which isn't the most stable or reliable source, but it's the only concrete location where such a mapping is used or documented.
| void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, bool ck_log_config){ | ||
| if (!ck_log_config) { | ||
| return; | ||
| } | ||
|
|
||
| auto log_value = [](const char* label, const auto& value) { | ||
| std::cout << label << ": " << value << "\n"; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This merely standardizes the logging to make it a bit easier to parse through it at a glance while guaranteeing uniformity. This is similarly implemented across both files. Note the signature has been reduced to account for only what isn't stored in fmha_args.
| @@ -499,10 +508,13 @@ hipError_t ck_attn_bwd( | |||
| bool uses_bwd_v3, | |||
| bool is_v3_atomic_fp32, | |||
| int how_v3_bf16_cvt, | |||
| bool is_group_mode, | |||
| const char* func_name, | |||
| bool ck_log_config, | |||
| hipStream_t stream){ | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This abstracted implementation function has a signature that is a superset of the other higher-level API functions in the file, so that both can route directly to this function without affecting the API outside of this file.
| std::pair<const void*, const void*>{philox_seed_ptr, philox_offset_ptr}}; | ||
| }(); | ||
| aiter::mha_bwd_args fmha_args{}; | ||
| fmha_args.mask_type = ck_to_aiter_mask_type(mask_type, left, right); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the AITER mask type, despite the same argument referring to the CK mask type in the FWD pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that there is a PR for removing this argument which is still pending.
| } | ||
| return hipSuccess; | ||
| } | ||
| hipError_t ck_attn_bwd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both ck_attn_{varlen_}bwd are wrappers around _ck_attn_bwd_impl + TE-side post-processing kernels. Those kernels can probably be refactored similarly but that's a bit outside the scope here and of dubious value.
| # A custom develop command only used for ROCm builds | ||
| class EditableWheel(editable_wheel, SubCommand): | ||
| def run(self): | ||
| super().run() | ||
| if ( | ||
| int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and | ||
| int(os.getenv("NVTE_FUSED_ATTN", "1")) | ||
| ): | ||
| # Ensure that the AITER ASM kernels are properly available at runtime | ||
| # by creating a symlink to them. | ||
| project_dir = Path(__file__).parent | ||
| asm_src_dir = project_dir / '3rdparty' / 'aiter' / 'hsa' | ||
| # Must be synced with | ||
| # TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp | ||
| asm_target_dir = project_dir / 'transformer_engine' / 'lib' / 'aiter' | ||
| if asm_src_dir.is_dir() and not asm_target_dir.is_dir(): | ||
| print(f"Setting up symlink for AITER ASM kernels: {asm_target_dir} -> {asm_src_dir}") | ||
| asm_target_dir.symlink_to(asm_src_dir) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary to not break AITER ASM on editable installs -- we ran into this issue before.
| try: | ||
| subprocess.run( | ||
| sys.executable + " tools/check_aiter_mha_args_usage.py --mode both", | ||
| shell=True, check=True | ||
| ) | ||
| except subprocess.CalledProcessError: | ||
| print("Error checking AITER mha_args usage.") | ||
| sys.exit(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly checks the AITER API usage
| long_description_content_type="text/x-rst", | ||
| ext_modules=ext_modules, | ||
| cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, | ||
| cmdclass=cmdclass, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no change when NVTE_BUILD_METAPACKAGE=1 -- this only affects the NVTE_BUILD_METAPACKAGE=0 case.
Description
Updates AITER subcommit as well as refactors our internal usage of the API for greater clarity and explicitness. Pending the acceptance and merger of ROCm/aiter#1966
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: