Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ public class GraphSearcher implements Closeable {
private boolean pruneSearch;
private boolean asyncPipelineEnabled;
private final ScoreTracker.ScoreTrackerFactory scoreTrackerFactory;
private ScoreTracker scoreTracker;

// Reusable callbacks for view.processNeighbors so we don't allocate a fresh
// lambda / bound method-reference for every candidate popped from the heap.
private final ImmutableGraphIndex.IntMarker visitedAdder;
private final ImmutableGraphIndex.NeighborProcessor neighborProcessor;

private int visitedCount;
private int expandedCount;
Expand All @@ -94,6 +100,12 @@ protected GraphSearcher(ImmutableGraphIndex.View view) {

this.pruneSearch = true;
this.scoreTrackerFactory = new ScoreTracker.ScoreTrackerFactory();
this.visitedAdder = visited::add;
this.neighborProcessor = (node2, score) -> {
scoreTracker.track(score);
candidates.push(node2, score);
visitedCount++;
};
}

protected int getVisitedCount() {
Expand Down Expand Up @@ -453,7 +465,7 @@ public void searchOneLayer(SearchScoreProvider scoreProvider,
approximateResults.setMaxSize(rerankK);

// track scores to predict when we are done with threshold queries
var scoreTracker = scoreTrackerFactory.getScoreTracker(pruneSearch, rerankK, threshold);
this.scoreTracker = scoreTrackerFactory.getScoreTracker(pruneSearch, rerankK, threshold);

// the main search loop
while (candidates.size() > 0) {
Expand All @@ -480,12 +492,7 @@ public void searchOneLayer(SearchScoreProvider scoreProvider,

// score the neighbors of the top candidate and add them to the queue
var scoreFunction = scoreProvider.scoreFunction();
ImmutableGraphIndex.NeighborProcessor neighborProcessor = (node2, score) -> {
scoreTracker.track(score);
candidates.push(node2, score);
visitedCount++;
};
view.processNeighbors(level, topCandidateNode, scoreFunction, visited::add, neighborProcessor);
view.processNeighbors(level, topCandidateNode, scoreFunction, visitedAdder, neighborProcessor);
}
} catch (Throwable t) {
// clear scratch structures if terminated via throwable, as they may not have been drained
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class FusedPQ extends AbstractFeature implements FusedFeature {
private final ProductQuantization pq;
private final int maxDegree;
private final ThreadLocal<VectorFloat<?>> reusableResults;
private final ThreadLocal<VectorFloat<?>> reusableCenteredQuery;
private final ExplicitThreadLocal<ByteSequence<?>> reusableNeighborCodes;
private final ExplicitThreadLocal<ByteSequence<?>> pqCodeScratch;

Expand All @@ -54,6 +55,7 @@ public FusedPQ(int maxDegree, ProductQuantization pq) {
this.maxDegree = maxDegree;
this.pq = pq;
this.reusableResults = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(maxDegree));
this.reusableCenteredQuery = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(pq.getOriginalDimension()));
this.reusableNeighborCodes = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize() * maxDegree));
this.pqCodeScratch = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize()));
}
Expand Down Expand Up @@ -93,7 +95,7 @@ static FusedPQ load(CommonHeader header, RandomAccessReader reader) {
public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf, OnDiskGraphIndex.View view, ScoreFunction.ExactScoreFunction esf) {
var neighbors = new PackedNeighbors(view);
var hierarchyCachedFeatures = view.getInlineSourceFeatures();
return FusedPQDecoder.newDecoder(neighbors, pq, hierarchyCachedFeatures, queryVector, reusableNeighborCodes.get(), reusableResults.get(), vsf, esf);
return FusedPQDecoder.newDecoder(neighbors, pq, hierarchyCachedFeatures, queryVector, reusableNeighborCodes.get(), reusableResults.get(), reusableCenteredQuery.get(), vsf, esf);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public abstract class FusedPQDecoder implements ScoreFunction.ApproximateScoreFu
protected FusedPQDecoder(ProductQuantization pq,
Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures,
VectorFloat<?> query, FusedPQ.PackedNeighbors packedNeighbors,
ByteSequence<?> neighborCodes, VectorFloat<?> results, ExactScoreFunction esf,
ByteSequence<?> neighborCodes, VectorFloat<?> results,
VectorFloat<?> reusableCenteredQuery, ExactScoreFunction esf,
VectorSimilarityFunction vsf) {
this.pq = pq;
this.hierarchyCachedFeatures = hierarchyCachedFeatures;
Expand All @@ -66,7 +67,13 @@ protected FusedPQDecoder(ProductQuantization pq,
partialSums = pq.reusablePartialSums();
if (vsf != VectorSimilarityFunction.COSINE) {
VectorFloat<?> center = pq.globalCentroid;
var centeredQuery = center == null ? query : VectorUtil.sub(query, center);
VectorFloat<?> centeredQuery;
if (center == null) {
centeredQuery = query;
} else {
centeredQuery = reusableCenteredQuery;
VectorUtil.subInto(centeredQuery, query, center);
}
for (var i = 0; i < pq.getSubspaceCount(); i++) {
int offset = pq.subvectorSizesAndOffsets[i][1];
int size = pq.subvectorSizesAndOffsets[i][0];
Expand Down Expand Up @@ -124,8 +131,8 @@ static class DotProductDecoder extends FusedPQDecoder {
public DotProductDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq,
Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures,
VectorFloat<?> query, ByteSequence<?> neighborCodes, VectorFloat<?> results,
ExactScoreFunction esf) {
super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.DOT_PRODUCT);
VectorFloat<?> reusableCenteredQuery, ExactScoreFunction esf) {
super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, reusableCenteredQuery, esf, VectorSimilarityFunction.DOT_PRODUCT);
}

@Override
Expand All @@ -138,8 +145,8 @@ static class EuclideanDecoder extends FusedPQDecoder {
public EuclideanDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq,
Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures,
VectorFloat<?> query, ByteSequence<?> neighborCodes, VectorFloat<?> results,
ExactScoreFunction esf) {
super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.EUCLIDEAN);
VectorFloat<?> reusableCenteredQuery, ExactScoreFunction esf) {
super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, reusableCenteredQuery, esf, VectorSimilarityFunction.EUCLIDEAN);
}

@Override
Expand All @@ -158,8 +165,8 @@ static class CosineDecoder extends FusedPQDecoder {
protected CosineDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq,
Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures,
VectorFloat<?> query, ByteSequence<?> neighborCodes, VectorFloat<?> results,
ExactScoreFunction esf) {
super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.COSINE);
VectorFloat<?> reusableCenteredQuery, ExactScoreFunction esf) {
super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, reusableCenteredQuery, esf, VectorSimilarityFunction.COSINE);

// this part is not query-dependent, so we can cache it
partialSquaredMagnitudes = pq.partialSquaredMagnitudes().updateAndGet(current -> {
Expand All @@ -186,7 +193,13 @@ protected CosineDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization p
// compute partialSums
VectorFloat<?> center = pq.globalCentroid;
float queryMagSum = 0.0f;
var centeredQuery = center == null ? query : VectorUtil.sub(query, center);
VectorFloat<?> centeredQuery;
if (center == null) {
centeredQuery = query;
} else {
centeredQuery = reusableCenteredQuery;
VectorUtil.subInto(centeredQuery, query, center);
}
for (var i = 0; i < pq.getSubspaceCount(); i++) {
int offset = pq.subvectorSizesAndOffsets[i][1];
int size = pq.subvectorSizesAndOffsets[i][0];
Expand Down Expand Up @@ -229,14 +242,15 @@ protected float distanceToScore(float distance) {
public static FusedPQDecoder newDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq,
Int2ObjectHashMap<FusedFeature.InlineSource> hierarchyCachedFeatures, VectorFloat<?> query,
ByteSequence<?> reusableNeighborCodes, VectorFloat<?> results,
VectorFloat<?> reusableCenteredQuery,
VectorSimilarityFunction similarityFunction, ExactScoreFunction esf) {
switch (similarityFunction) {
case DOT_PRODUCT:
return new DotProductDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf);
return new DotProductDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, reusableCenteredQuery, esf);
case EUCLIDEAN:
return new EuclideanDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf);
return new EuclideanDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, reusableCenteredQuery, esf);
case COSINE:
return new CosineDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf);
return new CosineDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, reusableCenteredQuery, esf);
default:
throw new IllegalArgumentException("Unsupported similarity function: " + similarityFunction);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ public static VectorFloat<?> sub(VectorFloat<?> lhs, VectorFloat<?> rhs) {
return impl.sub(lhs, rhs);
}

/**
* Compute {@code lhs - rhs} into {@code dest} without allocating a new vector.
* {@code dest} must have at least {@code lhs.length()} slots; its existing contents are overwritten.
*/
public static void subInto(VectorFloat<?> dest, VectorFloat<?> lhs, VectorFloat<?> rhs) {
dest.copyFrom(lhs, 0, 0, lhs.length());
impl.subInPlace(dest, rhs);
}

public static VectorFloat<?> sub(VectorFloat<?> lhs, float value) {
return impl.sub(lhs, value);
}
Expand Down
Loading