From 6c5c44ebb29a6cd8b4529b6539dc5d29c5d02285 Mon Sep 17 00:00:00 2001 From: Nelson Spence Date: Thu, 28 May 2026 10:38:52 -0500 Subject: [PATCH 1/3] feat: add RankQuant eval scorer Signed-off-by: Nelson Spence --- ordvec-python/python/ordvec/__init__.py | 15 ++-- ordvec-python/src/lib.rs | 68 +++++++++++++++ ordvec-python/tests/test_rank_quant.py | 92 +++++++++++++++++++- src/lib.rs | 2 +- src/quant.rs | 109 ++++++++++++++++++++++++ tests/index/quant.rs | 104 +++++++++++++++++++++- 6 files changed, 381 insertions(+), 9 deletions(-) diff --git a/ordvec-python/python/ordvec/__init__.py b/ordvec-python/python/ordvec/__init__.py index 6010dcb..5026e64 100644 --- a/ordvec-python/python/ordvec/__init__.py +++ b/ordvec-python/python/ordvec/__init__.py @@ -7,12 +7,13 @@ ``SignBitmap``, plus the module-level rank-math primitives (``rank_transform``, ``rank_to_bucket``, ``bucket_ranks``, ``pack_buckets``, ``unpack_buckets``, ``rankquant_bytes_per_vec``, ``bucket_centre``, ``rank_norm``, -``rankquant_norm``), the byte-LUT scoring helper ``search_asymmetric_byte_lut``, -and the loader limit constants (``MAX_DIM``, ``MAX_SIGN_BITMAP_DIM``, -``MAX_VECTORS``). Together with the four classes' methods this mirrors the Rust -crate's public API; the low-level ``rank_io`` read/write functions are reached -through the classes' ``write()`` / ``load()`` methods rather than exposed as -standalone free functions. +``rankquant_norm``), the eval-only arbitrary-width scorer +``rankquant_eval_search``, the byte-LUT scoring helper +``search_asymmetric_byte_lut``, and the loader limit constants (``MAX_DIM``, +``MAX_SIGN_BITMAP_DIM``, ``MAX_VECTORS``). Together with the four classes' +methods this mirrors the Rust crate's public API; the low-level ``rank_io`` +read/write functions are reached through the classes' ``write()`` / ``load()`` +methods rather than exposed as standalone free functions. The ``*Index`` names are back-compat aliases for the pre-0.2 turbovec-python rank-mode classes; they are kept only to ease script migration and are not part @@ -47,6 +48,7 @@ rank_norm, rank_to_bucket, rank_transform, + rankquant_eval_search, rankquant_bytes_per_vec, rankquant_norm, search_asymmetric_byte_lut, @@ -77,6 +79,7 @@ "bucket_centre", "rank_norm", "rankquant_norm", + "rankquant_eval_search", "search_asymmetric_byte_lut", # loader limit constants "MAX_DIM", diff --git a/ordvec-python/src/lib.rs b/ordvec-python/src/lib.rs index a770cb8..adfe7a2 100644 --- a/ordvec-python/src/lib.rs +++ b/ordvec-python/src/lib.rs @@ -95,6 +95,17 @@ fn check_bits_max7(bits: u8) -> PyResult<()> { Ok(()) } +/// Eval-only RankQuant scoring supports non-byte-aligned widths but still needs +/// at least two buckets and a bucket alphabet representable by `u8`. +fn check_bits_1_7(bits: u8) -> PyResult<()> { + if !(1..=7).contains(&bits) { + return Err(pyo3::exceptions::PyValueError::new_err( + "bits must be in 1..=7", + )); + } + Ok(()) +} + fn not_contiguous_err() -> PyErr { pyo3::exceptions::PyValueError::new_err( "array must be C-contiguous; call np.ascontiguousarray() first", @@ -265,6 +276,14 @@ fn axis_len(arr: &Bound<'_, PyAny>, axis: usize) -> PyResult { arr.getattr("shape")?.get_item(axis)?.extract::() } +fn infer_float_2d_width(arr: &Bound<'_, PyAny>) -> PyResult { + if let Ok(a) = arr.cast::>() { + return Ok(a.readonly().as_array().ncols()); + } + gate_float_ndim(arr, 2)?; + axis_len(arr, 1) +} + /// Present an embedding vector as a 1-D `float32` `PyReadonlyArray`, converting at /// the boundary. The premise of ordvec is *float vector in → rank/sign transform*, /// so float32 is the internal working dtype, not a contract the caller must @@ -1501,6 +1520,54 @@ fn search_asymmetric_byte_lut<'py>( Ok((scores, indices)) } +/// Eval-only symmetric RankQuant-style search for arbitrary `bits` in `1..=7`. +/// +/// This rank-transforms and buckets the raw `corpus`/`queries` matrices on the +/// fly, so it supports non-byte-aligned widths such as `bits=3` without changing +/// `RankQuant` storage or `.tvrq` persistence. Returns `(scores, indices)` with +/// the same shape contract as `RankQuant.search`. +#[pyfunction] +fn rankquant_eval_search<'py>( + py: Python<'py>, + corpus: &Bound<'py, PyAny>, + queries: &Bound<'py, PyAny>, + bits: u8, + k: usize, +) -> PyResult> { + check_bits_1_7(bits)?; + let dim = infer_float_2d_width(corpus)?; + if !(2..=u16::MAX as usize).contains(&dim) { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "corpus width must be in [2, {}]", + u16::MAX + ))); + } + let corpus = as_f32_2d(corpus, dim)?; + let queries = as_f32_2d(queries, dim)?; + let q_arr = queries.as_array(); + let nq = q_arr.nrows(); + let corpus_arr = corpus.as_array(); + let corpus_slice = corpus_arr.as_slice().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "array must be C-contiguous; call np.ascontiguousarray() first", + ) + })?; + let query_slice = q_arr.as_slice().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "array must be C-contiguous; call np.ascontiguousarray() first", + ) + })?; + let results = + py.detach(|| ordvec_core::rankquant_eval_search(corpus_slice, query_slice, dim, bits, k)); + let scores = numpy::ndarray::Array2::from_shape_vec((nq, results.k), results.scores) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))? + .into_pyarray(py); + let indices = numpy::ndarray::Array2::from_shape_vec((nq, results.k), results.indices) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))? + .into_pyarray(py); + Ok((scores, indices)) +} + /// The native extension module backing the `ordvec` Python package. #[pymodule] fn _ordvec(m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -1521,6 +1588,7 @@ fn _ordvec(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(rank_norm, m)?)?; m.add_function(wrap_pyfunction!(rankquant_norm, m)?)?; m.add_function(wrap_pyfunction!(search_asymmetric_byte_lut, m)?)?; + m.add_function(wrap_pyfunction!(rankquant_eval_search, m)?)?; // Loader/limit constants (parity with `ordvec::rank_io::*`). m.add("MAX_DIM", ordvec_core::rank_io::MAX_DIM)?; diff --git a/ordvec-python/tests/test_rank_quant.py b/ordvec-python/tests/test_rank_quant.py index 1f1023b..701c812 100644 --- a/ordvec-python/tests/test_rank_quant.py +++ b/ordvec-python/tests/test_rank_quant.py @@ -12,7 +12,7 @@ import numpy as np import pytest -from ordvec import RankQuant +from ordvec import RankQuant, rankquant_eval_search def unit_vectors(n: int, dim: int, seed: int = 0) -> np.ndarray: @@ -22,6 +22,51 @@ def unit_vectors(n: int, dim: int, seed: int = 0) -> np.ndarray: return v +def rank_transform_reference(row: np.ndarray) -> np.ndarray: + order = np.lexsort((np.arange(row.size), row)) + ranks = np.empty(row.size, dtype=np.uint16) + ranks[order] = np.arange(row.size, dtype=np.uint16) + return ranks + + +def rankquant_eval_reference( + corpus: np.ndarray, queries: np.ndarray, bits: int, k: int +) -> tuple[np.ndarray, np.ndarray]: + dim = corpus.shape[1] + n_buckets = 1 << bits + rank_positions = np.arange(dim, dtype=np.uint64) + bucket_by_rank = (rank_positions * n_buckets // dim).astype(np.uint8) + centre_by_rank = bucket_by_rank.astype(np.float32) - ((n_buckets - 1) / 2.0) + norm = np.sqrt(np.sum(centre_by_rank * centre_by_rank, dtype=np.float64)).astype( + np.float32 + ) + + def centres(row: np.ndarray) -> np.ndarray: + ranks = rank_transform_reference(row) + buckets = (ranks.astype(np.uint64) * n_buckets // dim).astype(np.uint8) + return buckets.astype(np.float32) - ((n_buckets - 1) / 2.0) + + k_eff = min(k, corpus.shape[0]) + if k_eff == 0: + return ( + np.empty((queries.shape[0], 0), dtype=np.float32), + np.empty((queries.shape[0], 0), dtype=np.int64), + ) + + doc_centres = np.vstack([centres(row) for row in corpus]) + scores = np.empty((queries.shape[0], k_eff), dtype=np.float32) + indices = np.empty((queries.shape[0], k_eff), dtype=np.int64) + doc_ids = np.arange(corpus.shape[0], dtype=np.int64) + scale = np.float32(1.0) / (norm * norm) + for qi, query in enumerate(queries): + q_centres = centres(query) + row_scores = (doc_centres @ q_centres).astype(np.float32) * scale + order = np.lexsort((doc_ids, -row_scores))[:k_eff] + scores[qi] = row_scores[order] + indices[qi] = order + return scores, indices + + @pytest.mark.parametrize("bits", [1, 2, 4]) def test_new_reports_dim_and_bits(bits): # dim must be a multiple of 2^bits; 128 is divisible by 2, 4, 16. @@ -68,6 +113,45 @@ def test_search_asymmetric_shape(bits): assert indices.shape == (3, 10) +@pytest.mark.parametrize("bits", [1, 2, 4]) +def test_rankquant_eval_search_matches_rankquant_search(bits): + vectors = unit_vectors(45, 128, seed=31 + bits) + queries = unit_vectors(5, 128, seed=41 + bits) + idx = RankQuant(dim=128, bits=bits) + idx.add(vectors) + + packed_scores, packed_ids = idx.search(queries, k=8) + eval_scores, eval_ids = rankquant_eval_search(vectors, queries, bits=bits, k=8) + + np.testing.assert_array_equal(eval_ids, packed_ids) + np.testing.assert_allclose(eval_scores, packed_scores, rtol=1e-6, atol=1e-6) + + +def test_rankquant_eval_search_b3_matches_numpy_reference(): + vectors = unit_vectors(36, 128, seed=51) + queries = unit_vectors(4, 128, seed=52) + + scores, ids = rankquant_eval_search(vectors, queries, bits=3, k=9) + ref_scores, ref_ids = rankquant_eval_reference(vectors, queries, bits=3, k=9) + + assert scores.shape == (4, 9) + assert ids.shape == (4, 9) + assert scores.dtype == np.float32 + assert ids.dtype == np.int64 + np.testing.assert_array_equal(ids, ref_ids) + np.testing.assert_allclose(scores, ref_scores, rtol=1e-6, atol=1e-6) + + +def test_rankquant_eval_search_empty_corpus_shape(): + vectors = np.empty((0, 64), dtype=np.float32) + queries = unit_vectors(3, 64, seed=53) + + scores, ids = rankquant_eval_search(vectors, queries, bits=3, k=10) + + assert scores.shape == (3, 0) + assert ids.shape == (3, 0) + + @pytest.mark.parametrize("bits", [2, 4]) def test_self_query_recall_at_1(bits): # 1-bit is too lossy for a strict per-row self-query at this dim; @@ -86,6 +170,12 @@ def test_invalid_bits_rejected(): RankQuant(dim=64, bits=3) with pytest.raises(ValueError, match="bits"): RankQuant(dim=64, bits=8) + vectors = unit_vectors(4, 64, seed=54) + queries = unit_vectors(1, 64, seed=55) + with pytest.raises(ValueError, match="bits"): + rankquant_eval_search(vectors, queries, bits=0, k=2) + with pytest.raises(ValueError, match="bits"): + rankquant_eval_search(vectors, queries, bits=8, k=2) def test_dim_not_multiple_of_two_pow_bits_rejected(): diff --git a/src/lib.rs b/src/lib.rs index a2134e0..3972dec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,7 +52,7 @@ pub mod sign_bitmap; mod util; pub use bitmap::Bitmap; -pub use quant::RankQuant; +pub use quant::{rankquant_eval_search, RankQuant}; pub use rank::Rank; pub use sign_bitmap::SignBitmap; diff --git a/src/quant.rs b/src/quant.rs index e2e6907..bffb1a8 100644 --- a/src/quant.rs +++ b/src/quant.rs @@ -26,6 +26,32 @@ use crate::rank::{ use crate::util::{assert_all_finite, l2_normalise, result_buffer_len, TopK}; use crate::SearchResults; +fn check_eval_bits(bits: u8) { + assert!((1..=7).contains(&bits), "bits must be in 1..=7"); +} + +fn rankquant_eval_norm(dim: usize, bits: u8) -> f32 { + check_eval_bits(bits); + assert!(dim >= 2, "dim must be >= 2"); + assert!(dim <= u16::MAX as usize, "dim must fit in u16"); + let mut acc = 0.0f64; + for rank in 0..dim { + let b = rank_to_bucket(rank as u16, dim, bits); + let c = bucket_centre(b, bits) as f64; + acc += c * c; + } + acc.sqrt() as f32 +} + +fn rankquant_eval_centres(v: &[f32], bits: u8, out: &mut [f32]) { + debug_assert_eq!(v.len(), out.len()); + let ranks = rank_transform(v); + for (dst, rank) in out.iter_mut().zip(ranks) { + let bucket = rank_to_bucket(rank, v.len(), bits); + *dst = bucket_centre(bucket, bits); + } +} + /// `B`-bit RankQuant index. /// /// Each document is encoded by bucketing its rank vector into @@ -643,6 +669,89 @@ impl RankQuant { } } +/// Standalone symmetric RankQuant-style eval search for arbitrary bit widths. +/// +/// This does **not** use [`RankQuant`] storage and does not change the `.tvrq` +/// packing contract. It rank-transforms `corpus` and `queries`, buckets each +/// rank into `1 << bits` equal-width bins, mean-centres bucket ids, normalises +/// by the analytical norm for that `(dim, bits)`, and returns top-`k` results. +/// +/// Intended for research/eval sweeps where non-byte-aligned widths such as +/// `bits = 3` need to be scored without inventing a persistent packed format. +pub fn rankquant_eval_search( + corpus: &[f32], + queries: &[f32], + dim: usize, + bits: u8, + k: usize, +) -> SearchResults { + check_eval_bits(bits); + assert!(dim >= 2, "dim must be >= 2"); + assert!(dim <= u16::MAX as usize, "dim must fit in u16"); + let n = corpus.len() / dim; + let nq = queries.len() / dim; + assert_eq!( + corpus.len(), + n * dim, + "corpus length must be a multiple of dim" + ); + assert_eq!( + queries.len(), + nq * dim, + "queries length must be a multiple of dim" + ); + assert_all_finite(corpus); + assert_all_finite(queries); + + let k = k.min(n); + let k_eff = k; + let buf_len = result_buffer_len(nq, k); + if k_eff == 0 { + return SearchResults { + scores: vec![0.0; buf_len], + indices: vec![-1; buf_len], + nq, + k, + }; + } + + let norm = rankquant_eval_norm(dim, bits); + let inv_norm_sq = 1.0_f32 / (norm * norm); + let mut doc_centres = vec![0.0f32; n * dim]; + doc_centres + .par_chunks_mut(dim) + .zip(corpus.par_chunks(dim)) + .for_each(|(out, doc)| rankquant_eval_centres(doc, bits, out)); + + let mut scores_flat = vec![0.0f32; buf_len]; + let mut indices_flat = vec![-1i64; buf_len]; + queries + .par_chunks(dim) + .zip(scores_flat.par_chunks_mut(k)) + .zip(indices_flat.par_chunks_mut(k)) + .for_each(|((q, out_scores), out_indices)| { + let mut q_centres = vec![0.0f32; dim]; + rankquant_eval_centres(q, bits, &mut q_centres); + let mut top = TopK::new(k_eff); + for di in 0..n { + let doc = &doc_centres[di * dim..(di + 1) * dim]; + let mut acc = 0.0f32; + for d in 0..dim { + acc += q_centres[d] * doc[d]; + } + top.maybe_insert(acc * inv_norm_sq, di); + } + top.finalize_into(out_scores, out_indices); + }); + + SearchResults { + scores: scores_flat, + indices: indices_flat, + nq, + k, + } +} + // ------------------------------------------------------------------- // Byte-LUT scoring (asymmetric, B = 2 and B = 4). // diff --git a/tests/index/quant.rs b/tests/index/quant.rs index 79d9e6e..30ea9de 100644 --- a/tests/index/quant.rs +++ b/tests/index/quant.rs @@ -1,11 +1,37 @@ //! RankQuant (B-bit bucket-packed) integration tests. -use ordvec::RankQuant; +use ordvec::rank::{bucket_centre, rank_to_bucket, rank_transform}; +use ordvec::{rankquant_eval_search, RankQuant}; use rand::{RngExt, SeedableRng}; use rand_chacha::ChaCha8Rng; use crate::{make_corpus, ref_rankquant_asymmetric, D, N}; +fn ref_rankquant_eval_norm(dim: usize, bits: u8) -> f32 { + let mut acc = 0.0f32; + for rank in 0..dim { + let b = rank_to_bucket(rank as u16, dim, bits); + let c = bucket_centre(b, bits); + acc += c * c; + } + acc.sqrt() +} + +fn ref_rankquant_eval_symmetric(a: &[f32], b: &[f32], bits: u8) -> f32 { + let dim = a.len(); + let ra = rank_transform(a); + let rb = rank_transform(b); + let norm = ref_rankquant_eval_norm(dim, bits); + let inv_norm_sq = 1.0f32 / (norm * norm); + let mut acc = 0.0f32; + for d in 0..dim { + let ba = rank_to_bucket(ra[d], dim, bits); + let bb = rank_to_bucket(rb[d], dim, bits); + acc += bucket_centre(ba, bits) * bucket_centre(bb, bits); + } + acc * inv_norm_sq +} + #[test] fn rankquant_asymmetric_matches_reference_b2() { rankquant_asymmetric_matches_reference(2); @@ -21,6 +47,82 @@ fn rankquant_asymmetric_matches_reference_b1() { rankquant_asymmetric_matches_reference(1); } +#[test] +fn rankquant_eval_search_matches_rankquant_search_for_packed_widths() { + let corpus = make_corpus(71); + let mut rng = ChaCha8Rng::seed_from_u64(72); + let nq = 5; + let queries: Vec = (0..nq * D).map(|_| rng.random_range(-1.0..1.0)).collect(); + + for bits in [1u8, 2, 4] { + let mut idx = RankQuant::new(D, bits); + idx.add(&corpus); + + let packed = idx.search(&queries, 12); + let eval = rankquant_eval_search(&corpus, &queries, D, bits, 12); + + assert_eq!(eval.nq, packed.nq); + assert_eq!(eval.k, packed.k); + assert_eq!( + eval.indices, packed.indices, + "eval search top-k diverged from RankQuant::search at bits={bits}", + ); + for (slot, (&a, &b)) in eval.scores.iter().zip(&packed.scores).enumerate() { + assert!( + (a - b).abs() < 1e-6, + "bits={bits} slot {slot}: eval score {a} vs packed score {b}", + ); + } + } +} + +#[test] +fn rankquant_eval_search_b3_matches_scalar_reference() { + let corpus = make_corpus(73); + let mut rng = ChaCha8Rng::seed_from_u64(74); + let nq = 4; + let queries: Vec = (0..nq * D).map(|_| rng.random_range(-1.0..1.0)).collect(); + let res = rankquant_eval_search(&corpus, &queries, D, 3, 10); + + assert_eq!(res.nq, nq); + assert_eq!(res.k, 10); + for qi in 0..nq { + let q = &queries[qi * D..(qi + 1) * D]; + let mut reference: Vec<(f32, i64)> = (0..N) + .map(|di| { + ( + ref_rankquant_eval_symmetric(q, &corpus[di * D..(di + 1) * D], 3), + di as i64, + ) + }) + .collect(); + reference.sort_unstable_by(|a, b| b.0.total_cmp(&a.0).then_with(|| a.1.cmp(&b.1))); + let ref_top = &reference[..10]; + let ref_ids: Vec = ref_top.iter().map(|&(_, di)| di).collect(); + assert_eq!( + res.indices_for_query(qi), + ref_ids.as_slice(), + "b=3 eval top-k ids diverged for query {qi}", + ); + for (slot, &(s_ref, _)) in ref_top.iter().enumerate() { + let s = res.scores_for_query(qi)[slot]; + assert!( + (s - s_ref).abs() < 1e-6, + "query {qi} slot {slot}: b=3 eval score {s} vs reference {s_ref}", + ); + } + } +} + +#[test] +fn rankquant_constructor_still_rejects_b3() { + let err = std::panic::catch_unwind(|| RankQuant::new(D, 3)); + assert!( + err.is_err(), + "RankQuant::new must keep the packed-width domain" + ); +} + fn rankquant_asymmetric_matches_reference(bits: u8) { let corpus = make_corpus(3 + bits as u64); let mut idx = RankQuant::new(D, bits); From dc24254c0fd165ec0036927fdcfea11f8ccbeb06 Mon Sep 17 00:00:00 2001 From: Nelson Spence Date: Thu, 28 May 2026 11:15:00 -0500 Subject: [PATCH 2/3] fix rankquant eval empty-query path Signed-off-by: Nelson Spence --- ordvec-python/tests/test_rank_quant.py | 12 ++++++++++++ src/quant.rs | 10 +++------- tests/index/quant.rs | 12 ++++++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/ordvec-python/tests/test_rank_quant.py b/ordvec-python/tests/test_rank_quant.py index 701c812..413ff5d 100644 --- a/ordvec-python/tests/test_rank_quant.py +++ b/ordvec-python/tests/test_rank_quant.py @@ -152,6 +152,18 @@ def test_rankquant_eval_search_empty_corpus_shape(): assert ids.shape == (3, 0) +def test_rankquant_eval_search_empty_queries_shape(): + vectors = unit_vectors(4, 64, seed=56) + queries = np.empty((0, 64), dtype=np.float32) + + scores, ids = rankquant_eval_search(vectors, queries, bits=3, k=10) + + assert scores.shape == (0, 4) + assert ids.shape == (0, 4) + assert scores.dtype == np.float32 + assert ids.dtype == np.int64 + + @pytest.mark.parametrize("bits", [2, 4]) def test_self_query_recall_at_1(bits): # 1-bit is too lossy for a strict per-row self-query at this dim; diff --git a/src/quant.rs b/src/quant.rs index bffb1a8..c74236a 100644 --- a/src/quant.rs +++ b/src/quant.rs @@ -706,7 +706,7 @@ pub fn rankquant_eval_search( let k = k.min(n); let k_eff = k; let buf_len = result_buffer_len(nq, k); - if k_eff == 0 { + if nq == 0 || k_eff == 0 { return SearchResults { scores: vec![0.0; buf_len], indices: vec![-1; buf_len], @@ -733,12 +733,8 @@ pub fn rankquant_eval_search( let mut q_centres = vec![0.0f32; dim]; rankquant_eval_centres(q, bits, &mut q_centres); let mut top = TopK::new(k_eff); - for di in 0..n { - let doc = &doc_centres[di * dim..(di + 1) * dim]; - let mut acc = 0.0f32; - for d in 0..dim { - acc += q_centres[d] * doc[d]; - } + for (di, doc) in doc_centres.chunks_exact(dim).enumerate() { + let acc: f32 = q_centres.iter().zip(doc).map(|(q, d)| q * d).sum(); top.maybe_insert(acc * inv_norm_sq, di); } top.finalize_into(out_scores, out_indices); diff --git a/tests/index/quant.rs b/tests/index/quant.rs index 30ea9de..bf99a50 100644 --- a/tests/index/quant.rs +++ b/tests/index/quant.rs @@ -114,6 +114,18 @@ fn rankquant_eval_search_b3_matches_scalar_reference() { } } +#[test] +fn rankquant_eval_search_empty_queries_does_not_transform_corpus() { + let corpus = make_corpus(75); + let queries: Vec = Vec::new(); + let res = rankquant_eval_search(&corpus, &queries, D, 3, 10); + + assert_eq!(res.nq, 0); + assert_eq!(res.k, 10); + assert!(res.scores.is_empty()); + assert!(res.indices.is_empty()); +} + #[test] fn rankquant_constructor_still_rejects_b3() { let err = std::panic::catch_unwind(|| RankQuant::new(D, 3)); From 2ba013ad24ebefc91cc54afaddc3719c667038f1 Mon Sep 17 00:00:00 2001 From: Nelson Spence Date: Thu, 28 May 2026 11:16:48 -0500 Subject: [PATCH 3/3] test rankquant eval against numpy reference Signed-off-by: Nelson Spence --- ordvec-python/tests/test_rank_quant.py | 11 ++++++----- src/quant.rs | 25 ++++++++++++++++++++----- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/ordvec-python/tests/test_rank_quant.py b/ordvec-python/tests/test_rank_quant.py index 413ff5d..cc2893e 100644 --- a/ordvec-python/tests/test_rank_quant.py +++ b/ordvec-python/tests/test_rank_quant.py @@ -127,12 +127,13 @@ def test_rankquant_eval_search_matches_rankquant_search(bits): np.testing.assert_allclose(eval_scores, packed_scores, rtol=1e-6, atol=1e-6) -def test_rankquant_eval_search_b3_matches_numpy_reference(): - vectors = unit_vectors(36, 128, seed=51) - queries = unit_vectors(4, 128, seed=52) +@pytest.mark.parametrize("bits", [1, 2, 3, 4]) +def test_rankquant_eval_search_matches_numpy_reference(bits): + vectors = unit_vectors(36, 128, seed=51 + bits) + queries = unit_vectors(4, 128, seed=61 + bits) - scores, ids = rankquant_eval_search(vectors, queries, bits=3, k=9) - ref_scores, ref_ids = rankquant_eval_reference(vectors, queries, bits=3, k=9) + scores, ids = rankquant_eval_search(vectors, queries, bits=bits, k=9) + ref_scores, ref_ids = rankquant_eval_reference(vectors, queries, bits=bits, k=9) assert scores.shape == (4, 9) assert ids.shape == (4, 9) diff --git a/src/quant.rs b/src/quant.rs index c74236a..2d9562d 100644 --- a/src/quant.rs +++ b/src/quant.rs @@ -52,6 +52,14 @@ fn rankquant_eval_centres(v: &[f32], bits: u8, out: &mut [f32]) { } } +fn rankquant_eval_buckets(v: &[f32], bits: u8, out: &mut [u8]) { + debug_assert_eq!(v.len(), out.len()); + let ranks = rank_transform(v); + for (dst, rank) in out.iter_mut().zip(ranks) { + *dst = rank_to_bucket(rank, v.len(), bits); + } +} + /// `B`-bit RankQuant index. /// /// Each document is encoded by bucketing its rank vector into @@ -717,11 +725,14 @@ pub fn rankquant_eval_search( let norm = rankquant_eval_norm(dim, bits); let inv_norm_sq = 1.0_f32 / (norm * norm); - let mut doc_centres = vec![0.0f32; n * dim]; - doc_centres + let centres: Vec = (0..(1usize << bits)) + .map(|bucket| bucket_centre(bucket as u8, bits)) + .collect(); + let mut doc_buckets = vec![0u8; n * dim]; + doc_buckets .par_chunks_mut(dim) .zip(corpus.par_chunks(dim)) - .for_each(|(out, doc)| rankquant_eval_centres(doc, bits, out)); + .for_each(|(out, doc)| rankquant_eval_buckets(doc, bits, out)); let mut scores_flat = vec![0.0f32; buf_len]; let mut indices_flat = vec![-1i64; buf_len]; @@ -733,8 +744,12 @@ pub fn rankquant_eval_search( let mut q_centres = vec![0.0f32; dim]; rankquant_eval_centres(q, bits, &mut q_centres); let mut top = TopK::new(k_eff); - for (di, doc) in doc_centres.chunks_exact(dim).enumerate() { - let acc: f32 = q_centres.iter().zip(doc).map(|(q, d)| q * d).sum(); + for (di, doc) in doc_buckets.chunks_exact(dim).enumerate() { + let acc: f32 = q_centres + .iter() + .zip(doc) + .map(|(q, &bucket)| q * centres[bucket as usize]) + .sum(); top.maybe_insert(acc * inv_norm_sq, di); } top.finalize_into(out_scores, out_indices);