Skip to content

Commit 72500bb

Browse files
[EXP-3458] Append User Agent and XFF header to proxied requests (#44)
* Add http context for X-Forwarded-For and UserAgent * Add agent headers and XFF from context * Add http context to server * Integration test GitOrigin-RevId: 75ffabc
1 parent d76949b commit 72500bb

6 files changed

Lines changed: 428 additions & 5 deletions

File tree

cmd/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/render-oss/render-mcp-server/pkg/cfg"
1010
"github.com/render-oss/render-mcp-server/pkg/client"
1111
"github.com/render-oss/render-mcp-server/pkg/deploy"
12+
"github.com/render-oss/render-mcp-server/pkg/httpcontext"
1213
"github.com/render-oss/render-mcp-server/pkg/keyvalue"
1314
"github.com/render-oss/render-mcp-server/pkg/logs"
1415
"github.com/render-oss/render-mcp-server/pkg/metrics"
@@ -56,6 +57,7 @@ func Serve(transport string) *server.MCPServer {
5657
NewStreamableHTTPServer(s, server.WithHTTPContextFunc(multicontext.MultiHTTPContextFunc(
5758
session.ContextWithHTTPSession(sessionStore),
5859
authn.ContextWithAPITokenFromHeader,
60+
httpcontext.ContextWithHTTPRequest,
5961
))).
6062
Start(":10000")
6163
if err != nil {

pkg/cfg/cfg.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@ func GetAPIKey() string {
2121
return os.Getenv("RENDER_API_KEY")
2222
}
2323

24-
func AddUserAgent(header http.Header) http.Header {
25-
header.Add("user-agent", fmt.Sprintf("render-mcp-server/%s (%s)", Version, getOSInfoOnce()))
24+
func AddUserAgent(header http.Header, clientUserAgent string) http.Header {
25+
ua := fmt.Sprintf("render-mcp-server/%s (%s)", Version, getOSInfoOnce())
26+
if clientUserAgent != "" {
27+
ua = ua + " " + clientUserAgent
28+
}
29+
header.Add("user-agent", ua)
2630
return header
2731
}
2832

pkg/client/client.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/render-oss/render-mcp-server/pkg/authn"
1212
"github.com/render-oss/render-mcp-server/pkg/cfg"
1313
"github.com/render-oss/render-mcp-server/pkg/config"
14+
"github.com/render-oss/render-mcp-server/pkg/httpcontext"
1415
)
1516

1617
var ErrUnauthorized = errors.New("unauthorized")
@@ -24,9 +25,13 @@ func NewDefaultClient() (*ClientWithResponses, error) {
2425
return clientWithAuth(&http.Client{}, apiCfg)
2526
}
2627

27-
func AddHeaders(header http.Header, token string) http.Header {
28-
header = cfg.AddUserAgent(header)
28+
func AddHeaders(ctx context.Context, header http.Header, token string) http.Header {
29+
hc := httpcontext.FromContext(ctx)
30+
header = cfg.AddUserAgent(header, hc.UserAgent)
2931
header.Add("authorization", fmt.Sprintf("Bearer %s", token))
32+
if hc.ForwardedFor != "" {
33+
header.Add("X-Forwarded-For", hc.ForwardedFor)
34+
}
3035
return header
3136
}
3237

@@ -93,7 +98,7 @@ func firstNonNilErrorField(response any) *ErrorWithCode {
9398

9499
func clientWithAuth(httpClient *http.Client, apiCfg config.APIConfig) (*ClientWithResponses, error) {
95100
insertAuth := func(ctx context.Context, req *http.Request) error {
96-
req.Header = AddHeaders(req.Header, authn.APITokenFromContext(ctx))
101+
req.Header = AddHeaders(ctx, req.Header, authn.APITokenFromContext(ctx))
97102
return nil
98103
}
99104

pkg/httpcontext/httpcontext.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package httpcontext
2+
3+
import (
4+
"context"
5+
"log"
6+
"net"
7+
"net/http"
8+
"strings"
9+
)
10+
11+
// HTTPContext stores HTTP metadata extracted from incoming requests.
12+
type HTTPContext struct {
13+
UserAgent string // Client's User-Agent header
14+
ForwardedFor string // X-Forwarded-For chain
15+
}
16+
17+
type ctxKey struct{}
18+
19+
// FromContext retrieves HTTPContext from context. Returns empty HTTPContext if not present.
20+
func FromContext(ctx context.Context) HTTPContext {
21+
if hc, ok := ctx.Value(ctxKey{}).(HTTPContext); ok {
22+
return hc
23+
}
24+
return HTTPContext{}
25+
}
26+
27+
// ContextWithHTTPContext stores HTTPContext in the context.
28+
func ContextWithHTTPContext(ctx context.Context, hc HTTPContext) context.Context {
29+
return context.WithValue(ctx, ctxKey{}, hc)
30+
}
31+
32+
// ContextWithHTTPRequest extracts HTTP metadata from a request and stores it in context.
33+
func ContextWithHTTPRequest(ctx context.Context, req *http.Request) context.Context {
34+
hc := HTTPContext{
35+
UserAgent: req.Header.Get("User-Agent"),
36+
ForwardedFor: buildXFF(req.Header.Get("X-Forwarded-For"), req.RemoteAddr),
37+
}
38+
return ContextWithHTTPContext(ctx, hc)
39+
}
40+
41+
// buildXFF constructs the X-Forwarded-For chain.
42+
// It appends the client IP from RemoteAddr to the existing XFF header,
43+
// avoiding consecutive duplicates at the end.
44+
func buildXFF(existingXFF, remoteAddr string) string {
45+
clientIP := getClientIP(remoteAddr)
46+
if clientIP == "" {
47+
return existingXFF
48+
}
49+
50+
if existingXFF == "" {
51+
return clientIP
52+
}
53+
54+
// Check if clientIP is already the last entry to avoid consecutive duplicates
55+
lastEntry := lastXFFEntry(existingXFF)
56+
if lastEntry == clientIP {
57+
return existingXFF
58+
}
59+
60+
return existingXFF + ", " + clientIP
61+
}
62+
63+
// getClientIP extracts the IP address from RemoteAddr, stripping the port if present.
64+
func getClientIP(remoteAddr string) string {
65+
if remoteAddr == "" {
66+
return ""
67+
}
68+
69+
// net.SplitHostPort handles both IPv4 and IPv6 addresses
70+
host, _, err := net.SplitHostPort(remoteAddr)
71+
if err != nil {
72+
// RemoteAddr doesn't have standard host:port format.
73+
// This can happen with non-TCP transports or unusual proxy configurations.
74+
// Log for debugging but use the raw value.
75+
log.Printf("httpcontext: could not parse RemoteAddr %q: %v", remoteAddr, err)
76+
return remoteAddr
77+
}
78+
return host
79+
}
80+
81+
// lastXFFEntry returns the last IP in the X-Forwarded-For chain.
82+
func lastXFFEntry(xff string) string {
83+
parts := strings.Split(xff, ",")
84+
// strings.Split always returns at least one element, even for empty string
85+
return strings.TrimSpace(parts[len(parts)-1])
86+
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package httpcontext
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"testing"
7+
)
8+
9+
func TestFromContext_EmptyContext(t *testing.T) {
10+
ctx := context.Background()
11+
hc := FromContext(ctx)
12+
13+
if hc.UserAgent != "" {
14+
t.Errorf("expected empty UserAgent, got %q", hc.UserAgent)
15+
}
16+
if hc.ForwardedFor != "" {
17+
t.Errorf("expected empty ForwardedFor, got %q", hc.ForwardedFor)
18+
}
19+
}
20+
21+
func TestContextWithHTTPContext_RoundTrip(t *testing.T) {
22+
ctx := context.Background()
23+
expected := HTTPContext{
24+
UserAgent: "TestAgent/1.0",
25+
ForwardedFor: "10.0.0.1, 192.168.1.1",
26+
}
27+
28+
ctx = ContextWithHTTPContext(ctx, expected)
29+
actual := FromContext(ctx)
30+
31+
if actual.UserAgent != expected.UserAgent {
32+
t.Errorf("expected UserAgent %q, got %q", expected.UserAgent, actual.UserAgent)
33+
}
34+
if actual.ForwardedFor != expected.ForwardedFor {
35+
t.Errorf("expected ForwardedFor %q, got %q", expected.ForwardedFor, actual.ForwardedFor)
36+
}
37+
}
38+
39+
func TestContextWithHTTPRequest_UserAgent(t *testing.T) {
40+
ctx := context.Background()
41+
req, _ := http.NewRequest("GET", "/", nil)
42+
req.Header.Set("User-Agent", "Claude-Desktop/1.2.3")
43+
req.RemoteAddr = "192.168.1.100:54321"
44+
45+
ctx = ContextWithHTTPRequest(ctx, req)
46+
hc := FromContext(ctx)
47+
48+
if hc.UserAgent != "Claude-Desktop/1.2.3" {
49+
t.Errorf("expected UserAgent %q, got %q", "Claude-Desktop/1.2.3", hc.UserAgent)
50+
}
51+
}
52+
53+
func TestContextWithHTTPRequest_XFFChainBuilding(t *testing.T) {
54+
tests := []struct {
55+
name string
56+
existingXFF string
57+
remoteAddr string
58+
expectedXFF string
59+
}{
60+
{
61+
name: "no existing XFF",
62+
existingXFF: "",
63+
remoteAddr: "192.168.1.100:54321",
64+
expectedXFF: "192.168.1.100",
65+
},
66+
{
67+
name: "with existing XFF",
68+
existingXFF: "10.0.0.1",
69+
remoteAddr: "192.168.1.100:54321",
70+
expectedXFF: "10.0.0.1, 192.168.1.100",
71+
},
72+
{
73+
name: "multiple entries in existing XFF",
74+
existingXFF: "10.0.0.1, 172.16.0.1",
75+
remoteAddr: "192.168.1.100:54321",
76+
expectedXFF: "10.0.0.1, 172.16.0.1, 192.168.1.100",
77+
},
78+
}
79+
80+
for _, tt := range tests {
81+
t.Run(tt.name, func(t *testing.T) {
82+
ctx := context.Background()
83+
req, _ := http.NewRequest("GET", "/", nil)
84+
if tt.existingXFF != "" {
85+
req.Header.Set("X-Forwarded-For", tt.existingXFF)
86+
}
87+
req.RemoteAddr = tt.remoteAddr
88+
89+
ctx = ContextWithHTTPRequest(ctx, req)
90+
hc := FromContext(ctx)
91+
92+
if hc.ForwardedFor != tt.expectedXFF {
93+
t.Errorf("expected ForwardedFor %q, got %q", tt.expectedXFF, hc.ForwardedFor)
94+
}
95+
})
96+
}
97+
}
98+
99+
func TestContextWithHTTPRequest_LastEntryDeduplication(t *testing.T) {
100+
ctx := context.Background()
101+
req, _ := http.NewRequest("GET", "/", nil)
102+
req.Header.Set("X-Forwarded-For", "10.0.0.1, 192.168.1.100")
103+
req.RemoteAddr = "192.168.1.100:54321" // Same as last entry in XFF
104+
105+
ctx = ContextWithHTTPRequest(ctx, req)
106+
hc := FromContext(ctx)
107+
108+
// Should not add duplicate
109+
expected := "10.0.0.1, 192.168.1.100"
110+
if hc.ForwardedFor != expected {
111+
t.Errorf("expected ForwardedFor %q (no duplicate), got %q", expected, hc.ForwardedFor)
112+
}
113+
}
114+
115+
func TestContextWithHTTPRequest_EarlierDuplicatePreserved(t *testing.T) {
116+
ctx := context.Background()
117+
req, _ := http.NewRequest("GET", "/", nil)
118+
// IP appears earlier in chain but not as last entry
119+
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
120+
req.RemoteAddr = "192.168.1.100:54321"
121+
122+
ctx = ContextWithHTTPRequest(ctx, req)
123+
hc := FromContext(ctx)
124+
125+
// Should add it since it's not the last entry
126+
expected := "192.168.1.100, 10.0.0.1, 192.168.1.100"
127+
if hc.ForwardedFor != expected {
128+
t.Errorf("expected ForwardedFor %q (earlier duplicate preserved), got %q", expected, hc.ForwardedFor)
129+
}
130+
}
131+
132+
func TestGetClientIP_PortStripping(t *testing.T) {
133+
tests := []struct {
134+
name string
135+
remoteAddr string
136+
expected string
137+
}{
138+
{
139+
name: "IPv4 with port",
140+
remoteAddr: "192.168.1.100:54321",
141+
expected: "192.168.1.100",
142+
},
143+
{
144+
name: "IPv6 with port",
145+
remoteAddr: "[::1]:54321",
146+
expected: "::1",
147+
},
148+
{
149+
name: "IPv4 without port",
150+
remoteAddr: "192.168.1.100",
151+
expected: "192.168.1.100",
152+
},
153+
{
154+
name: "empty string",
155+
remoteAddr: "",
156+
expected: "",
157+
},
158+
}
159+
160+
for _, tt := range tests {
161+
t.Run(tt.name, func(t *testing.T) {
162+
result := getClientIP(tt.remoteAddr)
163+
if result != tt.expected {
164+
t.Errorf("expected %q, got %q", tt.expected, result)
165+
}
166+
})
167+
}
168+
}
169+
170+
func TestBuildXFF(t *testing.T) {
171+
tests := []struct {
172+
name string
173+
existingXFF string
174+
remoteAddr string
175+
expected string
176+
}{
177+
{
178+
name: "empty XFF, valid remoteAddr",
179+
existingXFF: "",
180+
remoteAddr: "192.168.1.100:54321",
181+
expected: "192.168.1.100",
182+
},
183+
{
184+
name: "existing XFF, valid remoteAddr",
185+
existingXFF: "10.0.0.1",
186+
remoteAddr: "192.168.1.100:54321",
187+
expected: "10.0.0.1, 192.168.1.100",
188+
},
189+
{
190+
name: "empty remoteAddr",
191+
existingXFF: "10.0.0.1",
192+
remoteAddr: "",
193+
expected: "10.0.0.1",
194+
},
195+
{
196+
name: "both empty",
197+
existingXFF: "",
198+
remoteAddr: "",
199+
expected: "",
200+
},
201+
{
202+
name: "XFF with spaces",
203+
existingXFF: "10.0.0.1, 172.16.0.1",
204+
remoteAddr: "192.168.1.100:54321",
205+
expected: "10.0.0.1, 172.16.0.1, 192.168.1.100",
206+
},
207+
}
208+
209+
for _, tt := range tests {
210+
t.Run(tt.name, func(t *testing.T) {
211+
result := buildXFF(tt.existingXFF, tt.remoteAddr)
212+
if result != tt.expected {
213+
t.Errorf("expected %q, got %q", tt.expected, result)
214+
}
215+
})
216+
}
217+
}
218+

0 commit comments

Comments
 (0)