Skip to content

Commit e2ec3e8

Browse files
authored
Merge pull request #97 from remind101/http-handler-compat
Ensure Context arg is the same as request Context
2 parents a2aea2a + 8f46cf0 commit e2ec3e8

File tree

7 files changed

+23
-1
lines changed

7 files changed

+23
-1
lines changed

httpx/middleware/header.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ func (h *Header) ServeHTTPContext(ctx context.Context, w http.ResponseWriter, r
2828
value := e(r)
2929

3030
ctx = httpx.WithHeader(ctx, h.key, value)
31+
r = r.WithContext(ctx)
32+
3133
return h.handler.ServeHTTPContext(ctx, w, r)
3234
}
3335

httpx/middleware/header_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"testing"
77

88
"context"
9+
910
"github.com/remind101/pkg/httpx"
1011
)
1112

@@ -24,7 +25,11 @@ func TestHeader(t *testing.T) {
2425
m := ExtractHeader(
2526
httpx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
2627
data := httpx.Header(ctx, tt.key)
28+
if got, want := data, tt.val; got != want {
29+
t.Fatalf("%s => %s; want %s", tt.key, got, want)
30+
}
2731

32+
data = httpx.Header(r.Context(), tt.key)
2833
if got, want := data, tt.val; got != want {
2934
t.Fatalf("%s => %s; want %s", tt.key, got, want)
3035
}

httpx/middleware/logger.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ func LogTo(h httpx.Handler, g loggerGenerator) httpx.Handler {
4343
func InsertLogger(h httpx.Handler, g loggerGenerator) httpx.Handler {
4444
return httpx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
4545
l := g(ctx, r)
46+
4647
ctx = logger.WithLogger(ctx, l)
48+
r = r.WithContext(ctx)
49+
4750
return h.ServeHTTPContext(ctx, w, r)
4851
})
4952
}

httpx/middleware/opentracing.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ func (h *OpentracingTracer) ServeHTTPContext(ctx context.Context, w http.Respons
4545

4646
defer span.Finish()
4747
ctx = opentracing.ContextWithSpan(ctx, span)
48+
r = r.WithContext(ctx)
4849

4950
rw := NewResponseWriter(w)
5051
reqErr := h.handler.ServeHTTPContext(ctx, rw, r)

httpx/middleware/reporter.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ func (m *Reporter) ServeHTTPContext(ctx context.Context, w http.ResponseWriter,
2626
// Add the request id to reporter context.
2727
ctx = errors.WithInfo(ctx, "request_id", httpx.RequestID(ctx))
2828

29+
r = r.WithContext(ctx)
30+
2931
return m.handler.ServeHTTPContext(ctx, w, r)
3032
}
3133

httpx/middleware/request_id.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"net/http"
55

66
"context"
7+
78
"github.com/remind101/pkg/httpx"
89
)
910

@@ -39,5 +40,7 @@ func (h *RequestID) ServeHTTPContext(ctx context.Context, w http.ResponseWriter,
3940
requestID := e(r)
4041

4142
ctx = httpx.WithRequestID(ctx, requestID)
43+
r = r.WithContext(ctx)
44+
4245
return h.handler.ServeHTTPContext(ctx, w, r)
4346
}

httpx/middleware/request_id_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ import (
55
"net/http/httptest"
66
"testing"
77

8-
"github.com/remind101/pkg/httpx"
98
"context"
9+
10+
"github.com/remind101/pkg/httpx"
1011
)
1112

1213
func TestRequestID(t *testing.T) {
@@ -23,7 +24,12 @@ func TestRequestID(t *testing.T) {
2324
m := &RequestID{
2425
handler: httpx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
2526
requestID := httpx.RequestID(ctx)
27+
if got, want := requestID, tt.id; got != want {
28+
t.Fatalf("RequestID => %s; want %s", got, want)
29+
}
2630

31+
// From request.Context()
32+
requestID = httpx.RequestID(r.Context())
2733
if got, want := requestID, tt.id; got != want {
2834
t.Fatalf("RequestID => %s; want %s", got, want)
2935
}

0 commit comments

Comments
 (0)