diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f09cce4..529b182 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -36,7 +36,10 @@ jobs: run: go build -v ./... - name: Run tests - run: go test -v ./... + run: make test + + - name: Run tests with race detector + run: make test-race - name: Run vet run: go vet ./... diff --git a/Makefile b/Makefile index 9f4fe69..6a51095 100644 --- a/Makefile +++ b/Makefile @@ -18,3 +18,10 @@ gen-certs: @echo "Generating TLS certificates..." ./scripts/gen-certs.sh +test: + @echo "Running tests..." + go test -v ./... + +test-race: + @echo "Running tests with race detector..." + go test -race -v ./... diff --git a/cmd/loadbalancer/connection/connection.go b/cmd/loadbalancer/connection/connection.go index d615c18..e3bfca6 100644 --- a/cmd/loadbalancer/connection/connection.go +++ b/cmd/loadbalancer/connection/connection.go @@ -2,6 +2,8 @@ package connection import ( "net" + + "github.com/gobwas/ws" ) // Connection combines tracking and proxying capabilities @@ -13,7 +15,12 @@ type Connection struct { // NewConnection creates a fully configured connection func NewConnection(user, upstreamHost, downstreamHost string, downstreamConn net.Conn) *Connection { tracker := NewTracker(user, upstreamHost, downstreamHost, downstreamConn) - proxier := NewWSProxier(tracker) + dialer := ws.Dialer{ + Header: ws.HandshakeHeaderHTTP{ + "ws-user-id": []string{user}, + }, + } + proxier := NewWSProxier(tracker, &dialer) return &Connection{ Tracker: tracker, @@ -23,8 +30,6 @@ func NewConnection(user, upstreamHost, downstreamHost string, downstreamConn net // Handle manages the connection lifecycle func (c *Connection) Handle() { - //defer c.Close() - proxiedConn, err := c.ProxyDownstreamToUpstream() if err != nil { return diff --git a/cmd/loadbalancer/connection/downstream_proxier.go b/cmd/loadbalancer/connection/downstream_proxier.go index aaed11b..a34a53e 100644 --- a/cmd/loadbalancer/connection/downstream_proxier.go +++ b/cmd/loadbalancer/connection/downstream_proxier.go @@ -10,18 +10,13 @@ import ( func (p *WSProxier) ProxyDownstreamToUpstream() (net.Conn, error) { host := p.tracker.UpstreamHost() - dialer := ws.Dialer{ - Header: ws.HandshakeHeaderHTTP{ - "ws-user-id": []string{p.tracker.User()}, - }, - } p.tracker.Debug("Dialing upstream") upstreamContext := p.tracker.UpstreamContext() upstreamCancelChan := p.tracker.UpstreamCancelChan() downstreamConn := p.tracker.DownstreamConn() - proxiedConn, _, _, err := dialer.Dial(context.Background(), "ws://"+host) + proxiedConn, _, _, err := p.dialer.Dial(context.Background(), "ws://"+host) if err != nil { p.tracker.Error("Failed to dial upstream", "error", err) return nil, err @@ -36,6 +31,10 @@ func (p *WSProxier) ProxyDownstreamToUpstream() (net.Conn, error) { default: p.tracker.Debug("No one waiting for cancellation signal, skipping") } + err := proxiedConn.Close() + if err != nil { + p.tracker.Error("Failed to close upstream connection", "error", err) + } } //TODO missing defer diff --git a/cmd/loadbalancer/connection/proxier.go b/cmd/loadbalancer/connection/proxier.go index aee3d9c..43a110f 100644 --- a/cmd/loadbalancer/connection/proxier.go +++ b/cmd/loadbalancer/connection/proxier.go @@ -1,7 +1,11 @@ package connection import ( + "bufio" + "context" "net" + + "github.com/gobwas/ws" ) // Proxier manages bidirectional proxying of connections @@ -10,14 +14,20 @@ type Proxier interface { ProxyDownstreamToUpstream() (net.Conn, error) } +type WSDialer interface { + Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, ws.Handshake, error) +} + // WSProxier implements Proxier for WebSocket connections type WSProxier struct { tracker ConnectionTracker + dialer WSDialer } // NewWSProxier creates a new WebSocket proxier -func NewWSProxier(tracker ConnectionTracker) *WSProxier { +func NewWSProxier(tracker ConnectionTracker, dialer WSDialer) *WSProxier { return &WSProxier{ tracker: tracker, + dialer: dialer, } } diff --git a/cmd/loadbalancer/connection/tracker.go b/cmd/loadbalancer/connection/tracker.go index 5c34426..3d06b39 100644 --- a/cmd/loadbalancer/connection/tracker.go +++ b/cmd/loadbalancer/connection/tracker.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "net" + "sync" ) // Logger defines the logging behavior @@ -40,20 +41,73 @@ type Tracker struct { cancelFunc context.CancelFunc ctx context.Context cancelChan chan int + mu sync.RWMutex } // Create accessor methods without "Get" prefix (more idiomatic in Go) -func (t *Tracker) User() string { return t.user } -func (t *Tracker) UpstreamHost() string { return t.upstreamHost } -func (t *Tracker) DownstreamHost() string { return t.downstreamHost } -func (t *Tracker) UpstreamConn() net.Conn { return t.upstreamConn } -func (t *Tracker) DownstreamConn() net.Conn { return t.downstreamConn } -func (t *Tracker) UpstreamContext() context.Context { return t.ctx } -func (t *Tracker) UpstreamCancelChan() chan int { return t.cancelChan } -func (t *Tracker) SetUpstreamConn(conn net.Conn) { t.upstreamConn = conn } -func (t *Tracker) SetUpstreamHost(host string) { t.upstreamHost = host } -func (t *Tracker) SetDownstreamConn(conn net.Conn) { t.downstreamConn = conn } +func (t *Tracker) User() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.user +} + +func (t *Tracker) UpstreamHost() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.upstreamHost +} + +func (t *Tracker) DownstreamHost() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.downstreamHost +} + +func (t *Tracker) UpstreamConn() net.Conn { + t.mu.RLock() + defer t.mu.RUnlock() + return t.upstreamConn +} + +func (t *Tracker) DownstreamConn() net.Conn { + t.mu.RLock() + defer t.mu.RUnlock() + return t.downstreamConn +} + +func (t *Tracker) UpstreamContext() context.Context { + t.mu.RLock() + defer t.mu.RUnlock() + return t.ctx +} + +func (t *Tracker) UpstreamCancelChan() chan int { + t.mu.RLock() + defer t.mu.RUnlock() + return t.cancelChan +} + +func (t *Tracker) SetUpstreamConn(conn net.Conn) { + t.mu.Lock() + defer t.mu.Unlock() + t.upstreamConn = conn +} + +func (t *Tracker) SetUpstreamHost(host string) { + t.mu.Lock() + defer t.mu.Unlock() + t.upstreamHost = host +} + +func (t *Tracker) SetDownstreamConn(conn net.Conn) { + t.mu.Lock() + defer t.mu.Unlock() + t.downstreamConn = conn +} + func (t *Tracker) SwitchUpstreamHost(host string) { + t.mu.Lock() + defer t.mu.Unlock() t.cancelFunc() t.ctx, t.cancelFunc = context.WithCancel(context.Background()) t.upstreamHost = host @@ -61,38 +115,61 @@ func (t *Tracker) SwitchUpstreamHost(host string) { // Logging methods with chaining func (t *Tracker) Info(message string, args ...any) Logger { - slog.With("user", t.user). - With("upstreamHost", t.upstreamHost). - With("downstreamHost", t.downstreamHost). + t.mu.RLock() + user := t.user + upstreamHost := t.upstreamHost + downstreamHost := t.downstreamHost + t.mu.RUnlock() + + slog.With("user", user). + With("upstreamHost", upstreamHost). + With("downstreamHost", downstreamHost). With("component", "connection-tracker"). Info(message, args...) return t } func (t *Tracker) Error(message string, args ...any) Logger { - slog.With("user", t.user). - With("upstreamHost", t.upstreamHost). - With("downstreamHost", t.downstreamHost). + t.mu.RLock() + user := t.user + upstreamHost := t.upstreamHost + downstreamHost := t.downstreamHost + t.mu.RUnlock() + + slog.With("user", user). + With("upstreamHost", upstreamHost). + With("downstreamHost", downstreamHost). With("component", "connection-tracker"). Error(message, args...) return t } func (t *Tracker) Debug(message string, args ...any) Logger { - slog.With("user", t.user). - With("upstreamHost", t.upstreamHost). - With("downstreamHost", t.downstreamHost). + t.mu.RLock() + user := t.user + upstreamHost := t.upstreamHost + downstreamHost := t.downstreamHost + t.mu.RUnlock() + + slog.With("user", user). + With("upstreamHost", upstreamHost). + With("downstreamHost", downstreamHost). With("component", "connection-tracker"). Debug(message, args...) return t } func (t *Tracker) Close() { - if t.upstreamConn != nil { - t.upstreamConn.Close() + t.mu.Lock() + upstreamConn := t.upstreamConn + downstreamConn := t.downstreamConn + t.mu.Unlock() + + if upstreamConn != nil { + upstreamConn.Close() } - if t.downstreamConn != nil { - t.downstreamConn.Close() + if downstreamConn != nil { + downstreamConn.Close() } } diff --git a/cmd/loadbalancer/server/rebalance_test.go b/cmd/loadbalancer/server/rebalance_test.go index 8ed533b..41c2acb 100644 --- a/cmd/loadbalancer/server/rebalance_test.go +++ b/cmd/loadbalancer/server/rebalance_test.go @@ -1,11 +1,16 @@ package server import ( + "bufio" + "context" "log/slog" "lukas8219/websocket-operator/cmd/loadbalancer/connection" "net" + "sync" "testing" "time" + + "github.com/gobwas/ws" ) type MockRouter struct { @@ -24,29 +29,49 @@ func (m *MockRouter) GetAllUpstreamHosts() []string { } func (m *MockRouter) InitializeHosts() error { return nil } -type MockConnection struct { - *connection.Connection - upstreamHost string - user string - upstreamCancelChan chan struct{} +type NetConnectionMock struct { + net.Conn + remoteAddr net.Addr + isClosed bool +} + +func (m *NetConnectionMock) Read(b []byte) (int, error) { + return 0, nil +} + +func (m *NetConnectionMock) Write(b []byte) (int, error) { + return 0, nil } -type MockProxier struct{} +func (m *NetConnectionMock) RemoteAddr() net.Addr { + return m.remoteAddr +} -func (m *MockProxier) ProxyDownstreamToUpstream() (net.Conn, error) { - return nil, nil +func (m *NetConnectionMock) Close() error { + m.isClosed = true + return nil } -func (m *MockProxier) ProxyUpstreamToDownstream() { +type MockWSDialer struct { + dialCalls []string + connections []*NetConnectionMock + mu sync.RWMutex +} +func (m *MockWSDialer) Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, ws.Handshake, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.dialCalls = append(m.dialCalls, urlstr) + mockConn := &NetConnectionMock{remoteAddr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}} + m.connections = append(m.connections, mockConn) + return mockConn, nil, ws.Handshake{}, nil } -func NewMockConnection(user, upstreamHost string) *connection.Connection { - conn := &net.TCPConn{} - tracker := connection.NewTracker(user, upstreamHost, "downstream", conn) +func NewMockConnection(user, upstreamHost string, downstreamConn net.Conn, wsDialer *MockWSDialer) *connection.Connection { + tracker := connection.NewTracker(user, upstreamHost, "downstream", downstreamConn) return &connection.Connection{ Tracker: tracker, - Proxier: &MockProxier{}, + Proxier: connection.NewWSProxier(tracker, wsDialer), } } @@ -58,32 +83,27 @@ func TestHandleRebalanceLoop(t *testing.T) { go handleRebalanceLoop(mockRouter, connections) - t.Run("No connection tracker found", func(t *testing.T) { - mockRouter.rebalanceChan <- [][2]string{{"non-existent", "new-host:3000"}} - }) - - t.Run("No need to rebalance - same host", func(t *testing.T) { - mockConn := NewMockConnection("user1", "same-host:3000") - connections["user1"] = mockConn - - mockRouter.rebalanceChan <- [][2]string{{"user1", "same-host:3000"}} - }) - - t.Run("Successfully received cancellation signal", func(t *testing.T) { - mockConn := NewMockConnection("user2", "old-host:3000") - connections["user2"] = mockConn - - go func() { - mockConn.Tracker.UpstreamCancelChan() <- 1 - }() - mockRouter.rebalanceChan <- [][2]string{{"user2", "new-host:3000"}} + t.Run("Sucessfully rebalanced", func(t *testing.T) { + mockDownstreamConn := &NetConnectionMock{remoteAddr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}} + mockWSDialer := &MockWSDialer{} + mockConn := NewMockConnection("user1", "old-host:3000", mockDownstreamConn, mockWSDialer) + go mockConn.Handle() + connections[mockConn.Tracker.User()] = mockConn + mockConn.Tracker.UpstreamCancelChan() <- 1 + mockRouter.rebalanceChan <- [][2]string{{mockConn.Tracker.User(), "new-host:3000"}} time.Sleep(100 * time.Millisecond) - //Expect Switch Host + Handle to be called if mockConn.UpstreamHost() != "new-host:3000" { t.Errorf("Expected host to be updated to new-host:3000, got %s", mockConn.UpstreamHost()) } + if len(mockWSDialer.dialCalls) != 2 { + t.Errorf("Expected dial to be called twice, got %d", len(mockWSDialer.dialCalls)) + return + } + if mockWSDialer.dialCalls[1] != "ws://new-host:3000" { + t.Errorf("Expected dial to be called with ws://new-host:3000, got %s", mockWSDialer.dialCalls[1]) + } + }) - //TODO: change interface to use io.ReadWriter and inject Dialer so we can test new connections }