diff --git a/internal/remediation/root.go b/internal/remediation/root.go index 888f02c..1c61478 100644 --- a/internal/remediation/root.go +++ b/internal/remediation/root.go @@ -1,37 +1,140 @@ package remediation -// The order matters since we use slices.Max to get the max value +import ( + "sync" +) + +// Default weights for built-in remediations +// Allow=0, Unknown=1, then expand others to allow custom remediations to slot in between +const ( + WeightAllow = 0 + WeightUnknown = 1 + WeightCaptcha = 10 + WeightBan = 20 +) + +// Remediation represents a remediation type as a string +type Remediation string + +// Built-in remediation constants const ( - Allow Remediation = iota // Allow remediation - Unknown // Unknown remediation (Unknown is used to have a value for remediation we don't support EG "MFA") - Captcha // Captcha remediation - Ban // Ban remediation + Allow Remediation = "allow" + Unknown Remediation = "unknown" + Captcha Remediation = "captcha" + Ban Remediation = "ban" ) -type Remediation uint8 // Remediation type is smallest uint to save space +// registry manages remediation weights +type registry struct { + mu sync.RWMutex + weights map[string]int // Maps remediation name to its weight +} + +var globalRegistry = ®istry{ + weights: make(map[string]int), +} + +//nolint:gochecknoinits // init() is required to initialize default weights +func init() { + // Initialize built-in remediations with default weights + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + + globalRegistry.weights["allow"] = WeightAllow + globalRegistry.weights["unknown"] = WeightUnknown + globalRegistry.weights["captcha"] = WeightCaptcha + globalRegistry.weights["ban"] = WeightBan +} + +// SetWeight sets a custom weight for a remediation (for configuration) +func SetWeight(name string, weight int) { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + + globalRegistry.weights[name] = weight +} + +// GetWeight returns the weight for a remediation name +func GetWeight(name string) int { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + if weight, exists := globalRegistry.weights[name]; exists { + return weight + } + // Default to Unknown weight for unknown remediations + return WeightUnknown +} + +// LoadWeights loads weights for multiple remediations at once (for startup initialization) +func LoadWeights(weights map[string]int) { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + for name, weight := range weights { + globalRegistry.weights[name] = weight + } +} + +// New creates a new Remediation from a string. +func New(name string) Remediation { + return Remediation(name) +} + +// String returns the remediation name func (r Remediation) String() string { - switch r { - case Ban: - return "ban" - case Captcha: - return "captcha" - case Unknown: - return "unknown" - default: - return "allow" + if r == "" { + return "allow" // Default fallback } + return string(r) +} + +// Compare returns: +// - negative if a < b +// - zero if a == b +// - positive if a > b +func Compare(a, b Remediation) int { + weightA := GetWeight(a.String()) + weightB := GetWeight(b.String()) + return weightA - weightB +} + +// IsHigher returns true if a has a higher weight than b +func IsHigher(a, b Remediation) bool { + return Compare(a, b) > 0 +} + +// IsLower returns true if a has a lower weight than b +func IsLower(a, b Remediation) bool { + return Compare(a, b) < 0 +} + +// IsEqual returns true if a represents the same remediation as b. +// This compares the remediation names (strings). +func IsEqual(a, b Remediation) bool { + return a == b +} + +// HasSameWeight returns true if a has the same weight as b. +// This is useful for checking if two different remediations have the same priority. +// Note: Two remediations with the same weight will be compared by name (alphabetical) +// as a tie-breaker when determining priority. +func HasSameWeight(a, b Remediation) bool { + return Compare(a, b) == 0 } +// IsWeighted returns true if r is not Allow (has weight > Allow) +// This is useful for checking if a remediation should be applied +func IsWeighted(r Remediation) bool { + return GetWeight(r.String()) > WeightAllow +} + +// FromString creates a Remediation from a string (alias for New for backward compatibility) func FromString(s string) Remediation { - switch s { - case "ban": - return Ban - case "captcha": - return Captcha - case "allow": - return Allow - default: - return Unknown - } + return New(s) +} + +// IsZero returns true if the remediation is zero-valued +func (r Remediation) IsZero() bool { + return r == "" } diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index 3f5b923..e31e5f6 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -7,6 +7,7 @@ import ( "gopkg.in/yaml.v2" "github.com/crowdsecurity/crowdsec-spoa/internal/geo" + "github.com/crowdsecurity/crowdsec-spoa/internal/remediation" "github.com/crowdsecurity/crowdsec-spoa/pkg/host" cslogging "github.com/crowdsecurity/crowdsec-spoa/pkg/logging" "github.com/crowdsecurity/go-cs-lib/csyaml" @@ -36,6 +37,32 @@ type BouncerConfig struct { ListenUnix string `yaml:"listen_unix"` PrometheusConfig PrometheusConfig `yaml:"prometheus"` PprofConfig PprofConfig `yaml:"pprof"` + // RemediationWeights allows users to configure custom weights for remediations. + // + // Format: + // remediation_weights: + // : + // + // Example: + // remediation_weights: + // mfa: 15 # slots between captcha (10) and ban (20) + // + // Valid weight range: integer values >= 0. Lower values are less severe; higher values are more severe. + // Recommended: Use values between 0 and 100. + // + // Built-in defaults: + // allow=0, unknown=1, captcha=10, ban=20 + // + // Custom weights override or supplement built-in remediations. If a custom remediation is defined, + // its weight will be used for ordering and severity. Custom remediations can slot between built-in + // ones by choosing an appropriate weight value. + // + // Tie-breaking: If two remediations have the same weight, alphabetical order of the remediation + // name is used as a deterministic tie-breaker when determining priority. + // + // Note: Custom weights for built-in remediations (allow, unknown, captcha, ban) must be set + // before package initialization. After init(), package-level constants already have cached weights. + RemediationWeights map[string]int `yaml:"remediation_weights,omitempty"` } // MergedConfig() returns the byte content of the patched configuration file (with .yaml.local). @@ -67,6 +94,11 @@ func NewConfig(reader io.Reader) (*BouncerConfig, error) { return nil, fmt.Errorf("failed to setup logging: %w", err) } + // Load custom remediation weights if configured (loads all weights at once on startup) + if config.RemediationWeights != nil { + remediation.LoadWeights(config.RemediationWeights) + } + if err := config.Validate(); err != nil { return nil, err } diff --git a/pkg/dataset/bart_types.go b/pkg/dataset/bart_types.go index eac4f54..81b82da 100644 --- a/pkg/dataset/bart_types.go +++ b/pkg/dataset/bart_types.go @@ -15,7 +15,7 @@ import ( type BartAddOp struct { Prefix netip.Prefix Origin string - R remediation.Remediation + R string // Remediation name as string IPType string Scope string } @@ -23,7 +23,7 @@ type BartAddOp struct { // BartRemoveOp represents a single prefix removal operation for batch processing type BartRemoveOp struct { Prefix netip.Prefix - R remediation.Remediation + R string // Remediation name as string Origin string IPType string Scope string @@ -91,7 +91,7 @@ func (s *BartRangeSet) initializeBatch(operations []BartAddOp) { // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String()) + valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R) valueLog.Trace("initial load: collecting prefix operations") } @@ -101,7 +101,7 @@ func (s *BartRangeSet) initializeBatch(operations []BartAddOp) { data = RemediationMap{} } // Add the remediation (this handles merging if prefix already seen) - data.Add(valueLog, op.R, op.Origin) + data.Add(valueLog, remediation.FromString(op.R), op.Origin) prefixMap[prefix] = data } @@ -134,7 +134,7 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [ // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String()) + valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R) valueLog.Trace("adding to bart trie") } @@ -146,7 +146,7 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [ valueLog.Trace("exact prefix exists, merging remediations") } // bart already cloned via our Cloner interface, modify directly - existingData.Add(valueLog, op.R, op.Origin) + existingData.Add(valueLog, remediation.FromString(op.R), op.Origin) return existingData, false // false = don't delete } if valueLog != nil { @@ -154,7 +154,7 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [ } // Create new data newData := make(RemediationMap) - newData.Add(valueLog, op.R, op.Origin) + newData.Add(valueLog, remediation.FromString(op.R), op.Origin) return newData, false // false = don't delete }) } @@ -193,7 +193,7 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp { // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String()) + valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R) valueLog.Trace("removing from bart trie") } @@ -210,11 +210,12 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp { // Check if the remediation exists with the matching origin before removing // This prevents removing decisions when the origin has been overwritten (e.g., by CAPI) - if !existingData.HasRemediationWithOrigin(op.R, op.Origin) { + if !existingData.HasRemediationWithOrigin(remediation.FromString(op.R), op.Origin) { // Origin doesn't match - this decision was likely overwritten by another origin // Don't remove it, as it's not the decision we're trying to delete if valueLog != nil { - storedOrigin, exists := existingData[op.R] + r := remediation.FromString(op.R) + storedOrigin, exists := existingData[r] if exists { valueLog.Tracef("remediation exists but origin mismatch (stored: %s, requested: %s), skipping removal", storedOrigin, op.Origin) } else { @@ -228,7 +229,7 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp { // bart already cloned via our Cloner interface, modify directly // Remove returns an error if remediation doesn't exist (duplicate delete) // We already checked origin above, so this should succeed - err := existingData.Remove(valueLog, op.R) + err := existingData.Remove(valueLog, remediation.FromString(op.R)) if errors.Is(err, ErrRemediationNotFound) { // This shouldn't happen since we checked above, but handle it gracefully if valueLog != nil { @@ -288,11 +289,11 @@ func (s *BartRangeSet) Contains(ip netip.Addr) (remediation.Remediation, string) return remediation.Allow, "" } - remediationResult, origin := data.GetRemediationAndOrigin() + r, origin := data.GetRemediationAndOrigin() if valueLog != nil { - valueLog.Tracef("bart result: %s (data: %+v)", remediationResult.String(), data) + valueLog.Tracef("bart result: %s (data: %+v)", r.String(), data) } - return remediationResult, origin + return r, origin } // HasRemediation checks if an exact prefix has a specific remediation with a specific origin. diff --git a/pkg/dataset/benchmark_test.go b/pkg/dataset/benchmark_test.go index 6ed3a43..98bda65 100644 --- a/pkg/dataset/benchmark_test.go +++ b/pkg/dataset/benchmark_test.go @@ -174,13 +174,13 @@ func TestCorrectness(t *testing.T) { continue } - // Basic sanity check - result should be valid - if result < remediation.Allow { + // Basic sanity check - result should be valid (not less than Allow weight) + if remediation.Compare(result, remediation.Allow) < 0 { t.Errorf("Invalid result for IP %s: %v", testIP.String(), result) } - // Origin should be non-empty if we have a match - if result > remediation.Allow && origin == "" { + // Origin should be non-empty if we have a match (result > Allow) + if remediation.Compare(result, remediation.Allow) > 0 && origin == "" { t.Errorf("Empty origin for IP %s with result %v", testIP.String(), result) } } @@ -192,38 +192,38 @@ func TestLongestPrefixMatch(t *testing.T) { // Add individual IP to IPMap and ranges to RangeSet dataset.IPMap.AddBatch([]IPAddOp{ - {IP: netip.MustParseAddr("192.168.1.1"), Origin: "test", R: remediation.Allow, IPType: "ipv4"}, + {IP: netip.MustParseAddr("192.168.1.1"), Origin: "test", R: remediation.Allow.String(), IPType: "ipv4"}, }) dataset.RangeSet.AddBatch([]BartAddOp{ - {Prefix: netip.MustParsePrefix("192.168.0.0/16"), Origin: "test", R: remediation.Ban, IPType: "ipv4", Scope: "range"}, - {Prefix: netip.MustParsePrefix("192.168.1.0/24"), Origin: "test", R: remediation.Captcha, IPType: "ipv4", Scope: "range"}, + {Prefix: netip.MustParsePrefix("192.168.0.0/16"), Origin: "test", R: remediation.Ban.String(), IPType: "ipv4", Scope: "range"}, + {Prefix: netip.MustParsePrefix("192.168.1.0/24"), Origin: "test", R: remediation.Captcha.String(), IPType: "ipv4", Scope: "range"}, }) // Test that individual IP from IPMap wins (checked first before RangeSet) ip1 := netip.MustParseAddr("192.168.1.1") result, _, _ := dataset.CheckIP(ip1) - if result != remediation.Allow { + if !remediation.IsEqual(result, remediation.Allow) { t.Errorf("Expected Allow for 192.168.1.1 (from IPMap), got %v", result) } // Test that we get the LPM from RangeSet (Captcha /24 wins over Ban /16) ip2 := netip.MustParseAddr("192.168.1.2") result, _, _ = dataset.CheckIP(ip2) - if result != remediation.Captcha { + if !remediation.IsEqual(result, remediation.Captcha) { t.Errorf("Expected Captcha for 192.168.1.2 (LPM from RangeSet), got %v", result) } // Test that we get the broadest match from RangeSet ip3 := netip.MustParseAddr("192.168.2.1") result, _, _ = dataset.CheckIP(ip3) - if result != remediation.Ban { + if !remediation.IsEqual(result, remediation.Ban) { t.Errorf("Expected Ban for 192.168.2.1 (from RangeSet), got %v", result) } // Test that we get no match ip4 := netip.MustParseAddr("10.0.0.1") result, _, _ = dataset.CheckIP(ip4) - if result != remediation.Allow { + if !remediation.IsEqual(result, remediation.Allow) { t.Errorf("Expected Allow for 10.0.0.1 (no match), got %v", result) } } @@ -354,7 +354,7 @@ func BenchmarkHybridVsBartOnly(b *testing.B) { ops = append(ops, BartAddOp{ Prefix: netip.PrefixFrom(ip, prefixLen), Origin: *d.Origin, - R: remediation.Ban, + R: remediation.Ban.String(), IPType: "ipv4", Scope: "ip", }) @@ -384,13 +384,13 @@ func BenchmarkLookupHybrid(b *testing.B) { byte(i % 256), }) dataset.IPMap.AddBatch([]IPAddOp{ - {IP: ip, Origin: "test", R: remediation.Ban, IPType: "ipv4"}, + {IP: ip, Origin: "test", R: remediation.Ban.String(), IPType: "ipv4"}, }) } // Add some ranges to RangeSet dataset.RangeSet.AddBatch([]BartAddOp{ - {Prefix: netip.MustParsePrefix("192.168.0.0/16"), Origin: "test", R: remediation.Ban, IPType: "ipv4", Scope: "range"}, + {Prefix: netip.MustParsePrefix("192.168.0.0/16"), Origin: "test", R: remediation.Ban.String(), IPType: "ipv4", Scope: "range"}, }) // Test IPs - some in IPMap, some in RangeSet, some not found diff --git a/pkg/dataset/ipmap.go b/pkg/dataset/ipmap.go index 96a6338..747727f 100644 --- a/pkg/dataset/ipmap.go +++ b/pkg/dataset/ipmap.go @@ -50,14 +50,14 @@ func NewIPMap(logAlias string) *IPMap { type IPAddOp struct { IP netip.Addr Origin string - R remediation.Remediation + R string // Remediation name as string IPType string } // IPRemoveOp represents a remove operation for an individual IP type IPRemoveOp struct { IP netip.Addr - R remediation.Remediation + R string // Remediation name as string Origin string IPType string } @@ -78,7 +78,7 @@ func (m *IPMap) AddBatch(operations []IPAddOp) { func (m *IPMap) add(op IPAddOp) { var valueLog *log.Entry if m.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = m.logger.WithField("ip", op.IP.String()).WithField("remediation", op.R.String()) + valueLog = m.logger.WithField("ip", op.IP.String()).WithField("remediation", op.R) valueLog.Trace("adding IP to map") } @@ -102,7 +102,7 @@ func (m *IPMap) add(op IPAddOp) { // Empty or nil - no need to clone, just create new map newData = make(RemediationMap) } - newData.Add(valueLog, op.R, op.Origin) + newData.Add(valueLog, remediation.FromString(op.R), op.Origin) entry.data.Store(&newData) return } @@ -111,7 +111,7 @@ func (m *IPMap) add(op IPAddOp) { // Create new entry with data // Store directly (no LoadOrStore race needed since application uses single writer) newData := make(RemediationMap) - newData.Add(valueLog, op.R, op.Origin) + newData.Add(valueLog, remediation.FromString(op.R), op.Origin) entry := &ipEntry{} entry.data.Store(&newData) ipMap.Store(op.IP, entry) @@ -139,7 +139,7 @@ func (m *IPMap) RemoveBatch(operations []IPRemoveOp) []*IPRemoveOp { func (m *IPMap) remove(op IPRemoveOp) bool { var valueLog *log.Entry if m.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = m.logger.WithField("ip", op.IP.String()).WithField("remediation", op.R.String()) + valueLog = m.logger.WithField("ip", op.IP.String()).WithField("remediation", op.R) valueLog.Trace("removing IP from map") } @@ -174,11 +174,12 @@ func (m *IPMap) remove(op IPRemoveOp) bool { // Check if the remediation exists with the matching origin before removing // This prevents removing decisions when the origin has been overwritten (e.g., by CAPI) - if !current.HasRemediationWithOrigin(op.R, op.Origin) { + if !current.HasRemediationWithOrigin(remediation.FromString(op.R), op.Origin) { // Origin doesn't match - this decision was likely overwritten by another origin // Don't remove it, as it's not the decision we're trying to delete if valueLog != nil { - storedOrigin, exists := (*current)[op.R] + r := remediation.FromString(op.R) + storedOrigin, exists := (*current)[r] if exists { valueLog.Tracef("remediation exists but origin mismatch (stored: %s, requested: %s), skipping removal", storedOrigin, op.Origin) } else { @@ -195,7 +196,7 @@ func (m *IPMap) remove(op IPRemoveOp) bool { // Remove returns an error if remediation doesn't exist (duplicate delete) // We already checked origin above, so this should succeed - err := newData.Remove(valueLog, op.R) + err := newData.Remove(valueLog, remediation.FromString(op.R)) if errors.Is(err, ErrRemediationNotFound) { // This shouldn't happen since we checked above, but handle it gracefully if valueLog != nil { diff --git a/pkg/dataset/metrics_test.go b/pkg/dataset/metrics_test.go index 2885ef9..c79ed6e 100644 --- a/pkg/dataset/metrics_test.go +++ b/pkg/dataset/metrics_test.go @@ -746,7 +746,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { r, foundOrigin, found := dataSet.IPMap.Contains(ip) assert.True(t, found, "IP should still exist") assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban, r, "remediation should be ban") + assert.True(t, remediation.IsEqual(r, remediation.Ban), "remediation should be ban") }) t.Run("Duplicate range decision is no-op", func(t *testing.T) { @@ -775,7 +775,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { require.NoError(t, err) r, foundOrigin := dataSet.RangeSet.Contains(testIP) assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban, r, "remediation should be ban") + assert.True(t, remediation.IsEqual(r, remediation.Ban), "remediation should be ban") }) t.Run("Duplicate country decision is no-op", func(t *testing.T) { @@ -802,7 +802,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { // Verify decision still exists r, foundOrigin := dataSet.CNSet.Contains("US") assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban, r, "remediation should be ban") + assert.True(t, remediation.IsEqual(r, remediation.Ban), "remediation should be ban") }) t.Run("Same IP different remediation is not no-op", func(t *testing.T) { @@ -866,7 +866,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { require.NoError(t, err) r, storedOrigin, found := dataSet.IPMap.Contains(ip) assert.True(t, found, "IP should exist") - assert.Equal(t, remediation.Ban, r, "IP should have ban remediation") + assert.True(t, remediation.IsEqual(r, remediation.Ban), "IP should have ban remediation") assert.Equal(t, origin, storedOrigin, "IP should have CAPI origin") // Step 2: Add captcha from same CAPI origin (overwrites ban) @@ -890,7 +890,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { // Both ban and captcha exist in the map, but ban is returned as highest priority r2, storedOrigin2, found2 := dataSet.IPMap.Contains(ip) assert.True(t, found2, "IP should still exist") - assert.Equal(t, remediation.Ban, r2, "IP should have ban remediation (highest priority, ban > captcha)") + assert.True(t, remediation.IsEqual(r2, remediation.Ban), "IP should have ban remediation (highest priority, ban > captcha)") assert.Equal(t, origin, storedOrigin2, "IP should still have CAPI origin") // Step 3: Delete ban from CAPI @@ -913,7 +913,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { // Verify IP now has captcha (ban was removed, captcha is now active) r3, storedOrigin3, found3 := dataSet.IPMap.Contains(ip) assert.True(t, found3, "IP should still exist") - assert.Equal(t, remediation.Captcha, r3, "IP should now have captcha remediation (ban was removed)") + assert.True(t, remediation.IsEqual(r3, remediation.Captcha), "IP should now have captcha remediation (ban was removed)") assert.Equal(t, origin, storedOrigin3, "IP should still have CAPI origin") }) } diff --git a/pkg/dataset/root.go b/pkg/dataset/root.go index 770ee6e..5ac94d6 100644 --- a/pkg/dataset/root.go +++ b/pkg/dataset/root.go @@ -64,6 +64,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { scope := strings.ToLower(*decision.Scope) r := remediation.FromString(*decision.Type) + remediationName := r.String() // Convert to string for operations structs switch scope { case "ip": @@ -82,13 +83,13 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { ipType = "ipv6" } // Check if we're overwriting an existing decision with different origin - if existingR, existingOrigin, found := d.IPMap.Contains(ip); found && existingR == r && existingOrigin != origin { + if existingR, existingOrigin, found := d.IPMap.Contains(ip); found && remediation.IsEqual(existingR, r) && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, ipType, "ip").Dec() } // Individual IPs go to IPMap for memory efficiency - ipOps = append(ipOps, IPAddOp{IP: ip, Origin: origin, R: r, IPType: ipType}) + ipOps = append(ipOps, IPAddOp{IP: ip, Origin: origin, R: remediationName, IPType: ipType}) // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(origin, ipType, "ip").Inc() case "range": @@ -113,7 +114,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, ipType, "range").Dec() } // Ranges go to BART for LPM support - rangeOps = append(rangeOps, BartAddOp{Prefix: prefix, Origin: origin, R: r, IPType: ipType, Scope: "range"}) + rangeOps = append(rangeOps, BartAddOp{Prefix: prefix, Origin: origin, R: remediationName, IPType: ipType, Scope: "range"}) // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(origin, ipType, "range").Inc() case "country": @@ -125,7 +126,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { continue } // Check if we're overwriting an existing decision with different origin - if existingR, existingOrigin := d.CNSet.Contains(cn); existingR == r && existingOrigin != "" && existingOrigin != origin { + if existingR, existingOrigin := d.CNSet.Contains(cn); remediation.IsEqual(existingR, r) && existingOrigin != "" && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, "", "country").Dec() @@ -177,15 +178,15 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { } log.Infof("Processing %d deleted decisions", len(decisions)) + // Separate operations by type + // Note: We don't pre-allocate capacity here because many decisions might be no-ops + // (duplicates) and would waste allocated memory. Let Go handle dynamic growth. type cnOp struct { cn string - r remediation.Remediation origin string + r remediation.Remediation } - // Separate operations by type - // Note: We don't pre-allocate capacity here because many decisions might be no-ops - // (duplicates) and would waste allocated memory. Let Go handle dynamic growth. ipOps := make([]IPRemoveOp, 0) rangeOps := make([]BartRemoveOp, 0) cnOps := make([]cnOp, 0) @@ -202,6 +203,7 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { scope := strings.ToLower(*decision.Scope) r := remediation.FromString(*decision.Type) + remediationName := r.String() // Convert to string for storage switch scope { case "ip": @@ -214,7 +216,7 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { if ip.Is6() { ipType = "ipv6" } - ipOps = append(ipOps, IPRemoveOp{IP: ip, R: r, Origin: origin, IPType: ipType}) + ipOps = append(ipOps, IPRemoveOp{IP: ip, R: remediationName, Origin: origin, IPType: ipType}) case "range": prefix, err := netip.ParsePrefix(*decision.Value) if err != nil { @@ -225,7 +227,7 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { if prefix.Addr().Is6() { ipType = "ipv6" } - rangeOps = append(rangeOps, BartRemoveOp{Prefix: prefix, R: r, Origin: origin, IPType: ipType, Scope: "range"}) + rangeOps = append(rangeOps, BartRemoveOp{Prefix: prefix, R: remediationName, Origin: origin, IPType: ipType, Scope: "range"}) case "country": // Clone country code to break reference to Decision struct memory cnOps = append(cnOps, cnOp{cn: strings.Clone(*decision.Value), r: r, origin: origin}) @@ -306,7 +308,8 @@ func (d *DataSet) CheckIP(ip netip.Addr) (remediation.Remediation, string, error } func (d *DataSet) CheckCN(cn string) (remediation.Remediation, string) { - return d.CNSet.Contains(cn) + r, origin := d.CNSet.Contains(cn) + return r, origin } // Helper method for CN operations (still needed for country scope) diff --git a/pkg/dataset/root_test.go b/pkg/dataset/root_test.go index 3242fb5..15e186b 100644 --- a/pkg/dataset/root_test.go +++ b/pkg/dataset/root_test.go @@ -16,7 +16,7 @@ type toCheck struct { Value string // IP, Country Scope string // IP, Country Origin string - Type remediation.Remediation + Type remediation.Remediation // remediation type } func TestDataSet(t *testing.T) { @@ -169,7 +169,7 @@ func TestDataSet(t *testing.T) { t.Fatalf("unknown scope %s", tt.toCheck.Scope) } require.NoError(t, err) - assert.Equal(t, r, tt.toCheck.Type) + assert.True(t, remediation.IsEqual(r, tt.toCheck.Type), "remediation should match: got %s, expected %s", r.String(), tt.toCheck.Type.String()) assert.Equal(t, origin, tt.toCheck.Origin) }) } diff --git a/pkg/dataset/types.go b/pkg/dataset/types.go index 4ba4f3e..bb934cb 100644 --- a/pkg/dataset/types.go +++ b/pkg/dataset/types.go @@ -12,7 +12,7 @@ import ( // ErrRemediationNotFound is returned when attempting to remove a remediation that doesn't exist. var ErrRemediationNotFound = errors.New("remediation not found") -// RemediationMap stores one origin string per remediation type. +// RemediationMap stores one origin string per remediation type (using Remediation as keys). // ID is not tracked since LAPI behavior ensures we only have the longest decision. // // LAPI behavior: @@ -20,6 +20,9 @@ var ErrRemediationNotFound = errors.New("remediation not found") // - Stream: Only returns NEW decisions if they're LONGER than current // - Deletions: Delete means user wants to allow the IP - just remove the remediation entry. // Duplicate deletes are safely ignored (entry already gone). +// +// Keys are remediation.Remediation types (string type). +// Weight comparison is done via remediation.Compare() when determining priority. type RemediationMap map[remediation.Remediation]string // Remove removes a remediation entry (deletion means user wants to allow the IP). @@ -54,16 +57,32 @@ func (rM RemediationMap) Add(clog *log.Entry, r remediation.Remediation, origin } // GetRemediationAndOrigin returns the highest priority remediation and its origin. +// Priority is determined by comparing weights using remediation.Compare(). +// If two remediations have the same weight, alphabetical order of the name is used as a tie-breaker +// to ensure deterministic behavior. func (rM RemediationMap) GetRemediationAndOrigin() (remediation.Remediation, string) { var maxRemediation remediation.Remediation var maxOrigin string first := true - for k, v := range rM { - if first || k > maxRemediation { - maxRemediation = k - maxOrigin = v + for r, origin := range rM { + if first { + maxRemediation = r + maxOrigin = origin first = false + continue + } + + // Compare by weight first + if remediation.IsHigher(r, maxRemediation) { + maxRemediation = r + maxOrigin = origin + } else if remediation.HasSameWeight(r, maxRemediation) { + // Tie-breaker: use alphabetical order of name for deterministic behavior + if r.String() < maxRemediation.String() { + maxRemediation = r + maxOrigin = origin + } } } diff --git a/pkg/spoa/root.go b/pkg/spoa/root.go index 9d73c90..8155f66 100644 --- a/pkg/spoa/root.go +++ b/pkg/spoa/root.go @@ -185,7 +185,7 @@ type HTTPRequestData struct { // First stage is to check the host header and determine if the remediation from handleIpRequest is still valid // Second stage is to check if AppSec is enabled and then forward to the component if needed func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { - r := remediation.Allow + r := remediation.Allow // Default to allow var origin string shouldCountMetrics := false @@ -208,7 +208,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { } if rstring != nil { - r = remediation.FromString(*rstring) + r = remediation.FromString(*rstring) // Convert to Remediation type // Remediation came from IP check, already counted } @@ -218,12 +218,11 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // defer a function that always add the remediation to the request at end of processing defer func() { - if matchedHost == nil && r == remediation.Captcha { + if matchedHost == nil && remediation.IsEqual(r, remediation.Captcha) { s.logger.Warn("remediation is captcha, no matching host was found cannot issue captcha remediation reverting to ban") r = remediation.Ban } - rString := r.String() - req.Actions.SetVar(action.ScopeTransaction, "remediation", rString) + req.Actions.SetVar(action.ScopeTransaction, "remediation", r.String()) // Count metrics if this is the only handler (upstream proxy mode) if shouldCountMetrics { @@ -242,8 +241,9 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // Count processed request - use WithLabelValues to avoid map allocation on hot path metrics.TotalProcessedRequests.WithLabelValues(ipTypeLabel).Inc() - // Count blocked request if remediation applied - if r > remediation.Unknown { + // Count blocked request if remediation applied (not Allow) + // This includes Unknown, Captcha, Ban, and any custom remediations + if remediation.IsWeighted(r) { // Label order: origin, ip_type, remediation (as defined in metrics.go) metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r.String()).Inc() } @@ -314,13 +314,19 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { r, httpData = s.handleCaptchaRemediation(req, mes, matchedHost) // If remediation changed to fallback, return early // If it became Allow, continue for AppSec processing + // Note: Direct struct comparison works because Remediation is comparable if r != remediation.Captcha && r != remediation.Allow { return } + default: + // Unknown or custom remediation: currently, only HTTP data is parsed for AppSec processing. + // If a custom remediation requires special handling (like Ban), this must be implemented explicitly. + httpData = parseHTTPData(s.logger, mes) } - // If remediation is ban/captcha we dont need to create a request to send to appsec unless always send is on - if r > remediation.Unknown && !matchedHost.AppSec.AlwaysSend { + // If remediation is not allow, we dont need to create a request to send to appsec unless always send is on + // This includes Unknown, Captcha, Ban, and any custom remediations + if remediation.IsWeighted(r) && !matchedHost.AppSec.AlwaysSend { return } // !TODO APPSEC STUFF - httpData contains parsed URL, Method, Body, Headers for reuse @@ -640,10 +646,10 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediatio // Always set the ISO code variable when available req.Actions.SetVar(action.ScopeTransaction, "isocode", iso) - // If no IP-specific remediation, check country-based remediation - if r < remediation.Unknown { + // If no IP-specific remediation (Allow), check country-based remediation + if remediation.IsEqual(r, remediation.Allow) { cnR, cnOrigin := s.dataset.CheckCN(iso) - if cnR > remediation.Unknown { + if remediation.IsHigher(cnR, remediation.Allow) { r = cnR origin = cnOrigin } @@ -680,8 +686,9 @@ func (s *Spoa) handleIPRequest(req *request.Request, mes *message.Message) { // Check IP directly against dataset r, origin := s.getIPRemediation(req, ipAddr) - // Count blocked requests - if r > remediation.Unknown { + // Count blocked requests (not Allow) + // This includes Unknown, Captcha, Ban, and any custom remediations + if remediation.IsWeighted(r) { // Label order: origin, ip_type, remediation (as defined in metrics.go) metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r.String()).Inc() }