|
| 1 | +/* |
| 2 | + *Copyright Redis Ltd. 2021 - present |
| 3 | + *Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or |
| 4 | + *the Server Side Public License v1 (SSPLv1). |
| 5 | + */ |
| 6 | + |
| 7 | +#include "VecSim/spaces/space_includes.h" |
| 8 | +#include <arm_neon.h> |
| 9 | + |
| 10 | +static inline void L2SquareStep(float *&pVect1, float *&pVect2, float32x4_t &sum) { |
| 11 | + float32x4_t v1 = vld1q_f32(pVect1); |
| 12 | + float32x4_t v2 = vld1q_f32(pVect2); |
| 13 | + |
| 14 | + float32x4_t diff = vsubq_f32(v1, v2); |
| 15 | + |
| 16 | + sum = vmlaq_f32(sum, diff, diff); |
| 17 | + |
| 18 | + pVect1 += 4; |
| 19 | + pVect2 += 4; |
| 20 | +} |
| 21 | + |
| 22 | +template <unsigned char residual> // 0..15 |
| 23 | +float FP32_L2SqrSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { |
| 24 | + float *pVect1 = (float *)pVect1v; |
| 25 | + float *pVect2 = (float *)pVect2v; |
| 26 | + |
| 27 | + float32x4_t sum0 = vdupq_n_f32(0.0f); |
| 28 | + float32x4_t sum1 = vdupq_n_f32(0.0f); |
| 29 | + float32x4_t sum2 = vdupq_n_f32(0.0f); |
| 30 | + float32x4_t sum3 = vdupq_n_f32(0.0f); |
| 31 | + |
| 32 | + const size_t num_of_chunks = dimension / 16; |
| 33 | + |
| 34 | + for (size_t i = 0; i < num_of_chunks; i++) { |
| 35 | + L2SquareStep(pVect1, pVect2, sum0); |
| 36 | + L2SquareStep(pVect1, pVect2, sum1); |
| 37 | + L2SquareStep(pVect1, pVect2, sum2); |
| 38 | + L2SquareStep(pVect1, pVect2, sum3); |
| 39 | + } |
| 40 | + |
| 41 | + // Handle remaining complete 4-float blocks within residual |
| 42 | + constexpr size_t remaining_chunks = residual / 4; |
| 43 | + // Unrolled loop for the 4-float blocks |
| 44 | + if constexpr (remaining_chunks >= 1) { |
| 45 | + L2SquareStep(pVect1, pVect2, sum0); |
| 46 | + } |
| 47 | + if constexpr (remaining_chunks >= 2) { |
| 48 | + L2SquareStep(pVect1, pVect2, sum1); |
| 49 | + } |
| 50 | + if constexpr (remaining_chunks >= 3) { |
| 51 | + L2SquareStep(pVect1, pVect2, sum2); |
| 52 | + } |
| 53 | + |
| 54 | + // Handle final residual elements (0-3 elements) |
| 55 | + constexpr size_t final_residual = residual % 4; |
| 56 | + if constexpr (final_residual > 0) { |
| 57 | + float32x4_t v1 = vdupq_n_f32(0.0f); |
| 58 | + float32x4_t v2 = vdupq_n_f32(0.0f); |
| 59 | + |
| 60 | + if constexpr (final_residual >= 1) { |
| 61 | + v1 = vld1q_lane_f32(pVect1, v1, 0); |
| 62 | + v2 = vld1q_lane_f32(pVect2, v2, 0); |
| 63 | + } |
| 64 | + if constexpr (final_residual >= 2) { |
| 65 | + v1 = vld1q_lane_f32(pVect1 + 1, v1, 1); |
| 66 | + v2 = vld1q_lane_f32(pVect2 + 1, v2, 1); |
| 67 | + } |
| 68 | + if constexpr (final_residual >= 3) { |
| 69 | + v1 = vld1q_lane_f32(pVect1 + 2, v1, 2); |
| 70 | + v2 = vld1q_lane_f32(pVect2 + 2, v2, 2); |
| 71 | + } |
| 72 | + |
| 73 | + float32x4_t diff = vsubq_f32(v1, v2); |
| 74 | + sum3 = vmlaq_f32(sum3, diff, diff); |
| 75 | + } |
| 76 | + |
| 77 | + // Combine all four sum accumulators |
| 78 | + float32x4_t sum_combined = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)); |
| 79 | + |
| 80 | + // Horizontal sum of the 4 elements in the combined NEON register |
| 81 | + float32x2_t sum_halves = vadd_f32(vget_low_f32(sum_combined), vget_high_f32(sum_combined)); |
| 82 | + float32x2_t summed = vpadd_f32(sum_halves, sum_halves); |
| 83 | + float sum = vget_lane_f32(summed, 0); |
| 84 | + |
| 85 | + return sum; |
| 86 | +} |
0 commit comments