From 116ef5b29fc0bab336f6ee5a996ac4be17ed83a1 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 15:12:01 +0000 Subject: [PATCH 01/11] Refactor remediations to use string-based keys with configurable weights - Replace iota-based Remediation enum with string-based system - Add weight system: Allow=0, Unknown=1, Captcha=10, Ban=20 - Implement string deduplication using pointers to reduce allocations - Add remediation_weights configuration option for custom remediations - Update RemediationMap to use string keys instead of enum - Update GetRemediationAndOrigin to compare weights for priority - Update all dataset and SPOA code to use string-based remediations - Fix metrics and AppSec checks to use > WeightAllow instead of > WeightUnknown to properly include custom remediations that default to Unknown weight Fixes #136 --- internal/remediation/root.go | 152 +++++++++++++++++++++++++++++------ pkg/cfg/config.go | 13 +++ pkg/dataset/bart_types.go | 27 +++---- pkg/dataset/ipmap.go | 23 +++--- pkg/dataset/root.go | 44 +++++----- pkg/dataset/types.go | 75 +++++++++-------- pkg/spoa/root.go | 71 ++++++++-------- 7 files changed, 265 insertions(+), 140 deletions(-) diff --git a/internal/remediation/root.go b/internal/remediation/root.go index 888f02c..a8c1e58 100644 --- a/internal/remediation/root.go +++ b/internal/remediation/root.go @@ -1,37 +1,139 @@ package remediation -// The order matters since we use slices.Max to get the max value +import ( + "sync" +) + +// Default weights for built-in remediations +// Allow=0, Unknown=1, then expand others to allow custom remediations to slot in between const ( - Allow Remediation = iota // Allow remediation - Unknown // Unknown remediation (Unknown is used to have a value for remediation we don't support EG "MFA") - Captcha // Captcha remediation - Ban // Ban remediation + WeightAllow = 0 + WeightUnknown = 1 + WeightCaptcha = 10 + WeightBan = 20 ) -type Remediation uint8 // Remediation type is smallest uint to save space +// Remediation represents a remediation type as a string +// We use string pointers for deduplication to reduce allocations +type Remediation struct { + name *string // Pointer to deduplicated string + weight int // Weight for comparison (higher = more severe) +} + +// registry manages deduplicated remediation strings and their weights +type registry struct { + mu sync.RWMutex + strings map[string]*string // Maps string to its deduplicated pointer + weights map[string]int // Maps remediation name to its weight +} + +var globalRegistry = ®istry{ + strings: make(map[string]*string), + weights: make(map[string]int), +} + +func init() { + // Initialize built-in remediations with default weights + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + + globalRegistry.weights["allow"] = WeightAllow + globalRegistry.weights["unknown"] = WeightUnknown + globalRegistry.weights["captcha"] = WeightCaptcha + globalRegistry.weights["ban"] = WeightBan + // Pre-create deduplicated strings for built-in remediations + for name := range globalRegistry.weights { + deduped := name + globalRegistry.strings[name] = &deduped + } +} + +// SetWeight sets a custom weight for a remediation (for configuration) +func SetWeight(name string, weight int) { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + + globalRegistry.weights[name] = weight + // Ensure deduplicated string exists + if _, exists := globalRegistry.strings[name]; !exists { + deduped := name + globalRegistry.strings[name] = &deduped + } +} + +// GetWeight returns the weight for a remediation name +func GetWeight(name string) int { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + if weight, exists := globalRegistry.weights[name]; exists { + return weight + } + // Default to Unknown weight for unknown remediations + return WeightUnknown +} + +// Built-in remediation constants (for convenience) +var ( + Allow = New("allow") + Unknown = New("unknown") + Captcha = New("captcha") + Ban = New("ban") +) + +// New creates a new Remediation from a string +// Uses deduplicated string pointers to reduce allocations +func New(name string) Remediation { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + + // Get or create deduplicated string pointer + deduped, exists := globalRegistry.strings[name] + if !exists { + // Create new deduplicated string + deduped = &name + globalRegistry.strings[name] = deduped + // Set default weight if not configured + if _, hasWeight := globalRegistry.weights[name]; !hasWeight { + globalRegistry.weights[name] = WeightUnknown + } + } + + weight := globalRegistry.weights[name] + return Remediation{ + name: deduped, + weight: weight, + } +} + +// String returns the remediation name func (r Remediation) String() string { - switch r { - case Ban: - return "ban" - case Captcha: - return "captcha" - case Unknown: - return "unknown" - default: - return "allow" + if r.name == nil { + return "allow" // Default fallback } + return *r.name +} + +// Weight returns the weight of the remediation +func (r Remediation) Weight() int { + return r.weight +} + +// Compare returns: +// - negative if r < other +// - zero if r == other +// - positive if r > other +func (r Remediation) Compare(other Remediation) int { + return r.weight - other.weight } +// FromString creates a Remediation from a string (alias for New for backward compatibility) func FromString(s string) Remediation { - switch s { - case "ban": - return Ban - case "captcha": - return Captcha - case "allow": - return Allow - default: - return Unknown - } + return New(s) +} + +// IsZero returns true if the remediation is zero-valued +func (r Remediation) IsZero() bool { + return r.name == nil } diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index 3f5b923..7eec17e 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -7,6 +7,7 @@ import ( "gopkg.in/yaml.v2" "github.com/crowdsecurity/crowdsec-spoa/internal/geo" + "github.com/crowdsecurity/crowdsec-spoa/internal/remediation" "github.com/crowdsecurity/crowdsec-spoa/pkg/host" cslogging "github.com/crowdsecurity/crowdsec-spoa/pkg/logging" "github.com/crowdsecurity/go-cs-lib/csyaml" @@ -36,6 +37,11 @@ type BouncerConfig struct { ListenUnix string `yaml:"listen_unix"` PrometheusConfig PrometheusConfig `yaml:"prometheus"` PprofConfig PprofConfig `yaml:"pprof"` + // RemediationWeights allows users to configure custom weights for remediations + // Format: map[string]int where key is remediation name and value is weight + // Built-in defaults: allow=0, unknown=1, captcha=10, ban=20 + // Custom remediations can slot between these values + RemediationWeights map[string]int `yaml:"remediation_weights,omitempty"` } // MergedConfig() returns the byte content of the patched configuration file (with .yaml.local). @@ -67,6 +73,13 @@ func NewConfig(reader io.Reader) (*BouncerConfig, error) { return nil, fmt.Errorf("failed to setup logging: %w", err) } + // Apply custom remediation weights if configured + if config.RemediationWeights != nil { + for remediationName, weight := range config.RemediationWeights { + remediation.SetWeight(remediationName, weight) + } + } + if err := config.Validate(); err != nil { return nil, err } diff --git a/pkg/dataset/bart_types.go b/pkg/dataset/bart_types.go index eac4f54..46d6cf7 100644 --- a/pkg/dataset/bart_types.go +++ b/pkg/dataset/bart_types.go @@ -6,7 +6,6 @@ import ( "sync" "sync/atomic" - "github.com/crowdsecurity/crowdsec-spoa/internal/remediation" "github.com/gaissmai/bart" log "github.com/sirupsen/logrus" ) @@ -15,7 +14,7 @@ import ( type BartAddOp struct { Prefix netip.Prefix Origin string - R remediation.Remediation + R string // Remediation name as string IPType string Scope string } @@ -23,7 +22,7 @@ type BartAddOp struct { // BartRemoveOp represents a single prefix removal operation for batch processing type BartRemoveOp struct { Prefix netip.Prefix - R remediation.Remediation + R string // Remediation name as string Origin string IPType string Scope string @@ -91,7 +90,7 @@ func (s *BartRangeSet) initializeBatch(operations []BartAddOp) { // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String()) + valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R) valueLog.Trace("initial load: collecting prefix operations") } @@ -134,7 +133,7 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [ // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String()) + valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R) valueLog.Trace("adding to bart trie") } @@ -193,7 +192,7 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp { // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String()) + valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R) valueLog.Trace("removing from bart trie") } @@ -263,13 +262,13 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp { // Contains checks if an IP address matches any prefix in the bart table. // Returns the longest matching prefix's remediation and origin. // This method uses lock-free reads via atomic pointer for optimal performance. -func (s *BartRangeSet) Contains(ip netip.Addr) (remediation.Remediation, string) { +func (s *BartRangeSet) Contains(ip netip.Addr) (string, string) { // Lock-free read: atomically load the current table pointer table := s.tableAtomicPtr.Load() // Check for nil table (not yet initialized) if table == nil { - return remediation.Allow, "" + return "allow", "" } // Only build logging fields if trace level is enabled @@ -285,12 +284,12 @@ func (s *BartRangeSet) Contains(ip netip.Addr) (remediation.Remediation, string) if valueLog != nil { valueLog.Trace("no match found") } - return remediation.Allow, "" + return "allow", "" } remediationResult, origin := data.GetRemediationAndOrigin() if valueLog != nil { - valueLog.Tracef("bart result: %s (data: %+v)", remediationResult.String(), data) + valueLog.Tracef("bart result: %s (data: %+v)", remediationResult, data) } return remediationResult, origin } @@ -299,7 +298,7 @@ func (s *BartRangeSet) Contains(ip netip.Addr) (remediation.Remediation, string) // Uses Get() for exact prefix lookup (not LPM like Contains/Lookup). // Returns true if the exact prefix exists and has the given remediation with the given origin. // This method uses lock-free reads via atomic pointer for optimal performance. -func (s *BartRangeSet) HasRemediation(prefix netip.Prefix, r remediation.Remediation, origin string) bool { +func (s *BartRangeSet) HasRemediation(prefix netip.Prefix, remediationName string, origin string) bool { // Lock-free read: atomically load the current table pointer table := s.tableAtomicPtr.Load() @@ -315,14 +314,14 @@ func (s *BartRangeSet) HasRemediation(prefix netip.Prefix, r remediation.Remedia return false } - return data.HasRemediationWithOrigin(r, origin) + return data.HasRemediationWithOrigin(remediationName, origin) } // GetOriginForRemediation returns the origin for a specific remediation on an exact prefix. // Uses Get() for exact prefix lookup (not LPM). // Returns the origin and true if the exact prefix exists and has the given remediation, false otherwise. // This method uses lock-free reads via atomic pointer for optimal performance. -func (s *BartRangeSet) GetOriginForRemediation(prefix netip.Prefix, r remediation.Remediation) (string, bool) { +func (s *BartRangeSet) GetOriginForRemediation(prefix netip.Prefix, remediationName string) (string, bool) { // Lock-free read: atomically load the current table pointer table := s.tableAtomicPtr.Load() @@ -339,7 +338,7 @@ func (s *BartRangeSet) GetOriginForRemediation(prefix netip.Prefix, r remediatio } // Check if the remediation exists and return its origin - if existingOrigin, ok := data[r]; ok { + if existingOrigin, ok := data[remediationName]; ok { return existingOrigin, true } diff --git a/pkg/dataset/ipmap.go b/pkg/dataset/ipmap.go index 96a6338..949f074 100644 --- a/pkg/dataset/ipmap.go +++ b/pkg/dataset/ipmap.go @@ -6,7 +6,6 @@ import ( "sync" "sync/atomic" - "github.com/crowdsecurity/crowdsec-spoa/internal/remediation" log "github.com/sirupsen/logrus" ) @@ -50,14 +49,14 @@ func NewIPMap(logAlias string) *IPMap { type IPAddOp struct { IP netip.Addr Origin string - R remediation.Remediation + R string // Remediation name as string IPType string } // IPRemoveOp represents a remove operation for an individual IP type IPRemoveOp struct { IP netip.Addr - R remediation.Remediation + R string // Remediation name as string Origin string IPType string } @@ -78,7 +77,7 @@ func (m *IPMap) AddBatch(operations []IPAddOp) { func (m *IPMap) add(op IPAddOp) { var valueLog *log.Entry if m.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = m.logger.WithField("ip", op.IP.String()).WithField("remediation", op.R.String()) + valueLog = m.logger.WithField("ip", op.IP.String()).WithField("remediation", op.R) valueLog.Trace("adding IP to map") } @@ -139,7 +138,7 @@ func (m *IPMap) RemoveBatch(operations []IPRemoveOp) []*IPRemoveOp { func (m *IPMap) remove(op IPRemoveOp) bool { var valueLog *log.Entry if m.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = m.logger.WithField("ip", op.IP.String()).WithField("remediation", op.R.String()) + valueLog = m.logger.WithField("ip", op.IP.String()).WithField("remediation", op.R) valueLog.Trace("removing IP from map") } @@ -226,7 +225,7 @@ func (m *IPMap) remove(op IPRemoveOp) bool { // Contains checks if an IP address exists in the map // Returns the remediation and origin if found // This method is completely lock-free - SPOA handlers never block -func (m *IPMap) Contains(ip netip.Addr) (remediation.Remediation, string, bool) { +func (m *IPMap) Contains(ip netip.Addr) (string, string, bool) { var valueLog *log.Entry if m.logger.Logger.IsLevelEnabled(log.TraceLevel) { valueLog = m.logger.WithField("ip", ip.String()) @@ -244,23 +243,23 @@ func (m *IPMap) Contains(ip netip.Addr) (remediation.Remediation, string, bool) if valueLog != nil { valueLog.Trace("IP not found in map") } - return remediation.Allow, "", false + return "allow", "", false } entry, ok := existing.(*ipEntry) if !ok { - return remediation.Allow, "", false + return "allow", "", false } // Lock-free read via atomic pointer data := entry.data.Load() if data == nil { - return remediation.Allow, "", false + return "allow", "", false } r, origin := data.GetRemediationAndOrigin() if valueLog != nil { - valueLog.Tracef("found IP with remediation: %s", r.String()) + valueLog.Tracef("found IP with remediation: %s", r) } return r, origin, true } @@ -272,7 +271,7 @@ func (m *IPMap) Count() (ipv4 int64, ipv6 int64) { // HasRemediation checks if an IP has a specific remediation with a specific origin. // Returns true if the IP exists and has the given remediation with the given origin. -func (m *IPMap) HasRemediation(ip netip.Addr, r remediation.Remediation, origin string) bool { +func (m *IPMap) HasRemediation(ip netip.Addr, remediationName string, origin string) bool { // Select the appropriate map based on IP version ipMap := &m.ipv4 if ip.Is6() { @@ -295,5 +294,5 @@ func (m *IPMap) HasRemediation(ip netip.Addr, r remediation.Remediation, origin return false } - return data.HasRemediationWithOrigin(r, origin) + return data.HasRemediationWithOrigin(remediationName, origin) } diff --git a/pkg/dataset/root.go b/pkg/dataset/root.go index 770ee6e..c3ea77f 100644 --- a/pkg/dataset/root.go +++ b/pkg/dataset/root.go @@ -40,7 +40,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { type cnOp struct { cn string origin string - r remediation.Remediation + r string // Remediation name as string } // Separate operations by type: @@ -64,6 +64,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { scope := strings.ToLower(*decision.Scope) r := remediation.FromString(*decision.Type) + remediationName := r.String() // Convert to string for storage switch scope { case "ip": @@ -73,7 +74,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { continue } // Check for no-op: same IP, same remediation, same origin already exists - if d.IPMap.HasRemediation(ip, r, origin) { + if d.IPMap.HasRemediation(ip, remediationName, origin) { // Exact duplicate - skip processing (no-op) continue } @@ -82,13 +83,13 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { ipType = "ipv6" } // Check if we're overwriting an existing decision with different origin - if existingR, existingOrigin, found := d.IPMap.Contains(ip); found && existingR == r && existingOrigin != origin { + if existingR, existingOrigin, found := d.IPMap.Contains(ip); found && existingR == remediationName && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, ipType, "ip").Dec() } // Individual IPs go to IPMap for memory efficiency - ipOps = append(ipOps, IPAddOp{IP: ip, Origin: origin, R: r, IPType: ipType}) + ipOps = append(ipOps, IPAddOp{IP: ip, Origin: origin, R: remediationName, IPType: ipType}) // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(origin, ipType, "ip").Inc() case "range": @@ -98,7 +99,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { continue } // Check for no-op: same prefix, same remediation, same origin already exists - if d.RangeSet.HasRemediation(prefix, r, origin) { + if d.RangeSet.HasRemediation(prefix, remediationName, origin) { // Exact duplicate - skip processing (no-op) continue } @@ -107,30 +108,30 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { ipType = "ipv6" } // Check if we're overwriting an existing decision with different origin - if existingOrigin, found := d.RangeSet.GetOriginForRemediation(prefix, r); found && existingOrigin != origin { + if existingOrigin, found := d.RangeSet.GetOriginForRemediation(prefix, remediationName); found && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, ipType, "range").Dec() } // Ranges go to BART for LPM support - rangeOps = append(rangeOps, BartAddOp{Prefix: prefix, Origin: origin, R: r, IPType: ipType, Scope: "range"}) + rangeOps = append(rangeOps, BartAddOp{Prefix: prefix, Origin: origin, R: remediationName, IPType: ipType, Scope: "range"}) // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(origin, ipType, "range").Inc() case "country": // Clone country code to break reference to Decision struct memory cn := strings.Clone(*decision.Value) // Check for no-op: same country, same remediation, same origin already exists - if d.CNSet.HasRemediation(cn, r, origin) { + if d.CNSet.HasRemediation(cn, remediationName, origin) { // Exact duplicate - skip processing (no-op) continue } // Check if we're overwriting an existing decision with different origin - if existingR, existingOrigin := d.CNSet.Contains(cn); existingR == r && existingOrigin != "" && existingOrigin != origin { + if existingR, existingOrigin := d.CNSet.Contains(cn); existingR == remediationName && existingOrigin != "" && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, "", "country").Dec() } - cnOps = append(cnOps, cnOp{cn: cn, origin: origin, r: r}) + cnOps = append(cnOps, cnOp{cn: cn, origin: origin, r: remediationName}) default: log.Errorf("Unknown scope %s", *decision.Scope) } @@ -179,7 +180,7 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { type cnOp struct { cn string - r remediation.Remediation + r string // Remediation name as string origin string } @@ -202,6 +203,7 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { scope := strings.ToLower(*decision.Scope) r := remediation.FromString(*decision.Type) + remediationName := r.String() // Convert to string for storage switch scope { case "ip": @@ -214,7 +216,7 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { if ip.Is6() { ipType = "ipv6" } - ipOps = append(ipOps, IPRemoveOp{IP: ip, R: r, Origin: origin, IPType: ipType}) + ipOps = append(ipOps, IPRemoveOp{IP: ip, R: remediationName, Origin: origin, IPType: ipType}) case "range": prefix, err := netip.ParsePrefix(*decision.Value) if err != nil { @@ -225,10 +227,10 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { if prefix.Addr().Is6() { ipType = "ipv6" } - rangeOps = append(rangeOps, BartRemoveOp{Prefix: prefix, R: r, Origin: origin, IPType: ipType, Scope: "range"}) + rangeOps = append(rangeOps, BartRemoveOp{Prefix: prefix, R: remediationName, Origin: origin, IPType: ipType, Scope: "range"}) case "country": // Clone country code to break reference to Decision struct memory - cnOps = append(cnOps, cnOp{cn: strings.Clone(*decision.Value), r: r, origin: origin}) + cnOps = append(cnOps, cnOp{cn: strings.Clone(*decision.Value), r: remediationName, origin: origin}) default: log.Errorf("Unknown scope %s", *decision.Scope) } @@ -290,9 +292,9 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { log.Infof("Finished processing %d deleted decisions", len(decisions)) } -func (d *DataSet) CheckIP(ip netip.Addr) (remediation.Remediation, string, error) { +func (d *DataSet) CheckIP(ip netip.Addr) (string, string, error) { if !ip.IsValid() { - return remediation.Allow, "", fmt.Errorf("invalid IP address") + return "allow", "", fmt.Errorf("invalid IP address") } // First check the IPMap for exact IP match (O(1) lookup) @@ -305,23 +307,23 @@ func (d *DataSet) CheckIP(ip netip.Addr) (remediation.Remediation, string, error return r, origin, nil } -func (d *DataSet) CheckCN(cn string) (remediation.Remediation, string) { +func (d *DataSet) CheckCN(cn string) (string, string) { return d.CNSet.Contains(cn) } // Helper method for CN operations (still needed for country scope) -func (d *DataSet) addCN(cn string, origin string, r remediation.Remediation) error { +func (d *DataSet) addCN(cn string, origin string, remediationName string) error { if cn == "" { return fmt.Errorf("empty CN") } - d.CNSet.Add(cn, origin, r) + d.CNSet.Add(cn, origin, remediationName) return nil } -func (d *DataSet) removeCN(cn string, r remediation.Remediation) (bool, error) { +func (d *DataSet) removeCN(cn string, remediationName string) (bool, error) { if cn == "" { return false, fmt.Errorf("empty CN") } - removed := d.CNSet.Remove(cn, r) + removed := d.CNSet.Remove(cn, remediationName) return removed, nil } diff --git a/pkg/dataset/types.go b/pkg/dataset/types.go index 4ba4f3e..7044aef 100644 --- a/pkg/dataset/types.go +++ b/pkg/dataset/types.go @@ -12,7 +12,7 @@ import ( // ErrRemediationNotFound is returned when attempting to remove a remediation that doesn't exist. var ErrRemediationNotFound = errors.New("remediation not found") -// RemediationMap stores one origin string per remediation type. +// RemediationMap stores one origin string per remediation type (using string keys). // ID is not tracked since LAPI behavior ensures we only have the longest decision. // // LAPI behavior: @@ -20,49 +20,56 @@ var ErrRemediationNotFound = errors.New("remediation not found") // - Stream: Only returns NEW decisions if they're LONGER than current // - Deletions: Delete means user wants to allow the IP - just remove the remediation entry. // Duplicate deletes are safely ignored (entry already gone). -type RemediationMap map[remediation.Remediation]string +// +// Keys are strings (remediation names) to support custom remediations. +// Weight comparison is done via remediation.GetWeight() when determining priority. +type RemediationMap map[string]string // Remove removes a remediation entry (deletion means user wants to allow the IP). // Returns ErrRemediationNotFound if the remediation doesn't exist (duplicate delete). -func (rM RemediationMap) Remove(clog *log.Entry, r remediation.Remediation) error { - _, ok := rM[r] +func (rM RemediationMap) Remove(clog *log.Entry, remediationName string) error { + _, ok := rM[remediationName] if !ok { // Remediation not found - duplicate delete if clog != nil && clog.Logger.IsLevelEnabled(log.TraceLevel) { - clog.Tracef("remediation %s not found, duplicate delete", r.String()) + clog.Tracef("remediation %s not found, duplicate delete", remediationName) } return ErrRemediationNotFound } if clog != nil && clog.Logger.IsLevelEnabled(log.TraceLevel) { - clog.Tracef("removing remediation %s", r.String()) + clog.Tracef("removing remediation %s", remediationName) } - delete(rM, r) + delete(rM, remediationName) return nil } // Add adds or updates a decision for the given remediation type. // If a decision already exists, it's overwritten (since only one decision per remediation+value). -func (rM RemediationMap) Add(clog *log.Entry, r remediation.Remediation, origin string) { +func (rM RemediationMap) Add(clog *log.Entry, remediationName string, origin string) { if clog != nil && clog.Logger.IsLevelEnabled(log.TraceLevel) { - if _, exists := rM[r]; exists { - clog.Tracef("remediation %s found, updating", r.String()) + if _, exists := rM[remediationName]; exists { + clog.Tracef("remediation %s found, updating", remediationName) } else { - clog.Tracef("remediation %s not found, creating", r.String()) + clog.Tracef("remediation %s not found, creating", remediationName) } } - rM[r] = origin + rM[remediationName] = origin } // GetRemediationAndOrigin returns the highest priority remediation and its origin. -func (rM RemediationMap) GetRemediationAndOrigin() (remediation.Remediation, string) { - var maxRemediation remediation.Remediation +// Priority is determined by comparing weights using remediation.GetWeight(). +func (rM RemediationMap) GetRemediationAndOrigin() (string, string) { + var maxRemediation string var maxOrigin string + var maxWeight int first := true - for k, v := range rM { - if first || k > maxRemediation { - maxRemediation = k - maxOrigin = v + for remediationName, origin := range rM { + weight := remediation.GetWeight(remediationName) + if first || weight > maxWeight { + maxRemediation = remediationName + maxOrigin = origin + maxWeight = weight first = false } } @@ -77,8 +84,8 @@ func (rM RemediationMap) IsEmpty() bool { // HasRemediationWithOrigin checks if a specific remediation exists with the given origin. // Returns true if the remediation exists and has the same origin. -func (rM RemediationMap) HasRemediationWithOrigin(r remediation.Remediation, origin string) bool { - existingOrigin, exists := rM[r] +func (rM RemediationMap) HasRemediationWithOrigin(remediationName string, origin string) bool { + existingOrigin, exists := rM[remediationName] return exists && existingOrigin == origin } @@ -123,14 +130,14 @@ func NewCNSet(logAlias string) *CNSet { return s } -func (s *CNSet) Add(cn string, origin string, r remediation.Remediation) { +func (s *CNSet) Add(cn string, origin string, remediationName string) { s.writeMu.Lock() defer s.writeMu.Unlock() // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("value", cn).WithField("remediation", r.String()) + valueLog = s.logger.WithField("value", cn).WithField("remediation", remediationName) valueLog.Trace("adding") } @@ -154,27 +161,27 @@ func (s *CNSet) Add(cn string, origin string, r remediation.Remediation) { if valueLog != nil { valueLog.Trace("already exists") } - v.Add(valueLog, r, origin) + v.Add(valueLog, remediationName, origin) } else { if valueLog != nil { valueLog.Trace("not found, creating new entry") } newItems[cn] = make(RemediationMap) - newItems[cn].Add(valueLog, r, origin) + newItems[cn].Add(valueLog, remediationName, origin) } // Atomic swap - readers see old or new, never partial s.items.Store(&newItems) } -func (s *CNSet) Remove(cn string, r remediation.Remediation) bool { +func (s *CNSet) Remove(cn string, remediationName string) bool { s.writeMu.Lock() defer s.writeMu.Unlock() // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("value", cn).WithField("remediation", r.String()) + valueLog = s.logger.WithField("value", cn).WithField("remediation", remediationName) } current := s.items.Load() @@ -206,7 +213,7 @@ func (s *CNSet) Remove(cn string, r remediation.Remediation) bool { // Modify the cloned entry // Remove returns an error if remediation doesn't exist (duplicate delete) - err := newItems[cn].Remove(valueLog, r) + err := newItems[cn].Remove(valueLog, remediationName) if errors.Is(err, ErrRemediationNotFound) { // Duplicate delete - remediation not found, nothing to remove if valueLog != nil { @@ -229,7 +236,7 @@ func (s *CNSet) Remove(cn string, r remediation.Remediation) bool { // Contains checks if a country code has a decision. // This method is completely lock-free - SPOA handlers never block. -func (s *CNSet) Contains(toCheck string) (remediation.Remediation, string) { +func (s *CNSet) Contains(toCheck string) (string, string) { // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { @@ -237,7 +244,7 @@ func (s *CNSet) Contains(toCheck string) (remediation.Remediation, string) { valueLog.Trace("checking value") } - r := remediation.Allow + remediationName := "allow" origin := "" // Lock-free read via atomic pointer @@ -247,18 +254,18 @@ func (s *CNSet) Contains(toCheck string) (remediation.Remediation, string) { if valueLog != nil { valueLog.Trace("found") } - r, origin = v.GetRemediationAndOrigin() + remediationName, origin = v.GetRemediationAndOrigin() } } if valueLog != nil { - valueLog.Tracef("remediation: %s", r.String()) + valueLog.Tracef("remediation: %s", remediationName) } - return r, origin + return remediationName, origin } // HasRemediation checks if a country code has a specific remediation with a specific origin. // Returns true if the country code exists and has the given remediation with the given origin. -func (s *CNSet) HasRemediation(cn string, r remediation.Remediation, origin string) bool { +func (s *CNSet) HasRemediation(cn string, remediationName string, origin string) bool { // Lock-free read via atomic pointer items := s.items.Load() if items == nil { @@ -266,7 +273,7 @@ func (s *CNSet) HasRemediation(cn string, r remediation.Remediation, origin stri } if v, ok := (*items)[cn]; ok { - return v.HasRemediationWithOrigin(r, origin) + return v.HasRemediationWithOrigin(remediationName, origin) } return false } diff --git a/pkg/spoa/root.go b/pkg/spoa/root.go index 9d73c90..137fb86 100644 --- a/pkg/spoa/root.go +++ b/pkg/spoa/root.go @@ -185,7 +185,7 @@ type HTTPRequestData struct { // First stage is to check the host header and determine if the remediation from handleIpRequest is still valid // Second stage is to check if AppSec is enabled and then forward to the component if needed func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { - r := remediation.Allow + r := "allow" // Default to allow var origin string shouldCountMetrics := false @@ -208,7 +208,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { } if rstring != nil { - r = remediation.FromString(*rstring) + r = *rstring // Use string directly // Remediation came from IP check, already counted } @@ -218,12 +218,11 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // defer a function that always add the remediation to the request at end of processing defer func() { - if matchedHost == nil && r == remediation.Captcha { + if matchedHost == nil && r == "captcha" { s.logger.Warn("remediation is captcha, no matching host was found cannot issue captcha remediation reverting to ban") - r = remediation.Ban + r = "ban" } - rString := r.String() - req.Actions.SetVar(action.ScopeTransaction, "remediation", rString) + req.Actions.SetVar(action.ScopeTransaction, "remediation", r) // Count metrics if this is the only handler (upstream proxy mode) if shouldCountMetrics { @@ -242,10 +241,11 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // Count processed request - use WithLabelValues to avoid map allocation on hot path metrics.TotalProcessedRequests.WithLabelValues(ipTypeLabel).Inc() - // Count blocked request if remediation applied - if r > remediation.Unknown { + // Count blocked request if remediation applied (check weight > Allow weight) + // This includes Unknown, Captcha, Ban, and any custom remediations + if remediation.GetWeight(r) > remediation.WeightAllow { // Label order: origin, ip_type, remediation (as defined in metrics.go) - metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r.String()).Inc() + metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r).Inc() } } }() @@ -268,7 +268,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { var httpData HTTPRequestData switch r { - case remediation.Allow: + case "allow": // If user has a captcha cookie but decision is Allow, generate unset cookie // We don't set captcha_status, so HAProxy knows to clear the cookie cookieB64, err := readKeyFromMessage[string](mes, "crowdsec_captcha_cookie") @@ -305,22 +305,23 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { } // Parse HTTP data for AppSec processing httpData = parseHTTPData(s.logger, mes) - case remediation.Ban: + case "ban": //Handle ban matchedHost.Ban.InjectKeyValues(&req.Actions) // Parse HTTP data for AppSec processing httpData = parseHTTPData(s.logger, mes) - case remediation.Captcha: + case "captcha": r, httpData = s.handleCaptchaRemediation(req, mes, matchedHost) // If remediation changed to fallback, return early // If it became Allow, continue for AppSec processing - if r != remediation.Captcha && r != remediation.Allow { + if r != "captcha" && r != "allow" { return } } - // If remediation is ban/captcha we dont need to create a request to send to appsec unless always send is on - if r > remediation.Unknown && !matchedHost.AppSec.AlwaysSend { + // If remediation is not allow, we dont need to create a request to send to appsec unless always send is on + // This includes Unknown, Captcha, Ban, and any custom remediations + if remediation.GetWeight(r) > remediation.WeightAllow && !matchedHost.AppSec.AlwaysSend { return } // !TODO APPSEC STUFF - httpData contains parsed URL, Method, Body, Headers for reuse @@ -432,9 +433,9 @@ func (s *Spoa) createNewSessionAndCookie(req *request.Request, mes *message.Mess // handleCaptchaRemediation handles all captcha-related logic including cookie validation, // session management, captcha validation, and status updates. // Returns the remediation and parsed HTTP request data for reuse in AppSec processing. -func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Message, matchedHost *host.Host) (remediation.Remediation, HTTPRequestData) { +func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Message, matchedHost *host.Host) (string, HTTPRequestData) { if err := matchedHost.Captcha.InjectKeyValues(&req.Actions); err != nil { - return remediation.FromString(matchedHost.Captcha.FallbackRemediation), HTTPRequestData{} + return matchedHost.Captcha.FallbackRemediation, HTTPRequestData{} } cookieB64, err := readKeyFromMessage[string](mes, "crowdsec_captcha_cookie") @@ -470,7 +471,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "host": matchedHost.Host, "error": err, }).Error("Failed to create new session and cookie, falling back to fallback remediation") - return remediation.FromString(matchedHost.Captcha.FallbackRemediation), HTTPRequestData{} + return matchedHost.Captcha.FallbackRemediation, HTTPRequestData{} } } @@ -478,7 +479,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa // We should never hit this but safety net // As a fallback we set the remediation to the fallback remediation s.logger.Error("failed to get uuid from cookie") - return remediation.FromString(matchedHost.Captcha.FallbackRemediation), HTTPRequestData{} + return matchedHost.Captcha.FallbackRemediation, HTTPRequestData{} } // Get the session only if we didn't just create it (i.e., we have an existing cookie) @@ -499,7 +500,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "host": matchedHost.Host, "error": err, }).Error("Failed to create new session after reload, falling back to fallback remediation") - return remediation.FromString(matchedHost.Captcha.FallbackRemediation), HTTPRequestData{} + return matchedHost.Captcha.FallbackRemediation, HTTPRequestData{} } } } @@ -543,7 +544,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "key": "method", "host": matchedHost.Host, }).Error("failed to read method from message, cannot validate captcha form submission - ensure HAProxy is sending the 'method' variable in crowdsec-http message") - return remediation.Captcha, HTTPRequestData{URL: url} // Return partial data + return "captcha", HTTPRequestData{URL: url} // Return partial data } headersType, err := readKeyFromMessage[string](mes, "headers") @@ -553,7 +554,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "key": "headers", "host": matchedHost.Host, }).Error("failed to read headers from message, cannot validate captcha form submission - ensure HAProxy is sending the 'headers' variable in crowdsec-http message") - return remediation.Captcha, HTTPRequestData{URL: url, Method: method} // Return partial data + return "captcha", HTTPRequestData{URL: url, Method: method} // Return partial data } headers, err := readHeaders(*headersType) @@ -577,7 +578,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "host": matchedHost.Host, "session": uuid, }).Error("failed to read body from message, cannot validate captcha response - ensure HAProxy is sending the 'body' variable in crowdsec-http message for POST requests") - return remediation.Captcha, httpData // Return data without body + return "captcha", httpData // Return data without body } httpData.Body = body @@ -605,15 +606,15 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa // Delete the URI from the session so we dont redirect loop ses.Delete(session.URI) } - return remediation.Allow, httpData + return "allow", httpData } - return remediation.Captcha, httpData + return "captcha", httpData } // getIPRemediation performs IP and geo/country remediation checks // Returns the final remediation after checking IP, geo, and country -func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediation.Remediation, string) { +func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (string, string) { var origin string // Check IP directly against dataset r, origin, err := s.dataset.CheckIP(ip) @@ -622,7 +623,7 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediatio "ip": ip.String(), "error": err, }).Error("Failed to get IP remediation") - return remediation.Allow, "" // Safe default + return "allow", "" // Safe default } // Always try to get and set ISO code if geo database is available @@ -640,10 +641,11 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediatio // Always set the ISO code variable when available req.Actions.SetVar(action.ScopeTransaction, "isocode", iso) - // If no IP-specific remediation, check country-based remediation - if r < remediation.Unknown { + // If no IP-specific remediation (Allow), check country-based remediation + // Compare weights instead of direct comparison + if remediation.GetWeight(r) == remediation.WeightAllow { cnR, cnOrigin := s.dataset.CheckCN(iso) - if cnR > remediation.Unknown { + if remediation.GetWeight(cnR) > remediation.WeightAllow { r = cnR origin = cnOrigin } @@ -680,13 +682,14 @@ func (s *Spoa) handleIPRequest(req *request.Request, mes *message.Message) { // Check IP directly against dataset r, origin := s.getIPRemediation(req, ipAddr) - // Count blocked requests - if r > remediation.Unknown { + // Count blocked requests (check weight > Allow weight) + // This includes Unknown, Captcha, Ban, and any custom remediations + if remediation.GetWeight(r) > remediation.WeightAllow { // Label order: origin, ip_type, remediation (as defined in metrics.go) - metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r.String()).Inc() + metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r).Inc() } - req.Actions.SetVar(action.ScopeTransaction, "remediation", r.String()) + req.Actions.SetVar(action.ScopeTransaction, "remediation", r) } func handlerWrapper(s *Spoa) func(req *request.Request) { From 093fb806d08170e55a83468a4b7e3ab6fe92cbe3 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 15:55:53 +0000 Subject: [PATCH 02/11] refactor: use remediation.Remediation as map keys for automatic deduplication Change RemediationMap from map[string]string to map[remediation.Remediation]string. This leverages the deduplicated string pointers in Remediation structs to automatically reduce memory allocations when storing remediations in maps. Benefits: - Automatic string deduplication via shared *string pointers - Reduced memory allocations in RemediationMap - Type-safe map keys - Efficient pointer-based comparisons Updated all methods (Add, Remove, HasRemediationWithOrigin, GetRemediationAndOrigin) and Contains methods to use remediation.Remediation types directly. --- internal/remediation/root.go | 68 +++++++++++++++++++++++++------ pkg/dataset/bart_types.go | 34 ++++++++-------- pkg/dataset/benchmark_test.go | 28 ++++++------- pkg/dataset/ipmap.go | 26 ++++++------ pkg/dataset/metrics_test.go | 12 +++--- pkg/dataset/root.go | 45 +++++++++++---------- pkg/dataset/root_test.go | 4 +- pkg/dataset/types.go | 76 +++++++++++++++++------------------ pkg/spoa/root.go | 59 +++++++++++++-------------- 9 files changed, 198 insertions(+), 154 deletions(-) diff --git a/internal/remediation/root.go b/internal/remediation/root.go index a8c1e58..5e83f93 100644 --- a/internal/remediation/root.go +++ b/internal/remediation/root.go @@ -32,21 +32,45 @@ var globalRegistry = ®istry{ weights: make(map[string]int), } +// Built-in remediation constants (for convenience) +// Initialized to nil, will be set in init() +var ( + Allow Remediation + Unknown Remediation + Captcha Remediation + Ban Remediation +) + +//nolint:gochecknoinits // init() is required to initialize package-level vars after weights are set func init() { // Initialize built-in remediations with default weights globalRegistry.mu.Lock() defer globalRegistry.mu.Unlock() + // Set weights FIRST before creating strings globalRegistry.weights["allow"] = WeightAllow globalRegistry.weights["unknown"] = WeightUnknown globalRegistry.weights["captcha"] = WeightCaptcha globalRegistry.weights["ban"] = WeightBan // Pre-create deduplicated strings for built-in remediations - for name := range globalRegistry.weights { - deduped := name - globalRegistry.strings[name] = &deduped - } + // Must create new string variables for each to avoid pointer aliasing + allowStr := "allow" + unknownStr := "unknown" + captchaStr := "captcha" + banStr := "ban" + + globalRegistry.strings["allow"] = &allowStr + globalRegistry.strings["unknown"] = &unknownStr + globalRegistry.strings["captcha"] = &captchaStr + globalRegistry.strings["ban"] = &banStr + + // Now initialize the package-level vars directly (we already hold the lock) + // This avoids deadlock since New() would try to acquire the lock again + Allow = Remediation{name: &allowStr, weight: WeightAllow} + Unknown = Remediation{name: &unknownStr, weight: WeightUnknown} + Captcha = Remediation{name: &captchaStr, weight: WeightCaptcha} + Ban = Remediation{name: &banStr, weight: WeightBan} } // SetWeight sets a custom weight for a remediation (for configuration) @@ -74,14 +98,6 @@ func GetWeight(name string) int { return WeightUnknown } -// Built-in remediation constants (for convenience) -var ( - Allow = New("allow") - Unknown = New("unknown") - Captcha = New("captcha") - Ban = New("ban") -) - // New creates a new Remediation from a string // Uses deduplicated string pointers to reduce allocations func New(name string) Remediation { @@ -100,7 +116,12 @@ func New(name string) Remediation { } } - weight := globalRegistry.weights[name] + // Read weight from registry (may have been set in init() or SetWeight()) + weight, ok := globalRegistry.weights[name] + if !ok { + // Weight not found, default to Unknown + weight = WeightUnknown + } return Remediation{ name: deduped, weight: weight, @@ -128,6 +149,27 @@ func (r Remediation) Compare(other Remediation) int { return r.weight - other.weight } +// IsHigher returns true if r has a higher weight than other +func (r Remediation) IsHigher(other Remediation) bool { + return r.weight > other.weight +} + +// IsLower returns true if r has a lower weight than other +func (r Remediation) IsLower(other Remediation) bool { + return r.weight < other.weight +} + +// IsEqual returns true if r has the same weight as other +func (r Remediation) IsEqual(other Remediation) bool { + return r.weight == other.weight +} + +// IsWeighted returns true if r is not Allow (has weight > Allow) +// This is useful for checking if a remediation should be applied +func (r Remediation) IsWeighted() bool { + return r.weight > WeightAllow +} + // FromString creates a Remediation from a string (alias for New for backward compatibility) func FromString(s string) Remediation { return New(s) diff --git a/pkg/dataset/bart_types.go b/pkg/dataset/bart_types.go index 46d6cf7..81b82da 100644 --- a/pkg/dataset/bart_types.go +++ b/pkg/dataset/bart_types.go @@ -6,6 +6,7 @@ import ( "sync" "sync/atomic" + "github.com/crowdsecurity/crowdsec-spoa/internal/remediation" "github.com/gaissmai/bart" log "github.com/sirupsen/logrus" ) @@ -100,7 +101,7 @@ func (s *BartRangeSet) initializeBatch(operations []BartAddOp) { data = RemediationMap{} } // Add the remediation (this handles merging if prefix already seen) - data.Add(valueLog, op.R, op.Origin) + data.Add(valueLog, remediation.FromString(op.R), op.Origin) prefixMap[prefix] = data } @@ -145,7 +146,7 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [ valueLog.Trace("exact prefix exists, merging remediations") } // bart already cloned via our Cloner interface, modify directly - existingData.Add(valueLog, op.R, op.Origin) + existingData.Add(valueLog, remediation.FromString(op.R), op.Origin) return existingData, false // false = don't delete } if valueLog != nil { @@ -153,7 +154,7 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [ } // Create new data newData := make(RemediationMap) - newData.Add(valueLog, op.R, op.Origin) + newData.Add(valueLog, remediation.FromString(op.R), op.Origin) return newData, false // false = don't delete }) } @@ -209,11 +210,12 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp { // Check if the remediation exists with the matching origin before removing // This prevents removing decisions when the origin has been overwritten (e.g., by CAPI) - if !existingData.HasRemediationWithOrigin(op.R, op.Origin) { + if !existingData.HasRemediationWithOrigin(remediation.FromString(op.R), op.Origin) { // Origin doesn't match - this decision was likely overwritten by another origin // Don't remove it, as it's not the decision we're trying to delete if valueLog != nil { - storedOrigin, exists := existingData[op.R] + r := remediation.FromString(op.R) + storedOrigin, exists := existingData[r] if exists { valueLog.Tracef("remediation exists but origin mismatch (stored: %s, requested: %s), skipping removal", storedOrigin, op.Origin) } else { @@ -227,7 +229,7 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp { // bart already cloned via our Cloner interface, modify directly // Remove returns an error if remediation doesn't exist (duplicate delete) // We already checked origin above, so this should succeed - err := existingData.Remove(valueLog, op.R) + err := existingData.Remove(valueLog, remediation.FromString(op.R)) if errors.Is(err, ErrRemediationNotFound) { // This shouldn't happen since we checked above, but handle it gracefully if valueLog != nil { @@ -262,13 +264,13 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp { // Contains checks if an IP address matches any prefix in the bart table. // Returns the longest matching prefix's remediation and origin. // This method uses lock-free reads via atomic pointer for optimal performance. -func (s *BartRangeSet) Contains(ip netip.Addr) (string, string) { +func (s *BartRangeSet) Contains(ip netip.Addr) (remediation.Remediation, string) { // Lock-free read: atomically load the current table pointer table := s.tableAtomicPtr.Load() // Check for nil table (not yet initialized) if table == nil { - return "allow", "" + return remediation.Allow, "" } // Only build logging fields if trace level is enabled @@ -284,21 +286,21 @@ func (s *BartRangeSet) Contains(ip netip.Addr) (string, string) { if valueLog != nil { valueLog.Trace("no match found") } - return "allow", "" + return remediation.Allow, "" } - remediationResult, origin := data.GetRemediationAndOrigin() + r, origin := data.GetRemediationAndOrigin() if valueLog != nil { - valueLog.Tracef("bart result: %s (data: %+v)", remediationResult, data) + valueLog.Tracef("bart result: %s (data: %+v)", r.String(), data) } - return remediationResult, origin + return r, origin } // HasRemediation checks if an exact prefix has a specific remediation with a specific origin. // Uses Get() for exact prefix lookup (not LPM like Contains/Lookup). // Returns true if the exact prefix exists and has the given remediation with the given origin. // This method uses lock-free reads via atomic pointer for optimal performance. -func (s *BartRangeSet) HasRemediation(prefix netip.Prefix, remediationName string, origin string) bool { +func (s *BartRangeSet) HasRemediation(prefix netip.Prefix, r remediation.Remediation, origin string) bool { // Lock-free read: atomically load the current table pointer table := s.tableAtomicPtr.Load() @@ -314,14 +316,14 @@ func (s *BartRangeSet) HasRemediation(prefix netip.Prefix, remediationName strin return false } - return data.HasRemediationWithOrigin(remediationName, origin) + return data.HasRemediationWithOrigin(r, origin) } // GetOriginForRemediation returns the origin for a specific remediation on an exact prefix. // Uses Get() for exact prefix lookup (not LPM). // Returns the origin and true if the exact prefix exists and has the given remediation, false otherwise. // This method uses lock-free reads via atomic pointer for optimal performance. -func (s *BartRangeSet) GetOriginForRemediation(prefix netip.Prefix, remediationName string) (string, bool) { +func (s *BartRangeSet) GetOriginForRemediation(prefix netip.Prefix, r remediation.Remediation) (string, bool) { // Lock-free read: atomically load the current table pointer table := s.tableAtomicPtr.Load() @@ -338,7 +340,7 @@ func (s *BartRangeSet) GetOriginForRemediation(prefix netip.Prefix, remediationN } // Check if the remediation exists and return its origin - if existingOrigin, ok := data[remediationName]; ok { + if existingOrigin, ok := data[r]; ok { return existingOrigin, true } diff --git a/pkg/dataset/benchmark_test.go b/pkg/dataset/benchmark_test.go index 6ed3a43..a6e63b1 100644 --- a/pkg/dataset/benchmark_test.go +++ b/pkg/dataset/benchmark_test.go @@ -174,13 +174,13 @@ func TestCorrectness(t *testing.T) { continue } - // Basic sanity check - result should be valid - if result < remediation.Allow { + // Basic sanity check - result should be valid (not less than Allow weight) + if result.Compare(remediation.Allow) < 0 { t.Errorf("Invalid result for IP %s: %v", testIP.String(), result) } - // Origin should be non-empty if we have a match - if result > remediation.Allow && origin == "" { + // Origin should be non-empty if we have a match (result > Allow) + if result.Compare(remediation.Allow) > 0 && origin == "" { t.Errorf("Empty origin for IP %s with result %v", testIP.String(), result) } } @@ -192,38 +192,38 @@ func TestLongestPrefixMatch(t *testing.T) { // Add individual IP to IPMap and ranges to RangeSet dataset.IPMap.AddBatch([]IPAddOp{ - {IP: netip.MustParseAddr("192.168.1.1"), Origin: "test", R: remediation.Allow, IPType: "ipv4"}, + {IP: netip.MustParseAddr("192.168.1.1"), Origin: "test", R: remediation.Allow.String(), IPType: "ipv4"}, }) dataset.RangeSet.AddBatch([]BartAddOp{ - {Prefix: netip.MustParsePrefix("192.168.0.0/16"), Origin: "test", R: remediation.Ban, IPType: "ipv4", Scope: "range"}, - {Prefix: netip.MustParsePrefix("192.168.1.0/24"), Origin: "test", R: remediation.Captcha, IPType: "ipv4", Scope: "range"}, + {Prefix: netip.MustParsePrefix("192.168.0.0/16"), Origin: "test", R: remediation.Ban.String(), IPType: "ipv4", Scope: "range"}, + {Prefix: netip.MustParsePrefix("192.168.1.0/24"), Origin: "test", R: remediation.Captcha.String(), IPType: "ipv4", Scope: "range"}, }) // Test that individual IP from IPMap wins (checked first before RangeSet) ip1 := netip.MustParseAddr("192.168.1.1") result, _, _ := dataset.CheckIP(ip1) - if result != remediation.Allow { + if !result.IsEqual(remediation.Allow) { t.Errorf("Expected Allow for 192.168.1.1 (from IPMap), got %v", result) } // Test that we get the LPM from RangeSet (Captcha /24 wins over Ban /16) ip2 := netip.MustParseAddr("192.168.1.2") result, _, _ = dataset.CheckIP(ip2) - if result != remediation.Captcha { + if !result.IsEqual(remediation.Captcha) { t.Errorf("Expected Captcha for 192.168.1.2 (LPM from RangeSet), got %v", result) } // Test that we get the broadest match from RangeSet ip3 := netip.MustParseAddr("192.168.2.1") result, _, _ = dataset.CheckIP(ip3) - if result != remediation.Ban { + if !result.IsEqual(remediation.Ban) { t.Errorf("Expected Ban for 192.168.2.1 (from RangeSet), got %v", result) } // Test that we get no match ip4 := netip.MustParseAddr("10.0.0.1") result, _, _ = dataset.CheckIP(ip4) - if result != remediation.Allow { + if !result.IsEqual(remediation.Allow) { t.Errorf("Expected Allow for 10.0.0.1 (no match), got %v", result) } } @@ -354,7 +354,7 @@ func BenchmarkHybridVsBartOnly(b *testing.B) { ops = append(ops, BartAddOp{ Prefix: netip.PrefixFrom(ip, prefixLen), Origin: *d.Origin, - R: remediation.Ban, + R: remediation.Ban.String(), IPType: "ipv4", Scope: "ip", }) @@ -384,13 +384,13 @@ func BenchmarkLookupHybrid(b *testing.B) { byte(i % 256), }) dataset.IPMap.AddBatch([]IPAddOp{ - {IP: ip, Origin: "test", R: remediation.Ban, IPType: "ipv4"}, + {IP: ip, Origin: "test", R: remediation.Ban.String(), IPType: "ipv4"}, }) } // Add some ranges to RangeSet dataset.RangeSet.AddBatch([]BartAddOp{ - {Prefix: netip.MustParsePrefix("192.168.0.0/16"), Origin: "test", R: remediation.Ban, IPType: "ipv4", Scope: "range"}, + {Prefix: netip.MustParsePrefix("192.168.0.0/16"), Origin: "test", R: remediation.Ban.String(), IPType: "ipv4", Scope: "range"}, }) // Test IPs - some in IPMap, some in RangeSet, some not found diff --git a/pkg/dataset/ipmap.go b/pkg/dataset/ipmap.go index 949f074..747727f 100644 --- a/pkg/dataset/ipmap.go +++ b/pkg/dataset/ipmap.go @@ -6,6 +6,7 @@ import ( "sync" "sync/atomic" + "github.com/crowdsecurity/crowdsec-spoa/internal/remediation" log "github.com/sirupsen/logrus" ) @@ -101,7 +102,7 @@ func (m *IPMap) add(op IPAddOp) { // Empty or nil - no need to clone, just create new map newData = make(RemediationMap) } - newData.Add(valueLog, op.R, op.Origin) + newData.Add(valueLog, remediation.FromString(op.R), op.Origin) entry.data.Store(&newData) return } @@ -110,7 +111,7 @@ func (m *IPMap) add(op IPAddOp) { // Create new entry with data // Store directly (no LoadOrStore race needed since application uses single writer) newData := make(RemediationMap) - newData.Add(valueLog, op.R, op.Origin) + newData.Add(valueLog, remediation.FromString(op.R), op.Origin) entry := &ipEntry{} entry.data.Store(&newData) ipMap.Store(op.IP, entry) @@ -173,11 +174,12 @@ func (m *IPMap) remove(op IPRemoveOp) bool { // Check if the remediation exists with the matching origin before removing // This prevents removing decisions when the origin has been overwritten (e.g., by CAPI) - if !current.HasRemediationWithOrigin(op.R, op.Origin) { + if !current.HasRemediationWithOrigin(remediation.FromString(op.R), op.Origin) { // Origin doesn't match - this decision was likely overwritten by another origin // Don't remove it, as it's not the decision we're trying to delete if valueLog != nil { - storedOrigin, exists := (*current)[op.R] + r := remediation.FromString(op.R) + storedOrigin, exists := (*current)[r] if exists { valueLog.Tracef("remediation exists but origin mismatch (stored: %s, requested: %s), skipping removal", storedOrigin, op.Origin) } else { @@ -194,7 +196,7 @@ func (m *IPMap) remove(op IPRemoveOp) bool { // Remove returns an error if remediation doesn't exist (duplicate delete) // We already checked origin above, so this should succeed - err := newData.Remove(valueLog, op.R) + err := newData.Remove(valueLog, remediation.FromString(op.R)) if errors.Is(err, ErrRemediationNotFound) { // This shouldn't happen since we checked above, but handle it gracefully if valueLog != nil { @@ -225,7 +227,7 @@ func (m *IPMap) remove(op IPRemoveOp) bool { // Contains checks if an IP address exists in the map // Returns the remediation and origin if found // This method is completely lock-free - SPOA handlers never block -func (m *IPMap) Contains(ip netip.Addr) (string, string, bool) { +func (m *IPMap) Contains(ip netip.Addr) (remediation.Remediation, string, bool) { var valueLog *log.Entry if m.logger.Logger.IsLevelEnabled(log.TraceLevel) { valueLog = m.logger.WithField("ip", ip.String()) @@ -243,23 +245,23 @@ func (m *IPMap) Contains(ip netip.Addr) (string, string, bool) { if valueLog != nil { valueLog.Trace("IP not found in map") } - return "allow", "", false + return remediation.Allow, "", false } entry, ok := existing.(*ipEntry) if !ok { - return "allow", "", false + return remediation.Allow, "", false } // Lock-free read via atomic pointer data := entry.data.Load() if data == nil { - return "allow", "", false + return remediation.Allow, "", false } r, origin := data.GetRemediationAndOrigin() if valueLog != nil { - valueLog.Tracef("found IP with remediation: %s", r) + valueLog.Tracef("found IP with remediation: %s", r.String()) } return r, origin, true } @@ -271,7 +273,7 @@ func (m *IPMap) Count() (ipv4 int64, ipv6 int64) { // HasRemediation checks if an IP has a specific remediation with a specific origin. // Returns true if the IP exists and has the given remediation with the given origin. -func (m *IPMap) HasRemediation(ip netip.Addr, remediationName string, origin string) bool { +func (m *IPMap) HasRemediation(ip netip.Addr, r remediation.Remediation, origin string) bool { // Select the appropriate map based on IP version ipMap := &m.ipv4 if ip.Is6() { @@ -294,5 +296,5 @@ func (m *IPMap) HasRemediation(ip netip.Addr, remediationName string, origin str return false } - return data.HasRemediationWithOrigin(remediationName, origin) + return data.HasRemediationWithOrigin(r, origin) } diff --git a/pkg/dataset/metrics_test.go b/pkg/dataset/metrics_test.go index 2885ef9..fe81b15 100644 --- a/pkg/dataset/metrics_test.go +++ b/pkg/dataset/metrics_test.go @@ -746,7 +746,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { r, foundOrigin, found := dataSet.IPMap.Contains(ip) assert.True(t, found, "IP should still exist") assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban, r, "remediation should be ban") + assert.Equal(t, remediation.Ban.String(), r, "remediation should be ban") }) t.Run("Duplicate range decision is no-op", func(t *testing.T) { @@ -775,7 +775,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { require.NoError(t, err) r, foundOrigin := dataSet.RangeSet.Contains(testIP) assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban, r, "remediation should be ban") + assert.Equal(t, remediation.Ban.String(), r, "remediation should be ban") }) t.Run("Duplicate country decision is no-op", func(t *testing.T) { @@ -802,7 +802,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { // Verify decision still exists r, foundOrigin := dataSet.CNSet.Contains("US") assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban, r, "remediation should be ban") + assert.Equal(t, remediation.Ban.String(), r, "remediation should be ban") }) t.Run("Same IP different remediation is not no-op", func(t *testing.T) { @@ -866,7 +866,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { require.NoError(t, err) r, storedOrigin, found := dataSet.IPMap.Contains(ip) assert.True(t, found, "IP should exist") - assert.Equal(t, remediation.Ban, r, "IP should have ban remediation") + assert.True(t, r.IsEqual(remediation.Ban), "IP should have ban remediation") assert.Equal(t, origin, storedOrigin, "IP should have CAPI origin") // Step 2: Add captcha from same CAPI origin (overwrites ban) @@ -890,7 +890,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { // Both ban and captcha exist in the map, but ban is returned as highest priority r2, storedOrigin2, found2 := dataSet.IPMap.Contains(ip) assert.True(t, found2, "IP should still exist") - assert.Equal(t, remediation.Ban, r2, "IP should have ban remediation (highest priority, ban > captcha)") + assert.True(t, r2.IsEqual(remediation.Ban), "IP should have ban remediation (highest priority, ban > captcha)") assert.Equal(t, origin, storedOrigin2, "IP should still have CAPI origin") // Step 3: Delete ban from CAPI @@ -913,7 +913,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { // Verify IP now has captcha (ban was removed, captcha is now active) r3, storedOrigin3, found3 := dataSet.IPMap.Contains(ip) assert.True(t, found3, "IP should still exist") - assert.Equal(t, remediation.Captcha, r3, "IP should now have captcha remediation (ban was removed)") + assert.True(t, r3.IsEqual(remediation.Captcha), "IP should now have captcha remediation (ban was removed)") assert.Equal(t, origin, storedOrigin3, "IP should still have CAPI origin") }) } diff --git a/pkg/dataset/root.go b/pkg/dataset/root.go index c3ea77f..f85add3 100644 --- a/pkg/dataset/root.go +++ b/pkg/dataset/root.go @@ -40,7 +40,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { type cnOp struct { cn string origin string - r string // Remediation name as string + r remediation.Remediation } // Separate operations by type: @@ -64,7 +64,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { scope := strings.ToLower(*decision.Scope) r := remediation.FromString(*decision.Type) - remediationName := r.String() // Convert to string for storage + remediationName := r.String() // Convert to string for operations structs switch scope { case "ip": @@ -74,7 +74,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { continue } // Check for no-op: same IP, same remediation, same origin already exists - if d.IPMap.HasRemediation(ip, remediationName, origin) { + if d.IPMap.HasRemediation(ip, r, origin) { // Exact duplicate - skip processing (no-op) continue } @@ -83,7 +83,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { ipType = "ipv6" } // Check if we're overwriting an existing decision with different origin - if existingR, existingOrigin, found := d.IPMap.Contains(ip); found && existingR == remediationName && existingOrigin != origin { + if existingR, existingOrigin, found := d.IPMap.Contains(ip); found && existingR.IsEqual(r) && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, ipType, "ip").Dec() @@ -99,7 +99,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { continue } // Check for no-op: same prefix, same remediation, same origin already exists - if d.RangeSet.HasRemediation(prefix, remediationName, origin) { + if d.RangeSet.HasRemediation(prefix, r, origin) { // Exact duplicate - skip processing (no-op) continue } @@ -108,7 +108,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { ipType = "ipv6" } // Check if we're overwriting an existing decision with different origin - if existingOrigin, found := d.RangeSet.GetOriginForRemediation(prefix, remediationName); found && existingOrigin != origin { + if existingOrigin, found := d.RangeSet.GetOriginForRemediation(prefix, r); found && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, ipType, "range").Dec() @@ -121,17 +121,17 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { // Clone country code to break reference to Decision struct memory cn := strings.Clone(*decision.Value) // Check for no-op: same country, same remediation, same origin already exists - if d.CNSet.HasRemediation(cn, remediationName, origin) { + if d.CNSet.HasRemediation(cn, r, origin) { // Exact duplicate - skip processing (no-op) continue } // Check if we're overwriting an existing decision with different origin - if existingR, existingOrigin := d.CNSet.Contains(cn); existingR == remediationName && existingOrigin != "" && existingOrigin != origin { + if existingR, existingOrigin := d.CNSet.Contains(cn); existingR.IsEqual(r) && existingOrigin != "" && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, "", "country").Dec() } - cnOps = append(cnOps, cnOp{cn: cn, origin: origin, r: remediationName}) + cnOps = append(cnOps, cnOp{cn: cn, origin: origin, r: r}) default: log.Errorf("Unknown scope %s", *decision.Scope) } @@ -178,15 +178,15 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { } log.Infof("Processing %d deleted decisions", len(decisions)) + // Separate operations by type + // Note: We don't pre-allocate capacity here because many decisions might be no-ops + // (duplicates) and would waste allocated memory. Let Go handle dynamic growth. type cnOp struct { cn string - r string // Remediation name as string origin string + r remediation.Remediation } - // Separate operations by type - // Note: We don't pre-allocate capacity here because many decisions might be no-ops - // (duplicates) and would waste allocated memory. Let Go handle dynamic growth. ipOps := make([]IPRemoveOp, 0) rangeOps := make([]BartRemoveOp, 0) cnOps := make([]cnOp, 0) @@ -230,7 +230,7 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { rangeOps = append(rangeOps, BartRemoveOp{Prefix: prefix, R: remediationName, Origin: origin, IPType: ipType, Scope: "range"}) case "country": // Clone country code to break reference to Decision struct memory - cnOps = append(cnOps, cnOp{cn: strings.Clone(*decision.Value), r: remediationName, origin: origin}) + cnOps = append(cnOps, cnOp{cn: strings.Clone(*decision.Value), r: r, origin: origin}) default: log.Errorf("Unknown scope %s", *decision.Scope) } @@ -292,9 +292,9 @@ func (d *DataSet) Remove(decisions models.GetDecisionsResponse) { log.Infof("Finished processing %d deleted decisions", len(decisions)) } -func (d *DataSet) CheckIP(ip netip.Addr) (string, string, error) { +func (d *DataSet) CheckIP(ip netip.Addr) (remediation.Remediation, string, error) { if !ip.IsValid() { - return "allow", "", fmt.Errorf("invalid IP address") + return remediation.Allow, "", fmt.Errorf("invalid IP address") } // First check the IPMap for exact IP match (O(1) lookup) @@ -307,23 +307,24 @@ func (d *DataSet) CheckIP(ip netip.Addr) (string, string, error) { return r, origin, nil } -func (d *DataSet) CheckCN(cn string) (string, string) { - return d.CNSet.Contains(cn) +func (d *DataSet) CheckCN(cn string) (remediation.Remediation, string) { + r, origin := d.CNSet.Contains(cn) + return r, origin } // Helper method for CN operations (still needed for country scope) -func (d *DataSet) addCN(cn string, origin string, remediationName string) error { +func (d *DataSet) addCN(cn string, origin string, r remediation.Remediation) error { if cn == "" { return fmt.Errorf("empty CN") } - d.CNSet.Add(cn, origin, remediationName) + d.CNSet.Add(cn, origin, r) return nil } -func (d *DataSet) removeCN(cn string, remediationName string) (bool, error) { +func (d *DataSet) removeCN(cn string, r remediation.Remediation) (bool, error) { if cn == "" { return false, fmt.Errorf("empty CN") } - removed := d.CNSet.Remove(cn, remediationName) + removed := d.CNSet.Remove(cn, r) return removed, nil } diff --git a/pkg/dataset/root_test.go b/pkg/dataset/root_test.go index 3242fb5..42c1a78 100644 --- a/pkg/dataset/root_test.go +++ b/pkg/dataset/root_test.go @@ -16,7 +16,7 @@ type toCheck struct { Value string // IP, Country Scope string // IP, Country Origin string - Type remediation.Remediation + Type remediation.Remediation // remediation type } func TestDataSet(t *testing.T) { @@ -169,7 +169,7 @@ func TestDataSet(t *testing.T) { t.Fatalf("unknown scope %s", tt.toCheck.Scope) } require.NoError(t, err) - assert.Equal(t, r, tt.toCheck.Type) + assert.True(t, r.IsEqual(tt.toCheck.Type), "remediation should match: got %s, expected %s", r.String(), tt.toCheck.Type.String()) assert.Equal(t, origin, tt.toCheck.Origin) }) } diff --git a/pkg/dataset/types.go b/pkg/dataset/types.go index 7044aef..13c40c4 100644 --- a/pkg/dataset/types.go +++ b/pkg/dataset/types.go @@ -12,7 +12,7 @@ import ( // ErrRemediationNotFound is returned when attempting to remove a remediation that doesn't exist. var ErrRemediationNotFound = errors.New("remediation not found") -// RemediationMap stores one origin string per remediation type (using string keys). +// RemediationMap stores one origin string per remediation type (using Remediation as keys). // ID is not tracked since LAPI behavior ensures we only have the longest decision. // // LAPI behavior: @@ -21,55 +21,53 @@ var ErrRemediationNotFound = errors.New("remediation not found") // - Deletions: Delete means user wants to allow the IP - just remove the remediation entry. // Duplicate deletes are safely ignored (entry already gone). // -// Keys are strings (remediation names) to support custom remediations. -// Weight comparison is done via remediation.GetWeight() when determining priority. -type RemediationMap map[string]string +// Keys are remediation.Remediation types, which use deduplicated string pointers internally. +// This automatically benefits from string deduplication without extra complexity. +// Weight comparison is done via remediation.Compare() when determining priority. +type RemediationMap map[remediation.Remediation]string // Remove removes a remediation entry (deletion means user wants to allow the IP). // Returns ErrRemediationNotFound if the remediation doesn't exist (duplicate delete). -func (rM RemediationMap) Remove(clog *log.Entry, remediationName string) error { - _, ok := rM[remediationName] +func (rM RemediationMap) Remove(clog *log.Entry, r remediation.Remediation) error { + _, ok := rM[r] if !ok { // Remediation not found - duplicate delete if clog != nil && clog.Logger.IsLevelEnabled(log.TraceLevel) { - clog.Tracef("remediation %s not found, duplicate delete", remediationName) + clog.Tracef("remediation %s not found, duplicate delete", r.String()) } return ErrRemediationNotFound } if clog != nil && clog.Logger.IsLevelEnabled(log.TraceLevel) { - clog.Tracef("removing remediation %s", remediationName) + clog.Tracef("removing remediation %s", r.String()) } - delete(rM, remediationName) + delete(rM, r) return nil } // Add adds or updates a decision for the given remediation type. // If a decision already exists, it's overwritten (since only one decision per remediation+value). -func (rM RemediationMap) Add(clog *log.Entry, remediationName string, origin string) { +func (rM RemediationMap) Add(clog *log.Entry, r remediation.Remediation, origin string) { if clog != nil && clog.Logger.IsLevelEnabled(log.TraceLevel) { - if _, exists := rM[remediationName]; exists { - clog.Tracef("remediation %s found, updating", remediationName) + if _, exists := rM[r]; exists { + clog.Tracef("remediation %s found, updating", r.String()) } else { - clog.Tracef("remediation %s not found, creating", remediationName) + clog.Tracef("remediation %s not found, creating", r.String()) } } - rM[remediationName] = origin + rM[r] = origin } // GetRemediationAndOrigin returns the highest priority remediation and its origin. -// Priority is determined by comparing weights using remediation.GetWeight(). -func (rM RemediationMap) GetRemediationAndOrigin() (string, string) { - var maxRemediation string +// Priority is determined by comparing weights using remediation.Compare(). +func (rM RemediationMap) GetRemediationAndOrigin() (remediation.Remediation, string) { + var maxRemediation remediation.Remediation var maxOrigin string - var maxWeight int first := true - for remediationName, origin := range rM { - weight := remediation.GetWeight(remediationName) - if first || weight > maxWeight { - maxRemediation = remediationName + for r, origin := range rM { + if first || r.Compare(maxRemediation) > 0 { + maxRemediation = r maxOrigin = origin - maxWeight = weight first = false } } @@ -84,8 +82,8 @@ func (rM RemediationMap) IsEmpty() bool { // HasRemediationWithOrigin checks if a specific remediation exists with the given origin. // Returns true if the remediation exists and has the same origin. -func (rM RemediationMap) HasRemediationWithOrigin(remediationName string, origin string) bool { - existingOrigin, exists := rM[remediationName] +func (rM RemediationMap) HasRemediationWithOrigin(r remediation.Remediation, origin string) bool { + existingOrigin, exists := rM[r] return exists && existingOrigin == origin } @@ -130,14 +128,14 @@ func NewCNSet(logAlias string) *CNSet { return s } -func (s *CNSet) Add(cn string, origin string, remediationName string) { +func (s *CNSet) Add(cn string, origin string, r remediation.Remediation) { s.writeMu.Lock() defer s.writeMu.Unlock() // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("value", cn).WithField("remediation", remediationName) + valueLog = s.logger.WithField("value", cn).WithField("remediation", r.String()) valueLog.Trace("adding") } @@ -161,27 +159,27 @@ func (s *CNSet) Add(cn string, origin string, remediationName string) { if valueLog != nil { valueLog.Trace("already exists") } - v.Add(valueLog, remediationName, origin) + v.Add(valueLog, r, origin) } else { if valueLog != nil { valueLog.Trace("not found, creating new entry") } newItems[cn] = make(RemediationMap) - newItems[cn].Add(valueLog, remediationName, origin) + newItems[cn].Add(valueLog, r, origin) } // Atomic swap - readers see old or new, never partial s.items.Store(&newItems) } -func (s *CNSet) Remove(cn string, remediationName string) bool { +func (s *CNSet) Remove(cn string, r remediation.Remediation) bool { s.writeMu.Lock() defer s.writeMu.Unlock() // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { - valueLog = s.logger.WithField("value", cn).WithField("remediation", remediationName) + valueLog = s.logger.WithField("value", cn).WithField("remediation", r.String()) } current := s.items.Load() @@ -213,7 +211,7 @@ func (s *CNSet) Remove(cn string, remediationName string) bool { // Modify the cloned entry // Remove returns an error if remediation doesn't exist (duplicate delete) - err := newItems[cn].Remove(valueLog, remediationName) + err := newItems[cn].Remove(valueLog, r) if errors.Is(err, ErrRemediationNotFound) { // Duplicate delete - remediation not found, nothing to remove if valueLog != nil { @@ -236,7 +234,7 @@ func (s *CNSet) Remove(cn string, remediationName string) bool { // Contains checks if a country code has a decision. // This method is completely lock-free - SPOA handlers never block. -func (s *CNSet) Contains(toCheck string) (string, string) { +func (s *CNSet) Contains(toCheck string) (remediation.Remediation, string) { // Only build logging fields if trace level is enabled var valueLog *log.Entry if s.logger.Logger.IsLevelEnabled(log.TraceLevel) { @@ -244,7 +242,7 @@ func (s *CNSet) Contains(toCheck string) (string, string) { valueLog.Trace("checking value") } - remediationName := "allow" + r := remediation.Allow origin := "" // Lock-free read via atomic pointer @@ -254,18 +252,18 @@ func (s *CNSet) Contains(toCheck string) (string, string) { if valueLog != nil { valueLog.Trace("found") } - remediationName, origin = v.GetRemediationAndOrigin() + r, origin = v.GetRemediationAndOrigin() } } if valueLog != nil { - valueLog.Tracef("remediation: %s", remediationName) + valueLog.Tracef("remediation: %s", r.String()) } - return remediationName, origin + return r, origin } // HasRemediation checks if a country code has a specific remediation with a specific origin. // Returns true if the country code exists and has the given remediation with the given origin. -func (s *CNSet) HasRemediation(cn string, remediationName string, origin string) bool { +func (s *CNSet) HasRemediation(cn string, r remediation.Remediation, origin string) bool { // Lock-free read via atomic pointer items := s.items.Load() if items == nil { @@ -273,7 +271,7 @@ func (s *CNSet) HasRemediation(cn string, remediationName string, origin string) } if v, ok := (*items)[cn]; ok { - return v.HasRemediationWithOrigin(remediationName, origin) + return v.HasRemediationWithOrigin(r, origin) } return false } diff --git a/pkg/spoa/root.go b/pkg/spoa/root.go index 137fb86..9f45373 100644 --- a/pkg/spoa/root.go +++ b/pkg/spoa/root.go @@ -185,7 +185,7 @@ type HTTPRequestData struct { // First stage is to check the host header and determine if the remediation from handleIpRequest is still valid // Second stage is to check if AppSec is enabled and then forward to the component if needed func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { - r := "allow" // Default to allow + r := remediation.Allow // Default to allow var origin string shouldCountMetrics := false @@ -208,7 +208,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { } if rstring != nil { - r = *rstring // Use string directly + r = remediation.FromString(*rstring) // Convert to Remediation type // Remediation came from IP check, already counted } @@ -218,11 +218,11 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // defer a function that always add the remediation to the request at end of processing defer func() { - if matchedHost == nil && r == "captcha" { + if matchedHost == nil && r.String() == "captcha" { s.logger.Warn("remediation is captcha, no matching host was found cannot issue captcha remediation reverting to ban") - r = "ban" + r = remediation.Ban } - req.Actions.SetVar(action.ScopeTransaction, "remediation", r) + req.Actions.SetVar(action.ScopeTransaction, "remediation", r.String()) // Count metrics if this is the only handler (upstream proxy mode) if shouldCountMetrics { @@ -241,11 +241,11 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // Count processed request - use WithLabelValues to avoid map allocation on hot path metrics.TotalProcessedRequests.WithLabelValues(ipTypeLabel).Inc() - // Count blocked request if remediation applied (check weight > Allow weight) + // Count blocked request if remediation applied (not Allow) // This includes Unknown, Captcha, Ban, and any custom remediations - if remediation.GetWeight(r) > remediation.WeightAllow { + if r.IsWeighted() { // Label order: origin, ip_type, remediation (as defined in metrics.go) - metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r).Inc() + metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r.String()).Inc() } } }() @@ -267,7 +267,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { var httpData HTTPRequestData - switch r { + switch r.String() { case "allow": // If user has a captcha cookie but decision is Allow, generate unset cookie // We don't set captcha_status, so HAProxy knows to clear the cookie @@ -314,14 +314,14 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { r, httpData = s.handleCaptchaRemediation(req, mes, matchedHost) // If remediation changed to fallback, return early // If it became Allow, continue for AppSec processing - if r != "captcha" && r != "allow" { + if r.String() != "captcha" && r.String() != "allow" { return } } // If remediation is not allow, we dont need to create a request to send to appsec unless always send is on // This includes Unknown, Captcha, Ban, and any custom remediations - if remediation.GetWeight(r) > remediation.WeightAllow && !matchedHost.AppSec.AlwaysSend { + if r.IsWeighted() && !matchedHost.AppSec.AlwaysSend { return } // !TODO APPSEC STUFF - httpData contains parsed URL, Method, Body, Headers for reuse @@ -433,9 +433,9 @@ func (s *Spoa) createNewSessionAndCookie(req *request.Request, mes *message.Mess // handleCaptchaRemediation handles all captcha-related logic including cookie validation, // session management, captcha validation, and status updates. // Returns the remediation and parsed HTTP request data for reuse in AppSec processing. -func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Message, matchedHost *host.Host) (string, HTTPRequestData) { +func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Message, matchedHost *host.Host) (remediation.Remediation, HTTPRequestData) { if err := matchedHost.Captcha.InjectKeyValues(&req.Actions); err != nil { - return matchedHost.Captcha.FallbackRemediation, HTTPRequestData{} + return remediation.FromString(matchedHost.Captcha.FallbackRemediation), HTTPRequestData{} } cookieB64, err := readKeyFromMessage[string](mes, "crowdsec_captcha_cookie") @@ -471,7 +471,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "host": matchedHost.Host, "error": err, }).Error("Failed to create new session and cookie, falling back to fallback remediation") - return matchedHost.Captcha.FallbackRemediation, HTTPRequestData{} + return remediation.FromString(matchedHost.Captcha.FallbackRemediation), HTTPRequestData{} } } @@ -479,7 +479,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa // We should never hit this but safety net // As a fallback we set the remediation to the fallback remediation s.logger.Error("failed to get uuid from cookie") - return matchedHost.Captcha.FallbackRemediation, HTTPRequestData{} + return remediation.FromString(matchedHost.Captcha.FallbackRemediation), HTTPRequestData{} } // Get the session only if we didn't just create it (i.e., we have an existing cookie) @@ -500,7 +500,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "host": matchedHost.Host, "error": err, }).Error("Failed to create new session after reload, falling back to fallback remediation") - return matchedHost.Captcha.FallbackRemediation, HTTPRequestData{} + return remediation.FromString(matchedHost.Captcha.FallbackRemediation), HTTPRequestData{} } } } @@ -544,7 +544,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "key": "method", "host": matchedHost.Host, }).Error("failed to read method from message, cannot validate captcha form submission - ensure HAProxy is sending the 'method' variable in crowdsec-http message") - return "captcha", HTTPRequestData{URL: url} // Return partial data + return remediation.Captcha, HTTPRequestData{URL: url} // Return partial data } headersType, err := readKeyFromMessage[string](mes, "headers") @@ -554,7 +554,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "key": "headers", "host": matchedHost.Host, }).Error("failed to read headers from message, cannot validate captcha form submission - ensure HAProxy is sending the 'headers' variable in crowdsec-http message") - return "captcha", HTTPRequestData{URL: url, Method: method} // Return partial data + return remediation.Captcha, HTTPRequestData{URL: url, Method: method} // Return partial data } headers, err := readHeaders(*headersType) @@ -578,7 +578,7 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa "host": matchedHost.Host, "session": uuid, }).Error("failed to read body from message, cannot validate captcha response - ensure HAProxy is sending the 'body' variable in crowdsec-http message for POST requests") - return "captcha", httpData // Return data without body + return remediation.Captcha, httpData // Return data without body } httpData.Body = body @@ -606,15 +606,15 @@ func (s *Spoa) handleCaptchaRemediation(req *request.Request, mes *message.Messa // Delete the URI from the session so we dont redirect loop ses.Delete(session.URI) } - return "allow", httpData + return remediation.Allow, httpData } - return "captcha", httpData + return remediation.Captcha, httpData } // getIPRemediation performs IP and geo/country remediation checks // Returns the final remediation after checking IP, geo, and country -func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (string, string) { +func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediation.Remediation, string) { var origin string // Check IP directly against dataset r, origin, err := s.dataset.CheckIP(ip) @@ -623,7 +623,7 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (string, st "ip": ip.String(), "error": err, }).Error("Failed to get IP remediation") - return "allow", "" // Safe default + return remediation.Allow, "" // Safe default } // Always try to get and set ISO code if geo database is available @@ -642,10 +642,9 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (string, st req.Actions.SetVar(action.ScopeTransaction, "isocode", iso) // If no IP-specific remediation (Allow), check country-based remediation - // Compare weights instead of direct comparison - if remediation.GetWeight(r) == remediation.WeightAllow { + if r.IsEqual(remediation.Allow) { cnR, cnOrigin := s.dataset.CheckCN(iso) - if remediation.GetWeight(cnR) > remediation.WeightAllow { + if cnR.IsHigher(remediation.Allow) { r = cnR origin = cnOrigin } @@ -682,14 +681,14 @@ func (s *Spoa) handleIPRequest(req *request.Request, mes *message.Message) { // Check IP directly against dataset r, origin := s.getIPRemediation(req, ipAddr) - // Count blocked requests (check weight > Allow weight) + // Count blocked requests (not Allow) // This includes Unknown, Captcha, Ban, and any custom remediations - if remediation.GetWeight(r) > remediation.WeightAllow { + if r.IsWeighted() { // Label order: origin, ip_type, remediation (as defined in metrics.go) - metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r).Inc() + metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r.String()).Inc() } - req.Actions.SetVar(action.ScopeTransaction, "remediation", r) + req.Actions.SetVar(action.ScopeTransaction, "remediation", r.String()) } func handlerWrapper(s *Spoa) func(req *request.Request) { From 4bca4dcc85304f2a2552c990222cadf284d3d876 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 16:02:15 +0000 Subject: [PATCH 03/11] fix: update test assertions to use IsEqual for remediation.Remediation comparisons Tests were comparing remediation.Remediation with strings after Contains methods were updated to return remediation.Remediation directly. Updated assertions to use IsEqual() method for proper type-safe comparisons. --- pkg/dataset/metrics_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/dataset/metrics_test.go b/pkg/dataset/metrics_test.go index fe81b15..7a351f5 100644 --- a/pkg/dataset/metrics_test.go +++ b/pkg/dataset/metrics_test.go @@ -746,7 +746,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { r, foundOrigin, found := dataSet.IPMap.Contains(ip) assert.True(t, found, "IP should still exist") assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban.String(), r, "remediation should be ban") + assert.True(t, r.IsEqual(remediation.Ban), "remediation should be ban") }) t.Run("Duplicate range decision is no-op", func(t *testing.T) { @@ -775,7 +775,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { require.NoError(t, err) r, foundOrigin := dataSet.RangeSet.Contains(testIP) assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban.String(), r, "remediation should be ban") + assert.True(t, r.IsEqual(remediation.Ban), "remediation should be ban") }) t.Run("Duplicate country decision is no-op", func(t *testing.T) { @@ -802,7 +802,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { // Verify decision still exists r, foundOrigin := dataSet.CNSet.Contains("US") assert.Equal(t, origin, foundOrigin, "origin should match") - assert.Equal(t, remediation.Ban.String(), r, "remediation should be ban") + assert.True(t, r.IsEqual(remediation.Ban), "remediation should be ban") }) t.Run("Same IP different remediation is not no-op", func(t *testing.T) { From 727b7b5d5ab8125695b9784a779af5faa1547d05 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 16:03:06 +0000 Subject: [PATCH 04/11] refactor: replace string comparisons with IsEqual() in spoa package Replace all r.String() == "..." comparisons with r.IsEqual(remediation.X) for type-safe remediation comparisons. Updated switch statement to use if/else chain with IsEqual() checks. --- pkg/spoa/root.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pkg/spoa/root.go b/pkg/spoa/root.go index 9f45373..2d47ed8 100644 --- a/pkg/spoa/root.go +++ b/pkg/spoa/root.go @@ -218,7 +218,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // defer a function that always add the remediation to the request at end of processing defer func() { - if matchedHost == nil && r.String() == "captcha" { + if matchedHost == nil && r.IsEqual(remediation.Captcha) { s.logger.Warn("remediation is captcha, no matching host was found cannot issue captcha remediation reverting to ban") r = remediation.Ban } @@ -267,8 +267,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { var httpData HTTPRequestData - switch r.String() { - case "allow": + if r.IsEqual(remediation.Allow) { // If user has a captcha cookie but decision is Allow, generate unset cookie // We don't set captcha_status, so HAProxy knows to clear the cookie cookieB64, err := readKeyFromMessage[string](mes, "crowdsec_captcha_cookie") @@ -305,16 +304,16 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { } // Parse HTTP data for AppSec processing httpData = parseHTTPData(s.logger, mes) - case "ban": + } else if r.IsEqual(remediation.Ban) { //Handle ban matchedHost.Ban.InjectKeyValues(&req.Actions) // Parse HTTP data for AppSec processing httpData = parseHTTPData(s.logger, mes) - case "captcha": + } else if r.IsEqual(remediation.Captcha) { r, httpData = s.handleCaptchaRemediation(req, mes, matchedHost) // If remediation changed to fallback, return early // If it became Allow, continue for AppSec processing - if r.String() != "captcha" && r.String() != "allow" { + if !r.IsEqual(remediation.Captcha) && !r.IsEqual(remediation.Allow) { return } } From c9640a5525a90014a067da6e65e16e6e8a6b2279 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 16:05:12 +0000 Subject: [PATCH 05/11] fix: add nolint comment for if-else chain in spoa remediation handling Add nolint comment to suppress gocritic ifElseChain warning. We use if-else with IsEqual() for type-safe remediation comparisons instead of switch on String() to maintain type safety. --- pkg/spoa/root.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/spoa/root.go b/pkg/spoa/root.go index 2d47ed8..5dafe1f 100644 --- a/pkg/spoa/root.go +++ b/pkg/spoa/root.go @@ -267,6 +267,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { var httpData HTTPRequestData + //nolint:gocritic // Using if-else chain with IsEqual() for type-safe remediation comparisons instead of switch on String() if r.IsEqual(remediation.Allow) { // If user has a captcha cookie but decision is Allow, generate unset cookie // We don't set captcha_status, so HAProxy knows to clear the cookie From fffda45a11bcd865576341f21a8b30790fa0207b Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 16:06:47 +0000 Subject: [PATCH 06/11] refactor: use switch statement on remediation.Remediation directly Replace if-else chain with switch statement on remediation.Remediation struct. Since Remediation is comparable (all fields are comparable), we can switch on it directly using the known remediation constants (Allow, Ban, Captcha). --- pkg/spoa/root.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pkg/spoa/root.go b/pkg/spoa/root.go index 5dafe1f..2cccd2d 100644 --- a/pkg/spoa/root.go +++ b/pkg/spoa/root.go @@ -267,8 +267,8 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { var httpData HTTPRequestData - //nolint:gocritic // Using if-else chain with IsEqual() for type-safe remediation comparisons instead of switch on String() - if r.IsEqual(remediation.Allow) { + switch r { + case remediation.Allow: // If user has a captcha cookie but decision is Allow, generate unset cookie // We don't set captcha_status, so HAProxy knows to clear the cookie cookieB64, err := readKeyFromMessage[string](mes, "crowdsec_captcha_cookie") @@ -305,18 +305,21 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { } // Parse HTTP data for AppSec processing httpData = parseHTTPData(s.logger, mes) - } else if r.IsEqual(remediation.Ban) { + case remediation.Ban: //Handle ban matchedHost.Ban.InjectKeyValues(&req.Actions) // Parse HTTP data for AppSec processing httpData = parseHTTPData(s.logger, mes) - } else if r.IsEqual(remediation.Captcha) { + case remediation.Captcha: r, httpData = s.handleCaptchaRemediation(req, mes, matchedHost) // If remediation changed to fallback, return early // If it became Allow, continue for AppSec processing - if !r.IsEqual(remediation.Captcha) && !r.IsEqual(remediation.Allow) { + if r != remediation.Captcha && r != remediation.Allow { return } + default: + // Unknown or custom remediation - parse HTTP data for AppSec processing + httpData = parseHTTPData(s.logger, mes) } // If remediation is not allow, we dont need to create a request to send to appsec unless always send is on From e3737d1228d703c365059db232c68b8c06ab17d8 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 16:12:45 +0000 Subject: [PATCH 07/11] refactor: use switch statement on remediation.Remediation directly Replace if-else chain with switch statement on remediation.Remediation struct. Since Remediation is comparable (all fields are comparable), we can switch on it directly using the known remediation constants (Allow, Ban, Captcha). This satisfies the linter's ifElseChain warning while maintaining type safety. --- pkg/dataset/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/dataset/types.go b/pkg/dataset/types.go index 13c40c4..a3b6e61 100644 --- a/pkg/dataset/types.go +++ b/pkg/dataset/types.go @@ -65,7 +65,7 @@ func (rM RemediationMap) GetRemediationAndOrigin() (remediation.Remediation, str first := true for r, origin := range rM { - if first || r.Compare(maxRemediation) > 0 { + if first || r.IsHigher(maxRemediation) { maxRemediation = r maxOrigin = origin first = false From 5277f9e313b430ef251419641af124ea545dbdf1 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 16:23:02 +0000 Subject: [PATCH 08/11] fix: address Copilot review comments - Fix IsEqual() to compare both name pointer and weight for true equality - Add comprehensive documentation for RemediationWeights configuration - Add clarifying comments about struct comparison and custom remediation handling - Note: Direct struct comparison (r != remediation.X) works because Remediation is comparable (all fields are comparable: *string and int) --- internal/remediation/root.go | 6 ++++-- pkg/cfg/config.go | 26 ++++++++++++++++++++++---- pkg/spoa/root.go | 4 +++- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/internal/remediation/root.go b/internal/remediation/root.go index 5e83f93..fbc759e 100644 --- a/internal/remediation/root.go +++ b/internal/remediation/root.go @@ -159,9 +159,11 @@ func (r Remediation) IsLower(other Remediation) bool { return r.weight < other.weight } -// IsEqual returns true if r has the same weight as other +// IsEqual returns true if r represents the same remediation as other. +// This compares both the name pointer (for deduplicated identity) and weight. +// Two remediations are equal if they have the same name pointer and weight. func (r Remediation) IsEqual(other Remediation) bool { - return r.weight == other.weight + return r.name == other.name && r.weight == other.weight } // IsWeighted returns true if r is not Allow (has weight > Allow) diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index 7eec17e..46cd3e1 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -37,10 +37,28 @@ type BouncerConfig struct { ListenUnix string `yaml:"listen_unix"` PrometheusConfig PrometheusConfig `yaml:"prometheus"` PprofConfig PprofConfig `yaml:"pprof"` - // RemediationWeights allows users to configure custom weights for remediations - // Format: map[string]int where key is remediation name and value is weight - // Built-in defaults: allow=0, unknown=1, captcha=10, ban=20 - // Custom remediations can slot between these values + // RemediationWeights allows users to configure custom weights for remediations. + // + // Format: + // remediation_weights: + // : + // + // Example: + // remediation_weights: + // mfa: 15 # slots between captcha (10) and ban (20) + // + // Valid weight range: integer values >= 0. Lower values are less severe; higher values are more severe. + // Recommended: Use values between 0 and 100. + // + // Built-in defaults: + // allow=0, unknown=1, captcha=10, ban=20 + // + // Custom weights override or supplement built-in remediations. If a custom remediation is defined, + // its weight will be used for ordering and severity. Custom remediations can slot between built-in + // ones by choosing an appropriate weight value. + // + // Note: Custom weights for built-in remediations (allow, unknown, captcha, ban) must be set + // before package initialization. After init(), package-level constants already have cached weights. RemediationWeights map[string]int `yaml:"remediation_weights,omitempty"` } diff --git a/pkg/spoa/root.go b/pkg/spoa/root.go index 2cccd2d..8df5a85 100644 --- a/pkg/spoa/root.go +++ b/pkg/spoa/root.go @@ -314,11 +314,13 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { r, httpData = s.handleCaptchaRemediation(req, mes, matchedHost) // If remediation changed to fallback, return early // If it became Allow, continue for AppSec processing + // Note: Direct struct comparison works because Remediation is comparable if r != remediation.Captcha && r != remediation.Allow { return } default: - // Unknown or custom remediation - parse HTTP data for AppSec processing + // Unknown or custom remediation: currently, only HTTP data is parsed for AppSec processing. + // If a custom remediation requires special handling (like Ban), this must be implemented explicitly. httpData = parseHTTPData(s.logger, mes) } From c3ba1d85c22bc3f806cfc3a98ab08ea55a383589 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 16:25:23 +0000 Subject: [PATCH 09/11] feat: add HasSameWeight() method and handle equal-weight remediations - Add HasSameWeight() method as suggested by Copilot for checking weight equality - Update GetRemediationAndOrigin() to handle remediations with same weight using alphabetical order as a deterministic tie-breaker - Add documentation explaining tie-breaking behavior when remediations have same weight - This ensures deterministic behavior when users configure multiple remediations with the same weight value --- internal/remediation/root.go | 8 ++++++++ pkg/cfg/config.go | 3 +++ pkg/dataset/types.go | 17 ++++++++++++++++- 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/internal/remediation/root.go b/internal/remediation/root.go index fbc759e..354e5a1 100644 --- a/internal/remediation/root.go +++ b/internal/remediation/root.go @@ -166,6 +166,14 @@ func (r Remediation) IsEqual(other Remediation) bool { return r.name == other.name && r.weight == other.weight } +// HasSameWeight returns true if r has the same weight as other. +// This is useful for checking if two different remediations have the same priority. +// Note: Two remediations with the same weight will be compared by name (alphabetical) +// as a tie-breaker when determining priority. +func (r Remediation) HasSameWeight(other Remediation) bool { + return r.weight == other.weight +} + // IsWeighted returns true if r is not Allow (has weight > Allow) // This is useful for checking if a remediation should be applied func (r Remediation) IsWeighted() bool { diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index 46cd3e1..6deb9bf 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -57,6 +57,9 @@ type BouncerConfig struct { // its weight will be used for ordering and severity. Custom remediations can slot between built-in // ones by choosing an appropriate weight value. // + // Tie-breaking: If two remediations have the same weight, alphabetical order of the remediation + // name is used as a deterministic tie-breaker when determining priority. + // // Note: Custom weights for built-in remediations (allow, unknown, captcha, ban) must be set // before package initialization. After init(), package-level constants already have cached weights. RemediationWeights map[string]int `yaml:"remediation_weights,omitempty"` diff --git a/pkg/dataset/types.go b/pkg/dataset/types.go index a3b6e61..b9f9c1b 100644 --- a/pkg/dataset/types.go +++ b/pkg/dataset/types.go @@ -59,16 +59,31 @@ func (rM RemediationMap) Add(clog *log.Entry, r remediation.Remediation, origin // GetRemediationAndOrigin returns the highest priority remediation and its origin. // Priority is determined by comparing weights using remediation.Compare(). +// If two remediations have the same weight, alphabetical order of the name is used as a tie-breaker +// to ensure deterministic behavior. func (rM RemediationMap) GetRemediationAndOrigin() (remediation.Remediation, string) { var maxRemediation remediation.Remediation var maxOrigin string first := true for r, origin := range rM { - if first || r.IsHigher(maxRemediation) { + if first { maxRemediation = r maxOrigin = origin first = false + continue + } + + // Compare by weight first + if r.IsHigher(maxRemediation) { + maxRemediation = r + maxOrigin = origin + } else if r.HasSameWeight(maxRemediation) { + // Tie-breaker: use alphabetical order of name for deterministic behavior + if r.String() < maxRemediation.String() { + maxRemediation = r + maxOrigin = origin + } } } From 04de61bf93b58c950dcefc8184082f85fb010330 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 4 Dec 2025 16:29:36 +0000 Subject: [PATCH 10/11] docs: add documentation about map key equality requirements Add documentation explaining that all Remediation instances must be created via New() or FromString() to ensure proper deduplication for map key equality. Direct struct initialization will create different string pointers, causing map lookups to fail. This addresses Copilot's concern about map key comparison - the deduplication mechanism ensures all Remediation instances with the same name share the same *string pointer, making map key equality work correctly. --- internal/remediation/root.go | 23 +++++++++++++++++------ pkg/dataset/types.go | 6 ++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/internal/remediation/root.go b/internal/remediation/root.go index 354e5a1..f1925df 100644 --- a/internal/remediation/root.go +++ b/internal/remediation/root.go @@ -81,8 +81,9 @@ func SetWeight(name string, weight int) { globalRegistry.weights[name] = weight // Ensure deduplicated string exists if _, exists := globalRegistry.strings[name]; !exists { - deduped := name - globalRegistry.strings[name] = &deduped + // Create new string copy on heap to ensure pointer remains valid + nameCopy := name + globalRegistry.strings[name] = &nameCopy } } @@ -98,8 +99,16 @@ func GetWeight(name string) int { return WeightUnknown } -// New creates a new Remediation from a string -// Uses deduplicated string pointers to reduce allocations +// New creates a new Remediation from a string. +// Uses deduplicated string pointers to reduce allocations and ensure map key equality. +// +// IMPORTANT: All Remediation instances must be created via New() or FromString() to ensure +// proper deduplication. Direct struct initialization will create different string pointers, +// causing map lookups (e.g., in RemediationMap) to fail even for the same remediation name. +// +// The deduplication works by storing a single *string pointer per unique remediation name +// in globalRegistry.strings. All subsequent calls with the same name return Remediation +// instances with the same name pointer, ensuring map key equality works correctly. func New(name string) Remediation { globalRegistry.mu.Lock() defer globalRegistry.mu.Unlock() @@ -107,8 +116,10 @@ func New(name string) Remediation { // Get or create deduplicated string pointer deduped, exists := globalRegistry.strings[name] if !exists { - // Create new deduplicated string - deduped = &name + // Create new deduplicated string. The variable escapes to heap when stored in + // the package-level map, ensuring the pointer remains valid for map key comparisons. + nameCopy := name + deduped = &nameCopy globalRegistry.strings[name] = deduped // Set default weight if not configured if _, hasWeight := globalRegistry.weights[name]; !hasWeight { diff --git a/pkg/dataset/types.go b/pkg/dataset/types.go index b9f9c1b..d3d9cc2 100644 --- a/pkg/dataset/types.go +++ b/pkg/dataset/types.go @@ -24,6 +24,12 @@ var ErrRemediationNotFound = errors.New("remediation not found") // Keys are remediation.Remediation types, which use deduplicated string pointers internally. // This automatically benefits from string deduplication without extra complexity. // Weight comparison is done via remediation.Compare() when determining priority. +// +// IMPORTANT: Map key comparison uses Go's == operator which compares all struct fields, +// including the name pointer. For map lookups to work correctly, all Remediation instances +// MUST be created via remediation.New() or remediation.FromString() to ensure proper +// deduplication. Direct struct initialization will result in different pointers and +// cause map lookups to fail. type RemediationMap map[remediation.Remediation]string // Remove removes a remediation entry (deletion means user wants to allow the IP). From 8346aa6c2e376af69a0212ff20d86cf8a36fc998 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 11 Dec 2025 13:06:29 +0000 Subject: [PATCH 11/11] Refactor remediation to string type with registry-level weight management - Change Remediation from struct to string type - Move weight storage to registry level only - Add LoadWeights() function for startup initialization - Convert all comparison methods to registry-level functions - Update all call sites to use new function signatures - Simplify config loading to use LoadWeights() - All tests passing --- internal/remediation/root.go | 152 ++++++++++------------------------ pkg/cfg/config.go | 6 +- pkg/dataset/benchmark_test.go | 12 +-- pkg/dataset/metrics_test.go | 12 +-- pkg/dataset/root.go | 4 +- pkg/dataset/root_test.go | 2 +- pkg/dataset/types.go | 13 +-- pkg/spoa/root.go | 12 +-- 8 files changed, 71 insertions(+), 142 deletions(-) diff --git a/internal/remediation/root.go b/internal/remediation/root.go index f1925df..1c61478 100644 --- a/internal/remediation/root.go +++ b/internal/remediation/root.go @@ -14,63 +14,36 @@ const ( ) // Remediation represents a remediation type as a string -// We use string pointers for deduplication to reduce allocations -type Remediation struct { - name *string // Pointer to deduplicated string - weight int // Weight for comparison (higher = more severe) -} +type Remediation string -// registry manages deduplicated remediation strings and their weights +// Built-in remediation constants +const ( + Allow Remediation = "allow" + Unknown Remediation = "unknown" + Captcha Remediation = "captcha" + Ban Remediation = "ban" +) + +// registry manages remediation weights type registry struct { mu sync.RWMutex - strings map[string]*string // Maps string to its deduplicated pointer - weights map[string]int // Maps remediation name to its weight + weights map[string]int // Maps remediation name to its weight } var globalRegistry = ®istry{ - strings: make(map[string]*string), weights: make(map[string]int), } -// Built-in remediation constants (for convenience) -// Initialized to nil, will be set in init() -var ( - Allow Remediation - Unknown Remediation - Captcha Remediation - Ban Remediation -) - -//nolint:gochecknoinits // init() is required to initialize package-level vars after weights are set +//nolint:gochecknoinits // init() is required to initialize default weights func init() { // Initialize built-in remediations with default weights globalRegistry.mu.Lock() defer globalRegistry.mu.Unlock() - // Set weights FIRST before creating strings globalRegistry.weights["allow"] = WeightAllow globalRegistry.weights["unknown"] = WeightUnknown globalRegistry.weights["captcha"] = WeightCaptcha globalRegistry.weights["ban"] = WeightBan - - // Pre-create deduplicated strings for built-in remediations - // Must create new string variables for each to avoid pointer aliasing - allowStr := "allow" - unknownStr := "unknown" - captchaStr := "captcha" - banStr := "ban" - - globalRegistry.strings["allow"] = &allowStr - globalRegistry.strings["unknown"] = &unknownStr - globalRegistry.strings["captcha"] = &captchaStr - globalRegistry.strings["ban"] = &banStr - - // Now initialize the package-level vars directly (we already hold the lock) - // This avoids deadlock since New() would try to acquire the lock again - Allow = Remediation{name: &allowStr, weight: WeightAllow} - Unknown = Remediation{name: &unknownStr, weight: WeightUnknown} - Captcha = Remediation{name: &captchaStr, weight: WeightCaptcha} - Ban = Remediation{name: &banStr, weight: WeightBan} } // SetWeight sets a custom weight for a remediation (for configuration) @@ -79,12 +52,6 @@ func SetWeight(name string, weight int) { defer globalRegistry.mu.Unlock() globalRegistry.weights[name] = weight - // Ensure deduplicated string exists - if _, exists := globalRegistry.strings[name]; !exists { - // Create new string copy on heap to ensure pointer remains valid - nameCopy := name - globalRegistry.strings[name] = &nameCopy - } } // GetWeight returns the weight for a remediation name @@ -99,96 +66,67 @@ func GetWeight(name string) int { return WeightUnknown } -// New creates a new Remediation from a string. -// Uses deduplicated string pointers to reduce allocations and ensure map key equality. -// -// IMPORTANT: All Remediation instances must be created via New() or FromString() to ensure -// proper deduplication. Direct struct initialization will create different string pointers, -// causing map lookups (e.g., in RemediationMap) to fail even for the same remediation name. -// -// The deduplication works by storing a single *string pointer per unique remediation name -// in globalRegistry.strings. All subsequent calls with the same name return Remediation -// instances with the same name pointer, ensuring map key equality works correctly. -func New(name string) Remediation { +// LoadWeights loads weights for multiple remediations at once (for startup initialization) +func LoadWeights(weights map[string]int) { globalRegistry.mu.Lock() defer globalRegistry.mu.Unlock() - // Get or create deduplicated string pointer - deduped, exists := globalRegistry.strings[name] - if !exists { - // Create new deduplicated string. The variable escapes to heap when stored in - // the package-level map, ensuring the pointer remains valid for map key comparisons. - nameCopy := name - deduped = &nameCopy - globalRegistry.strings[name] = deduped - // Set default weight if not configured - if _, hasWeight := globalRegistry.weights[name]; !hasWeight { - globalRegistry.weights[name] = WeightUnknown - } + for name, weight := range weights { + globalRegistry.weights[name] = weight } +} - // Read weight from registry (may have been set in init() or SetWeight()) - weight, ok := globalRegistry.weights[name] - if !ok { - // Weight not found, default to Unknown - weight = WeightUnknown - } - return Remediation{ - name: deduped, - weight: weight, - } +// New creates a new Remediation from a string. +func New(name string) Remediation { + return Remediation(name) } // String returns the remediation name func (r Remediation) String() string { - if r.name == nil { + if r == "" { return "allow" // Default fallback } - return *r.name -} - -// Weight returns the weight of the remediation -func (r Remediation) Weight() int { - return r.weight + return string(r) } // Compare returns: -// - negative if r < other -// - zero if r == other -// - positive if r > other -func (r Remediation) Compare(other Remediation) int { - return r.weight - other.weight +// - negative if a < b +// - zero if a == b +// - positive if a > b +func Compare(a, b Remediation) int { + weightA := GetWeight(a.String()) + weightB := GetWeight(b.String()) + return weightA - weightB } -// IsHigher returns true if r has a higher weight than other -func (r Remediation) IsHigher(other Remediation) bool { - return r.weight > other.weight +// IsHigher returns true if a has a higher weight than b +func IsHigher(a, b Remediation) bool { + return Compare(a, b) > 0 } -// IsLower returns true if r has a lower weight than other -func (r Remediation) IsLower(other Remediation) bool { - return r.weight < other.weight +// IsLower returns true if a has a lower weight than b +func IsLower(a, b Remediation) bool { + return Compare(a, b) < 0 } -// IsEqual returns true if r represents the same remediation as other. -// This compares both the name pointer (for deduplicated identity) and weight. -// Two remediations are equal if they have the same name pointer and weight. -func (r Remediation) IsEqual(other Remediation) bool { - return r.name == other.name && r.weight == other.weight +// IsEqual returns true if a represents the same remediation as b. +// This compares the remediation names (strings). +func IsEqual(a, b Remediation) bool { + return a == b } -// HasSameWeight returns true if r has the same weight as other. +// HasSameWeight returns true if a has the same weight as b. // This is useful for checking if two different remediations have the same priority. // Note: Two remediations with the same weight will be compared by name (alphabetical) // as a tie-breaker when determining priority. -func (r Remediation) HasSameWeight(other Remediation) bool { - return r.weight == other.weight +func HasSameWeight(a, b Remediation) bool { + return Compare(a, b) == 0 } // IsWeighted returns true if r is not Allow (has weight > Allow) // This is useful for checking if a remediation should be applied -func (r Remediation) IsWeighted() bool { - return r.weight > WeightAllow +func IsWeighted(r Remediation) bool { + return GetWeight(r.String()) > WeightAllow } // FromString creates a Remediation from a string (alias for New for backward compatibility) @@ -198,5 +136,5 @@ func FromString(s string) Remediation { // IsZero returns true if the remediation is zero-valued func (r Remediation) IsZero() bool { - return r.name == nil + return r == "" } diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index 6deb9bf..e31e5f6 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -94,11 +94,9 @@ func NewConfig(reader io.Reader) (*BouncerConfig, error) { return nil, fmt.Errorf("failed to setup logging: %w", err) } - // Apply custom remediation weights if configured + // Load custom remediation weights if configured (loads all weights at once on startup) if config.RemediationWeights != nil { - for remediationName, weight := range config.RemediationWeights { - remediation.SetWeight(remediationName, weight) - } + remediation.LoadWeights(config.RemediationWeights) } if err := config.Validate(); err != nil { diff --git a/pkg/dataset/benchmark_test.go b/pkg/dataset/benchmark_test.go index a6e63b1..98bda65 100644 --- a/pkg/dataset/benchmark_test.go +++ b/pkg/dataset/benchmark_test.go @@ -175,12 +175,12 @@ func TestCorrectness(t *testing.T) { } // Basic sanity check - result should be valid (not less than Allow weight) - if result.Compare(remediation.Allow) < 0 { + if remediation.Compare(result, remediation.Allow) < 0 { t.Errorf("Invalid result for IP %s: %v", testIP.String(), result) } // Origin should be non-empty if we have a match (result > Allow) - if result.Compare(remediation.Allow) > 0 && origin == "" { + if remediation.Compare(result, remediation.Allow) > 0 && origin == "" { t.Errorf("Empty origin for IP %s with result %v", testIP.String(), result) } } @@ -202,28 +202,28 @@ func TestLongestPrefixMatch(t *testing.T) { // Test that individual IP from IPMap wins (checked first before RangeSet) ip1 := netip.MustParseAddr("192.168.1.1") result, _, _ := dataset.CheckIP(ip1) - if !result.IsEqual(remediation.Allow) { + if !remediation.IsEqual(result, remediation.Allow) { t.Errorf("Expected Allow for 192.168.1.1 (from IPMap), got %v", result) } // Test that we get the LPM from RangeSet (Captcha /24 wins over Ban /16) ip2 := netip.MustParseAddr("192.168.1.2") result, _, _ = dataset.CheckIP(ip2) - if !result.IsEqual(remediation.Captcha) { + if !remediation.IsEqual(result, remediation.Captcha) { t.Errorf("Expected Captcha for 192.168.1.2 (LPM from RangeSet), got %v", result) } // Test that we get the broadest match from RangeSet ip3 := netip.MustParseAddr("192.168.2.1") result, _, _ = dataset.CheckIP(ip3) - if !result.IsEqual(remediation.Ban) { + if !remediation.IsEqual(result, remediation.Ban) { t.Errorf("Expected Ban for 192.168.2.1 (from RangeSet), got %v", result) } // Test that we get no match ip4 := netip.MustParseAddr("10.0.0.1") result, _, _ = dataset.CheckIP(ip4) - if !result.IsEqual(remediation.Allow) { + if !remediation.IsEqual(result, remediation.Allow) { t.Errorf("Expected Allow for 10.0.0.1 (no match), got %v", result) } } diff --git a/pkg/dataset/metrics_test.go b/pkg/dataset/metrics_test.go index 7a351f5..c79ed6e 100644 --- a/pkg/dataset/metrics_test.go +++ b/pkg/dataset/metrics_test.go @@ -746,7 +746,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { r, foundOrigin, found := dataSet.IPMap.Contains(ip) assert.True(t, found, "IP should still exist") assert.Equal(t, origin, foundOrigin, "origin should match") - assert.True(t, r.IsEqual(remediation.Ban), "remediation should be ban") + assert.True(t, remediation.IsEqual(r, remediation.Ban), "remediation should be ban") }) t.Run("Duplicate range decision is no-op", func(t *testing.T) { @@ -775,7 +775,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { require.NoError(t, err) r, foundOrigin := dataSet.RangeSet.Contains(testIP) assert.Equal(t, origin, foundOrigin, "origin should match") - assert.True(t, r.IsEqual(remediation.Ban), "remediation should be ban") + assert.True(t, remediation.IsEqual(r, remediation.Ban), "remediation should be ban") }) t.Run("Duplicate country decision is no-op", func(t *testing.T) { @@ -802,7 +802,7 @@ func TestMetrics_NoOp_DuplicateDecisions(t *testing.T) { // Verify decision still exists r, foundOrigin := dataSet.CNSet.Contains("US") assert.Equal(t, origin, foundOrigin, "origin should match") - assert.True(t, r.IsEqual(remediation.Ban), "remediation should be ban") + assert.True(t, remediation.IsEqual(r, remediation.Ban), "remediation should be ban") }) t.Run("Same IP different remediation is not no-op", func(t *testing.T) { @@ -866,7 +866,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { require.NoError(t, err) r, storedOrigin, found := dataSet.IPMap.Contains(ip) assert.True(t, found, "IP should exist") - assert.True(t, r.IsEqual(remediation.Ban), "IP should have ban remediation") + assert.True(t, remediation.IsEqual(r, remediation.Ban), "IP should have ban remediation") assert.Equal(t, origin, storedOrigin, "IP should have CAPI origin") // Step 2: Add captcha from same CAPI origin (overwrites ban) @@ -890,7 +890,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { // Both ban and captcha exist in the map, but ban is returned as highest priority r2, storedOrigin2, found2 := dataSet.IPMap.Contains(ip) assert.True(t, found2, "IP should still exist") - assert.True(t, r2.IsEqual(remediation.Ban), "IP should have ban remediation (highest priority, ban > captcha)") + assert.True(t, remediation.IsEqual(r2, remediation.Ban), "IP should have ban remediation (highest priority, ban > captcha)") assert.Equal(t, origin, storedOrigin2, "IP should still have CAPI origin") // Step 3: Delete ban from CAPI @@ -913,7 +913,7 @@ func TestMetrics_OriginOverwriteAndDelete(t *testing.T) { // Verify IP now has captcha (ban was removed, captcha is now active) r3, storedOrigin3, found3 := dataSet.IPMap.Contains(ip) assert.True(t, found3, "IP should still exist") - assert.True(t, r3.IsEqual(remediation.Captcha), "IP should now have captcha remediation (ban was removed)") + assert.True(t, remediation.IsEqual(r3, remediation.Captcha), "IP should now have captcha remediation (ban was removed)") assert.Equal(t, origin, storedOrigin3, "IP should still have CAPI origin") }) } diff --git a/pkg/dataset/root.go b/pkg/dataset/root.go index f85add3..5ac94d6 100644 --- a/pkg/dataset/root.go +++ b/pkg/dataset/root.go @@ -83,7 +83,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { ipType = "ipv6" } // Check if we're overwriting an existing decision with different origin - if existingR, existingOrigin, found := d.IPMap.Contains(ip); found && existingR.IsEqual(r) && existingOrigin != origin { + if existingR, existingOrigin, found := d.IPMap.Contains(ip); found && remediation.IsEqual(existingR, r) && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, ipType, "ip").Dec() @@ -126,7 +126,7 @@ func (d *DataSet) Add(decisions models.GetDecisionsResponse) { continue } // Check if we're overwriting an existing decision with different origin - if existingR, existingOrigin := d.CNSet.Contains(cn); existingR.IsEqual(r) && existingOrigin != "" && existingOrigin != origin { + if existingR, existingOrigin := d.CNSet.Contains(cn); remediation.IsEqual(existingR, r) && existingOrigin != "" && existingOrigin != origin { // Decrement old origin's metric before incrementing new one // Label order: origin, ip_type, scope (as defined in metrics.go) metrics.TotalActiveDecisions.WithLabelValues(existingOrigin, "", "country").Dec() diff --git a/pkg/dataset/root_test.go b/pkg/dataset/root_test.go index 42c1a78..15e186b 100644 --- a/pkg/dataset/root_test.go +++ b/pkg/dataset/root_test.go @@ -169,7 +169,7 @@ func TestDataSet(t *testing.T) { t.Fatalf("unknown scope %s", tt.toCheck.Scope) } require.NoError(t, err) - assert.True(t, r.IsEqual(tt.toCheck.Type), "remediation should match: got %s, expected %s", r.String(), tt.toCheck.Type.String()) + assert.True(t, remediation.IsEqual(r, tt.toCheck.Type), "remediation should match: got %s, expected %s", r.String(), tt.toCheck.Type.String()) assert.Equal(t, origin, tt.toCheck.Origin) }) } diff --git a/pkg/dataset/types.go b/pkg/dataset/types.go index d3d9cc2..bb934cb 100644 --- a/pkg/dataset/types.go +++ b/pkg/dataset/types.go @@ -21,15 +21,8 @@ var ErrRemediationNotFound = errors.New("remediation not found") // - Deletions: Delete means user wants to allow the IP - just remove the remediation entry. // Duplicate deletes are safely ignored (entry already gone). // -// Keys are remediation.Remediation types, which use deduplicated string pointers internally. -// This automatically benefits from string deduplication without extra complexity. +// Keys are remediation.Remediation types (string type). // Weight comparison is done via remediation.Compare() when determining priority. -// -// IMPORTANT: Map key comparison uses Go's == operator which compares all struct fields, -// including the name pointer. For map lookups to work correctly, all Remediation instances -// MUST be created via remediation.New() or remediation.FromString() to ensure proper -// deduplication. Direct struct initialization will result in different pointers and -// cause map lookups to fail. type RemediationMap map[remediation.Remediation]string // Remove removes a remediation entry (deletion means user wants to allow the IP). @@ -81,10 +74,10 @@ func (rM RemediationMap) GetRemediationAndOrigin() (remediation.Remediation, str } // Compare by weight first - if r.IsHigher(maxRemediation) { + if remediation.IsHigher(r, maxRemediation) { maxRemediation = r maxOrigin = origin - } else if r.HasSameWeight(maxRemediation) { + } else if remediation.HasSameWeight(r, maxRemediation) { // Tie-breaker: use alphabetical order of name for deterministic behavior if r.String() < maxRemediation.String() { maxRemediation = r diff --git a/pkg/spoa/root.go b/pkg/spoa/root.go index 8df5a85..8155f66 100644 --- a/pkg/spoa/root.go +++ b/pkg/spoa/root.go @@ -218,7 +218,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // defer a function that always add the remediation to the request at end of processing defer func() { - if matchedHost == nil && r.IsEqual(remediation.Captcha) { + if matchedHost == nil && remediation.IsEqual(r, remediation.Captcha) { s.logger.Warn("remediation is captcha, no matching host was found cannot issue captcha remediation reverting to ban") r = remediation.Ban } @@ -243,7 +243,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // Count blocked request if remediation applied (not Allow) // This includes Unknown, Captcha, Ban, and any custom remediations - if r.IsWeighted() { + if remediation.IsWeighted(r) { // Label order: origin, ip_type, remediation (as defined in metrics.go) metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r.String()).Inc() } @@ -326,7 +326,7 @@ func (s *Spoa) handleHTTPRequest(req *request.Request, mes *message.Message) { // If remediation is not allow, we dont need to create a request to send to appsec unless always send is on // This includes Unknown, Captcha, Ban, and any custom remediations - if r.IsWeighted() && !matchedHost.AppSec.AlwaysSend { + if remediation.IsWeighted(r) && !matchedHost.AppSec.AlwaysSend { return } // !TODO APPSEC STUFF - httpData contains parsed URL, Method, Body, Headers for reuse @@ -647,9 +647,9 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediatio req.Actions.SetVar(action.ScopeTransaction, "isocode", iso) // If no IP-specific remediation (Allow), check country-based remediation - if r.IsEqual(remediation.Allow) { + if remediation.IsEqual(r, remediation.Allow) { cnR, cnOrigin := s.dataset.CheckCN(iso) - if cnR.IsHigher(remediation.Allow) { + if remediation.IsHigher(cnR, remediation.Allow) { r = cnR origin = cnOrigin } @@ -688,7 +688,7 @@ func (s *Spoa) handleIPRequest(req *request.Request, mes *message.Message) { // Count blocked requests (not Allow) // This includes Unknown, Captcha, Ban, and any custom remediations - if r.IsWeighted() { + if remediation.IsWeighted(r) { // Label order: origin, ip_type, remediation (as defined in metrics.go) metrics.TotalBlockedRequests.WithLabelValues(origin, ipTypeLabel, r.String()).Inc() }