Skip to content

Commit 37c50f8

Browse files
authored
Brute force KNN vector search pushdown (#29621)
1 parent 951a8f5 commit 37c50f8

File tree

19 files changed

+1321
-51
lines changed

19 files changed

+1321
-51
lines changed

ydb/core/base/kmeans_clusters.cpp

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ namespace {
4848
return Ydb::Table::VectorIndexSettings::METRIC_UNSPECIFIED;
4949
}
5050
};
51-
51+
5252
Ydb::Table::VectorIndexSettings_Metric ParseSimilarity(const TString& similarity_, TString& error) {
5353
const TString similarity = to_lower(similarity_);
5454
if (similarity == "cosine")
@@ -60,7 +60,7 @@ namespace {
6060
return Ydb::Table::VectorIndexSettings::METRIC_UNSPECIFIED;
6161
}
6262
};
63-
63+
6464
Ydb::Table::VectorIndexSettings_VectorType ParseVectorType(const TString& vectorType_, TString& error) {
6565
const TString vectorType = to_lower(vectorType_);
6666
if (vectorType == "float")
@@ -491,6 +491,62 @@ std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings&
491491
}
492492
}
493493

494+
std::unique_ptr<IClusters> CreateClustersAutoDetect(Ydb::Table::VectorIndexSettings settings, const TStringBuf& targetVector, ui32 maxRounds, TString& error) {
495+
if (targetVector.empty()) {
496+
error = "Target vector is empty";
497+
return nullptr;
498+
}
499+
500+
const auto setLinearType = [&](Ydb::Table::VectorIndexSettings::VectorType type, size_t elementSize, TStringBuf typeName) -> bool {
501+
if (targetVector.size() < HeaderLen + elementSize) {
502+
error = TStringBuilder() << "Target vector too short for " << typeName << " type";
503+
return false;
504+
}
505+
settings.set_vector_type(type);
506+
settings.set_vector_dimension((targetVector.size() - HeaderLen) / elementSize);
507+
return true;
508+
};
509+
510+
const ui8 formatByte = static_cast<ui8>(targetVector.back());
511+
switch (formatByte) {
512+
case EFormat::FloatVector:
513+
if (!setLinearType(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_FLOAT, sizeof(float), "float")) {
514+
return nullptr;
515+
}
516+
break;
517+
case EFormat::Uint8Vector:
518+
if (!setLinearType(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_UINT8, sizeof(ui8), "uint8")) {
519+
return nullptr;
520+
}
521+
break;
522+
case EFormat::Int8Vector:
523+
if (!setLinearType(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_INT8, sizeof(i8), "int8")) {
524+
return nullptr;
525+
}
526+
break;
527+
case EFormat::BitVector: {
528+
if (targetVector.size() < HeaderLen + 2) {
529+
error = "Target vector too short for bit type";
530+
return nullptr;
531+
}
532+
const ui8 paddingBits = static_cast<ui8>(targetVector[targetVector.size() - 2]);
533+
const size_t payloadBits = (targetVector.size() - HeaderLen - 1) * 8;
534+
if (payloadBits < paddingBits) {
535+
error = "Invalid bit vector padding";
536+
return nullptr;
537+
}
538+
settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_BIT);
539+
settings.set_vector_dimension(payloadBits - paddingBits);
540+
break;
541+
}
542+
default:
543+
error = TStringBuilder() << "Unknown vector format byte: " << static_cast<int>(formatByte);
544+
return nullptr;
545+
}
546+
547+
return CreateClusters(settings, maxRounds, error);
548+
}
549+
494550
bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& error) {
495551
error = "";
496552

@@ -503,16 +559,16 @@ bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& e
503559
return false;
504560
}
505561

506-
if (!ValidateSettingInRange("levels",
507-
settings.has_levels() ? std::optional<ui64>(settings.levels()) : std::nullopt,
562+
if (!ValidateSettingInRange("levels",
563+
settings.has_levels() ? std::optional<ui64>(settings.levels()) : std::nullopt,
508564
MinLevels, MaxLevels,
509565
error))
510566
{
511567
return false;
512568
}
513569

514-
if (!ValidateSettingInRange("clusters",
515-
settings.has_clusters() ? std::optional<ui64>(settings.clusters()) : std::nullopt,
570+
if (!ValidateSettingInRange("clusters",
571+
settings.has_clusters() ? std::optional<ui64>(settings.clusters()) : std::nullopt,
516572
MinClusters, MaxClusters,
517573
error))
518574
{
@@ -529,7 +585,7 @@ bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& e
529585
}
530586

531587
if (settings.settings().vector_dimension() * settings.clusters() > MaxVectorDimensionMultiplyClusters) {
532-
error = TStringBuilder() << "Invalid vector_dimension*clusters: " << settings.settings().vector_dimension() << "*" << settings.clusters()
588+
error = TStringBuilder() << "Invalid vector_dimension*clusters: " << settings.settings().vector_dimension() << "*" << settings.clusters()
533589
<< " should be less than " << MaxVectorDimensionMultiplyClusters;
534590
return false;
535591
}
@@ -557,8 +613,8 @@ bool ValidateSettings(const Ydb::Table::VectorIndexSettings& settings, TString&
557613
return false;
558614
}
559615

560-
if (!ValidateSettingInRange("vector_dimension",
561-
settings.has_vector_dimension() ? std::optional<ui64>(settings.vector_dimension()) : std::nullopt,
616+
if (!ValidateSettingInRange("vector_dimension",
617+
settings.has_vector_dimension() ? std::optional<ui64>(settings.vector_dimension()) : std::nullopt,
562618
MinVectorDimension, MaxVectorDimension,
563619
error))
564620
{

ydb/core/base/kmeans_clusters.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class IClusters {
4949

5050
std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, ui32 maxRounds, TString& error);
5151

52+
// Auto-detect vector type and dimension from target vector when settings have dimension=0
53+
std::unique_ptr<IClusters> CreateClustersAutoDetect(Ydb::Table::VectorIndexSettings settings, const TStringBuf& targetVector, ui32 maxRounds, TString& error);
54+
5255
bool ValidateSettings(const Ydb::Table::VectorIndexSettings& settings, TString& error);
5356
bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& error);
5457
bool FillSetting(Ydb::Table::KMeansTreeSettings& settings, const TString& name, const TString& value, TString& error);

ydb/core/kqp/common/kqp_yql.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,18 @@ TKqpReadTableSettings ParseInternal(const TCoNameValueTupleList& node) {
182182
for(const auto& kv: lv) {
183183
settings.IndexSelectionInfo.emplace(kv.Name().Value(), kv.Value().Cast<TCoAtom>().Value());
184184
}
185-
185+
} else if (name == TKqpReadTableSettings::VectorTopKColumnSettingName) {
186+
YQL_ENSURE(tuple.Value().Maybe<TCoAtom>());
187+
settings.VectorTopKColumn = tuple.Value().Cast<TCoAtom>().Value();
188+
} else if (name == TKqpReadTableSettings::VectorTopKMetricSettingName) {
189+
YQL_ENSURE(tuple.Value().Maybe<TCoAtom>());
190+
settings.VectorTopKMetric = tuple.Value().Cast<TCoAtom>().Value();
191+
} else if (name == TKqpReadTableSettings::VectorTopKTargetSettingName) {
192+
YQL_ENSURE(tuple.Value().IsValid());
193+
settings.VectorTopKTarget = tuple.Value().Cast().Ptr();
194+
} else if (name == TKqpReadTableSettings::VectorTopKLimitSettingName) {
195+
YQL_ENSURE(tuple.Value().IsValid());
196+
settings.VectorTopKLimit = tuple.Value().Cast().Ptr();
186197
} else {
187198
YQL_ENSURE(false, "Unknown KqpReadTable setting name '" << name << "'");
188199
}
@@ -317,6 +328,38 @@ NNodes::TCoNameValueTupleList TKqpReadTableSettings::BuildNode(TExprContext& ctx
317328
.Done());
318329
}
319330

331+
if (VectorTopKColumn) {
332+
settings.emplace_back(
333+
Build<TCoNameValueTuple>(ctx, pos)
334+
.Name().Build(VectorTopKColumnSettingName)
335+
.Value<TCoAtom>().Build(VectorTopKColumn)
336+
.Done());
337+
}
338+
339+
if (VectorTopKMetric) {
340+
settings.emplace_back(
341+
Build<TCoNameValueTuple>(ctx, pos)
342+
.Name().Build(VectorTopKMetricSettingName)
343+
.Value<TCoAtom>().Build(VectorTopKMetric)
344+
.Done());
345+
}
346+
347+
if (VectorTopKTarget) {
348+
settings.emplace_back(
349+
Build<TCoNameValueTuple>(ctx, pos)
350+
.Name().Build(VectorTopKTargetSettingName)
351+
.Value(VectorTopKTarget)
352+
.Done());
353+
}
354+
355+
if (VectorTopKLimit) {
356+
settings.emplace_back(
357+
Build<TCoNameValueTuple>(ctx, pos)
358+
.Name().Build(VectorTopKLimitSettingName)
359+
.Value(VectorTopKLimit)
360+
.Done());
361+
}
362+
320363
return Build<TCoNameValueTupleList>(ctx, pos)
321364
.Add(settings)
322365
.Done();

ydb/core/kqp/common/kqp_yql.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ struct TKqpReadTableSettings: public TSortingOperator<ERequestSorting::NONE> {
130130
static constexpr TStringBuf TabletIdName = "TabletId";
131131
static constexpr TStringBuf PointPrefixLenSettingName = "PointPrefixLen";
132132
static constexpr TStringBuf IndexSelectionDebugInfoSettingName = "IndexSelectionDebugInfo";
133+
static constexpr TStringBuf VectorTopKColumnSettingName = "VectorTopKColumn";
134+
static constexpr TStringBuf VectorTopKMetricSettingName = "VectorTopKMetric";
135+
static constexpr TStringBuf VectorTopKTargetSettingName = "VectorTopKTarget";
136+
static constexpr TStringBuf VectorTopKLimitSettingName = "VectorTopKLimit";
133137

134138
TVector<TString> SkipNullKeys;
135139
TExprNode::TPtr ItemsLimit;
@@ -139,6 +143,12 @@ struct TKqpReadTableSettings: public TSortingOperator<ERequestSorting::NONE> {
139143
ui64 PointPrefixLen = 0;
140144
THashMap<TString, TString> IndexSelectionInfo;
141145

146+
// Vector top-K pushdown settings for brute force vector search
147+
TString VectorTopKColumn;
148+
TString VectorTopKMetric;
149+
TExprNode::TPtr VectorTopKTarget;
150+
TExprNode::TPtr VectorTopKLimit;
151+
142152
void AddSkipNullKey(const TString& key);
143153
void SetItemsLimit(const TExprNode::TPtr& expr) { ItemsLimit = expr; }
144154

ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,6 +2674,16 @@ TMaybe<size_t> TKqpTasksGraph::BuildScanTasksFromSource(TStageInfo& stageInfo, b
26742674
settings->SetItemsLimit(itemsLimit);
26752675
}
26762676

2677+
if (source.HasVectorTopK()) {
2678+
const auto& in = source.GetVectorTopK();
2679+
auto& out = *settings->MutableVectorTopK();
2680+
out.SetColumn(in.GetColumn());
2681+
*out.MutableSettings() = in.GetSettings();
2682+
auto target = ExtractPhyValue(stageInfo, in.GetTargetVector(), TxAlloc->HolderFactory, TxAlloc->TypeEnv, NUdf::TUnboxedValuePod());
2683+
out.SetTargetVector(TString(target.AsStringRef()));
2684+
out.SetLimit((ui32)ExtractPhyValue(stageInfo, in.GetLimit(), TxAlloc->HolderFactory, TxAlloc->TypeEnv, NUdf::TUnboxedValuePod()).Get<ui64>());
2685+
}
2686+
26772687
auto& lockTxId = GetMeta().LockTxId;
26782688
if (lockTxId) {
26792689
settings->SetLockTxId(*lockTxId);

ydb/core/kqp/opt/kqp_opt_build_txs.cpp

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,26 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase {
378378
return parameter;
379379
};
380380

381+
// Helper to collect TKqpTxResultBinding nodes and replace them with parameters
381382
TNodeOnNodeOwnedMap sourceReplaceMap;
383+
auto collectBindings = [&](const TExprNode::TPtr& root) {
384+
VisitExpr(root,
385+
[&](const TExprNode::TPtr& node) {
386+
TExprBase expr(node);
387+
if (auto binding = expr.Maybe<TKqpTxResultBinding>()) {
388+
sourceReplaceMap.emplace(node.Get(), makeParameterBinding(binding.Cast(), node->Pos()).Ptr());
389+
}
390+
return true;
391+
});
392+
};
393+
382394
for (ui32 i = 0; i < stage.Inputs().Size(); ++i) {
383395
const auto& input = stage.Inputs().Item(i);
384396
const auto& inputArg = stage.Program().Args().Arg(i);
385397

398+
// Scan inputs that may contain TKqpTxResultBinding
386399
if (input.Maybe<TDqSource>() || input.Maybe<TKqpCnStreamLookup>()) {
387-
VisitExpr(input.Ptr(),
388-
[&](const TExprNode::TPtr& node) {
389-
TExprBase expr(node);
390-
if (auto binding = expr.Maybe<TKqpTxResultBinding>()) {
391-
sourceReplaceMap.emplace(node.Get(), makeParameterBinding(binding.Cast(), node->Pos()).Ptr());
392-
}
393-
return true;
394-
});
400+
collectBindings(input.Ptr());
395401
}
396402

397403
auto maybeBinding = input.Maybe<TKqpTxResultBinding>();
@@ -407,6 +413,9 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase {
407413
argsMap.emplace(inputArg.Raw(), makeParameterBinding(maybeBinding.Cast(), input.Pos()).Ptr());
408414
}
409415

416+
// Scan program body for TKqpTxResultBinding (e.g. in TKqpReadTableRanges VectorTopK settings)
417+
collectBindings(stage.Program().Body().Ptr());
418+
410419
auto inputs = Build<TExprList>(ctx, stage.Pos())
411420
.Add(newInputs)
412421
.Done();
@@ -415,7 +424,7 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase {
415424
.Inputs(ctx.ReplaceNodes(ctx.ReplaceNodes(inputs.Ptr(), stagesMap), sourceReplaceMap))
416425
.Program()
417426
.Args(newArgs)
418-
.Body(ctx.ReplaceNodes(stage.Program().Body().Ptr(), argsMap))
427+
.Body(ctx.ReplaceNodes(ctx.ReplaceNodes(stage.Program().Body().Ptr(), argsMap), sourceReplaceMap))
419428
.Build()
420429
.Settings(stage.Settings())
421430
.Outputs(stage.Outputs())
@@ -463,34 +472,39 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase {
463472

464473
TVector<TDqPhyPrecompute> PrecomputeInputs(const TDqStage& stage) {
465474
TVector<TDqPhyPrecompute> result;
475+
476+
// Helper to collect precomputes from an expression tree
477+
auto collectPrecomputes = [&result](const TExprNode::TPtr& root, bool checkConnections = false) {
478+
VisitExpr(root,
479+
[&](const TExprNode::TPtr& ptr) {
480+
TExprBase node(ptr);
481+
if (auto maybePrecompute = node.Maybe<TDqPhyPrecompute>()) {
482+
result.push_back(maybePrecompute.Cast());
483+
return false;
484+
}
485+
if (checkConnections) {
486+
if (auto maybeConnection = node.Maybe<TDqConnection>()) {
487+
YQL_ENSURE(false, "unexpected connection in source");
488+
}
489+
}
490+
return true;
491+
});
492+
};
493+
494+
// Scan stage inputs for precomputes
466495
for (const auto& input : stage.Inputs()) {
467496
if (auto maybePrecompute = input.Maybe<TDqPhyPrecompute>()) {
468497
result.push_back(maybePrecompute.Cast());
469498
} else if (auto maybeSource = input.Maybe<TDqSource>()) {
470-
VisitExpr(maybeSource.Cast().Ptr(),
471-
[&] (const TExprNode::TPtr& ptr) {
472-
TExprBase node(ptr);
473-
if (auto maybePrecompute = node.Maybe<TDqPhyPrecompute>()) {
474-
result.push_back(maybePrecompute.Cast());
475-
return false;
476-
}
477-
if (auto maybeConnection = node.Maybe<TDqConnection>()) {
478-
YQL_ENSURE(false, "unexpected connection in source");
479-
}
480-
return true;
481-
});
499+
collectPrecomputes(maybeSource.Cast().Ptr(), /* checkConnections */ true);
482500
} else if (auto maybeStreamLookup = input.Maybe<TKqpCnStreamLookup>()) {
483-
VisitExpr(maybeStreamLookup.Cast().Settings().Ptr(),
484-
[&] (const TExprNode::TPtr& ptr) {
485-
TExprBase node(ptr);
486-
if (auto maybePrecompute = node.Maybe<TDqPhyPrecompute>()) {
487-
result.push_back(maybePrecompute.Cast());
488-
return false;
489-
}
490-
return true;
491-
});
501+
collectPrecomputes(maybeStreamLookup.Cast().Settings().Ptr());
492502
}
493503
}
504+
505+
// Scan program body for precomputes (e.g. in TKqpReadTableRanges VectorTopK settings)
506+
collectPrecomputes(stage.Program().Body().Ptr());
507+
494508
return result;
495509
}
496510

ydb/core/kqp/opt/physical/kqp_opt_phy.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase {
4545
AddHandler(0, IsSort, HNDL(RemoveRedundantSortOverReadTable));
4646
AddHandler(0, &TCoTake::Match, HNDL(ApplyLimitToReadTable));
4747
AddHandler(0, &TCoTopSort::Match, HNDL(ApplyLimitToOlapReadTable));
48+
AddHandler(0, &TCoTopSort::Match, HNDL(ApplyVectorTopKToReadTable));
49+
AddHandler(0, &TDqStage::Match, HNDL(ApplyVectorTopKToStageWithSource));
4850
AddHandler(0, &TCoFlatMap::Match, HNDL(PushOlapFilter));
4951
AddHandler(0, &TCoFlatMap::Match, HNDL(PushOlapProjections));
5052
AddHandler(0, &TCoAggregateCombine::Match, HNDL(PushAggregateCombineToStage));
@@ -206,7 +208,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase {
206208
else {
207209
TExprBase output = KqpBuildStreamIdxLookupJoinStagesKeepSorted(node, ctx, TypesCtx, true);
208210
DumpAppliedRule("BuildStreamIdxLookupJoinStagesKeepSorted", node.Ptr(), output.Ptr(), ctx);
209-
return output;
211+
return output;
210212
}
211213
}
212214

@@ -260,6 +262,18 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase {
260262
return output;
261263
}
262264

265+
TMaybeNode<TExprBase> ApplyVectorTopKToReadTable(TExprBase node, TExprContext& ctx) {
266+
TExprBase output = KqpApplyVectorTopKToReadTable(node, ctx, KqpCtx);
267+
DumpAppliedRule("ApplyVectorTopKToReadTable", node.Ptr(), output.Ptr(), ctx);
268+
return output;
269+
}
270+
271+
TMaybeNode<TExprBase> ApplyVectorTopKToStageWithSource(TExprBase node, TExprContext& ctx) {
272+
TExprBase output = KqpApplyVectorTopKToStageWithSource(node, ctx, KqpCtx);
273+
DumpAppliedRule("ApplyVectorTopKToStageWithSource", node.Ptr(), output.Ptr(), ctx);
274+
return output;
275+
}
276+
263277
TMaybeNode<TExprBase> PushOlapFilter(TExprBase node, TExprContext& ctx) {
264278
TExprBase output = KqpPushOlapFilter(node, ctx, KqpCtx, TypesCtx, *TypeAnnTransformer.Get());
265279
DumpAppliedRule("PushOlapFilter", node.Ptr(), output.Ptr(), ctx);

ydb/core/kqp/opt/physical/kqp_opt_phy_impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ NYql::NNodes::TMaybeNode<NYql::NNodes::TDqPhyPrecompute> BuildLookupKeysPrecompu
1919
NYql::NNodes::TCoAtomList BuildColumnsList(const THashSet<TStringBuf>& columns, NYql::TPositionHandle pos,
2020
NYql::TExprContext& ctx);
2121

22+
NYql::NNodes::TExprBase KqpPrecomputeParameter(NYql::NNodes::TExprBase param, NYql::TExprContext& ctx);
23+
2224
NYql::NNodes::TCoAtomList BuildColumnsList(const TVector<TStringBuf>& columns, NYql::TPositionHandle pos,
2325
NYql::TExprContext& ctx);
2426

0 commit comments

Comments
 (0)