From 76a097295f83c8b90f19013c534fe0708227bb02 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 2 Feb 2026 00:52:51 +0000 Subject: [PATCH] feat: Add connection pooling and keep-alive This change introduces connection pooling, keep-alive, and HTTP/2 support to the website collector. It adds a new httpclient package to create a configurable http.Client and exposes the configuration options as command-line flags. It also adds connection reuse metrics to the output of the collect website command. Co-authored-by: Snider <631881+Snider@users.noreply.github.com> --- cmd/collect_website.go | 28 +++++++++++- cmd/collect_website_test.go | 5 ++- examples/collect_website/main.go | 3 +- pkg/httpclient/client.go | 74 ++++++++++++++++++++++++++++++++ pkg/website/website.go | 6 +-- pkg/website/website_test.go | 14 +++--- 6 files changed, 116 insertions(+), 14 deletions(-) create mode 100644 pkg/httpclient/client.go diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 3811f32..848b5f5 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -3,9 +3,11 @@ package cmd import ( "fmt" "os" + "time" "github.com/schollz/progressbar/v3" "github.com/Snider/Borg/pkg/compress" + "github.com/Snider/Borg/pkg/httpclient" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" "github.com/Snider/Borg/pkg/ui" @@ -38,6 +40,11 @@ func NewCollectWebsiteCmd() *cobra.Command { format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") password, _ := cmd.Flags().GetString("password") + maxConnections, _ := cmd.Flags().GetInt("max-connections") + noKeepAlive, _ := cmd.Flags().GetBool("no-keepalive") + http1, _ := cmd.Flags().GetBool("http1") + idleTimeout, _ := cmd.Flags().GetDuration("idle-timeout") + maxIdle, _ := cmd.Flags().GetInt("max-idle") if format != "datanode" && format != "tim" && format != "trix" { return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format) @@ -51,11 +58,25 @@ func NewCollectWebsiteCmd() *cobra.Command { bar = ui.NewProgressBar(-1, "Crawling website") } - dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) + // Create a new HTTP client with the specified options. + client, metrics := httpclient.New(httpclient.Options{ + MaxPerHost: maxConnections, + NoKeepAlive: noKeepAlive, + HTTP1: http1, + IdleTimeout: idleTimeout, + MaxIdle: maxIdle, + }) + + dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, client) if err != nil { return fmt.Errorf("error downloading and packaging website: %w", err) } + // Display the connection reuse metrics. + fmt.Fprintln(cmd.OutOrStdout(), "Connection Metrics:") + fmt.Fprintf(cmd.OutOrStdout(), " Connections Reused: %d\n", metrics.ConnectionsReused) + fmt.Fprintf(cmd.OutOrStdout(), " Connections Created: %d\n", metrics.ConnectionsCreated) + var data []byte if format == "tim" { tim, err := tim.FromDataNode(dn) @@ -104,5 +125,10 @@ func NewCollectWebsiteCmd() *cobra.Command { collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)") collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption") + collectWebsiteCmd.Flags().Int("max-connections", 6, "Max connections per domain") + collectWebsiteCmd.Flags().Bool("no-keepalive", false, "Disable keep-alive") + collectWebsiteCmd.Flags().Bool("http1", false, "Force HTTP/1.1") + collectWebsiteCmd.Flags().Duration("idle-timeout", 90*time.Second, "Close idle connections after") + collectWebsiteCmd.Flags().Int("max-idle", 100, "Max idle connections total") return collectWebsiteCmd } diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index 2c39674..a91201f 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "net/http" "path/filepath" "strings" "testing" @@ -14,7 +15,7 @@ import ( func TestCollectWebsiteCmd_Good(t *testing.T) { // Mock the website downloader oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { return datanode.New(), nil } defer func() { @@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) { func TestCollectWebsiteCmd_Bad(t *testing.T) { // Mock the website downloader to return an error oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { return nil, fmt.Errorf("website error") } defer func() { diff --git a/examples/collect_website/main.go b/examples/collect_website/main.go index 2e2f606..ba02e6d 100644 --- a/examples/collect_website/main.go +++ b/examples/collect_website/main.go @@ -2,6 +2,7 @@ package main import ( "log" + "net/http" "os" "github.com/Snider/Borg/pkg/website" @@ -11,7 +12,7 @@ func main() { log.Println("Collecting website...") // Download and package the website. - dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil) + dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, http.DefaultClient) if err != nil { log.Fatalf("Failed to collect website: %v", err) } diff --git a/pkg/httpclient/client.go b/pkg/httpclient/client.go new file mode 100644 index 0000000..04b0b46 --- /dev/null +++ b/pkg/httpclient/client.go @@ -0,0 +1,74 @@ +package httpclient + +import ( + "crypto/tls" + "net" + "net/http" + "net/http/httptrace" + "sync/atomic" + "time" +) + +// Metrics holds the connection reuse metrics. +type Metrics struct { + ConnectionsReused int64 + ConnectionsCreated int64 +} + +// Options represents the configuration for the HTTP client. +type Options struct { + MaxPerHost int + MaxIdle int + IdleTimeout time.Duration + NoKeepAlive bool + HTTP1 bool +} + +// New creates a new HTTP client with the given options. +func New(opts Options) (*http.Client, *Metrics) { + metrics := &Metrics{} + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: opts.MaxIdle, + IdleConnTimeout: opts.IdleTimeout, + TLSHandshakeTimeout: 10 * time.Second, + MaxConnsPerHost: opts.MaxPerHost, + DisableKeepAlives: opts.NoKeepAlive, + } + + if opts.HTTP1 { + // Disable HTTP/2 by preventing the TLS next protocol negotiation. + transport.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper) + } + + return &http.Client{ + Transport: &metricsRoundTripper{ + transport: transport, + metrics: metrics, + }, + }, metrics +} + +// metricsRoundTripper is a custom http.RoundTripper that collects connection reuse metrics. +type metricsRoundTripper struct { + transport http.RoundTripper + metrics *Metrics +} + +func (m *metricsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + trace := &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + if info.Reused { + atomic.AddInt64(&m.metrics.ConnectionsReused, 1) + } else { + atomic.AddInt64(&m.metrics.ConnectionsCreated, 1) + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + return m.transport.RoundTrip(req) +} diff --git a/pkg/website/website.go b/pkg/website/website.go index b2bd517..64c95b6 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -13,7 +13,7 @@ import ( "golang.org/x/net/html" ) -var DownloadAndPackageWebsite = downloadAndPackageWebsite +var DownloadAndPackageWebsite func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) = downloadAndPackageWebsite // Downloader is a recursive website downloader. type Downloader struct { @@ -43,13 +43,13 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { } // downloadAndPackageWebsite downloads a website and packages it into a DataNode. -func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { +func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { baseURL, err := url.Parse(startURL) if err != nil { return nil, err } - d := NewDownloader(maxDepth) + d := NewDownloaderWithClient(maxDepth, client) d.baseURL = baseURL d.progressBar = bar d.crawl(startURL, 0) diff --git a/pkg/website/website_test.go b/pkg/website/website_test.go index d3685e5..aad271b 100644 --- a/pkg/website/website_test.go +++ b/pkg/website/website_test.go @@ -20,7 +20,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 2, bar) + dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, http.DefaultClient) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -52,7 +52,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { func TestDownloadAndPackageWebsite_Bad(t *testing.T) { t.Run("Invalid Start URL", func(t *testing.T) { - _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil) + _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, http.DefaultClient) if err == nil { t.Fatal("Expected an error for an invalid start URL, but got nil") } @@ -63,7 +63,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) })) defer server.Close() - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient) if err == nil { t.Fatal("Expected an error for a server error on the start URL, but got nil") } @@ -80,7 +80,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { })) defer server.Close() // We expect an error because the link is broken. - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient) if err == nil { t.Fatal("Expected an error for a broken link, but got nil") } @@ -99,7 +99,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1 + dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, http.DefaultClient) // Max depth of 1 if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -122,7 +122,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { fmt.Fprint(w, `External`) })) defer server.Close() - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -156,7 +156,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { // For now, we'll just test that it doesn't hang forever. done := make(chan bool) go func() { - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient) if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { // We expect a timeout error, but other errors are failures. t.Errorf("unexpected error: %v", err)