diff --git a/cmd/collect_github_repos.go b/cmd/collect_github_repos.go index dfcd315..cdb78d9 100644 --- a/cmd/collect_github_repos.go +++ b/cmd/collect_github_repos.go @@ -2,8 +2,14 @@ package cmd import ( "fmt" + "os" + "github.com/Snider/Borg/pkg/compress" "github.com/Snider/Borg/pkg/github" + "github.com/Snider/Borg/pkg/tim" + "github.com/Snider/Borg/pkg/trix" + "github.com/Snider/Borg/pkg/ui" + "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" ) @@ -17,17 +23,80 @@ var collectGithubReposCmd = &cobra.Command{ Short: "Collects all public repositories for a user or organization", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + parallel, _ := cmd.Flags().GetInt("parallel") + outputFile, _ := cmd.Flags().GetString("output") + format, _ := cmd.Flags().GetString("format") + compression, _ := cmd.Flags().GetString("compression") + password, _ := cmd.Flags().GetString("password") + repos, err := GithubClient.GetPublicRepos(cmd.Context(), args[0]) if err != nil { return err } - for _, repo := range repos { - fmt.Fprintln(cmd.OutOrStdout(), repo) + + prompter := ui.NewNonInteractivePrompter(ui.GetVCSQuote) + prompter.Start() + defer prompter.Stop() + var bar *progressbar.ProgressBar + if prompter.IsInteractive() { + bar = ui.NewProgressBar(len(repos), "Cloning repositories") } + + downloader := github.NewDownloader(parallel, bar) + dn, err := downloader.DownloadRepositories(cmd.Context(), repos) + if err != nil { + return err + } + + var data []byte + if format == "tim" { + tim, err := tim.FromDataNode(dn) + if err != nil { + return fmt.Errorf("error creating tim: %w", err) + } + data, err = tim.ToTar() + if err != nil { + return fmt.Errorf("error serializing tim: %w", err) + } + } else if format == "trix" { + data, err = trix.ToTrix(dn, password) + if err != nil { + return fmt.Errorf("error serializing trix: %w", err) + } + } else { + data, err = dn.ToTar() + if err != nil { + return fmt.Errorf("error serializing DataNode: %w", err) + } + } + + compressedData, err := compress.Compress(data, compression) + if err != nil { + return fmt.Errorf("error compressing data: %w", err) + } + + if outputFile == "" { + outputFile = args[0] + "." + format + if compression != "none" { + outputFile += "." + compression + } + } + + err = os.WriteFile(outputFile, compressedData, 0644) + if err != nil { + return fmt.Errorf("error writing repos to file: %w", err) + } + + fmt.Fprintln(cmd.OutOrStdout(), "Repositories saved to", outputFile) return nil }, } func init() { collectGithubCmd.AddCommand(collectGithubReposCmd) + collectGithubReposCmd.PersistentFlags().Int("parallel", 1, "Number of concurrent workers") + collectGithubReposCmd.PersistentFlags().String("output", "", "Output file for the DataNode") + collectGithubReposCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)") + collectGithubReposCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") + collectGithubReposCmd.PersistentFlags().String("password", "", "Password for encryption") } diff --git a/cmd/collect_website.go b/cmd/collect_website.go index 3811f32..c908fd4 100644 --- a/cmd/collect_website.go +++ b/cmd/collect_website.go @@ -35,6 +35,8 @@ func NewCollectWebsiteCmd() *cobra.Command { websiteURL := args[0] outputFile, _ := cmd.Flags().GetString("output") depth, _ := cmd.Flags().GetInt("depth") + parallel, _ := cmd.Flags().GetInt("parallel") + rateLimit, _ := cmd.Flags().GetFloat64("rate-limit") format, _ := cmd.Flags().GetString("format") compression, _ := cmd.Flags().GetString("compression") password, _ := cmd.Flags().GetString("password") @@ -51,7 +53,7 @@ func NewCollectWebsiteCmd() *cobra.Command { bar = ui.NewProgressBar(-1, "Crawling website") } - dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar) + dn, err := website.DownloadAndPackageWebsite(cmd.Context(), websiteURL, depth, parallel, rateLimit, bar) if err != nil { return fmt.Errorf("error downloading and packaging website: %w", err) } @@ -101,6 +103,8 @@ func NewCollectWebsiteCmd() *cobra.Command { } collectWebsiteCmd.PersistentFlags().String("output", "", "Output file for the DataNode") collectWebsiteCmd.PersistentFlags().Int("depth", 2, "Recursion depth for downloading") + collectWebsiteCmd.PersistentFlags().Int("parallel", 1, "Number of concurrent workers") + collectWebsiteCmd.PersistentFlags().Float64("rate-limit", 0, "Max requests per second per domain") collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)") collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)") collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption") diff --git a/cmd/collect_website_test.go b/cmd/collect_website_test.go index 2c39674..edc759c 100644 --- a/cmd/collect_website_test.go +++ b/cmd/collect_website_test.go @@ -11,10 +11,14 @@ import ( "github.com/schollz/progressbar/v3" ) +import ( + "context" +) + func TestCollectWebsiteCmd_Good(t *testing.T) { // Mock the website downloader oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(ctx context.Context, startURL string, maxDepth, parallel int, rateLimit float64, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { return datanode.New(), nil } defer func() { @@ -35,7 +39,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) { func TestCollectWebsiteCmd_Bad(t *testing.T) { // Mock the website downloader to return an error oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite - website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { + website.DownloadAndPackageWebsite = func(ctx context.Context, startURL string, maxDepth, parallel int, rateLimit float64, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { return nil, fmt.Errorf("website error") } defer func() { diff --git a/examples/collect_website/main.go b/examples/collect_website/main.go index 2e2f606..23d1df4 100644 --- a/examples/collect_website/main.go +++ b/examples/collect_website/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "os" @@ -11,7 +12,7 @@ func main() { log.Println("Collecting website...") // Download and package the website. - dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil) + dn, err := website.DownloadAndPackageWebsite(context.Background(), "https://example.com", 2, 1, 0, nil) if err != nil { log.Fatalf("Failed to collect website: %v", err) } diff --git a/go.mod b/go.mod index d1c5f08..f9b3ff8 100644 --- a/go.mod +++ b/go.mod @@ -64,5 +64,6 @@ require ( golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect + golang.org/x/time v0.8.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect ) diff --git a/go.sum b/go.sum index 2a41157..95ab1a1 100644 --- a/go.sum +++ b/go.sum @@ -192,6 +192,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= diff --git a/pkg/github/downloader.go b/pkg/github/downloader.go new file mode 100644 index 0000000..e5462e3 --- /dev/null +++ b/pkg/github/downloader.go @@ -0,0 +1,128 @@ +package github + +import ( + "context" + "fmt" + "io" + "io/fs" + "net/url" + "strings" + "sync" + + "github.com/Snider/Borg/pkg/datanode" + "github.com/Snider/Borg/pkg/vcs" + "github.com/schollz/progressbar/v3" +) + +// Downloader manages a pool of workers for cloning repositories. +type Downloader struct { + parallel int + bar *progressbar.ProgressBar + cloner vcs.GitCloner +} + +// NewDownloader creates a new Downloader. +func NewDownloader(parallel int, bar *progressbar.ProgressBar) *Downloader { + return &Downloader{ + parallel: parallel, + bar: bar, + cloner: vcs.NewGitCloner(), + } +} + +// DownloadRepositories downloads a list of repositories in parallel. +func (d *Downloader) DownloadRepositories(ctx context.Context, repos []string) (*datanode.DataNode, error) { + var wg sync.WaitGroup + repoChan := make(chan string, len(repos)) + errChan := make(chan error, len(repos)) + mergedDN := datanode.New() + var mu sync.Mutex + + for i := 0; i < d.parallel; i++ { + wg.Add(1) + go d.worker(ctx, &wg, repoChan, mergedDN, &mu, errChan) + } + + for _, repo := range repos { + select { + case repoChan <- repo: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + close(repoChan) + + wg.Wait() + close(errChan) + + var errs []error + for err := range errChan { + errs = append(errs, err) + } + if len(errs) > 0 { + return nil, fmt.Errorf("errors cloning repositories: %v", errs) + } + + return mergedDN, nil +} + +func (d *Downloader) worker(ctx context.Context, wg *sync.WaitGroup, repoChan <-chan string, mergedDN *datanode.DataNode, mu *sync.Mutex, errChan chan<- error) { + defer wg.Done() + for repoURL := range repoChan { + select { + case <-ctx.Done(): + return + default: + } + + repoName, err := GetRepoNameFromURL(repoURL) + if err != nil { + errChan <- err + continue + } + + dn, err := d.cloner.CloneGitRepository(repoURL, nil) + if err != nil { + errChan <- fmt.Errorf("error cloning %s: %w", repoURL, err) + continue + } + + err = dn.Walk(".", func(path string, de fs.DirEntry, err error) error { + if err != nil { + return err + } + if !de.IsDir() { + file, err := dn.Open(path) + if err != nil { + return err + } + defer file.Close() + content, err := io.ReadAll(file) + if err != nil { + return err + } + mu.Lock() + mergedDN.AddData(fmt.Sprintf("%s/%s", repoName, path), content) + mu.Unlock() + } + return nil + }) + if err != nil { + errChan <- err + } + + if d.bar != nil { + d.bar.Add(1) + } + } +} + +// GetRepoNameFromURL extracts the repository name from a Git URL. +func GetRepoNameFromURL(repoURL string) (string, error) { + u, err := url.Parse(repoURL) + if err != nil { + return "", err + } + path := strings.TrimSuffix(u.Path, ".git") + return strings.TrimPrefix(path, "/"), nil +} diff --git a/pkg/website/website.go b/pkg/website/website.go index b2bd517..59922c4 100644 --- a/pkg/website/website.go +++ b/pkg/website/website.go @@ -1,6 +1,7 @@ package website import ( + "context" "fmt" "io" "net/http" @@ -9,8 +10,9 @@ import ( "github.com/Snider/Borg/pkg/datanode" "github.com/schollz/progressbar/v3" - "golang.org/x/net/html" + "golang.org/x/time/rate" + "sync" ) var DownloadAndPackageWebsite = downloadAndPackageWebsite @@ -21,38 +23,51 @@ type Downloader struct { dn *datanode.DataNode visited map[string]bool maxDepth int + parallel int progressBar *progressbar.ProgressBar client *http.Client errors []error + mu sync.Mutex + limiter *rate.Limiter } // NewDownloader creates a new Downloader. -func NewDownloader(maxDepth int) *Downloader { - return NewDownloaderWithClient(maxDepth, http.DefaultClient) +func NewDownloader(maxDepth, parallel int, rateLimit float64) *Downloader { + return NewDownloaderWithClient(maxDepth, parallel, rateLimit, http.DefaultClient) } // NewDownloaderWithClient creates a new Downloader with a custom http.Client. -func NewDownloaderWithClient(maxDepth int, client *http.Client) *Downloader { +func NewDownloaderWithClient(maxDepth, parallel int, rateLimit float64, client *http.Client) *Downloader { + var limiter *rate.Limiter + if rateLimit > 0 { + limiter = rate.NewLimiter(rate.Limit(rateLimit), 1) + } return &Downloader{ dn: datanode.New(), visited: make(map[string]bool), maxDepth: maxDepth, + parallel: parallel, client: client, errors: make([]error, 0), + limiter: limiter, } } // downloadAndPackageWebsite downloads a website and packages it into a DataNode. -func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { +func downloadAndPackageWebsite(ctx context.Context, startURL string, maxDepth, parallel int, rateLimit float64, bar *progressbar.ProgressBar) (*datanode.DataNode, error) { baseURL, err := url.Parse(startURL) if err != nil { return nil, err } - d := NewDownloader(maxDepth) + d := NewDownloader(maxDepth, parallel, rateLimit) d.baseURL = baseURL d.progressBar = bar - d.crawl(startURL, 0) + d.crawl(ctx, startURL) + + if err := ctx.Err(); err != nil { + return nil, err + } if len(d.errors) > 0 { var errs []string @@ -65,102 +80,136 @@ func downloadAndPackageWebsite(startURL string, maxDepth int, bar *progressbar.P return d.dn, nil } -func (d *Downloader) crawl(pageURL string, depth int) { - if depth > d.maxDepth || d.visited[pageURL] { - return - } - d.visited[pageURL] = true - if d.progressBar != nil { - d.progressBar.Add(1) - } +type crawlJob struct { + url string + depth int +} - resp, err := d.client.Get(pageURL) - if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error getting %s: %w", pageURL, err)) - return - } - defer resp.Body.Close() +func (d *Downloader) crawl(ctx context.Context, startURL string) { + var wg sync.WaitGroup + var jobWg sync.WaitGroup + jobChan := make(chan crawlJob, 100) - if resp.StatusCode >= 400 { - d.errors = append(d.errors, fmt.Errorf("bad status for %s: %s", pageURL, resp.Status)) - return + for i := 0; i < d.parallel; i++ { + wg.Add(1) + go d.worker(ctx, &wg, &jobWg, jobChan) } - body, err := io.ReadAll(resp.Body) - if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error reading body of %s: %w", pageURL, err)) - return - } + jobWg.Add(1) + jobChan <- crawlJob{url: startURL, depth: 0} - relPath := d.getRelativePath(pageURL) - d.dn.AddData(relPath, body) + go func() { + jobWg.Wait() + close(jobChan) + }() - // Don't try to parse non-html content - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") { - return - } + wg.Wait() +} - doc, err := html.Parse(strings.NewReader(string(body))) - if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error parsing HTML of %s: %w", pageURL, err)) - return - } +func (d *Downloader) worker(ctx context.Context, wg *sync.WaitGroup, jobWg *sync.WaitGroup, jobChan chan crawlJob) { + defer wg.Done() + for job := range jobChan { + func() { + defer jobWg.Done() - var f func(*html.Node) - f = func(n *html.Node) { - if n.Type == html.ElementNode { - for _, a := range n.Attr { - if a.Key == "href" || a.Key == "src" { - link, err := d.resolveURL(pageURL, a.Val) - if err != nil { - continue - } - if d.isLocal(link) { - if isAsset(link) { - d.downloadAsset(link) - } else { - d.crawl(link, depth+1) - } - } - } + select { + case <-ctx.Done(): + return + default: } - } - for c := n.FirstChild; c != nil; c = c.NextSibling { - f(c) - } - } - f(doc) -} -func (d *Downloader) downloadAsset(assetURL string) { - if d.visited[assetURL] { - return - } - d.visited[assetURL] = true - if d.progressBar != nil { - d.progressBar.Add(1) - } + if job.depth > d.maxDepth { + return + } - resp, err := d.client.Get(assetURL) - if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error getting asset %s: %w", assetURL, err)) - return - } - defer resp.Body.Close() + d.mu.Lock() + if d.visited[job.url] { + d.mu.Unlock() + return + } + d.visited[job.url] = true + d.mu.Unlock() - if resp.StatusCode >= 400 { - d.errors = append(d.errors, fmt.Errorf("bad status for asset %s: %s", assetURL, resp.Status)) - return - } + if d.progressBar != nil { + d.progressBar.Add(1) + } - body, err := io.ReadAll(resp.Body) - if err != nil { - d.errors = append(d.errors, fmt.Errorf("Error reading body of asset %s: %w", assetURL, err)) - return + if d.limiter != nil { + d.limiter.Wait(ctx) + } + + req, err := http.NewRequestWithContext(ctx, "GET", job.url, nil) + if err != nil { + d.addError(fmt.Errorf("Error creating request for %s: %w", job.url, err)) + return + } + resp, err := d.client.Do(req) + if err != nil { + d.addError(fmt.Errorf("Error getting %s: %w", job.url, err)) + return + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + d.addError(fmt.Errorf("bad status for %s: %s", job.url, resp.Status)) + return + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + d.addError(fmt.Errorf("Error reading body of %s: %w", job.url, err)) + return + } + + relPath := d.getRelativePath(job.url) + d.mu.Lock() + d.dn.AddData(relPath, body) + d.mu.Unlock() + + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") { + return + } + + doc, err := html.Parse(strings.NewReader(string(body))) + if err != nil { + d.addError(fmt.Errorf("Error parsing HTML of %s: %w", job.url, err)) + return + } + + var f func(*html.Node) + f = func(n *html.Node) { + if n.Type == html.ElementNode { + for _, a := range n.Attr { + if a.Key == "href" || a.Key == "src" { + link, err := d.resolveURL(job.url, a.Val) + if err != nil { + continue + } + if d.isLocal(link) { + select { + case <-ctx.Done(): + return + default: + jobWg.Add(1) + jobChan <- crawlJob{url: link, depth: job.depth + 1} + } + } + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + f(doc) + }() } +} - relPath := d.getRelativePath(assetURL) - d.dn.AddData(relPath, body) +func (d *Downloader) addError(err error) { + d.mu.Lock() + d.errors = append(d.errors, err) + d.mu.Unlock() } func (d *Downloader) getRelativePath(pageURL string) string { diff --git a/pkg/website/website_test.go b/pkg/website/website_test.go index d3685e5..bd71c77 100644 --- a/pkg/website/website_test.go +++ b/pkg/website/website_test.go @@ -1,6 +1,7 @@ package website import ( + "context" "fmt" "io" "io/fs" @@ -20,7 +21,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 2, bar) + dn, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 2, 1, 0, bar) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -52,7 +53,7 @@ func TestDownloadAndPackageWebsite_Good(t *testing.T) { func TestDownloadAndPackageWebsite_Bad(t *testing.T) { t.Run("Invalid Start URL", func(t *testing.T) { - _, err := DownloadAndPackageWebsite("http://invalid-url", 1, nil) + _, err := DownloadAndPackageWebsite(context.TODO(), "http://invalid-url", 1, 1, 0, nil) if err == nil { t.Fatal("Expected an error for an invalid start URL, but got nil") } @@ -63,7 +64,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) })) defer server.Close() - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, nil) if err == nil { t.Fatal("Expected an error for a server error on the start URL, but got nil") } @@ -80,7 +81,7 @@ func TestDownloadAndPackageWebsite_Bad(t *testing.T) { })) defer server.Close() // We expect an error because the link is broken. - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, nil) if err == nil { t.Fatal("Expected an error for a broken link, but got nil") } @@ -99,7 +100,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { defer server.Close() bar := progressbar.NewOptions(1, progressbar.OptionSetWriter(io.Discard)) - dn, err := DownloadAndPackageWebsite(server.URL, 1, bar) // Max depth of 1 + dn, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, bar) // Max depth of 1 if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -122,7 +123,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { fmt.Fprint(w, `External`) })) defer server.Close() - dn, err := DownloadAndPackageWebsite(server.URL, 1, nil) + dn, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, nil) if err != nil { t.Fatalf("DownloadAndPackageWebsite failed: %v", err) } @@ -156,7 +157,7 @@ func TestDownloadAndPackageWebsite_Ugly(t *testing.T) { // For now, we'll just test that it doesn't hang forever. done := make(chan bool) go func() { - _, err := DownloadAndPackageWebsite(server.URL, 1, nil) + _, err := DownloadAndPackageWebsite(context.TODO(), server.URL, 1, 1, 0, nil) if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { // We expect a timeout error, but other errors are failures. t.Errorf("unexpected error: %v", err)