Skip to content
Merged
18 changes: 13 additions & 5 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,6 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen64.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen128.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen32.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
Expand Down Expand Up @@ -847,9 +842,22 @@ endif()
${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S
${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.cpp
Comment thread
hariharans29 marked this conversation as resolved.
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen64.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen128.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit_blklen32.h
)
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl")

# Strip -mavx512vnni from the W2 scalar oracle / pack-helper TU so the
Comment thread
hariharans29 marked this conversation as resolved.
# compiler cannot autovectorize its int8 dot-product loops to vpdpbusd.
# Needed because this helper runs at model load on AVX-512-only
# (non-VNNI) hosts via the AVX-512 W2 dispatch. TU is pure C++ -- no
# AVX-512 intrinsics inside.
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512_2bit.cpp PROPERTIES
COMPILE_FLAGS "-mfma -mavx512bw -mavx512dq -mavx512vl")

set(mlas_platform_srcs_avx512vnni
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
${MLAS_SRC_DIR}/qkv_quant_kernel_avx512vnni.cpp
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/test/mlas/unittest/test_sqnbitgemm_2bit_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ void RunW2Case(size_t M, size_t N, size_t K, bool WithBias, uint32_t seed,
// handler in a follow-up.
//
TEST(MlasSq2BitTest, Scalar_BlkLen64) {
if (!GetMlasPlatform().Avx512Supported_) {
GTEST_SKIP() << "AVX-512 not available on this host";
}
struct Shape {
size_t M, N, K;
};
Expand Down Expand Up @@ -351,6 +354,9 @@ TEST(MlasSq2BitTest, Scalar_BlkLen64) {
// Same coverage with per-block non-default zero points.
//
TEST(MlasSq2BitTest, Scalar_BlkLen64_WithZeroPoints) {
if (!GetMlasPlatform().Avx512Supported_) {
GTEST_SKIP() << "AVX-512 not available on this host";
}
struct Shape {
size_t M, N, K;
};
Expand Down Expand Up @@ -806,6 +812,9 @@ constexpr struct {
} // namespace

TEST(MlasSq2BitTest, Scalar_BlkLen128) {
if (!GetMlasPlatform().Avx512Supported_) {
GTEST_SKIP() << "AVX-512 not available on this host";
}
for (uint32_t seed : {0xC0FFEEu, 0xBADC0DEu}) {
for (const auto& s : kSimdShapes_BlkLen128) {
for (bool bias : {false, true}) {
Expand All @@ -819,6 +828,9 @@ TEST(MlasSq2BitTest, Scalar_BlkLen128) {
}

TEST(MlasSq2BitTest, Scalar_BlkLen128_WithZeroPoints) {
if (!GetMlasPlatform().Avx512Supported_) {
GTEST_SKIP() << "AVX-512 not available on this host";
}
for (uint32_t seed : {0xC0FFEEu, 0xBADC0DEu}) {
for (const auto& s : kSimdShapes_BlkLen128) {
for (bool bias : {false, true}) {
Expand Down Expand Up @@ -1160,6 +1172,9 @@ constexpr struct {
} // namespace

TEST(MlasSq2BitTest, Scalar_BlkLen32) {
if (!GetMlasPlatform().Avx512Supported_) {
GTEST_SKIP() << "AVX-512 not available on this host";
}
for (uint32_t seed : {0xC0FFEEu, 0xBADC0DEu}) {
for (const auto& s : kSimdShapes_BlkLen32) {
for (bool bias : {false, true}) {
Expand All @@ -1173,6 +1188,9 @@ TEST(MlasSq2BitTest, Scalar_BlkLen32) {
}

TEST(MlasSq2BitTest, Scalar_BlkLen32_WithZeroPoints) {
if (!GetMlasPlatform().Avx512Supported_) {
GTEST_SKIP() << "AVX-512 not available on this host";
}
for (uint32_t seed : {0xC0FFEEu, 0xBADC0DEu}) {
for (const auto& s : kSimdShapes_BlkLen32) {
for (bool bias : {false, true}) {
Expand Down
Loading