diff --git a/README.md b/README.md index 4167010..9565abb 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,11 @@ environment variables that you can set. | `BAD_GATEWAY_PAGE` | Path to an HTML file to serve when the backend server returns a 502 Bad Gateway error. If there is no file at the specific path, Thruster will serve an empty 502 response instead. Because Thruster boots very quickly, a custom page can be a useful way to show that your application is starting up. | `./public/502.html` | | `HTTP_PORT` | The port to listen on for HTTP traffic. | 80 | | `HTTPS_PORT` | The port to listen on for HTTPS traffic. | 443 | +| `HTTP_HEALTH_PATH` | The http health path to check before start port listening. | None | +| `HTTP_HEALTH_HOST` | The http health host to check before start port listening. | 127.0.0.1 | +| `HTTP_HEALTH_INTERVAL` | The http health path check interval (seconds). | 1 | +| `HTTP_HEALTH_TIMEOUT` | The http health path check timeout (seconds). | 1 | +| `HTTP_HEALTH_DEADLINE` | The http health path deadline interval (seconds), after which thruster will exit with error, if no success response. | 120 | | `HTTP_IDLE_TIMEOUT` | The maximum time in seconds that a client can be idle before the connection is closed. | 60 | | `HTTP_READ_TIMEOUT` | The maximum time in seconds that a client can take to send the request headers and body. | 30 | | `HTTP_WRITE_TIMEOUT` | The maximum time in seconds during which the client must read the response. | 30 | @@ -103,6 +108,32 @@ Thruster's environment variables can optionally be prefixed with `THRUSTER_`. For example, `TLS_DOMAIN` can also be written as `THRUSTER_TLS_DOMAIN`. Whenever a prefixed variable is set, it will take precedence over the unprefixed version. +### HTTP_HEALTH_PATH and rails + +When using `HTTP_HEALTH_PATH` for health check, this endpoint should work over HTTP protocol and return 200 status code. In rails you can add in `config/routes.rb` such route: + +```ruby +get '/health', to: 'rails/health#show', as: :rails_health_check +``` + +and add in `config/application.rb` such settings for hosts checks: + +```ruby +config.host_authorization = { + exclude: ->(request) { request.path == '/health' } +} +``` + +If your environment have `config.assume_ssl = true` (not handle http to https redirects), in this case you done. But if you doing http to https redirects on rails side (like need on heroku router), you need also add in `config/application.rb` such settings: + +```ruby +config.ssl_options = { + redirect: { + exclude: ->(request) { request.path == '/health' } + } +} +``` + ## Security ### BREACH Mitigation diff --git a/go.mod b/go.mod index da03248..d6df750 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module github.com/basecamp/thruster go 1.25.6 require ( - github.com/klauspost/compress v1.18.2 + github.com/klauspost/compress v1.18.3 github.com/stretchr/testify v1.8.4 - golang.org/x/crypto v0.46.0 - golang.org/x/net v0.48.0 + golang.org/x/crypto v0.47.0 + golang.org/x/net v0.49.0 ) require ( @@ -14,7 +14,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/text v0.32.0 // indirect + golang.org/x/text v0.33.0 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index af9bfa8..d4e72d4 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= +github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -13,12 +13,12 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/config.go b/internal/config.go index fd2887d..55d62e5 100644 --- a/internal/config.go +++ b/internal/config.go @@ -27,11 +27,15 @@ const ( defaultStoragePath = "./storage/thruster" defaultBadGatewayPage = "./public/502.html" - defaultHttpPort = 80 - defaultHttpsPort = 443 - defaultHttpIdleTimeout = 60 * time.Second - defaultHttpReadTimeout = 30 * time.Second - defaultHttpWriteTimeout = 30 * time.Second + defaultHttpPort = 80 + defaultHttpsPort = 443 + defaultHttpHealthHost = "127.0.0.1" + defaultHttpHealthTimeout = 1 * time.Second + defaultHttpHealthInterval = 1 * time.Second + defaultHttpHealthDeadline = 2 * time.Minute + defaultHttpIdleTimeout = 60 * time.Second + defaultHttpReadTimeout = 30 * time.Second + defaultHttpWriteTimeout = 30 * time.Second defaultH2CEnabled = false @@ -62,11 +66,16 @@ type Config struct { StoragePath string BadGatewayPage string - HttpPort int - HttpsPort int - HttpIdleTimeout time.Duration - HttpReadTimeout time.Duration - HttpWriteTimeout time.Duration + HttpPort int + HttpsPort int + HttpHealthHost string + HttpHealthPath string + HttpHealthTimeout time.Duration + HttpHealthInterval time.Duration + HttpHealthDeadline time.Duration + HttpIdleTimeout time.Duration + HttpReadTimeout time.Duration + HttpWriteTimeout time.Duration H2CEnabled bool @@ -89,7 +98,7 @@ func NewConfig() (*Config, error) { config := &Config{ TargetPort: getEnvInt("TARGET_PORT", defaultTargetPort), UpstreamCommand: os.Args[1], - UpstreamArgs: os.Args[2:], + UpstreamArgs: append([]string{}, os.Args[2:]...), CacheSizeBytes: getEnvInt("CACHE_SIZE", defaultCacheSize), MaxCacheItemSizeBytes: getEnvInt("MAX_CACHE_ITEM_SIZE", defaultMaxCacheItemSizeBytes), @@ -106,11 +115,16 @@ func NewConfig() (*Config, error) { StoragePath: getEnvString("STORAGE_PATH", defaultStoragePath), BadGatewayPage: getEnvString("BAD_GATEWAY_PAGE", defaultBadGatewayPage), - HttpPort: getEnvInt("HTTP_PORT", defaultHttpPort), - HttpsPort: getEnvInt("HTTPS_PORT", defaultHttpsPort), - HttpIdleTimeout: getEnvDuration("HTTP_IDLE_TIMEOUT", defaultHttpIdleTimeout), - HttpReadTimeout: getEnvDuration("HTTP_READ_TIMEOUT", defaultHttpReadTimeout), - HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout), + HttpPort: getEnvInt("HTTP_PORT", defaultHttpPort), + HttpsPort: getEnvInt("HTTPS_PORT", defaultHttpsPort), + HttpHealthHost: getEnvString("HTTP_HEALTH_HOST", defaultHttpHealthHost), + HttpHealthPath: getEnvString("HTTP_HEALTH_PATH", ""), + HttpHealthInterval: getEnvDuration("HTTP_HEALTH_INTERVAL", defaultHttpHealthInterval), + HttpHealthTimeout: getEnvDuration("HTTP_HEALTH_TIMEOUT", defaultHttpHealthTimeout), + HttpHealthDeadline: getEnvDuration("HTTP_HEALTH_DEADLINE", defaultHttpHealthDeadline), + HttpIdleTimeout: getEnvDuration("HTTP_IDLE_TIMEOUT", defaultHttpIdleTimeout), + HttpReadTimeout: getEnvDuration("HTTP_READ_TIMEOUT", defaultHttpReadTimeout), + HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout), H2CEnabled: getEnvBool("H2C_ENABLED", defaultH2CEnabled), diff --git a/internal/config_test.go b/internal/config_test.go index 5dc1d9b..f8c6100 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -105,6 +105,11 @@ func TestConfig_defaults(t *testing.T) { assert.Equal(t, "echo", c.UpstreamCommand) assert.Equal(t, defaultCacheSize, c.CacheSizeBytes) assert.Equal(t, slog.LevelInfo, c.LogLevel) + assert.Equal(t, "", c.HttpHealthPath) + assert.Equal(t, "127.0.0.1", c.HttpHealthHost) + assert.Equal(t, 1*time.Second, c.HttpHealthTimeout) + assert.Equal(t, 1*time.Second, c.HttpHealthInterval) + assert.Equal(t, 2*time.Minute, c.HttpHealthDeadline) assert.Equal(t, false, c.H2CEnabled) } @@ -118,6 +123,11 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { usingEnvVar(t, "DEBUG", "1") usingEnvVar(t, "ACME_DIRECTORY", "https://acme-staging-v02.api.letsencrypt.org/directory") usingEnvVar(t, "LOG_REQUESTS", "false") + usingEnvVar(t, "HTTP_HEALTH_PATH", "/health") + usingEnvVar(t, "HTTP_HEALTH_HOST", "localhost") + usingEnvVar(t, "HTTP_HEALTH_INTERVAL", "3") + usingEnvVar(t, "HTTP_HEALTH_TIMEOUT", "4") + usingEnvVar(t, "HTTP_HEALTH_DEADLINE", "60") usingEnvVar(t, "H2C_ENABLED", "true") usingEnvVar(t, "GZIP_COMPRESSION_DISABLE_ON_AUTH", "true") usingEnvVar(t, "GZIP_COMPRESSION_JITTER", "64") @@ -132,6 +142,11 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { assert.Equal(t, false, c.GzipCompressionEnabled) assert.Equal(t, slog.LevelDebug, c.LogLevel) assert.Equal(t, "https://acme-staging-v02.api.letsencrypt.org/directory", c.ACMEDirectoryURL) + assert.Equal(t, "/health", c.HttpHealthPath) + assert.Equal(t, "localhost", c.HttpHealthHost) + assert.Equal(t, 3*time.Second, c.HttpHealthInterval) + assert.Equal(t, 4*time.Second, c.HttpHealthTimeout) + assert.Equal(t, 60*time.Second, c.HttpHealthDeadline) assert.Equal(t, false, c.LogRequests) assert.Equal(t, true, c.H2CEnabled) assert.Equal(t, true, c.GzipCompressionDisableOnAuth) @@ -146,6 +161,11 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { usingEnvVar(t, "THRUSTER_X_SENDFILE_ENABLED", "0") usingEnvVar(t, "THRUSTER_DEBUG", "1") usingEnvVar(t, "THRUSTER_LOG_REQUESTS", "0") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_PATH", "/health") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_HOST", "localhost") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_INTERVAL", "3") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_TIMEOUT", "4") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_DEADLINE", "60") usingEnvVar(t, "THRUSTER_H2C_ENABLED", "1") c, err := NewConfig() @@ -157,6 +177,11 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { assert.Equal(t, false, c.XSendfileEnabled) assert.Equal(t, slog.LevelDebug, c.LogLevel) assert.Equal(t, false, c.LogRequests) + assert.Equal(t, "/health", c.HttpHealthPath) + assert.Equal(t, "localhost", c.HttpHealthHost) + assert.Equal(t, 3*time.Second, c.HttpHealthInterval) + assert.Equal(t, 4*time.Second, c.HttpHealthTimeout) + assert.Equal(t, 60*time.Second, c.HttpHealthDeadline) assert.Equal(t, true, c.H2CEnabled) } @@ -171,6 +196,20 @@ func TestConfig_prefixed_variables_take_precedence_over_non_prefixed(t *testing. assert.Equal(t, 4000, c.TargetPort) } +func TestConfig_defaults_are_used_if_strconv_fails(t *testing.T) { + usingProgramArgs(t, "thruster", "echo", "hello") + usingEnvVar(t, "TARGET_PORT", "should-be-an-int") + usingEnvVar(t, "HTTP_IDLE_TIMEOUT", "should-be-a-duration") + usingEnvVar(t, "X_SENDFILE_ENABLED", "should-be-a-bool") + + c, err := NewConfig() + require.NoError(t, err) + + assert.Equal(t, 3000, c.TargetPort) + assert.Equal(t, 60*time.Second, c.HttpIdleTimeout) + assert.Equal(t, true, c.XSendfileEnabled) +} + func TestConfig_return_error_when_no_upstream_command(t *testing.T) { usingProgramArgs(t, "thruster") diff --git a/internal/service.go b/internal/service.go index c34ba73..8e6fb4c 100644 --- a/internal/service.go +++ b/internal/service.go @@ -1,16 +1,27 @@ package internal import ( + "context" "fmt" "log/slog" + "net/http" "net/url" "os" + "os/signal" + "syscall" + "time" ) type Service struct { config *Config } +// Represents the result of the upstream process execution. +type upstreamResult struct { + exitCode int + err error +} + func NewService(config *Config) *Service { return &Service{ config: config, @@ -36,23 +47,136 @@ func (s *Service) Run() int { server := NewServer(s.config, handler) upstream := NewUpstreamProcess(s.config.UpstreamCommand, s.config.UpstreamArgs...) + s.setEnvironment() + + // Channel to receive the result from the upstream process goroutine. + resultChan := make(chan upstreamResult, 1) + + // Run the upstream process in a separate goroutine + // This allows us to perform health checks while it starts up + go func() { + exitCode, err := upstream.Run() + resultChan <- upstreamResult{exitCode: exitCode, err: err} + }() + + // If a health check path is configured, wait for the upstream to become healthy + if s.config.HttpHealthPath != "" { + if err := s.performHealthCheck(resultChan); err != nil { + slog.Error("Upstream health check failed", "error", err) + // At this point, the upstream process is running but unhealthy + if err := upstream.Signal(syscall.SIGTERM); err != nil { + slog.Error("Failed to send signal to upstream process", "error", err) + } + return 1 + } + slog.Info("Upstream service is healthy, starting proxy server.") + } + + // Now that the upstream is ready, start the main proxy server if err := server.Start(); err != nil { return 1 } defer server.Stop() - s.setEnvironment() + // Delegate the waiting and signal handling to the new function + return s.awaitTermination(upstream, resultChan) +} + +// Private + +func (s *Service) awaitTermination(upstream *UpstreamProcess, resultChan <-chan upstreamResult) int { + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + + select { + case result := <-resultChan: + // The upstream process finished on its own. + slog.Info("Wrapped process exited on its own.", "exit_code", result.exitCode) + if result.err != nil { + slog.Error("Wrapped process failed", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", result.err) + return 1 + } + return result.exitCode + + case sig := <-signalChan: + // An OS signal was caught + slog.Info("Received signal, shutting down.", "signal", sig.String()) + + // Relay the signal to the child process to allow for graceful shutdown. + slog.Info("Relaying signal to upstream process...") + if err := upstream.Signal(sig); err != nil { + slog.Error("Failed to send signal to upstream process", "error", err) + } - exitCode, err := upstream.Run() - if err != nil { - slog.Error("Failed to start wrapped process", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", err) + // Give the upstream process a moment to shut down gracefully + // before the defer server.Stop() forcefully cleans up. + select { + case <-resultChan: + slog.Info("Upstream process terminated gracefully after signal.") + case <-time.After(10 * time.Second): + slog.Warn("Upstream process did not terminate within 10 seconds of signal.") + } + + // Exit with a non-zero status code to indicate termination by signal. return 1 } - - return exitCode } -// Private +// performHealthCheck polls the health check endpoint until it gets a 200 OK +func (s *Service) performHealthCheck(resultChan <-chan upstreamResult) error { + // Create a context with a 2-minute timeout (default) for the entire health check process + ctx, cancel := context.WithTimeout(context.Background(), s.config.HttpHealthDeadline) + defer cancel() + + // We assume the upstream server binds to the target URL's host + healthCheckURL := fmt.Sprintf("http://%s:%d%s", s.config.HttpHealthHost, s.config.TargetPort, s.config.HttpHealthPath) + slog.Info("Starting health checks", "url", healthCheckURL) + + // Use a ticker to check every second (default) + ticker := time.NewTicker(s.config.HttpHealthInterval) + defer ticker.Stop() + + // Create an HTTP client with a short timeout for individual requests + client := &http.Client{ + Timeout: s.config.HttpHealthTimeout, + } + + for { + select { + case <-ctx.Done(): + // Deadline exceeded + return fmt.Errorf("health check timed out after %v", s.config.HttpHealthDeadline) + + case result := <-resultChan: + // The upstream process exited before it became healthy + return fmt.Errorf("upstream process exited prematurely with code %d: %w", result.exitCode, result.err) + + case <-ticker.C: + // Ticker fired, time to perform a check + req, err := http.NewRequestWithContext(ctx, "GET", healthCheckURL, nil) + if err != nil { + return fmt.Errorf("failed to create health check request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + // This is expected while the server is starting up (e.g., "connection refused") + slog.Debug("Health check attempt failed, retrying...", "error", err) + continue + } + + // Don't forget to close the body + resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + // Success! + return nil + } + + slog.Debug("Health check received non-200 status", "status_code", resp.StatusCode) + } + } +} func (s *Service) cache() Cache { return NewMemoryCache(s.config.CacheSizeBytes, s.config.MaxCacheItemSizeBytes)