diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CompactWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CompactWriter.java index c0cc684e7..ad76a15b8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CompactWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CompactWriter.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.channels.FileChannel; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; @@ -66,6 +67,7 @@ final class CompactWriter implements AutoCloseable { private final ThreadLocal> zeroPQ; private final boolean fusedPQEnabled; private final Path outputPath; + private volatile FileChannel inlineChannel; private final List configuredLayerInfo; private final List configuredLayerDegrees; private final List level1FeatureRecords; @@ -174,6 +176,16 @@ public Path getOutputPath() { return outputPath; } + /** + * Sets the {@link FileChannel} that {@link #writeInlineNodeRecord} will write base-layer + * records into directly from worker threads. Must be set before the first call to + * {@code writeInlineNodeRecord}; clear by passing {@code null} once the level-0 phase + * is finished. Lifetime of the channel is managed by the caller. + */ + public void setInlineChannel(FileChannel fc) { + this.inlineChannel = fc; + } + public void setEntryNodePqCode(ByteSequence code) { this.entryNodePqCode = code; } @@ -200,7 +212,7 @@ public void close() throws IOException { writer.flush(); } - public WriteResult writeInlineNodeRecord(int ordinal, VectorFloat vec, SelectedVecCache selectedCache, ByteSequence pqCode) throws IOException + public WriteResult writeInlineNodeRecord(int ordinal, VectorFloat vec, VectorFloat encodeScratch, SelectedVecCache selectedCache, ByteSequence pqCode) throws IOException { var bwriter = new ByteBufferIndexWriter(bufferPerThread.get()); @@ -219,7 +231,7 @@ public WriteResult writeInlineNodeRecord(int ordinal, VectorFloat vec, Select int k = 0; for (; k < selectedCache.size; k++) { pqCode.zero(); - pq.encodeTo(selectedCache.vecs[k], pqCode); + pq.encodeTo(selectedCache.vecs[k], encodeScratch, pqCode); vectorTypeSupport.writeByteSequence(bwriter, pqCode); } for (; k < baseDegree; k++) { @@ -245,9 +257,18 @@ public WriteResult writeInlineNodeRecord(int ordinal, VectorFloat vec, Select ordinal, recordSize, bwriter.bytesWritten(), baseDegree)); } - ByteBuffer dataCopy = bwriter.cloneBuffer(); + FileChannel fc = inlineChannel; + if (fc == null) { + throw new IllegalStateException("inline channel not set; call setInlineChannel before writeInlineNodeRecord"); + } + ByteBuffer buf = bufferPerThread.get(); + buf.position(0).limit(recordSize); + long pos = fileOffset; + while (buf.hasRemaining()) { + pos += fc.write(buf, pos); + } - return new WriteResult(ordinal, fileOffset, dataCopy); + return new WriteResult(ordinal, fileOffset); } static final class UpperLayerFeatureRecord { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java index 66cde90bb..c30eeda8b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexCompactor.java @@ -19,8 +19,8 @@ import java.io.FileNotFoundException; import java.io.IOException; import java.io.UncheckedIOException; + import java.nio.file.Path; -import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.StandardOpenOption; import java.util.*; @@ -419,27 +419,18 @@ private void compactLevels(CompactWriter writer, var wropts = EnumSet.of(StandardOpenOption.WRITE, StandardOpenOption.READ); try (FileChannel fc = FileChannel.open(writer.getOutputPath(), wropts)) { - - runBatchesWithBackpressure( - batches, - ecs, - submitOne, - (results) -> { - try { - for (WriteResult r : results) { - ByteBuffer b = r.data; - long pos = r.fileOffset; - while (b.hasRemaining()) { - int n = fc.write(b, pos); - pos += n; - } - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }, - progressListener - ); + writer.setInlineChannel(fc); + try { + runBatchesWithBackpressure( + batches, + ecs, + submitOne, + (results) -> { /* records were written directly by workers */ }, + progressListener + ); + } finally { + writer.setInlineChannel(null); + } } writer.offsetAfterInline(); @@ -705,6 +696,7 @@ private WriteResult processBaseNode( return writer.writeInlineNodeRecord( newOrdinal, scratch.baseVec, + scratch.tmpVec, selected, scratch.pqCode ); @@ -1225,12 +1217,10 @@ private static int partition(int[] order, float[] score, int lo, int hi) { static final class WriteResult { final int newOrdinal; final long fileOffset; - final ByteBuffer data; - WriteResult(int newOrdinal, long fileOffset, ByteBuffer data) { + WriteResult(int newOrdinal, long fileOffset) { this.newOrdinal = newOrdinal; this.fileOffset = fileOffset; - this.data = data; } }; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java index d3cbef30e..7692000fe 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java @@ -435,8 +435,23 @@ public ByteSequence encode(VectorFloat vector) { @Override public void encodeTo(VectorFloat vector, ByteSequence dest) { + encodeTo(vector, null, dest); + } + + /** + * Allocation-free variant of {@link #encodeTo(VectorFloat, ByteSequence)}: when a global + * centroid is configured, the centered vector is written into {@code scratch} (which must + * have at least {@code vector.length()} components) rather than into a freshly allocated + * buffer. Pass {@code null} for {@code scratch} to fall back to the allocating path. + */ + public void encodeTo(VectorFloat vector, VectorFloat scratch, ByteSequence dest) { if (globalCentroid != null) { - vector = sub(vector, globalCentroid); + if (scratch == null) { + vector = sub(vector, globalCentroid); + } else { + VectorUtil.subInto(scratch, vector, globalCentroid); + vector = scratch; + } } if (anisotropicThreshold > UNWEIGHTED)