Skip to content

Commit ceb0e9b

Browse files
authored
skip caching datetime functions (#82)
- added IsCacheNeeded function - added test case for TestPluginDateFunctionInQuery
1 parent 5497d4c commit ceb0e9b

File tree

2 files changed

+118
-1
lines changed

2 files changed

+118
-1
lines changed

plugin/plugin.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ type CachePlugin struct {
4545
Impl Plugin
4646
}
4747

48+
// Define a set for PostgreSQL date/time functions
49+
// https://www.postgresql.org/docs/8.2/functions-datetime.html
50+
var pgDateTimeFunctions = map[string]struct{}{
51+
"AGE": {},
52+
"CLOCK_TIMESTAMP": {},
53+
"CURRENT_DATE": {},
54+
"CURRENT_TIME": {},
55+
"CURRENT_TIMESTAMP": {},
56+
"LOCALTIME": {},
57+
"LOCALTIMESTAMP": {},
58+
"NOW": {},
59+
"STATEMENT_TIMESTAMP": {},
60+
"TIMEOFDAY": {},
61+
"TRANSACTION_TIMESTAMP": {},
62+
}
63+
4864
// NewCachePlugin returns a new instance of the CachePlugin.
4965
func NewCachePlugin(impl Plugin) *CachePlugin {
5066
return &CachePlugin{
@@ -164,6 +180,18 @@ func (p *Plugin) OnTrafficFromClient(
164180
return req, nil
165181
}
166182

183+
// IsCacheNeeded determines if caching is needed.
184+
func IsCacheNeeded(upperQuery string) bool {
185+
// Iterate over each function name in the set of PostgreSQL date/time functions.
186+
for function := range pgDateTimeFunctions {
187+
if strings.Contains(upperQuery, function) {
188+
// If the query contains a date/time function, caching is not needed.
189+
return false
190+
}
191+
}
192+
return true
193+
}
194+
167195
func (p *Plugin) UpdateCache(ctx context.Context) {
168196
for {
169197
serverResponse, ok := <-p.UpdateCacheChannel
@@ -219,7 +247,7 @@ func (p *Plugin) UpdateCache(ctx context.Context) {
219247
}
220248

221249
cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":")
222-
if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 {
250+
if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 && IsCacheNeeded(cacheKey) {
223251
// The request was successful and the response contains data. Cache the response.
224252
if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil {
225253
CacheMissesCounter.Inc()

plugin/plugin_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ func testQueryRequest() (string, []byte) {
2525
return query, queryBytes
2626
}
2727

28+
func testQueryRequestWithDateFucntion() (string, []byte) {
29+
query := `SELECT
30+
user_id,
31+
username,
32+
last_login,
33+
NOW() AS current_time
34+
FROM
35+
users
36+
WHERE
37+
last_login >= CURRENT_DATE;`
38+
queryMsg := pgproto3.Query{String: query}
39+
// Encode the data to base64.
40+
queryBytes, _ := queryMsg.Encode(nil)
41+
return query, queryBytes
42+
}
43+
2844
func testStartupRequest() []byte {
2945
startupMsg := pgproto3.StartupMessage{
3046
ProtocolVersion: 196608,
@@ -180,3 +196,76 @@ func Test_Plugin(t *testing.T) {
180196
assert.Equal(t, resultMap["response"], response)
181197
assert.Contains(t, resultMap, sdkAct.Signals)
182198
}
199+
200+
func TestPluginDateFunctionInQuery(t *testing.T) {
201+
// Initialize a new mock Redis server.
202+
mockRedisServer := miniredis.RunT(t)
203+
redisURL := "redis://" + mockRedisServer.Addr() + "/0"
204+
redisConfig, err := redis.ParseURL(redisURL)
205+
redisClient := redis.NewClient(redisConfig)
206+
207+
cacheUpdateChannel := make(chan *v1.Struct, 10)
208+
209+
// Create and initialize a new plugin.
210+
logger := hclog.New(&hclog.LoggerOptions{
211+
Level: logging.GetLogLevel("error"),
212+
Output: os.Stdout,
213+
})
214+
plugin := NewCachePlugin(Plugin{
215+
Logger: logger,
216+
RedisURL: redisURL,
217+
RedisClient: redisClient,
218+
UpdateCacheChannel: cacheUpdateChannel,
219+
})
220+
221+
// Use a WaitGroup to wait for the goroutine to finish.
222+
var wg sync.WaitGroup
223+
wg.Add(1)
224+
go func() {
225+
defer wg.Done()
226+
plugin.Impl.UpdateCache(context.Background())
227+
}()
228+
229+
// Test the plugin's OnTrafficFromClient method with a StartupMessage.
230+
clientArgs := map[string]interface{}{
231+
"request": testStartupRequest(),
232+
"client": map[string]interface{}{
233+
"local": "localhost:15432",
234+
"remote": "localhost:45320",
235+
},
236+
"server": map[string]interface{}{
237+
"local": "localhost:54321",
238+
"remote": "localhost:5432",
239+
},
240+
"error": "",
241+
}
242+
clientRequest, err := v1.NewStruct(clientArgs)
243+
plugin.Impl.OnTrafficFromClient(context.Background(), clientRequest)
244+
245+
// Test the plugin's OnTrafficFromServer method with a query request.
246+
_, queryRequest := testQueryRequestWithDateFucntion()
247+
queryResponse, err := base64.StdEncoding.DecodeString("VAAAABsAAWlkAAAAQAQAAQAAABcABP////8AAEQAAAALAAEAAAABMUMAAAANU0VMRUNUIDEAWgAAAAVJ")
248+
assert.Nil(t, err)
249+
queryArgs := map[string]interface{}{
250+
"request": queryRequest,
251+
"response": queryResponse,
252+
"client": map[string]interface{}{
253+
"local": "localhost:15432",
254+
"remote": "localhost:45320",
255+
},
256+
"server": map[string]interface{}{
257+
"local": "localhost:54321",
258+
"remote": "localhost:5432",
259+
},
260+
"error": "",
261+
}
262+
serverRequest, err := v1.NewStruct(queryArgs)
263+
plugin.Impl.OnTrafficFromServer(context.Background(), serverRequest)
264+
265+
// Close the channel and wait for the cache updater to return gracefully.
266+
close(cacheUpdateChannel)
267+
wg.Wait()
268+
269+
keys, _ := redisClient.Keys(context.Background(), "*").Result()
270+
assert.Equal(t, 1, len(keys)) // Only one key (representing the database name) should be present.
271+
}

0 commit comments

Comments
 (0)