diff --git a/cmd/all.go b/cmd/all.go index 84a06db..385ed5b 100644 --- a/cmd/all.go +++ b/cmd/all.go @@ -11,6 +11,7 @@ import ( "github.com/Snider/Borg/pkg/compress" "github.com/Snider/Borg/pkg/datanode" "github.com/Snider/Borg/pkg/github" + "github.com/Snider/Borg/pkg/httpclient" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" "github.com/Snider/Borg/pkg/ui" @@ -42,7 +43,15 @@ func NewAllCmd() *cobra.Command { return err } - repos, err := GithubClient.GetPublicRepos(cmd.Context(), owner) + totalTimeout, _ := cmd.Flags().GetDuration("timeout") + connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout") + tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout") + headerTimeout, _ := cmd.Flags().GetDuration("header-timeout") + + httpClient := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout) + githubClient := github.NewGithubClient(httpClient) + + repos, err := githubClient.GetPublicRepos(cmd.Context(), owner) if err != nil { return err } diff --git a/cmd/all_test.go b/cmd/all_test.go index 66b4af1..659e396 100644 --- a/cmd/all_test.go +++ b/cmd/all_test.go @@ -23,7 +23,7 @@ func TestAllCmd_Good(t *testing.T) { }, }) oldNewAuthenticatedClient := github.NewAuthenticatedClient - github.NewAuthenticatedClient = func(ctx context.Context) *http.Client { + github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { return mockGithubClient } defer func() { @@ -67,7 +67,7 @@ func TestAllCmd_Bad(t *testing.T) { }, }) oldNewAuthenticatedClient := github.NewAuthenticatedClient - github.NewAuthenticatedClient = func(ctx context.Context) *http.Client { + github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { return mockGithubClient } defer func() { @@ -96,7 +96,7 @@ func TestAllCmd_Ugly(t *testing.T) { }, }) oldNewAuthenticatedClient := github.NewAuthenticatedClient - github.NewAuthenticatedClient = func(ctx context.Context) *http.Client { + github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { return mockGithubClient } defer func() { diff --git a/cmd/collect.go b/cmd/collect.go index a45ab09..7e78a1c 100644 --- a/cmd/collect.go +++ b/cmd/collect.go @@ -1,6 +1,8 @@ package cmd import ( + "time" + "github.com/spf13/cobra" ) @@ -11,11 +13,18 @@ func init() { RootCmd.AddCommand(GetCollectCmd()) } func NewCollectCmd() *cobra.Command { - return &cobra.Command{ + cmd := &cobra.Command{ Use: "collect", Short: "Collect a resource from a URI.", Long: `Collect a resource from a URI and store it in a DataNode.`, } + + cmd.PersistentFlags().Duration("timeout", 0, "Total request timeout (e.g., 60s). 0 means no timeout.") + cmd.PersistentFlags().Duration("connect-timeout", 10*time.Second, "TCP connection establishment timeout (e.g., 10s)") + cmd.PersistentFlags().Duration("tls-timeout", 10*time.Second, "TLS handshake timeout (e.g., 10s)") + cmd.PersistentFlags().Duration("header-timeout", 30*time.Second, "Time to receive response headers timeout (e.g., 30s)") + + return cmd } func GetCollectCmd() *cobra.Command { diff --git a/cmd/collect_github_repo.go b/cmd/collect_github_repo.go index c25df3b..0010c2e 100644 --- a/cmd/collect_github_repo.go +++ b/cmd/collect_github_repo.go @@ -6,6 +6,7 @@ import ( "os" "github.com/Snider/Borg/pkg/compress" + "github.com/Snider/Borg/pkg/httpclient" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" "github.com/Snider/Borg/pkg/ui" @@ -44,6 +45,13 @@ func NewCollectGithubRepoCmd() *cobra.Command { return fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression) } + totalTimeout, _ := cmd.Flags().GetDuration("timeout") + connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout") + tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout") + headerTimeout, _ := cmd.Flags().GetDuration("header-timeout") + + _ = httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout) + prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote) prompter.Start() defer prompter.Stop() diff --git a/cmd/collect_github_repos.go b/cmd/collect_github_repos.go index dfcd315..2898fb7 100644 --- a/cmd/collect_github_repos.go +++ b/cmd/collect_github_repos.go @@ -4,20 +4,24 @@ import ( "fmt" "github.com/Snider/Borg/pkg/github" + "github.com/Snider/Borg/pkg/httpclient" "github.com/spf13/cobra" ) -var ( - // GithubClient is the github client used by the command. It can be replaced for testing. - GithubClient = github.NewGithubClient() -) - var collectGithubReposCmd = &cobra.Command{ Use: "repos [user-or-org]", Short: "Collects all public repositories for a user or organization", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0]) + totalTimeout, _ := cmd.Flags().GetDuration("timeout") + connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout") + tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout") + headerTimeout, _ := cmd.Flags().GetDuration("header-timeout") + + httpClient := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout) + githubClient := github.NewGithubClient(httpClient) + + repos, err := githubClient.GetPublicRepos(cmd.Context(), args[0]) if err != nil { return err } diff --git a/cmd/collect_pwa.go b/cmd/collect_pwa.go index 8b5ef8c..6d8c2e5 100644 --- a/cmd/collect_pwa.go +++ b/cmd/collect_pwa.go @@ -5,6 +5,7 @@ import ( "os" "github.com/Snider/Borg/pkg/compress" + "github.com/Snider/Borg/pkg/httpclient" "github.com/Snider/Borg/pkg/pwa" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" @@ -13,17 +14,9 @@ import ( "github.com/spf13/cobra" ) -type CollectPWACmd struct { - cobra.Command - PWAClient pwa.PWAClient -} - // NewCollectPWACmd creates a new collect pwa command -func NewCollectPWACmd() *CollectPWACmd { - c := &CollectPWACmd{ - PWAClient: pwa.NewPWAClient(), - } - c.Command = cobra.Command{ +func NewCollectPWACmd() *cobra.Command { + cmd := &cobra.Command{ Use: "pwa [url]", Short: "Collect a single PWA using a URI", Long: `Collect a single PWA and store it in a DataNode. @@ -44,7 +37,15 @@ Examples: compression, _ := cmd.Flags().GetString("compression") password, _ := cmd.Flags().GetString("password") - finalPath, err := CollectPWA(c.PWAClient, pwaURL, outputFile, format, compression, password) + totalTimeout, _ := cmd.Flags().GetDuration("timeout") + connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout") + tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout") + headerTimeout, _ := cmd.Flags().GetDuration("header-timeout") + + httpClient := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout) + pwaClient := pwa.NewPWAClient(httpClient) + + finalPath, err := CollectPWA(pwaClient, pwaURL, outputFile, format, compression, password) if err != nil { return err } @@ -52,16 +53,16 @@ Examples: return nil }, } - c.Flags().String("uri", "", "The URI of the PWA to collect (can also be passed as positional arg)") - c.Flags().String("output", "", "Output file for the DataNode") - c.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)") - c.Flags().String("compression", "none", "Compression format (none, gz, or xz)") - c.Flags().String("password", "", "Password for encryption (required for stim format)") - return c + cmd.Flags().String("uri", "", "The URI of the PWA to collect (can also be passed as positional arg)") + cmd.Flags().String("output", "", "Output file for the DataNode") + cmd.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)") + cmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)") + cmd.Flags().String("password", "", "Password for encryption (required for stim format)") + return cmd } func init() { - collectCmd.AddCommand(&NewCollectPWACmd().Command) + collectCmd.AddCommand(NewCollectPWACmd()) } func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format string, compression string, password string) (string, error) { if pwaURL == "" { diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 3811f32..cc77201 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -6,6 +6,7 @@ import ( "github.com/schollz/progressbar/v3" "github.com/Snider/Borg/pkg/compress" + "github.com/Snider/Borg/pkg/httpclient" "github.com/Snider/Borg/pkg/tim" "github.com/Snider/Borg/pkg/trix" "github.com/Snider/Borg/pkg/ui" @@ -51,7 +52,14 @@ func NewCollectWebsiteCmd() *cobra.Command { bar = ui.NewProgressBar(-1, "Crawling website") } - dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) + totalTimeout, _ := cmd.Flags().GetDuration("timeout") + connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout") + tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout") + headerTimeout, _ := cmd.Flags().GetDuration("header-timeout") + + client := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout) + + dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, client) if err != nil { return fmt.Errorf("error downloading and packaging website: %w", err) } diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index 2c39674..a91201f 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "net/http" "path/filepath" "strings" "testing" @@ -14,7 +15,7 @@ import ( func TestCollectWebsiteCmd_Good(t *testing.T) { // Mock the website downloader oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { return datanode.New(), nil } defer func() { @@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) { func TestCollectWebsiteCmd_Bad(t *testing.T) { // Mock the website downloader to return an error oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { return nil, fmt.Errorf("website error") } defer func() { diff --git a/examples/all/main.go b/examples/all/main.go index 6411baa..0ef735d 100644 --- a/examples/all/main.go +++ b/examples/all/main.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net/http" "os" "github.com/Snider/Borg/pkg/github" @@ -13,7 +14,7 @@ import ( func main() { log.Println("Collecting all repositories for a user...") - repos, err := github.NewGithubClient().GetPublicRepos(context.Background(), "Snider") + repos, err := github.NewGithubClient(http.DefaultClient).GetPublicRepos(context.Background(), "Snider") if err != nil { log.Fatalf("Failed to get public repos: %v", err) } diff --git a/examples/collect_pwa/main.go b/examples/collect_pwa/main.go index 963ba62..1d0bd7e 100644 --- a/examples/collect_pwa/main.go +++ b/examples/collect_pwa/main.go @@ -2,6 +2,7 @@ package main import ( "log" + "net/http" "os" "github.com/Snider/Borg/pkg/pwa" @@ -10,7 +11,7 @@ import ( func main() { log.Println("Collecting PWA...") - client := pwa.NewPWAClient() + client := pwa.NewPWAClient(http.DefaultClient) pwaURL := "https://squoosh.app" manifestURL, err := client.FindManifest(pwaURL) diff --git a/examples/collect_website/main.go b/examples/collect_website/main.go index 2e2f606..ba02e6d 100644 --- a/examples/collect_website/main.go +++ b/examples/collect_website/main.go @@ -2,6 +2,7 @@ package main import ( "log" + "net/http" "os" "github.com/Snider/Borg/pkg/website" @@ -11,7 +12,7 @@ func main() { log.Println("Collecting website...") // Download and package the website. - dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil) + dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, http.DefaultClient) if err != nil { log.Fatalf("Failed to collect website: %v", err) } diff --git a/pkg/github/github.go b/pkg/github/github.go index 2e2e832..bc5a8e7 100644 --- a/pkg/github/github.go +++ b/pkg/github/github.go @@ -21,22 +21,29 @@ type GithubClient interface { } // NewGithubClient creates a new GithubClient. -func NewGithubClient() GithubClient { - return &githubClient{} +func NewGithubClient(client *http.Client) GithubClient { + return &githubClient{client: client} } -type githubClient struct{} +type githubClient struct { + client *http.Client +} // NewAuthenticatedClient creates a new authenticated http client. -var NewAuthenticatedClient = func(ctx context.Context) *http.Client { +var NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { token := os.Getenv("GITHUB_TOKEN") if token == "" { - return http.DefaultClient + return baseClient } ts := oauth2.StaticTokenSource( &oauth2.Token{AccessToken: token}, ) - return oauth2.NewClient(ctx, ts) + authedClient := oauth2.NewClient(ctx, ts) + authedClient.Transport = &oauth2.Transport{ + Base: baseClient.Transport, + Source: ts, + } + return authedClient } func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) { @@ -44,7 +51,7 @@ func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([] } func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) { - client := NewAuthenticatedClient(ctx) + client := NewAuthenticatedClient(ctx, g.client) var allCloneURLs []string url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg) isFirstRequest := true diff --git a/pkg/github/github_test.go b/pkg/github/github_test.go index 37857bd..78d9188 100644 --- a/pkg/github/github_test.go +++ b/pkg/github/github_test.go @@ -154,7 +154,7 @@ func TestFindNextURL_Ugly(t *testing.T) { func TestNewAuthenticatedClient_Good(t *testing.T) { t.Setenv("GITHUB_TOKEN", "test-token") - client := NewAuthenticatedClient(context.Background()) + client := NewAuthenticatedClient(context.Background(), http.DefaultClient) if client == http.DefaultClient { t.Error("expected an authenticated client, but got http.DefaultClient") } @@ -163,7 +163,7 @@ func TestNewAuthenticatedClient_Good(t *testing.T) { func TestNewAuthenticatedClient_Bad(t *testing.T) { // Unset the variable to ensure it's not present t.Setenv("GITHUB_TOKEN", "") - client := NewAuthenticatedClient(context.Background()) + client := NewAuthenticatedClient(context.Background(), http.DefaultClient) if client != http.DefaultClient { t.Error("expected http.DefaultClient when no token is set, but got something else") } @@ -171,9 +171,9 @@ func TestNewAuthenticatedClient_Bad(t *testing.T) { // setupMockClient is a helper function to inject a mock http.Client. func setupMockClient(t *testing.T, mock *http.Client) *githubClient { - client := &githubClient{} + client := &githubClient{client: mock} originalNewAuthenticatedClient := NewAuthenticatedClient - NewAuthenticatedClient = func(ctx context.Context) *http.Client { + NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client { return mock } // Restore the original function after the test diff --git a/pkg/httpclient/httpclient.go b/pkg/httpclient/httpclient.go new file mode 100644 index 0000000..372c272 --- /dev/null +++ b/pkg/httpclient/httpclient.go @@ -0,0 +1,23 @@ +package httpclient + +import ( + "net" + "net/http" + "time" +) + +// NewClient creates a new http.Client with configurable timeouts. +func NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout time.Duration) *http.Client { + transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: connectTimeout, + }).DialContext, + TLSHandshakeTimeout: tlsTimeout, + ResponseHeaderTimeout: headerTimeout, + } + + return &http.Client{ + Timeout: totalTimeout, + Transport: transport, + } +} diff --git a/pkg/httpclient/httpclient_test.go b/pkg/httpclient/httpclient_test.go new file mode 100644 index 0000000..1f4a085 --- /dev/null +++ b/pkg/httpclient/httpclient_test.go @@ -0,0 +1,33 @@ +package httpclient + +import ( + "net/http" + "testing" + "time" +) + +func TestNewClient(t *testing.T) { + totalTimeout := 10 * time.Second + connectTimeout := 2 * time.Second + tlsTimeout := 3 * time.Second + headerTimeout := 5 * time.Second + + client := NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout) + + if client.Timeout != totalTimeout { + t.Errorf("expected total timeout %v, got %v", totalTimeout, client.Timeout) + } + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected client transport to be *http.Transport, got %T", client.Transport) + } + + if transport.TLSHandshakeTimeout != tlsTimeout { + t.Errorf("expected TLS handshake timeout %v, got %v", tlsTimeout, transport.TLSHandshakeTimeout) + } + + if transport.ResponseHeaderTimeout != headerTimeout { + t.Errorf("expected response header timeout %v, got %v", headerTimeout, transport.ResponseHeaderTimeout) + } +} diff --git a/pkg/pwa/pwa.go b/pkg/pwa/pwa.go index ce7af06..8b9abc4 100644 --- a/pkg/pwa/pwa.go +++ b/pkg/pwa/pwa.go @@ -31,8 +31,8 @@ type PWAClient interface { } // NewPWAClient creates a new PWAClient. -func NewPWAClient() PWAClient { - return &pwaClient{client: http.DefaultClient} +func NewPWAClient(client *http.Client) PWAClient { + return &pwaClient{client: client} } type pwaClient struct { diff --git a/pkg/pwa/pwa_test.go b/pkg/pwa/pwa_test.go index 4145cd9..365bc3d 100644 --- a/pkg/pwa/pwa_test.go +++ b/pkg/pwa/pwa_test.go @@ -20,7 +20,7 @@ func TestFindManifest_Good(t *testing.T) { })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) expectedURL := server.URL + "/manifest.json" actualURL, err := client.FindManifest(server.URL) if err != nil { @@ -43,7 +43,7 @@ func TestFindManifest_Bad(t *testing.T) { } })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) _, err := client.FindManifest(server.URL) if err == nil { t.Fatal("expected an error, but got none") @@ -55,7 +55,7 @@ func TestFindManifest_Bad(t *testing.T) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) _, err := client.FindManifest(server.URL) if err == nil { t.Fatal("expected an error for server error, but got none") @@ -70,7 +70,7 @@ func TestFindManifest_Ugly(t *testing.T) { fmt.Fprint(w, `
`) })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) // Should find the first one expectedURL := server.URL + "/first.json" actualURL, err := client.FindManifest(server.URL) @@ -98,7 +98,7 @@ func TestFindManifest_Ugly(t *testing.T) { } })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) expectedURL := server.URL + "/manifest.json" actualURL, err := client.FindManifest(server.URL) if err != nil { @@ -123,7 +123,7 @@ func TestFindManifest_Ugly(t *testing.T) { } })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) expectedURL := server.URL + "/site.webmanifest" actualURL, err := client.FindManifest(server.URL) if err != nil { @@ -141,7 +141,7 @@ func TestDownloadAndPackagePWA_Good(t *testing.T) { server := newPWATestServer() defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) dn, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", bar) if err != nil { @@ -161,7 +161,7 @@ func TestDownloadAndPackagePWA_Bad(t *testing.T) { t.Run("Bad Manifest URL", func(t *testing.T) { server := newPWATestServer() defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) _, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/nonexistent-manifest.json", nil) if err == nil { t.Fatal("expected an error for bad manifest url, but got none") @@ -178,7 +178,7 @@ func TestDownloadAndPackagePWA_Bad(t *testing.T) { } })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) _, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", nil) if err == nil { t.Fatal("expected an error for asset 404, but got none") @@ -198,7 +198,7 @@ func TestDownloadAndPackagePWA_Ugly(t *testing.T) { })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) dn, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", nil) if err != nil { t.Fatalf("unexpected error for manifest with no assets: %v", err) @@ -214,7 +214,7 @@ func TestDownloadAndPackagePWA_Ugly(t *testing.T) { // --- Test Cases for resolveURL --- func TestResolveURL_Good(t *testing.T) { - client := NewPWAClient().(*pwaClient) + client := NewPWAClient(http.DefaultClient).(*pwaClient) tests := []struct { base string ref string @@ -239,7 +239,7 @@ func TestResolveURL_Good(t *testing.T) { } func TestResolveURL_Bad(t *testing.T) { - client := NewPWAClient().(*pwaClient) + client := NewPWAClient(http.DefaultClient).(*pwaClient) _, err := client.resolveURL("http://^invalid.com", "foo.html") if err == nil { t.Error("expected error for malformed base URL, but got nil") @@ -249,7 +249,7 @@ func TestResolveURL_Bad(t *testing.T) { // --- Test Cases for extractAssetsFromHTML --- func TestExtractAssetsFromHTML(t *testing.T) { - client := NewPWAClient().(*pwaClient) + client := NewPWAClient(http.DefaultClient).(*pwaClient) t.Run("extracts stylesheets", func(t *testing.T) { html := []byte(``) @@ -427,7 +427,7 @@ func TestDownloadAndPackagePWA_FullManifest(t *testing.T) { })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) dn, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", nil) if err != nil { t.Fatalf("DownloadAndPackagePWA failed: %v", err) @@ -495,7 +495,7 @@ func TestDownloadAndPackagePWA_ServiceWorker(t *testing.T) { })) defer server.Close() - client := NewPWAClient() + client := NewPWAClient(http.DefaultClient) dn, err := client.DownloadAndPackagePWA(server.URL, server.URL+"/manifest.json", nil) if err != nil { t.Fatalf("DownloadAndPackagePWA failed: %v", err) diff --git a/pkg/website/website.go b/pkg/website/website.go index b2bd517..5fbbd86 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -26,13 +26,8 @@ type Downloader struct { errors []error } -// NewDownloader creates a new Downloader. -func NewDownloader(maxDepth int) *Downloader { - return NewDownloaderWithClient(maxDepth, http.DefaultClient) -} - -// NewDownloaderWithClient creates a new Downloader with a custom http.Client. -func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { +// NewDownloader creates a new Downloader with a custom http.Client. +func NewDownloader(maxDepth int, client *http.Client) *Downloader { return &Downloader{ dn: datanode.New(), visited: make(map[string]bool), @@ -43,13 +38,13 @@ func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { } // downloadAndPackageWebsite downloads a website and packages it into a DataNode. -func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { +func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) { baseURL, err := url.Parse(startURL) if err != nil { return nil, err } - d := NewDownloader(maxDepth) + d := NewDownloader(maxDepth, client) d.baseURL = baseURL d.progressBar = bar d.crawl(startURL, 0) diff --git a/pkg/website/website_test.go b/pkg/website/website_test.go index d3685e5..aad271b 100644 --- a/pkg/website/website_test.go +++ b/pkg/website/website_test.go @@ -20,7 +20,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 2, bar) + dn, err := DownloadAndPackageWebsite(server.URL, 2, bar, http.DefaultClient) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -52,7 +52,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { func TestDownloadAndPackageWebsite_Bad(t *testing.T) { t.Run("Invalid Start URL", func(t *testing.T) { - _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil) + _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil, http.DefaultClient) if err == nil { t.Fatal("Expected an error for an invalid start URL, but got nil") } @@ -63,7 +63,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) })) defer server.Close() - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient) if err == nil { t.Fatal("Expected an error for a server error on the start URL, but got nil") } @@ -80,7 +80,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { })) defer server.Close() // We expect an error because the link is broken. - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient) if err == nil { t.Fatal("Expected an error for a broken link, but got nil") } @@ -99,7 +99,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1 + dn, err := DownloadAndPackageWebsite(server.URL, 1, bar, http.DefaultClient) // Max depth of 1 if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -122,7 +122,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { fmt.Fprint(w, `External`) })) defer server.Close() - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -156,7 +156,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { // For now, we'll just test that it doesn't hang forever. done := make(chan bool) go func() { - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(server.URL, 1, nil, http.DefaultClient) if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { // We expect a timeout error, but other errors are failures. t.Errorf("unexpected error: %v", err)