diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 0aee891df..15c32ed1b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -69,6 +69,12 @@ public class GraphSearcher implements Closeable { private boolean pruneSearch; private boolean asyncPipelineEnabled; private final ScoreTracker.ScoreTrackerFactory scoreTrackerFactory; + private ScoreTracker scoreTracker; + + // Reusable callbacks for view.processNeighbors so we don't allocate a fresh + // lambda / bound method-reference for every candidate popped from the heap. + private final ImmutableGraphIndex.IntMarker visitedAdder; + private final ImmutableGraphIndex.NeighborProcessor neighborProcessor; private int visitedCount; private int expandedCount; @@ -94,6 +100,12 @@ protected GraphSearcher(ImmutableGraphIndex.View view) { this.pruneSearch = true; this.scoreTrackerFactory = new ScoreTracker.ScoreTrackerFactory(); + this.visitedAdder = visited::add; + this.neighborProcessor = (node2, score) -> { + scoreTracker.track(score); + candidates.push(node2, score); + visitedCount++; + }; } protected int getVisitedCount() { @@ -453,7 +465,7 @@ public void searchOneLayer(SearchScoreProvider scoreProvider, approximateResults.setMaxSize(rerankK); // track scores to predict when we are done with threshold queries - var scoreTracker = scoreTrackerFactory.getScoreTracker(pruneSearch, rerankK, threshold); + this.scoreTracker = scoreTrackerFactory.getScoreTracker(pruneSearch, rerankK, threshold); // the main search loop while (candidates.size() > 0) { @@ -480,12 +492,7 @@ public void searchOneLayer(SearchScoreProvider scoreProvider, // score the neighbors of the top candidate and add them to the queue var scoreFunction = scoreProvider.scoreFunction(); - ImmutableGraphIndex.NeighborProcessor neighborProcessor = (node2, score) -> { - scoreTracker.track(score); - candidates.push(node2, score); - visitedCount++; - }; - view.processNeighbors(level, topCandidateNode, scoreFunction, visited::add, neighborProcessor); + view.processNeighbors(level, topCandidateNode, scoreFunction, visitedAdder, neighborProcessor); } } catch (Throwable t) { // clear scratch structures if terminated via throwable, as they may not have been drained diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java index 840650ba5..b0fbec976 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java @@ -44,6 +44,7 @@ public class FusedPQ extends AbstractFeature implements FusedFeature { private final ProductQuantization pq; private final int maxDegree; private final ThreadLocal> reusableResults; + private final ThreadLocal> reusableCenteredQuery; private final ExplicitThreadLocal> reusableNeighborCodes; private final ExplicitThreadLocal> pqCodeScratch; @@ -54,6 +55,7 @@ public FusedPQ(int maxDegree, ProductQuantization pq) { this.maxDegree = maxDegree; this.pq = pq; this.reusableResults = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(maxDegree)); + this.reusableCenteredQuery = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(pq.getOriginalDimension())); this.reusableNeighborCodes = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize() * maxDegree)); this.pqCodeScratch = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize())); } @@ -93,7 +95,7 @@ static FusedPQ load(CommonHeader header, RandomAccessReader reader) { public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf, OnDiskGraphIndex.View view, ScoreFunction.ExactScoreFunction esf) { var neighbors = new PackedNeighbors(view); var hierarchyCachedFeatures = view.getInlineSourceFeatures(); - return FusedPQDecoder.newDecoder(neighbors, pq, hierarchyCachedFeatures, queryVector, reusableNeighborCodes.get(), reusableResults.get(), vsf, esf); + return FusedPQDecoder.newDecoder(neighbors, pq, hierarchyCachedFeatures, queryVector, reusableNeighborCodes.get(), reusableResults.get(), reusableCenteredQuery.get(), vsf, esf); } @Override diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedPQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedPQDecoder.java index 04050fb35..2436496fe 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedPQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedPQDecoder.java @@ -49,7 +49,8 @@ public abstract class FusedPQDecoder implements ScoreFunction.ApproximateScoreFu protected FusedPQDecoder(ProductQuantization pq, Int2ObjectHashMap hierarchyCachedFeatures, VectorFloat query, FusedPQ.PackedNeighbors packedNeighbors, - ByteSequence neighborCodes, VectorFloat results, ExactScoreFunction esf, + ByteSequence neighborCodes, VectorFloat results, + VectorFloat reusableCenteredQuery, ExactScoreFunction esf, VectorSimilarityFunction vsf) { this.pq = pq; this.hierarchyCachedFeatures = hierarchyCachedFeatures; @@ -66,7 +67,13 @@ protected FusedPQDecoder(ProductQuantization pq, partialSums = pq.reusablePartialSums(); if (vsf != VectorSimilarityFunction.COSINE) { VectorFloat center = pq.globalCentroid; - var centeredQuery = center == null ? query : VectorUtil.sub(query, center); + VectorFloat centeredQuery; + if (center == null) { + centeredQuery = query; + } else { + centeredQuery = reusableCenteredQuery; + VectorUtil.subInto(centeredQuery, query, center); + } for (var i = 0; i < pq.getSubspaceCount(); i++) { int offset = pq.subvectorSizesAndOffsets[i][1]; int size = pq.subvectorSizesAndOffsets[i][0]; @@ -124,8 +131,8 @@ static class DotProductDecoder extends FusedPQDecoder { public DotProductDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, Int2ObjectHashMap hierarchyCachedFeatures, VectorFloat query, ByteSequence neighborCodes, VectorFloat results, - ExactScoreFunction esf) { - super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.DOT_PRODUCT); + VectorFloat reusableCenteredQuery, ExactScoreFunction esf) { + super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, reusableCenteredQuery, esf, VectorSimilarityFunction.DOT_PRODUCT); } @Override @@ -138,8 +145,8 @@ static class EuclideanDecoder extends FusedPQDecoder { public EuclideanDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, Int2ObjectHashMap hierarchyCachedFeatures, VectorFloat query, ByteSequence neighborCodes, VectorFloat results, - ExactScoreFunction esf) { - super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.EUCLIDEAN); + VectorFloat reusableCenteredQuery, ExactScoreFunction esf) { + super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, reusableCenteredQuery, esf, VectorSimilarityFunction.EUCLIDEAN); } @Override @@ -158,8 +165,8 @@ static class CosineDecoder extends FusedPQDecoder { protected CosineDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, Int2ObjectHashMap hierarchyCachedFeatures, VectorFloat query, ByteSequence neighborCodes, VectorFloat results, - ExactScoreFunction esf) { - super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.COSINE); + VectorFloat reusableCenteredQuery, ExactScoreFunction esf) { + super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, reusableCenteredQuery, esf, VectorSimilarityFunction.COSINE); // this part is not query-dependent, so we can cache it partialSquaredMagnitudes = pq.partialSquaredMagnitudes().updateAndGet(current -> { @@ -186,7 +193,13 @@ protected CosineDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization p // compute partialSums VectorFloat center = pq.globalCentroid; float queryMagSum = 0.0f; - var centeredQuery = center == null ? query : VectorUtil.sub(query, center); + VectorFloat centeredQuery; + if (center == null) { + centeredQuery = query; + } else { + centeredQuery = reusableCenteredQuery; + VectorUtil.subInto(centeredQuery, query, center); + } for (var i = 0; i < pq.getSubspaceCount(); i++) { int offset = pq.subvectorSizesAndOffsets[i][1]; int size = pq.subvectorSizesAndOffsets[i][0]; @@ -229,14 +242,15 @@ protected float distanceToScore(float distance) { public static FusedPQDecoder newDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, Int2ObjectHashMap hierarchyCachedFeatures, VectorFloat query, ByteSequence reusableNeighborCodes, VectorFloat results, + VectorFloat reusableCenteredQuery, VectorSimilarityFunction similarityFunction, ExactScoreFunction esf) { switch (similarityFunction) { case DOT_PRODUCT: - return new DotProductDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf); + return new DotProductDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, reusableCenteredQuery, esf); case EUCLIDEAN: - return new EuclideanDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf); + return new EuclideanDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, reusableCenteredQuery, esf); case COSINE: - return new CosineDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf); + return new CosineDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, reusableCenteredQuery, esf); default: throw new IllegalArgumentException("Unsupported similarity function: " + similarityFunction); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index 83cb5885b..d16f5835a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -150,6 +150,15 @@ public static VectorFloat sub(VectorFloat lhs, VectorFloat rhs) { return impl.sub(lhs, rhs); } + /** + * Compute {@code lhs - rhs} into {@code dest} without allocating a new vector. + * {@code dest} must have at least {@code lhs.length()} slots; its existing contents are overwritten. + */ + public static void subInto(VectorFloat dest, VectorFloat lhs, VectorFloat rhs) { + dest.copyFrom(lhs, 0, 0, lhs.length()); + impl.subInPlace(dest, rhs); + } + public static VectorFloat sub(VectorFloat lhs, float value) { return impl.sub(lhs, value); }