diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..9c2c8bf --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,9 @@ +services: + neon-proxy: + build: . + environment: + - ALLOW_ADDR_REGEX=.* + - LOG_TRAFFIC=true + - TLS_SKIP_VERIFY=true + ports: + - '5433:80' \ No newline at end of file diff --git a/main.go b/main.go index d14d6a2..12b2333 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "crypto/tls" "fmt" "io" "log" @@ -42,6 +43,9 @@ type Config struct { UseHostHeader bool `env:"USE_HOST_HEADER" envDefault:"false"` LogTraffic bool `env:"LOG_TRAFFIC" envDefault:"false"` LogConnInfo bool `env:"LOG_CONN_INFO" envDefault:"true"` + UseTLS bool `env:"USE_TLS" envDefault:"true"` + TLSSkipVerify bool `env:"TLS_SKIP_VERIFY" envDefault:"false"` + TLSServerName string `env:"TLS_SERVER_NAME" envDefault:""` } var upgrader = websocket.Upgrader{ @@ -142,9 +146,71 @@ func (h *ProxyHandler) HandleWS(conn *websocket.Conn, addr string) error { activeConnections.Inc() defer activeConnections.Dec() - socket, err := net.Dial("tcp", addr) - if err != nil { - return err + var socket net.Conn + var err error + + if h.cfg.UseTLS { + // First establish a plain TCP connection + socket, err = net.Dial("tcp", addr) + if err != nil { + return fmt.Errorf("failed to establish TCP connection: %w", err) + } + + // Send PostgreSQL SSL request + sslRequest := []byte{0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f} + _, err = socket.Write(sslRequest) + if err != nil { + socket.Close() + return fmt.Errorf("failed to send SSL request: %w", err) + } + + // Read SSL response (1 byte) + response := make([]byte, 1) + _, err = socket.Read(response) + if err != nil { + socket.Close() + return fmt.Errorf("failed to read SSL response: %w", err) + } + + if response[0] == 'S' { + // Server supports SSL, upgrade the connection + serverName := h.cfg.TLSServerName + if serverName == "" { + // Extract hostname from address if TLS_SERVER_NAME is not set + host, _, err := net.SplitHostPort(addr) + if err != nil { + // If SplitHostPort fails, use the full address as hostname + serverName = addr + } else { + serverName = host + } + } + + tlsConfig := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: h.cfg.TLSSkipVerify, + } + tlsConn := tls.Client(socket, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + socket.Close() + return fmt.Errorf("failed to complete TLS handshake: %w", err) + } + socket = tlsConn + } else if response[0] == 'N' { + // Server doesn't support SSL + if h.cfg.LogConnInfo { + log.Printf("PostgreSQL server doesn't support SSL, continuing with plain connection") + } + } else { + socket.Close() + return fmt.Errorf("unexpected SSL response from server: %c", response[0]) + } + } else { + socket, err = net.Dial("tcp", addr) + if err != nil { + return fmt.Errorf("failed to establish TCP connection: %w", err) + } } defer socket.Close()