diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 34395cc..1ff33b9 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.25.5' + go-version: '1.26.3' - name: Run golangci-lint uses: golangci/golangci-lint-action@v7 diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 488af95..fb043d2 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -28,7 +28,7 @@ jobs: if: ${{ steps.release.outputs.release_created }} uses: actions/setup-go@v5 with: - go-version: '1.25.5' + go-version: '1.26.3' - name: Run GoReleaser if: ${{ steps.release.outputs.release_created }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6ce7b36..447a9f6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.25.5' + go-version: '1.26.3' - name: Run GoReleaser uses: goreleaser/goreleaser-action@v6 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0bcb582..25e2fec 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.25.5' + go-version: '1.26.3' - name: Run unit tests run: go test ./... -cover diff --git a/cmd/check.go b/cmd/check.go index 754f227..860ceeb 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -30,11 +30,12 @@ var ( strictMode bool // Ignore flags. - ignoreDomains []string - ignorePatterns []string - ignoreRegex []string - showIgnored bool - noConfig bool + ignoreDomains []string + ignorePatterns []string + ignoreRegex []string + showIgnored bool + noConfig bool + allowPrivateHosts bool ) // checkCmd represents the check command. @@ -144,6 +145,12 @@ func init() { "Show which URLs were ignored and why") checkCmd.Flags().BoolVar(&noConfig, "no-config", false, "Skip loading .gonerc.yaml config file") + + // Security options + checkCmd.Flags().BoolVar(&allowPrivateHosts, "allow-private-hosts", false, + "Allow requests to loopback, private, link-local and reserved IP "+ + "ranges. Default is to block them to prevent SSRF when scanning "+ + "untrusted documents.") } // runCheck is the main entry point for the check command. diff --git a/cmd/check_runner.go b/cmd/check_runner.go index a303238..f368ce2 100644 --- a/cmd/check_runner.go +++ b/cmd/check_runner.go @@ -11,23 +11,24 @@ import ( ) type checkOptions struct { - OutputFormat string - OutputFile string - Concurrency int - Timeout int - Retries int - ShowAlive bool - ShowWarnings bool - ShowDead bool - ShowAll bool - ShowStats bool - FileTypes []string - StrictMode bool - IgnoreDomains []string - IgnorePatterns []string - IgnoreRegex []string - ShowIgnored bool - NoConfig bool + OutputFormat string + OutputFile string + Concurrency int + Timeout int + Retries int + ShowAlive bool + ShowWarnings bool + ShowDead bool + ShowAll bool + ShowStats bool + FileTypes []string + StrictMode bool + IgnoreDomains []string + IgnorePatterns []string + IgnoreRegex []string + ShowIgnored bool + NoConfig bool + AllowPrivateHosts bool } type checkRunner struct { @@ -46,23 +47,24 @@ func newCheckRunner(opts checkOptions, env CommandEnv, streams IOStreams) *check func currentCheckOptions() checkOptions { return checkOptions{ - OutputFormat: outputFormat, - OutputFile: outputFile, - Concurrency: concurrency, - Timeout: timeout, - Retries: retries, - ShowAlive: showAlive, - ShowWarnings: showWarnings, - ShowDead: showDead, - ShowAll: showAll, - ShowStats: showStats, - FileTypes: append([]string{}, fileTypes...), - StrictMode: strictMode, - IgnoreDomains: append([]string{}, ignoreDomains...), - IgnorePatterns: append([]string{}, ignorePatterns...), - IgnoreRegex: append([]string{}, ignoreRegex...), - ShowIgnored: showIgnored, - NoConfig: noConfig, + OutputFormat: outputFormat, + OutputFile: outputFile, + Concurrency: concurrency, + Timeout: timeout, + Retries: retries, + ShowAlive: showAlive, + ShowWarnings: showWarnings, + ShowDead: showDead, + ShowAll: showAll, + ShowStats: showStats, + FileTypes: append([]string{}, fileTypes...), + StrictMode: strictMode, + IgnoreDomains: append([]string{}, ignoreDomains...), + IgnorePatterns: append([]string{}, ignorePatterns...), + IgnoreRegex: append([]string{}, ignoreRegex...), + ShowIgnored: showIgnored, + NoConfig: noConfig, + AllowPrivateHosts: allowPrivateHosts, } } @@ -244,7 +246,10 @@ func (r *checkRunner) checkLinksWithConfig( links []checker.Link, cfg *LoadedConfig, perf *stats.Stats, ) ([]checker.Result, checker.Summary) { perf.StartCheck() - c := r.env.NewChecker(cfg.BuildCheckerOptions(r.opts.Concurrency, r.opts.Timeout, r.opts.Retries)) + c := r.env.NewChecker(cfg.BuildCheckerOptions( + r.opts.Concurrency, r.opts.Timeout, r.opts.Retries, + r.opts.AllowPrivateHosts, + )) results := c.CheckAll(links) summary := checker.Summarize(results) perf.EndCheck() diff --git a/cmd/e2e_bench_test.go b/cmd/e2e_bench_test.go index 1f2df3d..226da0c 100644 --- a/cmd/e2e_bench_test.go +++ b/cmd/e2e_bench_test.go @@ -100,6 +100,7 @@ func BenchmarkPipeline_FullCheck(b *testing.B) { c := checker.New( checker.DefaultOptions(). + WithAllowPrivateHosts(true). WithConcurrency(16). WithTimeout(2 * time.Second). WithMaxRetries(0), @@ -135,6 +136,7 @@ func BenchmarkPipeline_FixDryRun(b *testing.B) { c := checker.New( checker.DefaultOptions(). + WithAllowPrivateHosts(true). WithConcurrency(16). WithTimeout(2 * time.Second). WithMaxRetries(0), diff --git a/cmd/fix.go b/cmd/fix.go index 3086f1c..5e86cec 100644 --- a/cmd/fix.go +++ b/cmd/fix.go @@ -22,10 +22,11 @@ var ( fixStrictMode bool // Ignore flags (shared with check). - fixIgnoreDomains []string - fixIgnorePatterns []string - fixIgnoreRegex []string - fixNoConfig bool + fixIgnoreDomains []string + fixIgnorePatterns []string + fixIgnoreRegex []string + fixNoConfig bool + fixAllowPrivateHosts bool ) // fixCmd represents the fix command. @@ -100,6 +101,11 @@ func init() { "Regex patterns to ignore (can be repeated)") fixCmd.Flags().BoolVar(&fixNoConfig, "no-config", false, "Skip loading .gonerc.yaml config file") + + // Security options + fixCmd.Flags().BoolVar(&fixAllowPrivateHosts, "allow-private-hosts", false, + "Allow requests to loopback, private, link-local and reserved IP "+ + "ranges. Default is to block them to prevent SSRF.") } // runFix is the main entry point for the fix command. diff --git a/cmd/fix_runner.go b/cmd/fix_runner.go index 0ddfcc9..7821370 100644 --- a/cmd/fix_runner.go +++ b/cmd/fix_runner.go @@ -10,18 +10,19 @@ import ( ) type fixOptions struct { - Yes bool - DryRun bool - Concurrency int - Timeout int - Retries int - ShowStats bool - FileTypes []string - StrictMode bool - IgnoreDomains []string - IgnorePatterns []string - IgnoreRegex []string - NoConfig bool + Yes bool + DryRun bool + Concurrency int + Timeout int + Retries int + ShowStats bool + FileTypes []string + StrictMode bool + IgnoreDomains []string + IgnorePatterns []string + IgnoreRegex []string + NoConfig bool + AllowPrivateHosts bool } type fixRunner struct { @@ -40,18 +41,19 @@ func newFixRunner(opts fixOptions, env CommandEnv, streams IOStreams) *fixRunner func currentFixOptions() fixOptions { return fixOptions{ - Yes: fixYes, - DryRun: fixDryRun, - Concurrency: fixConcurrency, - Timeout: fixTimeout, - Retries: fixRetries, - ShowStats: fixShowStats, - FileTypes: append([]string{}, fixFileTypes...), - StrictMode: fixStrictMode, - IgnoreDomains: append([]string{}, fixIgnoreDomains...), - IgnorePatterns: append([]string{}, fixIgnorePatterns...), - IgnoreRegex: append([]string{}, fixIgnoreRegex...), - NoConfig: fixNoConfig, + Yes: fixYes, + DryRun: fixDryRun, + Concurrency: fixConcurrency, + Timeout: fixTimeout, + Retries: fixRetries, + ShowStats: fixShowStats, + FileTypes: append([]string{}, fixFileTypes...), + StrictMode: fixStrictMode, + IgnoreDomains: append([]string{}, fixIgnoreDomains...), + IgnorePatterns: append([]string{}, fixIgnorePatterns...), + IgnoreRegex: append([]string{}, fixIgnoreRegex...), + NoConfig: fixNoConfig, + AllowPrivateHosts: fixAllowPrivateHosts, } } @@ -126,7 +128,10 @@ func (r *fixRunner) Run(args []string) int { perf.StartCheck() results := r.env.NewChecker( - loadedCfg.BuildCheckerOptions(r.opts.Concurrency, r.opts.Timeout, r.opts.Retries), + loadedCfg.BuildCheckerOptions( + r.opts.Concurrency, r.opts.Timeout, r.opts.Retries, + r.opts.AllowPrivateHosts, + ), ).CheckAll(links) perf.EndCheck() diff --git a/cmd/helpers.go b/cmd/helpers.go index 71b86b6..8b9393d 100644 --- a/cmd/helpers.go +++ b/cmd/helpers.go @@ -94,6 +94,15 @@ func (lc *LoadedConfig) GetTimeout(cliValue, defaultValue int) int { return defaultValue } +// GetAllowPrivateHosts returns the effective allow-private-hosts flag. +// CLI true overrides config; otherwise the config value is used. +func (lc *LoadedConfig) GetAllowPrivateHosts(cliValue bool) bool { + if cliValue { + return true + } + return lc.cfg.Check.AllowPrivateHosts +} + // GetRetries returns the effective retry count. // CLI overrides config if it differs from the default. func (lc *LoadedConfig) GetRetries(cliValue, defaultValue int) int { @@ -161,13 +170,19 @@ func (lc *LoadedConfig) GetShowStats(cliValue bool) bool { } // BuildCheckerOptions creates checker.Options from config and CLI values. -func (lc *LoadedConfig) BuildCheckerOptions(cliConcurrency, cliTimeout, cliRetries int) checker.Options { +// cliAllowPrivate reflects the --allow-private-hosts flag; when true the +// checker is permitted to contact loopback, private, and reserved IPs. +func (lc *LoadedConfig) BuildCheckerOptions( + cliConcurrency, cliTimeout, cliRetries int, + cliAllowPrivate bool, +) checker.Options { defaultOpts := checker.DefaultOptions() return defaultOpts. WithConcurrency(lc.GetConcurrency(cliConcurrency, checker.DefaultConcurrency)). WithTimeout(time.Duration(lc.GetTimeout(cliTimeout, int(checker.DefaultTimeout.Seconds()))) * time.Second). - WithMaxRetries(lc.GetRetries(cliRetries, checker.DefaultMaxRetries)) + WithMaxRetries(lc.GetRetries(cliRetries, checker.DefaultMaxRetries)). + WithAllowPrivateHosts(lc.GetAllowPrivateHosts(cliAllowPrivate)) } // BuildScanOptions creates scanner.ScanOptions from config and path. diff --git a/cmd/helpers_test.go b/cmd/helpers_test.go index 9b74d36..6fb6303 100644 --- a/cmd/helpers_test.go +++ b/cmd/helpers_test.go @@ -111,10 +111,12 @@ func TestLoadedConfig_GettersAndBuilders(t *testing.T) { checker.DefaultConcurrency, int(checker.DefaultTimeout.Seconds()), checker.DefaultMaxRetries, + false, ) assert.Equal(t, 12, opts.Concurrency) assert.Equal(t, 7*time.Second, opts.Timeout) assert.Equal(t, 4, opts.MaxRetries) + assert.False(t, opts.AllowPrivateHosts) scanOpts := loaded.BuildScanOptions("/tmp/docs", []string{"md"}, []string{"md"}) assert.Equal(t, "/tmp/docs", scanOpts.Root) diff --git a/cmd/integration_test.go b/cmd/integration_test.go index 0739e00..8481549 100644 --- a/cmd/integration_test.go +++ b/cmd/integration_test.go @@ -91,7 +91,7 @@ func TestCheck_WritesOutputFile(t *testing.T) { )) reportPath := filepath.Join(tmpDir, "report.json") - result := runGone(t, tmpDir, "check", ".", "--output", reportPath, "--no-config") + result := runGone(t, tmpDir, "check", ".", "--output", reportPath, "--no-config", "--allow-private-hosts") require.Equal(t, 0, result.exitCode, result.stderr) assert.Contains(t, result.stdout, "Wrote report to") @@ -179,7 +179,7 @@ func TestFix_Yes_UpdatesRedirectsAcrossFileTypes(t *testing.T) { 0o644, )) - result := runGone(t, tmpDir, "fix", ".", "--yes", "--types=md,json", "--no-config") + result := runGone(t, tmpDir, "fix", ".", "--yes", "--types=md,json", "--no-config", "--allow-private-hosts") require.Equal(t, 0, result.exitCode, result.stderr) assert.Contains(t, result.stdout, "Found 2 file(s) of type(s): md, json") assert.Contains(t, result.stdout, "Fixed 2 redirect(s) across 2 file(s):") @@ -218,7 +218,7 @@ func TestFix_InteractiveScriptedInput(t *testing.T) { filePath := filepath.Join(tmpDir, "README.md") require.NoError(t, os.WriteFile(filePath, []byte("[docs]("+oldURL+")\n"), 0o644)) - result := runGoneWithInput(t, tmpDir, "?\ny\n", "fix", ".", "--types=md", "--no-config") + result := runGoneWithInput(t, tmpDir, "?\ny\n", "fix", ".", "--types=md", "--no-config", "--allow-private-hosts") require.Equal(t, 0, result.exitCode, result.stderr) assert.Contains(t, result.stdout, "Interactive mode options:") assert.Contains(t, result.stdout, "Fixed 1 redirect(s) in README.md") diff --git a/cmd/interactive.go b/cmd/interactive.go index 6bde1c3..429c576 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -11,10 +11,11 @@ var ( iFileTypes []string iStrictMode bool - iIgnoreDomains []string - iIgnorePatterns []string - iIgnoreRegex []string - iNoConfig bool + iIgnoreDomains []string + iIgnorePatterns []string + iIgnoreRegex []string + iNoConfig bool + iAllowPrivateHosts bool ) // interactiveCmd represents the interactive command. @@ -64,6 +65,11 @@ func init() { "Regex patterns to ignore (can be repeated)") interactiveCmd.Flags().BoolVar(&iNoConfig, "no-config", false, "Skip loading .gonerc.yaml config file") + + // Security options + interactiveCmd.Flags().BoolVar(&iAllowPrivateHosts, "allow-private-hosts", false, + "Allow requests to loopback, private, link-local and reserved IP "+ + "ranges. Default is to block them to prevent SSRF.") } // runInteractive launches the interactive TUI for link checking. diff --git a/cmd/interactive_runner.go b/cmd/interactive_runner.go index f73c09a..419a897 100644 --- a/cmd/interactive_runner.go +++ b/cmd/interactive_runner.go @@ -3,12 +3,13 @@ package cmd import "github.com/leonardomso/gone/internal/checker" type interactiveOptions struct { - FileTypes []string - StrictMode bool - IgnoreDomains []string - IgnorePatterns []string - IgnoreRegex []string - NoConfig bool + FileTypes []string + StrictMode bool + IgnoreDomains []string + IgnorePatterns []string + IgnoreRegex []string + NoConfig bool + AllowPrivateHosts bool } type interactiveRunner struct { @@ -31,12 +32,13 @@ func newInteractiveRunner( func currentInteractiveOptions() interactiveOptions { return interactiveOptions{ - FileTypes: append([]string{}, iFileTypes...), - StrictMode: iStrictMode, - IgnoreDomains: append([]string{}, iIgnoreDomains...), - IgnorePatterns: append([]string{}, iIgnorePatterns...), - IgnoreRegex: append([]string{}, iIgnoreRegex...), - NoConfig: iNoConfig, + FileTypes: append([]string{}, iFileTypes...), + StrictMode: iStrictMode, + IgnoreDomains: append([]string{}, iIgnoreDomains...), + IgnorePatterns: append([]string{}, iIgnorePatterns...), + IgnoreRegex: append([]string{}, iIgnoreRegex...), + NoConfig: iNoConfig, + AllowPrivateHosts: iAllowPrivateHosts, } } @@ -76,6 +78,7 @@ func (r *interactiveRunner) Run(args []string) int { checker.DefaultConcurrency, int(checker.DefaultTimeout.Seconds()), checker.DefaultMaxRetries, + r.opts.AllowPrivateHosts, ), ) diff --git a/go.mod b/go.mod index f19461d..a49f496 100644 --- a/go.mod +++ b/go.mod @@ -1,46 +1,48 @@ module github.com/leonardomso/gone -go 1.25.5 +go 1.26.3 require ( github.com/BurntSushi/toml v1.6.0 - github.com/charmbracelet/bubbles v0.21.0 + github.com/alitto/pond/v2 v2.7.1 + github.com/charmbracelet/bubbles v1.0.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 github.com/gobwas/glob v0.2.3 github.com/samber/lo v1.53.0 + github.com/sourcegraph/conc v0.3.0 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 - github.com/yuin/goldmark v1.7.13 + github.com/yuin/goldmark v1.8.2 go.uber.org/goleak v1.3.0 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/alitto/pond/v2 v2.7.0 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect - github.com/charmbracelet/x/ansi v0.10.1 // indirect - github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect - github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/charmbracelet/colorprofile v0.4.3 // indirect + github.com/charmbracelet/x/ansi v0.11.7 // indirect + github.com/charmbracelet/x/cellbuf v0.0.15 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.11.0 // indirect + github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kr/text v0.2.0 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect + github.com/lucasb-eyer/go-colorful v1.4.0 // indirect + github.com/mattn/go-isatty v0.0.22 // indirect github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mattn/go-runewidth v0.0.23 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/sahilm/fuzzy v0.1.1 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect + github.com/sahilm/fuzzy v0.1.2 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - golang.org/x/sys v0.36.0 // indirect - golang.org/x/text v0.22.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/text v0.37.0 // indirect ) diff --git a/go.sum b/go.sum index d271e47..f68e190 100644 --- a/go.sum +++ b/go.sum @@ -1,29 +1,33 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/alitto/pond/v2 v2.7.0 h1:c76L+yN916m/DRXjGCeUBHHu92uWnh/g1bwVk4zyyXg= -github.com/alitto/pond/v2 v2.7.0/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= +github.com/alitto/pond/v2 v2.7.1 h1:QxMbcfjcVTa0pyxX5Ib1226mM8u8D7gKUVkCUU4DYIw= +github.com/alitto/pond/v2 v2.7.1/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= -github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= -github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= -github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= -github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= +github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY= +github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E= +github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc= +github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E= github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex9t5KX76i20Q= +github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q= github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= -github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= -github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI= +github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= -github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= -github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= +github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8= +github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= +github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= +github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -34,21 +38,20 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4= +github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4= +github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= +github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= @@ -57,12 +60,13 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= -github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= +github.com/sahilm/fuzzy v0.1.2 h1:kdSkz23lx1meNjEl+SLJULeSbjTI4Dn14K/YxdGrIww= +github.com/sahilm/fuzzy v0.1.2/go.mod h1:au6//VbVSqu6DFrkL2CfjlJ5iURpNCPeE+1GwY3XsT8= github.com/samber/lo v1.53.0 h1:t975lj2py4kJPQ6haz1QMgtId2gtmfktACxIXArw3HM= github.com/samber/lo v1.53.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= @@ -76,22 +80,20 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= -github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= -github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= +github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= -golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/atomicfile/atomicfile.go b/internal/atomicfile/atomicfile.go new file mode 100644 index 0000000..9abc6fb --- /dev/null +++ b/internal/atomicfile/atomicfile.go @@ -0,0 +1,68 @@ +// Package atomicfile provides atomic file write primitives. The goal is that +// after a successful WriteFile call, readers either see the full new content +// or the full previous content — never a partial / truncated file caused by a +// crash, OS reboot, disk-full, or SIGKILL halfway through the write. +package atomicfile + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" +) + +// WriteFile atomically writes data to the file named by path. It writes to a +// temp file in the same directory (so the final rename is atomic on the same +// filesystem), fsyncs the temp file, then renames it over path. On any error +// the temp file is removed. +// +// perm is applied to the final file via chmod after creation; on Unix the +// resulting file ignores umask, matching os.WriteFile semantics. +func WriteFile(path string, data []byte, perm fs.FileMode) (retErr error) { + if path == "" { + return errors.New("atomicfile: empty path") + } + + dir := filepath.Dir(path) + base := filepath.Base(path) + + tmp, err := os.CreateTemp(dir, "."+base+".tmp-*") + if err != nil { + return fmt.Errorf("creating temp file: %w", err) + } + tmpName := tmp.Name() + + // On any failure past this point, remove the leftover temp file. The + // rename below clears retErr on success, so we won't unlink the final + // destination. + defer func() { + if retErr != nil { + _ = os.Remove(tmpName) + } + }() + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return fmt.Errorf("writing temp file: %w", err) + } + + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return fmt.Errorf("syncing temp file: %w", err) + } + + if err := tmp.Close(); err != nil { + return fmt.Errorf("closing temp file: %w", err) + } + + if err := os.Chmod(tmpName, perm); err != nil { + return fmt.Errorf("chmod temp file: %w", err) + } + + if err := os.Rename(tmpName, path); err != nil { + return fmt.Errorf("renaming temp file: %w", err) + } + + return nil +} diff --git a/internal/atomicfile/atomicfile_test.go b/internal/atomicfile/atomicfile_test.go new file mode 100644 index 0000000..e29a833 --- /dev/null +++ b/internal/atomicfile/atomicfile_test.go @@ -0,0 +1,240 @@ +package atomicfile + +import ( + "os" + "path/filepath" + "runtime" + "slices" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWriteFile_CreatesNewFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "out.txt") + require.NoError(t, WriteFile(path, []byte("hello"), 0o600)) + + got, err := os.ReadFile(path) + require.NoError(t, err) + assert.Equal(t, "hello", string(got)) +} + +func TestWriteFile_OverwritesExistingFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "out.txt") + require.NoError(t, os.WriteFile(path, []byte("before"), 0o600)) + + require.NoError(t, WriteFile(path, []byte("after"), 0o600)) + + got, err := os.ReadFile(path) + require.NoError(t, err) + assert.Equal(t, "after", string(got)) +} + +func TestWriteFile_AppliesPerm(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission bits not enforced on windows") + } + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "out.txt") + require.NoError(t, WriteFile(path, []byte("x"), 0o644)) + + info, err := os.Stat(path) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o644), info.Mode().Perm()) +} + +func TestWriteFile_EmptyData(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "out.txt") + require.NoError(t, WriteFile(path, nil, 0o600)) + + got, err := os.ReadFile(path) + require.NoError(t, err) + assert.Empty(t, got) +} + +func TestWriteFile_LeavesNoTempOnSuccess(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "out.txt") + require.NoError(t, WriteFile(path, []byte("data"), 0o600)) + + entries, err := os.ReadDir(dir) + require.NoError(t, err) + for _, e := range entries { + if strings.Contains(e.Name(), ".tmp-") { + t.Fatalf("unexpected leftover temp file: %s", e.Name()) + } + } +} + +func TestWriteFile_EmptyPathRejected(t *testing.T) { + t.Parallel() + require.Error(t, WriteFile("", []byte("x"), 0o600)) +} + +func TestWriteFile_NonexistentDirectory(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "no-such-subdir", "out.txt") + err := WriteFile(path, []byte("x"), 0o600) + require.Error(t, err) +} + +// On failure, the temp file should be cleaned up — never leak into the +// surrounding directory. +func TestWriteFile_NoTempLeakOnDirError(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + // Use a path whose parent doesn't exist; CreateTemp fails before we'd + // ever write data. The directory shouldn't end up with a stray file. + path := filepath.Join(dir, "missing", "out.txt") + _ = WriteFile(path, []byte("x"), 0o600) + + // The parent dir of "missing" is `dir`, which should still be empty. + entries, err := os.ReadDir(dir) + require.NoError(t, err) + for _, e := range entries { + if e.Name() != "missing" { // missing was never created + t.Fatalf("unexpected entry: %s", e.Name()) + } + } +} + +// Two concurrent writes to distinct paths must both succeed without +// interfering. +func TestWriteFile_ConcurrentDistinctPaths(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + const n = 16 + var wg sync.WaitGroup + for i := range n { + wg.Go(func() { + p := filepath.Join(dir, "file-"+strings.Repeat("x", i+1)+".txt") + data := []byte(strings.Repeat("z", i+1)) + require.NoError(t, WriteFile(p, data, 0o600)) + }) + } + wg.Wait() + + entries, err := os.ReadDir(dir) + require.NoError(t, err) + regularFiles := 0 + for _, e := range entries { + if strings.Contains(e.Name(), ".tmp-") { + t.Fatalf("leftover temp file: %s", e.Name()) + } + regularFiles++ + } + assert.Equal(t, n, regularFiles) +} + +// Concurrent writes to the SAME path: each must produce a complete file +// (one of the values), never a partial / interleaved result. +func TestWriteFile_ConcurrentSamePath_AllOrNothing(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "shared.txt") + values := []string{ + strings.Repeat("A", 4096), + strings.Repeat("B", 4096), + strings.Repeat("C", 4096), + strings.Repeat("D", 4096), + } + + var wg sync.WaitGroup + for _, v := range values { + wg.Go(func() { + require.NoError(t, WriteFile(path, []byte(v), 0o600)) + }) + } + wg.Wait() + + got, err := os.ReadFile(path) + require.NoError(t, err) + // File must equal exactly one of the candidate values; never a mix. + assert.True(t, slices.Contains(values, string(got)), + "final content was not any of the candidate writes") +} + +func TestWriteFile_LargePayload(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "big.bin") + data := make([]byte, 1<<20) // 1 MiB + for i := range data { + data[i] = byte(i % 251) + } + require.NoError(t, WriteFile(path, data, 0o600)) + + got, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, len(data), len(got)) + assert.Equal(t, data, got) +} + +func TestWriteFile_OverwriteIsAtomicForReaders(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("rename-over-open semantics differ on windows") + } + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "out.txt") + require.NoError(t, os.WriteFile(path, []byte("OLD"), 0o600)) + + // While another goroutine is repeatedly reading, swap the file + // contents. The reader should always observe either the full old + // content or the full new content. + stop := make(chan struct{}) + bad := make(chan string, 1) + go func() { + for { + select { + case <-stop: + return + default: + } + b, err := os.ReadFile(path) + if err != nil { + continue + } + s := string(b) + if s != "OLD" && s != "NEW-CONTENT-LONGER" { + select { + case bad <- s: + default: + } + return + } + } + }() + + require.NoError(t, WriteFile(path, []byte("NEW-CONTENT-LONGER"), 0o600)) + close(stop) + + select { + case s := <-bad: + t.Fatalf("reader observed partial content: %q", s) + default: + } +} diff --git a/internal/checker/checker.go b/internal/checker/checker.go index c415e51..38d1a8c 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -33,6 +33,13 @@ func New(opts Options) *Checker { } } +// Close releases idle HTTP connections held by the underlying client. +// Callers should invoke Close when the Checker is no longer needed so the +// connection-pool goroutines (readLoop/writeLoop) can exit promptly. +func (c *Checker) Close() { + c.client.CloseIdleConnections() +} + // newHTTPClient creates an optimized HTTP client for link checking. // It configures connection pooling for efficiency, proper timeouts for reliability, // and TLS settings for security. The client does NOT follow redirects automatically @@ -51,11 +58,16 @@ func newHTTPClient(opts Options) *http.Client { MinVersion: tls.VersionTLS12, }, - // Timeout layers for different phases - tuned for speed - DialContext: (&net.Dialer{ - Timeout: opts.Timeout, - KeepAlive: 30 * time.Second, - }).DialContext, + // Timeout layers for different phases - tuned for speed. + // The base dialer is wrapped by safeDialContext so every resolved + // IP is screened against the SSRF blocklist before connecting. + DialContext: safeDialContext( + (&net.Dialer{ + Timeout: opts.Timeout, + KeepAlive: 30 * time.Second, + }).DialContext, + opts.AllowPrivateHosts, + ), TLSHandshakeTimeout: 5 * time.Second, // Faster TLS handshake timeout ResponseHeaderTimeout: opts.Timeout, ExpectContinueTimeout: 1 * time.Second, @@ -342,6 +354,13 @@ func (c *Checker) followRedirectChain(ctx context.Context, startURL string) ([]R currentURL := startURL for i := 0; i < c.opts.MaxRedirects; i++ { + // Refuse to fetch URLs whose scheme is not http(s) or whose literal + // host IP is in the SSRF blocklist. safeDialContext catches the + // hostname-resolves-to-private-IP case at the TCP layer. + if err := validateURL(currentURL, c.opts.AllowPrivateHosts); err != nil { + return chain, currentURL, 0, fmt.Errorf("blocked redirect: %w", err) + } + statusCode, location, err := c.doRequestGetLocation(ctx, currentURL) if err != nil { return chain, currentURL, 0, err diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index cd958cf..386f8dc 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -539,7 +539,8 @@ func TestChecker_CheckAll_200OK(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL, FilePath: "test.md", Line: 1}} results := checker.CheckAll(links) @@ -557,7 +558,8 @@ func TestChecker_CheckAll_404NotFound(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -575,7 +577,8 @@ func TestChecker_CheckAll_500ServerError(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -599,7 +602,8 @@ func TestChecker_CheckAll_HeadFallbackToGet(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -622,7 +626,8 @@ func TestChecker_CheckAll_Redirect301(t *testing.T) { })) defer redirectServer.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: redirectServer.URL}} results := checker.CheckAll(links) @@ -647,7 +652,8 @@ func TestChecker_CheckAll_Redirect302(t *testing.T) { })) defer redirectServer.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: redirectServer.URL}} results := checker.CheckAll(links) @@ -676,7 +682,8 @@ func TestChecker_CheckAll_RedirectChain(t *testing.T) { })) defer serverA.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: serverA.URL}} results := checker.CheckAll(links) @@ -697,7 +704,12 @@ func TestChecker_CheckAll_TooManyRedirects(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0).WithMaxRedirects(3)) + checker := New(DefaultOptions(). + WithAllowPrivateHosts(true). + WithConcurrency(1). + WithMaxRetries(0). + WithMaxRedirects(3)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -720,7 +732,8 @@ func TestChecker_CheckAll_RedirectToDead(t *testing.T) { })) defer redirectServer.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: redirectServer.URL}} results := checker.CheckAll(links) @@ -737,7 +750,8 @@ func TestChecker_CheckAll_403Blocked(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -761,7 +775,8 @@ func TestChecker_CheckAll_403ThenOKWithBrowserHeaders(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -778,7 +793,8 @@ func TestChecker_CheckAll_Duplicates(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{ {URL: server.URL, FilePath: "a.md", Line: 1}, {URL: server.URL, FilePath: "b.md", Line: 5}, // Duplicate @@ -817,7 +833,8 @@ func TestChecker_CheckAll_MultipleDifferentURLs(t *testing.T) { })) defer server2.Close() - checker := New(DefaultOptions().WithConcurrency(2).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(2).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{ {URL: server1.URL}, {URL: server2.URL}, @@ -850,7 +867,8 @@ func TestChecker_CheckAll_RetryOn5xx(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(3)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(3)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -874,7 +892,8 @@ func TestChecker_CheckAll_RetryOn429(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(2)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(2)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -893,7 +912,8 @@ func TestChecker_CheckAll_NoRetryOn404(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(3)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(3)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -912,7 +932,8 @@ func TestChecker_Check_ContextCanceled(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithTimeout(10 * time.Second)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithTimeout(10 * time.Second)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} ctx, cancel := context.WithCancel(context.Background()) @@ -937,7 +958,8 @@ func TestChecker_Check_ContextCanceled(t *testing.T) { func TestChecker_CheckAll_EmptyLinks(t *testing.T) { t.Parallel() - checker := New(DefaultOptions()) + checker := New(DefaultOptions().WithAllowPrivateHosts(true)) + t.Cleanup(checker.Close) results := checker.CheckAll(nil) assert.Empty(t, results) @@ -953,7 +975,8 @@ func TestChecker_CheckAll_PreservesLinkMetadata(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1)) + t.Cleanup(checker.Close) links := []Link{{ URL: server.URL, FilePath: "README.md", @@ -978,7 +1001,12 @@ func TestChecker_CheckAll_Timeout(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithTimeout(100 * time.Millisecond).WithMaxRetries(0)) + checker := New(DefaultOptions(). + WithAllowPrivateHosts(true). + WithConcurrency(1). + WithTimeout(100 * time.Millisecond). + WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -991,7 +1019,8 @@ func TestChecker_CheckAll_Timeout(t *testing.T) { func TestChecker_CheckAll_InvalidURL(t *testing.T) { t.Parallel() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: "not-a-valid-url"}} results := checker.CheckAll(links) @@ -1004,7 +1033,12 @@ func TestChecker_CheckAll_ConnectionRefused(t *testing.T) { t.Parallel() // Port that nothing is listening on - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0).WithTimeout(1 * time.Second)) + checker := New(DefaultOptions(). + WithAllowPrivateHosts(true). + WithConcurrency(1). + WithMaxRetries(0). + WithTimeout(1 * time.Second)) + t.Cleanup(checker.Close) links := []Link{{URL: "http://127.0.0.1:59999"}} results := checker.CheckAll(links) @@ -1043,7 +1077,8 @@ func TestChecker_CheckAll_Concurrency(t *testing.T) { links[i] = Link{URL: server.URL + "/" + string(rune('a'+i))} } - checker := New(DefaultOptions().WithConcurrency(5).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(5).WithMaxRetries(0)) + t.Cleanup(checker.Close) results := checker.CheckAll(links) assert.Len(t, results, 20) @@ -1216,7 +1251,8 @@ func TestChecker_CheckAll_VariousStatusCodes(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -1252,7 +1288,8 @@ func TestChecker_CheckAll_RedirectTo403ThenOKWithBrowserHeaders(t *testing.T) { })) defer redirectServer.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: redirectServer.URL}} results := checker.CheckAll(links) @@ -1278,7 +1315,8 @@ func TestChecker_CheckAll_RedirectTo403StillBlocked(t *testing.T) { })) defer redirectServer.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: redirectServer.URL}} results := checker.CheckAll(links) @@ -1302,7 +1340,8 @@ func TestChecker_CheckAll_501NotImplemented(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -1328,7 +1367,8 @@ func TestChecker_CheckAll_RedirectWithRelativeLocation(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL + "/start"}} results := checker.CheckAll(links) @@ -1349,7 +1389,8 @@ func TestChecker_CheckAll_RedirectWithInvalidLocation(t *testing.T) { })) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL}} results := checker.CheckAll(links) @@ -1419,7 +1460,8 @@ func TestChecker_CheckAll_RedirectToSameHost(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: server.URL + "/a"}} results := checker.CheckAll(links) @@ -1443,7 +1485,8 @@ func TestChecker_CheckAll_307TemporaryRedirect(t *testing.T) { })) defer redirectServer.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: redirectServer.URL}} results := checker.CheckAll(links) @@ -1466,7 +1509,8 @@ func TestChecker_CheckAll_308PermanentRedirect(t *testing.T) { })) defer redirectServer.Close() - checker := New(DefaultOptions().WithConcurrency(1).WithMaxRetries(0)) + checker := New(DefaultOptions().WithAllowPrivateHosts(true).WithConcurrency(1).WithMaxRetries(0)) + t.Cleanup(checker.Close) links := []Link{{URL: redirectServer.URL}} results := checker.CheckAll(links) diff --git a/internal/checker/options.go b/internal/checker/options.go index febcc5c..589e77f 100644 --- a/internal/checker/options.go +++ b/internal/checker/options.go @@ -45,6 +45,13 @@ type Options struct { // MaxRedirects is the maximum number of redirects to follow. MaxRedirects int + + // AllowPrivateHosts permits requests to loopback, link-local, private, + // and other reserved IP ranges. The zero value (false) is the safe + // default for users who run the tool on untrusted documents: it + // prevents SSRF against cloud metadata services (169.254.169.254), + // internal corporate hosts, and locally bound services. + AllowPrivateHosts bool } // DefaultOptions returns optimized default configuration. @@ -99,6 +106,13 @@ func (o Options) WithUserAgent(ua string) Options { return o } +// WithAllowPrivateHosts sets whether requests to loopback, link-local, +// private, and reserved IP ranges are permitted. Defaults to false. +func (o Options) WithAllowPrivateHosts(v bool) Options { + o.AllowPrivateHosts = v + return o +} + // BrowserUserAgent is a realistic browser User-Agent for bypassing bot detection. const BrowserUserAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" diff --git a/internal/checker/safety.go b/internal/checker/safety.go new file mode 100644 index 0000000..2e65fe7 --- /dev/null +++ b/internal/checker/safety.go @@ -0,0 +1,160 @@ +// Package checker - safety.go contains URL and address validation used to +// prevent server-side request forgery (SSRF). Without these guards an +// attacker who controls a redirect target or who plants a URL in a scanned +// document could coerce the tool into fetching cloud metadata endpoints +// (e.g. 169.254.169.254), localhost services, or internal corporate hosts +// from the user's machine. +package checker + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "strings" +) + +// ErrBlockedAddress is returned when a URL or its resolved IP falls inside +// a range that the checker refuses to contact (loopback, link-local, +// private, CGNAT, unspecified, etc.). +var ErrBlockedAddress = errors.New("blocked address") + +// ErrUnsupportedScheme is returned for URLs whose scheme is not http or https. +var ErrUnsupportedScheme = errors.New("unsupported scheme") + +// blockedCIDRs lists the IP ranges considered unsafe to contact during link +// checking. The list is intentionally explicit so that adding or removing a +// range is a code change with an accompanying test. +var blockedCIDRs = compileCIDRs([]string{ + // IPv4 + "0.0.0.0/8", // "this network" + "10.0.0.0/8", // RFC 1918 private + "100.64.0.0/10", // CGNAT (RFC 6598) + "127.0.0.0/8", // loopback + "169.254.0.0/16", // link-local (incl. cloud metadata) + "172.16.0.0/12", // RFC 1918 private + "192.0.0.0/24", // IETF assignments + "192.0.2.0/24", // TEST-NET-1 + "192.168.0.0/16", // RFC 1918 private + "198.18.0.0/15", // benchmark + "198.51.100.0/24", // TEST-NET-2 + "203.0.113.0/24", // TEST-NET-3 + "224.0.0.0/4", // multicast + "240.0.0.0/4", // reserved + "255.255.255.255/32", + + // IPv6 + "::/128", // unspecified + "::1/128", // loopback + "fc00::/7", // unique local + "fe80::/10", // link-local + "ff00::/8", // multicast + "2001:db8::/32", // documentation +}) + +func compileCIDRs(cidrs []string) []*net.IPNet { + out := make([]*net.IPNet, 0, len(cidrs)) + for _, c := range cidrs { + _, n, err := net.ParseCIDR(c) + if err != nil { + // Programmer error - the list above is a constant. + panic(fmt.Sprintf("checker: invalid blocked CIDR %q: %v", c, err)) + } + out = append(out, n) + } + return out +} + +// isBlockedIP reports whether the given IP falls in any of blockedCIDRs. +func isBlockedIP(ip net.IP) bool { + if ip == nil { + return true + } + for _, n := range blockedCIDRs { + if n.Contains(ip) { + return true + } + } + return false +} + +// validateURL parses rawURL and ensures it uses an allowed scheme and that, +// if the host is a literal IP, the IP is not in a blocked range. Hostnames +// that resolve to blocked IPs are caught later by safeDialContext. +func validateURL(rawURL string, allowPrivate bool) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("parse url: %w", err) + } + scheme := strings.ToLower(u.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("%w: %q", ErrUnsupportedScheme, u.Scheme) + } + if allowPrivate { + return nil + } + host := u.Hostname() + if host == "" { + return fmt.Errorf("%w: empty host", ErrBlockedAddress) + } + if ip := net.ParseIP(host); ip != nil && isBlockedIP(ip) { + return fmt.Errorf("%w: %s", ErrBlockedAddress, ip) + } + return nil +} + +// safeDialContext wraps a base dial function so every resolved IP is +// checked before the connection is established. It catches: +// - hostnames that resolve to blocked IPs +// - DNS rebinding (a host that resolves to a public IP at validate time +// and a blocked IP at dial time) +// +// If allowPrivate is true the wrapper is a transparent pass-through. +func safeDialContext( + base func(ctx context.Context, network, addr string) (net.Conn, error), + allowPrivate bool, +) func(ctx context.Context, network, addr string) (net.Conn, error) { + if allowPrivate { + return base + } + resolver := net.DefaultResolver + return func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("split host:port: %w", err) + } + // Literal IP fast path. + if ip := net.ParseIP(host); ip != nil { + if isBlockedIP(ip) { + return nil, fmt.Errorf("%w: %s", ErrBlockedAddress, ip) + } + return base(ctx, network, addr) + } + // Resolve and verify every returned address. + ips, err := resolver.LookupIP(ctx, ipNetwork(network), host) + if err != nil { + return nil, err + } + for _, ip := range ips { + if isBlockedIP(ip) { + return nil, fmt.Errorf("%w: %s -> %s", ErrBlockedAddress, host, ip) + } + } + // Re-dial using the verified address. We pick the first allowed IP + // to avoid the resolver returning a different (possibly blocked) + // address between LookupIP and Dial. + return base(ctx, network, net.JoinHostPort(ips[0].String(), port)) + } +} + +func ipNetwork(network string) string { + switch network { + case "tcp4", "udp4", "ip4": + return "ip4" + case "tcp6", "udp6", "ip6": + return "ip6" + default: + return "ip" + } +} diff --git a/internal/checker/safety_test.go b/internal/checker/safety_test.go new file mode 100644 index 0000000..77ffdfc --- /dev/null +++ b/internal/checker/safety_test.go @@ -0,0 +1,325 @@ +package checker + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================================= +// isBlockedIP / validateURL +// ============================================================================= + +func TestIsBlockedIP_TableDriven(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + ip string + blocked bool + }{ + // IPv4 reserved ranges that MUST be blocked. + {"loopback 127.0.0.1", "127.0.0.1", true}, + {"loopback 127.0.0.53", "127.0.0.53", true}, + {"link-local IMDS", "169.254.169.254", true}, + {"link-local generic", "169.254.1.1", true}, + {"RFC1918 10/8", "10.0.0.1", true}, + {"RFC1918 172.16/12 low", "172.16.0.1", true}, + {"RFC1918 172.16/12 high", "172.31.255.254", true}, + {"RFC1918 192.168/16", "192.168.1.1", true}, + {"CGNAT", "100.64.0.1", true}, + {"unspecified 0.0.0.0", "0.0.0.0", true}, + {"multicast", "224.0.0.1", true}, + {"broadcast 255.255.255.255", "255.255.255.255", true}, + {"TEST-NET-1", "192.0.2.1", true}, + + // IPv6 reserved ranges. + {"IPv6 loopback ::1", "::1", true}, + {"IPv6 unspecified ::", "::", true}, + {"IPv6 link-local fe80::", "fe80::1", true}, + {"IPv6 ULA fc00::", "fc00::1", true}, + {"IPv6 ULA fd00::", "fd00::1", true}, + + // Public addresses must NOT be blocked. + {"public IPv4 1.1.1.1", "1.1.1.1", false}, + {"public IPv4 8.8.8.8", "8.8.8.8", false}, + {"public IPv4 just above 172/12", "172.32.0.1", false}, + {"public IPv4 just below 10/8", "9.255.255.255", false}, + {"public IPv6 2001:4860::8888", "2001:4860::8888", false}, + + // Edge boundary case: 172.15.255.255 is OUTSIDE 172.16/12. + {"public boundary 172.15.255.255", "172.15.255.255", false}, + + // nil-ish (parses to nil) must be treated as blocked. + {"empty string", "", true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ip := net.ParseIP(tc.ip) + assert.Equal(t, tc.blocked, isBlockedIP(ip), + "isBlockedIP(%q) = %v, want %v", tc.ip, isBlockedIP(ip), tc.blocked) + }) + } +} + +func TestValidateURL_RejectsBlockedSchemes(t *testing.T) { + t.Parallel() + + cases := []string{ + "file:///etc/passwd", + "ftp://example.com/x", + "gopher://example.com/", + "javascript:alert(1)", + "data:text/html,