diff --git a/Makefile b/Makefile index 6a51095..0a26d0d 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: all build-sidecar build-controller gen-certs -all: build-sidecar build-controller gen-certs +all: build-sidecar build-controller build-loadbalancer build-sidecar: @echo "Building WebSocket Proxy Sidecar..." @@ -9,7 +9,7 @@ build-sidecar: build-controller: @echo "Building WebSocket Operator Controller..." ./scripts/build-controller.sh - + build-loadbalancer: @echo "Building WebSocket Operator LoadBalancer..." ./scripts/build-loadbalancer.sh diff --git a/cmd/loadbalancer/main.go b/cmd/loadbalancer/main.go index 913c6fd..0be3b81 100644 --- a/cmd/loadbalancer/main.go +++ b/cmd/loadbalancer/main.go @@ -3,24 +3,24 @@ package main import ( "flag" "lukas8219/websocket-operator/cmd/loadbalancer/server" + "lukas8219/websocket-operator/internal/consistent_hashing" "lukas8219/websocket-operator/internal/logger" - "lukas8219/websocket-operator/internal/route" -) - -var ( - router route.RouterImpl + "lukas8219/websocket-operator/internal/peer_discovery" + "lukas8219/websocket-operator/internal/resolver" ) func main() { port := flag.String("port", "3000", "Port to listen on") - mode := flag.String("mode", "kubernetes", "Mode to use") + // mode := flag.String("mode", "kubernetes", "Mode to use") debug := flag.Bool("debug", false, "Debug mode") flag.Parse() logger.SetupLogger(*debug) - router = route.NewRouter(route.RouterConfig{Mode: route.RouterConfigMode(*mode)}) - router.InitializeHosts() + peerDiscovery := peer_discovery.NewKubernetes("default", "ws-headless-proxy") + go peerDiscovery.Initialize() + resolver := resolver.New(peerDiscovery, consistent_hashing.NewJumpHash(peerDiscovery)) + go resolver.Initialize() server.StartServer(server.ServerConfig{ - Router: router, - Port: *port, + Resolver: resolver, + Port: *port, }) } diff --git a/cmd/loadbalancer/server/handler.go b/cmd/loadbalancer/server/handler.go index 5ee8d8a..b4dc3eb 100644 --- a/cmd/loadbalancer/server/handler.go +++ b/cmd/loadbalancer/server/handler.go @@ -3,20 +3,20 @@ package server import ( "log/slog" "lukas8219/websocket-operator/cmd/loadbalancer/connection" - "lukas8219/websocket-operator/internal/route" + "lukas8219/websocket-operator/internal/resolver" "net/http" "os" "github.com/gobwas/ws" ) -func createHandler(router route.RouterImpl, connections map[string]*connection.Connection) http.HandlerFunc { +func createHandler(rslv resolver.Resolver, connections map[string]*connection.Connection) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - handleConnection(router, connections, w, r) + handleConnection(rslv, connections, w, r) } } -func handleConnection(router route.RouterImpl, connections map[string]*connection.Connection, w http.ResponseWriter, r *http.Request) { +func handleConnection(rslv resolver.Resolver, connections map[string]*connection.Connection, w http.ResponseWriter, r *http.Request) { user := r.Header.Get("ws-user-id") if user == "" { slog.Error("No user id provided") @@ -26,9 +26,8 @@ func handleConnection(router route.RouterImpl, connections map[string]*connectio //TODO: we should only accept `NewConnection` already with client connection and host set.` //As only the `connection` pkg should alter it`. - host := router.Route(user) - slog.With("user", user).Debug("New connection") - if host == "" { + host, err := rslv.Lookup([]byte(user)) + if err != nil { slog.Error("No host found for user") w.WriteHeader(http.StatusBadRequest) return @@ -38,7 +37,7 @@ func handleConnection(router route.RouterImpl, connections map[string]*connectio upgrader := ws.HTTPUpgrader{ Header: http.Header{ "x-ws-operator-proxy-instance": []string{os.Getenv("HOSTNAME")}, - "x-ws-operator-upstream-host": []string{host}, + "x-ws-operator-upstream-host": []string{host.Hostname()}, }, } downstreamConn, _, _, err := upgrader.Upgrade(r, w) @@ -49,7 +48,7 @@ func handleConnection(router route.RouterImpl, connections map[string]*connectio return } - proxiedConnection := connection.NewConnection(user, host, downstreamConn.RemoteAddr().String(), downstreamConn) + proxiedConnection := connection.NewConnection(user, host.SocketAddres(), downstreamConn.RemoteAddr().String(), downstreamConn) connections[user] = proxiedConnection proxiedConnection.Debug("New connection") diff --git a/cmd/loadbalancer/server/rebalance.go b/cmd/loadbalancer/server/rebalance.go index 1122a9a..dfd4a2c 100644 --- a/cmd/loadbalancer/server/rebalance.go +++ b/cmd/loadbalancer/server/rebalance.go @@ -3,51 +3,45 @@ package server import ( "log/slog" "lukas8219/websocket-operator/cmd/loadbalancer/connection" - "lukas8219/websocket-operator/internal/route" + rslv "lukas8219/websocket-operator/internal/resolver" "time" ) -func handleRebalanceLoop(router route.RouterImpl, connections map[string]*connection.Connection) { - slog.Debug("Starting rebalance loop") - for { - select { - case hosts := <-router.RebalanceRequests(): - slog.Debug("Received message to rebalance", "hosts", hosts) - upstreamHostsToConnectionTracker := make(map[string]*connection.Connection, len(connections)) - slog.Debug("Flat mapping ConnectionTracker to upstreamHosts", "connections", connections) - for _, connectionTracker := range connections { - upstreamHostsToConnectionTracker[connectionTracker.User()] = connectionTracker +func handleRebalanceLoop(resolver rslv.Resolver, connections map[string]*connection.Connection) { + slog.Info("Starting rebalance loop") + for _ = range resolver.VersionUpgradeChannel() { + hosts, err := resolver.CurrentHosts() + if err != nil { + slog.Error(err.Error()) + continue + } + slog.Debug("Received message to rebalance", "hosts", hosts) + upstreamHostsToConnectionTracker := make(map[string]*connection.Connection, len(connections)) + for user, connectionTracker := range connections { + newHost, err := resolver.Lookup([]byte(user)) + if err != nil { + panic(err) //TODO } - for _, affectedHost := range hosts { - recipientId := affectedHost[0] - newHost := affectedHost[1] - connectionTracker := upstreamHostsToConnectionTracker[recipientId] - if connectionTracker == nil { - slog.Debug("No connection tracker found", "user", recipientId) - continue - } - oldHost := connectionTracker.UpstreamHost() - if connectionTracker.UpstreamHost() == newHost { - connectionTracker.Debug("No need to rebalance") - continue - } - connectionTracker.Debug("Waiting for upstream to cancel", "oldHost", oldHost) - connectionTracker.SwitchUpstreamHost(newHost) - - select { - case <-connectionTracker.UpstreamCancelChan(): - connectionTracker.Debug("Successfully received cancellation signal") - case <-time.After(5 * time.Second): - connectionTracker.Error("Timeout waiting for upstream cancellation, proceeding anyway") - } + previousHost := connectionTracker.UpstreamHost() + if previousHost == newHost.SocketAddres() { + connectionTracker.Debug("No need to rebalance") + continue + } + connectionTracker.Debug("Waiting for upstream to cancel", "previous", previousHost) + connectionTracker.SwitchUpstreamHost(newHost.SocketAddres()) - connections[recipientId] = connectionTracker - //TODO: gut feeling here. either we move rebalance to the connection pkg or we re-design stuff - //connectionTracker.UpstreamContext, connectionTracker.CancelUpstream = context.WithCancel(context.Background()) - connectionTracker.Info("Rebalancing connection from", "old", oldHost, "new", newHost) - //TODO: stopping down -> up could cause issues if this is mid read/write - go connectionTracker.Handle() + select { + case <-connectionTracker.UpstreamCancelChan(): + connectionTracker.Debug("Successfully received cancellation signal") + case <-time.After(5 * time.Second): + connectionTracker.Error("Timeout waiting for upstream cancellation, proceeding anyway") } + //TODO: gut feeling here. either we move rebalance to the connection pkg or we re-design stuff + //connectionTracker.UpstreamContext, connectionTracker.CancelUpstream = context.WithCancel(context.Background()) + connectionTracker.Info("Rebalancing connection from", "previous", previousHost, "new", newHost) + //TODO: stopping down -> up could cause issues if this is mid read/write + go connectionTracker.Handle() + upstreamHostsToConnectionTracker[connectionTracker.User()] = connectionTracker } } } diff --git a/cmd/loadbalancer/server/rebalance_test.go b/cmd/loadbalancer/server/rebalance_test.go index d2501cc..3899fa7 100644 --- a/cmd/loadbalancer/server/rebalance_test.go +++ b/cmd/loadbalancer/server/rebalance_test.go @@ -5,8 +5,10 @@ import ( "context" "io" "log" - "log/slog" "lukas8219/websocket-operator/cmd/loadbalancer/connection" + "lukas8219/websocket-operator/internal/consistent_hashing" + "lukas8219/websocket-operator/internal/peer_discovery" + "lukas8219/websocket-operator/internal/resolver" "net" "sync" "testing" @@ -15,22 +17,6 @@ import ( "github.com/gobwas/ws" ) -type MockRouter struct { - rebalanceChan chan [][2]string - *slog.Logger -} - -func (m *MockRouter) RebalanceRequests() <-chan [][2]string { - return m.rebalanceChan -} - -func (m *MockRouter) Route(string) string { return "" } -func (m *MockRouter) Add([]string) {} -func (m *MockRouter) GetAllUpstreamHosts() []string { - return []string{} -} -func (m *MockRouter) InitializeHosts() error { return nil } - type NetConnectionMock struct { net.Conn remoteAddr net.Addr @@ -118,12 +104,15 @@ func NewMockConnection(user, upstreamHost string, downstreamConn net.Conn, wsDia } func TestHandleRebalanceLoop(t *testing.T) { - mockRouter := &MockRouter{ - rebalanceChan: make(chan [][2]string, 1), - } + memoryDiscoveryBackend := peer_discovery.NewInMemoryPeerDiscovery() + mockResolver := resolver.New( + memoryDiscoveryBackend, + consistent_hashing.NewJumpHash(memoryDiscoveryBackend), + ) + go mockResolver.Init() connections := make(map[string]*connection.Connection) - go handleRebalanceLoop(mockRouter, connections) + go handleRebalanceLoop(mockResolver, connections) t.Run("Sucessfully rebalanced", func(t *testing.T) { mockDownstreamConn := &NetConnectionMock{ @@ -139,7 +128,12 @@ func TestHandleRebalanceLoop(t *testing.T) { mockConn.Tracker.UpstreamCancelChan() <- 1 time.Sleep(100 * time.Millisecond) - mockRouter.rebalanceChan <- [][2]string{{mockConn.Tracker.User(), "new-host:3000"}} + t.Log("Before atomic") + memoryDiscoveryBackend.AtomicOperation( + []peer_discovery.Peer{peer_discovery.NewPeer("new-host", 3000)}, + []peer_discovery.Peer{peer_discovery.NewPeer("old-host", 3000)}, + ) + t.Log("After atomic") time.Sleep(100 * time.Millisecond) if mockConn.UpstreamHost() != "new-host:3000" { diff --git a/cmd/loadbalancer/server/server.go b/cmd/loadbalancer/server/server.go index aba44a3..060fbd6 100644 --- a/cmd/loadbalancer/server/server.go +++ b/cmd/loadbalancer/server/server.go @@ -3,20 +3,20 @@ package server import ( "log/slog" "lukas8219/websocket-operator/cmd/loadbalancer/connection" - "lukas8219/websocket-operator/internal/route" + "lukas8219/websocket-operator/internal/resolver" "net/http" ) type ServerConfig struct { - Router route.RouterImpl - Port string + Resolver resolver.Resolver + Port string } func StartServer(config ServerConfig) { slog.Info("Starting load balancer server", "port", config.Port) - router := config.Router connections := make(map[string]*connection.Connection) //TODO: This could be a broadcast instead of a single recipient/connection - go handleRebalanceLoop(router, connections) + + go handleRebalanceLoop(config.Resolver, connections) //TODO how to properly test this - aka not having a server running at all - http.ListenAndServe("0.0.0.0:"+config.Port, createHandler(router, connections)) + http.ListenAndServe("0.0.0.0:"+config.Port, createHandler(config.Resolver, connections)) } diff --git a/cmd/sidecar/main.go b/cmd/sidecar/main.go index 0e5acd8..b938696 100644 --- a/cmd/sidecar/main.go +++ b/cmd/sidecar/main.go @@ -6,9 +6,9 @@ import ( "flag" "io" "log/slog" + "lukas8219/websocket-operator/internal/consistent_hashing" "lukas8219/websocket-operator/internal/logger" "lukas8219/websocket-operator/internal/peer_discovery" - "lukas8219/websocket-operator/internal/rendezvous" "lukas8219/websocket-operator/internal/resolver" "lukas8219/websocket-operator/internal/transports" "net" @@ -61,8 +61,16 @@ func main() { logger.SetupLogger(*debug) //TODO move to config peerDiscovery := peer_discovery.NewKubernetes("default", "ws-headless-proxy") - resolver := resolver.New(peerDiscovery, rendezvous.NewDefault()) - transport := transports.NewHTTPTransport(*resolver) + err := peerDiscovery.Initialize() + if err != nil { + panic(err) //TODO Better handling + } + resolver := resolver.New(peerDiscovery, consistent_hashing.NewJumpHash(peerDiscovery)) + err = resolver.Initialize() + if err != nil { + panic(err) + } + transport := transports.NewHTTPTransport(resolver) slog.Info("Starting server", "port", *port) // Map to store active WebSocket connections diff --git a/go.mod b/go.mod index 0fa0adb..17b1550 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/buraksezer/consistent v0.10.0 github.com/gobwas/ws v1.4.0 github.com/hashicorp/go-set/v3 v3.0.1 + github.com/lithammer/go-jump-consistent-hash v1.0.2 github.com/zeebo/xxh3 v1.0.2 k8s.io/api v0.32.3 k8s.io/apimachinery v0.32.3 @@ -30,6 +31,7 @@ require ( github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect diff --git a/go.sum b/go.sum index 09925db..66c1344 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/go-set/v3 v3.0.1 h1:ZwO15ZYmIrFYL9zSm2wBuwcRiHxVdp46m/XA/MUlM6I= github.com/hashicorp/go-set/v3 v3.0.1/go.mod h1:0oPQqhtitglZeT2ZiWnRIfUG6gJAHnn7LzrS7SbgNY4= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -60,6 +62,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lithammer/go-jump-consistent-hash v1.0.2 h1:w74N9XiMa4dWZdoVnfLbnDhfpGOMCxlrudzt2e7wtyk= +github.com/lithammer/go-jump-consistent-hash v1.0.2/go.mod h1:4MD1WDikNGnb9D56hAtscaZaOWOiCG+lLbRR5ZN9JL0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/internal/consistent_hashing/consistent_hashing.go b/internal/consistent_hashing/consistent_hashing.go new file mode 100644 index 0000000..9b0b120 --- /dev/null +++ b/internal/consistent_hashing/consistent_hashing.go @@ -0,0 +1,6 @@ +package consistent_hashing + +type ConsistentHashing[T any] interface { + Lookup([]byte) (T, error) + Transaction(Add, Remove []T) +} diff --git a/internal/consistent_hashing/jump_hash.go b/internal/consistent_hashing/jump_hash.go new file mode 100644 index 0000000..b274c65 --- /dev/null +++ b/internal/consistent_hashing/jump_hash.go @@ -0,0 +1,39 @@ +package consistent_hashing + +import ( + "fmt" + "lukas8219/websocket-operator/internal/peer_discovery" + + "github.com/lithammer/go-jump-consistent-hash" +) + +type JumpHashing struct { + peerDiscoveryBacked peer_discovery.PeerDiscovery +} + +func NewJumpHash(peerDiscoveryBackend peer_discovery.PeerDiscovery) JumpHashing { + return JumpHashing{ + peerDiscoveryBacked: peerDiscoveryBackend, + } +} + +func (j JumpHashing) Lookup(Recipient []byte) (peer_discovery.Peer, error) { + //TODO handle it + currentHosts, err := j.peerDiscoveryBacked.CurrentHosts() + if err != nil { + //We need to handle errrs + return peer_discovery.Peer{}, err + } + var myUint64 uint64 = 1 + index := jump.Hash( + myUint64, + int32(len(currentHosts)), + ) + if index == -1 { + return peer_discovery.Peer{}, fmt.Errorf("Didn't find any Peer to route") + } + return currentHosts[index], nil +} + +func (j JumpHashing) Transaction(Add, Remove []peer_discovery.Peer) { +} diff --git a/internal/consistent_hashing/rendezvous.go b/internal/consistent_hashing/rendezvous.go new file mode 100644 index 0000000..7d5d561 --- /dev/null +++ b/internal/consistent_hashing/rendezvous.go @@ -0,0 +1,185 @@ +package consistent_hashing + +import ( + "lukas8219/websocket-operator/internal/peer_discovery" + "math" + "sync" + + "github.com/buraksezer/consistent" + "github.com/zeebo/xxh3" +) + +var FiftyThreeOnes = uint64(0xFFFFFFFFFFFFFFFF >> (64 - 53)) +var FiftyThreeZeros = float64(1 << 53) + +var ErrInsufficientMemberCount = consistent.ErrInsufficientMemberCount + +type Hasher consistent.Hasher + +type DefaultHasher struct { +} + +func (h *DefaultHasher) Sum64(b []byte) uint64 { + return xxh3.Hash(b) +} + +type WeightedMember struct { + member string + weight float64 +} + +func (w *WeightedMember) GetMember() string { + return w.member +} + +// Config represents a structure to control the rendezvous package. +type Config struct { + Hasher Hasher +} + +// Rendezvous holds the information about the members of the consistent hash circle. +type Rendezvous struct { + mu sync.RWMutex + + config Config + hasher Hasher + members map[string]*WeightedMember + ring map[uint64]*WeightedMember +} + +func (r *Rendezvous) Transaction(NewNodes, RemoveNodes []peer_discovery.Peer) { + r.mu.Lock() + defer r.mu.Unlock() + for _, Node := range NewNodes { + r.Add(Node.String()) + } + for _, Node := range RemoveNodes { + r.Remove(Node.String()) + } +} + +// New creates and returns a new Rendezvous object +func New(members []WeightedMember, config Config) *Rendezvous { + r := &Rendezvous{ + config: config, + members: make(map[string]*WeightedMember), + ring: make(map[uint64]*WeightedMember), + } + + if config.Hasher == nil { + // Use the Default Hasher + r.hasher = &DefaultHasher{} + } else { + r.hasher = config.Hasher + } + for _, member := range members { + r.add(member) + } + return r +} + +func NewDefault() *Rendezvous { + return New([]WeightedMember{}, Config{}) +} + +// IntToFloat is a golang port of the python implementation mentioned here +// https://en.wikipedia.org/wiki/Rendezvous_hashing#Weighted_rendezvous_hash +func IntToFloat(value uint64) (float_value float64) { + return float64((value & FiftyThreeOnes)) / FiftyThreeZeros +} + +func (r *Rendezvous) ComputeWeightedScore(m WeightedMember, key []byte) (score float64) { + hash := r.hasher.Sum64(append([]byte(m.member), key...)) + score = 1.0 / math.Log(IntToFloat(hash)) + return m.weight * score +} + +func (r *Rendezvous) LocateKey(key []byte) (member WeightedMember) { + r.mu.RLock() + defer r.mu.RUnlock() + lowest_score := 1.0 + for _, _member := range r.members { + score := r.ComputeWeightedScore(*_member, key) + if score < lowest_score { + lowest_score = score + member = *_member + } + } + return member +} + +func (r *Rendezvous) Lookup(node string) string { + foundNode := r.LocateKey([]byte(node)) + return foundNode.member +} + +type byScore []struct { + string + float64 +} + +func (scores byScore) Len() int { + return len(scores) +} + +func (scores byScore) Swap(i, j int) { + scores[i], scores[j] = scores[j], scores[i] +} + +func (scores byScore) Less(i, j int) bool { + return scores[i].float64 < scores[j].float64 +} + +func (r *Rendezvous) add(member WeightedMember) { + r.members[member.member] = &member +} + +func (r *Rendezvous) AddMember(member WeightedMember) { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.members[member.member]; ok { + // We already have this member. Quit immediately. + return + } + r.add(member) +} + +func (r *Rendezvous) Add(node string) { + r.AddMember(WeightedMember{ + member: node, + weight: 1.0, + }) +} + +func (r *Rendezvous) Remove(name string) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, ok := r.members[name]; !ok { + // There is no member with that name. Quit immediately. + return + } + + delete(r.members, name) +} + +// GetMembers returns a thread-safe copy of members. +func (r *Rendezvous) GetNodes() (members []WeightedMember) { + r.mu.RLock() + defer r.mu.RUnlock() + + // Create a thread-safe copy of member list. + members = make([]WeightedMember, 0, len(r.members)) + for _, member := range r.members { + members = append(members, *member) + } + return +} + +func (r *Rendezvous) GetAllHosts() []string { + hosts := make([]string, 0) + for _, member := range r.members { + hosts = append(hosts, member.member) + } + return hosts +} diff --git a/internal/dns/dns.go b/internal/dns/dns.go deleted file mode 100644 index a424904..0000000 --- a/internal/dns/dns.go +++ /dev/null @@ -1,107 +0,0 @@ -package dns - -import ( - "context" - "log/slog" - "lukas8219/websocket-operator/internal/rendezvous" - "net" - "os" - "strconv" - "time" -) - -type DnsRouter struct { - loadbalancer *rendezvous.Rendezvous -} - -func (r *DnsRouter) Info(msg string, args ...any) { - slog.With("component", "router").With("mode", "dns").Info(msg, args...) -} - -func (r *DnsRouter) Debug(msg string, args ...any) { - slog.With("component", "router").With("mode", "dns").Debug(msg, args...) -} - -func (r *DnsRouter) Error(msg string, args ...any) { - slog.With("component", "router").With("mode", "dns").Error(msg, args...) -} - -func WithDns(loadbalancer *rendezvous.Rendezvous) *DnsRouter { - return &DnsRouter{ - loadbalancer, - } -} - -func (r *DnsRouter) InitializeHosts() error { - srvRecord := os.Getenv("WS_OPERATOR_SRV_DNS_RECORD") - if srvRecord == "" { - srvRecord = "ws-operator.local" - } - hosts, err := r.getCurrentHosts(srvRecord) - if err != nil { - return err - } - for _, host := range hosts { - r.loadbalancer.Add(host) - } - return nil -} - -func createResolver() *net.Resolver { - if os.Getenv("KUBERNETES_SERVICE_HOST") != "" { - return &net.Resolver{} - } - // Create a custom resolver that first tries localhost:53 (for testing) - // and falls back to the system resolver if that fails - r := &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - // First try localhost:53 - d := net.Dialer{} - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - slog.Debug("Looking for address on localhost:53", "address", address) - conn, err := d.DialContext(ctx, "udp", "0.0.0.0:53") - if err != nil { - slog.Debug("Failed to connect to localhost:53, falling back to system resolver", "error", err) - return d.DialContext(ctx, network, address) - } - return conn, nil - }, - } - - return r -} - -func (r *DnsRouter) getCurrentHosts(service string) ([]string, error) { - resolver := createResolver() - slog.Debug("Getting random SRV host for service", "service", service) - _, addrs, err := resolver.LookupSRV(context.Background(), "", "", service) - if err != nil { - return nil, err - } - - if len(addrs) == 0 { - return []string{}, nil - } - - // Create a slice of tuples [addr,port] from the SRV records - addrPorts := make([]string, len(addrs)) - for i, srv := range addrs { - addr, err := resolver.LookupIP(context.Background(), "ip", srv.Target) - if err != nil { - return nil, err - } - addrPorts[i] = net.JoinHostPort(addr[0].String(), strconv.Itoa(int(srv.Port))) - } - slog.Debug("Found addresses", "addrs", addrPorts) - return addrPorts, nil -} - -func (r *DnsRouter) Route(recipientId string) string { - return r.loadbalancer.Lookup(recipientId) -} - -func (r *DnsRouter) RebalanceRequests() <-chan [][2]string { - return nil -} diff --git a/internal/kubernetes/kubernetes.go b/internal/kubernetes/kubernetes.go deleted file mode 100644 index 76a722f..0000000 --- a/internal/kubernetes/kubernetes.go +++ /dev/null @@ -1,177 +0,0 @@ -package kubernetes - -import ( - "fmt" - "log/slog" - "os" - "path/filepath" - - "lukas8219/websocket-operator/internal/rendezvous" - - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/fields" - "k8s.io/client-go/kubernetes" - "k8s.io/client-go/rest" - "k8s.io/client-go/tools/cache" - clientcmd "k8s.io/client-go/tools/clientcmd" -) - -type KubernetesRouter struct { - k8sClient *kubernetes.Clientset - cacheStore cache.Store - loadbalancer *rendezvous.Rendezvous - alreadyCalculatedRecipients map[string]string - handleUpdatedEndpoints func([]string) - handleCreatedEnpoints func([]string) - handleDeletedEnpoints func([]string) - rebalanceRequest chan [][2]string -} - -func (k *KubernetesRouter) Info(msg string, args ...any) { - slog.With("component", "router").With("mode", "kubernetes").Info(msg, args...) -} - -func (k *KubernetesRouter) Debug(msg string, args ...any) { - slog.With("component", "router").With("mode", "kubernetes").Debug(msg, args...) -} - -func (k *KubernetesRouter) Error(msg string, args ...any) { - slog.With("component", "router").With("mode", "kubernetes").Error(msg, args...) -} - -func (k *KubernetesRouter) GetAllUpstreamHosts() []string { - return k.loadbalancer.GetAllHosts() -} - -var ( - addedHosts = make(map[string]bool) -) - -func NewRouter(loadbalancer *rendezvous.Rendezvous) *KubernetesRouter { - client := createClient() - return &KubernetesRouter{ - k8sClient: client, - alreadyCalculatedRecipients: make(map[string]string), - loadbalancer: loadbalancer, - rebalanceRequest: make(chan [][2]string, 1), - } -} - -func createClient() *kubernetes.Clientset { - config, err := rest.InClusterConfig() - if err != nil { - kubeconfig := filepath.Join( - os.Getenv("HOME"), ".kube", "config", - ) - config, err = clientcmd.BuildConfigFromFlags("", kubeconfig) - slog.Info("Failed to get in-cluster config, using empty config") - } - - return kubernetes.NewForConfigOrDie(config) -} - -func (k *KubernetesRouter) Route(recipientId string) string { - k.Debug("Lookup", "recipientId", recipientId, "nodes", k.loadbalancer.GetNodes()) - host := k.loadbalancer.Lookup(recipientId) - if host == "" { - k.Debug("No host found", "recipientId", recipientId, "nodes", k.loadbalancer.GetNodes()) - return "" - } - k.alreadyCalculatedRecipients[recipientId] = host - k.Debug("Host found", "recipientId", recipientId, "host", host) - host = fmt.Sprintf("%s:3000", host) - return host -} - -func (k *KubernetesRouter) Add(host []string) { - return -} - -func (k *KubernetesRouter) InitializeHosts() error { - watchList := cache.NewListWatchFromClient(k.k8sClient.CoreV1().RESTClient(), "endpoints", "default", - fields.OneTermEqualSelector("metadata.name", "ws-proxy-headless"), - ) - - store, controller := cache.NewInformerWithOptions(cache.InformerOptions{ - ListerWatcher: watchList, - ObjectType: &v1.Endpoints{}, - Handler: cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { - hosts := make([]string, 0) - for _, address := range obj.(*v1.Endpoints).Subsets { - for _, address := range address.Addresses { - if address.IP != "" { - hosts = append(hosts, address.IP) - } - } - } - for _, host := range hosts { - k.loadbalancer.Add(host) - addedHosts[host] = true - } - k.Info("Added addresses", "hosts", hosts) - }, - UpdateFunc: func(oldObj, newObj interface{}) { - hosts := make([]string, 0) - //This is nuts, yes. But i'll look into re-writing the Rendezvous to be customized for this use case - for _, subset := range oldObj.(*v1.Endpoints).Subsets { - for _, address := range subset.Addresses { - k.loadbalancer.Remove(address.IP) - delete(addedHosts, address.IP) - } - } - for _, subset := range newObj.(*v1.Endpoints).Subsets { - for _, address := range subset.Addresses { - hosts = append(hosts, address.IP) - k.loadbalancer.Add(address.IP) - } - } - k.Info("Updated addresses", "hosts", hosts) - rebalanceHosts := make([][2]string, 0) - //re-calculate computed recipients to check re-balancing - for recipientId, host := range k.alreadyCalculatedRecipients { - newlyCalculatedHost := k.loadbalancer.Lookup(recipientId) - k.Debug("Checking rebalance", "recipientId", recipientId, "oldHost", host, "newHost", newlyCalculatedHost) - if newlyCalculatedHost == host { - continue - } - newlyCalculatedHostWithPort := fmt.Sprintf("%s:3000", newlyCalculatedHost) - rebalanceHosts = append(rebalanceHosts, [2]string{recipientId, newlyCalculatedHostWithPort}) - } - if len(rebalanceHosts) > 0 { - k.Info("Rebalancing hosts", "hosts", rebalanceHosts) - k.triggerRebalance(rebalanceHosts) - } else { - k.Debug("No rebalancing hosts found") - } - }, - DeleteFunc: func(obj interface{}) { - hosts := make([]string, 1) - for _, address := range obj.(*v1.Endpoints).Subsets[0].Addresses { - hosts = append(hosts, address.IP) - } - for _, host := range hosts { - k.loadbalancer.Remove(host) - } - k.Info("Deleted addresses", "hosts", hosts) - }, - }, - }) - stop := make(chan struct{}) - go controller.Run(stop) - if !cache.WaitForCacheSync(stop, controller.HasSynced) { - k.Error("Timed out waiting for caches to sync") - return fmt.Errorf("timed out waiting for caches to sync") - } - k.cacheStore = store - return nil -} - -func (k *KubernetesRouter) triggerRebalance(hosts [][2]string) { - k.Debug("Sending rebalance request", "hosts", hosts) - k.rebalanceRequest <- hosts -} - -func (k *KubernetesRouter) RebalanceRequests() <-chan [][2]string { - return k.rebalanceRequest -} diff --git a/internal/peer_discovery/in-memory.go b/internal/peer_discovery/in-memory.go new file mode 100644 index 0000000..2e078c7 --- /dev/null +++ b/internal/peer_discovery/in-memory.go @@ -0,0 +1,58 @@ +package peer_discovery + +import ( + "lukas8219/websocket-operator/internal/diff" + + "github.com/hashicorp/go-set/v3" +) + +type InMemoryPeerDiscovery struct { + availablePeers *set.Set[Peer] + channel chan diff.DifferenceOutput[Peer] +} + +func NewInMemoryPeerDiscovery() InMemoryPeerDiscovery { + return InMemoryPeerDiscovery{ + availablePeers: set.New[Peer](1), + channel: make(chan diff.DifferenceOutput[Peer], 0), + } +} + +func (m InMemoryPeerDiscovery) Initialize() error { + return nil +} + +func (m InMemoryPeerDiscovery) CurrentHosts() ([]Peer, error) { + return m.availablePeers.Slice(), nil +} + +func (m InMemoryPeerDiscovery) Mode() PeerDiscoveryMode { + return PeerDiscoveryModeInMemory +} + +func (m InMemoryPeerDiscovery) NotificationChannel() chan diff.DifferenceOutput[Peer] { + return m.channel +} + +func (m InMemoryPeerDiscovery) AddPeer(peer Peer) { + if !m.availablePeers.Insert(peer) { + return + } + m.channel <- diff.DifferenceOutput[Peer]{ + Added: []Peer{peer}, + } +} + +func (m InMemoryPeerDiscovery) RemovePeer(peer Peer) { + if !m.availablePeers.Remove(peer) { + return + } + m.channel <- diff.DifferenceOutput[Peer]{ + Removed: []Peer{peer}, + } +} + +func (m InMemoryPeerDiscovery) AtomicOperation(NewPeers, RemovePeers []Peer) { + diff := diff.Difference(m.availablePeers, NewPeers, RemovePeers) + m.channel <- diff +} diff --git a/internal/peer_discovery/kubernetes.go b/internal/peer_discovery/kubernetes.go index 69f6508..080250a 100644 --- a/internal/peer_discovery/kubernetes.go +++ b/internal/peer_discovery/kubernetes.go @@ -30,6 +30,7 @@ func NewKubernetes(namespace string, service string) *KubernetesPeerDiscovery { return &KubernetesPeerDiscovery{ k8sNamespace: namespace, targetK8sServiceName: service, + k8sClient: createClient(), } } diff --git a/internal/peer_discovery/peer_discovery.go b/internal/peer_discovery/peer_discovery.go index 153b681..0a79d4f 100644 --- a/internal/peer_discovery/peer_discovery.go +++ b/internal/peer_discovery/peer_discovery.go @@ -47,4 +47,5 @@ type PeerDiscoveryConfig struct { const ( PeerDiscoveryModeDns PeerDiscoveryMode = "dns" PeerDiscoveryModeKubernetes PeerDiscoveryMode = "kubernetes" + PeerDiscoveryModeInMemory PeerDiscoveryMode = "in-memory" ) diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index a730f0e..058d18b 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -1,36 +1,81 @@ package resolver import ( + "fmt" + "lukas8219/websocket-operator/internal/consistent_hashing" peerDiscovery "lukas8219/websocket-operator/internal/peer_discovery" - "lukas8219/websocket-operator/internal/rendezvous" + "sync/atomic" + + "k8s.io/utils/lru" ) -type Resolver struct { - hashingAlgorithm *rendezvous.Rendezvous +type ResolverVersion = uint32 + +type Resolver interface { + peerDiscovery.PeerDiscovery + Init() + VersionUpgradeChannel() chan ResolverVersion + Lookup([]byte) (peerDiscovery.Peer, error) +} + +type ResolverImpl struct { + consistentHashingAlgorithm consistent_hashing.ConsistentHashing[peerDiscovery.Peer] peerDiscovery.PeerDiscovery + version atomic.Uint32 + cache *lru.Cache + versionUpgradeChannel chan ResolverVersion } func New( peerDiscovery peerDiscovery.PeerDiscovery, - hashingAlgorithm *rendezvous.Rendezvous, -) *Resolver { - return nil + consistentHashing consistent_hashing.ConsistentHashing[peerDiscovery.Peer], +) Resolver { + return &ResolverImpl{ + consistentHashingAlgorithm: consistentHashing, + PeerDiscovery: peerDiscovery, + cache: lru.New(1024), + version: atomic.Uint32{}, + versionUpgradeChannel: make(chan ResolverVersion), + } } -func (r *Resolver) Init() { +func (r *ResolverImpl) VersionUpgradeChannel() chan ResolverVersion { + return r.versionUpgradeChannel +} + +func (r *ResolverImpl) Init() { for event := range r.NotificationChannel() { - r.hashingAlgorithm.Transaction( + r.consistentHashingAlgorithm.Transaction( event.Added, event.Removed, ) + new := r.version.Add(1) + r.versionUpgradeChannel <- new } } -func (r *Resolver) Lookup(Recipient []byte) (peerDiscovery.Peer, error) { +func (r *ResolverImpl) Lookup(Recipient []byte) (peerDiscovery.Peer, error) { _, error := r.CurrentHosts() if error != nil { return peerDiscovery.Peer{}, error } - member := r.hashingAlgorithm.LocateKey(Recipient) - return peerDiscovery.NewPeer(member.GetMember(), 3000), nil + // TODO: investigate how to implemeny :relaxed memory access to prevent mem ordering here + version := r.version.Load() + cachedEntry, found := r.cache.Get(createCacheKey(version, Recipient)) + if found { + return cachedEntry.(peerDiscovery.Peer), nil + } + peer, err := r.consistentHashingAlgorithm.Lookup(Recipient) + if err != nil { + return peerDiscovery.Peer{}, err + } + r.cache.Add( + createCacheKey(version, Recipient), + peer, + ) + return peer, nil +} + +func createCacheKey(version uint32, Recipient []byte) string { + return fmt.Sprintf("%d:%s", version, Recipient) } diff --git a/internal/route/router.go b/internal/route/router.go deleted file mode 100644 index 3fd09bb..0000000 --- a/internal/route/router.go +++ /dev/null @@ -1,47 +0,0 @@ -package route - -import ( - "log/slog" - "lukas8219/websocket-operator/internal/dns" - "lukas8219/websocket-operator/internal/kubernetes" - - "lukas8219/websocket-operator/internal/rendezvous" -) - -type Logger interface { - Info(msg string, args ...any) - Debug(msg string, args ...any) - Error(msg string, args ...any) -} - -type RouterImpl interface { - InitializeHosts() error - Route(recipientId string) string - RebalanceRequests() <-chan [][2]string - Logger -} - -type RouterConfigMode string - -type RouterConfig struct { - Mode RouterConfigMode - ConfigMeta interface{} -} - -const ( - RouterConfigModeDns RouterConfigMode = "dns" - RouterConfigModeKubernetes RouterConfigMode = "kubernetes" -) - -func NewRouter(config RouterConfig) RouterImpl { - slog.With("component", "router").With("mode", config.Mode).Info("New router") - rendezvous := rendezvous.NewDefault() - switch config.Mode { - case RouterConfigModeDns: - return dns.WithDns(rendezvous) - case RouterConfigModeKubernetes: - return kubernetes.NewRouter(rendezvous) - default: - panic("invalid router mode") - } -}