Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions src/sgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,21 +404,17 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
let a1357 = _mm256_movehdup_ps(av); // Load: a1 a1 a3 a3 a5 a5 a7 a7
let a3175 = _mm256_permute_ps(a1357, PERM32_2301);

let a4602 = _mm256_permute2f128_ps(a0246, a0246, PERM128_30);
let a6420 = _mm256_permute2f128_ps(a2064, a2064, PERM128_30);

let a5713 = _mm256_permute2f128_ps(a1357, a1357, PERM128_30);
let a7531 = _mm256_permute2f128_ps(a3175, a3175, PERM128_30);
let bv_lh = _mm256_permute2f128_ps(bv, bv, PERM128_30);

ab[0] = MA::multiply_add(a0246, bv, ab[0]);
ab[1] = MA::multiply_add(a2064, bv, ab[1]);
ab[2] = MA::multiply_add(a4602, bv, ab[2]);
ab[3] = MA::multiply_add(a6420, bv, ab[3]);
ab[2] = MA::multiply_add(a0246, bv_lh, ab[2]);
ab[3] = MA::multiply_add(a2064, bv_lh, ab[3]);

ab[4] = MA::multiply_add(a1357, bv, ab[4]);
ab[5] = MA::multiply_add(a3175, bv, ab[5]);
ab[6] = MA::multiply_add(a5713, bv, ab[6]);
ab[7] = MA::multiply_add(a7531, bv, ab[7]);
ab[6] = MA::multiply_add(a1357, bv_lh, ab[6]);
ab[7] = MA::multiply_add(a3175, bv_lh, ab[7]);

if !is_last {
a = a.add(MR);
Expand All @@ -441,19 +437,19 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,

let ab0246 = ab[0];
let ab2064 = ab[1];
let ab4602 = ab[2];
let ab6420 = ab[3];
let ab4602 = ab[2]; // reverse order
let ab6420 = ab[3]; // reverse order

let ab1357 = ab[4];
let ab3175 = ab[5];
let ab5713 = ab[6];
let ab7531 = ab[7];
let ab5713 = ab[6]; // reverse order
let ab7531 = ab[7]; // reverse order

const SHUF_0123: i32 = shuffle_mask!(3, 2, 1, 0);
debug_assert_eq!(SHUF_0123, 0xE4);

const PERM128_03: i32 = permute2f128_mask!(3, 0);
const PERM128_21: i32 = permute2f128_mask!(1, 2);
const PERM128_02: i32 = permute2f128_mask!(2, 0);
const PERM128_31: i32 = permute2f128_mask!(1, 3);

// No elements are "shuffled" in truth, they all stay at their index
// but we combine vectors to de-stripe them.
Expand All @@ -480,17 +476,17 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
let ab5511 = _mm256_shuffle_ps(ab5713, ab7531, SHUF_0123);
let ab7733 = _mm256_shuffle_ps(ab7531, ab5713, SHUF_0123);

let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_03);
let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_21);
let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_02);
let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_31);

let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_03);
let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_21);
let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_02);
let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_31);

let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_03);
let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_21);
let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_02);
let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_31);

let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_03);
let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_21);
let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_02);
let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_31);

ab[0] = ab0000;
ab[1] = ab1111;
Expand Down
Loading