Skip to content

Commit 4243190

Browse files
authored
Limit should not be pushed down if there are unconverted restrictions. Closes #291
1 parent 9646a61 commit 4243190

File tree

4 files changed

+47
-39
lines changed

4 files changed

+47
-39
lines changed

fdw.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ func goFdwBeginForeignScan(node *C.ForeignScanState, eflags C.int) {
222222

223223
// create a wrapper struct for cinfos
224224
cinfos := newConversionInfos(execState)
225-
quals := restrictionsToQuals(node, cinfos)
225+
quals, unhandledRestrictions := restrictionsToQuals(node, cinfos)
226226

227227
// start the plugin hub
228228
var err error
@@ -238,7 +238,7 @@ func goFdwBeginForeignScan(node *C.ForeignScanState, eflags C.int) {
238238
}
239239
// if we are NOT explaining, create an iterator to scan for us
240240
if !explain {
241-
iter, err := pluginHub.GetIterator(columns, quals, int64(execState.limit), opts)
241+
iter, err := pluginHub.GetIterator(columns, quals, unhandledRestrictions, int64(execState.limit), opts)
242242
if err != nil {
243243
log.Printf("[WARN] pluginHub.GetIterator FAILED: %s", err)
244244
FdwError(err)

hub/hub.go

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ func (h *Hub) GetSchema(remoteSchema string, localSchema string) (*proto.Schema,
279279
}
280280

281281
// GetIterator creates and returns an iterator
282-
func (h *Hub) GetIterator(columns []string, quals *proto.Quals, limit int64, opts types.Options) (Iterator, error) {
282+
func (h *Hub) GetIterator(columns []string, quals *proto.Quals, unhandledRestrictions int, limit int64, opts types.Options) (Iterator, error) {
283283
logging.LogTime("GetIterator start")
284284
qualMap, err := h.buildQualMap(quals)
285285
connectionName := opts["connection"]
@@ -296,10 +296,10 @@ func (h *Hub) GetIterator(columns []string, quals *proto.Quals, limit int64, opt
296296
var iterator Iterator
297297
// if this is a legacy aggregator connection, create a group iterator
298298
if h.IsLegacyAggregatorConnection(connectionName) {
299-
iterator, err = newLegacyGroupIterator(connectionName, table, qualMap, columns, limit, h, scanTraceCtx)
299+
iterator, err = newLegacyGroupIterator(connectionName, table, qualMap, unhandledRestrictions, columns, limit, h, scanTraceCtx)
300300
log.Printf("[TRACE] Hub GetIterator() created aggregate iterator (%p)", iterator)
301301
} else {
302-
iterator, err = h.startScanForConnection(connectionName, table, qualMap, columns, limit, scanTraceCtx)
302+
iterator, err = h.startScanForConnection(connectionName, table, qualMap, unhandledRestrictions, columns, limit, scanTraceCtx)
303303
log.Printf("[TRACE] Hub GetIterator() created iterator (%p)", iterator)
304304
}
305305

@@ -473,7 +473,7 @@ func (h *Hub) traceContextForScan(table string, columns []string, limit int64, q
473473
}
474474

475475
// startScanForConnection starts a scan for a single connection, using a scanIterator or a legacyScanIterator
476-
func (h *Hub) startScanForConnection(connectionName string, table string, qualMap map[string]*proto.Quals, columns []string, limit int64, scanTraceCtx *telemetry.TraceCtx) (_ Iterator, err error) {
476+
func (h *Hub) startScanForConnection(connectionName string, table string, qualMap map[string]*proto.Quals, unhandledRestrictions int, columns []string, limit int64, scanTraceCtx *telemetry.TraceCtx) (_ Iterator, err error) {
477477
defer func() {
478478
if err != nil {
479479
// close the span in case of errir
@@ -489,7 +489,7 @@ func (h *Hub) startScanForConnection(connectionName string, table string, qualMa
489489
}
490490
// if this is a legacy plugin, create legacy iterator
491491
if !connectionPlugin.SupportedOperations.MultipleConnections {
492-
return h.startScanForLegacyConnection(connectionName, table, qualMap, columns, limit, scanTraceCtx)
492+
return h.startScanForLegacyConnection(connectionName, table, qualMap, unhandledRestrictions, columns, limit, scanTraceCtx)
493493
}
494494

495495
// ok so this is a multi connection plugin, build list of connections,
@@ -506,7 +506,7 @@ func (h *Hub) startScanForConnection(connectionName string, table string, qualMa
506506
connectionNames = connectionConfig.GetResolveConnectionNames()
507507
}
508508
// for each connection, determine whether to pushdown the limit
509-
connectionLimitMap, err := h.buildConnectionLimitMap(table, qualMap, connectionNames, limit, connectionPlugin)
509+
connectionLimitMap, err := h.buildConnectionLimitMap(table, qualMap, unhandledRestrictions, connectionNames, limit, connectionPlugin)
510510
if err != nil {
511511
return nil, err
512512
}
@@ -524,7 +524,7 @@ func (h *Hub) startScanForConnection(connectionName string, table string, qualMa
524524
return iterator, nil
525525
}
526526

527-
func (h *Hub) buildConnectionLimitMap(table string, qualMap map[string]*proto.Quals, connectionNames []string, limit int64, connectionPlugin *steampipeconfig.ConnectionPlugin) (map[string]int64, error) {
527+
func (h *Hub) buildConnectionLimitMap(table string, qualMap map[string]*proto.Quals, unhandledRestrictions int, connectionNames []string, limit int64, connectionPlugin *steampipeconfig.ConnectionPlugin) (map[string]int64, error) {
528528
log.Printf("[TRACE] buildConnectionLimitMap, table: '%s', %d %s, limit: %d", table, len(connectionNames), utils.Pluralize("connection", len(connectionNames)), limit)
529529

530530
connectionSchema, err := connectionPlugin.GetSchema(connectionNames[0])
@@ -538,7 +538,7 @@ func (h *Hub) buildConnectionLimitMap(table string, qualMap map[string]*proto.Qu
538538
// check once whether we should push down
539539
if limit != -1 && schemaMode == plugin.SchemaModeStatic {
540540
log.Printf("[TRACE] static schema - using same limit for all connections")
541-
if !h.shouldPushdownLimit(table, qualMap, connectionSchema) {
541+
if !h.shouldPushdownLimit(table, qualMap, unhandledRestrictions, connectionSchema) {
542542
limit = -1
543543
}
544544
}
@@ -548,7 +548,7 @@ func (h *Hub) buildConnectionLimitMap(table string, qualMap map[string]*proto.Qu
548548
for _, c := range connectionNames {
549549
connectionLimit := limit
550550
// if schema mode is dynamic, check whether we should push down for each connection
551-
if schemaMode == plugin.SchemaModeDynamic && !h.shouldPushdownLimit(table, qualMap, connectionSchema) {
551+
if schemaMode == plugin.SchemaModeDynamic && !h.shouldPushdownLimit(table, qualMap, unhandledRestrictions, connectionSchema) {
552552
log.Printf("[INFO] not pushing limit down for connection %s", c)
553553
connectionLimit = -1
554554
}
@@ -558,10 +558,10 @@ func (h *Hub) buildConnectionLimitMap(table string, qualMap map[string]*proto.Qu
558558
return connectionLimitMap, nil
559559
}
560560

561-
func (h *Hub) startScanForLegacyConnection(connectionName string, table string, qualMap map[string]*proto.Quals, columns []string, limit int64, scanTraceCtx *telemetry.TraceCtx) (_ Iterator, err error) {
561+
func (h *Hub) startScanForLegacyConnection(connectionName string, table string, qualMap map[string]*proto.Quals, unhandledRestrictions int, columns []string, limit int64, scanTraceCtx *telemetry.TraceCtx) (_ Iterator, err error) {
562562
// if this is an aggregate connection, create a group iterator
563563
if h.IsLegacyAggregatorConnection(connectionName) {
564-
return newLegacyGroupIterator(connectionName, table, qualMap, columns, limit, h, scanTraceCtx)
564+
return newLegacyGroupIterator(connectionName, table, qualMap, unhandledRestrictions, columns, limit, h, scanTraceCtx)
565565
}
566566

567567
connectionPlugin, err := h.getConnectionPlugin(connectionName)
@@ -576,7 +576,7 @@ func (h *Hub) startScanForLegacyConnection(connectionName string, table string,
576576
// determine whether to include the limit, based on the quals
577577
// we ONLY pushdown the limit is all quals have corresponding key columns,
578578
// and if the qual operator is supported by the key column
579-
if limit != -1 && !h.shouldPushdownLimit(table, qualMap, connectionSchema) {
579+
if limit != -1 && !h.shouldPushdownLimit(table, qualMap, unhandledRestrictions, connectionSchema) {
580580
limit = -1
581581
}
582582

@@ -602,7 +602,12 @@ func (h *Hub) startScanForLegacyConnection(connectionName string, table string,
602602
// determine whether to include the limit, based on the quals
603603
// we ONLY pushdown the limit if all quals have corresponding key columns,
604604
// and if the qual operator is supported by the key column
605-
func (h *Hub) shouldPushdownLimit(table string, qualMap map[string]*proto.Quals, connectionSchema *proto.Schema) bool {
605+
func (h *Hub) shouldPushdownLimit(table string, qualMap map[string]*proto.Quals, unhandledRestrictions int, connectionSchema *proto.Schema) bool {
606+
// if we have any unhandled restrictions, we CANNOT push limit down
607+
if unhandledRestrictions > 0 {
608+
return false
609+
}
610+
606611
// build a map of all key columns
607612
tableSchema, ok := connectionSchema.Schema[table]
608613
if !ok {

hub/legacy_group_iterator.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type legacyGroupIterator struct {
2424
traceCtx *telemetry.TraceCtx
2525
}
2626

27-
func newLegacyGroupIterator(name string, table string, qualMap map[string]*proto.Quals, columns []string, limit int64, h *Hub, scanTraceCtx *telemetry.TraceCtx) (Iterator, error) {
27+
func newLegacyGroupIterator(name string, table string, qualMap map[string]*proto.Quals, unhandledRestrictions int, columns []string, limit int64, h *Hub, scanTraceCtx *telemetry.TraceCtx) (Iterator, error) {
2828
res := &legacyGroupIterator{
2929
Name: name,
3030
// create a buffered channel
@@ -48,7 +48,7 @@ func newLegacyGroupIterator(name string, table string, qualMap map[string]*proto
4848
)
4949
connectionTraceCtx := &telemetry.TraceCtx{Ctx: ctx, Span: span}
5050

51-
iterator, err := h.startScanForLegacyConnection(connectionName, table, qualMap, columns, limit, connectionTraceCtx)
51+
iterator, err := h.startScanForLegacyConnection(connectionName, table, qualMap, unhandledRestrictions, columns, limit, connectionTraceCtx)
5252
if err != nil {
5353
errors = append(errors, err)
5454
} else {

quals.go

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import "C"
99

1010
import (
1111
"fmt"
12+
"github.com/gertd/go-pluralize"
1213
"log"
1314
"net"
1415
"unsafe"
@@ -19,7 +20,7 @@ import (
1920
"github.com/turbot/steampipe-plugin-sdk/v5/plugin/quals"
2021
)
2122

22-
func restrictionsToQuals(node *C.ForeignScanState, cinfos *conversionInfos) *proto.Quals {
23+
func restrictionsToQuals(node *C.ForeignScanState, cinfos *conversionInfos) (qualsList *proto.Quals, unhandledRestrictions int) {
2324
defer func() {
2425
if r := recover(); r != nil {
2526
log.Printf("[WARN] restrictionsToQuals recovered from panic: %v", r)
@@ -29,51 +30,53 @@ func restrictionsToQuals(node *C.ForeignScanState, cinfos *conversionInfos) *pro
2930
plan := (*C.ForeignScan)(unsafe.Pointer(node.ss.ps.plan))
3031
restrictions := plan.fdw_exprs
3132

32-
qualsList := &proto.Quals{}
33+
qualsList = &proto.Quals{}
3334
if restrictions == nil {
34-
return qualsList
35+
return qualsList, 0
3536
}
3637

3738
for it := C.list_head(restrictions); it != nil; it = C.lnext(restrictions, it) {
3839
restriction := C.cellGetExpr(it)
3940

4041
log.Printf("[TRACE] RestrictionsToQuals: restriction %s", C.GoString(C.tagTypeToString(C.fdw_nodeTag(restriction))))
4142

43+
var q *proto.Qual
4244
switch C.fdw_nodeTag(restriction) {
4345
case C.T_OpExpr:
44-
if q := qualFromOpExpr(C.cellGetOpExpr(it), node, cinfos); q != nil {
45-
qualsList.Append(q)
46-
}
46+
q = qualFromOpExpr(C.cellGetOpExpr(it), node, cinfos)
4747
case C.T_Var:
48-
q := qualFromVar(C.cellGetVar(it), node, cinfos)
49-
qualsList.Append(q)
48+
q = qualFromVar(C.cellGetVar(it), node, cinfos)
5049

5150
case C.T_ScalarArrayOpExpr:
52-
if q := qualFromScalarOpExpr(C.cellGetScalarArrayOpExpr(it), node, cinfos); q != nil {
53-
qualsList.Append(q)
54-
}
51+
q = qualFromScalarOpExpr(C.cellGetScalarArrayOpExpr(it), node, cinfos)
5552
case C.T_NullTest:
56-
q := qualFromNullTest(C.cellGetNullTest(it), node, cinfos)
57-
if q != nil {
58-
qualsList.Append(q)
59-
}
53+
q = qualFromNullTest(C.cellGetNullTest(it), node, cinfos)
54+
6055
//extractClauseFromNullTest(base_relids, (NullTest *) node, qualsList);
6156
case C.T_BooleanTest:
62-
if q := qualFromBooleanTest((*C.BooleanTest)(unsafe.Pointer(restriction)), node, cinfos); q != nil {
63-
qualsList.Append(q)
64-
}
57+
q = qualFromBooleanTest((*C.BooleanTest)(unsafe.Pointer(restriction)), node, cinfos)
6558
case C.T_BoolExpr:
66-
if q := qualFromBoolExpr((*C.BoolExpr)(unsafe.Pointer(restriction)), node, cinfos); q != nil {
67-
qualsList.Append(q)
68-
}
59+
q = qualFromBoolExpr((*C.BoolExpr)(unsafe.Pointer(restriction)), node, cinfos)
60+
}
61+
62+
if q != nil {
63+
qualsList.Append(q)
64+
} else {
65+
// we failed to handle this restriction
66+
unhandledRestrictions++
6967
}
7068

7169
}
7270
log.Printf("[TRACE] RestrictionsToQuals: converted postgres restrictions protobuf quals")
7371
for _, q := range qualsList.Quals {
7472
log.Printf("[TRACE] %s", grpc.QualToString(q))
7573
}
76-
return qualsList
74+
if unhandledRestrictions > 0 {
75+
log.Printf("[WARN] RestrictionsToQuals: failed to convert %s %s to quals",
76+
unhandledRestrictions,
77+
pluralize.NewClient().Pluralize("restriction", unhandledRestrictions, false))
78+
}
79+
return qualsList, unhandledRestrictions
7780
}
7881

7982
// build a protobuf qual from an OpExpr

0 commit comments

Comments
 (0)