Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 169 additions & 25 deletions internal/remediation/root.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,181 @@
package remediation

// The order matters since we use slices.Max to get the max value
import (
"sync"
)

// Default weights for built-in remediations
// Allow=0, Unknown=1, then expand others to allow custom remediations to slot in between
const (
Allow Remediation = iota // Allow remediation
Unknown // Unknown remediation (Unknown is used to have a value for remediation we don't support EG "MFA")
Captcha // Captcha remediation
Ban // Ban remediation
WeightAllow = 0
WeightUnknown = 1
WeightCaptcha = 10
WeightBan = 20
)

// Remediation represents a remediation type as a string
// We use string pointers for deduplication to reduce allocations
type Remediation struct {
name *string // Pointer to deduplicated string
weight int // Weight for comparison (higher = more severe)
}
Comment on lines +18 to +21
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight is stored redundantly in both the Remediation struct and the globalRegistry.weights map. This creates potential for inconsistency and doesn't align with the claimed "deduplication" benefits.

Issues:

  1. When SetWeight() is called after a Remediation has been created, existing instances retain their old weight while new instances get the new weight.
  2. This breaks the invariant that all instances with the same name should behave identically.
  3. The PR description claims "Automatic deduplication" and "Reduced allocations", but storing weight in the struct doesn't achieve this.

Recommendation: Remove the weight field from the struct and always look it up from the registry:

func (r Remediation) Weight() int {
    return GetWeight(*r.name)
}

This ensures consistency and truly achieves deduplication - only the name pointer is stored per instance.

Copilot uses AI. Check for mistakes.

// registry manages deduplicated remediation strings and their weights
type registry struct {
mu sync.RWMutex
strings map[string]*string // Maps string to its deduplicated pointer
weights map[string]int // Maps remediation name to its weight
}

var globalRegistry = &registry{
strings: make(map[string]*string),
weights: make(map[string]int),
}

// Built-in remediation constants (for convenience)
// Initialized to nil, will be set in init()
var (
Allow Remediation
Unknown Remediation
Captcha Remediation
Ban Remediation
)

type Remediation uint8 // Remediation type is smallest uint to save space
//nolint:gochecknoinits // init() is required to initialize package-level vars after weights are set
func init() {
// Initialize built-in remediations with default weights
globalRegistry.mu.Lock()
defer globalRegistry.mu.Unlock()

// Set weights FIRST before creating strings
globalRegistry.weights["allow"] = WeightAllow
globalRegistry.weights["unknown"] = WeightUnknown
globalRegistry.weights["captcha"] = WeightCaptcha
globalRegistry.weights["ban"] = WeightBan

// Pre-create deduplicated strings for built-in remediations
// Must create new string variables for each to avoid pointer aliasing
allowStr := "allow"
unknownStr := "unknown"
captchaStr := "captcha"
banStr := "ban"

globalRegistry.strings["allow"] = &allowStr
globalRegistry.strings["unknown"] = &unknownStr
globalRegistry.strings["captcha"] = &captchaStr
globalRegistry.strings["ban"] = &banStr

// Now initialize the package-level vars directly (we already hold the lock)
// This avoids deadlock since New() would try to acquire the lock again
Allow = Remediation{name: &allowStr, weight: WeightAllow}
Unknown = Remediation{name: &unknownStr, weight: WeightUnknown}
Captcha = Remediation{name: &captchaStr, weight: WeightCaptcha}
Ban = Remediation{name: &banStr, weight: WeightBan}
}

// SetWeight sets a custom weight for a remediation (for configuration)
func SetWeight(name string, weight int) {
globalRegistry.mu.Lock()
defer globalRegistry.mu.Unlock()

globalRegistry.weights[name] = weight
// Ensure deduplicated string exists
if _, exists := globalRegistry.strings[name]; !exists {
deduped := name
globalRegistry.strings[name] = &deduped
}
}

// GetWeight returns the weight for a remediation name
func GetWeight(name string) int {
globalRegistry.mu.RLock()
defer globalRegistry.mu.RUnlock()

if weight, exists := globalRegistry.weights[name]; exists {
return weight
}
// Default to Unknown weight for unknown remediations
return WeightUnknown
}

// New creates a new Remediation from a string
// Uses deduplicated string pointers to reduce allocations
func New(name string) Remediation {
globalRegistry.mu.Lock()
defer globalRegistry.mu.Unlock()

// Get or create deduplicated string pointer
deduped, exists := globalRegistry.strings[name]
if !exists {
// Create new deduplicated string
deduped = &name
globalRegistry.strings[name] = deduped
// Set default weight if not configured
if _, hasWeight := globalRegistry.weights[name]; !hasWeight {
globalRegistry.weights[name] = WeightUnknown
}
}

// Read weight from registry (may have been set in init() or SetWeight())
weight, ok := globalRegistry.weights[name]
if !ok {
// Weight not found, default to Unknown
weight = WeightUnknown
}
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data race: The globalRegistry.weights map is accessed without proper synchronization. In New(), lines 131-134 read from weights after releasing the lock on line 114. Between these lines, another goroutine could call SetWeight() and modify the map, causing a concurrent map read/write panic.

Problematic sequence:

defer globalRegistry.mu.Unlock()  // Line 114 - lock is released
// ... (lines 115-130)
weight, ok := globalRegistry.weights[name]  // Line 131 - unsynchronized read!

Solution: Move the weight lookup inside the critical section before unlocking:

globalRegistry.mu.Lock()
defer globalRegistry.mu.Unlock()
// ... get/create deduped string ...
weight := globalRegistry.weights[name]
if weight == 0 {
    weight = WeightUnknown
}
return Remediation{name: deduped, weight: weight}
Suggested change
}
}

Copilot uses AI. Check for mistakes.
return Remediation{
name: deduped,
weight: weight,
}
}

// String returns the remediation name
func (r Remediation) String() string {
switch r {
case Ban:
return "ban"
case Captcha:
return "captcha"
case Unknown:
return "unknown"
default:
return "allow"
if r.name == nil {
return "allow" // Default fallback
}
return *r.name
}

// Weight returns the weight of the remediation
func (r Remediation) Weight() int {
return r.weight
}

// Compare returns:
// - negative if r < other
// - zero if r == other
// - positive if r > other
func (r Remediation) Compare(other Remediation) int {
return r.weight - other.weight
}

// IsHigher returns true if r has a higher weight than other
func (r Remediation) IsHigher(other Remediation) bool {
return r.weight > other.weight
}

// IsLower returns true if r has a lower weight than other
func (r Remediation) IsLower(other Remediation) bool {
return r.weight < other.weight
}

// IsEqual returns true if r has the same weight as other
func (r Remediation) IsEqual(other Remediation) bool {
return r.weight == other.weight
}

// IsWeighted returns true if r is not Allow (has weight > Allow)
// This is useful for checking if a remediation should be applied
func (r Remediation) IsWeighted() bool {
return r.weight > WeightAllow
}

// FromString creates a Remediation from a string (alias for New for backward compatibility)
func FromString(s string) Remediation {
switch s {
case "ban":
return Ban
case "captcha":
return Captcha
case "allow":
return Allow
default:
return Unknown
}
return New(s)
}

// IsZero returns true if the remediation is zero-valued
func (r Remediation) IsZero() bool {
return r.name == nil
}
13 changes: 13 additions & 0 deletions pkg/cfg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"gopkg.in/yaml.v2"

"github.com/crowdsecurity/crowdsec-spoa/internal/geo"
"github.com/crowdsecurity/crowdsec-spoa/internal/remediation"
"github.com/crowdsecurity/crowdsec-spoa/pkg/host"
cslogging "github.com/crowdsecurity/crowdsec-spoa/pkg/logging"
"github.com/crowdsecurity/go-cs-lib/csyaml"
Expand Down Expand Up @@ -36,6 +37,11 @@ type BouncerConfig struct {
ListenUnix string `yaml:"listen_unix"`
PrometheusConfig PrometheusConfig `yaml:"prometheus"`
PprofConfig PprofConfig `yaml:"pprof"`
// RemediationWeights allows users to configure custom weights for remediations
// Format: map[string]int where key is remediation name and value is weight
// Built-in defaults: allow=0, unknown=1, captcha=10, ban=20
// Custom remediations can slot between these values
RemediationWeights map[string]int `yaml:"remediation_weights,omitempty"`
}

// MergedConfig() returns the byte content of the patched configuration file (with .yaml.local).
Expand Down Expand Up @@ -67,6 +73,13 @@ func NewConfig(reader io.Reader) (*BouncerConfig, error) {
return nil, fmt.Errorf("failed to setup logging: %w", err)
}

// Apply custom remediation weights if configured
if config.RemediationWeights != nil {
for remediationName, weight := range config.RemediationWeights {
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The configuration loading order creates a race condition. Custom weights are applied in NewConfig() (line 97-102) after the init() function has already initialized the built-in constants with default weights. This means:

  1. If the config sets a custom weight for "ban", it updates the registry but Ban constant already has weight 20 cached in its struct
  2. New Remediation instances created via FromString("ban") will get the custom weight
  3. But comparisons between the constant Ban and new instances will compare different weights

Example failure scenario:

// Config sets ban weight to 25
remediation.SetWeight("ban", 25)
r := remediation.FromString("ban") // gets weight 25
if r.IsEqual(remediation.Ban) {    // false! Ban has weight 20
    // This never executes
}

Solution: Document that custom weights for built-in remediations are not supported, or redesign to not cache weights in the struct (see comment on lines 18-21).

Suggested change
for remediationName, weight := range config.RemediationWeights {
for remediationName, weight := range config.RemediationWeights {
// Prevent custom weights for built-in remediations to avoid inconsistent behavior
switch remediationName {
case remediation.Ban.Name, remediation.Captcha.Name, remediation.Bypass.Name, remediation.Stream.Name:
return nil, fmt.Errorf("custom weights for built-in remediation '%s' are not supported", remediationName)
}

Copilot uses AI. Check for mistakes.
remediation.SetWeight(remediationName, weight)
}
}

if err := config.Validate(); err != nil {
return nil, err
}
Expand Down
29 changes: 15 additions & 14 deletions pkg/dataset/bart_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ import (
type BartAddOp struct {
Prefix netip.Prefix
Origin string
R remediation.Remediation
R string // Remediation name as string
IPType string
Scope string
}

// BartRemoveOp represents a single prefix removal operation for batch processing
type BartRemoveOp struct {
Prefix netip.Prefix
R remediation.Remediation
R string // Remediation name as string
Origin string
IPType string
Scope string
Expand Down Expand Up @@ -91,7 +91,7 @@ func (s *BartRangeSet) initializeBatch(operations []BartAddOp) {
// Only build logging fields if trace level is enabled
var valueLog *log.Entry
if s.logger.Logger.IsLevelEnabled(log.TraceLevel) {
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String())
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R)
valueLog.Trace("initial load: collecting prefix operations")
}

Expand All @@ -101,7 +101,7 @@ func (s *BartRangeSet) initializeBatch(operations []BartAddOp) {
data = RemediationMap{}
}
// Add the remediation (this handles merging if prefix already seen)
data.Add(valueLog, op.R, op.Origin)
data.Add(valueLog, remediation.FromString(op.R), op.Origin)
prefixMap[prefix] = data
}

Expand Down Expand Up @@ -134,7 +134,7 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [
// Only build logging fields if trace level is enabled
var valueLog *log.Entry
if s.logger.Logger.IsLevelEnabled(log.TraceLevel) {
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String())
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R)
valueLog.Trace("adding to bart trie")
}

Expand All @@ -146,15 +146,15 @@ func (s *BartRangeSet) updateBatch(cur *bart.Table[RemediationMap], operations [
valueLog.Trace("exact prefix exists, merging remediations")
}
// bart already cloned via our Cloner interface, modify directly
existingData.Add(valueLog, op.R, op.Origin)
existingData.Add(valueLog, remediation.FromString(op.R), op.Origin)
return existingData, false // false = don't delete
}
if valueLog != nil {
valueLog.Trace("creating new entry")
}
// Create new data
newData := make(RemediationMap)
newData.Add(valueLog, op.R, op.Origin)
newData.Add(valueLog, remediation.FromString(op.R), op.Origin)
return newData, false // false = don't delete
})
}
Expand Down Expand Up @@ -193,7 +193,7 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp {
// Only build logging fields if trace level is enabled
var valueLog *log.Entry
if s.logger.Logger.IsLevelEnabled(log.TraceLevel) {
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R.String())
valueLog = s.logger.WithField("prefix", prefix.String()).WithField("remediation", op.R)
valueLog.Trace("removing from bart trie")
}

Expand All @@ -210,11 +210,12 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp {

// Check if the remediation exists with the matching origin before removing
// This prevents removing decisions when the origin has been overwritten (e.g., by CAPI)
if !existingData.HasRemediationWithOrigin(op.R, op.Origin) {
if !existingData.HasRemediationWithOrigin(remediation.FromString(op.R), op.Origin) {
// Origin doesn't match - this decision was likely overwritten by another origin
// Don't remove it, as it's not the decision we're trying to delete
if valueLog != nil {
storedOrigin, exists := existingData[op.R]
r := remediation.FromString(op.R)
storedOrigin, exists := existingData[r]
if exists {
valueLog.Tracef("remediation exists but origin mismatch (stored: %s, requested: %s), skipping removal", storedOrigin, op.Origin)
} else {
Expand All @@ -228,7 +229,7 @@ func (s *BartRangeSet) RemoveBatch(operations []BartRemoveOp) []*BartRemoveOp {
// bart already cloned via our Cloner interface, modify directly
// Remove returns an error if remediation doesn't exist (duplicate delete)
// We already checked origin above, so this should succeed
err := existingData.Remove(valueLog, op.R)
err := existingData.Remove(valueLog, remediation.FromString(op.R))
if errors.Is(err, ErrRemediationNotFound) {
// This shouldn't happen since we checked above, but handle it gracefully
if valueLog != nil {
Expand Down Expand Up @@ -288,11 +289,11 @@ func (s *BartRangeSet) Contains(ip netip.Addr) (remediation.Remediation, string)
return remediation.Allow, ""
}

remediationResult, origin := data.GetRemediationAndOrigin()
r, origin := data.GetRemediationAndOrigin()
if valueLog != nil {
valueLog.Tracef("bart result: %s (data: %+v)", remediationResult.String(), data)
valueLog.Tracef("bart result: %s (data: %+v)", r.String(), data)
}
return remediationResult, origin
return r, origin
}

// HasRemediation checks if an exact prefix has a specific remediation with a specific origin.
Expand Down
Loading