Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/scripts/Dockerfile.ci.deps
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

## TE CI Dockerfile
ARG BASE_DOCKER=registry-sc-harbor.amd.com/framework/compute-rocm-rel-7.2:57_ubuntu22.04_py3.11_pytorch_release-2.8_08d38866
FROM $BASE_DOCKER
WORKDIR /

# Build arguments
ARG FA_VERSION=v2.8.1
ARG ROCM_VERSION=7.2
ARG JAX_VERSION=0.8.0
ARG PYTHON_VERSION=311

RUN pip install setuptools wheel
RUN pip install ipython pytest fire pydantic pybind11 ninja pandas
RUN apt-get update && apt-get install -y vim

# Install flash-attention
ENV GPU_ARCHS=gfx90a;gfx950;gfx942
RUN git clone --branch ${FA_VERSION} --depth 1 https://github.com/Dao-AILab/flash-attention.git \
&& cd flash-attention \
&& FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE && FLASH_ATTENTION_SKIP_CK_BUILD=FALSE python setup.py install \
&& cd ..

# Install JAX
RUN ROCM_MAJOR=$(echo "${ROCM_VERSION}" | cut -d. -f1) && pip install \
https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jax_rocm${ROCM_MAJOR}_pjrt-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-py3-none-manylinux_2_28_x86_64.whl \
https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jax_rocm${ROCM_MAJOR}_plugin-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux_2_28_x86_64.whl \
jax==${JAX_VERSION} \
https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jaxlib-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux_2_27_x86_64.whl

WORKDIR /workspace/
CMD ["/bin/bash"]
2 changes: 1 addition & 1 deletion ci/ci_config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"docker_images": {
"default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.7.1_fa-2.8.0",
"default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.2_ubuntu22.04_py3.11_pytorch_release-2.8_08d38866_jax_0.8.0_fa_2.8.1",
"release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273",
"release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273"
}
Expand Down
3 changes: 2 additions & 1 deletion tests/pytorch/attention/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
get_device_compute_capability,
)

_current_file = pathlib.Path(__file__).resolve()
Expand Down Expand Up @@ -378,7 +379,7 @@ def get_tols(config, module, backend, dtype):
# With FA on ROCm it may not fit default tolerance
if IS_HIP_EXTENSION and backend == "FlashAttention":
tols = {
torch.half: (5e-3, 5e-3),
torch.half: (6e-3, 6e-3) if get_device_compute_capability() == (9, 4) else (5e-3, 5e-3),
torch.bfloat16: (4e-2, 4e-2),
}
else:
Expand Down
Loading