Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion cmd/collect_website.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package cmd
import (
"fmt"
"os"
"time"

"github.com/schollz/progressbar/v3"
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/httpclient"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
"github.com/Snider/Borg/pkg/ui"
Expand Down Expand Up @@ -38,6 +40,11 @@ func NewCollectWebsiteCmd() *cobra.Command {
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
password, _ := cmd.Flags().GetString("password")
maxConnections, _ := cmd.Flags().GetInt("max-connections")
noKeepAlive, _ := cmd.Flags().GetBool("no-keepalive")
http1, _ := cmd.Flags().GetBool("http1")
idleTimeout, _ := cmd.Flags().GetDuration("idle-timeout")
maxIdle, _ := cmd.Flags().GetInt("max-idle")
Comment on lines +43 to +47

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The errors returned from parsing these new command-line flags are being ignored. If a user provides an invalid value for any of these flags (e.g., --max-connections=foo), the error will be discarded, and the variable will be assigned its zero value. This can lead to the application running with an unintended configuration without any warning. The errors should be checked and returned to the user, so they are aware of the issue.

			maxConnections, err := cmd.Flags().GetInt("max-connections")
			if err != nil {
				return err
			}
			noKeepAlive, err := cmd.Flags().GetBool("no-keepalive")
			if err != nil {
				return err
			}
			http1, err := cmd.Flags().GetBool("http1")
			if err != nil {
				return err
			}
			idleTimeout, err := cmd.Flags().GetDuration("idle-timeout")
			if err != nil {
				return err
			}
			maxIdle, err := cmd.Flags().GetInt("max-idle")
			if err != nil {
				return err
			}


if format != "datanode" && format != "tim" && format != "trix" {
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format)
Expand All @@ -51,11 +58,25 @@ func NewCollectWebsiteCmd() *cobra.Command {
bar = ui.NewProgressBar(-1, "Crawling website")
}

dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
// Create a new HTTP client with the specified options.
client, metrics := httpclient.New(httpclient.Options{
MaxPerHost: maxConnections,
NoKeepAlive: noKeepAlive,
HTTP1: http1,
IdleTimeout: idleTimeout,
MaxIdle: maxIdle,
})

dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, client)
if err != nil {
return fmt.Errorf("error downloading and packaging website: %w", err)
}

// Display the connection reuse metrics.
fmt.Fprintln(cmd.OutOrStdout(), "Connection Metrics:")
fmt.Fprintf(cmd.OutOrStdout(), " Connections Reused: %d\n", metrics.ConnectionsReused)
fmt.Fprintf(cmd.OutOrStdout(), " Connections Created: %d\n", metrics.ConnectionsCreated)

var data []byte
if format == "tim" {
tim, err := tim.FromDataNode(dn)
Expand Down Expand Up @@ -104,5 +125,10 @@ func NewCollectWebsiteCmd() *cobra.Command {
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
collectWebsiteCmd.Flags().Int("max-connections", 6, "Max connections per domain")
collectWebsiteCmd.Flags().Bool("no-keepalive", false, "Disable keep-alive")
collectWebsiteCmd.Flags().Bool("http1", false, "Force HTTP/1.1")
collectWebsiteCmd.Flags().Duration("idle-timeout", 90*time.Second, "Close idle connections after")
collectWebsiteCmd.Flags().Int("max-idle", 100, "Max idle connections total")
return collectWebsiteCmd
}
5 changes: 3 additions & 2 deletions cmd/collect_website_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"net/http"
"path/filepath"
"strings"
"testing"
Expand All @@ -14,7 +15,7 @@ import (
func TestCollectWebsiteCmd_Good(t *testing.T) {
// Mock the website downloader
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
return datanode.New(), nil
}
defer func() {
Expand All @@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
func TestCollectWebsiteCmd_Bad(t *testing.T) {
// Mock the website downloader to return an error
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
return nil, fmt.Errorf("website error")
}
defer func() {
Expand Down
3 changes: 2 additions & 1 deletion examples/collect_website/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"log"
"net/http"
"os"

"github.com/Snider/Borg/pkg/website"
Expand All @@ -11,7 +12,7 @@ func main() {
log.Println("Collecting website...")

// Download and package the website.
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, http.DefaultClient)
if err != nil {
log.Fatalf("Failed to collect website: %v", err)
}
Expand Down
74 changes: 74 additions & 0 deletions pkg/httpclient/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package httpclient

import (
"crypto/tls"
"net"
"net/http"
"net/http/httptrace"
"sync/atomic"
"time"
)

// Metrics holds the connection reuse metrics.
type Metrics struct {
ConnectionsReused int64
ConnectionsCreated int64
}

// Options represents the configuration for the HTTP client.
type Options struct {
MaxPerHost int
MaxIdle int
IdleTimeout time.Duration
NoKeepAlive bool
HTTP1 bool
}

// New creates a new HTTP client with the given options.
func New(opts Options) (*http.Client, *Metrics) {
metrics := &Metrics{}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: opts.MaxIdle,
IdleConnTimeout: opts.IdleTimeout,
TLSHandshakeTimeout: 10 * time.Second,
MaxConnsPerHost: opts.MaxPerHost,
DisableKeepAlives: opts.NoKeepAlive,
}
Comment on lines +30 to +41

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of creating a new http.Transport from scratch, it's better practice to clone http.DefaultTransport. This ensures that you start with a known-good configuration with sensible defaults (like ForceAttemptHTTP2 for HTTP/2 support, which this PR aims to add) and then customize it. Re-implementing the defaults can lead to missing out on important settings or future improvements to DefaultTransport.

	transport := http.DefaultTransport.(*http.Transport).Clone()
	transport.MaxIdleConns = opts.MaxIdle
	transport.IdleConnTimeout = opts.IdleTimeout
	transport.MaxConnsPerHost = opts.MaxPerHost
	transport.DisableKeepAlives = opts.NoKeepAlive


if opts.HTTP1 {
// Disable HTTP/2 by preventing the TLS next protocol negotiation.
transport.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
}

return &http.Client{
Transport: &metricsRoundTripper{
transport: transport,
metrics: metrics,
},
}, metrics
}

// metricsRoundTripper is a custom http.RoundTripper that collects connection reuse metrics.
type metricsRoundTripper struct {
transport http.RoundTripper
metrics *Metrics
}

func (m *metricsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
trace := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
if info.Reused {
atomic.AddInt64(&m.metrics.ConnectionsReused, 1)
} else {
atomic.AddInt64(&m.metrics.ConnectionsCreated, 1)
}
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
return m.transport.RoundTrip(req)
}
6 changes: 3 additions & 3 deletions pkg/website/website.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"golang.org/x/net/html"
)

var DownloadAndPackageWebsite = downloadAndPackageWebsite
var DownloadAndPackageWebsite func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) = downloadAndPackageWebsite

// Downloader is a recursive website downloader.
type Downloader struct {
Expand Down Expand Up @@ -43,13 +43,13 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader {
}

// downloadAndPackageWebsite downloads a website and packages it into a DataNode.
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
baseURL, err := url.Parse(startURL)
if err != nil {
return nil, err
}

d := NewDownloader(maxDepth)
d := NewDownloaderWithClient(maxDepth, client)
d.baseURL = baseURL
d.progressBar = bar
d.crawl(startURL, 0)
Expand Down
14 changes: 7 additions & 7 deletions pkg/website/website_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {
defer server.Close()

bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar)
dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, http.DefaultClient)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
Expand Down Expand Up @@ -52,7 +52,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) {

func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
t.Run("Invalid Start URL", func(t *testing.T) {
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil)
_, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, http.DefaultClient)
if err == nil {
t.Fatal("Expected an error for an invalid start URL, but got nil")
}
Expand All @@ -63,7 +63,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}))
defer server.Close()
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
if err == nil {
t.Fatal("Expected an error for a server error on the start URL, but got nil")
}
Expand All @@ -80,7 +80,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) {
}))
defer server.Close()
// We expect an error because the link is broken.
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
if err == nil {
t.Fatal("Expected an error for a broken link, but got nil")
}
Expand All @@ -99,7 +99,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
defer server.Close()

bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard))
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1
dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, http.DefaultClient) // Max depth of 1
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
Expand All @@ -122,7 +122,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
fmt.Fprint(w, `<a href="http://externalsite.com/page.html">External</a>`)
}))
defer server.Close()
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil)
dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
if err != nil {
t.Fatalf("DownloadAndPackageWebsite failed: %v", err)
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) {
// For now, we'll just test that it doesn't hang forever.
done := make(chan bool)
go func() {
_, err := DownloadAndPackageWebsite(server.URL, 1, nil)
_, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient)
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
// We expect a timeout error, but other errors are failures.
t.Errorf("unexpected error: %v", err)
Expand Down
Loading