Skip to content

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Feb 9, 2026

Description

Updates AITER subcommit as well as refactors our internal usage of the API for greater clarity and explicitness. Pending the acceptance and merger of ROCm/aiter#1966

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Updates AITER subcommit
  • Refactors to central implementation functions for batch vs group mode fwd/bwd functions
  • Refactors to explicit instantiation of arguments (rather than relying on positional)
  • Includes new AITER mask type conversion as part of necessary arguments for the new API
  • Streamlines logging
  • Adds helper python script to explicitly compare our application of the AITER API vs the AITER API's declaration through its header files.
  • Uses helper script during install as sanity-check on appropriate usage of AITER API
  • WIP pending tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Contributor Author

@Micky774 Micky774 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments to hopefully help reviewers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper script scans semi-hard-coded files wrt TE source-code in order to directly compare AITER's internal API and our attempt at utilizing it. This script is run during setup through setup.py

Comment on lines +19 to +28
int ck_to_aiter_mask_type(mask_enum mask_type, ck_tile::index_t left, ck_tile::index_t right){
if(
mask_type == mask_enum::no_mask ||
mask_type == mask_enum::window_generic
) return 0;
if(left == -1 && right == 0){
return mask_type == mask_enum::mask_top_left ? 1 : 2;
}
return 3;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on their op_tests/cpp/mha/* which isn't the most stable or reliable source, but it's the only concrete location where such a mapping is used or documented.

Comment on lines +347 to +354
void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, bool ck_log_config){
if (!ck_log_config) {
return;
}

auto log_value = [](const char* label, const auto& value) {
std::cout << label << ": " << value << "\n";
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This merely standardizes the logging to make it a bit easier to parse through it at a glance while guaranteeing uniformity. This is similarly implemented across both files. Note the signature has been reduced to account for only what isn't stored in fmha_args.

Comment on lines 467 to 514
@@ -499,10 +508,13 @@ hipError_t ck_attn_bwd(
bool uses_bwd_v3,
bool is_v3_atomic_fp32,
int how_v3_bf16_cvt,
bool is_group_mode,
const char* func_name,
bool ck_log_config,
hipStream_t stream){
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This abstracted implementation function has a signature that is a superset of the other higher-level API functions in the file, so that both can route directly to this function without affecting the API outside of this file.

std::pair<const void*, const void*>{philox_seed_ptr, philox_offset_ptr}};
}();
aiter::mha_bwd_args fmha_args{};
fmha_args.mask_type = ck_to_aiter_mask_type(mask_type, left, right);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the AITER mask type, despite the same argument referring to the CK mask type in the FWD pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there is a PR for removing this argument which is still pending.

}
return hipSuccess;
}
hipError_t ck_attn_bwd(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both ck_attn_{varlen_}bwd are wrappers around _ck_attn_bwd_impl + TE-side post-processing kernels. Those kernels can probably be refactored similarly but that's a bit outside the scope here and of dubious value.

@Micky774 Micky774 marked this pull request as ready for review February 11, 2026 16:01
Comment on lines +62 to +80
# A custom develop command only used for ROCm builds
class EditableWheel(editable_wheel, SubCommand):
def run(self):
super().run()
if (
int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and
int(os.getenv("NVTE_FUSED_ATTN", "1"))
):
# Ensure that the AITER ASM kernels are properly available at runtime
# by creating a symlink to them.
project_dir = Path(__file__).parent
asm_src_dir = project_dir / '3rdparty' / 'aiter' / 'hsa'
# Must be synced with
# TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
asm_target_dir = project_dir / 'transformer_engine' / 'lib' / 'aiter'
if asm_src_dir.is_dir() and not asm_target_dir.is_dir():
print(f"Setting up symlink for AITER ASM kernels: {asm_target_dir} -> {asm_src_dir}")
asm_target_dir.symlink_to(asm_src_dir)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary to not break AITER ASM on editable installs -- we ran into this issue before.

Comment on lines +113 to +120
try:
subprocess.run(
sys.executable + " tools/check_aiter_mha_args_usage.py --mode both",
shell=True, check=True
)
except subprocess.CalledProcessError:
print("Error checking AITER mha_args usage.")
sys.exit(1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly checks the AITER API usage

long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
cmdclass=cmdclass,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no change when NVTE_BUILD_METAPACKAGE=1 -- this only affects the NVTE_BUILD_METAPACKAGE=0 case.

@Micky774 Micky774 changed the title [WIP] Update AITER subcommit and refactor internal AITER/CK FA API usage [No Merge] Update AITER subcommit and refactor internal AITER/CK FA API usage Feb 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant