Skip to content

Commit 4e2cef1

Browse files
fix: geo database initialization bug and ISO code header issue (#121)
* Fix geo database initialization bug and ISO code header issue - Fix geo database initialization: remove meaningless IsValid() check in open() that always passed on first call, add proper path validation before loading - Improve code redundancy: extract common file checking logic into checkAndUpdateModTime() function, reduce duplication in WatchFiles() - Add comprehensive test coverage for Init() with various error cases - Fix ISO code header: always set isocode variable when available, regardless of remediation status, so X-Crowdsec-IsoCode header is populated correctly * Fix linting issues in geo tests - Replace assert.Error with require.Error for error assertions (testifylint) - Use t.TempDir() instead of empty string for os.CreateTemp (usetesting)
1 parent fdf9b12 commit 4e2cef1

File tree

3 files changed

+247
-45
lines changed

3 files changed

+247
-45
lines changed

internal/geo/geo_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ package geo
33
import (
44
"context"
55
"net/netip"
6+
"os"
67
"path/filepath"
78
"testing"
9+
"time"
810

911
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
1013
)
1114

1215
func TestGetCityAndASN(t *testing.T) {
@@ -20,6 +23,9 @@ func TestGetCityAndASN(t *testing.T) {
2023

2124
g.Init(ctx)
2225

26+
// Verify Init succeeded
27+
assert.True(t, g.IsValid(), "GeoDatabase should be valid after successful Init")
28+
2329
ip := netip.MustParseAddr("2.125.160.216")
2430
city, err := g.GetCity(ip)
2531
if err != nil {
@@ -51,3 +57,136 @@ func TestGetCityAndASN(t *testing.T) {
5157
continentName = city.Continent.Names.English
5258
assert.Empty(t, continentName, "Expected empty continent name, got '%s'", continentName)
5359
}
60+
61+
func TestInit_EmptyPaths(t *testing.T) {
62+
ctx, cancel := context.WithCancel(context.Background())
63+
defer cancel()
64+
65+
g := &GeoDatabase{
66+
ASNPath: "",
67+
CityPath: "",
68+
}
69+
70+
g.Init(ctx)
71+
72+
// Should be invalid when both paths are empty
73+
assert.False(t, g.IsValid(), "GeoDatabase should be invalid when both paths are empty")
74+
75+
// GetCity should return error
76+
ip := netip.MustParseAddr("1.1.1.1")
77+
_, err := g.GetCity(ip)
78+
require.Error(t, err)
79+
assert.Equal(t, ErrNotValidConfig, err)
80+
}
81+
82+
func TestInit_MissingFiles(t *testing.T) {
83+
ctx, cancel := context.WithCancel(context.Background())
84+
defer cancel()
85+
86+
g := &GeoDatabase{
87+
ASNPath: "/nonexistent/path/to/ASN.mmdb",
88+
CityPath: "/nonexistent/path/to/City.mmdb",
89+
}
90+
91+
g.Init(ctx)
92+
93+
// Should be invalid when files don't exist
94+
assert.False(t, g.IsValid(), "GeoDatabase should be invalid when files don't exist")
95+
96+
// GetCity should return error
97+
ip := netip.MustParseAddr("1.1.1.1")
98+
_, err := g.GetCity(ip)
99+
require.Error(t, err)
100+
assert.Equal(t, ErrNotValidConfig, err)
101+
}
102+
103+
func TestInit_OneValidPath(t *testing.T) {
104+
ctx, cancel := context.WithCancel(context.Background())
105+
defer cancel()
106+
107+
// Test with only ASN path
108+
g := &GeoDatabase{
109+
ASNPath: filepath.Join("test_data", "GeoLite2-ASN.mmdb"),
110+
CityPath: "",
111+
}
112+
113+
g.Init(ctx)
114+
assert.True(t, g.IsValid(), "GeoDatabase should be valid with only ASN path")
115+
116+
ip := netip.MustParseAddr("1.0.0.1")
117+
asn, err := g.GetASN(ip)
118+
require.NoError(t, err)
119+
assert.Equal(t, uint(15169), asn.AutonomousSystemNumber)
120+
121+
// City should return empty record, not error
122+
city, err := g.GetCity(ip)
123+
require.NoError(t, err)
124+
assert.NotNil(t, city)
125+
}
126+
127+
func TestInit_InvalidPathIsDirectory(t *testing.T) {
128+
ctx, cancel := context.WithCancel(context.Background())
129+
defer cancel()
130+
131+
// Use test_data directory as an invalid path (it's a directory, not a file)
132+
g := &GeoDatabase{
133+
ASNPath: "test_data",
134+
CityPath: filepath.Join("test_data", "GeoLite2-City.mmdb"),
135+
}
136+
137+
g.Init(ctx)
138+
139+
// Should be invalid when path is a directory
140+
assert.False(t, g.IsValid(), "GeoDatabase should be invalid when path is a directory")
141+
}
142+
143+
func TestInit_WatchFilesGoroutine(t *testing.T) {
144+
ctx, cancel := context.WithCancel(context.Background())
145+
defer cancel()
146+
147+
g := &GeoDatabase{
148+
ASNPath: filepath.Join("test_data", "GeoLite2-ASN.mmdb"),
149+
CityPath: filepath.Join("test_data", "GeoLite2-City.mmdb"),
150+
}
151+
152+
g.Init(ctx)
153+
assert.True(t, g.IsValid(), "GeoDatabase should be valid after successful Init")
154+
155+
// Give WatchFiles goroutine a moment to initialize
156+
time.Sleep(100 * time.Millisecond)
157+
158+
// Verify lastModTime map is initialized
159+
g.RLock()
160+
assert.NotNil(t, g.lastModTime, "lastModTime map should be initialized")
161+
g.RUnlock()
162+
163+
// Cancel context to stop WatchFiles goroutine
164+
cancel()
165+
166+
// Give it a moment to clean up
167+
time.Sleep(100 * time.Millisecond)
168+
169+
// Database should still be valid after context cancellation
170+
assert.True(t, g.IsValid(), "GeoDatabase should remain valid after context cancellation")
171+
}
172+
173+
func TestInit_EmptyFile(t *testing.T) {
174+
// Create a temporary empty file
175+
tmpFile, err := os.CreateTemp(t.TempDir(), "empty-*.mmdb")
176+
require.NoError(t, err)
177+
tmpPath := tmpFile.Name()
178+
tmpFile.Close()
179+
180+
ctx, cancel := context.WithCancel(context.Background())
181+
defer cancel()
182+
183+
g := &GeoDatabase{
184+
ASNPath: tmpPath,
185+
CityPath: "",
186+
}
187+
188+
g.Init(ctx)
189+
190+
// Should be invalid when file is empty
191+
assert.False(t, g.IsValid(), "GeoDatabase should be invalid when file is empty")
192+
}

internal/geo/root.go

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ func (g *GeoDatabase) Init(ctx context.Context) {
3434
return
3535
}
3636

37+
// Validate paths exist before attempting to load
38+
if err := g.validatePaths(); err != nil {
39+
log.Errorf("geo database path validation failed: %s", err)
40+
g.loadFailed = true
41+
return
42+
}
43+
3744
if err := g.open(); err != nil {
3845
log.Errorf("failed to open databases: %s", err)
3946
g.loadFailed = true
@@ -50,25 +57,65 @@ func (g *GeoDatabase) reload() error {
5057
return g.open()
5158
}
5259

53-
func (g *GeoDatabase) open() error {
54-
if !g.IsValid() {
55-
return ErrNotValidConfig
60+
// validatePaths checks if the configured database paths exist and are readable
61+
func (g *GeoDatabase) validatePaths() error {
62+
if g.ASNPath != "" {
63+
if err := g.validatePath(g.ASNPath, "ASN"); err != nil {
64+
return err
65+
}
66+
}
67+
68+
if g.CityPath != "" {
69+
if err := g.validatePath(g.CityPath, "City"); err != nil {
70+
return err
71+
}
5672
}
5773

74+
return nil
75+
}
76+
77+
// validatePath checks if a single database path exists and is readable
78+
func (g *GeoDatabase) validatePath(path, dbType string) error {
79+
info, err := os.Stat(path)
80+
if err != nil {
81+
if os.IsNotExist(err) {
82+
return fmt.Errorf("%s database file does not exist: %s", dbType, path)
83+
}
84+
return fmt.Errorf("failed to stat %s database file %s: %w", dbType, path, err)
85+
}
86+
87+
if info.IsDir() {
88+
return fmt.Errorf("%s database path is a directory, not a file: %s", dbType, path)
89+
}
90+
91+
if info.Size() == 0 {
92+
return fmt.Errorf("%s database file is empty: %s", dbType, path)
93+
}
94+
95+
return nil
96+
}
97+
98+
func (g *GeoDatabase) open() error {
5899
g.Lock()
59100
defer g.Unlock()
101+
60102
var err error
61103
if g.asnReader == nil && g.ASNPath != "" {
62104
g.asnReader, err = geoip2.Open(g.ASNPath)
63105
if err != nil {
64-
return err
106+
return fmt.Errorf("failed to open ASN database: %w", err)
65107
}
66108
}
67109

68110
if g.cityReader == nil && g.CityPath != "" {
69111
g.cityReader, err = geoip2.Open(g.CityPath)
70112
if err != nil {
71-
return err
113+
// Clean up ASN reader if it was opened successfully
114+
if g.asnReader != nil {
115+
g.asnReader.Close()
116+
g.asnReader = nil
117+
}
118+
return fmt.Errorf("failed to open City database: %w", err)
72119
}
73120
}
74121

@@ -158,51 +205,29 @@ func GetIsoCodeFromRecord(record *geoip2.City) string {
158205
func (g *GeoDatabase) WatchFiles(ctx context.Context) {
159206

160207
ticker := time.NewTicker(1 * time.Minute)
208+
defer ticker.Stop()
209+
161210
for {
162211
select {
163212
case <-ctx.Done():
164-
ticker.Stop()
165213
return
166214
case <-ticker.C:
167215
shouldUpdate := false
168-
if asnLastModTime, ok := g.lastModTime[g.ASNPath]; ok {
169-
info, err := os.Stat(g.ASNPath)
170-
if err != nil {
171-
log.Warnf("failed to stat ASN database: %s", err)
172-
continue
173-
}
174-
if info.ModTime().After(asnLastModTime) {
175-
log.Infof("ASN database has been updated, reloading")
216+
217+
// Check ASN database
218+
if g.ASNPath != "" {
219+
if updated := g.checkAndUpdateModTime(g.ASNPath, "ASN"); updated {
176220
shouldUpdate = true
177-
g.lastModTime[g.ASNPath] = info.ModTime()
178-
}
179-
} else {
180-
info, err := os.Stat(g.ASNPath)
181-
if err != nil {
182-
log.Warnf("failed to stat ASN database: %s", err)
183-
continue
184221
}
185-
g.lastModTime[g.ASNPath] = info.ModTime()
186222
}
187-
if cityLastModTime, ok := g.lastModTime[g.CityPath]; ok {
188-
info, err := os.Stat(g.CityPath)
189-
if err != nil {
190-
log.Warnf("failed to stat city database: %s", err)
191-
continue
192-
}
193-
if info.ModTime().After(cityLastModTime) {
194-
log.Infof("City database has been updated, reloading")
223+
224+
// Check City database
225+
if g.CityPath != "" {
226+
if updated := g.checkAndUpdateModTime(g.CityPath, "City"); updated {
195227
shouldUpdate = true
196-
g.lastModTime[g.CityPath] = info.ModTime()
197-
}
198-
} else {
199-
info, err := os.Stat(g.CityPath)
200-
if err != nil {
201-
log.Warnf("failed to stat city database: %s", err)
202-
continue
203228
}
204-
g.lastModTime[g.CityPath] = info.ModTime()
205229
}
230+
206231
if shouldUpdate {
207232
if err := g.reload(); err != nil {
208233
log.Warnf("failed to reload databases: %s", err)
@@ -211,3 +236,35 @@ func (g *GeoDatabase) WatchFiles(ctx context.Context) {
211236
}
212237
}
213238
}
239+
240+
// checkAndUpdateModTime checks if a database file has been modified and updates the lastModTime
241+
// Returns true if the file was updated (needs reload), false otherwise
242+
func (g *GeoDatabase) checkAndUpdateModTime(path, dbType string) bool {
243+
info, err := os.Stat(path)
244+
if err != nil {
245+
log.Warnf("failed to stat %s database: %s", dbType, err)
246+
return false
247+
}
248+
249+
g.RLock()
250+
lastModTime, exists := g.lastModTime[path]
251+
g.RUnlock()
252+
253+
if !exists {
254+
// First time checking this file, just record the mod time
255+
g.Lock()
256+
g.lastModTime[path] = info.ModTime()
257+
g.Unlock()
258+
return false
259+
}
260+
261+
if info.ModTime().After(lastModTime) {
262+
log.Infof("%s database has been updated, reloading", dbType)
263+
g.Lock()
264+
g.lastModTime[path] = info.ModTime()
265+
g.Unlock()
266+
return true
267+
}
268+
269+
return false
270+
}

pkg/spoa/root.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,9 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediatio
629629
return remediation.Allow, "" // Safe default
630630
}
631631

632-
// If no IP-specific remediation, check country-based
633-
if r < remediation.Unknown && s.geoDatabase.IsValid() {
632+
// Always try to get and set ISO code if geo database is available
633+
// This allows upstream services to use the ISO code regardless of remediation status
634+
if s.geoDatabase.IsValid() {
634635
record, err := s.geoDatabase.GetCity(ip)
635636
if err != nil && !errors.Is(err, geo.ErrNotValidConfig) {
636637
s.logger.WithFields(log.Fields{
@@ -640,12 +641,17 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediatio
640641
} else if record != nil {
641642
iso := geo.GetIsoCodeFromRecord(record)
642643
if iso != "" {
643-
cnR, cnOrigin := s.dataset.CheckCN(iso)
644-
if cnR > remediation.Unknown {
645-
r = cnR
646-
origin = cnOrigin
647-
}
644+
// Always set the ISO code variable when available
648645
req.Actions.SetVar(action.ScopeTransaction, "isocode", iso)
646+
647+
// If no IP-specific remediation, check country-based remediation
648+
if r < remediation.Unknown {
649+
cnR, cnOrigin := s.dataset.CheckCN(iso)
650+
if cnR > remediation.Unknown {
651+
r = cnR
652+
origin = cnOrigin
653+
}
654+
}
649655
}
650656
}
651657
}

0 commit comments

Comments
 (0)