Disable TE cross entropy loss fusion#5115
Conversation
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
Greptile SummaryThis PR disables the Transformer Engine implementation of cross-entropy loss fusion by adding a guard that raises an error when
Confidence Score: 4/5The change correctly blocks the unstable TE fusion path; the only concern is that In megatron/core/model_parallel_config.py and tests/unit_tests/test_model_parallel_config.py (the test expects AssertionError and will need updating alongside the fix) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[User sets cross_entropy_loss_fusion=True] --> B{cross_entropy_fusion_impl?}
B -- "'te'" --> C["ModelParallelConfig.__post_init__\nassert / raise ValueError"]
C --> D[AssertionError / ValueError\n'TE fusion disabled due to stability issues']
B -- "'native'" --> E[Validation passes]
E --> F[Native cross-entropy fusion used]
G[CLI: --cross-entropy-loss-fusion] --> H{--cross-entropy-fusion-impl?}
H -- "te" --> I["validate_args assert"]
I --> J[AssertionError\n'Use --cross-entropy-fusion-impl native']
H -- "native" --> K[Validation passes]
K --> L[Native cross-entropy fusion used]
Reviews (1): Last reviewed commit: "Disable TE cross entropy loss fusion" | Re-trigger Greptile |
| assert not ( | ||
| self.cross_entropy_loss_fusion and self.cross_entropy_fusion_impl == 'te' | ||
| ), ( | ||
| "Transformer Engine cross entropy loss fusion is disabled due to stability issues. " | ||
| "Use cross_entropy_fusion_impl='native', or disable cross_entropy_loss_fusion." | ||
| ) |
There was a problem hiding this comment.
Every other validation in
__post_init__ uses raise ValueError(...), but this guard uses assert. Running Python with the -O (optimize) flag silently disables all assert statements, meaning the TE fusion combination would be permitted and the stability regression would reappear without any error. Use raise ValueError here to match the surrounding code and make the guard unconditional.
| assert not ( | |
| self.cross_entropy_loss_fusion and self.cross_entropy_fusion_impl == 'te' | |
| ), ( | |
| "Transformer Engine cross entropy loss fusion is disabled due to stability issues. " | |
| "Use cross_entropy_fusion_impl='native', or disable cross_entropy_loss_fusion." | |
| ) | |
| if self.cross_entropy_loss_fusion and self.cross_entropy_fusion_impl == 'te': | |
| raise ValueError( | |
| "Transformer Engine cross entropy loss fusion is disabled due to stability issues. " | |
| "Use cross_entropy_fusion_impl='native', or disable cross_entropy_loss_fusion." | |
| ) |
| def test_te_cross_entropy_loss_fusion_is_disabled(): | ||
| with pytest.raises(AssertionError, match="Transformer Engine cross entropy loss fusion"): | ||
| ModelParallelConfig(cross_entropy_loss_fusion=True, cross_entropy_fusion_impl='te') |
There was a problem hiding this comment.
If the guard in
model_parallel_config.py is changed from assert to raise ValueError (as suggested), this test will need to match ValueError instead of AssertionError.
| def test_te_cross_entropy_loss_fusion_is_disabled(): | |
| with pytest.raises(AssertionError, match="Transformer Engine cross entropy loss fusion"): | |
| ModelParallelConfig(cross_entropy_loss_fusion=True, cross_entropy_fusion_impl='te') | |
| def test_te_cross_entropy_loss_fusion_is_disabled(): | |
| with pytest.raises(ValueError, match="Transformer Engine cross entropy loss fusion"): | |
| ModelParallelConfig(cross_entropy_loss_fusion=True, cross_entropy_fusion_impl='te') |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
2fac253 to
89daa0f
Compare
|
/ok to test eab65fb |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26906455119 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26914187574 |
[X] I, the PR author, have personally reviewed every line of this PR.
What does this PR do ?
Disables the Transformer Engine implementation of cross entropy loss fusion with an assertion due to observed training stability issues, while keeping native cross entropy fusion available.
Issue tracking
Linked issue: N/A, small stability bug fix.
Contribution process
Pre-checks
Validation
cross_entropy_loss_fusion=Truewithcross_entropy_fusion_impl='te'.native.python -m py_compileon changed Python files.git diff --check.torchis not installed.