-
Notifications
You must be signed in to change notification settings - Fork 23
Enable AOTriton BWD V3 API #382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp
Outdated
Show resolved
Hide resolved
|
Note this PR is not merge ready -- there are still some CI failures from the level 3 run that need to be corrected. |
ff336b5 to
51da203
Compare
|
CI failures have been addressed (thanks @xinyazhang for your help!). Ready for review @wenchenvincent @ipanfilo @wangye805 |
| RoPE, | ||
| is_training, | ||
| ) | ||
| if len(fused_attn_backends) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or !IS_HIP_EXTENSION
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| is_training, | ||
| ) | ||
| elif len(fused_attn_backends) == 2: | ||
| elif len(fused_attn_backends) == 2 and IS_HIP_EXTENSION: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IS_HIP_EXTENSION here is redundant
| set(__AOTRITON_SUFFIX "_TEprivate") | ||
|
|
||
| if(NOT DEFINED AOTRITON_PATH) | ||
| # If AOTRITON_PATH is not provided, we proceed to build the runtime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your original changes used env variable to control it. On the other hand this feature seems unused
Description
This is a follow-up to the earlier AOTriton update PR (#360). This one focuses on adopting the AOTriton V3 API for the BWD pass, including introducing a
LazyTensorwrapper implementation (driven by an eager tensor as usual).Follow-up for #360
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: