diff --git a/LICENSE b/LICENSE index bbe6990..54a5ec3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,6 @@ MIT License +Copyright (c) 2022 Elliot Lunness Copyright (c) 2018 Victor Springer Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/README.md b/README.md index 0722a81..c82e202 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,15 @@ It is simple, super fast, thread safe and gives the possibility to choose the ad The memory adapter minimizes GC overhead to near zero and supports some options of caching algorithms (LRU, MRU, LFU, MFU). This way, it is able to store plenty of gigabytes of responses, keeping great performance and being free of leaks. -**Note:** Some tests are currently disabled as they weren't updated when the library was updated for use with echo. +**Note:** Some tests are currently disabled as they weren't updated when the library was updated for use with echo. I plan to fix this soon. + +## Original Credit + +Project has been detached from the original repository as it's not maintained anymore and github defaults PRs to the original repository. + +* [echo-http-cache](https://github.com/SporkHubr/echo-http-cache) +* [http-cache](https://github.com/victorspringer/http-cache) + ## Getting Started @@ -84,6 +92,7 @@ import ( cache.ClientWithAdapter(redis.NewAdapter(ringOpt)), cache.ClientWithTTL(10 * time.Minute), cache.ClientWithRefreshKey("opn"), + cache.ClientWithStatusCodeFilter(func(code int) bool { return code != 400 }), // Default ) ... diff --git a/cache.go b/cache.go index 6613452..e0af5c9 100644 --- a/cache.go +++ b/cache.go @@ -52,6 +52,9 @@ type Response struct { // Header is the cached response header. Header http.Header + // StatusCode is the cached response status code. + StatusCode int + // Expiration is the cached response expiration date. Expiration time.Time @@ -66,11 +69,12 @@ type Response struct { // Client data structure for HTTP cache middleware. type Client struct { - adapter Adapter - ttl time.Duration - refreshKey string - methods []string - restrictedPaths []string + adapter Adapter + ttl time.Duration + refreshKey string + methods []string + restrictedPaths []string + statusCodeFilter func(int) bool } type bodyDumpResponseWriter struct { @@ -152,10 +156,10 @@ func (client *Client) Middleware() echo.MiddlewareFunc { response.Frequency++ client.adapter.Set(key, response.Bytes(), response.Expiration) - //w.WriteHeader(http.StatusNotModified) for k, v := range response.Header { c.Response().Header().Set(k, strings.Join(v, ",")) } + c.Response().WriteHeader(response.StatusCode) c.Response().WriteHeader(http.StatusOK) c.Response().Write(response.Value) return nil @@ -175,7 +179,7 @@ func (client *Client) Middleware() echo.MiddlewareFunc { statusCode := writer.statusCode value := resBody.Bytes() - if statusCode < 400 { + if client.statusCodeFilter(statusCode) { now := time.Now() response := Response{ @@ -288,10 +292,22 @@ func NewClient(opts ...ClientOption) (*Client, error) { if c.methods == nil { c.methods = []string{http.MethodGet} } + if c.statusCodeFilter == nil { + c.statusCodeFilter = func(code int) bool { return code < 400 } + } return c, nil } +// ClientWithStatusCodeFilter sets the acceptable status codes to be cached. +// Optional setting. If not set, default filter allows caching of every response with status code below 400. +func ClientWithStatusCodeFilter(filter func(int) bool) ClientOption { + return func(c *Client) error { + c.statusCodeFilter = filter + return nil + } +} + // ClientWithAdapter sets the adapter type for the HTTP cache // middleware client. func ClientWithAdapter(a Adapter) ClientOption { diff --git a/cache_test.go b/cache_test.go index 5b819c7..2d2a731 100644 --- a/cache_test.go +++ b/cache_test.go @@ -52,18 +52,22 @@ func (errReader) Read(p []byte) (n int, err error) { // store: map[uint64][]byte{ // 14974843192121052621: Response{ // Value: []byte("value 1"), +// StatusCode: 200, // Expiration: time.Now().Add(1 * time.Minute), // }.Bytes(), // 14974839893586167988: Response{ // Value: []byte("value 2"), +// StatusCode: 200, // Expiration: time.Now().Add(1 * time.Minute), // }.Bytes(), // 14974840993097796199: Response{ // Value: []byte("value 3"), +// StatusCode: 200, // Expiration: time.Now().Add(-1 * time.Minute), // }.Bytes(), // 10956846073361780255: Response{ // Value: []byte("value 4"), +// StatusCode: 200, // Expiration: time.Now().Add(-1 * time.Minute), // }.Bytes(), // }, @@ -470,9 +474,61 @@ func TestNewClient(t *testing.T) { t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.want != nil { + got.statusCodeFilter = nil + tt.want.statusCodeFilter = nil + } if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewClient() = %v, want %v", got, tt.want) } }) } } + +func TestNewClientWithStatusCodeFilter(t *testing.T) { + adapter := &adapterMock{} + + tests := []struct { + name string + opts []ClientOption + wantCache []int + wantSkip []int + }{ + { + "returns new client with status code filter", + []ClientOption{ + ClientWithAdapter(adapter), + ClientWithTTL(1 * time.Millisecond), + }, + []int{200, 300}, + []int{400, 500}, + }, + { + "returns new client with status code filter", + []ClientOption{ + ClientWithAdapter(adapter), + ClientWithTTL(1 * time.Millisecond), + ClientWithStatusCodeFilter(func(code int) bool { return code < 350 || code > 450 }), + }, + []int{200, 300, 500}, + []int{400}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := NewClient(tt.opts...) + + for _, c := range tt.wantCache { + if got.statusCodeFilter(c) == false { + t.Errorf("NewClient() allows caching of status code %v, don't want it to", c) + } + } + + for _, c := range tt.wantSkip { + if got.statusCodeFilter(c) == true { + t.Errorf("NewClient() doesn't allow caching of status code %v, want it to", c) + } + } + }) + } +}