diff --git a/cmd/sidecar/main.go b/cmd/sidecar/main.go index 9b34a0a..0e5acd8 100644 --- a/cmd/sidecar/main.go +++ b/cmd/sidecar/main.go @@ -6,8 +6,11 @@ import ( "flag" "io" "log/slog" - "lukas8219/websocket-operator/cmd/sidecar/proxy" "lukas8219/websocket-operator/internal/logger" + "lukas8219/websocket-operator/internal/peer_discovery" + "lukas8219/websocket-operator/internal/rendezvous" + "lukas8219/websocket-operator/internal/resolver" + "lukas8219/websocket-operator/internal/transports" "net" "net/http" "os" @@ -23,6 +26,7 @@ type ConnectionTracker struct { downstreamHost string upstreamConn net.Conn downstreamConn net.Conn + transports.Transport } func (c *ConnectionTracker) Info(message string, args ...any) *ConnectionTracker { @@ -51,11 +55,15 @@ var incomingMessageStruct = reflect.StructOf([]reflect.StructField{ func main() { port := flag.String("port", "3000", "Port to listen on") targetPort := flag.String("targetPort", "3001", "Port to target") - mode := flag.String("mode", "kubernetes", "Mode to use") + // mode := flag.String("mode", "kubernetes", "Mode to use") debug := flag.Bool("debug", false, "Debug mode") flag.Parse() logger.SetupLogger(*debug) - proxy.InitializeProxy(*mode) + //TODO move to config + peerDiscovery := peer_discovery.NewKubernetes("default", "ws-headless-proxy") + resolver := resolver.New(peerDiscovery, rendezvous.NewDefault()) + transport := transports.NewHTTPTransport(*resolver) + slog.Info("Starting server", "port", *port) // Map to store active WebSocket connections // Key: user ID, Value: ConnectionTracker @@ -115,6 +123,7 @@ func main() { downstreamHost: r.RemoteAddr, upstreamConn: proxiedConn, downstreamConn: clientConn, + Transport: &transport, } connections[user] = connectionTracker if err != nil { @@ -143,7 +152,6 @@ func proxySidecarServerToClient(deferClose func(), connectionTracker *Connection connectionTracker.Error("Failed to read from server", "error", err) return } - //TODO: we might need to handle `recipientId` routing messages here also //Write as client - to the proxied connection @@ -190,7 +198,7 @@ func handleIncomingMessagesToProxy(connections map[string]*ConnectionTracker, de slog.Debug("Message recipient", "recipientId", recipientIdString, "recipientConnection", recipientConnection) if recipientConnection == nil { slog.Debug("No recipient found in-memory. Routing message to the correct target.", "recipientId", recipientIdString) - err := proxy.SendProxiedMessage(recipientIdString, msg, op) + err := connectionTracker.Write(rawBytes, op, msg) if err != nil { connectionTracker.Error("Failed to route message", "error", err) } diff --git a/cmd/sidecar/proxy/proxy.go b/cmd/sidecar/proxy/proxy.go deleted file mode 100644 index 029e458..0000000 --- a/cmd/sidecar/proxy/proxy.go +++ /dev/null @@ -1,58 +0,0 @@ -package proxy - -import ( - "bytes" - "context" - "errors" - "log/slog" - "lukas8219/websocket-operator/internal/route" - "net/http" - "time" - - "github.com/gobwas/ws" -) - -var ( - router route.RouterImpl -) - -func InitializeProxy(mode string) { - router = route.NewRouter(route.RouterConfig{Mode: route.RouterConfigMode(mode)}) - err := router.InitializeHosts() - if err != nil { - slog.Error("failed to initialize hosts", "error", err) - //TODO should we panic here? - } -} - -func SendProxiedMessage(recipientId string, message []byte, opCode ws.OpCode) error { - host := router.Route(recipientId) - slog := slog.With("recipientId", recipientId).With("opCode", opCode).With("host", host).With("component", "proxy") - if host == "" { - slog.Error("no host found") - return errors.New("no host found") - } - slog.Debug("Routing message") - //TODO hardcoded 5 seconds to debug DNS resolve issues - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - messageWithOpCode := append([]byte{byte(opCode)}, message...) - url := "http://" + host + "/message" - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(messageWithOpCode)) - if err != nil { - slog.Error("failed to create request", "error", err) - return errors.Join(errors.New("failed to create request"), err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("ws-user-id", recipientId) - - slog.Debug("POST request", "url", url) - resp, err := http.DefaultClient.Do(req) - if err != nil { - slog.Error("Error sending request", "error", err) - return err - } - slog.Debug("Received response", "status", resp.Status) - defer resp.Body.Close() - return nil -} diff --git a/go.mod b/go.mod index ec41bc0..0fa0adb 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.23.7 require ( github.com/buraksezer/consistent v0.10.0 github.com/gobwas/ws v1.4.0 + github.com/hashicorp/go-set/v3 v3.0.1 github.com/zeebo/xxh3 v1.0.2 k8s.io/api v0.32.3 k8s.io/apimachinery v0.32.3 diff --git a/go.sum b/go.sum index 50ee823..09925db 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgY github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/go-set/v3 v3.0.1 h1:ZwO15ZYmIrFYL9zSm2wBuwcRiHxVdp46m/XA/MUlM6I= +github.com/hashicorp/go-set/v3 v3.0.1/go.mod h1:0oPQqhtitglZeT2ZiWnRIfUG6gJAHnn7LzrS7SbgNY4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -78,6 +80,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/shoenig/test v1.12.1 h1:mLHfnMv7gmhhP44WrvT+nKSxKkPDiNkIuHGdIGI9RLU= +github.com/shoenig/test v1.12.1/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/internal/diff/diff.go b/internal/diff/diff.go new file mode 100644 index 0000000..8e1d110 --- /dev/null +++ b/internal/diff/diff.go @@ -0,0 +1,22 @@ +package diff + +import "github.com/hashicorp/go-set/v3" + +type DifferenceOutput[T comparable] struct { + Added []T + Removed []T +} + +// before[1,2,3] + [4] -> currentState[1,2,3,4] = Before|currentState [4] +// before[1,2,3] - [3] -> currentState[1,2,4] = currentState|before [3] +func Difference[T comparable](currentState *set.Set[T], NewEntries []T, ToRemoveEntries []T) DifferenceOutput[T] { + beforeUpdate := currentState.Copy() + currentState.InsertSlice(NewEntries) + currentState.RemoveSlice(ToRemoveEntries) + added := currentState.Difference(beforeUpdate) + removed := beforeUpdate.Difference(currentState) + return DifferenceOutput[T]{ + Added: added.Slice(), + Removed: removed.Slice(), + } +} diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go new file mode 100644 index 0000000..b319bb9 --- /dev/null +++ b/internal/diff/diff_test.go @@ -0,0 +1,25 @@ +package diff_test + +import ( + "lukas8219/websocket-operator/internal/diff" + "testing" + + "github.com/hashicorp/go-set/v3" +) + +func TestDifferenceOuputOnAtomicUpsert(t *testing.T) { + initialState := set.New[int](10) + initialState.InsertSlice([]int{1, 2, 3}) + output := diff.Difference(initialState, []int{4, 3, 2, 5}, []int{3, 1}) + if !set.From(output.Added).EqualSlice([]int{4, 5}) { + t.Error("Output expected to have added int 4,5") + } + + if !set.From(output.Removed).EqualSlice([]int{3, 1}) { + t.Error("Output expected to have remove 3,1") + } + + if !initialState.EqualSlice([]int{2, 4, 5}) { + t.Error("Final state expected is 2,4,5", initialState.String()) + } +} diff --git a/internal/peer_discovery/dns.go b/internal/peer_discovery/dns.go new file mode 100644 index 0000000..42f0649 --- /dev/null +++ b/internal/peer_discovery/dns.go @@ -0,0 +1,89 @@ +package peer_discovery + +import ( + "context" + "log/slog" + "net" + "os" + "time" +) + +type DnsPeerDiscovery struct { + srvRecord string + notificationChannel chan []Peer + PeerDiscovery +} + +func (r *DnsPeerDiscovery) Initialize() error { + return nil +} + +func NewDNS(srvRecord string) *DnsPeerDiscovery { + return &DnsPeerDiscovery{ + srvRecord: srvRecord, + } +} + +func (r *DnsPeerDiscovery) NotificationChannel() chan []Peer { + return r.notificationChannel +} + +func (r *DnsPeerDiscovery) CurrentHosts() ([]Peer, error) { + resolver := createResolver() + slog.Debug("Getting random SRV host for service", "service", r.srvRecord) + _, addrs, err := resolver.LookupSRV(context.Background(), "", "", r.srvRecord) + if err != nil { + return nil, err + } + + if len(addrs) == 0 { + return []Peer{}, nil + } + + peers := make([]Peer, len(addrs)) + for i, srv := range addrs { + addr, err := resolver.LookupIP(context.Background(), "ip", srv.Target) + if err != nil { + slog.Warn("Failed to resolve ip - skipping", "ip", srv.Target) + continue + } + hostname := addr[0].String() + port := srv.Port + peers[i] = Peer{ + hostname, + port, + } + } + return peers, nil +} + +func (d *DnsPeerDiscovery) Mode() PeerDiscoveryMode { + return PeerDiscoveryModeDns +} + +// TODO review +func createResolver() *net.Resolver { + if os.Getenv("KUBERNETES_SERVICE_HOST") != "" { + return &net.Resolver{} + } + // Create a custom resolver that first tries localhost:53 (for testing) + // and falls back to the system resolver if that fails + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + // First try localhost:53 + d := net.Dialer{} + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + slog.Debug("Looking for address on localhost:53", "address", address) + conn, err := d.DialContext(ctx, "udp", "0.0.0.0:53") + if err != nil { + slog.Debug("Failed to connect to localhost:53, falling back to system resolver", "error", err) + return d.DialContext(ctx, network, address) + } + return conn, nil + }, + } + + return r +} diff --git a/internal/peer_discovery/kubernetes.go b/internal/peer_discovery/kubernetes.go new file mode 100644 index 0000000..69f6508 --- /dev/null +++ b/internal/peer_discovery/kubernetes.go @@ -0,0 +1,133 @@ +package peer_discovery + +import ( + "fmt" + "log/slog" + "lukas8219/websocket-operator/internal/diff" + "os" + "path/filepath" + + goset "github.com/hashicorp/go-set/v3" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + clientcmd "k8s.io/client-go/tools/clientcmd" +) + +type KubernetesPeerDiscovery struct { + k8sClient *kubernetes.Clientset + cacheStore cache.Store + currentHosts *goset.Set[string] + targetK8sServiceName string + k8sNamespace string + notificationChannel chan diff.DifferenceOutput[Peer] + PeerDiscovery +} + +func NewKubernetes(namespace string, service string) *KubernetesPeerDiscovery { + return &KubernetesPeerDiscovery{ + k8sNamespace: namespace, + targetK8sServiceName: service, + } +} + +func (k *KubernetesPeerDiscovery) NotificationChannel() chan diff.DifferenceOutput[Peer] { + return k.notificationChannel +} + +var ( + EMPTY_ARRAY = make([]string, 0) +) + +const ( + MAX_KUBERNETES_ENDPOINT_SLICE_SIZE = 1000 +) + +// Remove OR MOVE +func createClient() *kubernetes.Clientset { + config, err := rest.InClusterConfig() + if err != nil { + kubeconfig := filepath.Join( + os.Getenv("HOME"), ".kube", "config", + ) + config, err = clientcmd.BuildConfigFromFlags("", kubeconfig) + slog.Info("Failed to get in-cluster config, using empty config") + } + + return kubernetes.NewForConfigOrDie(config) +} + +func (k *KubernetesPeerDiscovery) Initialize() error { + watchList := cache.NewListWatchFromClient(k.k8sClient.CoreV1().RESTClient(), "endpoints", k.k8sNamespace, + fields.OneTermEqualSelector("metadata.name", k.targetK8sServiceName), + ) + + store, controller := cache.NewInformerWithOptions(cache.InformerOptions{ + ListerWatcher: watchList, + ObjectType: &v1.Endpoints{}, + Handler: cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + hosts := getAllAddressesFromEndpoint(obj.(*v1.Endpoints)) + k.updateHostsArray(hosts, EMPTY_ARRAY) + }, + UpdateFunc: func(oldObj, newObj interface{}) { + hosts := getAllAddressesFromEndpoint(newObj.(*v1.Endpoints)) + k.updateHostsArray(hosts, EMPTY_ARRAY) + }, + DeleteFunc: func(obj interface{}) { + hosts := getAllAddressesFromEndpoint(obj.(*v1.Endpoints)) + k.updateHostsArray(EMPTY_ARRAY, hosts) + }, + }, + }) + stop := make(chan struct{}) + go controller.Run(stop) + if !cache.WaitForCacheSync(stop, controller.HasSynced) { + slog.Error("Timed out waiting for caches to sync") + return fmt.Errorf("timed out waiting for caches to sync") + } + k.cacheStore = store + return nil +} + +func getAllAddressesFromEndpoint(endpoint *v1.Endpoints) []string { + hosts := make([]string, 256) + for _, address := range endpoint.Subsets { + for _, address := range address.Addresses { + if address.IP != "" { + hosts = append(hosts, address.IP) + } + } + } + return hosts +} + +func (k *KubernetesPeerDiscovery) updateHostsArray(NewHosts []string, ToRemoveHosts []string) { + difference := diff.Difference[string](k.currentHosts, NewHosts, ToRemoveHosts) + k.notificationChannel <- diff.DifferenceOutput[Peer]{ + Added: mapToPeers(difference.Added), + Removed: mapToPeers(difference.Removed), + } +} + +func mapToPeers(hosts []string) []Peer { + mappedHosts := make([]Peer, len(hosts)) + for _, address := range hosts { + mappedHosts = append(mappedHosts, Peer{ + hostname: address, + port: 3000, + }) + } + return mappedHosts +} + +func (k *KubernetesPeerDiscovery) GetCurrentHosts() ([]Peer, error) { + mappedHosts := mapToPeers(k.currentHosts.Slice()) + return mappedHosts, nil +} + +func (k *KubernetesPeerDiscovery) Mode() PeerDiscoveryMode { + return PeerDiscoveryModeKubernetes +} diff --git a/internal/peer_discovery/peer_discovery.go b/internal/peer_discovery/peer_discovery.go new file mode 100644 index 0000000..153b681 --- /dev/null +++ b/internal/peer_discovery/peer_discovery.go @@ -0,0 +1,50 @@ +package peer_discovery + +import ( + "fmt" + "lukas8219/websocket-operator/internal/diff" +) + +type Peer struct { + hostname string + port uint16 +} + +func (p Peer) Hostname() string { + return p.hostname +} + +func (p *Peer) String() string { + return p.hostname +} + +func (p Peer) Port() uint16 { + return p.port +} + +func (p Peer) SocketAddres() string { + return fmt.Sprintf("%s:%d", p.hostname, p.port) +} + +type PeerDiscovery interface { + Initialize() error + CurrentHosts() ([]Peer, error) + NotificationChannel() chan diff.DifferenceOutput[Peer] + Mode() PeerDiscoveryMode +} + +func NewPeer(hostname string, port uint16) Peer { + return Peer{hostname, port} +} + +type PeerDiscoveryMode string + +type PeerDiscoveryConfig struct { + Mode PeerDiscoveryMode + ConfigMeta interface{} +} + +const ( + PeerDiscoveryModeDns PeerDiscoveryMode = "dns" + PeerDiscoveryModeKubernetes PeerDiscoveryMode = "kubernetes" +) diff --git a/internal/rendezvous/rendezvous.go b/internal/rendezvous/rendezvous.go index ea4dcae..cf1dae7 100644 --- a/internal/rendezvous/rendezvous.go +++ b/internal/rendezvous/rendezvous.go @@ -1,6 +1,7 @@ package rendezvous import ( + "lukas8219/websocket-operator/internal/peer_discovery" "math" "sync" @@ -27,6 +28,10 @@ type WeightedMember struct { weight float64 } +func (w *WeightedMember) GetMember() string { + return w.member +} + // Config represents a structure to control the rendezvous package. type Config struct { Hasher Hasher @@ -42,6 +47,17 @@ type Rendezvous struct { ring map[uint64]*WeightedMember } +func (r *Rendezvous) Transaction(NewNodes, RemoveNodes []peer_discovery.Peer) { + r.mu.Lock() + defer r.mu.Unlock() + for _, Node := range NewNodes { + r.Add(Node.String()) + } + for _, Node := range RemoveNodes { + r.Remove(Node.String()) + } +} + // New creates and returns a new Rendezvous object func New(members []WeightedMember, config Config) *Rendezvous { r := &Rendezvous{ diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go new file mode 100644 index 0000000..a730f0e --- /dev/null +++ b/internal/resolver/resolver.go @@ -0,0 +1,36 @@ +package resolver + +import ( + peerDiscovery "lukas8219/websocket-operator/internal/peer_discovery" + "lukas8219/websocket-operator/internal/rendezvous" +) + +type Resolver struct { + hashingAlgorithm *rendezvous.Rendezvous + peerDiscovery.PeerDiscovery +} + +func New( + peerDiscovery peerDiscovery.PeerDiscovery, + hashingAlgorithm *rendezvous.Rendezvous, +) *Resolver { + return nil +} + +func (r *Resolver) Init() { + for event := range r.NotificationChannel() { + r.hashingAlgorithm.Transaction( + event.Added, + event.Removed, + ) + } +} + +func (r *Resolver) Lookup(Recipient []byte) (peerDiscovery.Peer, error) { + _, error := r.CurrentHosts() + if error != nil { + return peerDiscovery.Peer{}, error + } + member := r.hashingAlgorithm.LocateKey(Recipient) + return peerDiscovery.NewPeer(member.GetMember(), 3000), nil +} diff --git a/internal/transports/http.go b/internal/transports/http.go new file mode 100644 index 0000000..663bc84 --- /dev/null +++ b/internal/transports/http.go @@ -0,0 +1,57 @@ +package transports + +import ( + "bytes" + "context" + "errors" + "log/slog" + rslv "lukas8219/websocket-operator/internal/resolver" + "net/http" + "time" + + "github.com/gobwas/ws" +) + +type HttpTransport struct { + resolver rslv.Resolver +} + +func NewHTTPTransport( + resolver rslv.Resolver, +) HttpTransport { + return HttpTransport{resolver} +} + +func (h *HttpTransport) Write( + Recipient []byte, + OpCode ws.OpCode, + Data []byte, +) error { + peer, error := h.resolver.Lookup(Recipient) + if error != nil { + return error + } + slog := slog.With("recipientId", Recipient).With("opCode", OpCode).With("peer", peer).With("component", "proxy") + //TODO hardcoded 5 seconds to debug DNS resolve issues + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + messageWithOpCode := append([]byte{byte(OpCode)}, Data...) + url := "http://" + peer.SocketAddres() + "/message" + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(messageWithOpCode)) + if err != nil { + slog.Error("failed to create request", "error", err) + return errors.Join(errors.New("failed to create request"), err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("ws-user-id", string(Recipient)) + + slog.Debug("POST request", "url", url) + resp, err := http.DefaultClient.Do(req) + if err != nil { + slog.Error("Error sending request", "error", err) + return err + } + slog.Debug("Received response", "status", resp.Status) + defer resp.Body.Close() + return nil +} diff --git a/internal/transports/transports.go b/internal/transports/transports.go new file mode 100644 index 0000000..3bd1965 --- /dev/null +++ b/internal/transports/transports.go @@ -0,0 +1,7 @@ +package transports + +import "github.com/gobwas/ws" + +type Transport interface { + Write(Recipient []byte, OpCode ws.OpCode, Data []byte) error +} diff --git a/internal/utils/slice_utils.go b/internal/utils/slice_utils.go new file mode 100644 index 0000000..d240980 --- /dev/null +++ b/internal/utils/slice_utils.go @@ -0,0 +1,9 @@ +package utils + +func ConvertSlice[T any, R any](in []T, f func(T) R) []R { + out := make([]R, len(in)) + for i, v := range in { + out[i] = f(v) + } + return out +}