|
1 | 1 | package plugin |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "encoding/base64" |
5 | 4 | "net" |
6 | 5 | "strconv" |
7 | 6 | "strings" |
8 | | - |
9 | | - pgQuery "github.com/pganalyze/pg_query_go/v2" |
10 | | -) |
11 | | - |
12 | | -const ( |
13 | | - MinPgSQLMessageLength = 5 |
14 | | - AddressPortPairLength = 2 |
15 | 7 | ) |
16 | 8 |
|
17 | | -// GetQueryFromRequest decodes the request and returns the query. |
18 | | -func GetQueryFromRequest(req string) (string, error) { |
19 | | - requestDecoded, err := base64.StdEncoding.DecodeString(req) |
20 | | - if err != nil { |
21 | | - return "", err |
22 | | - } |
23 | | - |
24 | | - if len(requestDecoded) < MinPgSQLMessageLength { |
25 | | - return "", nil |
26 | | - } |
27 | | - |
28 | | - // The first byte is the message type. |
29 | | - // The next 4 bytes are the length of the message. |
30 | | - // The rest of the message is the query. |
31 | | - // See https://www.postgresql.org/docs/13/protocol-message-formats.html |
32 | | - // for more information. |
33 | | - size := int(requestDecoded[1])<<24 + int(requestDecoded[2])<<16 + int(requestDecoded[3])<<8 + int(requestDecoded[4]) |
34 | | - return string(requestDecoded[MinPgSQLMessageLength:size]), nil |
35 | | -} |
36 | | - |
37 | | -// isMulti checks if the query is a union, intersect, or except. |
38 | | -func isMulti(stmt *pgQuery.SelectStmt) bool { |
39 | | - return stmt.GetOp() == pgQuery.SetOperation_SETOP_UNION || |
40 | | - stmt.GetOp() == pgQuery.SetOperation_SETOP_INTERSECT || |
41 | | - stmt.GetOp() == pgQuery.SetOperation_SETOP_EXCEPT |
42 | | -} |
43 | | - |
44 | | -// getSingleTable returns the tables used in a query. |
45 | | -func getSingleTable(stmt *pgQuery.SelectStmt) []string { |
46 | | - tables := []string{} |
47 | | - for _, from := range stmt.FromClause { |
48 | | - rangeVar := from.GetRangeVar() |
49 | | - if rangeVar != nil { |
50 | | - tables = append(tables, rangeVar.Relname) |
51 | | - } |
52 | | - } |
53 | | - |
54 | | - return tables |
55 | | -} |
56 | | - |
57 | | -// getMultiTable returns the tables used in a union, intersect, or except query. |
58 | | -func getMultiTable(stmt *pgQuery.SelectStmt) []string { |
59 | | - tables := []string{} |
60 | | - // Get the tables from the left side. |
61 | | - left := stmt.GetLarg() |
62 | | - tables = append(tables, getSingleTable(left)...) |
63 | | - // Get the tables from the right side. |
64 | | - right := stmt.GetRarg() |
65 | | - tables = append(tables, getSingleTable(right)...) |
66 | | - |
67 | | - return tables |
68 | | -} |
69 | | - |
70 | | -// GetTablesFromQuery returns the tables used in a query. |
71 | | -func GetTablesFromQuery(query string) ([]string, error) { |
72 | | - stmt, err := pgQuery.Parse(query) |
73 | | - if err != nil { |
74 | | - return nil, err |
75 | | - } |
76 | | - |
77 | | - if len(stmt.Stmts) == 0 { |
78 | | - return nil, nil |
79 | | - } |
80 | | - |
81 | | - tables := []string{} |
82 | | - |
83 | | - for _, stmt := range stmt.Stmts { |
84 | | - // Get the tables from the left and right side of the complex query. |
85 | | - selectStatement := stmt.Stmt.GetSelectStmt() |
86 | | - if isMulti(selectStatement) { |
87 | | - tables = append(tables, getMultiTable(selectStatement)...) |
88 | | - } |
89 | | - |
90 | | - // Get the table from the WITH clause. |
91 | | - if withClause := stmt.Stmt.GetSelectStmt().GetWithClause(); withClause != nil { |
92 | | - for _, cte := range withClause.Ctes { |
93 | | - selectStmt := cte.GetCommonTableExpr().Ctequery.GetSelectStmt() |
94 | | - if isMulti(selectStmt) { |
95 | | - tables = append(tables, getMultiTable(selectStmt)...) |
96 | | - } else { |
97 | | - tables = append(tables, getSingleTable(selectStmt)...) |
98 | | - } |
99 | | - } |
100 | | - } else { |
101 | | - // Get the table from the FROM clause. |
102 | | - if selectStatement := stmt.Stmt.GetSelectStmt(); selectStatement != nil { |
103 | | - tables = append(tables, getSingleTable(selectStatement)...) |
104 | | - } |
105 | | - } |
106 | | - |
107 | | - if insertQuery := stmt.Stmt.GetInsertStmt(); insertQuery != nil { |
108 | | - tables = append(tables, insertQuery.Relation.Relname) |
109 | | - } |
110 | | - |
111 | | - if updateQuery := stmt.Stmt.GetUpdateStmt(); updateQuery != nil { |
112 | | - tables = append(tables, updateQuery.Relation.Relname) |
113 | | - } |
114 | | - |
115 | | - if deleteQuery := stmt.Stmt.GetDeleteStmt(); deleteQuery != nil { |
116 | | - tables = append(tables, deleteQuery.Relation.Relname) |
117 | | - } |
118 | | - |
119 | | - if truncateQuery := stmt.Stmt.GetTruncateStmt(); truncateQuery != nil { |
120 | | - for _, relation := range truncateQuery.Relations { |
121 | | - tables = append(tables, relation.GetRangeVar().Relname) |
122 | | - } |
123 | | - } |
124 | | - |
125 | | - if dropTable := stmt.Stmt.GetDropStmt(); dropTable != nil { |
126 | | - for _, object := range dropTable.GetObjects() { |
127 | | - for _, table := range object.GetList().GetItems() { |
128 | | - tables = append(tables, table.GetString_().Str) |
129 | | - } |
130 | | - } |
131 | | - } |
132 | | - |
133 | | - if alterTable := stmt.Stmt.GetAlterTableStmt(); alterTable != nil { |
134 | | - tables = append(tables, alterTable.Relation.Relname) |
135 | | - } |
136 | | - } |
137 | | - |
138 | | - return tables, nil |
139 | | -} |
| 9 | +const AddressPortPairLength = 2 |
140 | 10 |
|
141 | 11 | // validateIP checks if an IP address is valid. |
142 | 12 | func validateIP(ip net.IP) bool { |
|
0 commit comments