Skip to content

Commit 99877a9

Browse files
authored
Merge pull request #73 from gatewayd-io/66-write-to-cache-async
Update Redis cache asynhronously
2 parents aa362ef + 726cf90 commit 99877a9

File tree

5 files changed

+117
-70
lines changed

5 files changed

+117
-70
lines changed

gatewayd_plugin.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ plugins:
3030
- API_ADDRESS=localhost:18080
3131
- EXIT_ON_STARTUP_ERROR=False
3232
- SENTRY_DSN=https://70eb1abcd32e41acbdfc17bc3407a543@o4504550475038720.ingest.sentry.io/4505342961123328
33+
- CACHE_CHANNEL_BUFFER_SIZE=100
3334
checksum: 3988e10aefce2cd9b30888eddd2ec93a431c9018a695aea1cea0dac46ba91cae

main.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/gatewayd-io/gatewayd-plugin-sdk/logging"
1212
"github.com/gatewayd-io/gatewayd-plugin-sdk/metrics"
1313
p "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin"
14+
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
1415
"github.com/getsentry/sentry-go"
1516
"github.com/go-redis/redis/v8"
1617
"github.com/hashicorp/go-hclog"
@@ -52,6 +53,14 @@ func main() {
5253
go metrics.ExposeMetrics(metricsConfig, logger)
5354
}
5455

56+
cacheBufferSize := cast.ToUint(cfg["cacheBufferSize"])
57+
if cacheBufferSize <= 0 {
58+
cacheBufferSize = 100 // default value
59+
}
60+
61+
pluginInstance.Impl.UpdateCacheChannel = make(chan *v1.Struct, cacheBufferSize)
62+
go pluginInstance.Impl.UpdateCache(context.Background())
63+
5564
pluginInstance.Impl.RedisURL = cast.ToString(cfg["redisURL"])
5665
pluginInstance.Impl.Expiry = cast.ToDuration(cfg["expiry"])
5766
pluginInstance.Impl.DefaultDBName = cast.ToString(cfg["defaultDBName"])
@@ -93,6 +102,8 @@ func main() {
93102
}
94103
}
95104

105+
defer close(pluginInstance.Impl.UpdateCacheChannel)
106+
96107
goplugin.Serve(&goplugin.ServeConfig{
97108
HandshakeConfig: goplugin.HandshakeConfig{
98109
ProtocolVersion: 1,

plugin/module.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ var (
4545
"PERIODIC_INVALIDATOR_INTERVAL", "1m"),
4646
"apiAddress": sdkConfig.GetEnv("API_ADDRESS", "localhost:8080"),
4747
"exitOnStartupError": sdkConfig.GetEnv("EXIT_ON_STARTUP_ERROR", "false"),
48+
"cacheBufferSize": sdkConfig.GetEnv("CACHE_CHANNEL_BUFFER_SIZE", "100"),
4849
},
4950
"hooks": []interface{}{
5051
int32(v1.HookName_HOOK_NAME_ON_CLOSED),

plugin/plugin.go

Lines changed: 82 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ type Plugin struct {
3030
ScanCount int64
3131
ExitOnStartupError bool
3232

33+
UpdateCacheChannel chan *v1.Struct
34+
3335
// Periodic invalidator configuration.
3436
PeriodicInvalidatorEnabled bool
3537
PeriodicInvalidatorStartDelay time.Duration
@@ -144,87 +146,103 @@ func (p *Plugin) OnTrafficFromClient(
144146
return req, nil
145147
}
146148

147-
// OnTrafficFromServer is called when a response is received by GatewayD from the server.
148-
func (p *Plugin) OnTrafficFromServer(
149-
ctx context.Context, resp *v1.Struct,
150-
) (*v1.Struct, error) {
151-
OnTrafficFromServerCounter.Inc()
152-
resp, err := postgres.HandleServerMessage(resp, p.Logger)
153-
if err != nil {
154-
p.Logger.Info("Failed to handle server message", "error", err)
155-
}
156-
157-
rowDescription := cast.ToString(sdkPlugin.GetAttr(resp, "rowDescription", ""))
158-
dataRow := cast.ToStringSlice(sdkPlugin.GetAttr(resp, "dataRow", []interface{}{}))
159-
errorResponse := cast.ToString(sdkPlugin.GetAttr(resp, "errorResponse", ""))
160-
request, ok := sdkPlugin.GetAttr(resp, "request", nil).([]byte)
161-
if !ok {
162-
request = []byte{}
163-
}
164-
response, ok := sdkPlugin.GetAttr(resp, "response", nil).([]byte)
165-
if !ok {
166-
response = []byte{}
167-
}
168-
server := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "server", ""))
149+
func (p *Plugin) UpdateCache(ctx context.Context) {
150+
for {
151+
serverResponse, ok := <-p.UpdateCacheChannel
152+
if !ok {
153+
p.Logger.Info("Channel closed, returning from function")
154+
return
155+
}
169156

170-
// This is used as a fallback if the database is not found in the startup message.
171-
database := p.DefaultDBName
172-
if database == "" {
173-
client := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "client", ""))
174-
if client != nil && client["remote"] != "" {
175-
database, err = p.RedisClient.Get(ctx, client["remote"]).Result()
176-
if err != nil {
177-
CacheMissesCounter.Inc()
178-
p.Logger.Debug("Failed to get cached response", "error", err)
179-
}
180-
CacheGetsCounter.Inc()
157+
OnTrafficFromServerCounter.Inc()
158+
resp, err := postgres.HandleServerMessage(serverResponse, p.Logger)
159+
if err != nil {
160+
p.Logger.Info("Failed to handle server message", "error", err)
181161
}
182-
}
183162

184-
// If the database is still not found, return the response as is without caching.
185-
// This might also happen if the cache is cleared while the client is still connected.
186-
// In this case, the client should reconnect and the error will go away.
187-
if database == "" {
188-
p.Logger.Debug("Database name not found or set in cache, startup message or plugin config. Skipping cache")
189-
p.Logger.Debug("Consider setting the database name in the plugin config or disabling the plugin if you don't need it")
190-
return resp, nil
191-
}
163+
rowDescription := cast.ToString(sdkPlugin.GetAttr(resp, "rowDescription", ""))
164+
dataRow := cast.ToStringSlice(sdkPlugin.GetAttr(resp, "dataRow", []interface{}{}))
165+
errorResponse := cast.ToString(sdkPlugin.GetAttr(resp, "errorResponse", ""))
166+
request, isOk := sdkPlugin.GetAttr(resp, "request", nil).([]byte)
167+
if !isOk {
168+
request = []byte{}
169+
}
192170

193-
cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":")
194-
if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 {
195-
// The request was successful and the response contains data. Cache the response.
196-
if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil {
197-
CacheMissesCounter.Inc()
198-
p.Logger.Debug("Failed to set cache", "error", err)
171+
response, isOk := sdkPlugin.GetAttr(resp, "response", nil).([]byte)
172+
if !isOk {
173+
response = []byte{}
199174
}
200-
CacheSetsCounter.Inc()
175+
server := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "server", ""))
201176

202-
// Cache the query as well.
203-
query, err := postgres.GetQueryFromRequest(request)
204-
if err != nil {
205-
p.Logger.Debug("Failed to get query from request", "error", err)
206-
return resp, nil
177+
// This is used as a fallback if the database is not found in the startup message.
178+
179+
database := p.DefaultDBName
180+
if database == "" {
181+
client := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "client", ""))
182+
if client != nil && client["remote"] != "" {
183+
database, err = p.RedisClient.Get(ctx, client["remote"]).Result()
184+
if err != nil {
185+
CacheMissesCounter.Inc()
186+
p.Logger.Debug("Failed to get cached response", "error", err)
187+
}
188+
CacheGetsCounter.Inc()
189+
}
207190
}
208191

209-
tables, err := postgres.GetTablesFromQuery(query)
210-
if err != nil {
211-
p.Logger.Debug("Failed to get tables from query", "error", err)
212-
return resp, nil
192+
// If the database is still not found, return the response as is without caching.
193+
// This might also happen if the cache is cleared while the client is still connected.
194+
// In this case, the client should reconnect and the error will go away.
195+
if database == "" {
196+
p.Logger.Debug("Database name not found or set in cache, startup message or plugin config. " +
197+
"Skipping cache")
198+
p.Logger.Debug("Consider setting the database name in the " +
199+
"plugin config or disabling the plugin if you don't need it")
200+
return
213201
}
214202

215-
// Cache the table(s) used in each cached request. This is used to invalidate
216-
// the cache when a rows is inserted, updated or deleted into that table.
217-
for _, table := range tables {
218-
requestQueryCacheKey := strings.Join([]string{table, cacheKey}, ":")
219-
if err := p.RedisClient.Set(
220-
ctx, requestQueryCacheKey, "", p.Expiry).Err(); err != nil {
203+
cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":")
204+
if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 {
205+
// The request was successful and the response contains data. Cache the response.
206+
if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil {
221207
CacheMissesCounter.Inc()
222208
p.Logger.Debug("Failed to set cache", "error", err)
223209
}
224210
CacheSetsCounter.Inc()
211+
212+
// Cache the query as well.
213+
query, err := postgres.GetQueryFromRequest(request)
214+
if err != nil {
215+
p.Logger.Debug("Failed to get query from request", "error", err)
216+
return
217+
}
218+
219+
tables, err := postgres.GetTablesFromQuery(query)
220+
if err != nil {
221+
p.Logger.Debug("Failed to get tables from query", "error", err)
222+
return
223+
}
224+
225+
// Cache the table(s) used in each cached request. This is used to invalidate
226+
// the cache when a rows is inserted, updated or deleted into that table.
227+
for _, table := range tables {
228+
requestQueryCacheKey := strings.Join([]string{table, cacheKey}, ":")
229+
if err := p.RedisClient.Set(
230+
ctx, requestQueryCacheKey, "", p.Expiry).Err(); err != nil {
231+
CacheMissesCounter.Inc()
232+
p.Logger.Debug("Failed to set cache", "error", err)
233+
}
234+
CacheSetsCounter.Inc()
235+
}
225236
}
226237
}
238+
}
227239

240+
// OnTrafficFromServer is called when a response is received by GatewayD from the server.
241+
func (p *Plugin) OnTrafficFromServer(
242+
_ context.Context, resp *v1.Struct,
243+
) (*v1.Struct, error) {
244+
p.Logger.Debug("Traffic is coming from the server side")
245+
p.UpdateCacheChannel <- resp
228246
return resp, nil
229247
}
230248

plugin/plugin_test.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ package plugin
33
import (
44
"context"
55
"encoding/base64"
6-
"os"
7-
"testing"
8-
96
miniredis "github.com/alicebob/miniredis/v2"
107
"github.com/gatewayd-io/gatewayd-plugin-sdk/logging"
118
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
129
"github.com/go-redis/redis/v8"
1310
"github.com/hashicorp/go-hclog"
1411
pgproto3 "github.com/jackc/pgx/v5/pgproto3"
1512
"github.com/stretchr/testify/assert"
13+
"os"
14+
"sync"
15+
"testing"
1616
)
1717

1818
func testQueryRequest() (string, []byte) {
@@ -44,16 +44,28 @@ func Test_Plugin(t *testing.T) {
4444
redisClient := redis.NewClient(redisConfig)
4545
assert.NotNil(t, redisClient)
4646

47+
updateCacheChannel := make(chan *v1.Struct, 10)
48+
4749
// Create and initialize a new plugin.
4850
logger := hclog.New(&hclog.LoggerOptions{
4951
Level: logging.GetLogLevel("error"),
5052
Output: os.Stdout,
5153
})
5254
p := NewCachePlugin(Plugin{
53-
Logger: logger,
54-
RedisURL: redisURL,
55-
RedisClient: redisClient,
55+
Logger: logger,
56+
RedisURL: redisURL,
57+
RedisClient: redisClient,
58+
UpdateCacheChannel: updateCacheChannel,
5659
})
60+
61+
// Use a WaitGroup to wait for the goroutine to finish
62+
var wg sync.WaitGroup
63+
wg.Add(1)
64+
go func() {
65+
defer wg.Done()
66+
p.Impl.UpdateCache(context.Background())
67+
}()
68+
5769
assert.NotNil(t, p)
5870

5971
// Test the plugin's GetPluginConfig method.
@@ -146,6 +158,10 @@ func Test_Plugin(t *testing.T) {
146158
assert.NotNil(t, result)
147159
assert.Equal(t, result, resp)
148160

161+
// Close the channel and wait for the cache updater to return gracefully
162+
close(updateCacheChannel)
163+
wg.Wait()
164+
149165
// Check that the query and response was cached.
150166
cachedResponse, err := redisClient.Get(
151167
context.Background(), "localhost:5432:postgres:"+string(request)).Bytes()

0 commit comments

Comments
 (0)