Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions ordvec-python/python/ordvec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
``SignBitmap``, plus the module-level rank-math primitives (``rank_transform``,
``rank_to_bucket``, ``bucket_ranks``, ``pack_buckets``, ``unpack_buckets``,
``rankquant_bytes_per_vec``, ``bucket_centre``, ``rank_norm``,
``rankquant_norm``), the byte-LUT scoring helper ``search_asymmetric_byte_lut``,
and the loader limit constants (``MAX_DIM``, ``MAX_SIGN_BITMAP_DIM``,
``MAX_VECTORS``). Together with the four classes' methods this mirrors the Rust
crate's public API; the low-level ``rank_io`` read/write functions are reached
through the classes' ``write()`` / ``load()`` methods rather than exposed as
standalone free functions.
``rankquant_norm``), the eval-only arbitrary-width scorer
``rankquant_eval_search``, the byte-LUT scoring helper
``search_asymmetric_byte_lut``, and the loader limit constants (``MAX_DIM``,
``MAX_SIGN_BITMAP_DIM``, ``MAX_VECTORS``). Together with the four classes'
methods this mirrors the Rust crate's public API; the low-level ``rank_io``
read/write functions are reached through the classes' ``write()`` / ``load()``
methods rather than exposed as standalone free functions.

The ``*Index`` names are back-compat aliases for the pre-0.2 turbovec-python
rank-mode classes; they are kept only to ease script migration and are not part
Expand Down Expand Up @@ -47,6 +48,7 @@
rank_norm,
rank_to_bucket,
rank_transform,
rankquant_eval_search,
rankquant_bytes_per_vec,
rankquant_norm,
search_asymmetric_byte_lut,
Expand Down Expand Up @@ -77,6 +79,7 @@
"bucket_centre",
"rank_norm",
"rankquant_norm",
"rankquant_eval_search",
"search_asymmetric_byte_lut",
# loader limit constants
"MAX_DIM",
Expand Down
68 changes: 68 additions & 0 deletions ordvec-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ fn check_bits_max7(bits: u8) -> PyResult<()> {
Ok(())
}

/// Eval-only RankQuant scoring supports non-byte-aligned widths but still needs
/// at least two buckets and a bucket alphabet representable by `u8`.
fn check_bits_1_7(bits: u8) -> PyResult<()> {
if !(1..=7).contains(&bits) {
return Err(pyo3::exceptions::PyValueError::new_err(
"bits must be in 1..=7",
));
}
Ok(())
}

fn not_contiguous_err() -> PyErr {
pyo3::exceptions::PyValueError::new_err(
"array must be C-contiguous; call np.ascontiguousarray() first",
Expand Down Expand Up @@ -267,6 +278,14 @@ fn axis_len(arr: &Bound<'_, PyAny>, axis: usize) -> PyResult<usize> {
arr.getattr("shape")?.get_item(axis)?.extract::<usize>()
}

fn infer_float_2d_width(arr: &Bound<'_, PyAny>) -> PyResult<usize> {
if let Ok(a) = arr.cast::<PyArray2<f32>>() {
return Ok(a.readonly().as_array().ncols());
}
gate_float_ndim(arr, 2)?;
axis_len(arr, 1)
}

/// Present an embedding vector as a 1-D `float32` `PyReadonlyArray`, converting at
/// the boundary. The premise of ordvec is *float vector in → rank/sign transform*,
/// so float32 is the internal working dtype, not a contract the caller must
Expand Down Expand Up @@ -1547,6 +1566,54 @@ fn search_asymmetric_byte_lut<'py>(
Ok((scores, indices))
}

/// Eval-only symmetric RankQuant-style search for arbitrary `bits` in `1..=7`.
///
/// This rank-transforms and buckets the raw `corpus`/`queries` matrices on the
/// fly, so it supports non-byte-aligned widths such as `bits=3` without changing
/// `RankQuant` storage or `.tvrq` persistence. Returns `(scores, indices)` with
/// the same shape contract as `RankQuant.search`.
#[pyfunction]
fn rankquant_eval_search<'py>(
py: Python<'py>,
corpus: &Bound<'py, PyAny>,
queries: &Bound<'py, PyAny>,
bits: u8,
k: usize,
) -> PyResult<SearchArrays<'py>> {
check_bits_1_7(bits)?;
let dim = infer_float_2d_width(corpus)?;
if !(2..=u16::MAX as usize).contains(&dim) {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"corpus width must be in [2, {}]",
u16::MAX
)));
}
let corpus = as_f32_2d(corpus, dim)?;
let queries = as_f32_2d(queries, dim)?;
let q_arr = queries.as_array();
let nq = q_arr.nrows();
let corpus_arr = corpus.as_array();
let corpus_slice = corpus_arr.as_slice().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"array must be C-contiguous; call np.ascontiguousarray() first",
)
})?;
let query_slice = q_arr.as_slice().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"array must be C-contiguous; call np.ascontiguousarray() first",
)
})?;
let results =
py.detach(|| ordvec_core::rankquant_eval_search(corpus_slice, query_slice, dim, bits, k));
let scores = numpy::ndarray::Array2::from_shape_vec((nq, results.k), results.scores)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?
.into_pyarray(py);
let indices = numpy::ndarray::Array2::from_shape_vec((nq, results.k), results.indices)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?
.into_pyarray(py);
Ok((scores, indices))
}

/// The native extension module backing the `ordvec` Python package.
#[pymodule]
fn _ordvec(m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand All @@ -1567,6 +1634,7 @@ fn _ordvec(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(rank_norm, m)?)?;
m.add_function(wrap_pyfunction!(rankquant_norm, m)?)?;
m.add_function(wrap_pyfunction!(search_asymmetric_byte_lut, m)?)?;
m.add_function(wrap_pyfunction!(rankquant_eval_search, m)?)?;

// Loader/limit constants (parity with `ordvec::rank_io::*`).
m.add("MAX_DIM", ordvec_core::rank_io::MAX_DIM)?;
Expand Down
105 changes: 104 additions & 1 deletion ordvec-python/tests/test_rank_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import pytest

from ordvec import RankQuant
from ordvec import RankQuant, rankquant_eval_search


def unit_vectors(n: int, dim: int, seed: int = 0) -> np.ndarray:
Expand All @@ -22,6 +22,51 @@ def unit_vectors(n: int, dim: int, seed: int = 0) -> np.ndarray:
return v


def rank_transform_reference(row: np.ndarray) -> np.ndarray:
order = np.lexsort((np.arange(row.size), row))
ranks = np.empty(row.size, dtype=np.uint16)
ranks[order] = np.arange(row.size, dtype=np.uint16)
return ranks


def rankquant_eval_reference(
corpus: np.ndarray, queries: np.ndarray, bits: int, k: int
) -> tuple[np.ndarray, np.ndarray]:
dim = corpus.shape[1]
n_buckets = 1 << bits
rank_positions = np.arange(dim, dtype=np.uint64)
bucket_by_rank = (rank_positions * n_buckets // dim).astype(np.uint8)
centre_by_rank = bucket_by_rank.astype(np.float32) - ((n_buckets - 1) / 2.0)
norm = np.sqrt(np.sum(centre_by_rank * centre_by_rank, dtype=np.float64)).astype(
np.float32
)

def centres(row: np.ndarray) -> np.ndarray:
ranks = rank_transform_reference(row)
buckets = (ranks.astype(np.uint64) * n_buckets // dim).astype(np.uint8)
return buckets.astype(np.float32) - ((n_buckets - 1) / 2.0)

k_eff = min(k, corpus.shape[0])
if k_eff == 0:
return (
np.empty((queries.shape[0], 0), dtype=np.float32),
np.empty((queries.shape[0], 0), dtype=np.int64),
)

doc_centres = np.vstack([centres(row) for row in corpus])
scores = np.empty((queries.shape[0], k_eff), dtype=np.float32)
indices = np.empty((queries.shape[0], k_eff), dtype=np.int64)
doc_ids = np.arange(corpus.shape[0], dtype=np.int64)
scale = np.float32(1.0) / (norm * norm)
for qi, query in enumerate(queries):
q_centres = centres(query)
row_scores = (doc_centres @ q_centres).astype(np.float32) * scale
order = np.lexsort((doc_ids, -row_scores))[:k_eff]
scores[qi] = row_scores[order]
indices[qi] = order
return scores, indices


@pytest.mark.parametrize("bits", [1, 2, 4])
def test_new_reports_dim_and_bits(bits):
# dim must be a multiple of 2^bits; 128 is divisible by 2, 4, 16.
Expand Down Expand Up @@ -68,6 +113,58 @@ def test_search_asymmetric_shape(bits):
assert indices.shape == (3, 10)


@pytest.mark.parametrize("bits", [1, 2, 4])
def test_rankquant_eval_search_matches_rankquant_search(bits):
vectors = unit_vectors(45, 128, seed=31 + bits)
queries = unit_vectors(5, 128, seed=41 + bits)
idx = RankQuant(dim=128, bits=bits)
idx.add(vectors)

packed_scores, packed_ids = idx.search(queries, k=8)
eval_scores, eval_ids = rankquant_eval_search(vectors, queries, bits=bits, k=8)

np.testing.assert_array_equal(eval_ids, packed_ids)
np.testing.assert_allclose(eval_scores, packed_scores, rtol=1e-6, atol=1e-6)

Comment thread
Fieldnote-Echo marked this conversation as resolved.

@pytest.mark.parametrize("bits", [1, 2, 3, 4])
def test_rankquant_eval_search_matches_numpy_reference(bits):
vectors = unit_vectors(36, 128, seed=51 + bits)
queries = unit_vectors(4, 128, seed=61 + bits)

scores, ids = rankquant_eval_search(vectors, queries, bits=bits, k=9)
ref_scores, ref_ids = rankquant_eval_reference(vectors, queries, bits=bits, k=9)

assert scores.shape == (4, 9)
assert ids.shape == (4, 9)
assert scores.dtype == np.float32
assert ids.dtype == np.int64
np.testing.assert_array_equal(ids, ref_ids)
np.testing.assert_allclose(scores, ref_scores, rtol=1e-6, atol=1e-6)


def test_rankquant_eval_search_empty_corpus_shape():
vectors = np.empty((0, 64), dtype=np.float32)
queries = unit_vectors(3, 64, seed=53)

scores, ids = rankquant_eval_search(vectors, queries, bits=3, k=10)

assert scores.shape == (3, 0)
assert ids.shape == (3, 0)


def test_rankquant_eval_search_empty_queries_shape():
vectors = unit_vectors(4, 64, seed=56)
queries = np.empty((0, 64), dtype=np.float32)

scores, ids = rankquant_eval_search(vectors, queries, bits=3, k=10)

assert scores.shape == (0, 4)
assert ids.shape == (0, 4)
assert scores.dtype == np.float32
assert ids.dtype == np.int64


@pytest.mark.parametrize("bits", [2, 4])
def test_self_query_recall_at_1(bits):
# 1-bit is too lossy for a strict per-row self-query at this dim;
Expand All @@ -86,6 +183,12 @@ def test_invalid_bits_rejected():
RankQuant(dim=64, bits=3)
with pytest.raises(ValueError, match="bits"):
RankQuant(dim=64, bits=8)
vectors = unit_vectors(4, 64, seed=54)
queries = unit_vectors(1, 64, seed=55)
with pytest.raises(ValueError, match="bits"):
rankquant_eval_search(vectors, queries, bits=0, k=2)
with pytest.raises(ValueError, match="bits"):
rankquant_eval_search(vectors, queries, bits=8, k=2)


def test_dim_not_multiple_of_two_pow_bits_rejected():
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub mod sign_bitmap;
mod util;

pub use bitmap::Bitmap;
pub use quant::RankQuant;
pub use quant::{rankquant_eval_search, RankQuant};
pub use rank::Rank;
pub use sign_bitmap::SignBitmap;

Expand Down
Loading
Loading