Releases: ROCm/TransformerEngine
Releases · ROCm/TransformerEngine
v2.4_rocm
What's Changed
- Enable MLA in TE JAX Extension
- Allocate dkv expanded buffer according to max_tokens_kv
- Support aiter build in multiple target/dockerfile
- Enable the jax side fused-attn pytests with sequence packing +swa
- update hipify_torch submodule to fix v2 mappings
- Ensure weight transpose is valid for FP8 training
- MXFP8 hipblasLt GEMM support
- Normalization kernels for mxfp8
- Enabled fp8 gemm gelu_aux_bias
- If 'ninja' is not found, it installs it via pip
- Add transpose cache to LayerNorm kernel
Fixes
- FIX Accumulate
intoverflow in workspace memory calculation - Fix dropout when using a new-style rng
- FIX Update intra-seq padding detection in CK Fused Attention backend
- Fix NCCL error in test_torch_fsdp2
- [Fix] Ensure ln_out is not cached if wgrads
- [Fix] Increased tolerance and used FP32 to compute for unpermute kernel
Upstream release notes: https://github.com/NVIDIA/TransformerEngine/releases/tag/v2.4
Full Changelog: v2.2_rocm...v2.4_rocm
v2.2 ROCm
What's Changed
- Support math_sm_count for GEMM
- Added drop in Triton replacement for layernorm, rmsnorm
- Added Triton MXFP8 quantize/dequantize
- Reduce fp8 weight transpose cache occupied
- Switched to AOTriton 0.10c
- Switched from CK to AITER
- JAX 0.7 support
- FlashAttn 2.8.0.post2 support
- Add gfx950 as default target
- Fix building on ROCm6.2
- Fix faults with current scaling
Upstream release notes: https://github.com/NVIDIA/TransformerEngine/releases/tag/v2.2
Full Changelog: v2.1_rocm...v2.2_rocm
v2.1 ROCm
What's Changed
- Enable Multi-latent attention
- Gfx950 support
- Add release wheels building support
- Remove rocBlas support
- Add layernorm Triton kernels
Upstream changes:
https://github.com/NVIDIA/TransformerEngine/releases/tag/v2.1
https://github.com/NVIDIA/TransformerEngine/releases/tag/v2.0
Full Changelog: v1.14_rocm...v2.1_rocm
v1.14 ROCm
[CI] deprecate praxis installation and tests - Removed praxis installation and related test setup from `ci/jax.sh` - Installed `flax>=0.7.1`, with typing_extensions>=4.12.2
v1.13 ROCm
[CI] deprecate praxis installation and tests - Removed praxis installation and related test setup from `ci/jax.sh` - Installed `flax>=0.7.1`, with typing_extensions>=4.12.2
v1.9 ROCm
[ROCm] backport rmsnorm triton kernels into rocm v1.9 (#169) * [ROCm] backport rmsnorm triton kernels into rocm v1.9 * [ROCm] use single worker for CI
v1.12 ROCm
v1.12_rocm IFU release v1.12
v1.11 ROCm
[PyTorch] Drop FA as an installation requirement (#1226) (#125) Upstream cherry-pick 161b1d9 + partially e762592 Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>