From d77d227b567d44012c19400ccb6f2e67f734830c Mon Sep 17 00:00:00 2001 From: SXX Date: Sun, 11 May 2025 14:03:32 +0800 Subject: [PATCH] sgemm: Reduce unnecessary AVX register permutations - Removed redundant `_mm256_permute2f128_ps` instructions for lane swapping. - Consolidated `bv_lh` usage for upper and lower halves, reducing the number of separate permutes. - Reordered final output assignments to match the expected layout directly, simplifying downstream processing. - This change reduces register pressure and improves instruction efficiency without altering the computation logic. --- src/sgemm_kernel.rs | 42 +++++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/src/sgemm_kernel.rs b/src/sgemm_kernel.rs index fc9e998..28fe8ed 100644 --- a/src/sgemm_kernel.rs +++ b/src/sgemm_kernel.rs @@ -404,21 +404,17 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T, let a1357 = _mm256_movehdup_ps(av); // Load: a1 a1 a3 a3 a5 a5 a7 a7 let a3175 = _mm256_permute_ps(a1357, PERM32_2301); - let a4602 = _mm256_permute2f128_ps(a0246, a0246, PERM128_30); - let a6420 = _mm256_permute2f128_ps(a2064, a2064, PERM128_30); - - let a5713 = _mm256_permute2f128_ps(a1357, a1357, PERM128_30); - let a7531 = _mm256_permute2f128_ps(a3175, a3175, PERM128_30); + let bv_lh = _mm256_permute2f128_ps(bv, bv, PERM128_30); ab[0] = MA::multiply_add(a0246, bv, ab[0]); ab[1] = MA::multiply_add(a2064, bv, ab[1]); - ab[2] = MA::multiply_add(a4602, bv, ab[2]); - ab[3] = MA::multiply_add(a6420, bv, ab[3]); + ab[2] = MA::multiply_add(a0246, bv_lh, ab[2]); + ab[3] = MA::multiply_add(a2064, bv_lh, ab[3]); ab[4] = MA::multiply_add(a1357, bv, ab[4]); ab[5] = MA::multiply_add(a3175, bv, ab[5]); - ab[6] = MA::multiply_add(a5713, bv, ab[6]); - ab[7] = MA::multiply_add(a7531, bv, ab[7]); + ab[6] = MA::multiply_add(a1357, bv_lh, ab[6]); + ab[7] = MA::multiply_add(a3175, bv_lh, ab[7]); if !is_last { a = a.add(MR); @@ -441,19 +437,19 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T, let ab0246 = ab[0]; let ab2064 = ab[1]; - let ab4602 = ab[2]; - let ab6420 = ab[3]; + let ab4602 = ab[2]; // reverse order + let ab6420 = ab[3]; // reverse order let ab1357 = ab[4]; let ab3175 = ab[5]; - let ab5713 = ab[6]; - let ab7531 = ab[7]; + let ab5713 = ab[6]; // reverse order + let ab7531 = ab[7]; // reverse order const SHUF_0123: i32 = shuffle_mask!(3, 2, 1, 0); debug_assert_eq!(SHUF_0123, 0xE4); - const PERM128_03: i32 = permute2f128_mask!(3, 0); - const PERM128_21: i32 = permute2f128_mask!(1, 2); + const PERM128_02: i32 = permute2f128_mask!(2, 0); + const PERM128_31: i32 = permute2f128_mask!(1, 3); // No elements are "shuffled" in truth, they all stay at their index // but we combine vectors to de-stripe them. @@ -480,17 +476,17 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T, let ab5511 = _mm256_shuffle_ps(ab5713, ab7531, SHUF_0123); let ab7733 = _mm256_shuffle_ps(ab7531, ab5713, SHUF_0123); - let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_03); - let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_21); + let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_02); + let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_31); - let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_03); - let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_21); + let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_02); + let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_31); - let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_03); - let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_21); + let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_02); + let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_31); - let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_03); - let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_21); + let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_02); + let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_31); ab[0] = ab0000; ab[1] = ab1111;