Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.PHONY: all build-sidecar build-controller gen-certs

all: build-sidecar build-controller gen-certs
all: build-sidecar build-controller build-loadbalancer

build-sidecar:
@echo "Building WebSocket Proxy Sidecar..."
Expand All @@ -9,7 +9,7 @@ build-sidecar:
build-controller:
@echo "Building WebSocket Operator Controller..."
./scripts/build-controller.sh

build-loadbalancer:
@echo "Building WebSocket Operator LoadBalancer..."
./scripts/build-loadbalancer.sh
Expand Down
20 changes: 10 additions & 10 deletions cmd/loadbalancer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@ package main
import (
"flag"
"lukas8219/websocket-operator/cmd/loadbalancer/server"
"lukas8219/websocket-operator/internal/consistent_hashing"
"lukas8219/websocket-operator/internal/logger"
"lukas8219/websocket-operator/internal/route"
)

var (
router route.RouterImpl
"lukas8219/websocket-operator/internal/peer_discovery"
"lukas8219/websocket-operator/internal/resolver"
)

func main() {
port := flag.String("port", "3000", "Port to listen on")
mode := flag.String("mode", "kubernetes", "Mode to use")
// mode := flag.String("mode", "kubernetes", "Mode to use")
debug := flag.Bool("debug", false, "Debug mode")
flag.Parse()
logger.SetupLogger(*debug)
router = route.NewRouter(route.RouterConfig{Mode: route.RouterConfigMode(*mode)})
router.InitializeHosts()
peerDiscovery := peer_discovery.NewKubernetes("default", "ws-headless-proxy")
go peerDiscovery.Initialize()
resolver := resolver.New(peerDiscovery, consistent_hashing.NewJumpHash(peerDiscovery))
go resolver.Initialize()
server.StartServer(server.ServerConfig{
Router: router,
Port: *port,
Resolver: resolver,
Port: *port,
})
}
17 changes: 8 additions & 9 deletions cmd/loadbalancer/server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ package server
import (
"log/slog"
"lukas8219/websocket-operator/cmd/loadbalancer/connection"
"lukas8219/websocket-operator/internal/route"
"lukas8219/websocket-operator/internal/resolver"
"net/http"
"os"

"github.com/gobwas/ws"
)

func createHandler(router route.RouterImpl, connections map[string]*connection.Connection) http.HandlerFunc {
func createHandler(rslv resolver.Resolver, connections map[string]*connection.Connection) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handleConnection(router, connections, w, r)
handleConnection(rslv, connections, w, r)
}
}

func handleConnection(router route.RouterImpl, connections map[string]*connection.Connection, w http.ResponseWriter, r *http.Request) {
func handleConnection(rslv resolver.Resolver, connections map[string]*connection.Connection, w http.ResponseWriter, r *http.Request) {
user := r.Header.Get("ws-user-id")
if user == "" {
slog.Error("No user id provided")
Expand All @@ -26,9 +26,8 @@ func handleConnection(router route.RouterImpl, connections map[string]*connectio
//TODO: we should only accept `NewConnection` already with client connection and host set.`
//As only the `connection` pkg should alter it`.

host := router.Route(user)
slog.With("user", user).Debug("New connection")
if host == "" {
host, err := rslv.Lookup([]byte(user))
if err != nil {
slog.Error("No host found for user")
w.WriteHeader(http.StatusBadRequest)
return
Expand All @@ -38,7 +37,7 @@ func handleConnection(router route.RouterImpl, connections map[string]*connectio
upgrader := ws.HTTPUpgrader{
Header: http.Header{
"x-ws-operator-proxy-instance": []string{os.Getenv("HOSTNAME")},
"x-ws-operator-upstream-host": []string{host},
"x-ws-operator-upstream-host": []string{host.Hostname()},
},
}
downstreamConn, _, _, err := upgrader.Upgrade(r, w)
Expand All @@ -49,7 +48,7 @@ func handleConnection(router route.RouterImpl, connections map[string]*connectio
return
}

proxiedConnection := connection.NewConnection(user, host, downstreamConn.RemoteAddr().String(), downstreamConn)
proxiedConnection := connection.NewConnection(user, host.SocketAddres(), downstreamConn.RemoteAddr().String(), downstreamConn)
connections[user] = proxiedConnection

proxiedConnection.Debug("New connection")
Expand Down
72 changes: 33 additions & 39 deletions cmd/loadbalancer/server/rebalance.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,45 @@ package server
import (
"log/slog"
"lukas8219/websocket-operator/cmd/loadbalancer/connection"
"lukas8219/websocket-operator/internal/route"
rslv "lukas8219/websocket-operator/internal/resolver"
"time"
)

func handleRebalanceLoop(router route.RouterImpl, connections map[string]*connection.Connection) {
slog.Debug("Starting rebalance loop")
for {
select {
case hosts := <-router.RebalanceRequests():
slog.Debug("Received message to rebalance", "hosts", hosts)
upstreamHostsToConnectionTracker := make(map[string]*connection.Connection, len(connections))
slog.Debug("Flat mapping ConnectionTracker to upstreamHosts", "connections", connections)
for _, connectionTracker := range connections {
upstreamHostsToConnectionTracker[connectionTracker.User()] = connectionTracker
func handleRebalanceLoop(resolver rslv.Resolver, connections map[string]*connection.Connection) {
slog.Info("Starting rebalance loop")
for _ = range resolver.VersionUpgradeChannel() {
hosts, err := resolver.CurrentHosts()
if err != nil {
slog.Error(err.Error())
continue
}
slog.Debug("Received message to rebalance", "hosts", hosts)
upstreamHostsToConnectionTracker := make(map[string]*connection.Connection, len(connections))
for user, connectionTracker := range connections {
newHost, err := resolver.Lookup([]byte(user))
if err != nil {
panic(err) //TODO
}
for _, affectedHost := range hosts {
recipientId := affectedHost[0]
newHost := affectedHost[1]
connectionTracker := upstreamHostsToConnectionTracker[recipientId]
if connectionTracker == nil {
slog.Debug("No connection tracker found", "user", recipientId)
continue
}
oldHost := connectionTracker.UpstreamHost()
if connectionTracker.UpstreamHost() == newHost {
connectionTracker.Debug("No need to rebalance")
continue
}
connectionTracker.Debug("Waiting for upstream to cancel", "oldHost", oldHost)
connectionTracker.SwitchUpstreamHost(newHost)

select {
case <-connectionTracker.UpstreamCancelChan():
connectionTracker.Debug("Successfully received cancellation signal")
case <-time.After(5 * time.Second):
connectionTracker.Error("Timeout waiting for upstream cancellation, proceeding anyway")
}
previousHost := connectionTracker.UpstreamHost()
if previousHost == newHost.SocketAddres() {
connectionTracker.Debug("No need to rebalance")
continue
}
connectionTracker.Debug("Waiting for upstream to cancel", "previous", previousHost)
connectionTracker.SwitchUpstreamHost(newHost.SocketAddres())

connections[recipientId] = connectionTracker
//TODO: gut feeling here. either we move rebalance to the connection pkg or we re-design stuff
//connectionTracker.UpstreamContext, connectionTracker.CancelUpstream = context.WithCancel(context.Background())
connectionTracker.Info("Rebalancing connection from", "old", oldHost, "new", newHost)
//TODO: stopping down -> up could cause issues if this is mid read/write
go connectionTracker.Handle()
select {
case <-connectionTracker.UpstreamCancelChan():
connectionTracker.Debug("Successfully received cancellation signal")
case <-time.After(5 * time.Second):
connectionTracker.Error("Timeout waiting for upstream cancellation, proceeding anyway")
}
//TODO: gut feeling here. either we move rebalance to the connection pkg or we re-design stuff
//connectionTracker.UpstreamContext, connectionTracker.CancelUpstream = context.WithCancel(context.Background())
connectionTracker.Info("Rebalancing connection from", "previous", previousHost, "new", newHost)
//TODO: stopping down -> up could cause issues if this is mid read/write
go connectionTracker.Handle()
upstreamHostsToConnectionTracker[connectionTracker.User()] = connectionTracker
}
}
}
38 changes: 16 additions & 22 deletions cmd/loadbalancer/server/rebalance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"context"
"io"
"log"
"log/slog"
"lukas8219/websocket-operator/cmd/loadbalancer/connection"
"lukas8219/websocket-operator/internal/consistent_hashing"
"lukas8219/websocket-operator/internal/peer_discovery"
"lukas8219/websocket-operator/internal/resolver"
"net"
"sync"
"testing"
Expand All @@ -15,22 +17,6 @@ import (
"github.com/gobwas/ws"
)

type MockRouter struct {
rebalanceChan chan [][2]string
*slog.Logger
}

func (m *MockRouter) RebalanceRequests() <-chan [][2]string {
return m.rebalanceChan
}

func (m *MockRouter) Route(string) string { return "" }
func (m *MockRouter) Add([]string) {}
func (m *MockRouter) GetAllUpstreamHosts() []string {
return []string{}
}
func (m *MockRouter) InitializeHosts() error { return nil }

type NetConnectionMock struct {
net.Conn
remoteAddr net.Addr
Expand Down Expand Up @@ -118,12 +104,15 @@ func NewMockConnection(user, upstreamHost string, downstreamConn net.Conn, wsDia
}

func TestHandleRebalanceLoop(t *testing.T) {
mockRouter := &MockRouter{
rebalanceChan: make(chan [][2]string, 1),
}
memoryDiscoveryBackend := peer_discovery.NewInMemoryPeerDiscovery()
mockResolver := resolver.New(
memoryDiscoveryBackend,
consistent_hashing.NewJumpHash(memoryDiscoveryBackend),
)
go mockResolver.Init()
connections := make(map[string]*connection.Connection)

go handleRebalanceLoop(mockRouter, connections)
go handleRebalanceLoop(mockResolver, connections)

t.Run("Sucessfully rebalanced", func(t *testing.T) {
mockDownstreamConn := &NetConnectionMock{
Expand All @@ -139,7 +128,12 @@ func TestHandleRebalanceLoop(t *testing.T) {

mockConn.Tracker.UpstreamCancelChan() <- 1
time.Sleep(100 * time.Millisecond)
mockRouter.rebalanceChan <- [][2]string{{mockConn.Tracker.User(), "new-host:3000"}}
t.Log("Before atomic")
memoryDiscoveryBackend.AtomicOperation(
[]peer_discovery.Peer{peer_discovery.NewPeer("new-host", 3000)},
[]peer_discovery.Peer{peer_discovery.NewPeer("old-host", 3000)},
)
t.Log("After atomic")
time.Sleep(100 * time.Millisecond)

if mockConn.UpstreamHost() != "new-host:3000" {
Expand Down
12 changes: 6 additions & 6 deletions cmd/loadbalancer/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ package server
import (
"log/slog"
"lukas8219/websocket-operator/cmd/loadbalancer/connection"
"lukas8219/websocket-operator/internal/route"
"lukas8219/websocket-operator/internal/resolver"
"net/http"
)

type ServerConfig struct {
Router route.RouterImpl
Port string
Resolver resolver.Resolver
Port string
}

func StartServer(config ServerConfig) {
slog.Info("Starting load balancer server", "port", config.Port)
router := config.Router
connections := make(map[string]*connection.Connection) //TODO: This could be a broadcast instead of a single recipient/connection
go handleRebalanceLoop(router, connections)

go handleRebalanceLoop(config.Resolver, connections)
//TODO how to properly test this - aka not having a server running at all
http.ListenAndServe("0.0.0.0:"+config.Port, createHandler(router, connections))
http.ListenAndServe("0.0.0.0:"+config.Port, createHandler(config.Resolver, connections))
}
14 changes: 11 additions & 3 deletions cmd/sidecar/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"flag"
"io"
"log/slog"
"lukas8219/websocket-operator/internal/consistent_hashing"
"lukas8219/websocket-operator/internal/logger"
"lukas8219/websocket-operator/internal/peer_discovery"
"lukas8219/websocket-operator/internal/rendezvous"
"lukas8219/websocket-operator/internal/resolver"
"lukas8219/websocket-operator/internal/transports"
"net"
Expand Down Expand Up @@ -61,8 +61,16 @@ func main() {
logger.SetupLogger(*debug)
//TODO move to config
peerDiscovery := peer_discovery.NewKubernetes("default", "ws-headless-proxy")
resolver := resolver.New(peerDiscovery, rendezvous.NewDefault())
transport := transports.NewHTTPTransport(*resolver)
err := peerDiscovery.Initialize()
if err != nil {
panic(err) //TODO Better handling
}
resolver := resolver.New(peerDiscovery, consistent_hashing.NewJumpHash(peerDiscovery))
err = resolver.Initialize()
if err != nil {
panic(err)
}
transport := transports.NewHTTPTransport(resolver)

slog.Info("Starting server", "port", *port)
// Map to store active WebSocket connections
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/buraksezer/consistent v0.10.0
github.com/gobwas/ws v1.4.0
github.com/hashicorp/go-set/v3 v3.0.1
github.com/lithammer/go-jump-consistent-hash v1.0.2
github.com/zeebo/xxh3 v1.0.2
k8s.io/api v0.32.3
k8s.io/apimachinery v0.32.3
Expand All @@ -30,6 +31,7 @@ require (
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/go-set/v3 v3.0.1 h1:ZwO15ZYmIrFYL9zSm2wBuwcRiHxVdp46m/XA/MUlM6I=
github.com/hashicorp/go-set/v3 v3.0.1/go.mod h1:0oPQqhtitglZeT2ZiWnRIfUG6gJAHnn7LzrS7SbgNY4=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand All @@ -60,6 +62,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lithammer/go-jump-consistent-hash v1.0.2 h1:w74N9XiMa4dWZdoVnfLbnDhfpGOMCxlrudzt2e7wtyk=
github.com/lithammer/go-jump-consistent-hash v1.0.2/go.mod h1:4MD1WDikNGnb9D56hAtscaZaOWOiCG+lLbRR5ZN9JL0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
Expand Down
6 changes: 6 additions & 0 deletions internal/consistent_hashing/consistent_hashing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package consistent_hashing

type ConsistentHashing[T any] interface {
Lookup([]byte) (T, error)
Transaction(Add, Remove []T)
}
39 changes: 39 additions & 0 deletions internal/consistent_hashing/jump_hash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package consistent_hashing

import (
"fmt"
"lukas8219/websocket-operator/internal/peer_discovery"

"github.com/lithammer/go-jump-consistent-hash"
)

type JumpHashing struct {
peerDiscoveryBacked peer_discovery.PeerDiscovery
}

func NewJumpHash(peerDiscoveryBackend peer_discovery.PeerDiscovery) JumpHashing {
return JumpHashing{
peerDiscoveryBacked: peerDiscoveryBackend,
}
}

func (j JumpHashing) Lookup(Recipient []byte) (peer_discovery.Peer, error) {
//TODO handle it
currentHosts, err := j.peerDiscoveryBacked.CurrentHosts()
if err != nil {
//We need to handle errrs
return peer_discovery.Peer{}, err
}
var myUint64 uint64 = 1
index := jump.Hash(
myUint64,
int32(len(currentHosts)),
)
if index == -1 {
return peer_discovery.Peer{}, fmt.Errorf("Didn't find any Peer to route")
}
return currentHosts[index], nil
}

func (j JumpHashing) Transaction(Add, Remove []peer_discovery.Peer) {
}
Loading