Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ jobs:
run: go build -v ./...

- name: Run tests
run: go test -v ./...
run: make test

- name: Run tests with race detector
run: make test-race

- name: Run vet
run: go vet ./...
Expand Down
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ gen-certs:
@echo "Generating TLS certificates..."
./scripts/gen-certs.sh

test:
@echo "Running tests..."
go test -v ./...

test-race:
@echo "Running tests with race detector..."
go test -race -v ./...
11 changes: 8 additions & 3 deletions cmd/loadbalancer/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package connection

import (
"net"

"github.com/gobwas/ws"
)

// Connection combines tracking and proxying capabilities
Expand All @@ -13,7 +15,12 @@ type Connection struct {
// NewConnection creates a fully configured connection
func NewConnection(user, upstreamHost, downstreamHost string, downstreamConn net.Conn) *Connection {
tracker := NewTracker(user, upstreamHost, downstreamHost, downstreamConn)
proxier := NewWSProxier(tracker)
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP{
"ws-user-id": []string{user},
},
}
proxier := NewWSProxier(tracker, &dialer)

return &Connection{
Tracker: tracker,
Expand All @@ -23,8 +30,6 @@ func NewConnection(user, upstreamHost, downstreamHost string, downstreamConn net

// Handle manages the connection lifecycle
func (c *Connection) Handle() {
//defer c.Close()

proxiedConn, err := c.ProxyDownstreamToUpstream()
if err != nil {
return
Expand Down
11 changes: 5 additions & 6 deletions cmd/loadbalancer/connection/downstream_proxier.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,13 @@ import (

func (p *WSProxier) ProxyDownstreamToUpstream() (net.Conn, error) {
host := p.tracker.UpstreamHost()
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP{
"ws-user-id": []string{p.tracker.User()},
},
}

p.tracker.Debug("Dialing upstream")
upstreamContext := p.tracker.UpstreamContext()
upstreamCancelChan := p.tracker.UpstreamCancelChan()
downstreamConn := p.tracker.DownstreamConn()

proxiedConn, _, _, err := dialer.Dial(context.Background(), "ws://"+host)
proxiedConn, _, _, err := p.dialer.Dial(context.Background(), "ws://"+host)
if err != nil {
p.tracker.Error("Failed to dial upstream", "error", err)
return nil, err
Expand All @@ -36,6 +31,10 @@ func (p *WSProxier) ProxyDownstreamToUpstream() (net.Conn, error) {
default:
p.tracker.Debug("No one waiting for cancellation signal, skipping")
}
err := proxiedConn.Close()
if err != nil {
p.tracker.Error("Failed to close upstream connection", "error", err)
}
}

//TODO missing defer
Expand Down
12 changes: 11 additions & 1 deletion cmd/loadbalancer/connection/proxier.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package connection

import (
"bufio"
"context"
"net"

"github.com/gobwas/ws"
)

// Proxier manages bidirectional proxying of connections
Expand All @@ -10,14 +14,20 @@ type Proxier interface {
ProxyDownstreamToUpstream() (net.Conn, error)
}

type WSDialer interface {
Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, ws.Handshake, error)
}

// WSProxier implements Proxier for WebSocket connections
type WSProxier struct {
tracker ConnectionTracker
dialer WSDialer
}

// NewWSProxier creates a new WebSocket proxier
func NewWSProxier(tracker ConnectionTracker) *WSProxier {
func NewWSProxier(tracker ConnectionTracker, dialer WSDialer) *WSProxier {
return &WSProxier{
tracker: tracker,
dialer: dialer,
}
}
123 changes: 100 additions & 23 deletions cmd/loadbalancer/connection/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"log/slog"
"net"
"sync"
)

// Logger defines the logging behavior
Expand Down Expand Up @@ -40,59 +41,135 @@ type Tracker struct {
cancelFunc context.CancelFunc
ctx context.Context
cancelChan chan int
mu sync.RWMutex
}

// Create accessor methods without "Get" prefix (more idiomatic in Go)
func (t *Tracker) User() string { return t.user }
func (t *Tracker) UpstreamHost() string { return t.upstreamHost }
func (t *Tracker) DownstreamHost() string { return t.downstreamHost }
func (t *Tracker) UpstreamConn() net.Conn { return t.upstreamConn }
func (t *Tracker) DownstreamConn() net.Conn { return t.downstreamConn }
func (t *Tracker) UpstreamContext() context.Context { return t.ctx }
func (t *Tracker) UpstreamCancelChan() chan int { return t.cancelChan }
func (t *Tracker) SetUpstreamConn(conn net.Conn) { t.upstreamConn = conn }
func (t *Tracker) SetUpstreamHost(host string) { t.upstreamHost = host }
func (t *Tracker) SetDownstreamConn(conn net.Conn) { t.downstreamConn = conn }
func (t *Tracker) User() string {
t.mu.RLock()
defer t.mu.RUnlock()
return t.user
}

func (t *Tracker) UpstreamHost() string {
t.mu.RLock()
defer t.mu.RUnlock()
return t.upstreamHost
}

func (t *Tracker) DownstreamHost() string {
t.mu.RLock()
defer t.mu.RUnlock()
return t.downstreamHost
}

func (t *Tracker) UpstreamConn() net.Conn {
t.mu.RLock()
defer t.mu.RUnlock()
return t.upstreamConn
}

func (t *Tracker) DownstreamConn() net.Conn {
t.mu.RLock()
defer t.mu.RUnlock()
return t.downstreamConn
}

func (t *Tracker) UpstreamContext() context.Context {
t.mu.RLock()
defer t.mu.RUnlock()
return t.ctx
}

func (t *Tracker) UpstreamCancelChan() chan int {
t.mu.RLock()
defer t.mu.RUnlock()
return t.cancelChan
}

func (t *Tracker) SetUpstreamConn(conn net.Conn) {
t.mu.Lock()
defer t.mu.Unlock()
t.upstreamConn = conn
}

func (t *Tracker) SetUpstreamHost(host string) {
t.mu.Lock()
defer t.mu.Unlock()
t.upstreamHost = host
}

func (t *Tracker) SetDownstreamConn(conn net.Conn) {
t.mu.Lock()
defer t.mu.Unlock()
t.downstreamConn = conn
}

func (t *Tracker) SwitchUpstreamHost(host string) {
t.mu.Lock()
defer t.mu.Unlock()
t.cancelFunc()
t.ctx, t.cancelFunc = context.WithCancel(context.Background())
t.upstreamHost = host
}

// Logging methods with chaining
func (t *Tracker) Info(message string, args ...any) Logger {
slog.With("user", t.user).
With("upstreamHost", t.upstreamHost).
With("downstreamHost", t.downstreamHost).
t.mu.RLock()
user := t.user
upstreamHost := t.upstreamHost
downstreamHost := t.downstreamHost
t.mu.RUnlock()

slog.With("user", user).
With("upstreamHost", upstreamHost).
With("downstreamHost", downstreamHost).
With("component", "connection-tracker").
Info(message, args...)
return t
}

func (t *Tracker) Error(message string, args ...any) Logger {
slog.With("user", t.user).
With("upstreamHost", t.upstreamHost).
With("downstreamHost", t.downstreamHost).
t.mu.RLock()
user := t.user
upstreamHost := t.upstreamHost
downstreamHost := t.downstreamHost
t.mu.RUnlock()

slog.With("user", user).
With("upstreamHost", upstreamHost).
With("downstreamHost", downstreamHost).
With("component", "connection-tracker").
Error(message, args...)
return t
}

func (t *Tracker) Debug(message string, args ...any) Logger {
slog.With("user", t.user).
With("upstreamHost", t.upstreamHost).
With("downstreamHost", t.downstreamHost).
t.mu.RLock()
user := t.user
upstreamHost := t.upstreamHost
downstreamHost := t.downstreamHost
t.mu.RUnlock()

slog.With("user", user).
With("upstreamHost", upstreamHost).
With("downstreamHost", downstreamHost).
With("component", "connection-tracker").
Debug(message, args...)
return t
}

func (t *Tracker) Close() {
if t.upstreamConn != nil {
t.upstreamConn.Close()
t.mu.Lock()
upstreamConn := t.upstreamConn
downstreamConn := t.downstreamConn
t.mu.Unlock()

if upstreamConn != nil {
upstreamConn.Close()
}
if t.downstreamConn != nil {
t.downstreamConn.Close()
if downstreamConn != nil {
downstreamConn.Close()
}
}

Expand Down
Loading