diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddIndexExec.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddIndexExec.scala index 266772fdb..735efacb3 100755 --- a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddIndexExec.scala +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddIndexExec.scala @@ -28,8 +28,9 @@ import org.apache.spark.sql.util.LanceArrowUtils import org.apache.spark.sql.util.LanceSerializeUtil.{decode, encode} import org.apache.spark.unsafe.types.UTF8String import org.lance.{CommitBuilder, Dataset, Transaction} -import org.lance.index.{Index, IndexOptions, IndexParams, IndexType} +import org.lance.index.{DistanceType, Index, IndexOptions, IndexParams, IndexType} import org.lance.index.scalar.{BTreeIndexParams, ScalarIndexParams} +import org.lance.index.vector.{IvfBuildParams, PQBuildParams, SQBuildParams, VectorIndexParams, VectorTrainer} import org.lance.operation.{CreateIndex => AddIndexOperation} import org.lance.spark.{BaseLanceNamespaceSparkCatalog, LanceDataset, LanceRuntime, LanceSparkReadOptions} import org.lance.spark.arrow.LanceArrowWriter @@ -217,6 +218,34 @@ case class AddIndexExec( UTF8String.fromString(indexName)))) } + // IVF_* vector indexes also use the logical segment commit path. They differ from zonemap + // in that IVF centroids (and the PQ codebook for IVF_PQ) must be trained once on the + // driver before per-fragment segment builds: lance-core's distributed build rejects + // per-fragment-trained centroids, and shared artifacts keep all segments in the same + // query-time compatibility group. + if (IndexUtils.isVectorIndex(indexType)) { + if (canonicalColumns.size != 1) { + throw new IllegalArgumentException( + s"Vector index supports a single column only, got: $canonicalColumns") + } + val spec = VectorIndexSpec.fromArgs(indexType, args) + val vectorJob = new VectorIndexJob( + this.copy(columns = canonicalColumns), + readOptions, + indexType, + spec, + fragmentIds, + nsImpl, + nsProps, + tableId, + initialStorageOpts) + val segments = vectorJob.run() + commitIndexSegments(readOptions, canonicalColumns.head, segments) + return Seq(new GenericInternalRow(Array[Any]( + fragmentIds.size.toLong, + UTF8String.fromString(indexName)))) + } + // FTS/INVERTED still uses a caller-assigned UUID: each fragment writes // partial metadata under that UUID and the driver merges them into a single // index root before committing the resulting index transaction. @@ -959,6 +988,273 @@ case class ZonemapIndexTask( } } +/** + * A job implementation for creating IVF-family vector indexes (IVF_FLAT, IVF_PQ, IVF_SQ) + * via the logical segment commit path. + * + * IVF centroids and (for IVF_PQ) the PQ codebook are trained once on the driver and shipped + * to executors so every per-fragment segment shares the same artifacts. lance-core's + * distributed build path requires precomputed centroids, and sharing them ensures all + * segments land in the same query-time compatibility group. + * + * Each Spark task calls [[org.lance.Dataset#createIndex]] with `withFragmentIds(List(fid))` + * and vector index params and returns an uncommitted index segment. The driver collects + * the segments and publishes them atomically via [[org.lance.Dataset#commitExistingIndexSegments]]. + */ +class VectorIndexJob( + addIndexExec: AddIndexExec, + readOptions: LanceSparkReadOptions, + indexType: IndexType, + spec: VectorIndexSpec, + fragmentIds: List[Integer], + nsImpl: Option[String], + nsProps: Option[Map[String, String]], + tableId: Option[List[String]], + initialStorageOpts: Option[Map[String, String]]) + extends Logging { + + def run(): Seq[Index] = { + val column = addIndexExec.columns.head + + val (centroids, codebook) = trainArtifactsOnDriver(column) + val trainedSpec = spec.copy(centroids = centroids, codebook = codebook) + + val encodedReadOptions = encode(readOptions) + val tasks = fragmentIds.map { fid => + VectorIndexTask( + encodedReadOptions, + addIndexExec.columns.toList, + indexType.name(), + trainedSpec, + addIndexExec.indexName, + fid, + nsImpl, + nsProps, + tableId, + initialStorageOpts) + }.toSeq + + try { + addIndexExec.session.sparkContext + .parallelize(tasks, tasks.size) + .map(t => t.execute()) + .collect() + .map(decode[Index]) + .toSeq + } catch { + case e: Exception => + throw new RuntimeException( + "Vector index segment build failed. Uncommitted segments are not " + + "visible to readers and will not affect query correctness.", + e) + } + } + + private def trainArtifactsOnDriver(column: String): (Array[Float], Array[Float]) = { + val dataset = Utils.openDatasetBuilder(readOptions) + .initialStorageOptions(initialStorageOpts.map(_.asJava).orNull) + .runtimeNamespace( + nsImpl.orNull, + nsProps.map(_.asJava).orNull, + tableId.map(_.asJava).orNull) + .build() + try { + val ivfBuilder = new IvfBuildParams.Builder().setNumPartitions(spec.numPartitions) + spec.sampleRate.foreach(ivfBuilder.setSampleRate) + spec.maxIters.foreach(ivfBuilder.setMaxIters) + val centroids = VectorTrainer.trainIvfCentroids( + dataset, + column, + ivfBuilder.build(), + spec.distanceType) + + val codebook: Array[Float] = indexType match { + case IndexType.IVF_PQ => + val pqBuilder = new PQBuildParams.Builder() + spec.numSubVectors.foreach(pqBuilder.setNumSubVectors) + spec.pqNumBits.foreach(pqBuilder.setNumBits) + spec.pqMaxIters.foreach(pqBuilder.setMaxIters) + spec.sampleRate.foreach(pqBuilder.setSampleRate) + VectorTrainer.trainPqCodebook( + dataset, + column, + pqBuilder.build(), + spec.distanceType) + case _ => null + } + + (centroids, codebook) + } finally { + dataset.close() + } + } +} + +/** + * Executor-side task that builds one uncommitted vector index segment for one fragment. + * Returns a Kryo-encoded [[org.lance.index.Index]] for the driver to commit. + */ +case class VectorIndexTask( + encodedReadOptions: String, + columns: List[String], + indexTypeName: String, + spec: VectorIndexSpec, + indexName: String, + fragmentId: Int, + namespaceImpl: Option[String], + namespaceProperties: Option[Map[String, String]], + tableId: Option[List[String]], + initialStorageOptions: Option[Map[String, String]]) extends Serializable { + + def execute(): String = { + val readOptions = decode[LanceSparkReadOptions](encodedReadOptions) + val indexType = IndexType.valueOf(indexTypeName) + val vectorParams = spec.toVectorIndexParams(indexType) + val params = IndexParams.builder().setVectorIndexParams(vectorParams).build() + + val indexOptions = IndexOptions + .builder(java.util.Arrays.asList(columns: _*), indexType, params) + .replace(false) + .withFragmentIds(Collections.singletonList(java.lang.Integer.valueOf(fragmentId))) + .build() + + val dataset = Utils.openDatasetBuilder(readOptions) + .initialStorageOptions(initialStorageOptions.map(_.asJava).orNull) + .runtimeNamespace( + namespaceImpl.orNull, + namespaceProperties.map(_.asJava).orNull, + tableId.map(_.asJava).orNull) + .build() + + try { + encode(dataset.createIndex(indexOptions)) + } finally { + dataset.close() + } + } +} + +/** + * Serializable carrier for IVF-family build parameters. Parsed from user WITH-args on the + * driver, populated with driver-trained artifacts (centroids, codebook), and shipped to + * executors where [[toVectorIndexParams]] rebuilds the native [[VectorIndexParams]] (whose + * nested builders are not Serializable). + */ +case class VectorIndexSpec( + metricType: String, + numPartitions: Int, + sampleRate: Option[Int], + maxIters: Option[Int], + // IVF_PQ + numSubVectors: Option[Int], + pqNumBits: Option[Int], + pqMaxIters: Option[Int], + // IVF_SQ + sqNumBits: Option[Short], + // Driver-trained artifacts populated by VectorIndexJob before task dispatch. + centroids: Array[Float] = null, + codebook: Array[Float] = null) extends Serializable { + + def distanceType: DistanceType = VectorIndexSpec.parseDistanceType(metricType) + + def toVectorIndexParams(indexType: IndexType): VectorIndexParams = { + val ivfBuilder = new IvfBuildParams.Builder().setNumPartitions(numPartitions) + sampleRate.foreach(ivfBuilder.setSampleRate) + maxIters.foreach(ivfBuilder.setMaxIters) + if (centroids != null) ivfBuilder.setCentroids(centroids) + val ivfParams = ivfBuilder.build() + + val builder = new VectorIndexParams.Builder(ivfParams).setDistanceType(distanceType) + + indexType match { + case IndexType.IVF_FLAT => + case IndexType.IVF_PQ => + val pqBuilder = new PQBuildParams.Builder() + numSubVectors.foreach(pqBuilder.setNumSubVectors) + pqNumBits.foreach(pqBuilder.setNumBits) + pqMaxIters.foreach(pqBuilder.setMaxIters) + sampleRate.foreach(pqBuilder.setSampleRate) + if (codebook != null) pqBuilder.setCodebook(codebook) + builder.setPqParams(pqBuilder.build()) + case IndexType.IVF_SQ => + val sqBuilder = new SQBuildParams.Builder() + sqNumBits.foreach(b => sqBuilder.setNumBits(b)) + sampleRate.foreach(sqBuilder.setSampleRate) + builder.setSqParams(sqBuilder.build()) + case other => + throw new IllegalArgumentException(s"Unsupported vector index type: $other") + } + + builder.build() + } +} + +object VectorIndexSpec { + + def fromArgs(indexType: IndexType, args: Seq[LanceNamedArgument]): VectorIndexSpec = { + def argInt(name: String): Option[Int] = args.find(_.name == name).map { a => + a.value match { + case i: java.lang.Integer => i.intValue() + case l: java.lang.Long => l.intValue() + case other => + throw new IllegalArgumentException( + s"Vector index arg '$name' must be an integer, got: $other") + } + } + def argString(name: String): Option[String] = args.find(_.name == name).map { a => + a.value match { + case s: java.lang.String => s + case other => String.valueOf(other) + } + } + + // use_residual is a no-op upstream as of lance 7 — the JNI bridge does not pass it to + // Rust and the Rust IvfBuildParams no longer carries the field. Reject the flag rather + // than silently ignore it so users don't think it took effect. + if (args.exists(_.name == "use_residual")) { + throw new IllegalArgumentException( + "WITH-arg 'use_residual' is not supported: lance-core's training path does not " + + "honor it. Remove the option to proceed.") + } + + val numPartitions = argInt("num_partitions").getOrElse( + throw new IllegalArgumentException( + "Vector index requires 'num_partitions' in WITH clause")) + + val metric = argString("metric_type").getOrElse("l2") + // Validate eagerly so a typo fails on the driver, not deep inside training. + parseDistanceType(metric) + + val sqBits: Option[Short] = indexType match { + case IndexType.IVF_SQ => argInt("num_bits").map(_.toShort).orElse(Some(8.toShort)) + case _ => None + } + val pqBits: Option[Int] = indexType match { + case IndexType.IVF_PQ => argInt("num_bits").orElse(Some(8)) + case _ => None + } + + VectorIndexSpec( + metricType = metric, + numPartitions = numPartitions, + sampleRate = argInt("sample_rate"), + maxIters = argInt("max_iters"), + numSubVectors = argInt("num_sub_vectors"), + pqNumBits = pqBits, + pqMaxIters = argInt("pq_max_iters"), + sqNumBits = sqBits) + } + + def parseDistanceType(s: String): DistanceType = s.toLowerCase match { + case "l2" | "euclidean" => DistanceType.L2 + case "cosine" => DistanceType.Cosine + case "dot" | "inner_product" | "ip" => DistanceType.Dot + case "hamming" => DistanceType.Hamming + case other => throw new IllegalArgumentException( + s"Unsupported metric_type '$other'; expected one of: l2, cosine, dot, hamming") + } +} + /** * Utility methods for working with index types. */ @@ -993,6 +1289,9 @@ object IndexUtils { case "btree" => IndexType.BTREE case "zonemap" => IndexType.ZONEMAP case "fts" => IndexType.INVERTED + case "ivf_flat" => IndexType.IVF_FLAT + case "ivf_pq" => IndexType.IVF_PQ + case "ivf_sq" => IndexType.IVF_SQ case other => throw new UnsupportedOperationException(s"Unsupported index method: $other") } } @@ -1006,6 +1305,13 @@ object IndexUtils { } } + private val VectorIndexTypes: Set[IndexType] = Set( + IndexType.IVF_FLAT, + IndexType.IVF_PQ, + IndexType.IVF_SQ) + + def isVectorIndex(indexType: IndexType): Boolean = VectorIndexTypes.contains(indexType) + def btreeBuildMode(indexType: IndexType, args: Seq[LanceNamedArgument]): Option[String] = { if (indexType != IndexType.BTREE) { None diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseAddIndexTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseAddIndexTest.java index 6cac53e07..11ca52d6f 100755 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseAddIndexTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/update/BaseAddIndexTest.java @@ -36,6 +36,7 @@ import java.nio.file.Path; import java.util.Collections; import java.util.List; +import java.util.Random; import java.util.Set; import java.util.UUID; import java.util.stream.Collectors; @@ -1160,6 +1161,31 @@ private Index checkIndex(String indexName) { } } + /** + * Asserts a vector index has the requested IVF subtype. {@link Index#indexType()} reports the + * umbrella {@code VECTOR} for all IVF variants; the concrete {@code IVF_FLAT}/{@code + * IVF_PQ}/{@code IVF_SQ} subtype is only exposed via {@link org.lance.Dataset#describeIndices()}. + * Asserting the subtype catches regressions where the connector silently builds the wrong variant + * for the requested USING method. + */ + private void assertVectorSubtype(String indexName, String expectedSubtype) { + org.lance.Dataset lanceDataset = org.lance.Dataset.open().uri(tableDir).build(); + try { + IndexDescription desc = + lanceDataset.describeIndices().stream() + .filter(d -> indexName.equals(d.getName())) + .findFirst() + .orElseThrow( + () -> new AssertionError("Index description for '" + indexName + "' not found")); + Assertions.assertEquals( + expectedSubtype, + desc.getIndexType().toUpperCase(), + "Vector index '" + indexName + "' subtype mismatch"); + } finally { + lanceDataset.close(); + } + } + private void checkFtsIndex(String indexName) { Index index = checkIndex(indexName); Assertions.assertEquals(IndexType.INVERTED, index.indexType()); @@ -1172,4 +1198,362 @@ private void checkFtsIndex(String indexName) { private int fieldId(org.lance.Dataset dataset, String path) { return FieldPathUtils.resolveLeafField(dataset.getLanceSchema(), path).getId(); } + + private static final int VECTOR_DIM = 16; + // PQ training with num_bits=8 needs >= 256 rows (2^8 centroids per subspace). + private static final int VECTOR_ROWS_PER_FRAGMENT = 256; + + /** + * Build a 2-fragment vector dataset with an `embedding` fixed-size-list column. Vectors are + * deterministic per seed so tests are reproducible across runs. + */ + private void prepareVectorDataset() { + spark.sql( + String.format( + "create table %s (id int, embedding array) using lance " + + "TBLPROPERTIES ('embedding.arrow.fixed-size-list.size' = '%d')", + fullTable, VECTOR_DIM)); + insertVectors(0, VECTOR_ROWS_PER_FRAGMENT, 42L); + insertVectors(VECTOR_ROWS_PER_FRAGMENT, 2 * VECTOR_ROWS_PER_FRAGMENT, 43L); + } + + private void insertVectors(int fromId, int toId, long seed) { + Random rng = new Random(seed); + String values = + IntStream.range(fromId, toId) + .mapToObj( + i -> { + StringBuilder sb = new StringBuilder("array("); + for (int j = 0; j < VECTOR_DIM; j++) { + if (j > 0) sb.append(", "); + sb.append(rng.nextFloat()); + } + sb.append(")"); + return String.format("(%d, %s)", i, sb); + }) + .collect(Collectors.joining(",")); + spark.sql(String.format("insert into %s (id, embedding) values %s", fullTable, values)); + } + + @Test + public void testCreateIvfFlatIndex() { + prepareVectorDataset(); + + Dataset result = + spark.sql( + String.format( + "alter table %s create index test_ivf_flat using ivf_flat (embedding) " + + "with (num_partitions=4, metric_type='l2')", + fullTable)); + + Assertions.assertEquals( + "StructType(StructField(fragments_indexed,LongType,true),StructField(index_name,StringType,true))", + result.schema().toString()); + + Row row = result.collectAsList().get(0); + long fragmentsIndexed = row.getLong(0); + String indexName = row.getString(1); + + Assertions.assertTrue(fragmentsIndexed >= 2, "Expected at least 2 fragments indexed"); + Assertions.assertEquals("test_ivf_flat", indexName); + + // Manifest Index#indexType() reports the umbrella VECTOR. The concrete subtype + // (IVF_FLAT/IVF_PQ/IVF_SQ) is exposed via describeIndices.getIndexType() — assert + // both, so a regression that, say, builds IVF_FLAT when the user asked for IVF_PQ + // would fail here. + checkIndex("test_ivf_flat"); + assertVectorSubtype("test_ivf_flat", "IVF_FLAT"); + } + + @Test + public void testCreateIvfPqIndex() { + prepareVectorDataset(); + + Dataset result = + spark.sql( + String.format( + "alter table %s create index test_ivf_pq using ivf_pq (embedding) " + + "with (num_partitions=4, num_sub_vectors=4, num_bits=8, metric_type='l2')", + fullTable)); + + Row row = result.collectAsList().get(0); + long fragmentsIndexed = row.getLong(0); + Assertions.assertTrue(fragmentsIndexed >= 2); + Assertions.assertEquals("test_ivf_pq", row.getString(1)); + + checkIndex("test_ivf_pq"); + assertVectorSubtype("test_ivf_pq", "IVF_PQ"); + } + + @Test + public void testCreateIvfSqIndex() { + prepareVectorDataset(); + + // metric_type='cosine' specifically: lance 6 silently built L2 centroids here, recall + // would have degraded. Lance 7 honors DistanceType end-to-end. This test guards that. + Dataset result = + spark.sql( + String.format( + "alter table %s create index test_ivf_sq using ivf_sq (embedding) " + + "with (num_partitions=4, num_bits=8, metric_type='cosine')", + fullTable)); + + Row row = result.collectAsList().get(0); + long fragmentsIndexed = row.getLong(0); + Assertions.assertTrue(fragmentsIndexed >= 2); + Assertions.assertEquals("test_ivf_sq", row.getString(1)); + + checkIndex("test_ivf_sq"); + assertVectorSubtype("test_ivf_sq", "IVF_SQ"); + } + + @Test + public void testRecreateIvfPqIndexReplacesOld() { + prepareVectorDataset(); + + String createSql = + String.format( + "alter table %s create index test_ivf_pq_repeat using ivf_pq (embedding) " + + "with (num_partitions=4, num_sub_vectors=4, num_bits=8, metric_type='l2')", + fullTable); + spark.sql(createSql); + checkIndex("test_ivf_pq_repeat"); + + // Capture segment UUIDs after first run (mirrors + // testRepeatedCreateZonemapIndexReplacesExistingSegments). + Set firstRunUuids; + org.lance.Dataset ds1 = org.lance.Dataset.open().uri(tableDir).build(); + try { + firstRunUuids = + ds1.getIndexes().stream() + .filter(idx -> "test_ivf_pq_repeat".equals(idx.name())) + .map(Index::uuid) + .collect(Collectors.toSet()); + } finally { + ds1.close(); + } + Assertions.assertFalse(firstRunUuids.isEmpty(), "First run should produce segments"); + + spark.sql(createSql); + checkIndex("test_ivf_pq_repeat"); + + // Verify replace semantics: fresh UUIDs, segments cover every fragment exactly once, + // no accumulation. + org.lance.Dataset lanceDataset = org.lance.Dataset.open().uri(tableDir).build(); + try { + int fragmentCount = lanceDataset.getFragments().size(); + List segments = + lanceDataset.getIndexes().stream() + .filter(idx -> "test_ivf_pq_repeat".equals(idx.name())) + .collect(Collectors.toList()); + Set secondRunUuids = segments.stream().map(Index::uuid).collect(Collectors.toSet()); + int coveredFragments = + segments.stream() + .map(idx -> idx.fragments().orElse(Collections.emptyList()).size()) + .mapToInt(Integer::intValue) + .sum(); + Assertions.assertEquals( + fragmentCount, + coveredFragments, + "Recreated IVF segments must cover all fragments exactly once (saw " + + segments.size() + + " segments covering " + + coveredFragments + + " slots, expected " + + fragmentCount + + ")"); + Assertions.assertTrue( + Collections.disjoint(firstRunUuids, secondRunUuids), + "Recreate must produce fresh segment UUIDs, not reuse first-run ones"); + } finally { + lanceDataset.close(); + } + } + + @Test + public void testIvfRejectsMultipleColumns() { + prepareVectorDataset(); + Assertions.assertThrows( + Exception.class, + () -> + spark.sql( + String.format( + "alter table %s create index idx_multi_vec using ivf_pq (id, embedding) " + + "with (num_partitions=4)", + fullTable))); + } + + @Test + public void testCreateIvfPqWithBadMetricTypeFails() { + prepareVectorDataset(); + + RuntimeException exception = + Assertions.assertThrows( + RuntimeException.class, + () -> + spark + .sql( + String.format( + "alter table %s create index test_ivf_pq_bad_metric using ivf_pq (embedding) " + + "with (num_partitions=4, num_sub_vectors=4, metric_type='manhattan')", + fullTable)) + .collect()); + + Assertions.assertTrue( + exception.getMessage().toLowerCase().contains("metric") + || exception.getMessage().contains("manhattan"), + "Expected error to mention bad metric_type, got: " + exception.getMessage()); + } + + @Test + public void testCreateIvfPqWithoutNumPartitionsFails() { + prepareVectorDataset(); + + RuntimeException exception = + Assertions.assertThrows( + RuntimeException.class, + () -> + spark + .sql( + String.format( + "alter table %s create index test_ivf_pq_missing using ivf_pq (embedding) " + + "with (num_sub_vectors=4)", + fullTable)) + .collect()); + + Assertions.assertTrue( + exception.getMessage().contains("num_partitions"), + "Expected error to mention missing num_partitions, got: " + exception.getMessage()); + } + + @Test + public void testIvfPqCosineRecallOnClusteredData() { + // Real recall test: build IVF_PQ with metric_type='cosine' on clustered embeddings, + // query with each cluster centroid via VECTOR_SEARCH, assert top-K is mostly the + // matching cluster. This is the test that would have caught the lance 6 silent-L2 + // bug — without it, the metric_type plumbing is unverified end-to-end. + int dim = 8; + int numClusters = 8; + int rowsPerCluster = 64; + int totalRows = numClusters * rowsPerCluster; // 512 + + String table = + catalogName + ".default.recall_test_" + UUID.randomUUID().toString().replace("-", ""); + spark.sql( + String.format( + "create table %s (id int, cluster_id int, vector array) using lance " + + "TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '%d')", + table, dim)); + + Random rng = new Random(7L); + float[][] centers = new float[numClusters][dim]; + for (int c = 0; c < numClusters; c++) { + for (int j = 0; j < dim; j++) { + centers[c][j] = (rng.nextFloat() - 0.5f) * 10.0f; + } + } + + // Two inserts to get two fragments. + insertClustered(table, centers, 0, totalRows / 2, dim, rowsPerCluster, 100L); + insertClustered(table, centers, totalRows / 2, totalRows, dim, rowsPerCluster, 200L); + + // num_partitions=8 matches cluster count; num_sub_vectors=4 divides dim=8 evenly. + spark.sql( + String.format( + "alter table %s create index recall_idx using ivf_pq (vector) " + + "with (num_partitions=8, num_sub_vectors=4, num_bits=8, metric_type='cosine')", + table)); + + int k = 10; + int hits = 0; + int totalQueried = 0; + for (int c = 0; c < numClusters; c++) { + final int cluster = c; + String queryVec = + "array(" + + IntStream.range(0, dim) + .mapToObj(j -> Float.toString(centers[cluster][j])) + .collect(Collectors.joining(", ")) + + ")"; + List results = + spark + .sql( + "SELECT cluster_id FROM VECTOR_SEARCH('" + + table + + "', " + + queryVec + + ", " + + k + + ")") + .collectAsList(); + Assertions.assertEquals(k, results.size(), "VECTOR_SEARCH must return k results"); + for (Row r : results) { + if (r.getInt(0) == cluster) hits++; + totalQueried++; + } + } + double recall = (double) hits / totalQueried; + Assertions.assertTrue( + recall >= 0.5, + "IVF_PQ cosine recall on clustered data must be >= 0.5, got " + + String.format("%.3f", recall) + + " (" + + hits + + "/" + + totalQueried + + " correct cluster hits)"); + } + + private void insertClustered( + String table, + float[][] centers, + int idStart, + int idEnd, + int dim, + int rowsPerCluster, + long seed) { + Random rng = new Random(seed); + int numClusters = centers.length; + String values = + IntStream.range(idStart, idEnd) + .mapToObj( + i -> { + int cluster = (i / rowsPerCluster) % numClusters; + StringBuilder sb = new StringBuilder("array("); + for (int j = 0; j < dim; j++) { + if (j > 0) sb.append(", "); + // Tight cluster (sigma=0.1) so cosine direction strongly correlates + // with cluster — recall on this should be near 1 for a working index. + float v = centers[cluster][j] + (float) (rng.nextGaussian() * 0.1); + sb.append(v); + } + sb.append(")"); + return String.format("(%d, %d, %s)", i, cluster, sb); + }) + .collect(Collectors.joining(",")); + spark.sql(String.format("insert into %s (id, cluster_id, vector) values %s", table, values)); + } + + @Test + public void testUseResidualWithArgIsRejected() { + prepareVectorDataset(); + + // use_residual is a no-op upstream and we reject it explicitly to avoid silent recall + // degradation. If lance-core later honors the flag this test's expectation should flip. + RuntimeException exception = + Assertions.assertThrows( + RuntimeException.class, + () -> + spark + .sql( + String.format( + "alter table %s create index test_ivf_pq_resid using ivf_pq (embedding) " + + "with (num_partitions=4, use_residual=true)", + fullTable)) + .collect()); + + Assertions.assertTrue( + exception.getMessage().contains("use_residual"), + "Expected error to mention use_residual, got: " + exception.getMessage()); + } }