From d274a85588142e0b635d203f8cb2c48bee6ad678 Mon Sep 17 00:00:00 2001 From: Laurence Date: Wed, 26 Nov 2025 09:01:46 +0000 Subject: [PATCH 1/2] feat(host): add trie-based host matching for MSSP scalability Implement a reverse domain trie for efficient host pattern matching, designed to scale for MSSP deployments with hundreds/thousands of hosts. Changes: - Add domainTrie data structure with O(m) lookup complexity - Hybrid approach: trie for simple patterns, filepath.Match fallback for complex - Priority system ensures most-specific-first matching behavior - Comprehensive tests and benchmarks Benchmark results (4 mixed lookups per iteration): | Hosts | Slice (old) | Trie (new) | Speedup | |---------|-------------|------------|--------------| | 10 | 4,901 ns | 432 ns | 11x faster | | 100 | 53,221 ns | 419 ns | 127x faster | | 1,000 | 414,463 ns | 428 ns | 968x faster | | 10,000 | 3,835,689 ns| 453 ns | 8,468x faster| Note: For small deployments (1-4 hosts), the existing cache provides sufficient performance. The trie optimization primarily benefits large-scale MSSP deployments. --- pkg/host/TRIE_IMPLEMENTATION.md | 112 ++++++++++ pkg/host/benchmark_test.go | 292 +++++++++++++++++++++++++++ pkg/host/root.go | 65 ++++-- pkg/host/root_test.go | 305 ++++++++++++++++++++++++++++ pkg/host/trie.go | 348 ++++++++++++++++++++++++++++++++ 5 files changed, 1107 insertions(+), 15 deletions(-) create mode 100644 pkg/host/TRIE_IMPLEMENTATION.md create mode 100644 pkg/host/benchmark_test.go create mode 100644 pkg/host/root_test.go create mode 100644 pkg/host/trie.go diff --git a/pkg/host/TRIE_IMPLEMENTATION.md b/pkg/host/TRIE_IMPLEMENTATION.md new file mode 100644 index 0000000..39f3b18 --- /dev/null +++ b/pkg/host/TRIE_IMPLEMENTATION.md @@ -0,0 +1,112 @@ +# Host Trie Implementation + +## Overview + +This document describes the trie-based optimization for host pattern matching, designed to scale efficiently for MSSP deployments with hundreds or thousands of host configurations. + +## Problem Statement + +The original implementation stored hosts in a slice and used linear search with `filepath.Match()` for each request. This approach has O(n) complexity and doesn't scale well for large numbers of hosts. + +## Solution: Reverse Domain Trie + +We implemented a **reverse domain trie** that provides O(m) lookup complexity where m is the depth of the domain (typically 2-4 levels), independent of the total number of hosts. + +### How It Works + +#### Domain Reversal + +Domains are reversed before insertion to enable efficient prefix matching: +- `www.example.com` → `["com", "example", "www"]` +- `*.example.com` → `["com", "example", "*"]` +- `*` → `["*"]` + +This allows patterns like `*.example.com` to share the common `com → example` path with other patterns for the same domain. + +#### Trie Structure + +``` +root +├── com (exact) +│ └── example (exact) +│ ├── www (exact) → Host: www.example.com +│ └── * (wildcard) → Host: *.example.com +└── * (wildcard) → Host: * (catch-all) +``` + +#### Matching Algorithm + +The `findMatches` function traverses the trie recursively: + +1. **Exact match first**: Try to match the current domain segment exactly +2. **Wildcard fallback**: If no exact match found, try the wildcard child +3. **Priority comparison**: When multiple matches are possible, the highest priority wins + +### Priority System + +Priority determines which pattern wins when multiple patterns could match: + +| Factor | Impact | +|--------|--------| +| Exact match (no wildcards) | +10,000 | +| Pattern length | +10 per character | +| Each wildcard character | -1,000 | + +Examples: +- `www.example.com` → 10,000 + 150 = **10,150** +- `*.example.com` → 0 + 130 - 1,000 = **-870** +- `*` → 0 + 10 - 1,000 = **-990** + +### Pattern Classification + +**Simple patterns** (handled efficiently by trie): +- Exact: `www.example.com` +- Prefix wildcard: `*.example.com` +- Suffix wildcard: `example.*` +- Catch-all: `*` + +**Complex patterns** (fallback to `filepath.Match`): +- Middle wildcards: `example.*.com` +- Partial wildcards: `*example.com`, `www*.example.com` + +## Performance Characteristics + +| Operation | Complexity | +|-----------|------------| +| Lookup | O(m) where m = domain depth | +| Insert | O(m) | +| Delete | O(m) | +| Space | O(n × m) where n = number of hosts | + +For typical domains (3-4 segments), lookup is effectively O(1) regardless of the number of hosts stored. + +## API + +The implementation is transparent - no changes needed to existing code: + +```go +manager := host.NewManager(logger) +manager.addHost(host) // Adds to trie or complexPatterns +manager.removeHost(host) // Removes from trie +matched := manager.MatchFirstHost("api.example.com") // Uses trie for lookup +``` + +## Key Improvements (v2) + +1. **Removed dead code**: Eliminated unused `getAllHosts()` and `collectHosts()` functions +2. **Fixed priority bug**: Priority comparison now uses `math.MinInt` as the initial value +3. **Zero allocations in hot path**: `findMatches` uses pointers instead of returning slices +4. **Better documentation**: Comprehensive comments explaining the algorithm +5. **Cleaner node structure**: Removed unused `priority` field from nodes (calculated on demand) +6. **Edge case handling**: Proper nil/empty checks throughout + +## Testing + +The implementation includes comprehensive tests covering: +- Single host matching +- Multiple hosts with priority ordering +- Wildcard patterns (prefix, suffix, catch-all) +- Complex wildcard patterns +- Host removal +- Cache behavior +- Edge cases (no hosts, no match) diff --git a/pkg/host/benchmark_test.go b/pkg/host/benchmark_test.go new file mode 100644 index 0000000..067ce67 --- /dev/null +++ b/pkg/host/benchmark_test.go @@ -0,0 +1,292 @@ +package host + +import ( + "fmt" + "path/filepath" + "testing" +) + +// sliceMatcher represents the old slice-based host matching approach +// for benchmark comparison purposes +type sliceMatcher struct { + hosts []*Host +} + +func newSliceMatcher() *sliceMatcher { + return &sliceMatcher{ + hosts: make([]*Host, 0), + } +} + +func (s *sliceMatcher) add(host *Host) { + s.hosts = append(s.hosts, host) +} + +// matchFirstHost is the old O(n) matching algorithm using filepath.Match +func (s *sliceMatcher) matchFirstHost(toMatch string) *Host { //nolint:unparam + for _, host := range s.hosts { + matched, err := filepath.Match(host.Host, toMatch) + if matched && err == nil { + return host + } + } + return nil +} + +// trieMatcher wraps the new trie implementation for benchmarking +type trieMatcher struct { + trie *domainTrie + complexPatterns []*Host +} + +func newTrieMatcher() *trieMatcher { + return &trieMatcher{ + trie: newDomainTrie(), + complexPatterns: make([]*Host, 0), + } +} + +func (t *trieMatcher) add(host *Host) { + if isComplexPattern(host.Host) { + t.complexPatterns = append(t.complexPatterns, host) + } else { + t.trie.add(host) + } +} + +func (t *trieMatcher) matchFirstHost(toMatch string) *Host { //nolint:unparam + return t.trie.match(toMatch, t.complexPatterns) +} + +// generateHosts creates n hosts with patterns like: +// - Exact matches: host0.example.com, host1.example.com, ... +// - With some wildcards: *.domain0.com, *.domain1.com, ... +// - One catch-all: * +func generateHosts(n int) []*Host { + hosts := make([]*Host, 0, n) + + // 70% exact matches + exactCount := (n * 70) / 100 + for range exactCount { + hosts = append(hosts, &Host{ + Host: fmt.Sprintf("host%d.example%d.com", len(hosts), len(hosts)%100), + }) + } + + // 29% wildcard patterns + wildcardCount := (n * 29) / 100 + for i := range wildcardCount { + hosts = append(hosts, &Host{ + Host: fmt.Sprintf("*.domain%d.com", i), + }) + } + + // 1% catch-all (at least 1) + catchAllCount := n - exactCount - wildcardCount + if catchAllCount < 1 { + catchAllCount = 1 + } + for range catchAllCount { + hosts = append(hosts, &Host{ + Host: "*", + }) + } + + return hosts +} + +// BenchmarkSliceMatcher_Small benchmarks slice matching with 10 hosts +func BenchmarkSliceMatcher_Small(b *testing.B) { + benchmarkSliceMatcher(b, 10) +} + +// BenchmarkTrieMatcher_Small benchmarks trie matching with 10 hosts +func BenchmarkTrieMatcher_Small(b *testing.B) { + benchmarkTrieMatcher(b, 10) +} + +// BenchmarkSliceMatcher_Medium benchmarks slice matching with 100 hosts +func BenchmarkSliceMatcher_Medium(b *testing.B) { + benchmarkSliceMatcher(b, 100) +} + +// BenchmarkTrieMatcher_Medium benchmarks trie matching with 100 hosts +func BenchmarkTrieMatcher_Medium(b *testing.B) { + benchmarkTrieMatcher(b, 100) +} + +// BenchmarkSliceMatcher_Large benchmarks slice matching with 1000 hosts +func BenchmarkSliceMatcher_Large(b *testing.B) { + benchmarkSliceMatcher(b, 1000) +} + +// BenchmarkTrieMatcher_Large benchmarks trie matching with 1000 hosts +func BenchmarkTrieMatcher_Large(b *testing.B) { + benchmarkTrieMatcher(b, 1000) +} + +// BenchmarkSliceMatcher_XLarge benchmarks slice matching with 10000 hosts +func BenchmarkSliceMatcher_XLarge(b *testing.B) { + benchmarkSliceMatcher(b, 10000) +} + +// BenchmarkTrieMatcher_XLarge benchmarks trie matching with 10000 hosts +func BenchmarkTrieMatcher_XLarge(b *testing.B) { + benchmarkTrieMatcher(b, 10000) +} + +func benchmarkSliceMatcher(b *testing.B, hostCount int) { + hosts := generateHosts(hostCount) + matcher := newSliceMatcher() + for _, h := range hosts { + matcher.add(h) + } + + // Test domains to match + testDomains := []string{ + "host0.example0.com", // First exact match + fmt.Sprintf("host%d.example%d.com", hostCount/2, (hostCount/2)%100), // Middle exact match + "api.domain50.com", // Wildcard match + "unknown.random.org", // Falls through to catch-all + } + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + for _, domain := range testDomains { + _ = matcher.matchFirstHost(domain) + } + } +} + +func benchmarkTrieMatcher(b *testing.B, hostCount int) { + hosts := generateHosts(hostCount) + matcher := newTrieMatcher() + for _, h := range hosts { + matcher.add(h) + } + + // Test domains to match + testDomains := []string{ + "host0.example0.com", // First exact match + fmt.Sprintf("host%d.example%d.com", hostCount/2, (hostCount/2)%100), // Middle exact match + "api.domain50.com", // Wildcard match + "unknown.random.org", // Falls through to catch-all + } + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + for _, domain := range testDomains { + _ = matcher.matchFirstHost(domain) + } + } +} + +// BenchmarkSliceMatcher_WorstCase benchmarks slice matching when match is at the end +func BenchmarkSliceMatcher_WorstCase(b *testing.B) { + hosts := generateHosts(1000) + matcher := newSliceMatcher() + for _, h := range hosts { + matcher.add(h) + } + + // Domain that will only match the catch-all at the end + domain := "nomatch.unknown.tld" + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + _ = matcher.matchFirstHost(domain) + } +} + +// BenchmarkTrieMatcher_WorstCase benchmarks trie matching when falling through to catch-all +func BenchmarkTrieMatcher_WorstCase(b *testing.B) { + hosts := generateHosts(1000) + matcher := newTrieMatcher() + for _, h := range hosts { + matcher.add(h) + } + + // Domain that will only match the catch-all + domain := "nomatch.unknown.tld" + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + _ = matcher.matchFirstHost(domain) + } +} + +// BenchmarkSliceMatcher_BestCase benchmarks slice matching when match is first +func BenchmarkSliceMatcher_BestCase(b *testing.B) { + hosts := generateHosts(1000) + matcher := newSliceMatcher() + for _, h := range hosts { + matcher.add(h) + } + + // Domain that matches the first host + domain := "host0.example0.com" + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + _ = matcher.matchFirstHost(domain) + } +} + +// BenchmarkTrieMatcher_BestCase benchmarks trie matching with exact match +func BenchmarkTrieMatcher_BestCase(b *testing.B) { + hosts := generateHosts(1000) + matcher := newTrieMatcher() + for _, h := range hosts { + matcher.add(h) + } + + // Domain that has an exact match + domain := "host0.example0.com" + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + _ = matcher.matchFirstHost(domain) + } +} + +// BenchmarkAddHost_Slice benchmarks adding hosts to slice +func BenchmarkAddHost_Slice(b *testing.B) { + hosts := generateHosts(100) + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + matcher := newSliceMatcher() + for _, h := range hosts { + matcher.add(h) + } + } +} + +// BenchmarkAddHost_Trie benchmarks adding hosts to trie +func BenchmarkAddHost_Trie(b *testing.B) { + hosts := generateHosts(100) + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + matcher := newTrieMatcher() + for _, h := range hosts { + matcher.add(h) + } + } +} diff --git a/pkg/host/root.go b/pkg/host/root.go index 0d97d3b..1244596 100644 --- a/pkg/host/root.go +++ b/pkg/host/root.go @@ -39,10 +39,12 @@ type Host struct { } type Manager struct { - Hosts []*Host - Chan chan HostOp - Logger *log.Entry - cache map[string]*Host + Hosts []*Host + Chan chan HostOp + Logger *log.Entry + cache map[string]*Host + trie *domainTrie + complexPatterns []*Host // Patterns that don't fit well in the trie (wildcards in middle) sync.RWMutex } @@ -66,10 +68,12 @@ func (h *Manager) String() string { func NewManager(l *log.Entry) *Manager { return &Manager{ - Hosts: make([]*Host, 0), - Chan: make(chan HostOp), - Logger: l, - cache: make(map[string]*Host), + Hosts: make([]*Host, 0), + Chan: make(chan HostOp), + Logger: l, + cache: make(map[string]*Host), + trie: newDomainTrie(), + complexPatterns: make([]*Host, 0), } } @@ -86,14 +90,15 @@ func (h *Manager) MatchFirstHost(toMatch string) *Host { return host } - for _, host := range h.Hosts { - matched, err := filepath.Match(host.Host, toMatch) - if matched && err == nil { - host.logger.WithField("requested_host", toMatch).Debug("matched host pattern") - h.cache[toMatch] = host - return host - } + // Use trie for efficient matching + host := h.trie.match(toMatch, h.complexPatterns) + + if host != nil { + host.logger.WithField("requested_host", toMatch).Debug("matched host pattern") + h.cache[toMatch] = host + return host } + h.Logger.WithField("requested_host", toMatch).Debug("no matching host found") return nil } @@ -178,6 +183,25 @@ func (h *Manager) sort() { } func (h *Manager) removeHost(host *Host) { + // Remove from trie + if isComplexPattern(host.Host) { + // Remove from complexPatterns + for i, th := range h.complexPatterns { + if th == host { + if i == len(h.complexPatterns)-1 { + h.complexPatterns = h.complexPatterns[:i] + } else { + h.complexPatterns = append(h.complexPatterns[:i], h.complexPatterns[i+1:]...) + } + break + } + } + } else { + // Remove from trie + h.trie.remove(host) + } + + // Remove from Hosts slice for i, th := range h.Hosts { if th == host { // Sessions persist in global manager, no cleanup needed @@ -186,6 +210,8 @@ func (h *Manager) removeHost(host *Host) { } else { h.Hosts = append(h.Hosts[:i], h.Hosts[i+1:]...) } + // Clear cache since host configuration changed + h.cache = make(map[string]*Host) return } } @@ -249,5 +275,14 @@ func (h *Manager) addHost(host *Host) { if err := host.AppSec.Init(host.logger); err != nil { host.logger.Error(err) } + + // Add to Hosts slice (for backward compatibility and complex patterns) h.Hosts = append(h.Hosts, host) + + // Add to trie or complexPatterns based on pattern complexity + if isComplexPattern(host.Host) { + h.complexPatterns = append(h.complexPatterns, host) + } else { + h.trie.add(host) + } } diff --git a/pkg/host/root_test.go b/pkg/host/root_test.go new file mode 100644 index 0000000..e3cdc0f --- /dev/null +++ b/pkg/host/root_test.go @@ -0,0 +1,305 @@ +package host + +import ( + "testing" + + "github.com/crowdsecurity/crowdsec-spoa/internal/remediation/ban" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestMatchFirstHost_SingleHost(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + // Add a single host + host := &Host{ + Host: "www.example.com", + Ban: ban.Ban{}, + } + manager.addHost(host) + + // Test exact match + matched := manager.MatchFirstHost("www.example.com") + assert.NotNil(t, matched, "should match exact host") + assert.Equal(t, "www.example.com", matched.Host, "matched host should be correct") + + // Test no match + matched = manager.MatchFirstHost("other.example.com") + assert.Nil(t, matched, "should not match different host") +} + +func TestMatchFirstHost_Priority(t *testing.T) { + // Table-driven test for priority matching + // Tests that more specific patterns always win regardless of insertion order + tests := []struct { + name string + hostPatterns []string // Patterns to add (in this order) + lookupDomain string + expectedMatch string + expectedReason string + }{ + { + name: "exact match wins over wildcard (specific first)", + hostPatterns: []string{"www.example.com", "*.example.com", "*"}, + lookupDomain: "www.example.com", + expectedMatch: "www.example.com", + expectedReason: "exact match should win over wildcard", + }, + { + name: "exact match wins over wildcard (wildcard first)", + hostPatterns: []string{"*", "*.example.com", "www.example.com"}, + lookupDomain: "www.example.com", + expectedMatch: "www.example.com", + expectedReason: "most specific pattern should win regardless of order", + }, + { + name: "wildcard matches subdomain (specific first)", + hostPatterns: []string{"www.example.com", "*.example.com", "*"}, + lookupDomain: "api.example.com", + expectedMatch: "*.example.com", + expectedReason: "wildcard pattern should match subdomain", + }, + { + name: "wildcard matches subdomain (wildcard first)", + hostPatterns: []string{"*", "*.example.com", "www.example.com"}, + lookupDomain: "api.example.com", + expectedMatch: "*.example.com", + expectedReason: "*.example.com should win over *", + }, + { + name: "catch-all matches unmatched (specific first)", + hostPatterns: []string{"www.example.com", "*.example.com", "*"}, + lookupDomain: "other.com", + expectedMatch: "*", + expectedReason: "catch-all should match any domain", + }, + { + name: "catch-all matches unmatched (wildcard first)", + hostPatterns: []string{"*", "*.example.com", "www.example.com"}, + lookupDomain: "other.com", + expectedMatch: "*", + expectedReason: "catch-all should match when no other pattern matches", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + // Add hosts in specified order + for _, pattern := range tt.hostPatterns { + manager.addHost(&Host{Host: pattern, Ban: ban.Ban{}}) + } + + matched := manager.MatchFirstHost(tt.lookupDomain) + assert.NotNil(t, matched, "should match %s", tt.lookupDomain) + assert.Equal(t, tt.expectedMatch, matched.Host, tt.expectedReason) + }) + } +} + +func TestMatchFirstHost_WildcardPatterns(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + // Test prefix wildcard + host1 := &Host{ + Host: "*.example.com", + Ban: ban.Ban{}, + } + manager.addHost(host1) + + matched := manager.MatchFirstHost("www.example.com") + assert.NotNil(t, matched, "*.example.com should match www.example.com") + assert.Equal(t, "*.example.com", matched.Host) + + matched = manager.MatchFirstHost("api.example.com") + assert.NotNil(t, matched, "*.example.com should match api.example.com") + assert.Equal(t, "*.example.com", matched.Host) + + matched = manager.MatchFirstHost("example.com") + assert.Nil(t, matched, "*.example.com should not match example.com (no subdomain)") + + // Test suffix wildcard + host2 := &Host{ + Host: "example.*", + Ban: ban.Ban{}, + } + manager.addHost(host2) + + matched = manager.MatchFirstHost("example.com") + assert.NotNil(t, matched, "example.* should match example.com") + assert.Equal(t, "example.*", matched.Host) + + matched = manager.MatchFirstHost("example.org") + assert.NotNil(t, matched, "example.* should match example.org") + assert.Equal(t, "example.*", matched.Host) +} + +func TestMatchFirstHost_ComplexWildcardPatterns(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + // Complex pattern with wildcard in middle (should use filepath.Match fallback) + host1 := &Host{ + Host: "example.*.com", + Ban: ban.Ban{}, + } + manager.addHost(host1) + + matched := manager.MatchFirstHost("example.test.com") + assert.NotNil(t, matched, "example.*.com should match example.test.com") + assert.Equal(t, "example.*.com", matched.Host) + + matched = manager.MatchFirstHost("example.api.com") + assert.NotNil(t, matched, "example.*.com should match example.api.com") + assert.Equal(t, "example.*.com", matched.Host) +} + +func TestMatchFirstHost_Priority_LengthMatters(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + // Longer patterns should have higher priority + host1 := &Host{ + Host: "*.com", + Ban: ban.Ban{}, + } + host2 := &Host{ + Host: "*.example.com", + Ban: ban.Ban{}, + } + + manager.addHost(host1) + manager.addHost(host2) + + // api.example.com should match the longer, more specific pattern + matched := manager.MatchFirstHost("api.example.com") + assert.NotNil(t, matched, "should match api.example.com") + assert.Equal(t, "*.example.com", matched.Host, "longer pattern should win") +} + +func TestMatchFirstHost_Cache(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + host1 := &Host{ + Host: "www.example.com", + Ban: ban.Ban{}, + } + manager.addHost(host1) + + // First match should populate cache + matched1 := manager.MatchFirstHost("www.example.com") + assert.NotNil(t, matched1) + + // Second match should use cache + matched2 := manager.MatchFirstHost("www.example.com") + assert.NotNil(t, matched2) + assert.Equal(t, matched1, matched2, "cached result should be returned") +} + +func TestMatchFirstHost_NoHosts(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + matched := manager.MatchFirstHost("www.example.com") + assert.Nil(t, matched, "should return nil when no hosts configured") +} + +func TestMatchFirstHost_RemoveHost(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + host1 := &Host{ + Host: "www.example.com", + Ban: ban.Ban{}, + } + host2 := &Host{ + Host: "*.example.com", + Ban: ban.Ban{}, + } + + manager.addHost(host1) + manager.addHost(host2) + + // Should match + matched := manager.MatchFirstHost("www.example.com") + assert.NotNil(t, matched) + assert.Equal(t, "www.example.com", matched.Host) + + // Remove host1 + manager.removeHost(host1) + + // Should now match wildcard + matched = manager.MatchFirstHost("www.example.com") + assert.NotNil(t, matched) + assert.Equal(t, "*.example.com", matched.Host, "should match wildcard after exact match removed") +} + +func TestMatchFirstHost_MultipleSpecificHosts(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + // Add multiple specific hosts + host1 := &Host{ + Host: "www.example.com", + Ban: ban.Ban{}, + } + host2 := &Host{ + Host: "api.example.com", + Ban: ban.Ban{}, + } + host3 := &Host{ + Host: "*.example.com", + Ban: ban.Ban{}, + } + + manager.addHost(host1) + manager.addHost(host2) + manager.addHost(host3) + + // Each specific host should match itself + matched := manager.MatchFirstHost("www.example.com") + assert.NotNil(t, matched) + assert.Equal(t, "www.example.com", matched.Host) + + matched = manager.MatchFirstHost("api.example.com") + assert.NotNil(t, matched) + assert.Equal(t, "api.example.com", matched.Host) + + // Other subdomains should match wildcard + matched = manager.MatchFirstHost("test.example.com") + assert.NotNil(t, matched) + assert.Equal(t, "*.example.com", matched.Host) +} + +func TestMatchFirstHost_CatchAllLastResort(t *testing.T) { + logger := log.NewEntry(log.New()) + manager := NewManager(logger) + + // Add catch-all and specific pattern + host1 := &Host{ + Host: "*", + Ban: ban.Ban{}, + } + host2 := &Host{ + Host: "specific.com", + Ban: ban.Ban{}, + } + + manager.addHost(host1) + manager.addHost(host2) + + // Specific should win + matched := manager.MatchFirstHost("specific.com") + assert.NotNil(t, matched) + assert.Equal(t, "specific.com", matched.Host, "specific pattern should win over catch-all") + + // Catch-all should match anything else + matched = manager.MatchFirstHost("other.com") + assert.NotNil(t, matched) + assert.Equal(t, "*", matched.Host, "catch-all should match when no specific pattern matches") +} diff --git a/pkg/host/trie.go b/pkg/host/trie.go new file mode 100644 index 0000000..8d07d3e --- /dev/null +++ b/pkg/host/trie.go @@ -0,0 +1,348 @@ +package host + +import ( + "math" + "path/filepath" + "strings" + "sync" +) + +const ( + // minPriority is used to initialize priority comparisons. + // Any valid pattern will have a higher priority than this. + minPriority = math.MinInt +) + +// domainTrieNode represents a node in the reverse domain trie. +// The trie is built by reversing domain names so "www.example.com" becomes ["com", "example", "www"]. +// This allows efficient prefix matching for wildcard patterns like "*.example.com". +type domainTrieNode struct { + // children maps exact domain segments to child nodes + children map[string]*domainTrieNode + // wildcardChild handles wildcard segment matches (e.g., "*" in "*.example.com") + wildcardChild *domainTrieNode + // host stores the Host configuration if this node represents a complete pattern + host *Host + // pattern stores the original pattern string for this node (used for priority calculation) + pattern string +} + +// domainTrie is a reverse domain trie for efficient host pattern matching. +// It provides O(m) lookup complexity where m is the number of domain segments, +// independent of the total number of hosts stored. +type domainTrie struct { + root *domainTrieNode + mu sync.RWMutex +} + +// newDomainTrie creates a new empty domain trie. +func newDomainTrie() *domainTrie { + return &domainTrie{ + root: &domainTrieNode{ + children: make(map[string]*domainTrieNode), + }, + } +} + +// reverseDomain splits and reverses a domain name for trie insertion/lookup. +// Examples: +// - "www.example.com" -> ["com", "example", "www"] +// - "*.example.com" -> ["com", "example", "*"] +// - "*" -> ["*"] +func reverseDomain(domain string) []string { + if domain == "" { + return nil + } + parts := strings.Split(domain, ".") + // Reverse in place + for i, j := 0, len(parts)-1; i < j; i, j = i+1, j-1 { + parts[i], parts[j] = parts[j], parts[i] + } + return parts +} + +// calculatePriority determines the specificity of a host pattern. +// Higher values indicate more specific patterns that should match first. +// +// Priority factors (in order of importance): +// 1. Exact matches (no wildcards) get highest priority +// 2. Longer patterns are more specific +// 3. Each wildcard reduces priority +func calculatePriority(pattern string) int { + if pattern == "" { + return -1 + } + + priority := 0 + + // Exact matches (no wildcards) get high base priority + hasWildcard := strings.ContainsAny(pattern, "*?") + if !hasWildcard { + priority += 10000 + } + + // Longer patterns are more specific (each character adds 10) + priority += len(pattern) * 10 + + // Each wildcard character reduces priority significantly + wildcardCount := strings.Count(pattern, "*") + strings.Count(pattern, "?") + priority -= wildcardCount * 1000 + + return priority +} + +// isWildcardSegment returns true if the segment contains wildcard characters. +func isWildcardSegment(segment string) bool { + return strings.ContainsAny(segment, "*?") +} + +// add inserts a host pattern into the trie. +// If a pattern already exists at the same node, the higher priority one is kept. +func (dt *domainTrie) add(host *Host) { + if host == nil || host.Host == "" { + return + } + + dt.mu.Lock() + defer dt.mu.Unlock() + + pattern := host.Host + parts := reverseDomain(pattern) + if len(parts) == 0 { + return + } + + current := dt.root + for _, part := range parts { + if isWildcardSegment(part) { + // Wildcard segments share a single child node + if current.wildcardChild == nil { + current.wildcardChild = &domainTrieNode{ + children: make(map[string]*domainTrieNode), + } + } + current = current.wildcardChild + } else { + // Exact segments get their own child nodes + if current.children[part] == nil { + current.children[part] = &domainTrieNode{ + children: make(map[string]*domainTrieNode), + } + } + current = current.children[part] + } + } + + // Store host at terminal node (higher priority wins) + newPriority := calculatePriority(pattern) + existingPriority := calculatePriority(current.pattern) + if current.host == nil || newPriority > existingPriority { + current.host = host + current.pattern = pattern + } +} + +// remove removes a host pattern from the trie. +// It also cleans up empty nodes to prevent memory leaks. +func (dt *domainTrie) remove(host *Host) { + if host == nil || host.Host == "" { + return + } + + dt.mu.Lock() + defer dt.mu.Unlock() + + pattern := host.Host + parts := reverseDomain(pattern) + if len(parts) == 0 { + return + } + + // Track the path for cleanup + type pathEntry struct { + node *domainTrieNode + parent *domainTrieNode + key string // empty for wildcardChild + } + + path := make([]pathEntry, 0, len(parts)+1) + current := dt.root + path = append(path, pathEntry{node: current, parent: nil, key: ""}) + + for _, part := range parts { + var next *domainTrieNode + var key string + + if isWildcardSegment(part) { + next = current.wildcardChild + key = "" + } else { + next = current.children[part] + key = part + } + + if next == nil { + return // Pattern not found + } + + path = append(path, pathEntry{node: next, parent: current, key: key}) + current = next + } + + // Only remove if this is the exact host we're looking for + if current.host != host { + return + } + + current.host = nil + current.pattern = "" + + // Clean up empty nodes (traverse backwards, skip root) + for i := len(path) - 1; i > 0; i-- { + entry := path[i] + node := entry.node + + // Stop if node still has content + if node.host != nil || len(node.children) > 0 || node.wildcardChild != nil { + break + } + + // Remove from parent + parent := entry.parent + if entry.key == "" { + parent.wildcardChild = nil + } else { + delete(parent.children, entry.key) + } + } +} + +// match finds the best matching host for a given domain. +// Returns nil if no match is found. +func (dt *domainTrie) match(domain string, complexPatterns []*Host) *Host { + if domain == "" { + return nil + } + + dt.mu.RLock() + defer dt.mu.RUnlock() + + parts := reverseDomain(domain) + if len(parts) == 0 { + return nil + } + + // Find all matches from the trie + var bestMatch *Host + bestPriority := minPriority + + dt.findMatches(parts, 0, dt.root, &bestMatch, &bestPriority) + + // Check complex patterns (fallback for patterns that don't fit the trie) + for _, host := range complexPatterns { + if matched, err := filepath.Match(host.Host, domain); matched && err == nil { + priority := calculatePriority(host.Host) + if priority > bestPriority { + bestMatch = host + bestPriority = priority + } + } + } + + return bestMatch +} + +// findMatches recursively searches the trie for all matching hosts. +// It updates bestMatch and bestPriority in place to avoid allocations. +func (dt *domainTrie) findMatches(parts []string, depth int, node *domainTrieNode, bestMatch **Host, bestPriority *int) { + if node == nil { + return + } + + // Base case: consumed all domain parts + if depth >= len(parts) { + if node.host != nil { + priority := calculatePriority(node.pattern) + if priority > *bestPriority { + *bestMatch = node.host + *bestPriority = priority + } + } + return + } + + currentPart := parts[depth] + + // Try exact match first (most specific) + exactMatchFound := false + if child, ok := node.children[currentPart]; ok { + prevBest := *bestMatch + dt.findMatches(parts, depth+1, child, bestMatch, bestPriority) + // Check if exact path found a match (either changed from nil or was already set) + exactMatchFound = *bestMatch != nil && (*bestMatch != prevBest || prevBest != nil) + } + + // Try wildcard match only if exact path didn't find anything + // This ensures deeper exact matches take precedence over shallower wildcards + if !exactMatchFound && node.wildcardChild != nil { + dt.findMatches(parts, depth+1, node.wildcardChild, bestMatch, bestPriority) + } + + // Check if current node can match remaining parts (for patterns like "*" or "*.com") + // Only consider if no better match was found from children + if *bestMatch == nil && node.host != nil { + // Verify this is a wildcard pattern that can match remaining parts + if isWildcardPattern(node.pattern) { + priority := calculatePriority(node.pattern) + if priority > *bestPriority { + *bestMatch = node.host + *bestPriority = priority + } + } + } +} + +// isWildcardPattern returns true if the pattern contains wildcards +// that can match variable-length domain parts. +func isWildcardPattern(pattern string) bool { + return strings.ContainsAny(pattern, "*?") +} + +// isComplexPattern determines if a pattern is too complex for the trie. +// Complex patterns have wildcards in positions that don't align with domain segments. +// +// Simple patterns (handled by trie): +// - Exact: "www.example.com" +// - Prefix wildcard: "*.example.com" +// - Suffix wildcard: "example.*" +// - Catch-all: "*" +// +// Complex patterns (fallback to filepath.Match): +// - Middle wildcards: "example.*.com" +// - Embedded wildcards: "*example.com", "www*.example.com" +func isComplexPattern(pattern string) bool { + if pattern == "" || pattern == "*" { + return false + } + + parts := strings.Split(pattern, ".") + + for i, part := range parts { + if !strings.ContainsAny(part, "*?") { + continue + } + + // Wildcards in middle segments are complex + if i > 0 && i < len(parts)-1 { + return true + } + + // Partial wildcards (not just "*" or "?") are complex + // e.g., "*example" or "www*" within a segment + if part != "*" && part != "?" { + return true + } + } + + return false +} From 4d5d11b173c3ef04a6bded82157e9c88ad75bdf0 Mon Sep 17 00:00:00 2001 From: Laurence Date: Wed, 26 Nov 2025 09:31:10 +0000 Subject: [PATCH 2/2] fix(host): address copilot review comments - Fix exactMatchFound logic in trie findMatches - Clarify removeHost comments for complex patterns - Fix race condition: use sync.Map for thread-safe cache access - Add proper type assertion check for cache retrieval --- pkg/host/root.go | 29 ++++++++++++++++------------- pkg/host/trie.go | 4 ++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/pkg/host/root.go b/pkg/host/root.go index 1244596..83eb799 100644 --- a/pkg/host/root.go +++ b/pkg/host/root.go @@ -42,7 +42,7 @@ type Manager struct { Hosts []*Host Chan chan HostOp Logger *log.Entry - cache map[string]*Host + cache sync.Map // map[string]*Host - thread-safe cache for matched hosts trie *domainTrie complexPatterns []*Host // Patterns that don't fit well in the trie (wildcards in middle) sync.RWMutex @@ -71,13 +71,21 @@ func NewManager(l *log.Entry) *Manager { Hosts: make([]*Host, 0), Chan: make(chan HostOp), Logger: l, - cache: make(map[string]*Host), + // cache is a sync.Map, zero value is ready to use trie: newDomainTrie(), complexPatterns: make([]*Host, 0), } } func (h *Manager) MatchFirstHost(toMatch string) *Host { + // Check cache first (thread-safe via sync.Map) + if cached, ok := h.cache.Load(toMatch); ok { + if host, ok := cached.(*Host); ok { + host.logger.WithField("requested_host", toMatch).Debug("matched host from cache") + return host + } + } + h.RLock() defer h.RUnlock() @@ -85,17 +93,12 @@ func (h *Manager) MatchFirstHost(toMatch string) *Host { return nil } - if host, ok := h.cache[toMatch]; ok { - host.logger.WithField("requested_host", toMatch).Debug("matched host from cache") - return host - } - // Use trie for efficient matching host := h.trie.match(toMatch, h.complexPatterns) if host != nil { host.logger.WithField("requested_host", toMatch).Debug("matched host pattern") - h.cache[toMatch] = host + h.cache.Store(toMatch, host) // Thread-safe via sync.Map return host } @@ -111,10 +114,10 @@ func (h *Manager) Run(ctx context.Context) { h.Lock() switch instruction.Op { case OpRemove: - h.cache = make(map[string]*Host) + h.cache.Clear() // Clear cache when hosts change h.removeHost(instruction.Host) case OpAdd: - h.cache = make(map[string]*Host) + h.cache.Clear() // Clear cache when hosts change h.addHost(instruction.Host) h.sort() case OpPatch: @@ -183,9 +186,9 @@ func (h *Manager) sort() { } func (h *Manager) removeHost(host *Host) { - // Remove from trie + // Remove from trie or complexPatterns based on pattern type if isComplexPattern(host.Host) { - // Remove from complexPatterns + // Complex patterns are stored in slice, not trie for i, th := range h.complexPatterns { if th == host { if i == len(h.complexPatterns)-1 { @@ -211,7 +214,7 @@ func (h *Manager) removeHost(host *Host) { h.Hosts = append(h.Hosts[:i], h.Hosts[i+1:]...) } // Clear cache since host configuration changed - h.cache = make(map[string]*Host) + h.cache.Clear() return } } diff --git a/pkg/host/trie.go b/pkg/host/trie.go index 8d07d3e..86bb9e7 100644 --- a/pkg/host/trie.go +++ b/pkg/host/trie.go @@ -278,8 +278,8 @@ func (dt *domainTrie) findMatches(parts []string, depth int, node *domainTrieNod if child, ok := node.children[currentPart]; ok { prevBest := *bestMatch dt.findMatches(parts, depth+1, child, bestMatch, bestPriority) - // Check if exact path found a match (either changed from nil or was already set) - exactMatchFound = *bestMatch != nil && (*bestMatch != prevBest || prevBest != nil) + // Check if exact path found a match (only if it actually improved the match) + exactMatchFound = *bestMatch != nil && *bestMatch != prevBest } // Try wildcard match only if exact path didn't find anything