Skip to content

Commit 4e3effd

Browse files
feat: implementation of non-transparent proxy
1 parent 5e39fc0 commit 4e3effd

File tree

1 file changed

+246
-0
lines changed

1 file changed

+246
-0
lines changed

proxy/proxy.go

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ func (p *Server) handleHTTPConnection(conn net.Conn) {
218218
return
219219
}
220220

221+
if req.Method == http.MethodConnect {
222+
p.handleCONNECT(conn, req)
223+
return
224+
}
225+
221226
p.logger.Debug("🌐 HTTP Request: %s %s", req.Method, req.URL.String())
222227
p.processHTTPRequest(conn, req, false)
223228
}
@@ -423,6 +428,247 @@ For more help: https://github.com/coder/boundary
423428
p.logger.Debug("Successfully wrote to connection")
424429
}
425430

431+
// handleCONNECT handles HTTP CONNECT requests for tunneling
432+
func (p *Server) handleCONNECT(conn net.Conn, req *http.Request) {
433+
// Extract target from CONNECT request
434+
// CONNECT requests have the target in req.Host (format: hostname:port)
435+
target := req.Host
436+
if target == "" {
437+
target = req.URL.Host
438+
}
439+
440+
p.logger.Debug("🔌 CONNECT request", "target", target)
441+
442+
// Check if target is allowed
443+
// Use "CONNECT" as method and target as the URL for evaluation
444+
result := p.ruleEngine.Evaluate("CONNECT", target)
445+
446+
// Audit the CONNECT request
447+
p.auditor.AuditRequest(audit.Request{
448+
Method: "CONNECT",
449+
URL: target,
450+
Host: target,
451+
Allowed: result.Allowed,
452+
Rule: result.Rule,
453+
})
454+
455+
if !result.Allowed {
456+
p.logger.Debug("CONNECT request blocked", "target", target)
457+
p.writeBlockedCONNECTResponse(conn, target)
458+
return
459+
}
460+
461+
// Send 200 Connection established response
462+
response := "HTTP/1.1 200 Connection established\r\n\r\n"
463+
_, err := conn.Write([]byte(response))
464+
if err != nil {
465+
p.logger.Error("Failed to send CONNECT response", "error", err)
466+
return
467+
}
468+
469+
p.logger.Debug("CONNECT tunnel established", "target", target)
470+
471+
// Handle the tunnel - decrypt TLS and process each HTTP request
472+
p.handleCONNECTTunnel(conn, target)
473+
}
474+
475+
// handleCONNECTTunnel handles the tunnel after CONNECT is established
476+
// It decrypts TLS traffic and processes each HTTP request separately
477+
func (p *Server) handleCONNECTTunnel(conn net.Conn, target string) {
478+
defer func() {
479+
err := conn.Close()
480+
if err != nil {
481+
p.logger.Error("Failed to close CONNECT tunnel", "error", err)
482+
}
483+
}()
484+
485+
// Wrap connection with TLS server to decrypt traffic
486+
tlsConn := tls.Server(conn, p.tlsConfig)
487+
488+
// Perform TLS handshake
489+
if err := tlsConn.Handshake(); err != nil {
490+
p.logger.Error("TLS handshake failed in CONNECT tunnel", "error", err)
491+
return
492+
}
493+
494+
p.logger.Debug("✅ TLS handshake successful in CONNECT tunnel")
495+
496+
// Process HTTP requests in a loop
497+
reader := bufio.NewReader(tlsConn)
498+
for {
499+
// Read HTTP request from tunnel
500+
req, err := http.ReadRequest(reader)
501+
if err != nil {
502+
if err == io.EOF {
503+
p.logger.Debug("CONNECT tunnel closed by client")
504+
break
505+
}
506+
p.logger.Error("Failed to read HTTP request from CONNECT tunnel", "error", err)
507+
break
508+
}
509+
510+
p.logger.Debug("🔒 HTTP Request in CONNECT tunnel", "method", req.Method, "url", req.URL.String(), "target", target)
511+
512+
// Process this request - check if allowed and forward to target
513+
p.processTunnelRequest(tlsConn, req, target)
514+
}
515+
}
516+
517+
// processTunnelRequest processes a single HTTP request from the CONNECT tunnel
518+
func (p *Server) processTunnelRequest(conn net.Conn, req *http.Request, targetHost string) {
519+
// Check if request should be allowed
520+
// Use the original request URL but evaluate against rules
521+
urlStr := req.Host + req.URL.String()
522+
result := p.ruleEngine.Evaluate(req.Method, urlStr)
523+
524+
// Audit the request
525+
p.auditor.AuditRequest(audit.Request{
526+
Method: req.Method,
527+
URL: req.URL.String(),
528+
Host: req.Host,
529+
Allowed: result.Allowed,
530+
Rule: result.Rule,
531+
})
532+
533+
if !result.Allowed {
534+
p.logger.Debug("Request in CONNECT tunnel blocked", "method", req.Method, "url", urlStr)
535+
p.writeBlockedResponse(conn, req)
536+
return
537+
}
538+
539+
// Forward request to target
540+
// The target is the original CONNECT target, but we use the request's host/path
541+
p.forwardTunnelRequest(conn, req, targetHost)
542+
}
543+
544+
// forwardTunnelRequest forwards a request from the tunnel to the target
545+
func (p *Server) forwardTunnelRequest(conn net.Conn, req *http.Request, targetHost string) {
546+
// Create HTTP client
547+
client := &http.Client{
548+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
549+
return http.ErrUseLastResponse // Don't follow redirects
550+
},
551+
}
552+
553+
// Parse target host to get hostname and port
554+
hostname := targetHost
555+
port := "443" // Default HTTPS port
556+
if strings.Contains(targetHost, ":") {
557+
parts := strings.Split(targetHost, ":")
558+
hostname = parts[0]
559+
port = parts[1]
560+
}
561+
562+
// Determine scheme based on port
563+
scheme := "https"
564+
if port == "80" {
565+
scheme = "http"
566+
}
567+
568+
// Build target URL using the request's path but the CONNECT target's host
569+
targetURL := &url.URL{
570+
Scheme: scheme,
571+
Host: targetHost,
572+
Path: req.URL.Path,
573+
RawQuery: req.URL.RawQuery,
574+
}
575+
576+
var body = req.Body
577+
if req.Method == http.MethodGet || req.Method == http.MethodHead {
578+
body = nil
579+
}
580+
581+
newReq, err := http.NewRequest(req.Method, targetURL.String(), body)
582+
if err != nil {
583+
p.logger.Error("can't create HTTP request for tunnel", "error", err)
584+
return
585+
}
586+
587+
// Copy headers
588+
for name, values := range req.Header {
589+
// Skip connection-specific headers
590+
if strings.ToLower(name) == "connection" || strings.ToLower(name) == "proxy-connection" {
591+
continue
592+
}
593+
for _, value := range values {
594+
newReq.Header.Add(name, value)
595+
}
596+
}
597+
598+
// Make request to destination
599+
resp, err := client.Do(newReq)
600+
if err != nil {
601+
p.logger.Error("Failed to forward request from CONNECT tunnel", "error", err)
602+
return
603+
}
604+
605+
p.logger.Debug("Response from target", "status", resp.StatusCode, "target", targetHost)
606+
607+
// Read the body and set Content-Length
608+
bodyBytes, err := io.ReadAll(resp.Body)
609+
if err != nil {
610+
p.logger.Error("can't read response body from tunnel", "error", err)
611+
return
612+
}
613+
resp.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes)))
614+
resp.ContentLength = int64(len(bodyBytes))
615+
err = resp.Body.Close()
616+
if err != nil {
617+
p.logger.Error("Failed to close response body", "error", err)
618+
return
619+
}
620+
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
621+
622+
// Normalize to HTTP/1.1
623+
resp.Proto = "HTTP/1.1"
624+
resp.ProtoMajor = 1
625+
resp.ProtoMinor = 1
626+
627+
// Write response back to tunnel
628+
err = resp.Write(conn)
629+
if err != nil {
630+
p.logger.Error("Failed to write response to CONNECT tunnel", "error", err)
631+
return
632+
}
633+
634+
p.logger.Debug("Successfully forwarded response in CONNECT tunnel")
635+
}
636+
637+
// writeBlockedCONNECTResponse writes a blocked response for CONNECT requests
638+
func (p *Server) writeBlockedCONNECTResponse(conn net.Conn, target string) {
639+
resp := &http.Response{
640+
Status: "403 Forbidden",
641+
StatusCode: http.StatusForbidden,
642+
Proto: "HTTP/1.1",
643+
ProtoMajor: 1,
644+
ProtoMinor: 1,
645+
Header: make(http.Header),
646+
Body: nil,
647+
ContentLength: 0,
648+
}
649+
650+
resp.Header.Set("Content-Type", "text/plain")
651+
652+
body := fmt.Sprintf(`🚫 CONNECT Request Blocked by Boundary
653+
654+
Target: %s
655+
656+
To allow this CONNECT request, restart boundary with:
657+
--allow "domain=%s"
658+
659+
For more help: https://github.com/coder/boundary
660+
`, target, target)
661+
662+
resp.Body = io.NopCloser(strings.NewReader(body))
663+
resp.ContentLength = int64(len(body))
664+
665+
err := resp.Write(conn)
666+
if err != nil {
667+
p.logger.Error("Failed to write blocked CONNECT response", "error", err)
668+
return
669+
}
670+
}
671+
426672
// connectionWrapper lets us "unread" the peeked byte
427673
type connectionWrapper struct {
428674
net.Conn

0 commit comments

Comments
 (0)