Skip to content

Commit bfb0a81

Browse files
authored
Add Weighted Round Robin Algorithm to Load Balancer (#591)
1 parent 856ba96 commit bfb0a81

File tree

11 files changed

+493
-10
lines changed

11 files changed

+493
-10
lines changed

cmd/run.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,7 @@ var runCmd = &cobra.Command{
851851
proxies[configGroupName][configBlockName] = network.NewProxy(
852852
runCtx,
853853
network.Proxy{
854+
Name: configBlockName,
854855
AvailableConnections: pools[configGroupName][configBlockName],
855856
PluginRegistry: pluginRegistry,
856857
HealthCheckPeriod: cfg.HealthCheckPeriod,
@@ -918,6 +919,7 @@ var runCmd = &cobra.Command{
918919
KeyFile: cfg.KeyFile,
919920
HandshakeTimeout: cfg.HandshakeTimeout,
920921
LoadbalancerStrategyName: cfg.LoadBalancer.Strategy,
922+
LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules,
921923
},
922924
)
923925

config/config.go

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
goerrors "errors"
77
"fmt"
88
"log"
9+
"math"
910
"os"
1011
"reflect"
1112
"sort"
@@ -681,16 +682,26 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError {
681682
}
682683

683684
for configGroup := range globalConfig.Servers {
684-
if globalConfig.Servers[configGroup] == nil {
685+
serverConfig := globalConfig.Servers[configGroup]
686+
687+
if serverConfig == nil {
685688
err := fmt.Errorf("\"servers.%s\" is nil or empty", configGroup)
686689
span.RecordError(err)
687690
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
691+
continue
688692
}
689693
if configGroup != strings.ToLower(configGroup) {
690694
err := fmt.Errorf(`"servers.%s" is not lowercase`, configGroup)
691695
span.RecordError(err)
692696
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
693697
}
698+
699+
// Validate Load Balancing Rules
700+
validatelBRulesErrors := ValidateLoadBalancingRules(serverConfig, configGroup, clientConfigGroups)
701+
for _, err := range validatelBRulesErrors {
702+
span.RecordError(err)
703+
errors = append(errors, gerr.ErrValidationFailed.Wrap(err))
704+
}
694705
}
695706

696707
if len(globalConfig.Servers) > 1 {
@@ -797,3 +808,113 @@ func generateTagMapping(structs []interface{}, tagMapping map[string]string) {
797808
}
798809
}
799810
}
811+
812+
// ValidateLoadBalancingRules validates the load balancing rules in the server configuration.
813+
func ValidateLoadBalancingRules(
814+
serverConfig *Server,
815+
configGroup string,
816+
clientConfigGroups map[string]map[string]bool,
817+
) []error {
818+
var errors []error
819+
820+
// Return early if there are no load balancing rules
821+
if serverConfig.LoadBalancer.LoadBalancingRules == nil {
822+
return errors
823+
}
824+
825+
// Validate each load balancing rule
826+
for _, rule := range serverConfig.LoadBalancer.LoadBalancingRules {
827+
// Validate the condition of the rule
828+
if err := validateRuleCondition(rule.Condition, configGroup); err != nil {
829+
errors = append(errors, err)
830+
}
831+
832+
// Validate the distribution of the rule
833+
if err := validateDistribution(
834+
rule.Distribution,
835+
configGroup,
836+
rule.Condition,
837+
clientConfigGroups,
838+
); err != nil {
839+
errors = append(errors, err)
840+
}
841+
}
842+
843+
return errors
844+
}
845+
846+
// validateRuleCondition checks if the rule condition is empty for LoadBalancingRules.
847+
func validateRuleCondition(condition string, configGroup string) error {
848+
if condition == "" {
849+
err := fmt.Errorf(`"servers.%s.loadBalancer.loadBalancingRules.condition" is nil or empty`, configGroup)
850+
return err
851+
}
852+
return nil
853+
}
854+
855+
// validateDistribution checks if the distribution list is valid for LoadBalancingRules.
856+
func validateDistribution(
857+
distributionList []Distribution,
858+
configGroup string,
859+
condition string,
860+
clientConfigGroups map[string]map[string]bool,
861+
) error {
862+
// Check if the distribution list is empty
863+
if len(distributionList) == 0 {
864+
return fmt.Errorf(
865+
`"servers.%s.loadBalancer.loadBalancingRules.distribution" is empty`,
866+
configGroup,
867+
)
868+
}
869+
870+
var totalWeight int
871+
for _, distribution := range distributionList {
872+
// Validate each distribution entry
873+
if err := validateDistributionEntry(distribution, configGroup, condition, clientConfigGroups); err != nil {
874+
return err
875+
}
876+
877+
// Check if adding the weight would exceed the maximum integer value
878+
if totalWeight > math.MaxInt-distribution.Weight {
879+
return fmt.Errorf(
880+
`"servers.%s.loadBalancer.loadBalancingRules.%s" total weight exceeds maximum int value`,
881+
configGroup,
882+
condition,
883+
)
884+
}
885+
886+
totalWeight += distribution.Weight
887+
}
888+
889+
return nil
890+
}
891+
892+
// validateDistributionEntry validates a single distribution entry for LoadBalancingRules.
893+
func validateDistributionEntry(
894+
distribution Distribution,
895+
configGroup string,
896+
condition string,
897+
clientConfigGroups map[string]map[string]bool,
898+
) error {
899+
// Check if the distribution.ProxyName is referenced in the proxy configuration
900+
if !clientConfigGroups[configGroup][distribution.ProxyName] {
901+
return fmt.Errorf(
902+
`"servers.%s.loadBalancer.loadBalancingRules.%s.%s" not referenced in proxy configuration`,
903+
configGroup,
904+
condition,
905+
distribution.ProxyName,
906+
)
907+
}
908+
909+
// Ensure that the distribution weight is positive
910+
if distribution.Weight <= 0 {
911+
return fmt.Errorf(
912+
`"servers.%s.loadBalancer.loadBalancingRules.%s.%s.weight" must be positive`,
913+
configGroup,
914+
condition,
915+
distribution.ProxyName,
916+
)
917+
}
918+
919+
return nil
920+
}

config/constants.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,12 @@ const (
9090
DefaultHealthCheckPeriod = 60 * time.Second // This must match PostgreSQL authentication timeout.
9191

9292
// Server constants.
93-
DefaultListenNetwork = "tcp"
94-
DefaultListenAddress = "0.0.0.0:15432"
95-
DefaultTickInterval = 5 * time.Second
96-
DefaultHandshakeTimeout = 5 * time.Second
97-
DefaultLoadBalancerStrategy = "ROUND_ROBIN"
93+
DefaultListenNetwork = "tcp"
94+
DefaultListenAddress = "0.0.0.0:15432"
95+
DefaultTickInterval = 5 * time.Second
96+
DefaultHandshakeTimeout = 5 * time.Second
97+
DefaultLoadBalancerStrategy = "ROUND_ROBIN"
98+
DefaultLoadBalancerCondition = "DEFAULT"
9899

99100
// Utility constants.
100101
DefaultSeed = 1000
@@ -129,6 +130,7 @@ const (
129130

130131
// Load balancing strategies.
131132
const (
132-
RoundRobinStrategy = "ROUND_ROBIN"
133-
RANDOMStrategy = "RANDOM"
133+
RoundRobinStrategy = "ROUND_ROBIN"
134+
RANDOMStrategy = "RANDOM"
135+
WeightedRoundRobinStrategy = "WEIGHTED_ROUND_ROBIN"
134136
)

config/types.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,19 @@ type Proxy struct {
9696
HealthCheckPeriod time.Duration `json:"healthCheckPeriod" jsonschema:"oneof_type=string;integer" yaml:"healthCheckPeriod"`
9797
}
9898

99+
type Distribution struct {
100+
ProxyName string `json:"proxyName"`
101+
Weight int `json:"weight"`
102+
}
103+
104+
type LoadBalancingRule struct {
105+
Condition string `json:"condition"`
106+
Distribution []Distribution `json:"distribution"`
107+
}
108+
99109
type LoadBalancer struct {
100-
Strategy string `json:"strategy"`
110+
Strategy string `json:"strategy"`
111+
LoadBalancingRules []LoadBalancingRule `json:"loadBalancingRules"`
101112
}
102113

103114
type Server struct {

errors/errors.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ const (
5555
ErrCodePublishAsyncAction
5656
ErrCodeLoadBalancerStrategyNotFound
5757
ErrCodeNoProxiesAvailable
58+
ErrCodeNoLoadBalancerRules
5859
)
5960

6061
var (
@@ -204,6 +205,10 @@ var (
204205
ErrCodeNoProxiesAvailable, "No proxies available to select.", nil,
205206
}
206207

208+
ErrNoLoadBalancerRules = &GatewayDError{
209+
ErrCodeNoLoadBalancerRules, "No load balancer rules provided.", nil,
210+
}
211+
207212
// Unwrapped errors.
208213
ErrLoggerRequired = errors.New("terminate action requires a logger parameter")
209214
)

gatewayd.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,15 @@ servers:
8383
address: 0.0.0.0:15432
8484
loadBalancer:
8585
# Load balancer strategies can be found in config/constants.go
86-
strategy: RANDOM # ROUND_ROBIN, RANDOM
86+
strategy: ROUND_ROBIN # ROUND_ROBIN, RANDOM, WEIGHTED_ROUND_ROBIN
87+
# Optional configuration for strategies that support rules (e.g., WEIGHTED_ROUND_ROBIN)
88+
# loadBalancingRules:
89+
# - condition: "DEFAULT" # Currently, only the "DEFAULT" condition is supported
90+
# distribution:
91+
# - proxyName: "writes"
92+
# weight: 70
93+
# - proxyName: "reads"
94+
# weight: 30
8795
enableTicker: False
8896
tickInterval: 5s # duration
8997
enableTLS: False

network/loadbalancer.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,35 @@ type LoadBalancerStrategy interface {
99
NextProxy() (IProxy, *gerr.GatewayDError)
1010
}
1111

12+
// NewLoadBalancerStrategy returns a LoadBalancerStrategy based on the server's load balancer strategy name.
13+
// If the server's load balancer strategy is weighted round-robin,
14+
// it selects a load balancer rule before returning the strategy.
15+
// Returns an error if the strategy is not found or if there are no load balancer rules when required.
1216
func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) {
1317
switch server.LoadbalancerStrategyName {
1418
case config.RoundRobinStrategy:
1519
return NewRoundRobin(server), nil
1620
case config.RANDOMStrategy:
1721
return NewRandom(server), nil
22+
case config.WeightedRoundRobinStrategy:
23+
if server.LoadbalancerRules == nil {
24+
return nil, gerr.ErrNoLoadBalancerRules
25+
}
26+
loadbalancerRule := selectLoadBalancerRule(server.LoadbalancerRules)
27+
return NewWeightedRoundRobin(server, loadbalancerRule), nil
1828
default:
1929
return nil, gerr.ErrLoadBalancerStrategyNotFound
2030
}
2131
}
32+
33+
// selectLoadBalancerRule selects and returns the first load balancer rule that matches the default condition.
34+
// If no rule matches, it returns the first rule in the list as a fallback.
35+
func selectLoadBalancerRule(rules []config.LoadBalancingRule) config.LoadBalancingRule {
36+
for _, rule := range rules {
37+
if rule.Condition == config.DefaultLoadBalancerCondition {
38+
return rule
39+
}
40+
}
41+
// Return the first rule as a fallback
42+
return rules[0]
43+
}

network/proxy.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ type IProxy interface {
3636
Shutdown()
3737
AvailableConnectionsString() []string
3838
BusyConnectionsString() []string
39+
GetName() string
3940
}
4041

4142
type Proxy struct {
43+
Name string
4244
AvailableConnections pool.IPool
4345
busyConnections pool.IPool
4446
Logger zerolog.Logger
@@ -136,6 +138,10 @@ func NewProxy(
136138
return &proxy
137139
}
138140

141+
func (pr *Proxy) GetName() string {
142+
return pr.Name
143+
}
144+
139145
// Connect maps a server connection from the available connection pool to a incoming connection.
140146
// It returns an error if the pool is exhausted.
141147
func (pr *Proxy) Connect(conn *ConnWrapper) *gerr.GatewayDError {

network/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ type Server struct {
7777
// loadbalancer
7878
loadbalancerStrategy LoadBalancerStrategy
7979
LoadbalancerStrategyName string
80+
LoadbalancerRules []config.LoadBalancingRule
8081
connectionToProxyMap map[*ConnWrapper]IProxy
8182
}
8283

@@ -696,6 +697,7 @@ func NewServer(
696697
stopServer: make(chan struct{}),
697698
connectionToProxyMap: make(map[*ConnWrapper]IProxy),
698699
LoadbalancerStrategyName: srv.LoadbalancerStrategyName,
700+
LoadbalancerRules: srv.LoadbalancerRules,
699701
}
700702

701703
// Try to resolve the address and log an error if it can't be resolved.

0 commit comments

Comments
 (0)