Skip to content

Commit a658707

Browse files
authored
Merge pull request #47 from gatewayd-io/move-query-parsing-to-sdk
Move query parsing to the SDK
2 parents ee83a0f + 10e7dc0 commit a658707

File tree

6 files changed

+16
-415
lines changed

6 files changed

+16
-415
lines changed

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@ go 1.20
44

55
require (
66
github.com/alicebob/miniredis/v2 v2.30.1
7-
github.com/gatewayd-io/gatewayd-plugin-sdk v0.0.17
7+
github.com/gatewayd-io/gatewayd-plugin-sdk v0.0.18
88
github.com/go-co-op/gocron v1.19.0
99
github.com/go-redis/redis/v8 v8.11.5
1010
github.com/hashicorp/go-hclog v1.5.0
1111
github.com/hashicorp/go-plugin v1.4.9
1212
github.com/jackc/pgx/v5 v5.3.1
13-
github.com/pganalyze/pg_query_go/v2 v2.2.0
1413
github.com/prometheus/client_golang v1.14.0
1514
github.com/spf13/cast v1.5.0
1615
github.com/stretchr/testify v1.8.2
@@ -33,6 +32,7 @@ require (
3332
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
3433
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
3534
github.com/oklog/run v1.1.0 // indirect
35+
github.com/pganalyze/pg_query_go/v2 v2.2.0 // indirect
3636
github.com/pmezard/go-difflib v1.0.0 // indirect
3737
github.com/prometheus/client_model v0.3.0 // indirect
3838
github.com/prometheus/common v0.42.0 // indirect

go.sum

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs=
1919
github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw=
2020
github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE=
2121
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
22-
github.com/gatewayd-io/gatewayd-plugin-sdk v0.0.17 h1:GH/m+MBBtxX8+OeVrOjXUOnJzPOyl/zEqj8x95uafRA=
23-
github.com/gatewayd-io/gatewayd-plugin-sdk v0.0.17/go.mod h1:XMLNt13q7KDXs/x4V7f3uEp3RJ8bAp6hvvKYlULj9OY=
22+
github.com/gatewayd-io/gatewayd-plugin-sdk v0.0.18 h1:Qr0vwC99Ov1Xd4+NQPWMGlLlnfVCocQjjdCFqdSH9lg=
23+
github.com/gatewayd-io/gatewayd-plugin-sdk v0.0.18/go.mod h1:qnxOD6QQQ7OWqa1JdGMLGtlvmFkK8Xf7eH0qUD1N0n0=
2424
github.com/go-co-op/gocron v1.19.0 h1:XlPLqNnxnKblmCRLdfcWV1UgbukQaU54QdNeR1jtgak=
2525
github.com/go-co-op/gocron v1.19.0/go.mod h1:UqVyvM90I1q/R1qGEX6cBORI6WArLuEgYlbncLMvzRM=
2626
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
@@ -87,7 +87,7 @@ github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJf
8787
github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY=
8888
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
8989
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
90-
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
90+
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
9191
github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
9292
github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU=
9393
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

plugin/plugin.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,13 @@ func (p *Plugin) OnTrafficFromServer(
192192
CacheSetsCounter.Inc()
193193

194194
// Cache the query as well.
195-
query, err := GetQueryFromRequest(request)
195+
query, err := postgres.GetQueryFromRequest(request)
196196
if err != nil {
197197
p.Logger.Debug("Failed to get query from request", "error", err)
198198
return resp, nil
199199
}
200200

201-
tables, err := GetTablesFromQuery(query)
201+
tables, err := postgres.GetTablesFromQuery(query)
202202
if err != nil {
203203
p.Logger.Debug("Failed to get tables from query", "error", err)
204204
return resp, nil
@@ -256,7 +256,7 @@ func (p *Plugin) invalidateDML(ctx context.Context, query string) {
256256
return
257257
}
258258

259-
tables, err := GetTablesFromQuery(queryMessage["String"])
259+
tables, err := postgres.GetTablesFromQuery(queryMessage["String"])
260260
if err != nil {
261261
p.Logger.Debug("Failed to get tables from query", "error", err)
262262
return

plugin/plugin_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ import (
1515
"google.golang.org/protobuf/types/known/structpb"
1616
)
1717

18+
func testQueryRequest() (string, string) {
19+
query := "SELECT * FROM users"
20+
queryMsg := pgproto3.Query{String: query}
21+
// Encode the data to base64.
22+
return query, base64.StdEncoding.EncodeToString(queryMsg.Encode(nil))
23+
}
24+
1825
func testStartupRequest() string {
1926
startupMsg := pgproto3.StartupMessage{
2027
ProtocolVersion: 196608,

plugin/utils.go

Lines changed: 1 addition & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,142 +1,12 @@
11
package plugin
22

33
import (
4-
"encoding/base64"
54
"net"
65
"strconv"
76
"strings"
8-
9-
pgQuery "github.com/pganalyze/pg_query_go/v2"
10-
)
11-
12-
const (
13-
MinPgSQLMessageLength = 5
14-
AddressPortPairLength = 2
157
)
168

17-
// GetQueryFromRequest decodes the request and returns the query.
18-
func GetQueryFromRequest(req string) (string, error) {
19-
requestDecoded, err := base64.StdEncoding.DecodeString(req)
20-
if err != nil {
21-
return "", err
22-
}
23-
24-
if len(requestDecoded) < MinPgSQLMessageLength {
25-
return "", nil
26-
}
27-
28-
// The first byte is the message type.
29-
// The next 4 bytes are the length of the message.
30-
// The rest of the message is the query.
31-
// See https://www.postgresql.org/docs/13/protocol-message-formats.html
32-
// for more information.
33-
size := int(requestDecoded[1])<<24 + int(requestDecoded[2])<<16 + int(requestDecoded[3])<<8 + int(requestDecoded[4])
34-
return string(requestDecoded[MinPgSQLMessageLength:size]), nil
35-
}
36-
37-
// isMulti checks if the query is a union, intersect, or except.
38-
func isMulti(stmt *pgQuery.SelectStmt) bool {
39-
return stmt.GetOp() == pgQuery.SetOperation_SETOP_UNION ||
40-
stmt.GetOp() == pgQuery.SetOperation_SETOP_INTERSECT ||
41-
stmt.GetOp() == pgQuery.SetOperation_SETOP_EXCEPT
42-
}
43-
44-
// getSingleTable returns the tables used in a query.
45-
func getSingleTable(stmt *pgQuery.SelectStmt) []string {
46-
tables := []string{}
47-
for _, from := range stmt.FromClause {
48-
rangeVar := from.GetRangeVar()
49-
if rangeVar != nil {
50-
tables = append(tables, rangeVar.Relname)
51-
}
52-
}
53-
54-
return tables
55-
}
56-
57-
// getMultiTable returns the tables used in a union, intersect, or except query.
58-
func getMultiTable(stmt *pgQuery.SelectStmt) []string {
59-
tables := []string{}
60-
// Get the tables from the left side.
61-
left := stmt.GetLarg()
62-
tables = append(tables, getSingleTable(left)...)
63-
// Get the tables from the right side.
64-
right := stmt.GetRarg()
65-
tables = append(tables, getSingleTable(right)...)
66-
67-
return tables
68-
}
69-
70-
// GetTablesFromQuery returns the tables used in a query.
71-
func GetTablesFromQuery(query string) ([]string, error) {
72-
stmt, err := pgQuery.Parse(query)
73-
if err != nil {
74-
return nil, err
75-
}
76-
77-
if len(stmt.Stmts) == 0 {
78-
return nil, nil
79-
}
80-
81-
tables := []string{}
82-
83-
for _, stmt := range stmt.Stmts {
84-
// Get the tables from the left and right side of the complex query.
85-
selectStatement := stmt.Stmt.GetSelectStmt()
86-
if isMulti(selectStatement) {
87-
tables = append(tables, getMultiTable(selectStatement)...)
88-
}
89-
90-
// Get the table from the WITH clause.
91-
if withClause := stmt.Stmt.GetSelectStmt().GetWithClause(); withClause != nil {
92-
for _, cte := range withClause.Ctes {
93-
selectStmt := cte.GetCommonTableExpr().Ctequery.GetSelectStmt()
94-
if isMulti(selectStmt) {
95-
tables = append(tables, getMultiTable(selectStmt)...)
96-
} else {
97-
tables = append(tables, getSingleTable(selectStmt)...)
98-
}
99-
}
100-
} else {
101-
// Get the table from the FROM clause.
102-
if selectStatement := stmt.Stmt.GetSelectStmt(); selectStatement != nil {
103-
tables = append(tables, getSingleTable(selectStatement)...)
104-
}
105-
}
106-
107-
if insertQuery := stmt.Stmt.GetInsertStmt(); insertQuery != nil {
108-
tables = append(tables, insertQuery.Relation.Relname)
109-
}
110-
111-
if updateQuery := stmt.Stmt.GetUpdateStmt(); updateQuery != nil {
112-
tables = append(tables, updateQuery.Relation.Relname)
113-
}
114-
115-
if deleteQuery := stmt.Stmt.GetDeleteStmt(); deleteQuery != nil {
116-
tables = append(tables, deleteQuery.Relation.Relname)
117-
}
118-
119-
if truncateQuery := stmt.Stmt.GetTruncateStmt(); truncateQuery != nil {
120-
for _, relation := range truncateQuery.Relations {
121-
tables = append(tables, relation.GetRangeVar().Relname)
122-
}
123-
}
124-
125-
if dropTable := stmt.Stmt.GetDropStmt(); dropTable != nil {
126-
for _, object := range dropTable.GetObjects() {
127-
for _, table := range object.GetList().GetItems() {
128-
tables = append(tables, table.GetString_().Str)
129-
}
130-
}
131-
}
132-
133-
if alterTable := stmt.Stmt.GetAlterTableStmt(); alterTable != nil {
134-
tables = append(tables, alterTable.Relation.Relname)
135-
}
136-
}
137-
138-
return tables, nil
139-
}
9+
const AddressPortPairLength = 2
14010

14111
// validateIP checks if an IP address is valid.
14212
func validateIP(ip net.IP) bool {

0 commit comments

Comments
 (0)