Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ import org.apache.spark.sql.util.LanceArrowUtils
import org.apache.spark.sql.util.LanceSerializeUtil.{decode, encode}
import org.apache.spark.unsafe.types.UTF8String
import org.lance.{CommitBuilder, Dataset, Transaction}
import org.lance.index.{Index, IndexOptions, IndexParams, IndexType}
import org.lance.index.{DistanceType, Index, IndexOptions, IndexParams, IndexType}
import org.lance.index.scalar.{BTreeIndexParams, ScalarIndexParams}
import org.lance.index.vector.{IvfBuildParams, PQBuildParams, SQBuildParams, VectorIndexParams, VectorTrainer}
import org.lance.operation.{CreateIndex => AddIndexOperation}
import org.lance.spark.{BaseLanceNamespaceSparkCatalog, LanceDataset, LanceRuntime, LanceSparkReadOptions}
import org.lance.spark.arrow.LanceArrowWriter
Expand Down Expand Up @@ -217,6 +218,34 @@ case class AddIndexExec(
UTF8String.fromString(indexName))))
}

// IVF_* vector indexes also use the logical segment commit path. They differ from zonemap
// in that IVF centroids (and the PQ codebook for IVF_PQ) must be trained once on the
// driver before per-fragment segment builds: lance-core's distributed build rejects
// per-fragment-trained centroids, and shared artifacts keep all segments in the same
// query-time compatibility group.
if (IndexUtils.isVectorIndex(indexType)) {
if (canonicalColumns.size != 1) {
throw new IllegalArgumentException(
s"Vector index supports a single column only, got: $canonicalColumns")
}
val spec = VectorIndexSpec.fromArgs(indexType, args)
val vectorJob = new VectorIndexJob(
this.copy(columns = canonicalColumns),
readOptions,
indexType,
spec,
fragmentIds,
nsImpl,
nsProps,
tableId,
initialStorageOpts)
val segments = vectorJob.run()
commitIndexSegments(readOptions, canonicalColumns.head, segments)
return Seq(new GenericInternalRow(Array[Any](
fragmentIds.size.toLong,
UTF8String.fromString(indexName))))
}

// FTS/INVERTED still uses a caller-assigned UUID: each fragment writes
// partial metadata under that UUID and the driver merges them into a single
// index root before committing the resulting index transaction.
Expand Down Expand Up @@ -959,6 +988,273 @@ case class ZonemapIndexTask(
}
}

/**
* A job implementation for creating IVF-family vector indexes (IVF_FLAT, IVF_PQ, IVF_SQ)
* via the logical segment commit path.
*
* IVF centroids and (for IVF_PQ) the PQ codebook are trained once on the driver and shipped
* to executors so every per-fragment segment shares the same artifacts. lance-core's
* distributed build path requires precomputed centroids, and sharing them ensures all
* segments land in the same query-time compatibility group.
*
* Each Spark task calls [[org.lance.Dataset#createIndex]] with `withFragmentIds(List(fid))`
* and vector index params and returns an uncommitted index segment. The driver collects
* the segments and publishes them atomically via [[org.lance.Dataset#commitExistingIndexSegments]].
*/
class VectorIndexJob(
addIndexExec: AddIndexExec,
readOptions: LanceSparkReadOptions,
indexType: IndexType,
spec: VectorIndexSpec,
fragmentIds: List[Integer],
nsImpl: Option[String],
nsProps: Option[Map[String, String]],
tableId: Option[List[String]],
initialStorageOpts: Option[Map[String, String]])
extends Logging {

def run(): Seq[Index] = {
val column = addIndexExec.columns.head

val (centroids, codebook) = trainArtifactsOnDriver(column)
val trainedSpec = spec.copy(centroids = centroids, codebook = codebook)

val encodedReadOptions = encode(readOptions)
val tasks = fragmentIds.map { fid =>
VectorIndexTask(
encodedReadOptions,
addIndexExec.columns.toList,
indexType.name(),
trainedSpec,
addIndexExec.indexName,
fid,
nsImpl,
nsProps,
tableId,
initialStorageOpts)
}.toSeq

try {
addIndexExec.session.sparkContext
.parallelize(tasks, tasks.size)
.map(t => t.execute())
.collect()
.map(decode[Index])
.toSeq
} catch {
case e: Exception =>
throw new RuntimeException(
"Vector index segment build failed. Uncommitted segments are not " +
"visible to readers and will not affect query correctness.",
e)
}
}

private def trainArtifactsOnDriver(column: String): (Array[Float], Array[Float]) = {
val dataset = Utils.openDatasetBuilder(readOptions)
.initialStorageOptions(initialStorageOpts.map(_.asJava).orNull)
.runtimeNamespace(
nsImpl.orNull,
nsProps.map(_.asJava).orNull,
tableId.map(_.asJava).orNull)
.build()
try {
val ivfBuilder = new IvfBuildParams.Builder().setNumPartitions(spec.numPartitions)
spec.sampleRate.foreach(ivfBuilder.setSampleRate)
spec.maxIters.foreach(ivfBuilder.setMaxIters)
val centroids = VectorTrainer.trainIvfCentroids(
dataset,
column,
ivfBuilder.build(),
spec.distanceType)

val codebook: Array[Float] = indexType match {
case IndexType.IVF_PQ =>
val pqBuilder = new PQBuildParams.Builder()
spec.numSubVectors.foreach(pqBuilder.setNumSubVectors)
spec.pqNumBits.foreach(pqBuilder.setNumBits)
spec.pqMaxIters.foreach(pqBuilder.setMaxIters)
spec.sampleRate.foreach(pqBuilder.setSampleRate)
VectorTrainer.trainPqCodebook(
dataset,
column,
pqBuilder.build(),
spec.distanceType)
case _ => null
}

(centroids, codebook)
} finally {
dataset.close()
}
}
}

/**
* Executor-side task that builds one uncommitted vector index segment for one fragment.
* Returns a Kryo-encoded [[org.lance.index.Index]] for the driver to commit.
*/
case class VectorIndexTask(
encodedReadOptions: String,
columns: List[String],
indexTypeName: String,
spec: VectorIndexSpec,
indexName: String,
fragmentId: Int,
namespaceImpl: Option[String],
namespaceProperties: Option[Map[String, String]],
tableId: Option[List[String]],
initialStorageOptions: Option[Map[String, String]]) extends Serializable {

def execute(): String = {
val readOptions = decode[LanceSparkReadOptions](encodedReadOptions)
val indexType = IndexType.valueOf(indexTypeName)
val vectorParams = spec.toVectorIndexParams(indexType)
val params = IndexParams.builder().setVectorIndexParams(vectorParams).build()

val indexOptions = IndexOptions
.builder(java.util.Arrays.asList(columns: _*), indexType, params)
.replace(false)
.withFragmentIds(Collections.singletonList(java.lang.Integer.valueOf(fragmentId)))
.build()

val dataset = Utils.openDatasetBuilder(readOptions)
.initialStorageOptions(initialStorageOptions.map(_.asJava).orNull)
.runtimeNamespace(
namespaceImpl.orNull,
namespaceProperties.map(_.asJava).orNull,
tableId.map(_.asJava).orNull)
.build()

try {
encode(dataset.createIndex(indexOptions))
} finally {
dataset.close()
}
}
}

/**
* Serializable carrier for IVF-family build parameters. Parsed from user WITH-args on the
* driver, populated with driver-trained artifacts (centroids, codebook), and shipped to
* executors where [[toVectorIndexParams]] rebuilds the native [[VectorIndexParams]] (whose
* nested builders are not Serializable).
*/
case class VectorIndexSpec(
metricType: String,
numPartitions: Int,
sampleRate: Option[Int],
maxIters: Option[Int],
// IVF_PQ
numSubVectors: Option[Int],
pqNumBits: Option[Int],
pqMaxIters: Option[Int],
// IVF_SQ
sqNumBits: Option[Short],
// Driver-trained artifacts populated by VectorIndexJob before task dispatch.
centroids: Array[Float] = null,
codebook: Array[Float] = null) extends Serializable {

def distanceType: DistanceType = VectorIndexSpec.parseDistanceType(metricType)

def toVectorIndexParams(indexType: IndexType): VectorIndexParams = {
val ivfBuilder = new IvfBuildParams.Builder().setNumPartitions(numPartitions)
sampleRate.foreach(ivfBuilder.setSampleRate)
maxIters.foreach(ivfBuilder.setMaxIters)
if (centroids != null) ivfBuilder.setCentroids(centroids)
val ivfParams = ivfBuilder.build()

val builder = new VectorIndexParams.Builder(ivfParams).setDistanceType(distanceType)

indexType match {
case IndexType.IVF_FLAT =>
case IndexType.IVF_PQ =>
val pqBuilder = new PQBuildParams.Builder()
numSubVectors.foreach(pqBuilder.setNumSubVectors)
pqNumBits.foreach(pqBuilder.setNumBits)
pqMaxIters.foreach(pqBuilder.setMaxIters)
sampleRate.foreach(pqBuilder.setSampleRate)
if (codebook != null) pqBuilder.setCodebook(codebook)
builder.setPqParams(pqBuilder.build())
case IndexType.IVF_SQ =>
val sqBuilder = new SQBuildParams.Builder()
sqNumBits.foreach(b => sqBuilder.setNumBits(b))
sampleRate.foreach(sqBuilder.setSampleRate)
builder.setSqParams(sqBuilder.build())
case other =>
throw new IllegalArgumentException(s"Unsupported vector index type: $other")
}

builder.build()
}
}

object VectorIndexSpec {

def fromArgs(indexType: IndexType, args: Seq[LanceNamedArgument]): VectorIndexSpec = {
def argInt(name: String): Option[Int] = args.find(_.name == name).map { a =>
a.value match {
case i: java.lang.Integer => i.intValue()
case l: java.lang.Long => l.intValue()
case other =>
throw new IllegalArgumentException(
s"Vector index arg '$name' must be an integer, got: $other")
}
}
def argString(name: String): Option[String] = args.find(_.name == name).map { a =>
a.value match {
case s: java.lang.String => s
case other => String.valueOf(other)
}
}

// use_residual is a no-op upstream as of lance 7 — the JNI bridge does not pass it to
// Rust and the Rust IvfBuildParams no longer carries the field. Reject the flag rather
// than silently ignore it so users don't think it took effect.
if (args.exists(_.name == "use_residual")) {
throw new IllegalArgumentException(
"WITH-arg 'use_residual' is not supported: lance-core's training path does not " +
"honor it. Remove the option to proceed.")
}

val numPartitions = argInt("num_partitions").getOrElse(
throw new IllegalArgumentException(
"Vector index requires 'num_partitions' in WITH clause"))

val metric = argString("metric_type").getOrElse("l2")
// Validate eagerly so a typo fails on the driver, not deep inside training.
parseDistanceType(metric)

val sqBits: Option[Short] = indexType match {
case IndexType.IVF_SQ => argInt("num_bits").map(_.toShort).orElse(Some(8.toShort))
case _ => None
}
val pqBits: Option[Int] = indexType match {
case IndexType.IVF_PQ => argInt("num_bits").orElse(Some(8))
case _ => None
}

VectorIndexSpec(
metricType = metric,
numPartitions = numPartitions,
sampleRate = argInt("sample_rate"),
maxIters = argInt("max_iters"),
numSubVectors = argInt("num_sub_vectors"),
pqNumBits = pqBits,
pqMaxIters = argInt("pq_max_iters"),
sqNumBits = sqBits)
}

def parseDistanceType(s: String): DistanceType = s.toLowerCase match {
case "l2" | "euclidean" => DistanceType.L2
case "cosine" => DistanceType.Cosine
case "dot" | "inner_product" | "ip" => DistanceType.Dot
case "hamming" => DistanceType.Hamming
case other => throw new IllegalArgumentException(
s"Unsupported metric_type '$other'; expected one of: l2, cosine, dot, hamming")
}
}

/**
* Utility methods for working with index types.
*/
Expand Down Expand Up @@ -993,6 +1289,9 @@ object IndexUtils {
case "btree" => IndexType.BTREE
case "zonemap" => IndexType.ZONEMAP
case "fts" => IndexType.INVERTED
case "ivf_flat" => IndexType.IVF_FLAT
case "ivf_pq" => IndexType.IVF_PQ
case "ivf_sq" => IndexType.IVF_SQ
case other => throw new UnsupportedOperationException(s"Unsupported index method: $other")
}
}
Expand All @@ -1006,6 +1305,13 @@ object IndexUtils {
}
}

private val VectorIndexTypes: Set[IndexType] = Set(
IndexType.IVF_FLAT,
IndexType.IVF_PQ,
IndexType.IVF_SQ)

def isVectorIndex(indexType: IndexType): Boolean = VectorIndexTypes.contains(indexType)

def btreeBuildMode(indexType: IndexType, args: Seq[LanceNamedArgument]): Option[String] = {
if (indexType != IndexType.BTREE) {
None
Expand Down
Loading