diff --git a/.github/workflows/nightly_coverage.yml b/.github/workflows/nightly_coverage.yml index 81429b1eb..e7276cc07 100644 --- a/.github/workflows/nightly_coverage.yml +++ b/.github/workflows/nightly_coverage.yml @@ -48,8 +48,7 @@ jobs: - name: Install system dependencies run: | sudo apt-get update - sudo apt-get install -y --no-install-recommends \ - lcov libaio-dev + sudo apt-get install -y lcov shell: bash - name: Install dependencies diff --git a/.gitignore b/.gitignore index b2fd51657..65859ca54 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,9 @@ allure-* !build_android.sh !build_ios.sh +# congfig +doc/ +config/ +examples/python/ +examples/c_api/ +logs/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b79b3c0b..708bac7a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.26) +cmake_minimum_required(VERSION 3.13) cmake_policy(SET CMP0077 NEW) project(zvec) set(CC_CXX_STANDARD 17) diff --git a/examples/c/index_example.c b/examples/c/index_example.c index 403c0bef9..5cce395cf 100644 --- a/examples/c/index_example.c +++ b/examples/c/index_example.c @@ -85,6 +85,12 @@ int main() { zvec_index_params_set_metric_type(hnsw_params_fast, ZVEC_METRIC_TYPE_L2); zvec_index_params_set_hnsw_params(hnsw_params_fast, 16, 100); + // Demonstrate INT8 quantization with random rotation preprocessing + // (enable_rotate rotates vectors before INT8 quantization to reduce error) + zvec_index_params_set_quantize_type(hnsw_params_fast, + ZVEC_QUANTIZE_TYPE_INT8); + zvec_index_params_set_quantizer_enable_rotate(hnsw_params_fast, true); + zvec_index_params_t *hnsw_params_balanced = zvec_index_params_create(ZVEC_INDEX_TYPE_HNSW); if (!hnsw_params_balanced) { diff --git a/python/tests/test_params.py b/python/tests/test_params.py index 2d2ba2795..e3c3070ad 100644 --- a/python/tests/test_params.py +++ b/python/tests/test_params.py @@ -36,6 +36,7 @@ IndexType, MetricType, QuantizeType, + QuantizerParam, DataType, VectorSchema, ) @@ -452,3 +453,88 @@ def test_isinstance_compatibility(self): warnings.simplefilter("always") vq = VectorQuery(field_name="embedding", id="doc123") assert isinstance(vq, Query) + + +# ---------------------------- +# QuantizerParam Test Case +# ---------------------------- + + +class TestQuantizerParam: + def test_default(self): + qp = QuantizerParam() + assert qp.enable_rotate is False + + def test_enable_rotate_true(self): + qp = QuantizerParam(enable_rotate=True) + assert qp.enable_rotate is True + + def test_enable_rotate_false(self): + qp = QuantizerParam(enable_rotate=False) + assert qp.enable_rotate is False + + def test_equality(self): + qp1 = QuantizerParam(enable_rotate=True) + qp2 = QuantizerParam(enable_rotate=True) + qp3 = QuantizerParam(enable_rotate=False) + assert qp1 == qp2 + assert qp1 != qp3 + + def test_to_dict(self): + qp = QuantizerParam(enable_rotate=True) + d = qp.to_dict() + assert isinstance(d, dict) + assert d.get("enable_rotate") is True + + def test_repr(self): + qp = QuantizerParam(enable_rotate=True) + r = repr(qp) + assert "enable_rotate" in r or "QuantizerParam" in r + + def test_pickle_roundtrip(self): + import pickle + + qp = QuantizerParam(enable_rotate=True) + data = pickle.dumps(qp) + qp2 = pickle.loads(data) + assert qp2.enable_rotate is True + assert qp == qp2 + + +# ---------------------------- +# HnswIndexParam with QuantizerParam +# ---------------------------- + + +class TestHnswIndexParamQuantizer: + def test_default_quantizer_param(self): + param = HnswIndexParam() + assert param.quantizer_param is not None + assert param.quantizer_param.enable_rotate is False + + def test_with_quantizer_param(self): + qp = QuantizerParam(enable_rotate=True) + param = HnswIndexParam( + metric_type=MetricType.L2, + quantize_type=QuantizeType.INT8, + quantizer_param=qp, + ) + assert param.quantizer_param.enable_rotate is True + assert param.quantize_type == QuantizeType.INT8 + + +# ---------------------------- +# FlatIndexParam with QuantizerParam +# ---------------------------- + + +class TestFlatIndexParamQuantizer: + def test_with_quantizer_param(self): + qp = QuantizerParam(enable_rotate=True) + param = FlatIndexParam( + metric_type=MetricType.L2, + quantize_type=QuantizeType.INT8, + quantizer_param=qp, + ) + assert param.quantizer_param.enable_rotate is True + assert param.quantize_type == QuantizeType.INT8 diff --git a/python/zvec/__init__.py b/python/zvec/__init__.py index 5fdf9732c..929c75fa2 100644 --- a/python/zvec/__init__.py +++ b/python/zvec/__init__.py @@ -108,6 +108,7 @@ IVFIndexParam, IVFQueryParam, OptimizeOption, + QuantizerParam, VamanaIndexParam, VamanaQueryParam, ) @@ -171,6 +172,7 @@ "HnswQueryParam", "HnswRabitqQueryParam", "IVFQueryParam", + "QuantizerParam", "VamanaIndexParam", "VamanaQueryParam", # Extensions diff --git a/python/zvec/__init__.pyi b/python/zvec/__init__.pyi index cefa15b01..3e75f931b 100644 --- a/python/zvec/__init__.pyi +++ b/python/zvec/__init__.pyi @@ -30,6 +30,7 @@ from .model.param import ( IVFIndexParam, IVFQueryParam, OptimizeOption, + QuantizerParam, VamanaIndexParam, VamanaQueryParam, ) @@ -79,6 +80,7 @@ __all__: list = [ "MetricType", "OptimizeOption", "QuantizeType", + "QuantizerParam", "Query", "ReRanker", "RrfReRanker", diff --git a/python/zvec/model/param/__init__.py b/python/zvec/model/param/__init__.py index 43fc1ddce..084a79d00 100644 --- a/python/zvec/model/param/__init__.py +++ b/python/zvec/model/param/__init__.py @@ -31,6 +31,7 @@ IVFIndexParam, IVFQueryParam, OptimizeOption, + QuantizerParam, VamanaIndexParam, VamanaQueryParam, ) @@ -53,6 +54,7 @@ "IndexOption", "InvertIndexParam", "OptimizeOption", + "QuantizerParam", "VamanaIndexParam", "VamanaQueryParam", ] diff --git a/python/zvec/model/param/__init__.pyi b/python/zvec/model/param/__init__.pyi index c1d227280..dc312ba8a 100644 --- a/python/zvec/model/param/__init__.pyi +++ b/python/zvec/model/param/__init__.pyi @@ -26,6 +26,7 @@ __all__: list[str] = [ "IndexParam", "InvertIndexParam", "OptimizeOption", + "QuantizerParam", "QueryParam", "SegmentOption", "VectorIndexParam", @@ -147,6 +148,8 @@ class FlatIndexParam(VectorIndexParam): quantize_type (QuantizeType): Optional quantization type for vector compression (e.g., FP16, INT8). Use ``QuantizeType.UNDEFINED`` to disable quantization. Default is ``QuantizeType.UNDEFINED``. + quantizer_param (QuantizerParam): Quantizer configuration (e.g., enable_rotate). + Default is ``QuantizerParam()``. Examples: >>> from zvec.typing import MetricType, QuantizeType @@ -163,6 +166,7 @@ class FlatIndexParam(VectorIndexParam): self, metric_type: _zvec.typing.MetricType = ..., quantize_type: _zvec.typing.QuantizeType = ..., + quantizer_param: QuantizerParam = ..., ) -> None: """ Constructs a FlatIndexParam instance. @@ -171,6 +175,8 @@ class FlatIndexParam(VectorIndexParam): metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED (no quantization). + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). """ def __repr__(self) -> str: ... @@ -226,6 +232,7 @@ class HnswIndexParam(VectorIndexParam): ef_construction: typing.SupportsInt = 500, quantize_type: _zvec.typing.QuantizeType = ..., use_contiguous_memory: bool = False, + quantizer_param: QuantizerParam = ..., ) -> None: ... def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... @@ -498,6 +505,7 @@ class IVFIndexParam(VectorIndexParam): n_iters: typing.SupportsInt = 10, use_soar: bool = False, quantize_type: _zvec.typing.QuantizeType = ..., + quantizer_param: QuantizerParam = ..., ) -> None: """ Constructs an IVFIndexParam instance. @@ -511,6 +519,8 @@ class IVFIndexParam(VectorIndexParam): use_soar (bool, optional): Enable SOAR optimization. Defaults to False. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED. + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). """ def __repr__(self) -> str: ... @@ -1015,6 +1025,49 @@ class SegmentOption: bool: Whether the segment is read-only. """ +class QuantizerParam: + """ + + Parameters for quantizer configuration. + + Encapsulates quantization-related settings such as enable_rotate. + Designed for future extensibility. + + Attributes: + enable_rotate (bool): Whether to apply random rotation before INT8 + quantization to reduce quantization error. + Only effective with quantize_type=INT8. Defaults to False. + + Examples: + >>> qp = QuantizerParam(enable_rotate=True) + >>> print(qp.enable_rotate) + True + """ + + def __getstate__(self) -> tuple: ... + def __init__(self, enable_rotate: bool = False) -> None: + """ + Constructs a QuantizerParam instance. + + Args: + enable_rotate (bool, optional): Whether to apply random rotation + before INT8 quantization. Defaults to False. + """ + + def __repr__(self) -> str: ... + def __setstate__(self, arg0: tuple) -> None: ... + def __eq__(self, arg0: typing.Any) -> bool: ... + def to_dict(self) -> dict: + """ + Convert to dictionary with all fields + """ + + @property + def enable_rotate(self) -> bool: + """ + bool: Whether random rotation is enabled before INT8 quantization. + """ + class VectorIndexParam(IndexParam): """ @@ -1026,6 +1079,7 @@ class VectorIndexParam(IndexParam): type (IndexType): The specific vector index type (e.g., HNSW, FLAT). metric_type (MetricType): Distance metric used for similarity search. quantize_type (QuantizeType): Optional vector quantization type. + quantizer_param (QuantizerParam): Quantizer configuration (e.g., enable_rotate). """ def __getstate__(self) -> tuple: ... @@ -1047,6 +1101,12 @@ class VectorIndexParam(IndexParam): QuantizeType: Vector quantization type (e.g., FP16, INT8). """ + @property + def quantizer_param(self) -> QuantizerParam: + """ + QuantizerParam: Quantizer configuration including enable_rotate. + """ + class _SearchQuery: field_name: str filter: str diff --git a/src/ailego/math/fht.h b/src/ailego/math/fht.h new file mode 100644 index 000000000..8cfc93485 --- /dev/null +++ b/src/ailego/math/fht.h @@ -0,0 +1,38 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +namespace zvec { +namespace ailego { + +//! Flip the sign of elements based on a packed bit-array. +void fht_flip_sign(const uint8_t *flip, float *data, size_t dim); + +//! Kac random walk: butterfly add/sub between first and second halves. +void fht_kacs_walk(float *data, size_t len); + +//! Inverse Kac walk: undo butterfly add/sub with 0.5 factor. +void fht_inv_kacs_walk(float *data, size_t len); + +//! In-place Fast Hadamard Transform on a power-of-2 length array. +void fht_inplace(float *data, size_t n); + +//! Scale each element by a constant factor. +void fht_vec_rescale(float *data, size_t n, float factor); + +} // namespace ailego +} // namespace zvec diff --git a/src/ailego/math/fht_avx2.cc b/src/ailego/math/fht_avx2.cc new file mode 100644 index 000000000..857eefe7e --- /dev/null +++ b/src/ailego/math/fht_avx2.cc @@ -0,0 +1,114 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(__AVX2__) + +#include +#include +#include +#include + +namespace zvec { +namespace ailego { + +void fht_flip_sign_avx2(const uint8_t *flip, float *data, size_t dim) { + size_t simd_end = dim & ~31u; + constexpr size_t kChunk = 32; + const __m256i bit_select = + _mm256_setr_epi32(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80); + const __m256 sign_flip = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000)); + for (size_t i = 0; i < simd_end; i += kChunk) { + uint32_t mask_bits; + std::memcpy(&mask_bits, &flip[i / 8], sizeof(mask_bits)); + for (int b = 0; b < 4; ++b) { + __m256i mb = _mm256_set1_epi32((mask_bits >> (b * 8)) & 0xFF); + __m256i test = _mm256_and_si256(mb, bit_select); + __m256i cmp = _mm256_cmpeq_epi32(test, bit_select); + __m256 xor_mask = _mm256_and_ps(_mm256_castsi256_ps(cmp), sign_flip); + __m256 v = _mm256_loadu_ps(&data[i + b * 8]); + v = _mm256_xor_ps(v, xor_mask); + _mm256_storeu_ps(&data[i + b * 8], v); + } + } + // Scalar tail + for (size_t i = simd_end; i < dim; ++i) { + if (flip[i / 8] & (1u << (i % 8))) { + data[i] = -data[i]; + } + } +} + +void fht_kacs_walk_avx2(float *data, size_t len) { + size_t half = len / 2; + size_t half_end = half & ~7u; + for (size_t i = 0; i < half_end; i += 8) { + __m256 x = _mm256_loadu_ps(&data[i]); + __m256 y = _mm256_loadu_ps(&data[i + half]); + _mm256_storeu_ps(&data[i], _mm256_add_ps(x, y)); + _mm256_storeu_ps(&data[i + half], _mm256_sub_ps(x, y)); + } + // Scalar tail + for (size_t i = half_end; i < half; ++i) { + float x = data[i]; + float y = data[i + half]; + data[i] = x + y; + data[i + half] = x - y; + } +} + +void fht_inv_kacs_walk_avx2(float *data, size_t len) { + size_t half = len / 2; + size_t half_end = half & ~7u; + const __m256 half_fac = _mm256_set1_ps(0.5f); + for (size_t i = 0; i < half_end; i += 8) { + __m256 a = _mm256_loadu_ps(&data[i]); + __m256 b = _mm256_loadu_ps(&data[i + half]); + _mm256_storeu_ps(&data[i], _mm256_mul_ps(_mm256_add_ps(a, b), half_fac)); + _mm256_storeu_ps(&data[i + half], + _mm256_mul_ps(_mm256_sub_ps(a, b), half_fac)); + } + // Scalar tail + for (size_t i = half_end; i < half; ++i) { + float a = data[i]; + float b = data[i + half]; + data[i] = (a + b) * 0.5f; + data[i + half] = (a - b) * 0.5f; + } +} + +void fht_inplace_avx2(float *data, size_t n) { + for (size_t len = 1; len < n; len <<= 1) { + size_t step = len << 1; + size_t simd_end = len & ~7u; + for (size_t i = 0; i < n; i += step) { + for (size_t j = 0; j < simd_end; j += 8) { + __m256 u = _mm256_loadu_ps(&data[i + j]); + __m256 v = _mm256_loadu_ps(&data[i + j + len]); + _mm256_storeu_ps(&data[i + j], _mm256_add_ps(u, v)); + _mm256_storeu_ps(&data[i + j + len], _mm256_sub_ps(u, v)); + } + for (size_t j = simd_end; j < len; ++j) { + float u = data[i + j]; + float v = data[i + j + len]; + data[i + j] = u + v; + data[i + j + len] = u - v; + } + } + } +} + +} // namespace ailego +} // namespace zvec + +#endif // __AVX2__ diff --git a/src/ailego/math/fht_avx512.cc b/src/ailego/math/fht_avx512.cc new file mode 100644 index 000000000..e56cb9554 --- /dev/null +++ b/src/ailego/math/fht_avx512.cc @@ -0,0 +1,119 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(__AVX512F__) + +#include +#include +#include +#include + +namespace zvec { +namespace ailego { + +void fht_flip_sign_avx512(const uint8_t *flip, float *data, size_t dim) { + size_t simd_end = dim & ~63u; + constexpr size_t kChunk = 64; + const __m512 sign_flip = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000)); + for (size_t i = 0; i < simd_end; i += kChunk) { + uint64_t mask_bits; + std::memcpy(&mask_bits, &flip[i / 8], sizeof(mask_bits)); + const __mmask16 m0 = _cvtu32_mask16(mask_bits & 0xFFFF); + const __mmask16 m1 = _cvtu32_mask16((mask_bits >> 16) & 0xFFFF); + const __mmask16 m2 = _cvtu32_mask16((mask_bits >> 32) & 0xFFFF); + const __mmask16 m3 = _cvtu32_mask16((mask_bits >> 48) & 0xFFFF); + __m512 v0 = _mm512_loadu_ps(&data[i]); + v0 = _mm512_mask_xor_ps(v0, m0, v0, sign_flip); + _mm512_storeu_ps(&data[i], v0); + __m512 v1 = _mm512_loadu_ps(&data[i + 16]); + v1 = _mm512_mask_xor_ps(v1, m1, v1, sign_flip); + _mm512_storeu_ps(&data[i + 16], v1); + __m512 v2 = _mm512_loadu_ps(&data[i + 32]); + v2 = _mm512_mask_xor_ps(v2, m2, v2, sign_flip); + _mm512_storeu_ps(&data[i + 32], v2); + __m512 v3 = _mm512_loadu_ps(&data[i + 48]); + v3 = _mm512_mask_xor_ps(v3, m3, v3, sign_flip); + _mm512_storeu_ps(&data[i + 48], v3); + } + // Scalar tail + for (size_t i = simd_end; i < dim; ++i) { + if (flip[i / 8] & (1u << (i % 8))) { + data[i] = -data[i]; + } + } +} + +void fht_kacs_walk_avx512(float *data, size_t len) { + size_t half = len / 2; + size_t half_end = half & ~15u; + for (size_t i = 0; i < half_end; i += 16) { + __m512 x = _mm512_loadu_ps(&data[i]); + __m512 y = _mm512_loadu_ps(&data[i + half]); + _mm512_storeu_ps(&data[i], _mm512_add_ps(x, y)); + _mm512_storeu_ps(&data[i + half], _mm512_sub_ps(x, y)); + } + // Scalar tail + for (size_t i = half_end; i < half; ++i) { + float x = data[i]; + float y = data[i + half]; + data[i] = x + y; + data[i + half] = x - y; + } +} + +void fht_inv_kacs_walk_avx512(float *data, size_t len) { + size_t half = len / 2; + size_t half_end = half & ~15u; + const __m512 half_fac = _mm512_set1_ps(0.5f); + for (size_t i = 0; i < half_end; i += 16) { + __m512 a = _mm512_loadu_ps(&data[i]); + __m512 b = _mm512_loadu_ps(&data[i + half]); + _mm512_storeu_ps(&data[i], _mm512_mul_ps(_mm512_add_ps(a, b), half_fac)); + _mm512_storeu_ps(&data[i + half], + _mm512_mul_ps(_mm512_sub_ps(a, b), half_fac)); + } + // Scalar tail + for (size_t i = half_end; i < half; ++i) { + float a = data[i]; + float b = data[i + half]; + data[i] = (a + b) * 0.5f; + data[i + half] = (a - b) * 0.5f; + } +} + +void fht_inplace_avx512(float *data, size_t n) { + for (size_t len = 1; len < n; len <<= 1) { + size_t step = len << 1; + size_t simd_end = len & ~15u; + for (size_t i = 0; i < n; i += step) { + for (size_t j = 0; j < simd_end; j += 16) { + __m512 u = _mm512_loadu_ps(&data[i + j]); + __m512 v = _mm512_loadu_ps(&data[i + j + len]); + _mm512_storeu_ps(&data[i + j], _mm512_add_ps(u, v)); + _mm512_storeu_ps(&data[i + j + len], _mm512_sub_ps(u, v)); + } + for (size_t j = simd_end; j < len; ++j) { + float u = data[i + j]; + float v = data[i + j + len]; + data[i + j] = u + v; + data[i + j + len] = u - v; + } + } + } +} + +} // namespace ailego +} // namespace zvec + +#endif // __AVX512F__ diff --git a/src/ailego/math/fht_dispatch.cc b/src/ailego/math/fht_dispatch.cc new file mode 100644 index 000000000..d9f80a1d2 --- /dev/null +++ b/src/ailego/math/fht_dispatch.cc @@ -0,0 +1,156 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "fht.h" + +namespace zvec { +namespace ailego { + +// ISA-specific forward declarations (implementations in +// fht_scalar/sse/avx2/avx512/neon) +void fht_flip_sign_scalar(const uint8_t *flip, float *data, size_t dim); +void fht_kacs_walk_scalar(float *data, size_t len); +void fht_inv_kacs_walk_scalar(float *data, size_t len); +void fht_inplace_scalar(float *data, size_t n); +#if defined(__SSE2__) +void fht_flip_sign_sse(const uint8_t *flip, float *data, size_t dim); +void fht_kacs_walk_sse(float *data, size_t len); +void fht_inv_kacs_walk_sse(float *data, size_t len); +#endif +#if defined(__AVX2__) +void fht_flip_sign_avx2(const uint8_t *flip, float *data, size_t dim); +void fht_kacs_walk_avx2(float *data, size_t len); +void fht_inv_kacs_walk_avx2(float *data, size_t len); +void fht_inplace_avx2(float *data, size_t n); +#endif +#if defined(__AVX512F__) +void fht_flip_sign_avx512(const uint8_t *flip, float *data, size_t dim); +void fht_kacs_walk_avx512(float *data, size_t len); +void fht_inv_kacs_walk_avx512(float *data, size_t len); +void fht_inplace_avx512(float *data, size_t n); +#endif +#if defined(__ARM_NEON) && defined(__aarch64__) +void fht_flip_sign_neon(const uint8_t *flip, float *data, size_t dim); +void fht_kacs_walk_neon(float *data, size_t len); +void fht_inv_kacs_walk_neon(float *data, size_t len); +#endif + +// ============================================================================ +// Runtime dispatch entry points +// ============================================================================ + +void fht_flip_sign(const uint8_t *flip, float *data, size_t dim) { +#if defined(__ARM_NEON) && defined(__aarch64__) + fht_flip_sign_neon(flip, data, dim); +#else +#if defined(__AVX512F__) + if (internal::CpuFeatures::static_flags_.AVX512F && + internal::CpuFeatures::static_flags_.AVX512DQ) { + fht_flip_sign_avx512(flip, data, dim); + return; + } +#endif +#if defined(__AVX2__) + if (internal::CpuFeatures::static_flags_.AVX2) { + fht_flip_sign_avx2(flip, data, dim); + return; + } +#endif +#if defined(__SSE2__) + if (internal::CpuFeatures::static_flags_.SSE2) { + fht_flip_sign_sse(flip, data, dim); + return; + } +#endif + fht_flip_sign_scalar(flip, data, dim); +#endif // __ARM_NEON +} + +void fht_kacs_walk(float *data, size_t len) { +#if defined(__ARM_NEON) && defined(__aarch64__) + fht_kacs_walk_neon(data, len); +#else +#if defined(__AVX512F__) + if (internal::CpuFeatures::static_flags_.AVX512F) { + fht_kacs_walk_avx512(data, len); + return; + } +#endif +#if defined(__AVX2__) + if (internal::CpuFeatures::static_flags_.AVX2) { + fht_kacs_walk_avx2(data, len); + return; + } +#endif +#if defined(__SSE2__) + if (internal::CpuFeatures::static_flags_.SSE2) { + fht_kacs_walk_sse(data, len); + return; + } +#endif + fht_kacs_walk_scalar(data, len); +#endif // __ARM_NEON +} + +void fht_inv_kacs_walk(float *data, size_t len) { +#if defined(__ARM_NEON) && defined(__aarch64__) + fht_inv_kacs_walk_neon(data, len); +#else +#if defined(__AVX512F__) + if (internal::CpuFeatures::static_flags_.AVX512F) { + fht_inv_kacs_walk_avx512(data, len); + return; + } +#endif +#if defined(__AVX2__) + if (internal::CpuFeatures::static_flags_.AVX2) { + fht_inv_kacs_walk_avx2(data, len); + return; + } +#endif +#if defined(__SSE2__) + if (internal::CpuFeatures::static_flags_.SSE2) { + fht_inv_kacs_walk_sse(data, len); + return; + } +#endif + fht_inv_kacs_walk_scalar(data, len); +#endif // __ARM_NEON +} + +void fht_inplace(float *data, size_t n) { +#if defined(__AVX512F__) + if (internal::CpuFeatures::static_flags_.AVX512F) { + fht_inplace_avx512(data, n); + return; + } +#endif +#if defined(__AVX2__) + if (internal::CpuFeatures::static_flags_.AVX2) { + fht_inplace_avx2(data, n); + return; + } +#endif + fht_inplace_scalar(data, n); +} + +void fht_vec_rescale(float *data, size_t n, float factor) { + for (size_t i = 0; i < n; ++i) { + data[i] *= factor; + } +} + +} // namespace ailego +} // namespace zvec diff --git a/src/ailego/math/fht_neon.cc b/src/ailego/math/fht_neon.cc new file mode 100644 index 000000000..a3879f5be --- /dev/null +++ b/src/ailego/math/fht_neon.cc @@ -0,0 +1,83 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(__ARM_NEON) && defined(__aarch64__) + +#include +#include +#include +#include + +namespace zvec { +namespace ailego { + +void fht_flip_sign_neon(const uint8_t *flip, float *data, size_t dim) { + const uint32x4_t sign_bit = vdupq_n_u32(0x80000000u); + for (size_t i = 0; i < dim; i += 4) { + uint16_t bits16; + std::memcpy(&bits16, &flip[i / 8], sizeof(bits16)); + bits16 >>= (i % 8); + uint32_t b0 = bits16 & 1u; + uint32_t b1 = (bits16 >> 1) & 1u; + uint32_t b2 = (bits16 >> 2) & 1u; + uint32_t b3 = (bits16 >> 3) & 1u; + uint32x4_t bit_mask = {b0, b1, b2, b3}; + uint32x4_t sign_mask = vmulq_u32(bit_mask, sign_bit); + float32x4_t v = vld1q_f32(&data[i]); + v = vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(v), sign_mask)); + vst1q_f32(&data[i], v); + } +} + +void fht_kacs_walk_neon(float *data, size_t len) { + size_t half = len / 2; + size_t half_end = half & ~3u; + for (size_t i = 0; i < half_end; i += 4) { + float32x4_t x = vld1q_f32(&data[i]); + float32x4_t y = vld1q_f32(&data[i + half]); + vst1q_f32(&data[i], vaddq_f32(x, y)); + vst1q_f32(&data[i + half], vsubq_f32(x, y)); + } + // Scalar tail + for (size_t i = half_end; i < half; ++i) { + float x = data[i]; + float y = data[i + half]; + data[i] = x + y; + data[i + half] = x - y; + } +} + +void fht_inv_kacs_walk_neon(float *data, size_t len) { + size_t half = len / 2; + size_t half_end = half & ~3u; + const float32x4_t half_fac = vdupq_n_f32(0.5f); + for (size_t i = 0; i < half_end; i += 4) { + float32x4_t a = vld1q_f32(&data[i]); + float32x4_t b = vld1q_f32(&data[i + half]); + vst1q_f32(&data[i], vmulq_f32(vaddq_f32(a, b), half_fac)); + vst1q_f32(&data[i + half], vmulq_f32(vsubq_f32(a, b), half_fac)); + } + // Scalar tail + for (size_t i = half_end; i < half; ++i) { + float a = data[i]; + float b = data[i + half]; + data[i] = (a + b) * 0.5f; + data[i + half] = (a - b) * 0.5f; + } +} + +} // namespace ailego +} // namespace zvec + +#endif // __ARM_NEON && __aarch64__ diff --git a/src/ailego/math/fht_scalar.cc b/src/ailego/math/fht_scalar.cc new file mode 100644 index 000000000..64dd073d4 --- /dev/null +++ b/src/ailego/math/fht_scalar.cc @@ -0,0 +1,63 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace zvec { +namespace ailego { + +void fht_flip_sign_scalar(const uint8_t *flip, float *data, size_t dim) { + for (size_t i = 0; i < dim; ++i) { + if (flip[i / 8] & (1u << (i % 8))) { + data[i] = -data[i]; + } + } +} + +void fht_kacs_walk_scalar(float *data, size_t len) { + size_t half = len / 2; + for (size_t i = 0; i < half; ++i) { + float x = data[i]; + float y = data[i + half]; + data[i] = x + y; + data[i + half] = x - y; + } +} + +void fht_inv_kacs_walk_scalar(float *data, size_t len) { + size_t half = len / 2; + for (size_t i = 0; i < half; ++i) { + float a = data[i]; + float b = data[i + half]; + data[i] = (a + b) * 0.5f; + data[i + half] = (a - b) * 0.5f; + } +} + +void fht_inplace_scalar(float *data, size_t n) { + for (size_t len = 1; len < n; len <<= 1) { + for (size_t i = 0; i < n; i += len << 1) { + for (size_t j = i; j < i + len; ++j) { + float u = data[j]; + float v = data[j + len]; + data[j] = u + v; + data[j + len] = u - v; + } + } + } +} + +} // namespace ailego +} // namespace zvec diff --git a/src/ailego/math/fht_sse.cc b/src/ailego/math/fht_sse.cc new file mode 100644 index 000000000..09029c58f --- /dev/null +++ b/src/ailego/math/fht_sse.cc @@ -0,0 +1,82 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(__SSE2__) + +#include +#include +#include +#include + +namespace zvec { +namespace ailego { + +void fht_flip_sign_sse(const uint8_t *flip, float *data, size_t dim) { + for (size_t i = 0; i < dim; i += 4) { + uint16_t bits16; + std::memcpy(&bits16, &flip[i / 8], sizeof(bits16)); + bits16 >>= (i % 8); + uint32_t b0 = bits16 & 1u; + uint32_t b1 = (bits16 >> 1) & 1u; + uint32_t b2 = (bits16 >> 2) & 1u; + uint32_t b3 = (bits16 >> 3) & 1u; + __m128i bit_mask = _mm_set_epi32(b3, b2, b1, b0); + __m128i sign_mask = _mm_slli_epi32(bit_mask, 31); + __m128 v = _mm_loadu_ps(&data[i]); + v = _mm_xor_ps(v, _mm_castsi128_ps(sign_mask)); + _mm_storeu_ps(&data[i], v); + } +} + +void fht_kacs_walk_sse(float *data, size_t len) { + size_t half = len / 2; + size_t half_end = half & ~3u; + for (size_t i = 0; i < half_end; i += 4) { + __m128 x = _mm_loadu_ps(&data[i]); + __m128 y = _mm_loadu_ps(&data[i + half]); + _mm_storeu_ps(&data[i], _mm_add_ps(x, y)); + _mm_storeu_ps(&data[i + half], _mm_sub_ps(x, y)); + } + // Scalar tail + for (size_t i = half_end; i < half; ++i) { + float x = data[i]; + float y = data[i + half]; + data[i] = x + y; + data[i + half] = x - y; + } +} + +void fht_inv_kacs_walk_sse(float *data, size_t len) { + size_t half = len / 2; + size_t half_end = half & ~3u; + const __m128 half_fac = _mm_set1_ps(0.5f); + for (size_t i = 0; i < half_end; i += 4) { + __m128 a = _mm_loadu_ps(&data[i]); + __m128 b = _mm_loadu_ps(&data[i + half]); + _mm_storeu_ps(&data[i], _mm_mul_ps(_mm_add_ps(a, b), half_fac)); + _mm_storeu_ps(&data[i + half], _mm_mul_ps(_mm_sub_ps(a, b), half_fac)); + } + // Scalar tail + for (size_t i = half_end; i < half; ++i) { + float a = data[i]; + float b = data[i + half]; + data[i] = (a + b) * 0.5f; + data[i + half] = (a - b) * 0.5f; + } +} + +} // namespace ailego +} // namespace zvec + +#endif // __SSE2__ diff --git a/src/binding/c/c_api.cc b/src/binding/c/c_api.cc index a81cc3864..6c7f226db 100644 --- a/src/binding/c/c_api.cc +++ b/src/binding/c/c_api.cc @@ -1476,6 +1476,60 @@ zvec_quantize_type_t zvec_index_params_get_quantize_type( return ZVEC_QUANTIZE_TYPE_UNDEFINED; } +/** + * @brief Set enable_rotate for quantizer parameters + * @param params Index parameters (must be vector index type) + * @param enable_rotate Whether to enable random rotation before quantization + * @return ZVEC_OK on success, error code on failure + */ +zvec_error_code_t zvec_index_params_set_quantizer_enable_rotate( + zvec_index_params_t *params, bool enable_rotate) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Index params pointer cannot be null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *cpp_params = reinterpret_cast(params); + + if (!cpp_params->is_vector_index_type()) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Index params is not a vector index type"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *vec_params = dynamic_cast(cpp_params); + if (!vec_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Failed to cast to VectorIndexParams"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + zvec::QuantizerParam qp = vec_params->quantizer_param(); + qp.set_enable_rotate(enable_rotate); + vec_params->set_quantizer_param(qp); + return ZVEC_OK; +} + +/** + * @brief Get enable_rotate setting from quantizer parameters + * @param params Index parameters + * @return true if rotation is enabled, false otherwise + */ +bool zvec_index_params_get_quantizer_enable_rotate( + const zvec_index_params_t *params) { + if (!params) { + return false; + } + auto *cpp_params = reinterpret_cast(params); + + if (cpp_params->is_vector_index_type()) { + auto *vec_params = + dynamic_cast(cpp_params); + if (vec_params) { + return vec_params->quantizer_param().enable_rotate(); + } + } + return false; +} + /** * @brief Get index type from index parameters * @param params Index parameters diff --git a/src/binding/python/CMakeLists.txt b/src/binding/python/CMakeLists.txt index 0db6d75ff..41dc0f6a3 100644 --- a/src/binding/python/CMakeLists.txt +++ b/src/binding/python/CMakeLists.txt @@ -61,6 +61,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Linux") $ $ $ + $ $ -Wl,--no-whole-archive zvec @@ -98,6 +99,7 @@ elseif (APPLE) -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ + -Wl,-force_load,$ -Wl,-force_load,$ zvec ) @@ -117,6 +119,7 @@ elseif (MSVC) core_metric_static core_utility_static core_quantizer_static + core_rotator core_plugin ) target_link_libraries(_zvec PRIVATE diff --git a/src/binding/python/model/param/python_param.cc b/src/binding/python/model/param/python_param.cc index 268214dbd..80dc203af 100644 --- a/src/binding/python/model/param/python_param.cc +++ b/src/binding/python/model/param/python_param.cc @@ -338,6 +338,61 @@ Constructs an FtsIndexParam instance. t[2].cast()); })); + // binding QuantizerParam + py::class_> quantizer_param( + m, "QuantizerParam", R"pbdoc( +Parameters for quantizer configuration. + +Encapsulates quantization-related settings such as enable_rotate. +Designed for future extensibility. + +Attributes: + enable_rotate (bool): Whether to apply random rotation before INT8 + quantization to reduce quantization error. + Only effective with quantize_type=INT8. Defaults to False. + +Examples: + >>> qp = QuantizerParam(enable_rotate=True) + >>> print(qp.enable_rotate) + True +)pbdoc"); + quantizer_param.def(py::init(), py::arg("enable_rotate") = false) + .def_property_readonly( + "enable_rotate", + [](const QuantizerParam &self) -> bool { + return self.enable_rotate(); + }, + "bool: Whether random rotation is enabled before INT8 quantization.") + .def( + "to_dict", + [](const QuantizerParam &self) -> py::dict { + py::dict dict; + dict["enable_rotate"] = self.enable_rotate(); + return dict; + }, + "Convert to dictionary with all fields") + .def("__repr__", + [](const QuantizerParam &self) -> std::string { + return "{\"enable_rotate\":" + + std::string(self.enable_rotate() ? "true" : "false") + "}"; + }) + .def( + "__eq__", + [](const QuantizerParam &self, const py::object &other) { + if (!py::isinstance(other)) return false; + return self == other.cast(); + }, + py::is_operator()) + .def(py::pickle( + [](const QuantizerParam &self) { + return py::make_tuple(self.enable_rotate()); + }, + [](py::tuple t) { + if (t.size() != 1) + throw std::runtime_error("Invalid state for QuantizerParam"); + return std::make_shared(t[0].cast()); + })); + // binding base vector index params py::class_> vector_params(m, "VectorIndexParam", R"pbdoc( @@ -349,6 +404,7 @@ Encapsulates common settings for all vector index types. type (IndexType): The specific vector index type (e.g., HNSW, FLAT). metric_type (MetricType): Distance metric used for similarity search. quantize_type (QuantizeType): Optional vector quantization type. + quantizer_param (QuantizerParam): Quantizer configuration (e.g., enable_rotate). )pbdoc"); vector_params .def_property_readonly( @@ -363,6 +419,12 @@ Encapsulates common settings for all vector index types. return self.quantize_type(); }, "QuantizeType: Vector quantization type (e.g., FP16, INT8).") + .def_property_readonly( + "quantizer_param", + [](const VectorIndexParams &self) -> QuantizerParam { + return self.quantizer_param(); + }, + "QuantizerParam: Quantizer configuration including enable_rotate.") .def( "to_dict", [](const VectorIndexParams &self) -> py::dict { @@ -371,6 +433,9 @@ Encapsulates common settings for all vector index types. dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") @@ -382,7 +447,7 @@ Encapsulates common settings for all vector index types. [](py::tuple t) { // __setstate__ if (t.size() != 3) throw std::runtime_error("Invalid state for VectorIndexParams"); - // 基类,不能直接实例化,用于子类 + // Base class, cannot instantiate directly, used by subclasses return std::shared_ptr(); })); @@ -421,13 +486,20 @@ encapsulates its construction hyperparameters. {'metric_type': 'IP', 'm': 16, 'ef_construction': 200, 'quantize_type': 'INT8', 'use_contiguous_memory': True} )pbdoc"); hnsw_params - .def(py::init(), + .def(py::init([](MetricType metric_type, int m, int ef_construction, + QuantizeType quantize_type, bool use_contiguous_memory, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, m, ef_construction, quantize_type, + use_contiguous_memory, quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("m") = core_interface::kDefaultHnswNeighborCnt, py::arg("ef_construction") = core_interface::kDefaultHnswEfConstruction, py::arg("quantize_type") = QuantizeType::UNDEFINED, - py::arg("use_contiguous_memory") = false) + py::arg("use_contiguous_memory") = false, + py::arg("quantizer_param") = QuantizerParam()) .def_property_readonly( "m", &HnswIndexParams::m, "int: Maximum number of neighbors per node in upper layers.") @@ -450,34 +522,43 @@ encapsulates its construction hyperparameters. dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); dict["use_contiguous_memory"] = self.use_contiguous_memory(); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") - .def("__repr__", - [](const HnswIndexParams &self) -> std::string { - return "{" - "\"metric_type\":" + - metric_type_to_string(self.metric_type()) + - ", \"m\":" + std::to_string(self.m()) + - ", \"ef_construction\":" + - std::to_string(self.ef_construction()) + - ", \"quantize_type\":" + - quantize_type_to_string(self.quantize_type()) + - ", \"use_contiguous_memory\":" + - (self.use_contiguous_memory() ? "true" : "false") + "}"; - }) + .def( + "__repr__", + [](const HnswIndexParams &self) -> std::string { + return "{" + "\"metric_type\":" + + metric_type_to_string(self.metric_type()) + + ", \"m\":" + std::to_string(self.m()) + + ", \"ef_construction\":" + + std::to_string(self.ef_construction()) + + ", \"quantize_type\":" + + quantize_type_to_string(self.quantize_type()) + + ", \"use_contiguous_memory\":" + + (self.use_contiguous_memory() ? "true" : "false") + + ", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; + }) .def(py::pickle( [](const HnswIndexParams &self) { return py::make_tuple(self.metric_type(), self.m(), self.ef_construction(), self.quantize_type(), - self.use_contiguous_memory()); + self.use_contiguous_memory(), + self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 5) + if (t.size() != 5 && t.size() != 6) throw std::runtime_error("Invalid state for HnswIndexParams"); + QuantizerParam qp(t.size() >= 6 ? t[5].cast() : false); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), - t[3].cast(), t[4].cast()); + t[3].cast(), t[4].cast(), qp); })); // binding hnsw rabitq index params @@ -626,8 +707,16 @@ its construction hyperparameters. ... ) )pbdoc"); vamana_params - .def(py::init(), + .def(py::init([](MetricType metric_type, int max_degree, + int search_list_size, float alpha, bool saturate_graph, + bool use_contiguous_memory, bool use_id_map, + QuantizeType quantize_type, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, max_degree, search_list_size, alpha, + saturate_graph, use_contiguous_memory, use_id_map, + quantize_type, quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("max_degree") = core_interface::kDefaultVamanaMaxDegree, py::arg("search_list_size") = @@ -637,7 +726,8 @@ its construction hyperparameters. core_interface::kDefaultVamanaSaturateGraph, py::arg("use_contiguous_memory") = false, py::arg("use_id_map") = false, - py::arg("quantize_type") = QuantizeType::UNDEFINED) + py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("quantizer_param") = QuantizerParam()) .def_property_readonly( "max_degree", &VamanaIndexParams::max_degree, "int: Maximum out-degree (R) of every node in the Vamana graph.") @@ -673,45 +763,53 @@ its construction hyperparameters. dict["use_id_map"] = self.use_id_map(); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") - .def("__repr__", - [](const VamanaIndexParams &self) -> std::string { - return "{" - "\"type\":\"" + - index_type_to_string(self.type()) + - "\", \"metric_type\":\"" + - metric_type_to_string(self.metric_type()) + - "\", \"max_degree\":" + std::to_string(self.max_degree()) + - ", \"search_list_size\":" + - std::to_string(self.search_list_size()) + - ", \"alpha\":" + std::to_string(self.alpha()) + - ", \"saturate_graph\":" + - std::string(self.saturate_graph() ? "true" : "false") + - ", \"use_contiguous_memory\":" + - std::string(self.use_contiguous_memory() ? "true" - : "false") + - ", \"use_id_map\":" + - std::string(self.use_id_map() ? "true" : "false") + - ", \"quantize_type\":\"" + - quantize_type_to_string(self.quantize_type()) + "\"}"; - }) + .def( + "__repr__", + [](const VamanaIndexParams &self) -> std::string { + return "{" + "\"type\":\"" + + index_type_to_string(self.type()) + + "\", \"metric_type\":\"" + + metric_type_to_string(self.metric_type()) + + "\", \"max_degree\":" + std::to_string(self.max_degree()) + + ", \"search_list_size\":" + + std::to_string(self.search_list_size()) + + ", \"alpha\":" + std::to_string(self.alpha()) + + ", \"saturate_graph\":" + + std::string(self.saturate_graph() ? "true" : "false") + + ", \"use_contiguous_memory\":" + + std::string(self.use_contiguous_memory() ? "true" + : "false") + + ", \"use_id_map\":" + + std::string(self.use_id_map() ? "true" : "false") + + ", \"quantize_type\":\"" + + quantize_type_to_string(self.quantize_type()) + + "\", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; + }) .def(py::pickle( [](const VamanaIndexParams &self) { - return py::make_tuple(self.metric_type(), self.max_degree(), - self.search_list_size(), self.alpha(), - self.saturate_graph(), - self.use_contiguous_memory(), - self.use_id_map(), self.quantize_type()); + return py::make_tuple( + self.metric_type(), self.max_degree(), self.search_list_size(), + self.alpha(), self.saturate_graph(), + self.use_contiguous_memory(), self.use_id_map(), + self.quantize_type(), self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 8) + if (t.size() != 8 && t.size() != 9) throw std::runtime_error("Invalid state for VamanaIndexParams"); + QuantizerParam qp(t.size() >= 9 ? t[8].cast() : false); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast(), t[4].cast(), t[5].cast(), - t[6].cast(), t[7].cast()); + t[6].cast(), t[7].cast(), qp); })); // FlatIndexParams @@ -741,9 +839,14 @@ suitable for small to medium datasets or as a baseline. {'metric_type': 'L2', 'quantize_type': 'FP16'} )pbdoc"); flat_params - .def(py::init(), + .def(py::init([](MetricType metric_type, QuantizeType quantize_type, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, quantize_type, quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("quantizer_param") = QuantizerParam(), R"pbdoc( Constructs a FlatIndexParam instance. @@ -751,6 +854,8 @@ Constructs a FlatIndexParam instance. metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED (no quantization). + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). )pbdoc") .def( "to_dict", @@ -759,26 +864,35 @@ Constructs a FlatIndexParam instance. dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") - .def("__repr__", - [](const FlatIndexParams &self) -> std::string { - return "{" - "\"metric_type\":" + - metric_type_to_string(self.metric_type()) + - ", \"quantize_type\":" + - quantize_type_to_string(self.quantize_type()) + "}"; - }) + .def( + "__repr__", + [](const FlatIndexParams &self) -> std::string { + return "{" + "\"metric_type\":" + + metric_type_to_string(self.metric_type()) + + ", \"quantize_type\":" + + quantize_type_to_string(self.quantize_type()) + + ", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; + }) .def(py::pickle( [](const FlatIndexParams &self) { - return py::make_tuple(self.metric_type(), self.quantize_type()); + return py::make_tuple(self.metric_type(), self.quantize_type(), + self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 2) + if (t.size() != 2 && t.size() != 3) throw std::runtime_error("Invalid state for FlatIndexParams"); - return std::make_shared(t[0].cast(), - t[1].cast()); + QuantizerParam qp(t.size() >= 3 ? t[2].cast() : false); + return std::make_shared( + t[0].cast(), t[1].cast(), qp); })); // IVFIndexParams @@ -815,10 +929,17 @@ and accuracy. 100 )pbdoc"); ivf_params - .def(py::init(), + .def(py::init([](MetricType metric_type, int n_list, int n_iters, + bool use_soar, QuantizeType quantize_type, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, n_list, n_iters, use_soar, quantize_type, + quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("n_list") = 10, py::arg("n_iters") = 10, py::arg("use_soar") = false, py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("quantizer_param") = QuantizerParam(), R"pbdoc( Constructs an IVFIndexParam instance. @@ -831,6 +952,8 @@ Constructs an IVFIndexParam instance. use_soar (bool, optional): Enable SOAR optimization. Defaults to False. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED. + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). )pbdoc") .def_property_readonly("n_list", &IVFIndexParams::n_list, "int: Number of inverted lists.") @@ -850,32 +973,41 @@ Constructs an IVFIndexParam instance. dict["use_soar"] = self.use_soar(); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") - .def("__repr__", - [](const IVFIndexParams &self) { - return "{" - "\"metric_type\":" + - metric_type_to_string(self.metric_type()) + - ", \"n_list\":" + std::to_string(self.n_list()) + - ", \"n_iters\":" + std::to_string(self.n_iters()) + - ", \"use_soar\":" + std::to_string(self.use_soar()) + - ", \"quantize_type\":" + - quantize_type_to_string(self.quantize_type()) + "}"; - }) + .def( + "__repr__", + [](const IVFIndexParams &self) { + return "{" + "\"metric_type\":" + + metric_type_to_string(self.metric_type()) + + ", \"n_list\":" + std::to_string(self.n_list()) + + ", \"n_iters\":" + std::to_string(self.n_iters()) + + ", \"use_soar\":" + std::to_string(self.use_soar()) + + ", \"quantize_type\":" + + quantize_type_to_string(self.quantize_type()) + + ", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; + }) .def(py::pickle( [](const IVFIndexParams &self) { return py::make_tuple(self.metric_type(), self.n_list(), self.n_iters(), self.use_soar(), - self.quantize_type()); + self.quantize_type(), + self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 5) + if (t.size() != 5 && t.size() != 6) throw std::runtime_error("Invalid state for IVFIndexParams"); + QuantizerParam qp(t.size() >= 6 ? t[5].cast() : false); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), - t[3].cast(), t[4].cast()); + t[3].cast(), t[4].cast(), qp); })); // DiskAnnIndexParams @@ -915,10 +1047,17 @@ only compressed vector will be loaded into memory. By this way, search memory at 100 )pbdoc"); diskann_params - .def(py::init(), + .def(py::init([](MetricType metric_type, int max_degree, int list_size, + int pq_chunk_num, QuantizeType quantize_type, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, max_degree, list_size, pq_chunk_num, + quantize_type, quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("max_degree") = 100, py::arg("list_size") = 50, py::arg("pq_chunk_num") = 0, py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("quantizer_param") = QuantizerParam(), R"pbdoc( Constructs an DiskAnnIndexParams instance. @@ -933,6 +1072,8 @@ Constructs an DiskAnnIndexParams instance. Clamped to [1, 1024]. Defaults to 0. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED. + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). )pbdoc") .def_property_readonly("max_degree", &DiskAnnIndexParams::max_degree, "int: max node degree.") @@ -955,6 +1096,9 @@ Constructs an DiskAnnIndexParams instance. dict["pq_chunk_num"] = self.pq_chunk_num(); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") @@ -968,20 +1112,25 @@ Constructs an DiskAnnIndexParams instance. ", \"list_size\":" + std::to_string(self.list_size()) + ", \"pq_chunk_num\":" + std::to_string(self.pq_chunk_num()) + ", \"quantize_type\":" + - quantize_type_to_string(self.quantize_type()) + "}"; + quantize_type_to_string(self.quantize_type()) + + ", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; }) .def(py::pickle( [](const DiskAnnIndexParams &self) { return py::make_tuple(self.metric_type(), self.max_degree(), self.list_size(), self.pq_chunk_num(), - self.quantize_type()); + self.quantize_type(), + self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 5) + if (t.size() != 5 && t.size() != 6) throw std::runtime_error("Invalid state for DiskAnnIndexParams"); + QuantizerParam qp(t.size() >= 6 ? t[5].cast() : false); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), - t[3].cast(), t[4].cast()); + t[3].cast(), t[4].cast(), qp); })); } diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 6405c3220..70bc0be30 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -66,12 +66,16 @@ endif() # (real on Linux x86_64, stub on other platforms). Including them here causes # duplicate symbols and missing -laio when test binaries link both zvec_core # (via zvec) and core_knn_diskann. +# Exclude quantizer/rotator files from zvec_core. They are built by the +# separate core_rotator library. +list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/quantizer/rotator/.*") + list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/algorithm/diskann/.*") if(NOT DISKANN_SUPPORTED) list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/interface/indexes/diskann_index\\.cc") endif() -set(ZVEC_CORE_LIBS zvec_ailego zvec_turbo sparsehash magic_enum rabitqlib) +set(ZVEC_CORE_LIBS zvec_ailego zvec_turbo sparsehash magic_enum rabitqlib core_rotator) # The plugin loader uses dlopen/dlsym, so link libdl on Linux. if(CMAKE_SYSTEM_NAME STREQUAL "Linux") list(APPEND ZVEC_CORE_LIBS ${CMAKE_DL_LIBS}) diff --git a/src/core/algorithm/ivf/ivf_entity.cc b/src/core/algorithm/ivf/ivf_entity.cc index 6dccc2b2c..decc86d22 100644 --- a/src/core/algorithm/ivf/ivf_entity.cc +++ b/src/core/algorithm/ivf/ivf_entity.cc @@ -71,6 +71,16 @@ int IVFEntity::IVFReformerWrapper::init(const IndexMeta &imeta) { return 0; } +//! Load reformer state (e.g. rotation matrix) from storage +int IVFEntity::IVFReformerWrapper::load(const IndexStorage::Pointer &storage) { + if (!reformer_) { + return 0; + } + int ret = reformer_->load(storage); + ivf_check_with_msg(ret, "Failed to load reformer state"); + return 0; +} + //! Update the params, Called by gpu searcher only int IVFEntity::IVFReformerWrapper::update(const IndexMeta &meta) { auto &name = meta.reformer_name(); @@ -503,6 +513,12 @@ int IVFEntity::load(const IndexStorage::Pointer &container) { //! Load the remaining segments container_ = container; + + //! Load reformer state (e.g. rotation matrix) from the main container, + //! which holds the rotator segment dumped at build time. + ret = reformer_.load(container); + ivf_check_error_code(ret); + size_t expect_size = header_.inverted_body_size; inverted_ = load_segment(IVF_INVERTED_BODY_SEG_ID, expect_size); if (!inverted_) { diff --git a/src/core/algorithm/ivf/ivf_entity.h b/src/core/algorithm/ivf/ivf_entity.h index e6fd4b6c4..e0265b6eb 100644 --- a/src/core/algorithm/ivf/ivf_entity.h +++ b/src/core/algorithm/ivf/ivf_entity.h @@ -267,6 +267,9 @@ class IVFEntity { //! Initialize int init(const IndexMeta &imeta); + //! Load reformer state (e.g. rotation matrix) from storage + int load(const IndexStorage::Pointer &storage); + //! Update int update(const IndexMeta &meta); diff --git a/src/core/algorithm/ivf/ivf_params.h b/src/core/algorithm/ivf/ivf_params.h index a33a7aa50..6cd66b474 100644 --- a/src/core/algorithm/ivf/ivf_params.h +++ b/src/core/algorithm/ivf/ivf_params.h @@ -62,6 +62,8 @@ static const std::string PARAM_IVF_BUILDER_BLOCK_VECTOR_COUNT( // searcher params static const std::string PARAM_IVF_SEARCHER_SCAN_RATIO( "proxima.ivf.searcher.scan_ratio"); +static const std::string PARAM_IVF_SEARCHER_NPROBE( + "proxima.ivf.searcher.nprobe"); static const std::string PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD( "proxima.ivf.searcher.brute_force_threshold"); static const std::string PARAM_IVF_SEARCHER_OPTIMIZER( diff --git a/src/core/algorithm/ivf/ivf_searcher.cc b/src/core/algorithm/ivf/ivf_searcher.cc index 972fc8680..047046701 100644 --- a/src/core/algorithm/ivf/ivf_searcher.cc +++ b/src/core/algorithm/ivf/ivf_searcher.cc @@ -86,6 +86,13 @@ int IVFSearcher::load(IndexStorage::Pointer container, } auto reformer = centroid_index_->reformer(); + if (reformer) { + //! The centroid index is loaded from the centroid sub-segment which does + //! not contain the rotator segment. Load the reformer state (e.g. rotation + //! matrix) from the main container instead. + ret = reformer->load(container); + ivf_check_error_code(ret); + } params_.set(PARAM_IVF_SEARCHER_CONVERTER_REFORMER, reformer); //! load iverted index diff --git a/src/core/algorithm/ivf/ivf_searcher_context.h b/src/core/algorithm/ivf/ivf_searcher_context.h index d9ccc45c1..dbd2b7ae1 100644 --- a/src/core/algorithm/ivf/ivf_searcher_context.h +++ b/src/core/algorithm/ivf/ivf_searcher_context.h @@ -62,18 +62,27 @@ class IVFSearcherContext : public IndexSearcher::Context { params.get(PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params.get(PARAM_IVF_SEARCHER_SCAN_RATIO, &scan_ratio_); + params.get(PARAM_IVF_SEARCHER_NPROBE, &nprobe_); if (scan_ratio_ <= 0.0) { LOG_ERROR("Invalid params %s=%f", PARAM_IVF_SEARCHER_SCAN_RATIO.c_str(), scan_ratio_); return IndexError_InvalidArgument; } - size_t topk_val = - std::max(static_cast( - std::round(entity_->inverted_list_count() * scan_ratio_)), - 1u); - centroid_searcher_ctx_->set_topk(topk_val); - max_scan_count_ = - static_cast(std::ceil(entity_->vector_count() * scan_ratio_)); + size_t nlist = entity_->inverted_list_count(); + size_t topk_val; + if (nprobe_ > 0) { + //! nprobe explicitly controls how many inverted lists (centroids) to + //! probe. Do not let max_scan_count_ cut off the probed lists. + topk_val = std::min(static_cast(nprobe_), nlist); + topk_val = std::max(topk_val, static_cast(1)); + max_scan_count_ = static_cast(entity_->vector_count()); + } else { + topk_val = + std::max(static_cast(std::round(nlist * scan_ratio_)), 1u); + max_scan_count_ = static_cast( + std::ceil(entity_->vector_count() * scan_ratio_)); + } + centroid_searcher_ctx_->set_topk(static_cast(topk_val)); max_scan_count_ = std::max(bruteforce_threshold_, max_scan_count_); return 0; } @@ -215,6 +224,7 @@ class IVFSearcherContext : public IndexSearcher::Context { uint32_t topk_{0}; uint32_t magic_{0}; float scan_ratio_{kDefaultScanRatio}; + int nprobe_{0}; uint32_t max_scan_count_{0}; uint32_t bruteforce_threshold_{kDefaultBfThreshold}; }; diff --git a/src/core/algorithm/ivf/ivf_streamer.cc b/src/core/algorithm/ivf/ivf_streamer.cc index a2c924141..e42728e9a 100644 --- a/src/core/algorithm/ivf/ivf_streamer.cc +++ b/src/core/algorithm/ivf/ivf_streamer.cc @@ -86,6 +86,13 @@ int IVFStreamer::open(IndexStorage::Pointer storage) { } auto reformer = centroid_index_->reformer(); + if (reformer) { + //! The centroid index is loaded from the centroid sub-segment which does + //! not contain the rotator segment. Load the reformer state (e.g. rotation + //! matrix) from the main storage instead. + ret = reformer->load(storage); + ivf_check_error_code(ret); + } params_.set(PARAM_IVF_SEARCHER_CONVERTER_REFORMER, reformer); //! load iverted index diff --git a/src/core/interface/index.cc b/src/core/interface/index.cc index 332d4526f..4cb266d3a 100644 --- a/src/core/interface/index.cc +++ b/src/core/interface/index.cc @@ -182,6 +182,21 @@ int Index::CreateAndInitConverterReformer(const QuantizerParam ¶m, } } + // Pass enable_rotate to converter_params (only effective for INT8) + if (param.enable_rotate) { + if (param.type == QuantizerType::kInt8) { + if (index_param.metric_type == MetricType::kCosine) { + converter_params.set("cosine.converter.enable_rotate", true); + } else { + converter_params.set("integer_streaming.converter.enable_rotate", true); + } + } else { + LOG_WARN( + "enable_rotate is only supported for INT8 quantizer, " + "ignoring for current quantizer type"); + } + } + proxima_index_meta_.set_converter(converter_name, 0, converter_params); converter_ = core::IndexFactory::CreateConverter(converter_name); if (converter_ == nullptr || @@ -336,6 +351,21 @@ int Index::Open(const std::string &file_path, StorageOptions storage_options) { // converter/reformer/metric are created in IndexFactory::CreateIndex // TODO: init + // Load reformer data from storage (e.g., rotation matrix for + // IntegerStreaming) + if (reformer_ != nullptr) { + // When building a new index, dump converter state (e.g., rotator) to + // storage so the reformer can load it. This is needed for + // enable_rotate with INT8 quantization. + if (storage_options.create_new && converter_ != nullptr) { + converter_->dump_to_storage(storage_); + } + if (reformer_->load(storage_) != 0) { + LOG_ERROR("Failed to load reformer, path: %s", file_path.c_str()); + return core::IndexError_Runtime; + } + } + // TODO: context pool if (!init_context()) { // to validate if any error, will be overwritten LOG_ERROR("Failed to init context"); diff --git a/src/core/interface/index_param.cc b/src/core/interface/index_param.cc index 9226eeed0..5d75276fd 100644 --- a/src/core/interface/index_param.cc +++ b/src/core/interface/index_param.cc @@ -251,12 +251,16 @@ ailego::JsonObject QuantizerParam::SerializeToJsonObject( json_obj.set("type", zvec::ailego::JsonValue(magic_enum::enum_name(type).data())); } + if (!omit_empty_value || enable_rotate) { + json_obj.set("enable_rotate", ailego::JsonValue(enable_rotate)); + } return json_obj; } bool QuantizerParam::DeserializeFromJsonObject( const ailego::JsonObject &json_obj) { DESERIALIZE_ENUM_FIELD(json_obj, type, QuantizerType); + DESERIALIZE_VALUE_FIELD(json_obj, enable_rotate); return true; } diff --git a/src/core/interface/indexes/ivf_index.cc b/src/core/interface/indexes/ivf_index.cc index d85acce62..6bd793b2a 100644 --- a/src/core/interface/indexes/ivf_index.cc +++ b/src/core/interface/indexes/ivf_index.cc @@ -121,6 +121,10 @@ int IVFIndex::Open(const std::string &file_path, LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } + // Load reformer data from storage (e.g., rotation matrix for INT8+rotate) + if (reformer_ != nullptr) { + reformer_->load(storage_); + } is_trained_ = true; } is_open_ = true; @@ -164,6 +168,10 @@ int IVFIndex::Train() { dumper->create(file_path_); builder_->dump(dumper); + // Dump converter state (e.g., rotator for INT8+rotate) to dumper + if (converter_) { + converter_->dump(dumper); + } dumper->close(); int ret = storage_->open(file_path_, false); if (ret != 0) { @@ -175,6 +183,10 @@ int IVFIndex::Train() { LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } + // Load reformer data from storage (e.g., rotation matrix) + if (reformer_ != nullptr) { + reformer_->load(storage_); + } is_trained_ = true; return 0; } @@ -209,11 +221,8 @@ int IVFIndex::_prepare_for_search( } if (ivf_search_param->nprobe > 0) { - // TODO: 1. sparse; 2. default ef ailego::Params params; - // need fix - params.set(core::PARAM_IVF_BUILDER_CENTROID_COUNT, - ivf_search_param->nprobe); + params.set(core::PARAM_IVF_SEARCHER_NPROBE, ivf_search_param->nprobe); context->update(params); } return 0; @@ -229,6 +238,10 @@ int IVFIndex::Merge(const std::vector &indexes, dumper->create(file_path_); builder_->dump(dumper); + // Dump converter state (e.g., rotator for INT8+rotate) to dumper + if (converter_) { + converter_->dump(dumper); + } dumper->close(); int ret = storage_->open(file_path_, false); if (ret != 0) { @@ -240,6 +253,10 @@ int IVFIndex::Merge(const std::vector &indexes, LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } + // Load reformer data from storage (e.g., rotation matrix) + if (reformer_ != nullptr) { + reformer_->load(storage_); + } is_trained_ = true; return 0; } diff --git a/src/core/quantizer/CMakeLists.txt b/src/core/quantizer/CMakeLists.txt index f5c9ad898..6907b495c 100644 --- a/src/core/quantizer/CMakeLists.txt +++ b/src/core/quantizer/CMakeLists.txt @@ -6,11 +6,13 @@ if(NOT APPLE) "-Wl,--exclude-libs,libparquet.a:libarrow.a:libarrow_bundled_dependencies.a") endif() +cc_directory(rotator) + cc_library( NAME core_quantizer STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc - LIBS zvec_ailego zvec_turbo core_framework + LIBS zvec_ailego zvec_turbo core_framework core_rotator INCS . ${PROJECT_ROOT_DIR}/src/core LDFLAGS "${CORE_QUANTIZER_LDFLAGS}" VERSION "${PROXIMA_ZVEC_VERSION}" diff --git a/src/core/quantizer/cosine_converter.cc b/src/core/quantizer/cosine_converter.cc index ded1e3eb5..8354ca269 100644 --- a/src/core/quantizer/cosine_converter.cc +++ b/src/core/quantizer/cosine_converter.cc @@ -18,6 +18,7 @@ #include #include #include +#include "rotator/rotator.h" #include "record_quantizer.h" #include "../metric/metric_params.h" @@ -54,6 +55,10 @@ class CosineConverterHolder : public IndexHolder { type_ == IndexMeta::DataType::DT_INT8) { buffer_.resize(element_size, 0); } + + if (owner_->rotator_) { + rotate_buffer_.resize(owner_->rotator_->dimension()); + } } this->convert_record(); @@ -116,17 +121,26 @@ class CosineConverterHolder : public IndexHolder { original_element_size); float *buf = reinterpret_cast(&normalize_buffer_[0]); + const float *vec = buf; + + if (owner_->rotator_) { + owner_->rotator_->rotate(vec, rotate_buffer_.data()); + vec = rotate_buffer_.data(); + } float norm = 0.0f; - ailego::Normalizer::L2(buf, original_dimension_, &norm); + ailego::Normalizer::L2(const_cast(vec), + original_dimension_, &norm); if (type_ == IndexMeta::DataType::DT_FP32) { + ::memcpy(reinterpret_cast(&normalize_buffer_[0]), vec, + original_dimension_ * sizeof(float)); ::memcpy(reinterpret_cast(&normalize_buffer_[0]) + original_dimension_, &norm, NORM_SIZE); } else if (type_ == IndexMeta::DataType::DT_FP16) { ailego::FloatHelper::ToFP16( - buf, original_dimension_, + const_cast(vec), original_dimension_, reinterpret_cast(&buffer_[0])); ::memcpy( @@ -134,9 +148,8 @@ class CosineConverterHolder : public IndexHolder { &norm, NORM_SIZE); } else if (type_ == IndexMeta::DataType::DT_INT4 || type_ == IndexMeta::DataType::DT_INT8) { - RecordQuantizer::quantize_record( - reinterpret_cast(normalize_buffer_.data()), - original_dimension_, type_, false, &buffer_[0]); + RecordQuantizer::quantize_record(vec, original_dimension_, type_, + false, &buffer_[0]); ::memcpy(reinterpret_cast(&buffer_[0]) + element_size - NORM_SIZE, @@ -149,6 +162,7 @@ class CosineConverterHolder : public IndexHolder { const CosineConverterHolder *owner_{nullptr}; std::string buffer_{}; std::string normalize_buffer_{}; + std::vector rotate_buffer_; IndexHolder::Iterator::Pointer front_iter_{}; size_t dimension_{0u}; size_t original_dimension_{0u}; @@ -159,11 +173,13 @@ class CosineConverterHolder : public IndexHolder { //! Constructor CosineConverterHolder(IndexHolder::Pointer front, IndexMeta::DataType original_type, - IndexMeta::DataType type) + IndexMeta::DataType type, + std::shared_ptr rotator = nullptr) : front_(std::move(front)), original_type_(original_type), type_(type), - dimension_(front_->dimension()) {} + dimension_(front_->dimension()), + rotator_(std::move(rotator)) {} //! Retrieve count of elements in holder (-1 indicates unknown) size_t count(void) const override { @@ -222,6 +238,7 @@ class CosineConverterHolder : public IndexHolder { IndexMeta::DataType original_type_{}; IndexMeta::DataType type_{}; uint32_t dimension_{0}; + std::shared_ptr rotator_{}; }; /*! Converter of Cosine @@ -264,8 +281,17 @@ class CosineConverter : public IndexConverter { return IndexError_Unsupported; } + params.get(COSINE_CONVERTER_ENABLE_ROTATE, &enable_rotate_); + ailego::Params reformer_params; + if (enable_rotate_) { + size_t dim = index_meta.dimension(); + rotator_ = std::make_shared(); + rotator_->init(dim); + LOG_DEBUG("CosineConverter: rotation enabled, dim=%zu", dim); + } + if (dst_type_ == IndexMeta::DataType::DT_INT8) { meta_.set_converter("CosineInt8Converter", 0, params); meta_.set_reformer("CosineInt8Reformer", 0, reformer_params); @@ -333,12 +359,23 @@ class CosineConverter : public IndexConverter { *stats_.mutable_transformed_count() += holder->count(); holder_ = std::make_shared( - holder, holder->data_type(), dst_type_); + holder, holder->data_type(), dst_type_, rotator_); return 0; } //! Dump index into storage - int dump(const IndexDumper::Pointer & /*dumper*/) override { + int dump(const IndexDumper::Pointer &dumper) override { + if (rotator_) { + return rotator_->dump(dumper); + } + return 0; + } + + //! Dump converter state to storage + int dump_to_storage(const IndexStorage::Pointer &storage) override { + if (rotator_) { + return rotator_->dump(storage); + } return 0; } @@ -378,6 +415,8 @@ class CosineConverter : public IndexConverter { IndexHolder::Pointer holder_{}; IndexMeta::DataType original_type_{IndexMeta::DataType::DT_UNDEFINED}; IndexMeta::DataType dst_type_{IndexMeta::DataType::DT_UNDEFINED}; + bool enable_rotate_{false}; + std::shared_ptr rotator_{}; }; INDEX_FACTORY_REGISTER_CONVERTER_ALIAS(CosineNormalizeConverter, diff --git a/src/core/quantizer/cosine_reformer.cc b/src/core/quantizer/cosine_reformer.cc index d6080b8d9..b71d04fc6 100644 --- a/src/core/quantizer/cosine_reformer.cc +++ b/src/core/quantizer/cosine_reformer.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. #include +#include #include #include #include #include #include +#include "rotator/rotator.h" #include "record_quantizer.h" namespace zvec { @@ -53,7 +55,24 @@ class CosineReformer : public IndexReformer { } //! Load index from container - int load(IndexStorage::Pointer) override { + //! Auto-detects rotation by checking for rotator segment in storage. + int load(IndexStorage::Pointer storage) override { + if (enable_rotate_ || storage->get(RECORD_ROTATOR_SEG_ID)) { + rotator_ = std::make_shared(); + int ret = rotator_->open(storage); + if (ret != 0) { + if (enable_rotate_) { + LOG_ERROR("CosineReformer: load rotator failed, ret=%d", ret); + rotator_.reset(); + return ret; + } + rotator_.reset(); + } else { + enable_rotate_ = true; + LOG_DEBUG("CosineReformer: rotator auto-loaded, dim=%zu", + rotator_->dimension()); + } + } return 0; } @@ -83,28 +102,42 @@ class CosineReformer : public IndexReformer { ometa->set_meta(dst_type_, qmeta.dimension() + ExtraDimension(dst_type_)); out->resize(ometa->element_size()); - float norm = 0.0f; size_t origin_dimension = qmeta.dimension(); + const float *vec = reinterpret_cast(query); + float norm = 0.0f; + + // Fast path: no rotation — matches main branch behavior exactly std::string normalized_buffer(reinterpret_cast(query), qmeta.element_size()); - float *buf = reinterpret_cast(&normalized_buffer[0]); - ailego::Normalizer::L2(buf, origin_dimension, &norm); + if (enable_rotate_ && rotator_) { + // Rotate then normalize the rotated vector + std::vector rotate_buffer(rotator_->dimension()); + rotator_->rotate(vec, rotate_buffer.data()); + std::memcpy(buf, rotate_buffer.data(), + origin_dimension * sizeof(float)); + ailego::Normalizer::L2(buf, origin_dimension, &norm); + vec = buf; + } else { + ailego::Normalizer::L2(buf, origin_dimension, &norm); + vec = buf; + } ::memcpy(reinterpret_cast(&(*out)[0]) + ometa->element_size() - NORM_SIZE, &norm, NORM_SIZE); if (dst_type_ == IndexMeta::DataType::DT_FP32) { - ::memcpy(reinterpret_cast(&(*out)[0]), buf, + ::memcpy(reinterpret_cast(&(*out)[0]), vec, ometa->element_size() - NORM_SIZE); } else if (dst_type_ == IndexMeta::DataType::DT_FP16) { - RecordQuantizer::quantize_record(buf, origin_dimension, dst_type_, - false, &(*out)[0]); + RecordQuantizer::quantize_record(const_cast(vec), + qmeta.dimension(), dst_type_, false, + &(*out)[0]); } else if (dst_type_ == IndexMeta::DataType::DT_INT4 || dst_type_ == IndexMeta::DataType::DT_INT8) { - RecordQuantizer::quantize_record(buf, qmeta.dimension(), dst_type_, + RecordQuantizer::quantize_record(vec, qmeta.dimension(), dst_type_, false, &(*out)[0]); } } else if (type == IndexMeta::DataType::DT_FP16) { @@ -186,6 +219,11 @@ class CosineReformer : public IndexReformer { NORM_SIZE, NORM_SIZE); + // Rotation was applied in transform() for all FP32-origin paths (FP32, + // INT8, INT4 stored types). FP16 input path never rotates. + const bool need_inv_rotate = + (type != IndexMeta::DataType::DT_FP16 && enable_rotate_ && rotator_); + if (type == IndexMeta::DataType::DT_FP32) { if (dst_type_ != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; @@ -195,6 +233,11 @@ class CosineReformer : public IndexReformer { const float *in_buf = reinterpret_cast(in); this->denormalize(in_buf, out_buf, qmeta, norm); + if (need_inv_rotate) { + std::vector tmp(dimension); + rotator_->unrotate(out_buf, tmp.data()); + std::memcpy(out_buf, tmp.data(), dimension * sizeof(float)); + } } else if (type == IndexMeta::DataType::DT_FP16) { if (dst_type_ != IndexMeta::DataType::DT_FP16) { return IndexError_Unsupported; @@ -210,6 +253,7 @@ class CosineReformer : public IndexReformer { RecordQuantizer::unquantize_record(in, dimension, dst_type_, out_buf); this->denormalize(out_buf, out_buf, qmeta, norm); + // FP16 type path: no rotation was applied, skip inverse } else { ailego::Float16 *out_buf = reinterpret_cast(&(*out)[0]); @@ -228,6 +272,11 @@ class CosineReformer : public IndexReformer { RecordQuantizer::unquantize_record(in, dimension, dst_type_, out_buf); this->denormalize(out_buf, out_buf, qmeta, norm); + if (need_inv_rotate) { + std::vector tmp(dimension); + rotator_->unrotate(out_buf, tmp.data()); + std::memcpy(out_buf, tmp.data(), dimension * sizeof(float)); + } } return 0; @@ -262,6 +311,8 @@ class CosineReformer : public IndexReformer { //! Members IndexMeta::DataType original_type_{IndexMeta::DataType::DT_UNDEFINED}; IndexMeta::DataType dst_type_{IndexMeta::DataType::DT_UNDEFINED}; + bool enable_rotate_{false}; + std::shared_ptr rotator_{}; }; INDEX_FACTORY_REGISTER_REFORMER_ALIAS(CosineNormalizeReformer, CosineReformer, diff --git a/src/core/quantizer/integer_quantizer_converter.cc b/src/core/quantizer/integer_quantizer_converter.cc index f812b6e3c..a949aa879 100644 --- a/src/core/quantizer/integer_quantizer_converter.cc +++ b/src/core/quantizer/integer_quantizer_converter.cc @@ -18,6 +18,7 @@ #include #include #include +#include "rotator/rotator.h" #include "record_quantizer.h" #include "../metric/metric_params.h" @@ -378,6 +379,7 @@ class IntegerStreamingConverter : public IndexConverter { meta_ = index_meta; params.get(INTEGER_STREAMING_CONVERTER_ENABLE_NORMALIZE, &enable_normalize_); + params.get(INTEGER_STREAMING_CONVERTER_ENABLE_ROTATE, &enable_rotate_); ailego::Params reformer_params; if (enable_normalize_) { reformer_params.set(INTEGER_STREAMING_REFORMER_ENABLE_NORMALIZE, true); @@ -390,6 +392,13 @@ class IntegerStreamingConverter : public IndexConverter { reformer_params.set(INTEGER_STREAMING_REFORMER_IS_EUCLIDEAN, true); } + // Create rotator if rotation is enabled + if (enable_rotate_) { + rotator_ = std::make_shared(); + rotator_->init(index_meta.dimension()); + LOG_DEBUG("IntegerStreamingConverter: rotation enabled, dim=%zu", + static_cast(index_meta.dimension())); + } if (data_type_ == IndexMeta::DataType::DT_INT8) { meta_.set_converter("Int8StreamingConverter", 0, params); @@ -433,12 +442,30 @@ class IntegerStreamingConverter : public IndexConverter { *stats_.mutable_transformed_count() += holder->count(); holder_ = std::make_shared( - holder, data_type_, enable_normalize_, is_euclidean_); + holder, data_type_, enable_normalize_, is_euclidean_, rotator_); return 0; } - //! Dump index into storage - int dump(const IndexDumper::Pointer & /*dumper*/) override { + //! Dump index into storage (writes rotator segment if rotate enabled) + int dump(const IndexDumper::Pointer &dumper) override { + if (enable_rotate_ && rotator_) { + return rotator_->dump(dumper); + } + return 0; + } + + //! Dump converter state to IndexStorage for streaming build + int dump_to_storage(const IndexStorage::Pointer &storage) override { + if (enable_rotate_ && rotator_) { + int ret = rotator_->dump(storage); + if (ret != 0) { + LOG_ERROR( + "IntegerStreamingConverter: dump rotator to storage failed, ret=%d", + ret); + return ret; + } + LOG_DEBUG("IntegerStreamingConverter: rotator dumped to storage"); + } return 0; } @@ -468,7 +495,8 @@ class IntegerStreamingConverter : public IndexConverter { IndexHolder::Iterator::Pointer &&iter) : owner_(owner), buffer_(owner->element_size(), 0), - normalize_buffer_(owner->front_->element_size(), 0), + normalize_buffer_(owner->dimension_ * sizeof(float), 0), + rotate_buffer_(owner->dimension_ * sizeof(float), 0), front_iter_(std::move(iter)) { this->encode_record(); } @@ -503,18 +531,24 @@ class IntegerStreamingConverter : public IndexConverter { if (front_iter_->is_valid()) { const float *vec = reinterpret_cast(front_iter_->data()); + size_t dim = owner_->dimension_; + if (owner_->rotator_) { + float *rotate_buf = + reinterpret_cast(rotate_buffer_.data()); + owner_->rotator_->rotate(vec, rotate_buf); + vec = rotate_buf; + } if (owner_->enable_normalize_) { float norm = 0.0; - memcpy((void *)normalize_buffer_.data(), vec, - owner_->front_->element_size()); + memcpy((void *)normalize_buffer_.data(), vec, dim * sizeof(float)); ailego::Normalizer::L2((float *)normalize_buffer_.data(), - owner_->dimension_, &norm); + dim, &norm); vec = (float *)normalize_buffer_.data(); } - RecordQuantizer::quantize_record( - vec, owner_->dimension_, owner_->data_type(), - owner_->is_euclidean_, buffer_.data()); + RecordQuantizer::quantize_record(vec, dim, owner_->data_type(), + owner_->is_euclidean_, + buffer_.data()); } } @@ -522,18 +556,21 @@ class IntegerStreamingConverter : public IndexConverter { const IntegerStreamingConverterHolder *owner_{nullptr}; std::vector buffer_{}; std::string normalize_buffer_{}; + std::string rotate_buffer_{}; IndexHolder::Iterator::Pointer front_iter_{}; }; //! Constructor IntegerStreamingConverterHolder(IndexHolder::Pointer front, IndexMeta::DataType tp, - bool enable_normalize, bool is_euclidean) + bool enable_normalize, bool is_euclidean, + std::shared_ptr rotator) : front_(std::move(front)), data_type_(tp), dimension_(front_->dimension()), enable_normalize_(enable_normalize), - is_euclidean_(is_euclidean) {} + is_euclidean_(is_euclidean), + rotator_(std::move(rotator)) {} //! Retrieve count of elements in holder (-1 indicates unknown) size_t count(void) const override { @@ -576,6 +613,7 @@ class IntegerStreamingConverter : public IndexConverter { uint32_t dimension_{0}; bool enable_normalize_{false}; bool is_euclidean_{false}; + std::shared_ptr rotator_{}; }; static size_t ExtraDimension(IndexMeta::DataType type) { @@ -593,7 +631,9 @@ class IntegerStreamingConverter : public IndexConverter { IndexHolder::Pointer holder_{}; IndexMeta::DataType data_type_{}; bool enable_normalize_{false}; + bool enable_rotate_{false}; bool is_euclidean_{false}; + std::shared_ptr rotator_{}; }; INDEX_FACTORY_REGISTER_CONVERTER_ALIAS( diff --git a/src/core/quantizer/integer_quantizer_reformer.cc b/src/core/quantizer/integer_quantizer_reformer.cc index 4228d0fda..6be69f3b1 100644 --- a/src/core/quantizer/integer_quantizer_reformer.cc +++ b/src/core/quantizer/integer_quantizer_reformer.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include #include #include +#include "rotator/rotator.h" #include "record_quantizer.h" namespace zvec { @@ -295,7 +297,30 @@ class IntegerStreamingReformer : public IndexReformer { } //! Load index from container - int load(IndexStorage::Pointer) override { + //! Auto-detects rotation by checking for rotator segment in storage. + //! No need for enable_rotate in search config. + int load(IndexStorage::Pointer storage) override { + // If config explicitly enables rotate but rotator not yet loaded, try + // storage If config doesn't enable rotate, still try storage (auto-detect) + if (enable_rotate_ || storage->get(RECORD_ROTATOR_SEG_ID)) { + rotator_ = std::make_shared(); + int ret = rotator_->open(storage); + if (ret != 0) { + if (enable_rotate_) { + // Config said enable_rotate but storage has no rotator — error + LOG_ERROR("IntegerStreamingReformer: load rotator failed, ret=%d", + ret); + rotator_.reset(); + return ret; + } + // No rotator in storage, rotation not available + rotator_.reset(); + } else { + enable_rotate_ = true; + LOG_DEBUG("IntegerStreamingReformer: rotator auto-loaded, dim=%zu", + rotator_->dimension()); + } + } return 0; } @@ -319,10 +344,16 @@ class IntegerStreamingReformer : public IndexReformer { ometa->set_meta(data_type_, qmeta.dimension() + extra_dimension_); out->resize(ometa->element_size()); const float *vec = reinterpret_cast(query); + std::unique_ptr rotate_buffer; + if (enable_rotate_ && rotator_) { + rotate_buffer.reset(new float[rotator_->dimension()]); + rotator_->rotate(vec, rotate_buffer.get()); + vec = rotate_buffer.get(); + } std::unique_ptr normalized; if (enable_normalize_) { normalized.reset(new float[qmeta.dimension()]); - vec = normalize(query, qmeta, normalized.get()); + vec = normalize(vec, qmeta, normalized.get()); } RecordQuantizer::quantize_record(vec, qmeta.dimension(), data_type_, @@ -344,13 +375,21 @@ class IntegerStreamingReformer : public IndexReformer { *ometa = qmeta; ometa->set_meta(data_type_, qmeta.dimension() + extra_dimension_); out->resize(count * ometa->element_size()); + std::unique_ptr rotate_buffer; std::unique_ptr normalized; + if (enable_rotate_ && rotator_) { + rotate_buffer.reset(new float[rotator_->dimension()]); + } if (enable_normalize_) { normalized.reset(new float[qmeta.dimension()]); } for (size_t i = 0; i < count; ++i) { const float *vec = reinterpret_cast(query) + i * qmeta.dimension(); + if (enable_rotate_ && rotator_) { + rotator_->rotate(vec, rotate_buffer.get()); + vec = rotate_buffer.get(); + } if (enable_normalize_) { vec = normalize(vec, qmeta, normalized.get()); } @@ -378,10 +417,16 @@ class IntegerStreamingReformer : public IndexReformer { ometa->set_meta(data_type_, rmeta.dimension() + extra_dimension_); out->resize(ometa->element_size()); const float *vec = reinterpret_cast(record); + std::unique_ptr rotate_buffer; + if (enable_rotate_ && rotator_) { + rotate_buffer.reset(new float[rotator_->dimension()]); + rotator_->rotate(vec, rotate_buffer.get()); + vec = rotate_buffer.get(); + } std::unique_ptr normalized; if (enable_normalize_) { normalized.reset(new float[rmeta.dimension()]); - vec = normalize(record, rmeta, normalized.get()); + vec = normalize(vec, rmeta, normalized.get()); } RecordQuantizer::quantize_record(vec, rmeta.dimension(), data_type_, @@ -404,13 +449,21 @@ class IntegerStreamingReformer : public IndexReformer { *ometa = rmeta; ometa->set_meta(data_type_, rmeta.dimension() + extra_dimension_); out->resize(count * ometa->element_size()); + std::unique_ptr rotate_buffer; std::unique_ptr normalized; + if (enable_rotate_ && rotator_) { + rotate_buffer.reset(new float[rotator_->dimension()]); + } if (enable_normalize_) { normalized.reset(new float[rmeta.dimension()]); } for (size_t i = 0; i < count; ++i) { const float *vec = reinterpret_cast(records) + i * rmeta.dimension(); + if (enable_rotate_ && rotator_) { + rotator_->rotate(vec, rotate_buffer.get()); + vec = rotate_buffer.get(); + } if (enable_normalize_) { vec = normalize(vec, rmeta, normalized.get()); } @@ -447,15 +500,23 @@ class IntegerStreamingReformer : public IndexReformer { std::string *out) const override { if (enable_normalize_) { LOG_ERROR("Unsupported revert for normalized value"); - return IndexError_Unsupported; } - out->resize((qmeta.dimension() - extra_dimension_) * sizeof(float)); - float *out_buf = reinterpret_cast(out->data()); + const size_t stored_dim = qmeta.dimension() - extra_dimension_; - RecordQuantizer::unquantize_record(in, qmeta.dimension() - extra_dimension_, - data_type_, out_buf); + // Step 1: Unquantize into out buffer (stored_dim floats) + out->resize(stored_dim * sizeof(float)); + float *out_buf = reinterpret_cast(out->data()); + RecordQuantizer::unquantize_record(in, stored_dim, data_type_, out_buf); + + // Step 2: Inverse rotate in-place if rotation was applied + if (enable_rotate_ && rotator_) { + std::vector tmp(rotator_->dimension()); + rotator_->unrotate(out_buf, tmp.data()); + out->assign(reinterpret_cast(tmp.data()), + tmp.size() * sizeof(float)); + } return 0; } @@ -465,6 +526,8 @@ class IntegerStreamingReformer : public IndexReformer { uint32_t extra_dimension_{0}; bool enable_normalize_{false}; bool is_euclidean_{false}; + bool enable_rotate_{false}; + std::shared_ptr rotator_{}; }; INDEX_FACTORY_REGISTER_REFORMER_ALIAS( diff --git a/src/core/quantizer/quantizer_params.h b/src/core/quantizer/quantizer_params.h index a089a2d9f..d56c8591d 100644 --- a/src/core/quantizer/quantizer_params.h +++ b/src/core/quantizer/quantizer_params.h @@ -100,6 +100,8 @@ static const std::string INT4_QUANTIZER_REFORMER_METRIC = //! CosineConverter static const std::string COSINE_CONVERTER_FORCED_HALF_FLOAT = "cosine.converter.forced_half_float"; +static const std::string COSINE_CONVERTER_ENABLE_ROTATE = + "cosine.converter.enable_rotate"; //! CosineReformer static const std::string COSINE_REFORMER_FORCED_HALF_FLOAT = @@ -108,8 +110,10 @@ static const std::string COSINE_REFORMER_FORCED_HALF_FLOAT = //! IntegerStreamingConverter static const std::string INTEGER_STREAMING_CONVERTER_ENABLE_NORMALIZE = "integer_streaming.converter.enable_normalize"; +static const std::string INTEGER_STREAMING_CONVERTER_ENABLE_ROTATE = + "integer_streaming.converter.enable_rotate"; -//! IntegerStreamingConverter +//! IntegerStreamingReformer static const std::string INTEGER_STREAMING_REFORMER_ENABLE_NORMALIZE = "integer_streaming.reformer.enable_normalize"; static const std::string INTEGER_STREAMING_REFORMER_IS_EUCLIDEAN = diff --git a/src/core/quantizer/rotator/CMakeLists.txt b/src/core/quantizer/rotator/CMakeLists.txt new file mode 100644 index 000000000..1403cc53d --- /dev/null +++ b/src/core/quantizer/rotator/CMakeLists.txt @@ -0,0 +1,16 @@ +include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) +include(${PROJECT_ROOT_DIR}/cmake/option.cmake) + +# No rabitqlib dependency: matrix_rotator.cc implements Householder QR +# manually with plain C++ to avoid ISA-sensitive Eigen inline functions +# that cause ODR violations (duplicate codegen with different -march flags) +# leading to SEGFAULT on linux-x64-clang. + +cc_library( + NAME core_rotator + STATIC STRICT ALWAYS_LINK + SRCS *.cc + LIBS zvec_ailego core_framework + INCS . .. ${PROJECT_ROOT_DIR}/src/core + VERSION "${PROXIMA_ZVEC_VERSION}" +) diff --git a/src/core/quantizer/rotator/fht_rotator.cc b/src/core/quantizer/rotator/fht_rotator.cc new file mode 100644 index 000000000..099328981 --- /dev/null +++ b/src/core/quantizer/rotator/fht_rotator.cc @@ -0,0 +1,156 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fht_rotator.h" +#include +#include +#include + +namespace zvec { +namespace core { + +// ============================================================================ +// FhtKacRotatorImpl method implementations +// ============================================================================ + +void FhtKacRotatorImpl::init(size_t dim) { + flip.resize(4 * dim / kByteLen); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist(0, 255); + for (auto &b : flip) b = static_cast(dist(gen)); +} + +void FhtKacRotatorImpl::rotate(const float *in, float *out, size_t dim) const { + std::memcpy(out, in, sizeof(float) * dim); + + if (trunc_dim == dim) { + // Exact power-of-2: 4 rounds of (flip -> FHT -> rescale) + ailego::fht_flip_sign(flip.data(), out, dim); + ailego::fht_inplace(out, trunc_dim); + ailego::fht_vec_rescale(out, trunc_dim, fac); + + ailego::fht_flip_sign(flip.data() + dim / kByteLen, out, dim); + ailego::fht_inplace(out, trunc_dim); + ailego::fht_vec_rescale(out, trunc_dim, fac); + + ailego::fht_flip_sign(flip.data() + 2 * dim / kByteLen, out, dim); + ailego::fht_inplace(out, trunc_dim); + ailego::fht_vec_rescale(out, trunc_dim, fac); + + ailego::fht_flip_sign(flip.data() + 3 * dim / kByteLen, out, dim); + ailego::fht_inplace(out, trunc_dim); + ailego::fht_vec_rescale(out, trunc_dim, fac); + + return; + } + + // Non-power-of-2 (64-aligned, e.g. 192, 320): 4 rounds with kacs_walk + size_t start = dim - trunc_dim; + float *trunc_ptr = out + start; + + // Round 1: FHT on [0, trunc_dim) + ailego::fht_flip_sign(flip.data(), out, dim); + ailego::fht_inplace(out, trunc_dim); + ailego::fht_vec_rescale(out, trunc_dim, fac); + ailego::fht_kacs_walk(out, dim); + + // Round 2: FHT on [start, start + trunc_dim) + ailego::fht_flip_sign(flip.data() + dim / kByteLen, out, dim); + ailego::fht_inplace(trunc_ptr, trunc_dim); + ailego::fht_vec_rescale(trunc_ptr, trunc_dim, fac); + ailego::fht_kacs_walk(out, dim); + + // Round 3: FHT on [0, trunc_dim) + ailego::fht_flip_sign(flip.data() + 2 * dim / kByteLen, out, dim); + ailego::fht_inplace(out, trunc_dim); + ailego::fht_vec_rescale(out, trunc_dim, fac); + ailego::fht_kacs_walk(out, dim); + + // Round 4: FHT on [start, start + trunc_dim) + ailego::fht_flip_sign(flip.data() + 3 * dim / kByteLen, out, dim); + ailego::fht_inplace(trunc_ptr, trunc_dim); + ailego::fht_vec_rescale(trunc_ptr, trunc_dim, fac); + ailego::fht_kacs_walk(out, dim); + + // Final rescale: combine the 4 kacs_walk reductions + ailego::fht_vec_rescale(out, dim, 0.25f); +} + +void FhtKacRotatorImpl::unrotate(const float *in, float *out, + size_t dim) const { + // Copy input into working buffer + std::vector data(in, in + dim); + + if (trunc_dim == dim) { + // Exact power-of-2: reverse 4 rounds in reverse order. + const float inv_fac = 1.0f / std::sqrt(static_cast(trunc_dim)); + for (int round = 3; round >= 0; --round) { + ailego::fht_inplace(data.data(), trunc_dim); + ailego::fht_vec_rescale(data.data(), trunc_dim, inv_fac); + ailego::fht_flip_sign(flip.data() + round * dim / kByteLen, data.data(), + dim); + } + std::memcpy(out, data.data(), dim * sizeof(float)); + return; + } + + // Non-power-of-2: undo final rescale(0.25) first + ailego::fht_vec_rescale(data.data(), dim, 4.0f); + + const float inv_fac = 1.0f / std::sqrt(static_cast(trunc_dim)); + size_t start = dim - trunc_dim; + float *trunc_ptr = data.data() + start; + + // Undo Round 4 (FHT on [start, start+trunc_dim)) + ailego::fht_inv_kacs_walk(data.data(), dim); + ailego::fht_inplace(trunc_ptr, trunc_dim); + ailego::fht_vec_rescale(trunc_ptr, trunc_dim, inv_fac); + ailego::fht_flip_sign(flip.data() + 3 * dim / kByteLen, data.data(), dim); + + // Undo Round 3 (FHT on [0, trunc_dim)) + ailego::fht_inv_kacs_walk(data.data(), dim); + ailego::fht_inplace(data.data(), trunc_dim); + ailego::fht_vec_rescale(data.data(), trunc_dim, inv_fac); + ailego::fht_flip_sign(flip.data() + 2 * dim / kByteLen, data.data(), dim); + + // Undo Round 2 (FHT on [start, start+trunc_dim)) + ailego::fht_inv_kacs_walk(data.data(), dim); + ailego::fht_inplace(trunc_ptr, trunc_dim); + ailego::fht_vec_rescale(trunc_ptr, trunc_dim, inv_fac); + ailego::fht_flip_sign(flip.data() + dim / kByteLen, data.data(), dim); + + // Undo Round 1 (FHT on [0, trunc_dim)) + ailego::fht_inv_kacs_walk(data.data(), dim); + ailego::fht_inplace(data.data(), trunc_dim); + ailego::fht_vec_rescale(data.data(), trunc_dim, inv_fac); + ailego::fht_flip_sign(flip.data(), data.data(), dim); + + std::memcpy(out, data.data(), dim * sizeof(float)); +} + +void FhtKacRotatorImpl::save(char *data) const { + std::memcpy(data, flip.data(), flip.size()); +} + +void FhtKacRotatorImpl::load(const char *data) { + std::memcpy(flip.data(), data, flip.size()); +} + +size_t FhtKacRotatorImpl::dump_bytes() const { + return flip.size(); +} + +} // namespace core +} // namespace zvec diff --git a/src/core/quantizer/rotator/fht_rotator.h b/src/core/quantizer/rotator/fht_rotator.h new file mode 100644 index 000000000..91925d67e --- /dev/null +++ b/src/core/quantizer/rotator/fht_rotator.h @@ -0,0 +1,48 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include + +namespace zvec { +namespace core { + +// ============================================================================ +// FhtKacRotatorImpl - O(d log d) FHT-based Kac random rotation +// +// Requires dimension % 4 == 0 (scalar tails handle SIMD remainder). +// When dimension is a power of 2, uses 4 rounds of (flip -> FHT -> rescale). +// When dimension is NOT a power of 2, uses kacs_walk reduction. +// ============================================================================ + +struct FhtKacRotatorImpl { + std::vector flip; + size_t trunc_dim{0}; + float fac{0}; + + static constexpr size_t kByteLen = 8; + + void init(size_t dim); + void rotate(const float *in, float *out, size_t dim) const; + void unrotate(const float *in, float *out, size_t dim) const; + void save(char *data) const; + void load(const char *data); + size_t dump_bytes() const; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/quantizer/rotator/matrix_rotator.cc b/src/core/quantizer/rotator/matrix_rotator.cc new file mode 100644 index 000000000..d5e63f348 --- /dev/null +++ b/src/core/quantizer/rotator/matrix_rotator.cc @@ -0,0 +1,168 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "matrix_rotator.h" +#include +#include +#include +#include +#include + +namespace zvec { +namespace core { + +namespace { + +// Generate a dim x dim random Gaussian matrix (row-major) without Eigen. +void random_gaussian_matrix(float *mat, size_t dim) { + static std::random_device rd; + static std::mt19937 gen(rd()); + std::normal_distribution dist(0.0f, 1.0f); + for (size_t i = 0; i < dim * dim; ++i) { + mat[i] = dist(gen); + } +} + +// Householder QR decomposition: A = Q * R. +// Computes the orthogonal matrix Q from input matrix A (row-major, dim x dim). +// Result is stored in q (row-major, dim x dim). +// +// Implemented manually to avoid rabitqlib/Eigen dependency whose ISA-sensitive +// inline functions cause ODR violations (duplicate codegen with different +// -march flags) leading to SEGFAULT on linux-x64-clang. +void householder_qr(const float *A_in, float *q, size_t dim) { + // R starts as a copy of A + std::vector R(A_in, A_in + dim * dim); + + // Q starts as identity + std::fill(q, q + dim * dim, 0.0f); + for (size_t i = 0; i < dim; ++i) { + q[i * dim + i] = 1.0f; + } + + std::vector v(dim); + + for (size_t k = 0; k < dim; ++k) { + // x = R[k:dim, k] (sub-column below and including diagonal) + float norm_x_sq = 0.0f; + for (size_t i = k; i < dim; ++i) { + norm_x_sq += R[i * dim + k] * R[i * dim + k]; + } + if (norm_x_sq == 0.0f) continue; + + float norm_x = std::sqrt(norm_x_sq); + + // alpha = -sign(R[k][k]) * ||x|| (choose sign to avoid cancellation) + float alpha = (R[k * dim + k] >= 0.0f) ? -norm_x : norm_x; + + // v = x - alpha * e1 (only the sub-vector [k, dim) is non-zero) + for (size_t i = k; i < dim; ++i) { + v[i - k] = R[i * dim + k]; + } + v[0] -= alpha; + + // Normalize v + float v_norm_sq = 0.0f; + for (size_t i = 0; i < dim - k; ++i) { + v_norm_sq += v[i] * v[i]; + } + if (v_norm_sq == 0.0f) continue; + float inv_v_norm = 1.0f / std::sqrt(v_norm_sq); + for (size_t i = 0; i < dim - k; ++i) { + v[i] *= inv_v_norm; + } + + // Apply Householder reflection to R: R[k:dim, k:dim] -= 2*v*(v^T * R) + for (size_t j = k; j < dim; ++j) { + float dot = 0.0f; + for (size_t i = 0; i < dim - k; ++i) { + dot += v[i] * R[(k + i) * dim + j]; + } + dot *= 2.0f; + for (size_t i = 0; i < dim - k; ++i) { + R[(k + i) * dim + j] -= v[i] * dot; + } + } + + // Accumulate Q: Q[:, k:dim] -= 2*(Q[:, k:dim] * v) * v^T + for (size_t i = 0; i < dim; ++i) { + float dot = 0.0f; + for (size_t j = 0; j < dim - k; ++j) { + dot += q[i * dim + k + j] * v[j]; + } + dot *= 2.0f; + for (size_t j = 0; j < dim - k; ++j) { + q[i * dim + k + j] -= dot * v[j]; + } + } + } +} + +} // anonymous namespace + +void MatrixRotatorImpl::init(size_t dim) { + // Generate dim x dim random Gaussian matrix + std::vector rand_mat(dim * dim); + random_gaussian_matrix(rand_mat.data(), dim); + + // Householder QR: A = Q * R, use Q^T as the rotation matrix + std::vector Q(dim * dim); + householder_qr(rand_mat.data(), Q.data(), dim); + + // Store Q^T (transpose) as the rotation matrix + matrix.resize(dim * dim); + for (size_t i = 0; i < dim; ++i) { + for (size_t j = 0; j < dim; ++j) { + matrix[j * dim + i] = Q[i * dim + j]; + } + } +} + +void MatrixRotatorImpl::rotate(const float *in, float *out, size_t dim) const { + // out = in * matrix (1 x dim) * (dim x dim) -> (1 x dim) + for (size_t j = 0; j < dim; ++j) { + float sum = 0.0f; + for (size_t i = 0; i < dim; ++i) { + sum += in[i] * matrix[i * dim + j]; + } + out[j] = sum; + } +} + +void MatrixRotatorImpl::unrotate(const float *in, float *out, + size_t dim) const { + // out = in * matrix^T (1 x dim) * (dim x dim)^T -> (1 x dim) + for (size_t j = 0; j < dim; ++j) { + float sum = 0.0f; + for (size_t i = 0; i < dim; ++i) { + sum += in[i] * matrix[j * dim + i]; + } + out[j] = sum; + } +} + +void MatrixRotatorImpl::save(char *data) const { + std::memcpy(data, matrix.data(), matrix.size() * sizeof(float)); +} + +void MatrixRotatorImpl::load(const char *data) { + std::memcpy(matrix.data(), data, matrix.size() * sizeof(float)); +} + +size_t MatrixRotatorImpl::dump_bytes() const { + return matrix.size() * sizeof(float); +} + +} // namespace core +} // namespace zvec diff --git a/src/core/quantizer/rotator/matrix_rotator.h b/src/core/quantizer/rotator/matrix_rotator.h new file mode 100644 index 000000000..90b43ceca --- /dev/null +++ b/src/core/quantizer/rotator/matrix_rotator.h @@ -0,0 +1,41 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +namespace zvec { +namespace core { + +// ============================================================================ +// MatrixRotatorImpl - O(d^2) random orthogonal matrix rotation +// +// No alignment requirement on dimension. Uses a dim x dim square orthogonal +// matrix generated via Householder QR on a random Gaussian matrix. +// ============================================================================ + +struct MatrixRotatorImpl { + std::vector matrix; // dim x dim, row-major + + void init(size_t dim); + void rotate(const float *in, float *out, size_t dim) const; + void unrotate(const float *in, float *out, size_t dim) const; + void save(char *data) const; + void load(const char *data); + size_t dump_bytes() const; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/quantizer/rotator/rotator.cc b/src/core/quantizer/rotator/rotator.cc new file mode 100644 index 000000000..6b13ab9b3 --- /dev/null +++ b/src/core/quantizer/rotator/rotator.cc @@ -0,0 +1,468 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rotator.h" +#include +#include +#include +#include +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_logger.h" +#include "fht_rotator.h" +#include "matrix_rotator.h" + +namespace zvec { +namespace core { + +namespace { + +//! Largest power-of-2 not exceeding n. +size_t floor_pow2(size_t n) { + size_t p = 1; + while ((p << 1) <= n) p <<= 1; + return p; +} + +//! Read a little-endian uint32 from raw bytes. +uint32_t read_u32_le(const char *p) { + return static_cast(static_cast(p[0])) | + (static_cast(static_cast(p[1])) << 8) | + (static_cast(static_cast(p[2])) << 16) | + (static_cast(static_cast(p[3])) << 24); +} + +//! Write a uint32 in little-endian to raw bytes. +void write_u32_le(char *p, uint32_t v) { + p[0] = static_cast(v & 0xFF); + p[1] = static_cast((v >> 8) & 0xFF); + p[2] = static_cast((v >> 16) & 0xFF); + p[3] = static_cast((v >> 24) & 0xFF); +} + +//! Read a little-endian uint16 from raw bytes. +uint16_t read_u16_le(const char *p) { + return static_cast(static_cast(p[0])) | + (static_cast(static_cast(p[1])) << 8); +} + +//! Write a uint16 in little-endian to raw bytes. +void write_u16_le(char *p, uint16_t v) { + p[0] = static_cast(v & 0xFF); + p[1] = static_cast((v >> 8) & 0xFF); +} + +} // anonymous namespace + +// ============================================================================ +// RecordRotator::Impl +// ============================================================================ + +struct RecordRotator::Impl { + //! New header layout (24 bytes, self-describing with magic): + //! magic(4B) + version(2B) + rotator_type(2B) + in_dim(4B) + //! + out_dim(4B) + payload_size(4B) + reserved(4B) = 24B + //! Legacy 12B format is auto-detected via magic mismatch in open(). + static constexpr size_t kHeaderSize = 24; + static constexpr size_t kLegacyHeaderSize = 12; + static constexpr uint32_t kMagic = 0x52544F52; // "ROTR" + static constexpr uint16_t kVersion = 1; + + struct Header { + uint32_t magic; + uint16_t version; + uint16_t rotator_type; // serialized: 0=Matrix, 1=Fht + uint32_t in_dim; + uint32_t out_dim; + uint32_t payload_size; + uint32_t reserved; + + //! RecordRotatorType -> serialized rotator_type + static uint16_t type_to_ser(RecordRotatorType t) { + return t == RecordRotatorType::Matrix ? 0 : 1; + } + //! serialized rotator_type -> RecordRotatorType + static RecordRotatorType ser_to_type(uint16_t s) { + return s == 0 ? RecordRotatorType::Matrix : RecordRotatorType::FhtKac; + } + + void write_to(char *buf) const { + write_u32_le(buf + 0, magic); + write_u16_le(buf + 4, version); + write_u16_le(buf + 6, rotator_type); + write_u32_le(buf + 8, in_dim); + write_u32_le(buf + 12, out_dim); + write_u32_le(buf + 16, payload_size); + write_u32_le(buf + 20, reserved); + } + + void read_from(const char *buf) { + magic = read_u32_le(buf + 0); + version = read_u16_le(buf + 4); + rotator_type = read_u16_le(buf + 6); + in_dim = read_u32_le(buf + 8); + out_dim = read_u32_le(buf + 12); + payload_size = read_u32_le(buf + 16); + reserved = read_u32_le(buf + 20); + } + }; + + size_t dimension{0}; + RecordRotatorType type{RecordRotatorType::FhtKac}; + + std::unique_ptr fht_impl; + std::unique_ptr mat_impl; + + void do_rotate(const float *in, float *out) const { + if (fht_impl) { + fht_impl->rotate(in, out, dimension); + } else { + mat_impl->rotate(in, out, dimension); + } + } + + void do_unrotate(const float *in, float *out) const { + if (fht_impl) { + fht_impl->unrotate(in, out, dimension); + } else { + mat_impl->unrotate(in, out, dimension); + } + } + + size_t blob_bytes() const { + if (fht_impl) return fht_impl->dump_bytes(); + return mat_impl->dump_bytes(); + } + + void save_blob(char *data) const { + if (fht_impl) { + fht_impl->save(data); + } else { + mat_impl->save(data); + } + } + + void load_blob(const char *data) { + if (fht_impl) { + fht_impl->load(data); + } else { + mat_impl->load(data); + } + } +}; + +// ============================================================================ +// RecordRotator public methods +// ============================================================================ + +RecordRotator::RecordRotator() : impl_(std::make_unique()) {} + +RecordRotator::~RecordRotator() = default; + +RecordRotator::RecordRotator(RecordRotator &&) noexcept = default; +RecordRotator &RecordRotator::operator=(RecordRotator &&) noexcept = default; + +void RecordRotator::init(size_t dimension, RecordRotatorType rotator_type) { + impl_->dimension = dimension; + + // Auto-select implementation based on dimension alignment when FhtKac + // is requested. FhtKac requires the dimension to be a multiple of 4; + // scalar tails handle the SIMD remainder. When the dimension is not + // 4-aligned we transparently fall back to the O(d^2) Matrix rotator. + bool use_fht = + (rotator_type == RecordRotatorType::FhtKac) && (dimension % 4 == 0); + + if (use_fht) { + impl_->type = RecordRotatorType::FhtKac; + impl_->fht_impl = std::make_unique(); + impl_->fht_impl->trunc_dim = floor_pow2(dimension); + impl_->fht_impl->fac = + 1.0f / std::sqrt(static_cast(impl_->fht_impl->trunc_dim)); + impl_->fht_impl->init(dimension); + } else { + if (rotator_type == RecordRotatorType::FhtKac) { + LOG_DEBUG( + "RecordRotator::init: dimension %zu is not 4-aligned, " + "falling back from FhtKac to Matrix rotator", + dimension); + } + impl_->type = RecordRotatorType::Matrix; + impl_->mat_impl = std::make_unique(); + impl_->mat_impl->init(dimension); + } +} + +void RecordRotator::rotate(const float *in, float *out) const { + impl_->do_rotate(in, out); +} + +std::vector RecordRotator::rotate(const float *in) const { + std::vector out(impl_->dimension); + impl_->do_rotate(in, out.data()); + return out; +} + +void RecordRotator::unrotate(const float *in, float *out) const { + if (!impl_->fht_impl && !impl_->mat_impl) { + LOG_ERROR("RecordRotator::unrotate: rotator not initialized"); + return; + } + impl_->do_unrotate(in, out); +} + +std::vector RecordRotator::unrotate(const float *in) const { + std::vector out(impl_->dimension); + unrotate(in, out.data()); + return out; +} + +size_t RecordRotator::dump_bytes() const { + return Impl::kHeaderSize + impl_->blob_bytes(); +} + +int RecordRotator::dump(const IndexStorage::Pointer &storage, + const std::string &seg_id) const { + if (!storage) { + LOG_ERROR("RecordRotator::dump(storage): null storage"); + return IndexError_InvalidArgument; + } + if (!impl_->fht_impl && !impl_->mat_impl) { + LOG_ERROR("RecordRotator::dump(storage): rotator not initialized"); + return IndexError_NoReady; + } + + auto align_size = [](size_t size) -> size_t { + return (size + 0x1F) & (~0x1F); + }; + + // Serialize: [RotatorSerHeader (24B)] [payload blob] + const size_t blob_size = impl_->blob_bytes(); + const size_t data_size = Impl::kHeaderSize + blob_size; + const size_t total_size = align_size(data_size); + std::vector buffer(data_size); + + Impl::Header header; + header.magic = Impl::kMagic; + header.version = Impl::kVersion; + header.rotator_type = Impl::Header::type_to_ser(impl_->type); + header.in_dim = static_cast(impl_->dimension); + header.out_dim = static_cast(impl_->dimension); + header.payload_size = static_cast(blob_size); + header.reserved = 0; + header.write_to(buffer.data()); + impl_->save_blob(buffer.data() + Impl::kHeaderSize); + + // Append segment to storage + int ret = storage->append(seg_id, total_size); + if (ret != 0) { + LOG_ERROR( + "RecordRotator::dump(storage): append segment '%s' failed, ret=%d", + seg_id.c_str(), ret); + return ret; + } + + auto segment = storage->get(seg_id); + if (!segment) { + LOG_ERROR("RecordRotator::dump(storage): get segment '%s' failed", + seg_id.c_str()); + return IndexError_WriteData; + } + + size_t written = segment->write(0, buffer.data(), data_size); + if (written != data_size) { + LOG_ERROR( + "RecordRotator::dump(storage): write failed, written=%zu, expected=%zu", + written, data_size); + return IndexError_WriteData; + } + segment->resize(data_size); + segment->update_data_crc(ailego::Crc32c::Hash(buffer.data(), data_size, 0)); + + LOG_DEBUG( + "RecordRotator::dump(storage) done: seg=%s, data_size=%zu, total=%zu", + seg_id.c_str(), data_size, total_size); + return 0; +} + +int RecordRotator::dump(const IndexDumper::Pointer &dumper, + const std::string &seg_id) const { + if (!dumper) { + LOG_ERROR("RecordRotator::dump(dumper): null dumper"); + return IndexError_InvalidArgument; + } + if (!impl_->fht_impl && !impl_->mat_impl) { + LOG_ERROR("RecordRotator::dump(dumper): rotator not initialized"); + return IndexError_NoReady; + } + + // Serialize: [RotatorSerHeader (24B)] [payload blob] + const size_t blob_size = impl_->blob_bytes(); + const size_t data_size = Impl::kHeaderSize + blob_size; + const size_t total_size = (data_size + 0x1F) & (~0x1F); + + std::vector buffer(total_size, 0); + Impl::Header header; + header.magic = Impl::kMagic; + header.version = Impl::kVersion; + header.rotator_type = Impl::Header::type_to_ser(impl_->type); + header.in_dim = static_cast(impl_->dimension); + header.out_dim = static_cast(impl_->dimension); + header.payload_size = static_cast(blob_size); + header.reserved = 0; + header.write_to(buffer.data()); + impl_->save_blob(buffer.data() + Impl::kHeaderSize); + + const uint32_t crc = ailego::Crc32c::Hash(buffer.data(), data_size, 0); + const size_t padding_size = total_size - data_size; + + // Write data + padding to dumper + if (dumper->write(buffer.data(), total_size) != total_size) { + LOG_ERROR("RecordRotator::dump(dumper): write failed, seg=%s", + seg_id.c_str()); + return IndexError_WriteData; + } + + // Register segment + int ret = dumper->append(seg_id, data_size, padding_size, crc); + if (ret != 0) { + LOG_ERROR("RecordRotator::dump(dumper): append failed, seg=%s, ret=%d", + seg_id.c_str(), ret); + return ret; + } + + LOG_DEBUG( + "RecordRotator::dump(dumper) done: seg=%s, data_size=%zu, padding=%zu", + seg_id.c_str(), data_size, padding_size); + return 0; +} + +int RecordRotator::open(IndexStorage::Pointer storage, + const std::string &seg_id) { + if (!storage) { + LOG_ERROR("RecordRotator::open: null storage"); + return IndexError_InvalidArgument; + } + + auto segment = storage->get(seg_id); + if (!segment) { + LOG_ERROR("RecordRotator::open: segment '%s' not found", seg_id.c_str()); + return IndexError_InvalidFormat; + } + + // Read the rotator data from the segment (header + blob) + const size_t data_size = segment->data_size(); + if (data_size <= Impl::kLegacyHeaderSize) { + LOG_ERROR("RecordRotator::open: data too small (%zu bytes)", data_size); + return IndexError_InvalidFormat; + } + + IndexStorage::MemoryBlock block; + size_t read_size = segment->read(0, block, data_size); + if (read_size != data_size) { + LOG_ERROR("RecordRotator::open: read failed, read=%zu, expected=%zu", + read_size, data_size); + return IndexError_InvalidFormat; + } + + // Verify CRC if available (covers header + blob) + uint32_t expected_crc = segment->data_crc(); + if (expected_crc != 0) { + uint32_t actual_crc = ailego::Crc32c::Hash(block.data(), data_size, 0); + if (actual_crc != expected_crc) { + LOG_ERROR( + "RecordRotator::open: CRC mismatch, expected=0x%08x, actual=0x%08x", + expected_crc, actual_crc); + return IndexError_InvalidFormat; + } + } + + // Detect format version via magic, then parse header accordingly + const char *raw = reinterpret_cast(block.data()); + uint32_t maybe_magic = read_u32_le(raw); + size_t header_size = 0; + + if (maybe_magic == Impl::kMagic) { + // New format (24B header) + if (data_size <= Impl::kHeaderSize) { + LOG_ERROR("RecordRotator::open: new-format data too small (%zu bytes)", + data_size); + return IndexError_InvalidFormat; + } + Impl::Header header; + header.read_from(raw); + impl_->type = Impl::Header::ser_to_type(header.rotator_type); + impl_->dimension = static_cast(header.in_dim); + header_size = Impl::kHeaderSize; + } else { + // Legacy format fallback (12B header) + impl_->type = static_cast(static_cast(raw[0])); + impl_->dimension = static_cast(read_u32_le(raw + 4)); + header_size = Impl::kLegacyHeaderSize; + } + + // Reconstruct the rotator from header info and load blob + if (impl_->type == RecordRotatorType::FhtKac) { + impl_->fht_impl = std::make_unique(); + impl_->fht_impl->flip.resize(4 * impl_->dimension / + FhtKacRotatorImpl::kByteLen); + impl_->fht_impl->trunc_dim = floor_pow2(impl_->dimension); + impl_->fht_impl->fac = + 1.0f / std::sqrt(static_cast(impl_->fht_impl->trunc_dim)); + impl_->fht_impl->load(raw + header_size); + } else { + impl_->mat_impl = std::make_unique(); + impl_->mat_impl->matrix.resize(impl_->dimension * impl_->dimension); + impl_->mat_impl->load(raw + header_size); + } + + LOG_DEBUG("RecordRotator::open done: seg=%s, dim=%zu, data_size=%zu", + seg_id.c_str(), impl_->dimension, data_size); + + return 0; +} + +int RecordRotator::load(const float *matrix, size_t dimension) { + if (!matrix) { + LOG_ERROR("RecordRotator::load: null matrix"); + return IndexError_InvalidArgument; + } + if (dimension == 0) { + LOG_ERROR("RecordRotator::load: invalid dim %zu", dimension); + return IndexError_InvalidArgument; + } + + impl_->dimension = dimension; + impl_->type = RecordRotatorType::Matrix; + impl_->mat_impl = std::make_unique(); + impl_->mat_impl->matrix.resize(dimension * dimension); + impl_->mat_impl->load(reinterpret_cast(matrix)); + + LOG_DEBUG("RecordRotator::load done: dim=%zu", dimension); + + return 0; +} + +size_t RecordRotator::dimension() const { + return impl_->dimension; +} + +RecordRotatorType RecordRotator::rotator_type() const { + return impl_->type; +} + +bool RecordRotator::initialized() const { + return impl_->fht_impl != nullptr || impl_->mat_impl != nullptr; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/quantizer/rotator/rotator.h b/src/core/quantizer/rotator/rotator.h new file mode 100644 index 000000000..038f87e78 --- /dev/null +++ b/src/core/quantizer/rotator/rotator.h @@ -0,0 +1,132 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include "zvec/core/framework/index_dumper.h" +#include "zvec/core/framework/index_storage.h" + +namespace zvec { +namespace core { + +//! Segment ID used when dumping/loading the rotator data +inline const std::string RECORD_ROTATOR_SEG_ID{"enable_rotate"}; + +//! Rotator type exposed without rabitqlib dependency +enum class RecordRotatorType : uint8_t { + FhtKac = 0, //!< O(d log d) FHT-based Kac random rotation (default) + Matrix = 1, //!< O(d^2) explicit random matrix rotation +}; + +/*! RecordRotator provides per-vector rotation without external dependencies. + * + * All rotation algorithms are implemented inline (FHT-based Kac walk and + * explicit random matrix), so no rabitqlib headers are required. + * + * Auto-selects the rotation algorithm based on dimension alignment: + * - dimension % 4 == 0 -> FhtKac (O(d log d), with scalar tails) + * - otherwise -> Matrix (O(d^2), no alignment requirement) + * + * Rotation preserves dimension: output size == input size (no padding). + * + * Used by IntegerStreamingConverter/Reformer and CosineConverter/Reformer + * when enable_rotate is true. + */ +class RecordRotator { + public: + RecordRotator(); + ~RecordRotator(); + + //! Move-only (pimpl with unique_ptr) + RecordRotator(RecordRotator &&) noexcept; + RecordRotator &operator=(RecordRotator &&) noexcept; + RecordRotator(const RecordRotator &) = delete; + RecordRotator &operator=(const RecordRotator &) = delete; + + //! Initialize the rotator. + //! Auto-selects FhtKac when dimension is 4-aligned, else falls back to + //! Matrix. The @p rotator_type parameter can force Matrix explicitly. + //! @param dimension vector dimension (input and output size) + //! @param rotator_type rotation algorithm (default: FhtKac, auto-degrades + //! to Matrix when dimension is not 4-aligned) + void init(size_t dimension, + RecordRotatorType rotator_type = RecordRotatorType::FhtKac); + + //! Rotate a single vector + //! @param in input vector of size >= dimension + //! @param out output buffer of size >= dimension + void rotate(const float *in, float *out) const; + + //! Rotate a single vector into a managed buffer + //! @param in input vector of size >= dimension + //! @return vector of size dimension containing rotated result + std::vector rotate(const float *in) const; + + //! Inverse-rotate a single vector (from rotated space back to original) + //! @param in input vector of size >= dimension (rotated vector) + //! @param out output buffer of size >= dimension (original space) + void unrotate(const float *in, float *out) const; + + //! Inverse-rotate a single vector into a managed buffer + //! @param in input vector of size >= dimension (rotated vector) + //! @return vector of size dimension containing inverse-rotated + //! result + std::vector unrotate(const float *in) const; + + //! Return the serialized size of the rotator in bytes (header + blob) + size_t dump_bytes() const; + + //! Dump the rotator to an IndexStorage as a named segment. + //! Same self-describing format as the dumper variant. + int dump(const IndexStorage::Pointer &storage, + const std::string &seg_id = RECORD_ROTATOR_SEG_ID) const; + + //! Dump the rotator to an IndexDumper as a named segment. + //! Format: [RotatorSerHeader (24B): magic|version|rotator_type|in_dim| + //! out_dim|payload_size|reserved] [payload blob] + //! Appends padding for 32-byte alignment. + int dump(const IndexDumper::Pointer &dumper, + const std::string &seg_id = RECORD_ROTATOR_SEG_ID) const; + + //! Open the rotator from an IndexStorage segment (self-describing, no init + //! needed). Parses header to get type/dimension, then reconstructs the + //! rotator. + int open(IndexStorage::Pointer storage, + const std::string &seg_id = RECORD_ROTATOR_SEG_ID); + + //! Load a user-specified rotation matrix. + //! Always uses MatrixRotator internally. + //! @param matrix row-major square matrix of shape dimension x dimension + //! @param dimension vector dimension + int load(const float *matrix, size_t dimension); + + //! Return the vector dimension + size_t dimension() const; + + //! Return the rotator type + RecordRotatorType rotator_type() const; + + //! Check if the rotator is initialized + bool initialized() const; + + private: + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/db/index/column/vector_column/engine_helper.hpp b/src/db/index/column/vector_column/engine_helper.hpp index f7b727b7d..7fa0c16d3 100644 --- a/src/db/index/column/vector_column/engine_helper.hpp +++ b/src/db/index/column/vector_column/engine_helper.hpp @@ -352,6 +352,7 @@ class ProximaEngineHelper { return tl::make_unexpected( Status::InvalidArgument("unsupported quantize type")); } + index_param_builder->WithEnableRotate(db_index_params->enable_rotate()); return index_param_builder; } diff --git a/src/db/index/common/proto_converter.cc b/src/db/index/common/proto_converter.cc index faf0cf0e3..80b4c61ca 100644 --- a/src/db/index/common/proto_converter.cc +++ b/src/db/index/common/proto_converter.cc @@ -18,11 +18,12 @@ namespace zvec { HnswIndexParams::OPtr ProtoConverter::FromPb( const proto::HnswIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); auto params = std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.m(), params_pb.ef_construction(), QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), - params_pb.use_contiguous_memory()); + params_pb.use_contiguous_memory(), QuantizerParam(enable_rotate)); return params; } @@ -33,6 +34,8 @@ proto::HnswIndexParams ProtoConverter::ToPb(const HnswIndexParams *params) { MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); params_pb.set_ef_construction(params->ef_construction()); params_pb.set_m(params->m()); params_pb.set_use_contiguous_memory(params->use_contiguous_memory()); @@ -68,9 +71,11 @@ proto::HnswRabitqIndexParams ProtoConverter::ToPb( // FlatIndexParams FlatIndexParams::OPtr ProtoConverter::FromPb( const proto::FlatIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); return std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), - QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + QuantizerParam(enable_rotate)); } proto::FlatIndexParams ProtoConverter::ToPb(const FlatIndexParams *params) { @@ -79,16 +84,20 @@ proto::FlatIndexParams ProtoConverter::ToPb(const FlatIndexParams *params) { MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); return params_pb; } // IVFIndexParams IVFIndexParams::OPtr ProtoConverter::FromPb( const proto::IVFIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); return std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.n_list(), params_pb.n_iters(), params_pb.use_soar(), - QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + QuantizerParam(enable_rotate)); } proto::IVFIndexParams ProtoConverter::ToPb(const IVFIndexParams *params) { @@ -97,6 +106,8 @@ proto::IVFIndexParams ProtoConverter::ToPb(const IVFIndexParams *params) { MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); params_pb.set_n_list(params->n_list()); params_pb.set_n_iters(params->n_iters()); params_pb.set_use_soar(params->use_soar()); @@ -106,12 +117,14 @@ proto::IVFIndexParams ProtoConverter::ToPb(const IVFIndexParams *params) { // VamanaIndexParams VamanaIndexParams::OPtr ProtoConverter::FromPb( const proto::VamanaIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); return std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.max_degree(), params_pb.search_list_size(), params_pb.alpha(), params_pb.saturate_graph(), params_pb.use_contiguous_memory(), params_pb.use_id_map(), - QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + QuantizerParam(enable_rotate)); } proto::VamanaIndexParams ProtoConverter::ToPb(const VamanaIndexParams *params) { @@ -120,6 +133,8 @@ proto::VamanaIndexParams ProtoConverter::ToPb(const VamanaIndexParams *params) { MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); params_pb.set_max_degree(params->max_degree()); params_pb.set_search_list_size(params->search_list_size()); params_pb.set_alpha(params->alpha()); @@ -147,10 +162,12 @@ proto::InvertIndexParams ProtoConverter::ToPb(const InvertIndexParams *params) { // DiskAnnIndexParams DiskAnnIndexParams::OPtr ProtoConverter::FromPb( const proto::DiskAnnIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); return std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.max_degree(), params_pb.list_size(), params_pb.pq_chunk_num(), - QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + QuantizerParam(enable_rotate)); } proto::DiskAnnIndexParams ProtoConverter::ToPb( @@ -160,6 +177,8 @@ proto::DiskAnnIndexParams ProtoConverter::ToPb( MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); params_pb.set_max_degree(params->max_degree()); params_pb.set_list_size(params->list_size()); params_pb.set_pq_chunk_num(params->pq_chunk_num()); diff --git a/src/db/index/segment/segment.cc b/src/db/index/segment/segment.cc index ca5d3adb3..e21ececa2 100644 --- a/src/db/index/segment/segment.cc +++ b/src/db/index/segment/segment.cc @@ -4036,7 +4036,8 @@ Status SegmentImpl::load_vector_index_blocks() { if (!segment_meta_->vector_indexed(column)) { new_field_params.set_index_params(MakeDefaultQuantVectorIndexParams( vector_index_params->metric_type(), - vector_index_params->quantize_type())); + vector_index_params->quantize_type(), + vector_index_params->quantizer_param())); } } @@ -4171,7 +4172,8 @@ Status SegmentImpl::init_memory_components() { block_id = allocate_block_id(); FieldSchema normal_quant_field(*field); normal_quant_field.set_index_params(MakeDefaultQuantVectorIndexParams( - index_params->metric_type(), index_params->quantize_type())); + index_params->metric_type(), index_params->quantize_type(), + index_params->quantizer_param())); auto quant_vector_indexer = create_vector_indexer( field->name(), normal_quant_field, block_id, true); diff --git a/src/db/proto/zvec.proto b/src/db/proto/zvec.proto index ad6cfb158..f2c18f5ad 100644 --- a/src/db/proto/zvec.proto +++ b/src/db/proto/zvec.proto @@ -87,9 +87,19 @@ message InvertIndexParams { bool enable_range_optimization = 1; }; +// Quantizer-related parameters for vector indexes. +// Designed for future extensibility. +message QuantizerParam { + // When enabled, vectors are rotated before INT8 quantization to reduce + // quantization error. Only effective with quantize_type=INT8. + bool enable_rotate = 1; +}; + message BaseIndexParams { MetricType metric_type = 1; QuantizeType quantize_type = 2; + // Quantizer parameters (enable_rotate, etc.) + QuantizerParam quantizer_param = 4; }; message HnswIndexParams { diff --git a/src/include/zvec/c_api.h b/src/include/zvec/c_api.h index d02335cb3..3f3e38638 100644 --- a/src/include/zvec/c_api.h +++ b/src/include/zvec/c_api.h @@ -961,6 +961,30 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_index_params_set_quantize_type( ZVEC_EXPORT zvec_quantize_type_t ZVEC_CALL zvec_index_params_get_quantize_type(const zvec_index_params_t *params); +/** + * @brief Set enable_rotate for quantizer (only effective with INT8 quantize + * type) + * + * When enabled, vectors are randomly rotated before INT8 quantization to + * reduce quantization error. The rotation matrix is stored with the index + * and automatically applied to query vectors at search time. + * + * @param params Index parameters (must be vector index type) + * @param enable_rotate Whether to enable random rotation before quantization + * @return ZVEC_OK on success, error code on failure + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_index_params_set_quantizer_enable_rotate(zvec_index_params_t *params, + bool enable_rotate); + +/** + * @brief Get enable_rotate setting from quantizer parameters + * @param params Index parameters (must not be NULL) + * @return true if rotation is enabled, false otherwise (default) + */ +ZVEC_EXPORT bool ZVEC_CALL zvec_index_params_get_quantizer_enable_rotate( + const zvec_index_params_t *params); + /** * @brief Set HNSW specific parameters * @param params Index parameters (must be HNSW type) diff --git a/src/include/zvec/core/framework/index_converter.h b/src/include/zvec/core/framework/index_converter.h index 53ac1c7a2..4dc26468f 100644 --- a/src/include/zvec/core/framework/index_converter.h +++ b/src/include/zvec/core/framework/index_converter.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "zvec/core/framework/index_reformer.h" namespace zvec { @@ -196,6 +197,13 @@ class IndexConverter : public IndexModule { //! Dump index into storage virtual int dump(const IndexDumper::Pointer &dumper) = 0; + //! Dump converter state (e.g. rotator) to IndexStorage for streaming build. + //! Default is no-op; override in subclasses that need storage persistence. + virtual int dump_to_storage(const IndexStorage::Pointer &storage) { + (void)storage; + return 0; + } + //! Retrieve statistics virtual const Stats &stats(void) const = 0; diff --git a/src/include/zvec/core/interface/index_param.h b/src/include/zvec/core/interface/index_param.h index 186c160f6..5d4e8a206 100644 --- a/src/include/zvec/core/interface/index_param.h +++ b/src/include/zvec/core/interface/index_param.h @@ -122,12 +122,17 @@ struct QuantizerParam : public SerializableBase { QuantizerType type = QuantizerType::kNone; int num_subquantizers = 8; // M int num_bits = 8; // bits per subquantizer + bool enable_rotate = + false; // rotate vectors before quantization to reduce error // Constructors // QuantizerParam() = default; QuantizerParam(QuantizerType t = QuantizerType::kNone, int subquantizers = 8, - int bits = 8) - : type(t), num_subquantizers(subquantizers), num_bits(bits) {} + int bits = 8, bool rotate = false) + : type(t), + num_subquantizers(subquantizers), + num_bits(bits), + enable_rotate(rotate) {} protected: diff --git a/src/include/zvec/core/interface/index_param_builders.h b/src/include/zvec/core/interface/index_param_builders.h index 328e60b11..d88057a93 100644 --- a/src/include/zvec/core/interface/index_param_builders.h +++ b/src/include/zvec/core/interface/index_param_builders.h @@ -87,6 +87,11 @@ class BaseIndexParamBuilder { // : public return static_cast(*this); } + ActualIndexParamBuilderType &WithEnableRotate(bool enable_rotate) { + param->quantizer_param.enable_rotate = enable_rotate; + return static_cast(*this); + } + ActualIndexParamBuilderType &WithUseExternalVector(bool use_external_vector) { param->use_external_vector = use_external_vector; return static_cast(*this); diff --git a/src/include/zvec/db/index_params.h b/src/include/zvec/db/index_params.h index c19cf8028..a4c2654d9 100644 --- a/src/include/zvec/db/index_params.h +++ b/src/include/zvec/db/index_params.h @@ -118,16 +118,50 @@ class InvertIndexParams : public IndexParams { bool enable_extended_wildcard_{false}; }; +/* + * Quantizer parameters for vector indexes. + * Encapsulates quantization-related settings such as enable_rotate. + * Designed for future extensibility (e.g., num_bits, calibration_size). + */ +class QuantizerParam { + public: + QuantizerParam() = default; + explicit QuantizerParam(bool enable_rotate) : enable_rotate_(enable_rotate) {} + + bool enable_rotate() const { + return enable_rotate_; + } + + void set_enable_rotate(bool v) { + enable_rotate_ = v; + } + + bool operator==(const QuantizerParam &other) const { + return enable_rotate_ == other.enable_rotate_; + } + + bool operator!=(const QuantizerParam &other) const { + return !(*this == other); + } + + private: + // When enabled, vectors are rotated before INT8 quantization to reduce + // quantization error. Only effective with quantize_type=INT8. + bool enable_rotate_{false}; +}; + /* * Column index params */ class VectorIndexParams : public IndexParams { public: VectorIndexParams(IndexType type, MetricType metric_type, - QuantizeType quantize_type = QuantizeType::UNDEFINED) + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) : IndexParams(type), metric_type_(metric_type), - quantize_type_(quantize_type) {} + quantize_type_(quantize_type), + quantizer_param_(quantizer_param) {} ~VectorIndexParams() override = default; @@ -151,9 +185,23 @@ class VectorIndexParams : public IndexParams { quantize_type_ = quantize_type; } + const QuantizerParam &quantizer_param() const { + return quantizer_param_; + } + + void set_quantizer_param(const QuantizerParam &quantizer_param) { + quantizer_param_ = quantizer_param; + } + + // Convenience getter for internal use (engine_helper, segment, etc.) + bool enable_rotate() const { + return quantizer_param_.enable_rotate(); + } + protected: MetricType metric_type_; QuantizeType quantize_type_; + QuantizerParam quantizer_param_; }; /* @@ -165,8 +213,9 @@ class HnswIndexParams : public VectorIndexParams { MetricType metric_type, int m = core_interface::kDefaultHnswNeighborCnt, int ef_construction = core_interface::kDefaultHnswEfConstruction, QuantizeType quantize_type = QuantizeType::UNDEFINED, - bool use_contiguous_memory = false) - : VectorIndexParams(IndexType::HNSW, metric_type, quantize_type), + bool use_contiguous_memory = false, QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::HNSW, metric_type, quantize_type, + quantizer_param), m_(m), ef_construction_(ef_construction), use_contiguous_memory_(use_contiguous_memory) {} @@ -175,9 +224,9 @@ class HnswIndexParams : public VectorIndexParams { public: Ptr clone() const override { - return std::make_shared(metric_type_, m_, ef_construction_, - quantize_type_, - use_contiguous_memory_); + return std::make_shared( + metric_type_, m_, ef_construction_, quantize_type_, + use_contiguous_memory_, quantizer_param_); } std::string to_string() const override { @@ -186,7 +235,8 @@ class HnswIndexParams : public VectorIndexParams { std::ostringstream oss; oss << base_str << ",m:" << m_ << ",ef_construction:" << ef_construction_ << ",use_contiguous_memory:" - << (use_contiguous_memory_ ? "true" : "false") << "}"; + << (use_contiguous_memory_ ? "true" : "false") << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -200,7 +250,9 @@ class HnswIndexParams : public VectorIndexParams { quantize_type() == static_cast(other).quantize_type() && use_contiguous_memory_ == static_cast(other) - .use_contiguous_memory_; + .use_contiguous_memory_ && + quantizer_param_ == + static_cast(other).quantizer_param_; } void set_m(int m) { @@ -348,21 +400,25 @@ class HnswRabitqIndexParams : public VectorIndexParams { class FlatIndexParams : public VectorIndexParams { public: FlatIndexParams(MetricType metric_type, - QuantizeType quantize_type = QuantizeType::UNDEFINED) - : VectorIndexParams(IndexType::FLAT, metric_type, quantize_type) {} + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::FLAT, metric_type, quantize_type, + quantizer_param) {} using OPtr = std::shared_ptr; public: Ptr clone() const override { - return std::make_shared(metric_type_, quantize_type_); + return std::make_shared(metric_type_, quantize_type_, + quantizer_param_); } std::string to_string() const override { auto base_str = vector_index_params_to_string("FlatIndexParams", metric_type_, quantize_type_); std::ostringstream oss; - oss << base_str << "}"; + oss << base_str << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -371,7 +427,9 @@ class FlatIndexParams : public VectorIndexParams { metric_type() == static_cast(other).metric_type() && quantize_type() == - static_cast(other).quantize_type(); + static_cast(other).quantize_type() && + quantizer_param() == + static_cast(other).quantizer_param(); } }; @@ -383,16 +441,19 @@ inline FlatIndexParams MakeDefaultVectorIndexParams(MetricType metric_type) { } inline FlatIndexParams MakeDefaultQuantVectorIndexParams( - MetricType metric_type, QuantizeType quantize_type) { - return FlatIndexParams(metric_type, quantize_type); + MetricType metric_type, QuantizeType quantize_type, + QuantizerParam quantizer_param = {}) { + return FlatIndexParams(metric_type, quantize_type, quantizer_param); } class IVFIndexParams : public VectorIndexParams { public: IVFIndexParams(MetricType metric_type, int n_list = 1024, int n_iters = 10, bool use_soar = false, - QuantizeType quantize_type = QuantizeType::UNDEFINED) - : VectorIndexParams(IndexType::IVF, metric_type, quantize_type), + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::IVF, metric_type, quantize_type, + quantizer_param), n_list_(n_list), n_iters_(n_iters), use_soar_(use_soar) {} @@ -402,14 +463,17 @@ class IVFIndexParams : public VectorIndexParams { public: Ptr clone() const override { return std::make_shared(metric_type_, n_list_, n_iters_, - use_soar_, quantize_type_); + use_soar_, quantize_type_, + quantizer_param_); } std::string to_string() const override { auto base_str = vector_index_params_to_string("IVFIndexParams", metric_type_, quantize_type_); std::ostringstream oss; - oss << base_str << ",n_list:" << n_list_ << ",n_iters:" << n_iters_ << "}"; + oss << base_str << ",n_list:" << n_list_ << ",n_iters:" << n_iters_ + << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -445,7 +509,9 @@ class IVFIndexParams : public VectorIndexParams { n_iters_ == static_cast(other).n_iters_ && use_soar_ == static_cast(other).use_soar_ && quantize_type() == - static_cast(other).quantize_type(); + static_cast(other).quantize_type() && + quantizer_param_ == + static_cast(other).quantizer_param_; } private: @@ -458,8 +524,10 @@ class DiskAnnIndexParams : public VectorIndexParams { public: DiskAnnIndexParams(MetricType metric_type, int max_degree = 100, int list_size = 50, int pq_chunk_num = 0, - QuantizeType quantize_type = QuantizeType::UNDEFINED) - : VectorIndexParams(IndexType::DISKANN, metric_type, quantize_type), + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::DISKANN, metric_type, quantize_type, + quantizer_param), max_degree_{max_degree}, list_size_{list_size}, pq_chunk_num_{pq_chunk_num} {} @@ -469,7 +537,8 @@ class DiskAnnIndexParams : public VectorIndexParams { public: Ptr clone() const override { return std::make_shared( - metric_type_, max_degree_, list_size_, pq_chunk_num_, quantize_type_); + metric_type_, max_degree_, list_size_, pq_chunk_num_, quantize_type_, + quantizer_param_); } std::string to_string() const override { @@ -478,7 +547,8 @@ class DiskAnnIndexParams : public VectorIndexParams { std::ostringstream oss; oss << base_str << ",max_degree:" << max_degree_ << ",list_size:" << list_size_ << ", pq_chunk_num:" << pq_chunk_num_ - << "}"; + << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -517,7 +587,9 @@ class DiskAnnIndexParams : public VectorIndexParams { pq_chunk_num_ == static_cast(other).pq_chunk_num_ && quantize_type() == - static_cast(other).quantize_type(); + static_cast(other).quantize_type() && + quantizer_param_ == + static_cast(other).quantizer_param_; } private: @@ -538,8 +610,10 @@ class VamanaIndexParams : public VectorIndexParams { float alpha = core_interface::kDefaultVamanaAlpha, bool saturate_graph = core_interface::kDefaultVamanaSaturateGraph, bool use_contiguous_memory = false, bool use_id_map = false, - QuantizeType quantize_type = QuantizeType::UNDEFINED) - : VectorIndexParams(IndexType::VAMANA, metric_type, quantize_type), + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::VAMANA, metric_type, quantize_type, + quantizer_param), max_degree_(max_degree), search_list_size_(search_list_size), alpha_(alpha), @@ -553,7 +627,7 @@ class VamanaIndexParams : public VectorIndexParams { Ptr clone() const override { return std::make_shared( metric_type_, max_degree_, search_list_size_, alpha_, saturate_graph_, - use_contiguous_memory_, use_id_map_, quantize_type_); + use_contiguous_memory_, use_id_map_, quantize_type_, quantizer_param_); } std::string to_string() const override { @@ -565,7 +639,9 @@ class VamanaIndexParams : public VectorIndexParams { << ",saturate_graph:" << (saturate_graph_ ? "true" : "false") << ",use_contiguous_memory:" << (use_contiguous_memory_ ? "true" : "false") - << ",use_id_map:" << (use_id_map_ ? "true" : "false") << "}"; + << ",use_id_map:" << (use_id_map_ ? "true" : "false") + << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -580,7 +656,8 @@ class VamanaIndexParams : public VectorIndexParams { search_list_size_ == rhs.search_list_size_ && alpha_ == rhs.alpha_ && saturate_graph_ == rhs.saturate_graph_ && use_contiguous_memory_ == rhs.use_contiguous_memory_ && - use_id_map_ == rhs.use_id_map_; + use_id_map_ == rhs.use_id_map_ && + quantizer_param_ == rhs.quantizer_param_; } int max_degree() const { diff --git a/tests/c/c_api_test.c b/tests/c/c_api_test.c index 8670ff845..b19dfed24 100644 --- a/tests/c/c_api_test.c +++ b/tests/c/c_api_test.c @@ -3491,6 +3491,179 @@ void test_index_params_functions(void) { TEST_END(); } +void test_quantizer_enable_rotate(void) { + TEST_START(); + + // Test 1: set enable_rotate=true on HNSW params and verify + zvec_index_params_t *hnsw_params = + zvec_index_params_create(ZVEC_INDEX_TYPE_HNSW); + TEST_ASSERT(hnsw_params != NULL); + + // Default should be false + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(hnsw_params) == + false); + + // Set to true and verify + zvec_error_code_t err = + zvec_index_params_set_quantizer_enable_rotate(hnsw_params, true); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(hnsw_params) == + true); + + // Set back to false and verify + err = zvec_index_params_set_quantizer_enable_rotate(hnsw_params, false); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(hnsw_params) == + false); + + zvec_index_params_destroy(hnsw_params); + + // Test 2: set enable_rotate on FLAT index params (also a vector index) + zvec_index_params_t *flat_params = + zvec_index_params_create(ZVEC_INDEX_TYPE_FLAT); + TEST_ASSERT(flat_params != NULL); + err = zvec_index_params_set_quantizer_enable_rotate(flat_params, true); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(flat_params) == + true); + zvec_index_params_destroy(flat_params); + + // Test 3: set enable_rotate on non-vector index (INVERT) should fail + zvec_index_params_t *invert_params = + zvec_index_params_create(ZVEC_INDEX_TYPE_INVERT); + TEST_ASSERT(invert_params != NULL); + err = zvec_index_params_set_quantizer_enable_rotate(invert_params, true); + TEST_ASSERT(err != ZVEC_OK); + zvec_index_params_destroy(invert_params); + + // Test 4: NULL params should return false for getter + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(NULL) == false); + + // Test 5: NULL params should return error for setter + err = zvec_index_params_set_quantizer_enable_rotate(NULL, true); + TEST_ASSERT(err != ZVEC_OK); + + TEST_END(); +} + +void test_int8_rotate_e2e(void) { + TEST_START(); + + char temp_dir[] = "./zvec_test_int8_rotate_e2e"; + const size_t dim = 128; + const size_t cnt = 2000; + const size_t topk = 10; + + // Create schema with HNSW + INT8 + enable_rotate + zvec_collection_schema_t *schema = + zvec_collection_schema_create("int8_rotate_test"); + TEST_ASSERT(schema != NULL); + + // Add ID field + zvec_field_schema_t *id_field = + zvec_field_schema_create("id", ZVEC_DATA_TYPE_INT64, false, 0); + zvec_collection_schema_add_field(schema, id_field); + + // Add vector field with HNSW + INT8 + rotate + zvec_index_params_t *hnsw_params = + zvec_index_params_create(ZVEC_INDEX_TYPE_HNSW); + TEST_ASSERT(hnsw_params != NULL); + zvec_index_params_set_metric_type(hnsw_params, ZVEC_METRIC_TYPE_L2); + zvec_index_params_set_hnsw_params(hnsw_params, 16, 100); + zvec_index_params_set_quantize_type(hnsw_params, ZVEC_QUANTIZE_TYPE_INT8); + zvec_index_params_set_quantizer_enable_rotate(hnsw_params, true); + + zvec_field_schema_t *vec_field = zvec_field_schema_create( + "embedding", ZVEC_DATA_TYPE_VECTOR_FP32, false, dim); + zvec_field_schema_set_index_params(vec_field, hnsw_params); + zvec_collection_schema_add_field(schema, vec_field); + zvec_index_params_destroy(hnsw_params); + + // Create and open collection + zvec_collection_t *collection = NULL; + zvec_error_code_t err = + zvec_collection_create_and_open(temp_dir, schema, NULL, &collection); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(collection != NULL); + + // Insert 2000 random vectors + srand(42); + for (size_t i = 0; i < cnt; i++) { + float *vec = (float *)malloc(dim * sizeof(float)); + TEST_ASSERT(vec != NULL); + for (size_t j = 0; j < dim; j++) { + vec[j] = (float)rand() / (float)RAND_MAX * 2.0f - 1.0f; + } + + zvec_doc_t *doc = zvec_doc_create(); + zvec_doc_set_pk(doc, zvec_test_make_pk(i + 1)); + zvec_doc_add_field_by_value(doc, "id", ZVEC_DATA_TYPE_INT64, + &(int64_t){(int64_t)(i + 1)}, sizeof(int64_t)); + zvec_doc_add_field_by_value(doc, "embedding", ZVEC_DATA_TYPE_VECTOR_FP32, + vec, dim * sizeof(float)); + + size_t success_count, error_count; + const zvec_doc_t *docs[] = {doc}; + err = zvec_collection_insert(collection, docs, 1, &success_count, + &error_count); + TEST_ASSERT(err == ZVEC_OK); + zvec_doc_destroy(doc); + free(vec); + } + + // Flush to build index + zvec_collection_flush(collection); + + // Search + float *query = (float *)malloc(dim * sizeof(float)); + TEST_ASSERT(query != NULL); + for (size_t j = 0; j < dim; j++) { + query[j] = (float)rand() / (float)RAND_MAX * 2.0f - 1.0f; + } + + zvec_vector_query_t *vq = zvec_vector_query_create(); + TEST_ASSERT(vq != NULL); + zvec_vector_query_set_field_name(vq, "embedding"); + zvec_vector_query_set_query_vector(vq, query, dim * sizeof(float)); + zvec_vector_query_set_topk(vq, topk); + + zvec_doc_t **results = NULL; + size_t result_count = 0; + err = zvec_collection_query(collection, vq, &results, &result_count); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(result_count > 0); + printf(" [int8_rotate_e2e] first search returned %zu results\n", + result_count); + zvec_docs_free(results, result_count); + + // Close and reopen + zvec_collection_close(collection); + collection = NULL; + + err = zvec_collection_open(temp_dir, NULL, &collection); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(collection != NULL); + + // Search again after reopen (rotator should auto-load from storage) + results = NULL; + result_count = 0; + err = zvec_collection_query(collection, vq, &results, &result_count); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(result_count > 0); + printf(" [int8_rotate_e2e] reopen search returned %zu results\n", + result_count); + zvec_docs_free(results, result_count); + + // Cleanup + zvec_vector_query_destroy(vq); + zvec_collection_destroy(collection); + zvec_collection_schema_destroy(schema); + free(query); + cleanup_temp_directory(temp_dir); + + TEST_END(); +} + void test_index_params_api_functions(void) { TEST_START(); @@ -5953,6 +6126,8 @@ int main(void) { // Index tests test_index_params(); test_index_params_functions(); + test_quantizer_enable_rotate(); + test_int8_rotate_e2e(); test_index_params_api_functions(); test_index_creation_and_management(); diff --git a/tests/core/algorithm/flat/flat_streamer_test.cc b/tests/core/algorithm/flat/flat_streamer_test.cc index b6e57a7a2..2f614f218 100644 --- a/tests/core/algorithm/flat/flat_streamer_test.cc +++ b/tests/core/algorithm/flat/flat_streamer_test.cc @@ -1209,6 +1209,100 @@ TEST_F(FlatStreamerTest, TestAddAndSearchWithID2) { EXPECT_GT(topk1Recall, 0.80f); } +// Test Flat + INT8 quantization + rotation end-to-end +TEST_F(FlatStreamerTest, TestInt8WithRotate) { + constexpr size_t kTestDim = 128; + constexpr size_t kCnt = 2000U; + constexpr size_t kTopk = 10; + + IndexStreamer::Pointer streamer = + IndexFactory::CreateStreamer("FlatStreamer"); + ASSERT_NE(nullptr, streamer); + + Params params; + + IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, kTestDim); + index_meta_raw.set_metric("SquaredEuclidean", 0, Params()); + + // Create INT8 converter with rotation enabled + Params converter_params; + converter_params.set("integer_streaming.converter.enable_rotate", true); + auto converter = IndexFactory::CreateConverter("Int8StreamingConverter"); + ASSERT_NE(nullptr, converter); + ASSERT_EQ(0, converter->init(index_meta_raw, converter_params)); + + IndexMeta index_meta = converter->meta(); + + auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); + ASSERT_NE(nullptr, reformer); + ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); + + Params stg_params; + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "TestInt8WithRotate.index", true)); + ASSERT_EQ(0, streamer->init(index_meta, params)); + ASSERT_EQ(0, streamer->open(storage)); + + // Add 2000 vectors + auto ctx = streamer->create_context(); + ASSERT_TRUE(!!ctx); + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, kTestDim); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (size_t i = 0; i < kCnt; i++) { + NumericalVector vec(kTestDim); + for (size_t j = 0; j < kTestDim; ++j) vec[j] = dist(gen); + + std::string new_vec; + IndexQueryMeta new_meta; + ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); + ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); + } + + streamer->flush(0UL); + streamer.reset(); + storage.reset(); + + // Reopen: reformer should auto-detect rotator from storage + auto storage2 = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage2); + ASSERT_EQ(0, storage2->init(stg_params)); + ASSERT_EQ(0, storage2->open(dir_ + "TestInt8WithRotate.index", false)); + + auto streamer2 = IndexFactory::CreateStreamer("FlatStreamer"); + ASSERT_NE(nullptr, streamer2); + ASSERT_EQ(0, streamer2->init(index_meta, params)); + ASSERT_EQ(0, streamer2->open(storage2)); + + auto reformer2 = IndexFactory::CreateReformer(index_meta.reformer_name()); + ASSERT_NE(nullptr, reformer2); + ASSERT_EQ(0, reformer2->init(index_meta.reformer_params())); + ASSERT_EQ(0, reformer2->load(storage2)); + + // Search: verify results are non-empty + auto knnCtx = streamer2->create_context(); + knnCtx->set_topk(kTopk); + auto linearCtx = streamer2->create_context(); + linearCtx->set_topk(kTopk); + + NumericalVector query(kTestDim); + for (size_t j = 0; j < kTestDim; ++j) query[j] = dist(gen); + + std::string new_query; + IndexQueryMeta new_qmeta; + ASSERT_EQ(0, + reformer2->transform(query.data(), qmeta, &new_query, &new_qmeta)); + ASSERT_EQ(0, streamer2->search_impl(new_query.data(), new_qmeta, knnCtx)); + ASSERT_EQ(0, + streamer2->search_bf_impl(new_query.data(), new_qmeta, linearCtx)); + + EXPECT_EQ(kTopk, knnCtx->result().size()); + EXPECT_EQ(kTopk, linearCtx->result().size()); +} + #if defined(__GNUC__) || defined(__GNUG__) #pragma GCC diagnostic pop #endif \ No newline at end of file diff --git a/tests/core/algorithm/hnsw/hnsw_streamer_test.cc b/tests/core/algorithm/hnsw/hnsw_streamer_test.cc index 8ef70cb77..bb57a95e9 100644 --- a/tests/core/algorithm/hnsw/hnsw_streamer_test.cc +++ b/tests/core/algorithm/hnsw/hnsw_streamer_test.cc @@ -3779,6 +3779,105 @@ TEST_F(HnswStreamerTest, TestContiguousMultiThreadSearch) { s3.wait(); } +// Test HNSW + INT8 quantization + rotation end-to-end +TEST_F(HnswStreamerTest, TestInt8WithRotate) { + constexpr size_t kTestDim = 128; + constexpr size_t kCnt = 2000U; + constexpr size_t kTopk = 10; + + IndexStreamer::Pointer streamer = + IndexFactory::CreateStreamer("HnswStreamer"); + ASSERT_NE(nullptr, streamer); + + ailego::Params params; + params.set(PARAM_HNSW_STREAMER_MAX_NEIGHBOR_COUNT, 16U); + params.set(PARAM_HNSW_STREAMER_SCALING_FACTOR, 5U); + params.set(PARAM_HNSW_STREAMER_EFCONSTRUCTION, 100); + params.set(PARAM_HNSW_STREAMER_EF, 100); + params.set(PARAM_HNSW_STREAMER_BRUTE_FORCE_THRESHOLD, 1000U); + + IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, kTestDim); + index_meta_raw.set_metric("SquaredEuclidean", 0, ailego::Params()); + + // Create INT8 converter with rotation enabled + ailego::Params converter_params; + converter_params.set("integer_streaming.converter.enable_rotate", true); + auto converter = IndexFactory::CreateConverter("Int8StreamingConverter"); + ASSERT_NE(nullptr, converter); + ASSERT_EQ(0, converter->init(index_meta_raw, converter_params)); + + IndexMeta index_meta = converter->meta(); + + auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); + ASSERT_NE(nullptr, reformer); + ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); + + ailego::Params stg_params; + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "TestInt8WithRotate.index", true)); + ASSERT_EQ(0, streamer->init(index_meta, params)); + ASSERT_EQ(0, streamer->open(storage)); + + // Add 2000 vectors + auto ctx = streamer->create_context(); + ASSERT_TRUE(!!ctx); + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, kTestDim); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (size_t i = 0; i < kCnt; i++) { + NumericalVector vec(kTestDim); + for (size_t j = 0; j < kTestDim; ++j) vec[j] = dist(gen); + + std::string new_vec; + IndexQueryMeta new_meta; + ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); + ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); + } + + streamer->flush(0UL); + streamer.reset(); + storage.reset(); + + // Reopen: reformer should auto-detect rotator from storage + auto storage2 = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage2); + ASSERT_EQ(0, storage2->init(stg_params)); + ASSERT_EQ(0, storage2->open(dir_ + "TestInt8WithRotate.index", false)); + + auto streamer2 = IndexFactory::CreateStreamer("HnswStreamer"); + ASSERT_NE(nullptr, streamer2); + ASSERT_EQ(0, streamer2->init(index_meta, params)); + ASSERT_EQ(0, streamer2->open(storage2)); + + auto reformer2 = IndexFactory::CreateReformer(index_meta.reformer_name()); + ASSERT_NE(nullptr, reformer2); + ASSERT_EQ(0, reformer2->init(index_meta.reformer_params())); + ASSERT_EQ(0, reformer2->load(storage2)); + + // Search: verify knn results are non-empty + auto knnCtx = streamer2->create_context(); + knnCtx->set_topk(kTopk); + auto linearCtx = streamer2->create_context(); + linearCtx->set_topk(kTopk); + + NumericalVector query(kTestDim); + for (size_t j = 0; j < kTestDim; ++j) query[j] = dist(gen); + + std::string new_query; + IndexQueryMeta new_qmeta; + ASSERT_EQ(0, + reformer2->transform(query.data(), qmeta, &new_query, &new_qmeta)); + ASSERT_EQ(0, streamer2->search_impl(new_query.data(), new_qmeta, knnCtx)); + ASSERT_EQ(0, + streamer2->search_bf_impl(new_query.data(), new_qmeta, linearCtx)); + + EXPECT_EQ(kTopk, knnCtx->result().size()); + EXPECT_EQ(kTopk, linearCtx->result().size()); +} + } // namespace core } // namespace zvec diff --git a/tests/core/algorithm/vamana/vamana_streamer_test.cc b/tests/core/algorithm/vamana/vamana_streamer_test.cc index cdae281e0..2454d64b3 100644 --- a/tests/core/algorithm/vamana/vamana_streamer_test.cc +++ b/tests/core/algorithm/vamana/vamana_streamer_test.cc @@ -785,6 +785,105 @@ TEST_F(VamanaStreamerTest, TestConcurrentBuild) { ASSERT_GT(result.size(), 0UL); } +// Test Vamana + INT8 quantization + rotation end-to-end +TEST_F(VamanaStreamerTest, TestInt8WithRotate) { + constexpr size_t kTestDim = 128; + constexpr size_t kCnt = 2000U; + constexpr size_t kTopk = 10; + + IndexStreamer::Pointer streamer = + IndexFactory::CreateStreamer("VamanaStreamer"); + ASSERT_NE(nullptr, streamer); + + Params params; + params.set(PARAM_VAMANA_STREAMER_MAX_DEGREE, 32U); + params.set(PARAM_VAMANA_STREAMER_SEARCH_LIST_SIZE, 100U); + params.set(PARAM_VAMANA_STREAMER_ALPHA, 1.2f); + params.set(PARAM_VAMANA_STREAMER_EF, 64U); + params.set(PARAM_VAMANA_STREAMER_BRUTE_FORCE_THRESHOLD, 500U); + + IndexMeta index_meta_raw(IndexMeta::DataType::DT_FP32, kTestDim); + index_meta_raw.set_metric("SquaredEuclidean", 0, Params()); + + // Create INT8 converter with rotation enabled + Params converter_params; + converter_params.set("integer_streaming.converter.enable_rotate", true); + auto converter = IndexFactory::CreateConverter("Int8StreamingConverter"); + ASSERT_NE(nullptr, converter); + ASSERT_EQ(0, converter->init(index_meta_raw, converter_params)); + + IndexMeta index_meta = converter->meta(); + + auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); + ASSERT_NE(nullptr, reformer); + ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); + + Params stg_params; + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage); + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(dir_ + "TestInt8WithRotate.index", true)); + ASSERT_EQ(0, streamer->init(index_meta, params)); + ASSERT_EQ(0, streamer->open(storage)); + + // Add 2000 vectors + auto ctx = streamer->create_context(); + ASSERT_TRUE(!!ctx); + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, kTestDim); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for (size_t i = 0; i < kCnt; i++) { + NumericalVector vec(kTestDim); + for (size_t j = 0; j < kTestDim; ++j) vec[j] = dist(gen); + + std::string new_vec; + IndexQueryMeta new_meta; + ASSERT_EQ(0, reformer->convert(vec.data(), qmeta, &new_vec, &new_meta)); + ASSERT_EQ(0, streamer->add_impl(i, new_vec.data(), new_meta, ctx)); + } + + streamer->flush(0UL); + streamer.reset(); + storage.reset(); + + // Reopen: reformer should auto-detect rotator from storage + auto storage2 = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(nullptr, storage2); + ASSERT_EQ(0, storage2->init(stg_params)); + ASSERT_EQ(0, storage2->open(dir_ + "TestInt8WithRotate.index", false)); + + auto streamer2 = IndexFactory::CreateStreamer("VamanaStreamer"); + ASSERT_NE(nullptr, streamer2); + ASSERT_EQ(0, streamer2->init(index_meta, params)); + ASSERT_EQ(0, streamer2->open(storage2)); + + auto reformer2 = IndexFactory::CreateReformer(index_meta.reformer_name()); + ASSERT_NE(nullptr, reformer2); + ASSERT_EQ(0, reformer2->init(index_meta.reformer_params())); + ASSERT_EQ(0, reformer2->load(storage2)); + + // Search: verify knn results are non-empty + auto knnCtx = streamer2->create_context(); + knnCtx->set_topk(kTopk); + auto linearCtx = streamer2->create_context(); + linearCtx->set_topk(kTopk); + + NumericalVector query(kTestDim); + for (size_t j = 0; j < kTestDim; ++j) query[j] = dist(gen); + + std::string new_query; + IndexQueryMeta new_qmeta; + ASSERT_EQ(0, + reformer2->transform(query.data(), qmeta, &new_query, &new_qmeta)); + ASSERT_EQ(0, streamer2->search_impl(new_query.data(), new_qmeta, knnCtx)); + ASSERT_EQ(0, + streamer2->search_bf_impl(new_query.data(), new_qmeta, linearCtx)); + + EXPECT_EQ(kTopk, knnCtx->result().size()); + EXPECT_EQ(kTopk, linearCtx->result().size()); +} + } // namespace core } // namespace zvec diff --git a/tests/core/quantizer/CMakeLists.txt b/tests/core/quantizer/CMakeLists.txt index d28e9e5cd..05004ac58 100644 --- a/tests/core/quantizer/CMakeLists.txt +++ b/tests/core/quantizer/CMakeLists.txt @@ -7,7 +7,7 @@ foreach(CC_SRCS ${ALL_TEST_SRCS}) cc_gtest( NAME ${CC_TARGET} STRICT - LIBS zvec_ailego core_framework core_quantizer + LIBS zvec_ailego core_framework core_utility core_quantizer SRCS ${CC_SRCS} INCS . ${PROJECT_ROOT_DIR}/src/core/ ) diff --git a/tests/core/quantizer/integer_quantizer_reformer_test.cc b/tests/core/quantizer/integer_quantizer_reformer_test.cc index 21967bb23..3ddcad4c7 100644 --- a/tests/core/quantizer/integer_quantizer_reformer_test.cc +++ b/tests/core/quantizer/integer_quantizer_reformer_test.cc @@ -12,10 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include +#include #include #include +#include "quantizer/rotator/rotator.h" +#include "tests/test_util.h" #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_holder.h" @@ -821,3 +825,162 @@ TEST(IntegerReformer, Int4InitConverterWithTrainedParams) { EXPECT_EQ(buffer, buffer2); } } + +// Test FhtKac rotator (dim=200, 4-aligned, non-power-of-2 kacs_walk path) +TEST(RecordRotatorTest, RotateUnrotateFhtKac_Dim200) { + const size_t dim = 200; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::FhtKac); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "FhtKac (dim=200) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} + +// Test FhtKac rotator (dim=96, 32-aligned but not 64-aligned, kacs_walk path) +TEST(RecordRotatorTest, RotateUnrotateFhtKac_Dim96) { + const size_t dim = 96; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::FhtKac); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "FhtKac (dim=96) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} + +// Test FhtKac rotator (dim=768, real-world embedding dimension, kacs_walk) +TEST(RecordRotatorTest, RotateUnrotateFhtKac_Dim768) { + const size_t dim = 768; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::FhtKac); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "FhtKac (dim=768) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} + +// Test FhtKac rotator (dim=128, power-of-2, pure FHT path) +TEST(RecordRotatorTest, RotateUnrotateFhtKac_Dim128) { + const size_t dim = 128; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::FhtKac); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "FhtKac (dim=128) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} + +// Test dump/open roundtrip: serialize then deserialize, verify rotate output +// matches. +TEST(RecordRotatorTest, DumpOpenRoundtrip) { + const std::string test_dir = "record_rotator_dump_test_dir/"; + zvec::test_util::RemoveTestPath(test_dir); + + const size_t dim = 128; + + // Build and dump original rotator + RecordRotator original; + original.init(dim); + EXPECT_EQ(original.rotator_type(), RecordRotatorType::FhtKac); + + auto storage = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(storage, nullptr); + zvec::ailego::Params stg_params; + ASSERT_EQ(0, storage->init(stg_params)); + ASSERT_EQ(0, storage->open(test_dir + "rotator.index", true)); + ASSERT_EQ(0, original.dump(storage)); + + // Close and reopen storage + storage.reset(); + + auto storage2 = IndexFactory::CreateStorage("MMapFileStorage"); + ASSERT_NE(storage2, nullptr); + ASSERT_EQ(0, storage2->init(stg_params)); + ASSERT_EQ(0, storage2->open(test_dir + "rotator.index", false)); + + // Load rotator from storage + RecordRotator loaded; + ASSERT_EQ(0, loaded.open(storage2)); + + // Verify metadata + EXPECT_EQ(original.rotator_type(), loaded.rotator_type()); + EXPECT_EQ(original.dimension(), loaded.dimension()); + EXPECT_TRUE(loaded.initialized()); + + // Verify rotate output matches + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + std::vector vec(dim); + for (size_t j = 0; j < dim; ++j) vec[j] = dist(gen); + + auto rotated_orig = original.rotate(vec.data()); + auto rotated_loaded = loaded.rotate(vec.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(rotated_orig[j] - rotated_loaded[j])); + std::cout << "DumpOpen roundtrip max error: " << max_err << std::endl; + EXPECT_EQ(max_err, 0.0f); + + zvec::test_util::RemoveTestPath(test_dir); +} diff --git a/tests/db/index/common/db_proto_converter_test.cc b/tests/db/index/common/db_proto_converter_test.cc index dff93e9dd..9c71c3c89 100644 --- a/tests/db/index/common/db_proto_converter_test.cc +++ b/tests/db/index/common/db_proto_converter_test.cc @@ -470,4 +470,81 @@ TEST(ConverterTest, SegmentMetaWithEmptyFields) { EXPECT_EQ(pb_result.persisted_blocks_size(), 0); EXPECT_FALSE(pb_result.has_writing_forward_block()); EXPECT_EQ(pb_result.indexed_vector_fields_size(), 0); +} + +// ==================== enable_rotate roundtrip tests ==================== + +TEST(ConverterTest, HnswIndexParamsWithEnableRotate) { + // C++ -> PB -> C++ roundtrip with enable_rotate = true + HnswIndexParams original(MetricType::COSINE, 16, 200, QuantizeType::INT8, + false, QuantizerParam(true)); + EXPECT_TRUE(original.quantizer_param().enable_rotate()); + + auto pb = ProtoConverter::ToPb(&original); + EXPECT_TRUE(pb.base().quantizer_param().enable_rotate()); + + auto restored = ProtoConverter::FromPb(pb); + ASSERT_NE(restored, nullptr); + EXPECT_TRUE(restored->quantizer_param().enable_rotate()); + EXPECT_TRUE(restored->enable_rotate()); // convenience getter + EXPECT_EQ(restored->metric_type(), MetricType::COSINE); + EXPECT_EQ(restored->m(), 16); + EXPECT_EQ(restored->ef_construction(), 200); + EXPECT_EQ(restored->quantize_type(), QuantizeType::INT8); + + // C++ -> PB -> C++ roundtrip with enable_rotate = false + HnswIndexParams original_no_rot(MetricType::L2, 32, 100, QuantizeType::FP16); + auto pb2 = ProtoConverter::ToPb(&original_no_rot); + EXPECT_FALSE(pb2.base().quantizer_param().enable_rotate()); + auto restored2 = ProtoConverter::FromPb(pb2); + ASSERT_NE(restored2, nullptr); + EXPECT_FALSE(restored2->quantizer_param().enable_rotate()); +} + +TEST(ConverterTest, FlatIndexParamsWithEnableRotate) { + FlatIndexParams original(MetricType::IP, QuantizeType::INT8, + QuantizerParam(true)); + EXPECT_TRUE(original.quantizer_param().enable_rotate()); + + auto pb = ProtoConverter::ToPb(&original); + EXPECT_TRUE(pb.base().quantizer_param().enable_rotate()); + + auto restored = ProtoConverter::FromPb(pb); + ASSERT_NE(restored, nullptr); + EXPECT_TRUE(restored->quantizer_param().enable_rotate()); + EXPECT_EQ(restored->metric_type(), MetricType::IP); + EXPECT_EQ(restored->quantize_type(), QuantizeType::INT8); + + // enable_rotate = false + FlatIndexParams original_no_rot(MetricType::L2, QuantizeType::FP16); + auto pb2 = ProtoConverter::ToPb(&original_no_rot); + EXPECT_FALSE(pb2.base().quantizer_param().enable_rotate()); + auto restored2 = ProtoConverter::FromPb(pb2); + EXPECT_FALSE(restored2->quantizer_param().enable_rotate()); +} + +TEST(ConverterTest, IVFIndexParamsWithEnableRotate) { + IVFIndexParams original(MetricType::COSINE, 256, 20, true, QuantizeType::INT8, + QuantizerParam(true)); + EXPECT_TRUE(original.quantizer_param().enable_rotate()); + + auto pb = ProtoConverter::ToPb(&original); + EXPECT_TRUE(pb.base().quantizer_param().enable_rotate()); + + auto restored = ProtoConverter::FromPb(pb); + ASSERT_NE(restored, nullptr); + EXPECT_TRUE(restored->quantizer_param().enable_rotate()); + EXPECT_EQ(restored->metric_type(), MetricType::COSINE); + EXPECT_EQ(restored->n_list(), 256); + EXPECT_EQ(restored->n_iters(), 20); + EXPECT_TRUE(restored->use_soar()); + EXPECT_EQ(restored->quantize_type(), QuantizeType::INT8); + + // enable_rotate = false + IVFIndexParams original_no_rot(MetricType::L2, 128, 10, false, + QuantizeType::FP16); + auto pb2 = ProtoConverter::ToPb(&original_no_rot); + EXPECT_FALSE(pb2.base().quantizer_param().enable_rotate()); + auto restored2 = ProtoConverter::FromPb(pb2); + EXPECT_FALSE(restored2->quantizer_param().enable_rotate()); } \ No newline at end of file diff --git a/tests/db/index/common/index_params_test.cc b/tests/db/index/common/index_params_test.cc index af67e7398..d5a85aeb9 100644 --- a/tests/db/index/common/index_params_test.cc +++ b/tests/db/index/common/index_params_test.cc @@ -186,4 +186,96 @@ TEST(IndexParamsTest, DynamicPointerCast) { IndexParams &base_ref = *base_ptr; auto &hnsw_ref = dynamic_cast(base_ref); EXPECT_EQ(hnsw_ref.type(), IndexType::HNSW); +} + +// ==================== QuantizerParam tests ==================== + +TEST(IndexParamsTest, QuantizerParamBasic) { + // Default constructor: enable_rotate should be false + QuantizerParam qp_default; + EXPECT_FALSE(qp_default.enable_rotate()); + + // Constructor with true + QuantizerParam qp_true(true); + EXPECT_TRUE(qp_true.enable_rotate()); + + // Constructor with false + QuantizerParam qp_false(false); + EXPECT_FALSE(qp_false.enable_rotate()); + + // Setter + qp_default.set_enable_rotate(true); + EXPECT_TRUE(qp_default.enable_rotate()); + qp_default.set_enable_rotate(false); + EXPECT_FALSE(qp_default.enable_rotate()); + + // Equality + EXPECT_TRUE(qp_true == QuantizerParam(true)); + EXPECT_TRUE(qp_false == QuantizerParam(false)); + EXPECT_FALSE(qp_true == qp_false); + + // Inequality + EXPECT_TRUE(qp_true != qp_false); + EXPECT_FALSE(qp_true != QuantizerParam(true)); +} + +TEST(IndexParamsTest, QuantizerParamWithVectorIndex) { + // HnswIndexParams + { + HnswIndexParams params(MetricType::COSINE, 16, 100, QuantizeType::INT8); + EXPECT_FALSE(params.quantizer_param().enable_rotate()); + EXPECT_FALSE(params.enable_rotate()); // convenience getter + + params.set_quantizer_param(QuantizerParam(true)); + EXPECT_TRUE(params.quantizer_param().enable_rotate()); + EXPECT_TRUE(params.enable_rotate()); + + // Clone preserves quantizer_param + auto cloned = params.clone(); + auto *cloned_hnsw = dynamic_cast(cloned.get()); + ASSERT_NE(cloned_hnsw, nullptr); + EXPECT_TRUE(cloned_hnsw->quantizer_param().enable_rotate()); + EXPECT_TRUE(*cloned == params); + + // Equality: different enable_rotate -> not equal + HnswIndexParams params2(MetricType::COSINE, 16, 100, QuantizeType::INT8); + params2.set_quantizer_param(QuantizerParam(false)); + EXPECT_FALSE(params == params2); + } + + // FlatIndexParams + { + FlatIndexParams params(MetricType::L2, QuantizeType::INT8); + EXPECT_FALSE(params.quantizer_param().enable_rotate()); + + params.set_quantizer_param(QuantizerParam(true)); + EXPECT_TRUE(params.quantizer_param().enable_rotate()); + EXPECT_TRUE(params.enable_rotate()); + + auto cloned = params.clone(); + auto *cloned_flat = dynamic_cast(cloned.get()); + ASSERT_NE(cloned_flat, nullptr); + EXPECT_TRUE(cloned_flat->quantizer_param().enable_rotate()); + + FlatIndexParams params2(MetricType::L2, QuantizeType::INT8); + EXPECT_FALSE(params == params2); + } + + // IVFIndexParams + { + IVFIndexParams params(MetricType::IP, 128, 10, false, QuantizeType::INT8); + EXPECT_FALSE(params.quantizer_param().enable_rotate()); + + params.set_quantizer_param(QuantizerParam(true)); + EXPECT_TRUE(params.quantizer_param().enable_rotate()); + EXPECT_TRUE(params.enable_rotate()); + + auto cloned = params.clone(); + auto *cloned_ivf = dynamic_cast(cloned.get()); + ASSERT_NE(cloned_ivf, nullptr); + EXPECT_TRUE(cloned_ivf->quantizer_param().enable_rotate()); + + IVFIndexParams params2(MetricType::IP, 128, 10, false, QuantizeType::INT8); + EXPECT_FALSE(params == params2); + } } \ No newline at end of file diff --git a/tools/core/local_builder.cc b/tools/core/local_builder.cc index 52ae8321d..7d3b7bf0a 100644 --- a/tools/core/local_builder.cc +++ b/tools/core/local_builder.cc @@ -35,7 +35,6 @@ #include "zvec/core/framework/index_reformer.h" #include "zvec/core/framework/index_streamer.h" #include "index_meta_helper.h" -#include "meta_segment_common.h" #include "vecs_index_holder.h" #ifdef __clang__ @@ -206,10 +205,6 @@ bool check_config(YAML::Node &config_root) { return false; } } - if (!common["DumpPath"]) { - LOG_ERROR("Can not find [DumpPath] in config"); - return false; - } if (!config_root["BuilderParams"]) { LOG_ERROR("Can not find [BuilderParams] in config"); return false; @@ -217,75 +212,6 @@ bool check_config(YAML::Node &config_root) { return true; } -static inline size_t AlignSize(size_t size) { - return (size + 0x1F) & (~0x1F); -} - -bool dump_meta_segment(const IndexDumper::Pointer &dumper, - const std::string &segment_id, const void *data, - size_t size, size_t &writes) { - size_t len = dumper->write(data, size); - if (len != size) { - LOG_ERROR("Dump segment %s data failed, expect: %lu, actual: %lu", - segment_id.c_str(), size, len); - return false; - } - - size_t padding_size = AlignSize(size) - size; - if (padding_size > 0) { - std::string padding(padding_size, '\0'); - if (dumper->write(padding.data(), padding_size) != padding_size) { - LOG_ERROR("Append padding failed, size %lu", padding_size); - return false; - } - } - - uint32_t crc = ailego::Crc32c::Hash(data, size); - int ret = dumper->append(segment_id, size, padding_size, crc); - if (ret != 0) { - LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); - return false; - } - - writes = len + padding_size; - - return true; -} - -int dump_taglist(IndexDumper::Pointer dumper, size_t num_vecs, - const void *key_base, const void *taglist_data, - uint64_t taglist_size) { - TagListHeader taglist_header; - - taglist_header.num_vecs = num_vecs; - - size_t total_writes; - - bool ret = - dump_meta_segment(dumper, TAGLIST_HEADER_SEGMENT_NAME, &taglist_header, - sizeof(TagListHeader), total_writes); - if (ret == false) { - LOG_ERROR("dump taglist meta failed"); - return IndexError_WriteData; - } - - ret = dump_meta_segment(dumper, TAGLIST_KEY_SEGMENT_NAME, key_base, - num_vecs * sizeof(uint64_t), total_writes); - if (ret == false) { - LOG_ERROR("dump taglist key failed"); - return IndexError_WriteData; - } - - ret = dump_meta_segment(dumper, TAGLIST_DATA_SEGMENT_NAME, taglist_data, - taglist_size, total_writes); - if (ret == false) { - LOG_ERROR("dump taglist data failed"); - return IndexError_WriteData; - } - - return 0; -} - int do_build_sparse_by_streamer(IndexStreamer::Pointer &streamer, uint32_t thread_count) { int ret; @@ -422,7 +348,8 @@ int do_build_sparse_by_streamer(IndexStreamer::Pointer &streamer, } int build_sparse_by_streamer(IndexStreamer::Pointer &streamer, - YAML::Node &config_common) { + YAML::Node &config_common, + const IndexConverter::Pointer &converter) { if (!config_common["IndexPath"]) { LOG_ERROR("Miss params IndexPath for Streamer"); return IndexError_InvalidArgument; @@ -451,6 +378,15 @@ int build_sparse_by_streamer(IndexStreamer::Pointer &streamer, return IndexError_Runtime; } + // Dump converter state (e.g. rotator) to storage for streaming build + if (converter) { + ret = converter->dump_to_storage(storage); + if (ret != 0) { + LOG_ERROR("Failed to dump converter to storage, ret=%d", ret); + return ret; + } + } + size_t thread_count = config_common["ThreadCount"] ? config_common["ThreadCount"].as() : std::thread::hardware_concurrency(); @@ -464,7 +400,8 @@ int build_sparse_by_streamer(IndexStreamer::Pointer &streamer, } int do_build_by_streamer(IndexStreamer::Pointer &streamer, - uint32_t thread_count, RetrievalMode retrieval_mode) { + uint32_t thread_count, RetrievalMode retrieval_mode, + const IndexStorage::Pointer &storage = nullptr) { int ret; ailego::ThreadPool pool(thread_count, false); std::atomic finished{0}; @@ -486,6 +423,14 @@ int do_build_by_streamer(IndexStreamer::Pointer &streamer, return IndexError_NoExist; } reformer->init(meta.reformer_params()); + // Load reformer state from storage (e.g. rotator for IntegerStreaming) + if (storage) { + ret = reformer->load(storage); + if (ret != 0) { + LOG_ERROR("Failed to load reformer from storage, ret=%d", ret); + return ret; + } + } } } @@ -593,7 +538,8 @@ int do_build_by_streamer(IndexStreamer::Pointer &streamer, } int build_by_streamer(IndexStreamer::Pointer &streamer, - YAML::Node &config_common) { + YAML::Node &config_common, + const IndexConverter::Pointer &converter) { if (!config_common["IndexPath"]) { LOG_ERROR("Miss params IndexPath for Streamer"); return IndexError_InvalidArgument; @@ -624,6 +570,15 @@ int build_by_streamer(IndexStreamer::Pointer &streamer, return IndexError_Runtime; } + // Dump converter state (e.g. rotator) to storage for streaming build + if (converter) { + ret = converter->dump_to_storage(storage); + if (ret != 0) { + LOG_ERROR("Failed to dump converter to storage, ret=%d", ret); + return ret; + } + } + size_t thread_count = config_common["ThreadCount"] ? config_common["ThreadCount"].as() : std::thread::hardware_concurrency(); @@ -639,14 +594,15 @@ int build_by_streamer(IndexStreamer::Pointer &streamer, LOG_DEBUG("thread count: %zu, retrieval mode: %s", thread_count, retrieval_mode == 1 ? "Dense" : "Sparse"); - do_build_by_streamer(streamer, thread_count, retrieval_mode); + do_build_by_streamer(streamer, thread_count, retrieval_mode, storage); return 0; } IndexSparseHolder::Pointer convert_sparse_holder( const std::string &name, const ailego::Params ¶ms, - VecsIndexSparseHolder::Pointer &in_holder, IndexMeta &index_meta) { + VecsIndexSparseHolder::Pointer &in_holder, IndexMeta &index_meta, + IndexConverter::Pointer *out_converter) { IndexSparseHolder::Pointer cast_holder = std::dynamic_pointer_cast(in_holder); if (name.empty()) { @@ -679,13 +635,17 @@ IndexSparseHolder::Pointer convert_sparse_holder( index_meta = converter->meta(); + if (out_converter) { + *out_converter = converter; + } return converter->sparse_result(); } IndexHolder::Pointer convert_holder(const std::string &name, const ailego::Params ¶ms, VecsIndexHolder::Pointer &in_holder, - IndexMeta &index_meta) { + IndexMeta &index_meta, + IndexConverter::Pointer *out_converter) { IndexHolder::Pointer cast_holder = std::dynamic_pointer_cast(in_holder); if (name.empty()) { @@ -718,6 +678,9 @@ IndexHolder::Pointer convert_holder(const std::string &name, index_meta = converter->meta(); + if (out_converter) { + *out_converter = converter; + } return converter->result(); } @@ -782,8 +745,9 @@ int do_build_sparse(YAML::Node &config_root, YAML::Node &config_common) { } cout << "Created builder " << builder_class << endl; + IndexConverter::Pointer build_converter; IndexSparseHolder::Pointer cv_build_holder = convert_sparse_holder( - converter_name, converter_params, build_holder, meta); + converter_name, converter_params, build_holder, meta, &build_converter); if (!cv_build_holder) { LOG_ERROR("Convert holder failed."); return -1; @@ -819,7 +783,7 @@ int do_build_sparse(YAML::Node &config_root, YAML::Node &config_common) { } IndexSparseHolder::Pointer cv_train_holder = convert_sparse_holder( - converter_name, converter_params, train_holder, meta); + converter_name, converter_params, train_holder, meta, nullptr); if (!cv_train_holder) { LOG_ERROR("Convert train holder failed."); return -1; @@ -846,7 +810,7 @@ int do_build_sparse(YAML::Node &config_root, YAML::Node &config_common) { if (builder != nullptr) { ret = builder->build(std::move(cv_build_holder)); } else { - ret = build_sparse_by_streamer(streamer, config_common); + ret = build_sparse_by_streamer(streamer, config_common, build_converter); } size_t build_time = timer.milli_seconds(); if (ret < 0) { @@ -856,45 +820,6 @@ int do_build_sparse(YAML::Node &config_root, YAML::Node &config_common) { cout << "Build finished, consume " << build_time << "ms." << endl; signal(SIGINT, SIG_DFL); - // DUMP - IndexDumper::Pointer dumper = IndexFactory::CreateDumper("FileDumper"); - if (!dumper) { - LOG_ERROR("Failed to create FileDumper."); - return -1; - } - string dump_prefix = config_common["DumpPath"].as(); - ret = dumper->create(dump_prefix); - if (ret != 0) { - LOG_ERROR("Failed to create in dumper, ret=%d", ret); - return -1; - } - timer.reset(); - ret = streamer ? streamer->dump(dumper) : builder->dump(dumper); - size_t dump_time = timer.milli_seconds(); - if (ret == IndexError_NotImplemented) { - LOG_WARN("Dump index not implemented"); - } else if (ret < 0) { - LOG_ERROR("Failed to dump in builder, ret=%d", ret); - return -1; - } - - if (build_holder->has_taglist()) { - size_t taglist_size{0}; - const void *taglist_data = build_holder->get_taglist_data(taglist_size); - const void *key_base = build_holder->get_key_base(); - - dump_taglist(dumper, build_holder->get_num_vecs(), key_base, taglist_data, - taglist_size); - } - - ret = dumper->close(); - if (ret != 0) { - LOG_ERROR("Dumper failed to close, ret=%d", ret); - return -1; - } - std::cout << "Dump to [" << dump_prefix << "] finished, consume " << dump_time - << "ms." << std::endl; - if (builder) { auto &stats = reinterpret_cast(builder.get())->stats(); @@ -987,8 +912,9 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { cout << "Created builder " << builder_class << endl; - IndexHolder::Pointer cv_build_holder = - convert_holder(converter_name, converter_params, build_holder, meta); + IndexConverter::Pointer build_converter; + IndexHolder::Pointer cv_build_holder = convert_holder( + converter_name, converter_params, build_holder, meta, &build_converter); if (!cv_build_holder) { LOG_ERROR("Convert holder failed."); return -1; @@ -1079,8 +1005,8 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { // support fp16 convert - IndexHolder::Pointer cv_train_holder = - convert_holder(converter_name, converter_params, train_holder, meta); + IndexHolder::Pointer cv_train_holder = convert_holder( + converter_name, converter_params, train_holder, meta, nullptr); if (!cv_train_holder) { LOG_ERROR("Convert train holder failed."); return -1; @@ -1136,8 +1062,8 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { if (!metric_name.empty()) { train_holder->set_metric(metric_name, metric_params); } - IndexHolder::Pointer cv_train_holder = - convert_holder(converter_name, converter_params, train_holder, meta); + IndexHolder::Pointer cv_train_holder = convert_holder( + converter_name, converter_params, train_holder, meta, nullptr); if (!cv_train_holder) { LOG_ERROR("Convert train holder failed."); return -1; @@ -1177,7 +1103,7 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { retrieval_mode = "dense"; } - ret = build_by_streamer(streamer, config_common); + ret = build_by_streamer(streamer, config_common, build_converter); } size_t build_time = timer.milli_seconds(); if (ret < 0) { @@ -1187,45 +1113,6 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { cout << "Build finished, consume " << build_time << "ms." << endl; signal(SIGINT, SIG_DFL); - // DUMP - IndexDumper::Pointer dumper = IndexFactory::CreateDumper("FileDumper"); - if (!dumper) { - LOG_ERROR("Failed to create FileDumper."); - return -1; - } - string dump_prefix = config_common["DumpPath"].as(); - ret = dumper->create(dump_prefix); - if (ret != 0) { - LOG_ERROR("Failed to create in dumper, ret=%d", ret); - return -1; - } - timer.reset(); - ret = streamer ? streamer->dump(dumper) : builder->dump(dumper); - size_t dump_time = timer.milli_seconds(); - if (ret == IndexError_NotImplemented) { - LOG_WARN("Dump index not implemented"); - } else if (ret < 0) { - LOG_ERROR("Failed to dump in builder, ret=%d", ret); - return -1; - } - - if (build_holder->has_taglist()) { - size_t taglist_size{0}; - const void *taglist_data = build_holder->get_taglist_data(taglist_size); - const void *key_base = build_holder->get_key_base(); - - dump_taglist(dumper, build_holder->get_num_vecs(), key_base, taglist_data, - taglist_size); - } - - ret = dumper->close(); - if (ret != 0) { - LOG_ERROR("Dumper failed to close, ret=%d", ret); - return -1; - } - std::cout << "Dump to [" << dump_prefix << "] finished, consume " << dump_time - << "ms." << std::endl; - if (builder) { auto &stats = reinterpret_cast(builder.get())->stats();