Skip to content

Commit 6a86d91

Browse files
committed
add degrade handler in circuit breaker
1 parent be1218e commit 6a86d91

File tree

5 files changed

+86
-31
lines changed

5 files changed

+86
-31
lines changed

pkg/gin/middleware/breaker.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ type circuitBreakerOptions struct {
2020
group *group.Group
2121
// http code for circuit breaker, default already includes 500 and 503
2222
validCodes map[int]struct{}
23+
// degrade func
24+
degradeHandler func(c *gin.Context)
2325
}
2426

2527
func defaultCircuitBreakerOptions() *circuitBreakerOptions {
@@ -59,6 +61,13 @@ func WithValidCode(code ...int) CircuitBreakerOption {
5961
}
6062
}
6163

64+
// WithDegradeHandler set degrade handler function
65+
func WithDegradeHandler(handler func(c *gin.Context)) CircuitBreakerOption {
66+
return func(o *circuitBreakerOptions) {
67+
o.degradeHandler = handler
68+
}
69+
}
70+
6271
// CircuitBreaker a circuit breaker middleware
6372
func CircuitBreaker(opts ...CircuitBreakerOption) gin.HandlerFunc {
6473
o := defaultCircuitBreakerOptions()
@@ -69,7 +78,11 @@ func CircuitBreaker(opts ...CircuitBreakerOption) gin.HandlerFunc {
6978
if err := breaker.Allow(); err != nil {
7079
// NOTE: when client reject request locally, keep adding counter let the drop ratio higher.
7180
breaker.MarkFailed()
72-
response.Output(c, http.StatusServiceUnavailable, err.Error())
81+
if o.degradeHandler != nil {
82+
o.degradeHandler(c)
83+
} else {
84+
response.Output(c, http.StatusServiceUnavailable, err.Error())
85+
}
7386
c.Abort()
7487
return
7588
}

pkg/gin/middleware/breaker_test.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package middleware
33
import (
44
"math/rand"
55
"net/http"
6-
"strings"
76
"sync"
87
"sync/atomic"
98
"testing"
@@ -21,19 +20,24 @@ import (
2120
func runCircuitBreakerHTTPServer() string {
2221
serverAddr, requestAddr := utils.GetLocalHTTPAddrPairs()
2322

23+
degradeHandler := func(c *gin.Context) {
24+
response.Output(c, http.StatusOK, "degrade")
25+
}
26+
2427
gin.SetMode(gin.ReleaseMode)
2528
r := gin.New()
2629
r.Use(CircuitBreaker(WithGroup(group.NewGroup(func() interface{} {
2730
return circuitbreaker.NewBreaker()
2831
})),
2932
WithValidCode(http.StatusForbidden),
33+
WithDegradeHandler(degradeHandler),
3034
))
3135

3236
r.GET("/hello", func(c *gin.Context) {
3337
if rand.Int()%2 == 0 {
3438
response.Output(c, http.StatusInternalServerError)
3539
} else {
36-
response.Success(c, "hello "+c.ClientIP())
40+
response.Success(c, "localhost"+serverAddr)
3741
}
3842
})
3943

@@ -51,27 +55,32 @@ func runCircuitBreakerHTTPServer() string {
5155
func TestCircuitBreaker(t *testing.T) {
5256
requestAddr := runCircuitBreakerHTTPServer()
5357

54-
var success, failures, countBreaker int32
58+
var success, failures, degradeCount int32
5559
for j := 0; j < 5; j++ {
5660
wg := &sync.WaitGroup{}
5761
wg.Add(1)
5862
go func() {
5963
defer wg.Done()
6064
for i := 0; i < 100; i++ {
6165
result := &httpcli.StdResult{}
62-
if err := httpcli.Get(result, requestAddr+"/hello"); err != nil {
63-
if strings.Contains(err.Error(), ErrNotAllowed.Error()) {
64-
atomic.AddInt32(&countBreaker, 1)
65-
}
66+
err := httpcli.Get(result, requestAddr+"/hello")
67+
if err != nil {
68+
//if errors.Is(err, ErrNotAllowed) {
69+
// atomic.AddInt32(&countBreaker, 1)
70+
//}
6671
atomic.AddInt32(&failures, 1)
72+
continue
73+
}
74+
if result.Data == "degrade" {
75+
atomic.AddInt32(&degradeCount, 1)
6776
} else {
6877
atomic.AddInt32(&success, 1)
6978
}
7079
}
7180
}()
7281

7382
wg.Wait()
74-
t.Logf("%s success: %d, failures: %d, breakerOpen: %d\n",
75-
time.Now().Format(time.RFC3339Nano), success, failures, countBreaker)
83+
t.Logf("%s success: %d, failures: %d, degradeCount: %d\n",
84+
time.Now().Format(time.RFC3339Nano), success, failures, degradeCount)
7685
}
7786
}

pkg/gin/response/response.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,41 +62,43 @@ func respJSONWithStatusCode(c *gin.Context, code int, msg string, data ...interf
6262
writeJSON(c, code, resp)
6363
}
6464

65-
// Output return json data by http status code
66-
func Output(c *gin.Context, code int, msg ...interface{}) {
65+
// Output standard HTTP status codes and data
66+
func Output(c *gin.Context, code int, data ...interface{}) {
6767
switch code {
6868
case http.StatusOK:
69-
respJSONWithStatusCode(c, http.StatusOK, "ok", msg...)
69+
respJSONWithStatusCode(c, http.StatusOK, "ok", data...)
7070
case http.StatusBadRequest:
71-
respJSONWithStatusCode(c, http.StatusBadRequest, errcode.InvalidParams.Msg(), msg...)
71+
respJSONWithStatusCode(c, http.StatusBadRequest, errcode.InvalidParams.Msg(), data...)
7272
case http.StatusUnauthorized:
73-
respJSONWithStatusCode(c, http.StatusUnauthorized, errcode.Unauthorized.Msg(), msg...)
73+
respJSONWithStatusCode(c, http.StatusUnauthorized, errcode.Unauthorized.Msg(), data...)
7474
case http.StatusForbidden:
75-
respJSONWithStatusCode(c, http.StatusForbidden, errcode.Forbidden.Msg(), msg...)
75+
respJSONWithStatusCode(c, http.StatusForbidden, errcode.Forbidden.Msg(), data...)
7676
case http.StatusNotFound:
77-
respJSONWithStatusCode(c, http.StatusNotFound, errcode.NotFound.Msg(), msg...)
77+
respJSONWithStatusCode(c, http.StatusNotFound, errcode.NotFound.Msg(), data...)
7878
case http.StatusRequestTimeout:
79-
respJSONWithStatusCode(c, http.StatusRequestTimeout, errcode.Timeout.Msg(), msg...)
79+
respJSONWithStatusCode(c, http.StatusRequestTimeout, errcode.Timeout.Msg(), data...)
8080
case http.StatusConflict:
81-
respJSONWithStatusCode(c, http.StatusConflict, errcode.Conflict.Msg(), msg...)
81+
respJSONWithStatusCode(c, http.StatusConflict, errcode.Conflict.Msg(), data...)
8282
case http.StatusInternalServerError:
83-
respJSONWithStatusCode(c, http.StatusInternalServerError, errcode.InternalServerError.Msg(), msg...)
83+
respJSONWithStatusCode(c, http.StatusInternalServerError, errcode.InternalServerError.Msg(), data...)
8484
case http.StatusTooManyRequests:
85-
respJSONWithStatusCode(c, http.StatusTooManyRequests, errcode.LimitExceed.Msg(), msg...)
85+
respJSONWithStatusCode(c, http.StatusTooManyRequests, errcode.LimitExceed.Msg(), data...)
8686
case http.StatusServiceUnavailable:
87-
respJSONWithStatusCode(c, http.StatusServiceUnavailable, errcode.ServiceUnavailable.Msg(), msg...)
87+
respJSONWithStatusCode(c, http.StatusServiceUnavailable, errcode.ServiceUnavailable.Msg(), data...)
8888

8989
default:
90-
respJSONWithStatusCode(c, code, http.StatusText(code), msg...)
90+
respJSONWithStatusCode(c, code, http.StatusText(code), data...)
9191
}
9292
}
9393

94-
// Out return json data by http status code, converted by errcode
94+
// Out HTTP standard status code which is converted from errcode.Error
9595
func Out(c *gin.Context, err *errcode.Error, data ...interface{}) {
9696
code := err.ToHTTPCode()
9797
switch code {
9898
case http.StatusOK:
9999
respJSONWithStatusCode(c, http.StatusOK, "ok", data...)
100+
case http.StatusInternalServerError:
101+
respJSONWithStatusCode(c, http.StatusInternalServerError, err.Msg(), data...)
100102
case http.StatusBadRequest:
101103
respJSONWithStatusCode(c, http.StatusBadRequest, err.Msg(), data...)
102104
case http.StatusUnauthorized:
@@ -109,8 +111,6 @@ func Out(c *gin.Context, err *errcode.Error, data ...interface{}) {
109111
respJSONWithStatusCode(c, http.StatusRequestTimeout, err.Msg(), data...)
110112
case http.StatusConflict:
111113
respJSONWithStatusCode(c, http.StatusConflict, err.Msg(), data...)
112-
case http.StatusInternalServerError:
113-
respJSONWithStatusCode(c, http.StatusInternalServerError, err.Msg(), data...)
114114
case http.StatusTooManyRequests:
115115
respJSONWithStatusCode(c, http.StatusTooManyRequests, err.Msg(), data...)
116116
case http.StatusServiceUnavailable:

pkg/grpc/interceptor/breaker.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ type circuitBreakerOptions struct {
2323
group *group.Group
2424
// rpc code for circuit breaker, default already includes codes.Internal and codes.Unavailable
2525
validCodes map[codes.Code]struct{}
26+
27+
// degrade handler for unary server
28+
unaryServerDegradeHandler func(ctx context.Context, req interface{}) (reply interface{}, error error)
2629
}
2730

2831
func defaultCircuitBreakerOptions() *circuitBreakerOptions {
@@ -62,6 +65,13 @@ func WithValidCode(code ...codes.Code) CircuitBreakerOption {
6265
}
6366
}
6467

68+
// WithUnaryServerDegradeHandler unary server degrade handler function
69+
func WithUnaryServerDegradeHandler(handler func(ctx context.Context, req interface{}) (reply interface{}, error error)) CircuitBreakerOption {
70+
return func(o *circuitBreakerOptions) {
71+
o.unaryServerDegradeHandler = handler
72+
}
73+
}
74+
6575
// UnaryClientCircuitBreaker client-side unary circuit breaker interceptor
6676
func UnaryClientCircuitBreaker(opts ...CircuitBreakerOption) grpc.UnaryClientInterceptor {
6777
o := defaultCircuitBreakerOptions()
@@ -130,6 +140,10 @@ func UnaryServerCircuitBreaker(opts ...CircuitBreakerOption) grpc.UnaryServerInt
130140
if err := breaker.Allow(); err != nil {
131141
// NOTE: when client reject request locally, keep adding let the drop ratio higher.
132142
breaker.MarkFailed()
143+
144+
if o.unaryServerDegradeHandler != nil {
145+
return o.unaryServerDegradeHandler(ctx, req)
146+
}
133147
return nil, errcode.StatusServiceUnavailable.ToRPCErr(err.Error())
134148
}
135149

pkg/grpc/interceptor/breaker_test.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,41 @@ func TestSteamClientCircuitBreaker(t *testing.T) {
5858
}
5959

6060
func TestUnaryServerCircuitBreaker(t *testing.T) {
61-
interceptor := UnaryServerCircuitBreaker()
61+
degradeHandler := func(ctx context.Context, req interface{}) (reply interface{}, error error) {
62+
return "degrade", errcode.StatusSuccess.ToRPCErr()
63+
}
64+
interceptor := UnaryServerCircuitBreaker(WithUnaryServerDegradeHandler(degradeHandler))
6265
assert.NotNil(t, interceptor)
6366

67+
count := 0
6468
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
69+
count++
70+
if count%2 == 0 {
71+
return nil, errcode.StatusSuccess.ToRPCErr()
72+
}
6573
return nil, errcode.StatusInternalServerError.ToRPCErr()
6674
}
67-
for i := 0; i < 110; i++ {
68-
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{FullMethod: "/test"}, handler)
69-
assert.Error(t, err)
75+
76+
successCount, failCount, degradeCount := 0, 0, 0
77+
for i := 0; i < 1000; i++ {
78+
reply, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{FullMethod: "/test"}, handler)
79+
if err != nil {
80+
failCount++
81+
continue
82+
}
83+
if reply == "degrade" {
84+
degradeCount++
85+
} else {
86+
successCount++
87+
}
7088
}
89+
t.Logf("successCount: %d, failCount: %d, degradeCount: %d", successCount, failCount, degradeCount)
7190

7291
handler = func(ctx context.Context, req interface{}) (interface{}, error) {
7392
return nil, errcode.StatusInvalidParams.Err()
7493
}
7594
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{FullMethod: "/test"}, handler)
76-
assert.Error(t, err)
95+
t.Log(err)
7796
}
7897

7998
func TestSteamServerCircuitBreaker(t *testing.T) {

0 commit comments

Comments
 (0)