diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 1a5eba9de4759..b8ab7142b6b35 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -796,11 +796,6 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen64.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen128.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen32.h ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h @@ -847,9 +842,22 @@ endif() ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.h + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen64.h + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen128.h + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen32.h ) set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") + # Strip -mavx512vnni from the W2 scalar oracle / pack-helper TU so the + # compiler cannot autovectorize its int8 dot-product loops to vpdpbusd. + # Needed because this helper runs at model load on AVX-512-only + # (non-VNNI) hosts via the AVX-512 W2 dispatch. TU is pure C++ -- no + # AVX-512 intrinsics inside. + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.cpp PROPERTIES + COMPILE_FLAGS "-mfma -mavx512bw -mavx512dq -mavx512vl") + set(mlas_platform_srcs_avx512vnni ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp ${MLAS_SRC_DIR}/qkv_quant_kernel_avx512vnni.cpp diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm_2bit_gemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_2bit_gemm.cpp index 7690b32c3c6a8..9e5a83d4fad99 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm_2bit_gemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_2bit_gemm.cpp @@ -310,6 +310,9 @@ void RunW2Case(size_t M, size_t N, size_t K, bool WithBias, uint32_t seed, // handler in a follow-up. // TEST(MlasSq2BitTest, Scalar_BlkLen64) { + if (!GetMlasPlatform().Avx512Supported_) { + GTEST_SKIP() << "AVX-512 not available on this host"; + } struct Shape { size_t M, N, K; }; @@ -351,6 +354,9 @@ TEST(MlasSq2BitTest, Scalar_BlkLen64) { // Same coverage with per-block non-default zero points. // TEST(MlasSq2BitTest, Scalar_BlkLen64_WithZeroPoints) { + if (!GetMlasPlatform().Avx512Supported_) { + GTEST_SKIP() << "AVX-512 not available on this host"; + } struct Shape { size_t M, N, K; }; @@ -806,6 +812,9 @@ constexpr struct { } // namespace TEST(MlasSq2BitTest, Scalar_BlkLen128) { + if (!GetMlasPlatform().Avx512Supported_) { + GTEST_SKIP() << "AVX-512 not available on this host"; + } for (uint32_t seed : {0xC0FFEEu, 0xBADC0DEu}) { for (const auto& s : kSimdShapes_BlkLen128) { for (bool bias : {false, true}) { @@ -819,6 +828,9 @@ TEST(MlasSq2BitTest, Scalar_BlkLen128) { } TEST(MlasSq2BitTest, Scalar_BlkLen128_WithZeroPoints) { + if (!GetMlasPlatform().Avx512Supported_) { + GTEST_SKIP() << "AVX-512 not available on this host"; + } for (uint32_t seed : {0xC0FFEEu, 0xBADC0DEu}) { for (const auto& s : kSimdShapes_BlkLen128) { for (bool bias : {false, true}) { @@ -1160,6 +1172,9 @@ constexpr struct { } // namespace TEST(MlasSq2BitTest, Scalar_BlkLen32) { + if (!GetMlasPlatform().Avx512Supported_) { + GTEST_SKIP() << "AVX-512 not available on this host"; + } for (uint32_t seed : {0xC0FFEEu, 0xBADC0DEu}) { for (const auto& s : kSimdShapes_BlkLen32) { for (bool bias : {false, true}) { @@ -1173,6 +1188,9 @@ TEST(MlasSq2BitTest, Scalar_BlkLen32) { } TEST(MlasSq2BitTest, Scalar_BlkLen32_WithZeroPoints) { + if (!GetMlasPlatform().Avx512Supported_) { + GTEST_SKIP() << "AVX-512 not available on this host"; + } for (uint32_t seed : {0xC0FFEEu, 0xBADC0DEu}) { for (const auto& s : kSimdShapes_BlkLen32) { for (bool bias : {false, true}) {