Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cmd/all.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/datanode"
"github.com/Snider/Borg/pkg/github"
"github.com/Snider/Borg/pkg/httpclient"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
"github.com/Snider/Borg/pkg/ui"
Expand Down Expand Up @@ -42,7 +43,15 @@ func NewAllCmd() *cobra.Command {
return err
}

repos, err := GithubClient.GetPublicRepos(cmd.Context(), owner)
totalTimeout, _ := cmd.Flags().GetDuration("timeout")
connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout")
tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout")
headerTimeout, _ := cmd.Flags().GetDuration("header-timeout")
Comment on lines +46 to +49

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The errors returned from cmd.Flags().GetDuration() are being ignored. If a user provides a malformed duration string (e.g., "--timeout=10xyz"), the flag will use its default value and return an error, but the user won't be notified of their invalid input. These errors should be handled to provide better feedback.

This same issue is present in other files where these flags are parsed, such as cmd/collect_github_repo.go, cmd/collect_github_repos.go, cmd/collect_pwa.go, and cmd/collect_website.go.

Suggested change
totalTimeout, _ := cmd.Flags().GetDuration("timeout")
connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout")
tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout")
headerTimeout, _ := cmd.Flags().GetDuration("header-timeout")
totalTimeout, err := cmd.Flags().GetDuration("timeout")
if err != nil {
return fmt.Errorf("invalid value for 'timeout': %w", err)
}
connectTimeout, err := cmd.Flags().GetDuration("connect-timeout")
if err != nil {
return fmt.Errorf("invalid value for 'connect-timeout': %w", err)
}
tlsTimeout, err := cmd.Flags().GetDuration("tls-timeout")
if err != nil {
return fmt.Errorf("invalid value for 'tls-timeout': %w", err)
}
headerTimeout, err := cmd.Flags().GetDuration("header-timeout")
if err != nil {
return fmt.Errorf("invalid value for 'header-timeout': %w", err)
}


httpClient := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout)
githubClient := github.NewGithubClient(httpClient)

repos, err := githubClient.GetPublicRepos(cmd.Context(), owner)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions cmd/all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestAllCmd_Good(t *testing.T) {
},
})
oldNewAuthenticatedClient := github.NewAuthenticatedClient
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client {
return mockGithubClient
}
defer func() {
Expand Down Expand Up @@ -67,7 +67,7 @@ func TestAllCmd_Bad(t *testing.T) {
},
})
oldNewAuthenticatedClient := github.NewAuthenticatedClient
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client {
return mockGithubClient
}
defer func() {
Expand Down Expand Up @@ -96,7 +96,7 @@ func TestAllCmd_Ugly(t *testing.T) {
},
})
oldNewAuthenticatedClient := github.NewAuthenticatedClient
github.NewAuthenticatedClient = func(ctx context.Context) *http.Client {
github.NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client {
return mockGithubClient
}
defer func() {
Expand Down
11 changes: 10 additions & 1 deletion cmd/collect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cmd

import (
"time"

"github.com/spf13/cobra"
)

Expand All @@ -11,11 +13,18 @@ func init() {
RootCmd.AddCommand(GetCollectCmd())
}
func NewCollectCmd() *cobra.Command {
return &cobra.Command{
cmd := &cobra.Command{
Use: "collect",
Short: "Collect a resource from a URI.",
Long: `Collect a resource from a URI and store it in a DataNode.`,
}

cmd.PersistentFlags().Duration("timeout", 0, "Total request timeout (e.g., 60s). 0 means no timeout.")
cmd.PersistentFlags().Duration("connect-timeout", 10*time.Second, "TCP connection establishment timeout (e.g., 10s)")
cmd.PersistentFlags().Duration("tls-timeout", 10*time.Second, "TLS handshake timeout (e.g., 10s)")
cmd.PersistentFlags().Duration("header-timeout", 30*time.Second, "Time to receive response headers timeout (e.g., 30s)")

return cmd
}

func GetCollectCmd() *cobra.Command {
Expand Down
8 changes: 8 additions & 0 deletions cmd/collect_github_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"

"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/httpclient"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
"github.com/Snider/Borg/pkg/ui"
Expand Down Expand Up @@ -44,6 +45,13 @@ func NewCollectGithubRepoCmd() *cobra.Command {
return fmt.Errorf("invalid compression: %s (must be 'none', 'gz', or 'xz')", compression)
}

totalTimeout, _ := cmd.Flags().GetDuration("timeout")
connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout")
tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout")
headerTimeout, _ := cmd.Flags().GetDuration("header-timeout")

_ = httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The http.Client created here with the specified timeouts is immediately discarded by assigning it to the blank identifier _. As a result, the subsequent call to GitCloner.CloneGitRepository uses a default HTTP client, and the new timeout flags have no effect for this command. This is a critical bug as the feature does not work as advertised for collect repo.

To fix this, the vcs.GitCloner implementation needs to be updated to accept and use a custom http.Client. You would then instantiate a local cloner here with the httpClient you've created, rather than using the global GitCloner variable.

Suggested change
_ = httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout)
httpClient := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout)
// TODO: Pass this httpClient to the GitCloner. The cloner needs to be updated to use it.


prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote)
prompter.Start()
defer prompter.Stop()
Expand Down
16 changes: 10 additions & 6 deletions cmd/collect_github_repos.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@ import (
"fmt"

"github.com/Snider/Borg/pkg/github"
"github.com/Snider/Borg/pkg/httpclient"
"github.com/spf13/cobra"
)

var (
// GithubClient is the github client used by the command. It can be replaced for testing.
GithubClient = github.NewGithubClient()
)

var collectGithubReposCmd = &cobra.Command{
Use: "repos [user-or-org]",
Short: "Collects all public repositories for a user or organization",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0])
totalTimeout, _ := cmd.Flags().GetDuration("timeout")
connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout")
tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout")
headerTimeout, _ := cmd.Flags().GetDuration("header-timeout")

httpClient := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout)
githubClient := github.NewGithubClient(httpClient)

repos, err := githubClient.GetPublicRepos(cmd.Context(), args[0])
if err != nil {
return err
}
Expand Down
37 changes: 19 additions & 18 deletions cmd/collect_pwa.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"

"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/httpclient"
"github.com/Snider/Borg/pkg/pwa"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
Expand All @@ -13,17 +14,9 @@ import (
"github.com/spf13/cobra"
)

type CollectPWACmd struct {
cobra.Command
PWAClient pwa.PWAClient
}

// NewCollectPWACmd creates a new collect pwa command
func NewCollectPWACmd() *CollectPWACmd {
c := &CollectPWACmd{
PWAClient: pwa.NewPWAClient(),
}
c.Command = cobra.Command{
func NewCollectPWACmd() *cobra.Command {
cmd := &cobra.Command{
Use: "pwa [url]",
Short: "Collect a single PWA using a URI",
Long: `Collect a single PWA and store it in a DataNode.
Expand All @@ -44,24 +37,32 @@ Examples:
compression, _ := cmd.Flags().GetString("compression")
password, _ := cmd.Flags().GetString("password")

finalPath, err := CollectPWA(c.PWAClient, pwaURL, outputFile, format, compression, password)
totalTimeout, _ := cmd.Flags().GetDuration("timeout")
connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout")
tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout")
headerTimeout, _ := cmd.Flags().GetDuration("header-timeout")

httpClient := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout)
pwaClient := pwa.NewPWAClient(httpClient)

finalPath, err := CollectPWA(pwaClient, pwaURL, outputFile, format, compression, password)
if err != nil {
return err
}
fmt.Fprintln(cmd.OutOrStdout(), "PWA saved to", finalPath)
return nil
},
}
c.Flags().String("uri", "", "The URI of the PWA to collect (can also be passed as positional arg)")
c.Flags().String("output", "", "Output file for the DataNode")
c.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)")
c.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
c.Flags().String("password", "", "Password for encryption (required for stim format)")
return c
cmd.Flags().String("uri", "", "The URI of the PWA to collect (can also be passed as positional arg)")
cmd.Flags().String("output", "", "Output file for the DataNode")
cmd.Flags().String("format", "datanode", "Output format (datanode, tim, trix, or stim)")
cmd.Flags().String("compression", "none", "Compression format (none, gz, or xz)")
cmd.Flags().String("password", "", "Password for encryption (required for stim format)")
return cmd
}

func init() {
collectCmd.AddCommand(&NewCollectPWACmd().Command)
collectCmd.AddCommand(NewCollectPWACmd())
}
func CollectPWA(client pwa.PWAClient, pwaURL string, outputFile string, format string, compression string, password string) (string, error) {
if pwaURL == "" {
Expand Down
10 changes: 9 additions & 1 deletion cmd/collect_website.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/schollz/progressbar/v3"
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/httpclient"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
"github.com/Snider/Borg/pkg/ui"
Expand Down Expand Up @@ -51,7 +52,14 @@ func NewCollectWebsiteCmd() *cobra.Command {
bar = ui.NewProgressBar(-1, "Crawling website")
}

dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
totalTimeout, _ := cmd.Flags().GetDuration("timeout")
connectTimeout, _ := cmd.Flags().GetDuration("connect-timeout")
tlsTimeout, _ := cmd.Flags().GetDuration("tls-timeout")
headerTimeout, _ := cmd.Flags().GetDuration("header-timeout")

client := httpclient.NewClient(totalTimeout, connectTimeout, tlsTimeout, headerTimeout)

dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar, client)
if err != nil {
return fmt.Errorf("error downloading and packaging website: %w", err)
}
Expand Down
5 changes: 3 additions & 2 deletions cmd/collect_website_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"net/http"
"path/filepath"
"strings"
"testing"
Expand All @@ -14,7 +15,7 @@ import (
func TestCollectWebsiteCmd_Good(t *testing.T) {
// Mock the website downloader
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
return datanode.New(), nil
}
defer func() {
Expand All @@ -35,7 +36,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
func TestCollectWebsiteCmd_Bad(t *testing.T) {
// Mock the website downloader to return an error
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar, client *http.Client) (*datanode.DataNode, error) {
return nil, fmt.Errorf("website error")
}
defer func() {
Expand Down
3 changes: 2 additions & 1 deletion examples/all/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"net/http"
"os"

"github.com/Snider/Borg/pkg/github"
Expand All @@ -13,7 +14,7 @@ import (
func main() {
log.Println("Collecting all repositories for a user...")

repos, err := github.NewGithubClient().GetPublicRepos(context.Background(), "Snider")
repos, err := github.NewGithubClient(http.DefaultClient).GetPublicRepos(context.Background(), "Snider")
if err != nil {
log.Fatalf("Failed to get public repos: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion examples/collect_pwa/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"log"
"net/http"
"os"

"github.com/Snider/Borg/pkg/pwa"
Expand All @@ -10,7 +11,7 @@ import (
func main() {
log.Println("Collecting PWA...")

client := pwa.NewPWAClient()
client := pwa.NewPWAClient(http.DefaultClient)
pwaURL := "https://squoosh.app"

manifestURL, err := client.FindManifest(pwaURL)
Expand Down
3 changes: 2 additions & 1 deletion examples/collect_website/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"log"
"net/http"
"os"

"github.com/Snider/Borg/pkg/website"
Expand All @@ -11,7 +12,7 @@ func main() {
log.Println("Collecting website...")

// Download and package the website.
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil, http.DefaultClient)
if err != nil {
log.Fatalf("Failed to collect website: %v", err)
}
Expand Down
21 changes: 14 additions & 7 deletions pkg/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,37 @@ type GithubClient interface {
}

// NewGithubClient creates a new GithubClient.
func NewGithubClient() GithubClient {
return &githubClient{}
func NewGithubClient(client *http.Client) GithubClient {
return &githubClient{client: client}
}

type githubClient struct{}
type githubClient struct {
client *http.Client
}

// NewAuthenticatedClient creates a new authenticated http client.
var NewAuthenticatedClient = func(ctx context.Context) *http.Client {
var NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client {
token := os.Getenv("GITHUB_TOKEN")
if token == "" {
return http.DefaultClient
return baseClient
}
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: token},
)
return oauth2.NewClient(ctx, ts)
authedClient := oauth2.NewClient(ctx, ts)
authedClient.Transport = &oauth2.Transport{
Base: baseClient.Transport,
Source: ts,
}
return authedClient
}

func (g *githubClient) GetPublicRepos(ctx context.Context, userOrOrg string) ([]string, error) {
return g.getPublicReposWithAPIURL(ctx, "https://api.github.com", userOrOrg)
}

func (g *githubClient) getPublicReposWithAPIURL(ctx context.Context, apiURL, userOrOrg string) ([]string, error) {
client := NewAuthenticatedClient(ctx)
client := NewAuthenticatedClient(ctx, g.client)
var allCloneURLs []string
url := fmt.Sprintf("%s/users/%s/repos", apiURL, userOrOrg)
isFirstRequest := true
Expand Down
8 changes: 4 additions & 4 deletions pkg/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestFindNextURL_Ugly(t *testing.T) {

func TestNewAuthenticatedClient_Good(t *testing.T) {
t.Setenv("GITHUB_TOKEN", "test-token")
client := NewAuthenticatedClient(context.Background())
client := NewAuthenticatedClient(context.Background(), http.DefaultClient)
if client == http.DefaultClient {
t.Error("expected an authenticated client, but got http.DefaultClient")
}
Expand All @@ -163,17 +163,17 @@ func TestNewAuthenticatedClient_Good(t *testing.T) {
func TestNewAuthenticatedClient_Bad(t *testing.T) {
// Unset the variable to ensure it's not present
t.Setenv("GITHUB_TOKEN", "")
client := NewAuthenticatedClient(context.Background())
client := NewAuthenticatedClient(context.Background(), http.DefaultClient)
if client != http.DefaultClient {
t.Error("expected http.DefaultClient when no token is set, but got something else")
}
}

// setupMockClient is a helper function to inject a mock http.Client.
func setupMockClient(t *testing.T, mock *http.Client) *githubClient {
client := &githubClient{}
client := &githubClient{client: mock}
originalNewAuthenticatedClient := NewAuthenticatedClient
NewAuthenticatedClient = func(ctx context.Context) *http.Client {
NewAuthenticatedClient = func(ctx context.Context, baseClient *http.Client) *http.Client {
return mock
}
// Restore the original function after the test
Expand Down
Loading
Loading