diff --git a/.github/scripts/Dockerfile.ci.deps b/.github/scripts/Dockerfile.ci.deps new file mode 100644 index 000000000..fd7207536 --- /dev/null +++ b/.github/scripts/Dockerfile.ci.deps @@ -0,0 +1,35 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. + +## TE CI Dockerfile +ARG BASE_DOCKER=registry-sc-harbor.amd.com/framework/compute-rocm-rel-7.2:57_ubuntu22.04_py3.11_pytorch_release-2.8_08d38866 +FROM $BASE_DOCKER +WORKDIR / + +# Build arguments +ARG FA_VERSION=v2.8.1 +ARG ROCM_VERSION=7.2 +ARG JAX_VERSION=0.8.0 +ARG PYTHON_VERSION=311 + +RUN pip install setuptools wheel +RUN pip install ipython pytest fire pydantic pybind11 ninja pandas +RUN apt-get update && apt-get install -y vim + +# Install flash-attention +ENV GPU_ARCHS=gfx90a;gfx950;gfx942 +RUN git clone --branch ${FA_VERSION} --depth 1 https://github.com/Dao-AILab/flash-attention.git \ + && cd flash-attention \ + && FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE && FLASH_ATTENTION_SKIP_CK_BUILD=FALSE python setup.py install \ + && cd .. + +# Install JAX +RUN ROCM_MAJOR=$(echo "${ROCM_VERSION}" | cut -d. -f1) && pip install \ + https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jax_rocm${ROCM_MAJOR}_pjrt-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-py3-none-manylinux_2_28_x86_64.whl \ + https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jax_rocm${ROCM_MAJOR}_plugin-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux_2_28_x86_64.whl \ + jax==${JAX_VERSION} \ + https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jaxlib-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux_2_27_x86_64.whl + +WORKDIR /workspace/ +CMD ["/bin/bash"] \ No newline at end of file diff --git a/ci/ci_config.json b/ci/ci_config.json index a7b3d5d6c..4123a3ed2 100644 --- a/ci/ci_config.json +++ b/ci/ci_config.json @@ -1,6 +1,6 @@ { "docker_images": { - "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.7.1_fa-2.8.0", + "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.2_ubuntu22.04_py3.11_pytorch_release-2.8_08d38866_jax_0.8.0_fa_2.8.1", "release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273", "release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273" } diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index ba5ab7f87..8d2518d3e 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -31,6 +31,7 @@ init_method_normal, scaled_init_method_normal, is_bf16_compatible, + get_device_compute_capability, ) _current_file = pathlib.Path(__file__).resolve() @@ -378,7 +379,7 @@ def get_tols(config, module, backend, dtype): # With FA on ROCm it may not fit default tolerance if IS_HIP_EXTENSION and backend == "FlashAttention": tols = { - torch.half: (5e-3, 5e-3), + torch.half: (6e-3, 6e-3) if get_device_compute_capability() == (9, 4) else (5e-3, 5e-3), torch.bfloat16: (4e-2, 4e-2), } else: