Skip to content

Commit 9b57318

Browse files
add constexpr to conditions where possible (ifs that test the residual which is the template paramter) (#484)
use fmadd in avx512 if possible (in avx2 it requires another flag) (cherry picked from commit 60b290a) Co-authored-by: meiravgri <meirav.grimberg@redis.com>
1 parent 5ee2661 commit 9b57318

17 files changed

+47
-47
lines changed

src/VecSim/spaces/IP/IP_AVX2_BF16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ float BF16_InnerProductSIMD32_AVX2(const void *pVect1v, const void *pVect2v, siz
108108
}
109109

110110
// Do a single step if residual >=16
111-
if (residual >= 16) {
111+
if constexpr (residual >= 16) {
112112
InnerProductStep(pVect1, pVect2, sum_prod);
113113
}
114114

src/VecSim/spaces/IP/IP_AVX512_FP16.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ float FP16_InnerProductSIMD32_AVX512(const void *pVect1v, const void *pVect2v, s
3131

3232
auto sum = _mm512_setzero_ps();
3333

34-
if (residual % 16) {
34+
if constexpr (residual % 16) {
3535
// Deal with remainder first. `dim` is more than 32, so we have at least one block of 32
3636
// 16-bit float so mask loading is guaranteed to be safe.
3737
__mmask16 constexpr residuals_mask = (1 << (residual % 16)) - 1;
@@ -46,7 +46,7 @@ float FP16_InnerProductSIMD32_AVX512(const void *pVect1v, const void *pVect2v, s
4646
pVect1 += residual % 16;
4747
pVect2 += residual % 16;
4848
}
49-
if (residual >= 16) {
49+
if constexpr (residual >= 16) {
5050
InnerProductStep(pVect1, pVect2, sum);
5151
}
5252

src/VecSim/spaces/IP/IP_AVX512_FP32.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ static inline void InnerProductStep(float *&pVect1, float *&pVect2, __m512 &sum5
1111
pVect1 += 16;
1212
__m512 v2 = _mm512_loadu_ps(pVect2);
1313
pVect2 += 16;
14-
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2));
14+
sum512 = _mm512_fmadd_ps(v1, v2, sum512);
1515
}
1616

1717
template <unsigned char residual> // 0..15
@@ -25,7 +25,7 @@ float FP32_InnerProductSIMD16_AVX512(const void *pVect1v, const void *pVect2v, s
2525

2626
// Deal with remainder first. `dim` is more than 16, so we have at least one 16-float block,
2727
// so mask loading is guaranteed to be safe
28-
if (residual) {
28+
if constexpr (residual) {
2929
__mmask16 constexpr mask = (1 << residual) - 1;
3030
__m512 v1 = _mm512_maskz_loadu_ps(mask, pVect1);
3131
pVect1 += residual;

src/VecSim/spaces/IP/IP_AVX512_FP64.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ static inline void InnerProductStep(double *&pVect1, double *&pVect2, __m512d &s
1111
pVect1 += 8;
1212
__m512d v2 = _mm512_loadu_pd(pVect2);
1313
pVect2 += 8;
14-
sum512 = _mm512_add_pd(sum512, _mm512_mul_pd(v1, v2));
14+
sum512 = _mm512_fmadd_pd(v1, v2, sum512);
1515
}
1616

1717
template <unsigned char residual> // 0..7
@@ -25,7 +25,7 @@ double FP64_InnerProductSIMD8_AVX512(const void *pVect1v, const void *pVect2v, s
2525

2626
// Deal with remainder first. `dim` is more than 8, so we have at least one 8-double block,
2727
// so mask loading is guaranteed to be safe
28-
if (residual) {
28+
if constexpr (residual) {
2929
__mmask8 constexpr mask = (1 << residual) - 1;
3030
__m512d v1 = _mm512_maskz_loadu_pd(mask, pVect1);
3131
pVect1 += residual;

src/VecSim/spaces/IP/IP_AVX_FP32.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ float FP32_InnerProductSIMD16_AVX(const void *pVect1v, const void *pVect2v, size
2626

2727
// Deal with 1-7 floats with mask loading, if needed. `dim` is >16, so we have at least one
2828
// 16-float block, so mask loading is guaranteed to be safe.
29-
if (residual % 8) {
29+
if constexpr (residual % 8) {
3030
__mmask8 constexpr mask = (1 << (residual % 8)) - 1;
3131
__m256 v1 = my_mm256_maskz_loadu_ps<mask>(pVect1);
3232
pVect1 += residual % 8;
@@ -36,7 +36,7 @@ float FP32_InnerProductSIMD16_AVX(const void *pVect1v, const void *pVect2v, size
3636
}
3737

3838
// If the reminder is >=8, have another step of 8 floats
39-
if (residual >= 8) {
39+
if constexpr (residual >= 8) {
4040
InnerProductStep(pVect1, pVect2, sum256);
4141
}
4242

src/VecSim/spaces/IP/IP_AVX_FP64.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ double FP64_InnerProductSIMD8_AVX(const void *pVect1v, const void *pVect2v, size
2626

2727
// Deal with 1-3 doubles with mask loading, if needed. `dim` is >8, so we have at least one
2828
// 8-double block, so mask loading is guaranteed to be safe.
29-
if (residual % 4) {
29+
if constexpr (residual % 4) {
3030
// _mm256_maskz_loadu_pd is not available in AVX
3131
__mmask8 constexpr mask = (1 << (residual % 4)) - 1;
3232
__m256d v1 = my_mm256_maskz_loadu_pd<mask>(pVect1);
@@ -37,7 +37,7 @@ double FP64_InnerProductSIMD8_AVX(const void *pVect1v, const void *pVect2v, size
3737
}
3838

3939
// If the reminder is >=4, have another step of 4 doubles
40-
if (residual >= 4) {
40+
if constexpr (residual >= 4) {
4141
InnerProductStep(pVect1, pVect2, sum256);
4242
}
4343

src/VecSim/spaces/IP/IP_F16C_FP16.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ float FP16_InnerProductSIMD32_F16C(const void *pVect1v, const void *pVect2v, siz
3131

3232
auto sum = _mm256_setzero_ps();
3333

34-
if (residual % 8) {
34+
if constexpr (residual % 8) {
3535
// Deal with remainder first. `dim` is more than 32, so we have at least one block of 32
3636
// 16-bit float so mask loading is guaranteed to be safe.
3737
__mmask16 constexpr residuals_mask = (1 << (residual % 8)) - 1;
@@ -47,12 +47,12 @@ float FP16_InnerProductSIMD32_F16C(const void *pVect1v, const void *pVect2v, siz
4747
pVect1 += residual % 8;
4848
pVect2 += residual % 8;
4949
}
50-
if (residual >= 8 && residual < 16) {
50+
if constexpr (residual >= 8 && residual < 16) {
5151
InnerProductStep(pVect1, pVect2, sum);
52-
} else if (residual >= 16 && residual < 24) {
52+
} else if constexpr (residual >= 16 && residual < 24) {
5353
InnerProductStep(pVect1, pVect2, sum);
5454
InnerProductStep(pVect1, pVect2, sum);
55-
} else if (residual >= 24) {
55+
} else if constexpr (residual >= 24) {
5656
InnerProductStep(pVect1, pVect2, sum);
5757
InnerProductStep(pVect1, pVect2, sum);
5858
InnerProductStep(pVect1, pVect2, sum);

src/VecSim/spaces/IP/IP_SSE_FP32.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@ float FP32_InnerProductSIMD16_SSE(const void *pVect1v, const void *pVect2v, size
2525

2626
// Deal with %4 remainder first. `dim` is >16, so we have at least one 16-float block,
2727
// so loading 4 floats and then masking them is safe.
28-
if (residual % 4) {
28+
if constexpr (residual % 4) {
2929
__m128 v1, v2;
30-
if (residual % 4 == 3) {
30+
if constexpr (residual % 4 == 3) {
3131
// Load 3 floats and set the last one to 0
3232
v1 = _mm_load_ss(pVect1); // load 1 float, set the rest to 0
3333
v2 = _mm_load_ss(pVect2);
3434
v1 = _mm_loadh_pi(v1, (__m64 *)(pVect1 + 1));
3535
v2 = _mm_loadh_pi(v2, (__m64 *)(pVect2 + 1));
36-
} else if (residual % 4 == 2) {
36+
} else if constexpr (residual % 4 == 2) {
3737
// Load 2 floats and set the last two to 0
3838
v1 = _mm_loadh_pi(_mm_setzero_ps(), (__m64 *)pVect1);
3939
v2 = _mm_loadh_pi(_mm_setzero_ps(), (__m64 *)pVect2);
40-
} else if (residual % 4 == 1) {
40+
} else if constexpr (residual % 4 == 1) {
4141
// Load 1 float and set the last three to 0
4242
v1 = _mm_load_ss(pVect1);
4343
v2 = _mm_load_ss(pVect2);
@@ -48,11 +48,11 @@ float FP32_InnerProductSIMD16_SSE(const void *pVect1v, const void *pVect2v, size
4848
}
4949

5050
// have another 1, 2 or 3 4-float steps according to residual
51-
if (residual >= 12)
51+
if constexpr (residual >= 12)
5252
InnerProductStep(pVect1, pVect2, sum_prod);
53-
if (residual >= 8)
53+
if constexpr (residual >= 8)
5454
InnerProductStep(pVect1, pVect2, sum_prod);
55-
if (residual >= 4)
55+
if constexpr (residual >= 4)
5656
InnerProductStep(pVect1, pVect2, sum_prod);
5757

5858
// We dealt with the residual part. We are left with some multiple of 16 floats.

src/VecSim/spaces/IP/IP_SSE_FP64.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ double FP64_InnerProductSIMD8_SSE(const void *pVect1v, const void *pVect2v, size
2525
__m128d sum_prod = _mm_setzero_pd();
2626

2727
// If residual is odd, we load 1 double and set the last one to 0
28-
if (residual % 2 == 1) {
28+
if constexpr (residual % 2 == 1) {
2929
__m128d v1 = _mm_load_sd(pVect1);
3030
pVect1++;
3131
__m128d v2 = _mm_load_sd(pVect2);
@@ -34,11 +34,11 @@ double FP64_InnerProductSIMD8_SSE(const void *pVect1v, const void *pVect2v, size
3434
}
3535

3636
// have another 1, 2 or 3 2-double steps according to residual
37-
if (residual >= 6)
37+
if constexpr (residual >= 6)
3838
InnerProductStep(pVect1, pVect2, sum_prod);
39-
if (residual >= 4)
39+
if constexpr (residual >= 4)
4040
InnerProductStep(pVect1, pVect2, sum_prod);
41-
if (residual >= 2)
41+
if constexpr (residual >= 2)
4242
InnerProductStep(pVect1, pVect2, sum_prod);
4343

4444
// We dealt with the residual part. We are left with some multiple of 8 doubles.

src/VecSim/spaces/L2/L2_AVX2_BF16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ float BF16_L2SqrSIMD32_AVX2(const void *pVect1v, const void *pVect2v, size_t dim
106106
}
107107

108108
// Do a single step if residual >=16
109-
if (residual >= 16) {
109+
if constexpr (residual >= 16) {
110110
L2SqrStep(pVect1, pVect2, sum);
111111
}
112112

0 commit comments

Comments
 (0)