diff --git a/Makefile b/Makefile index e94ed1f..c5433bb 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ CXX := g++ -CXXFLAGS := -std=c++17 -O3 -fPIC -fopenmp +CXXFLAGS := -std=c++17 -O3 -fPIC -march=native -fopenmp # Python / pybind11 include flags PYBIND11_INCLUDES := $(shell python3 -m pybind11 --includes) diff --git a/include/zenann/SimdUtils.h b/include/zenann/SimdUtils.h index af787b5..c1e522e 100644 --- a/include/zenann/SimdUtils.h +++ b/include/zenann/SimdUtils.h @@ -1,16 +1,39 @@ #pragma once #include +#include namespace zenann { -inline float l2_naive(const float* a, - const float* b, - size_t dim) { +inline float l2_simd(const float* __restrict a, + const float* __restrict b, + size_t dim) { +#if defined(__AVX2__) + const size_t step = 8; // 8 × 32-bit floats + __m256 acc = _mm256_setzero_ps(); + size_t i = 0; + for (; i + step - 1 < dim; i += step) { + __m256 va = _mm256_loadu_ps(a + i); + __m256 vb = _mm256_loadu_ps(b + i); + __m256 diff = _mm256_sub_ps(va, vb); + acc = _mm256_fmadd_ps(diff, diff, acc); // acc += diff² + } + float buf[step]; + _mm256_storeu_ps(buf, acc); + float d = 0.f; + for (int j = 0; j < step; ++j) d += buf[j]; + + for (; i < dim; ++i) { + float diff = a[i] - b[i]; + d += diff * diff; + } + return d; +#else float d = 0.f; for (size_t i = 0; i < dim; ++i) { float diff = a[i] - b[i]; d += diff * diff; } return d; +#endif } } diff --git a/src/IVFFlatIndex.cpp b/src/IVFFlatIndex.cpp index 53e3386..45465ba 100644 --- a/src/IVFFlatIndex.cpp +++ b/src/IVFFlatIndex.cpp @@ -51,7 +51,7 @@ SearchResult IVFFlatIndex::search(const Vector& query, size_t k) const { // Calculate distance from query to all centroids (parallelized) #pragma omp parallel for schedule(static) for (size_t c = 0; c < nlist_; ++c) { - float d = l2_naive(query.data(), centroids_[c].data(), dimension_); + float d = l2_simd(query.data(), centroids_[c].data(), dimension_); cdist[c] = {d, c}; } @@ -77,7 +77,7 @@ SearchResult IVFFlatIndex::search(const Vector& query, size_t k) const { // Search within this cluster's inverted list for (size_t id : lists_[c]) { - float dist = l2_naive(query.data(), data[id].data(), dimension_); + float dist = l2_simd(query.data(), data[id].data(), dimension_); if (local.size() < k) { local.emplace_back(dist, id);