From cc901397e004df2fa16ca57e354d8100202abd8e Mon Sep 17 00:00:00 2001 From: Dmitry Nikitenko Date: Tue, 17 Feb 2026 13:11:57 +0600 Subject: [PATCH] Log IP address extracted according to trust proxy setting --- cmd/tgfeed/main.go | 2 +- .../api/rest/{ipfilter.go => firewall.go} | 77 +++---------------- .../{ipfilter_test.go => firewall_test.go} | 0 internal/api/rest/ip.go | 67 ++++++++++++++++ internal/api/rest/middleware.go | 8 +- internal/api/rest/server.go | 38 ++++----- 6 files changed, 103 insertions(+), 89 deletions(-) rename internal/api/rest/{ipfilter.go => firewall.go} (62%) rename internal/api/rest/{ipfilter_test.go => firewall_test.go} (100%) create mode 100644 internal/api/rest/ip.go diff --git a/cmd/tgfeed/main.go b/cmd/tgfeed/main.go index 29e6a63..34435b5 100644 --- a/cmd/tgfeed/main.go +++ b/cmd/tgfeed/main.go @@ -83,7 +83,7 @@ func main() { generator := feed.NewGenerator() // Initialize and run the HTTP server - server := rest.NewServer(c, scraper, generator, ipFilter, port) + server := rest.NewServer(c, scraper, generator, ipFilter, port, trustProxy) if err := server.Run(ctx); err != nil { logger.Error("Server error", "error", err) diff --git a/internal/api/rest/ipfilter.go b/internal/api/rest/firewall.go similarity index 62% rename from internal/api/rest/ipfilter.go rename to internal/api/rest/firewall.go index fb22563..d6b9e74 100644 --- a/internal/api/rest/ipfilter.go +++ b/internal/api/rest/firewall.go @@ -20,9 +20,6 @@ type Firewall struct { // The allowedIPsStr parameter accepts a comma-separated list of IP addresses // and/or CIDR ranges (e.g., "10.0.0.0/24,192.168.1.1,2001:db8::/32"). // If allowedIPsStr is empty, all IP addresses are allowed by default. -// When trustProxy is true, the firewall will check X-Real-IP and X-Forwarded-For -// headers to determine the client's IP address, which is necessary when the -// application runs behind a reverse proxy. // Returns an error if any IP address or CIDR notation is invalid. func NewFirewall(allowedIPsStr string, trustProxy bool) (*Firewall, error) { if allowedIPsStr == "" { @@ -45,8 +42,7 @@ func NewFirewall(allowedIPsStr string, trustProxy bool) (*Firewall, error) { } // IsAllowed checks if the request originates from an allowed IP address. -// When trustProxy is enabled, it first checks X-Real-IP and X-Forwarded-For headers -// before falling back to RemoteAddr. If no IP restrictions are configured (empty allowlist), +// If no IP restrictions are configured (empty allowlist), // all requests are allowed. Returns false if the IP cannot be extracted or is not in the allowlist. func (f *Firewall) IsAllowed(r *http.Request) bool { if len(f.allowedNets) == 0 { @@ -59,58 +55,24 @@ func (f *Firewall) IsAllowed(r *http.Request) bool { return false } - return isIPAllowed(clientIP, f.allowedNets) + return f.isIPAllowed(clientIP) } -// extractClientIP extracts the client IP address from the request -func extractClientIP(r *http.Request, trustProxy bool) (string, error) { - if trustProxy { - if clientIP := tryExtractFromProxyHeaders(r); clientIP != "" { - return clientIP, nil - } - } - - return extractFromRemoteAddr(r.RemoteAddr) -} +// isIPAllowed checks if an IP address is in the allowed networks +func (f *Firewall) isIPAllowed(ipStr string) bool { + ip := net.ParseIP(ipStr) -// tryExtractFromProxyHeaders attempts to extract IP from proxy headers -func tryExtractFromProxyHeaders(r *http.Request) string { - if xRealIP := r.Header.Get("X-Real-IP"); xRealIP != "" { - if ip := net.ParseIP(xRealIP); ip != nil { - return ip.String() - } + if ip == nil { + return false } - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - ips := strings.Split(xff, ",") - - if len(ips) > 0 { - clientIP := strings.TrimSpace(ips[0]) - - if ip := net.ParseIP(clientIP); ip != nil { - return ip.String() - } + for _, ipNet := range f.allowedNets { + if ipNet.Contains(ip) { + return true } } - return "" -} - -// extractFromRemoteAddr extracts IP from RemoteAddr -func extractFromRemoteAddr(remoteAddr string) (string, error) { - host, _, err := net.SplitHostPort(remoteAddr) - - if err != nil { - return "", fmt.Errorf("invalid remote address: %w", err) - } - - ip := net.ParseIP(host) - - if ip == nil { - return "", fmt.Errorf("invalid IP address: %s", host) - } - - return ip.String(), nil + return false } // parseAllowedIPs parses a comma-separated list of IP addresses and CIDR ranges @@ -169,20 +131,3 @@ func parseIPOrCIDR(part string) (*net.IPNet, error) { return ipNet, nil } - -// isIPAllowed checks if an IP address is in the allowed networks -func isIPAllowed(ipStr string, allowedNets []*net.IPNet) bool { - ip := net.ParseIP(ipStr) - - if ip == nil { - return false - } - - for _, ipNet := range allowedNets { - if ipNet.Contains(ip) { - return true - } - } - - return false -} diff --git a/internal/api/rest/ipfilter_test.go b/internal/api/rest/firewall_test.go similarity index 100% rename from internal/api/rest/ipfilter_test.go rename to internal/api/rest/firewall_test.go diff --git a/internal/api/rest/ip.go b/internal/api/rest/ip.go new file mode 100644 index 0000000..63cefb7 --- /dev/null +++ b/internal/api/rest/ip.go @@ -0,0 +1,67 @@ +package rest + +import ( + "fmt" + "net" + "net/http" + "strings" +) + +// extractClientIP extracts the client IP address from the request. +// When trustProxy is true, the firewall will check X-Real-IP and +// X-Forwarded-For headers to determine the client's IP address, +// which is necessary when the application runs behind a reverse proxy. +func extractClientIP(r *http.Request, trustProxy bool) (string, error) { + if trustProxy { + if clientIP := tryExtractFromProxyHeaders(r); clientIP != "" { + return clientIP, nil + } + } + + return extractFromRemoteAddr(r.RemoteAddr) +} + +// mustExtractClientIP behaves exactly like extractClientIP except it +// doesn't return an error, ignoring it instead. +func mustExtractClientIP(r *http.Request, trustProxy bool) string { + ip, _ := extractClientIP(r, trustProxy) + + return ip +} + +// tryExtractFromProxyHeaders attempts to extract IP from proxy headers +func tryExtractFromProxyHeaders(r *http.Request) string { + if xRealIP := r.Header.Get("X-Real-IP"); xRealIP != "" { + if ip := net.ParseIP(xRealIP); ip != nil { + return ip.String() + } + } + + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + ips := strings.Split(xff, ",") + clientIP := strings.TrimSpace(ips[0]) + + if ip := net.ParseIP(clientIP); ip != nil { + return ip.String() + } + } + + return "" +} + +// extractFromRemoteAddr extracts IP from RemoteAddr +func extractFromRemoteAddr(remoteAddr string) (string, error) { + host, _, err := net.SplitHostPort(remoteAddr) + + if err != nil { + return "", fmt.Errorf("invalid remote address: %w", err) + } + + ip := net.ParseIP(host) + + if ip == nil { + return "", fmt.Errorf("invalid IP address: %s", host) + } + + return ip.String(), nil +} diff --git a/internal/api/rest/middleware.go b/internal/api/rest/middleware.go index 9bc791d..339b46b 100644 --- a/internal/api/rest/middleware.go +++ b/internal/api/rest/middleware.go @@ -9,7 +9,7 @@ import ( ) // Logger wraps an http.Handler with request/response logging -func Logger(next http.Handler) http.Handler { +func Logger(next http.Handler, trustProxy bool) http.Handler { logger := app.Logger() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -25,7 +25,7 @@ func Logger(next http.Handler) http.Handler { "method", r.Method, "path", r.URL.Path, "query", r.URL.RawQuery, - "remote_addr", r.RemoteAddr, + "remote_addr", mustExtractClientIP(r, trustProxy), "user_agent", r.UserAgent(), ) @@ -73,7 +73,7 @@ func (lrw *loggingResponseWriter) Unwrap() http.ResponseWriter { // When a filter is provided, each request is validated using filter.IsAllowed. // Denied requests receive a 403 Forbidden response with a JSON error message. // The middleware logs warnings for denied requests including the remote address and path. -func IPFilterMiddleware(filter IPFilter) func(http.Handler) http.Handler { +func IPFilterMiddleware(filter IPFilter, trustProxy bool) func(http.Handler) http.Handler { logger := app.Logger() return func(next http.Handler) http.Handler { @@ -88,7 +88,7 @@ func IPFilterMiddleware(filter IPFilter) func(http.Handler) http.Handler { } logger.Warn("IP not allowed", - "remote_addr", r.RemoteAddr, + "remote_addr", mustExtractClientIP(r, trustProxy), "path", r.URL.Path, ) diff --git a/internal/api/rest/server.go b/internal/api/rest/server.go index 25ec6f8..25566c9 100644 --- a/internal/api/rest/server.go +++ b/internal/api/rest/server.go @@ -14,32 +14,34 @@ import ( // Server represents the REST API server type Server struct { - mux *http.ServeMux - server *http.Server - logger *slog.Logger - cache cache.Cache - scraper Scraper - generator Generator - ipFilter IPFilter - port string + mux *http.ServeMux + server *http.Server + logger *slog.Logger + cache cache.Cache + scraper Scraper + generator Generator + ipFilter IPFilter + port string + trustProxy bool } // NewServer creates a new REST API server with the specified dependencies. // The ipFilter parameter controls IP-based access restrictions; pass nil to disable filtering. // The port parameter specifies the TCP port to listen on (e.g., "8080"). // The server is pre-configured with secure timeout values to mitigate common attacks. -func NewServer(c cache.Cache, s Scraper, g Generator, ipFilter IPFilter, port string) *Server { +func NewServer(c cache.Cache, s Scraper, g Generator, ipFilter IPFilter, port string, trustProxy bool) *Server { mux := http.NewServeMux() logger := app.Logger() server := &Server{ - mux: mux, - logger: logger, - cache: c, - scraper: s, - generator: g, - ipFilter: ipFilter, - port: port, + mux: mux, + logger: logger, + cache: c, + scraper: s, + generator: g, + ipFilter: ipFilter, + port: port, + trustProxy: trustProxy, server: &http.Server{ Addr: ":" + port, Handler: nil, // Will be set in Run @@ -65,8 +67,8 @@ func (s *Server) registerHandlers() { func (s *Server) Run(ctx context.Context) error { // Apply middleware chain handler := http.Handler(s.mux) - handler = IPFilterMiddleware(s.ipFilter)(handler) - handler = Logger(handler) + handler = IPFilterMiddleware(s.ipFilter, s.trustProxy)(handler) + handler = Logger(handler, s.trustProxy) // Set the handler with middleware s.server.Handler = handler