diff --git a/cmd/cli/commands/search.go b/cmd/cli/commands/search.go index 6c8962f72..5c78204d2 100644 --- a/cmd/cli/commands/search.go +++ b/cmd/cli/commands/search.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" + "github.com/docker/go-units" "github.com/docker/model-runner/cmd/cli/commands/formatter" "github.com/docker/model-runner/cmd/cli/search" "github.com/spf13/cobra" @@ -94,7 +95,7 @@ Examples: func prettyPrintSearchResults(results []search.SearchResult) string { var buf bytes.Buffer table := newTable(&buf) - table.Header([]string{"NAME", "DESCRIPTION", "BACKEND", "DOWNLOADS", "STARS", "SOURCE"}) + table.Header([]string{"NAME", "DESCRIPTION", "BACKEND", "SIZE", "DOWNLOADS", "STARS", "SOURCE"}) for _, r := range results { name := r.Name @@ -105,6 +106,7 @@ func prettyPrintSearchResults(results []search.SearchResult) string { name, r.Description, r.Backend, + formatSize(r.Size), formatCount(r.Downloads), formatCount(r.Stars), r.Source, @@ -125,3 +127,11 @@ func formatCount(n int64) string { } return fmt.Sprintf("%d", n) } + +// formatSize formats a byte count as a human-readable size string +func formatSize(n int64) string { + if n <= 0 { + return "n/a" + } + return units.CustomSize("%.2f%s", float64(n), 1000.0, []string{"B", "kB", "MB", "GB", "TB"}) +} diff --git a/cmd/cli/commands/search_test.go b/cmd/cli/commands/search_test.go new file mode 100644 index 000000000..a38157d69 --- /dev/null +++ b/cmd/cli/commands/search_test.go @@ -0,0 +1,51 @@ +package commands + +import ( + "testing" +) + +func TestFormatSize(t *testing.T) { + tests := []struct { + name string + input int64 + want string + }{ + {name: "zero returns n/a", input: 0, want: "n/a"}, + {name: "negative returns n/a", input: -1, want: "n/a"}, + {name: "bytes", input: 500, want: "500.00B"}, + {name: "kilobytes", input: 1500, want: "1.50kB"}, + {name: "megabytes", input: 2_500_000, want: "2.50MB"}, + {name: "gigabytes", input: 4_300_000_000, want: "4.30GB"}, + {name: "terabytes", input: 1_200_000_000_000, want: "1.20TB"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatSize(tt.input); got != tt.want { + t.Errorf("formatSize(%d) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestFormatCount(t *testing.T) { + tests := []struct { + name string + input int64 + want string + }{ + {name: "zero", input: 0, want: "0"}, + {name: "hundreds", input: 999, want: "999"}, + {name: "thousands", input: 1_000, want: "1.0K"}, + {name: "thousands with decimal", input: 45_600, want: "45.6K"}, + {name: "millions", input: 1_200_000, want: "1.2M"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatCount(tt.input); got != tt.want { + t.Errorf("formatCount(%d) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/cmd/cli/search/backend_resolution.go b/cmd/cli/search/backend_resolution.go index bc0b30298..3e7075f58 100644 --- a/cmd/cli/search/backend_resolution.go +++ b/cmd/cli/search/backend_resolution.go @@ -24,7 +24,7 @@ const ( ) type backendResolver interface { - Resolve(ctx context.Context, target string) (string, error) + Resolve(ctx context.Context, target string) (backend string, size int64, err error) } type registryBackendResolver struct { @@ -38,36 +38,42 @@ func newRegistryBackendResolver() *registryBackendResolver { } } -func (r *registryBackendResolver) Resolve(ctx context.Context, target string) (string, error) { +func (r *registryBackendResolver) Resolve(ctx context.Context, target string) (string, int64, error) { model, err := r.lookup(ctx, withDefaultTag(target)) if err != nil { - return backendUnknown, err + return backendUnknown, 0, err } + backend := backendUnknown config, configErr := model.Config() if configErr == nil { - if backend := backendFromFormat(config.GetFormat()); backend != backendUnknown { - return backend, nil - } + backend = backendFromFormat(config.GetFormat()) } manifest, manifestErr := model.Manifest() if manifestErr != nil { if configErr != nil { - return backendUnknown, errors.Join(configErr, manifestErr) + return backendUnknown, 0, errors.Join(configErr, manifestErr) } - return backendUnknown, manifestErr + return backend, 0, manifestErr + } + + if backend == backendUnknown { + backend = backendFromManifestLayers(manifest) } - if backend := backendFromManifestLayers(manifest); backend != backendUnknown { - return backend, nil + var totalSize int64 + if manifest != nil { + for _, layer := range manifest.Layers { + totalSize += layer.Size + } } - if configErr != nil { - return backendUnknown, configErr + if backend == backendUnknown && configErr != nil { + return backendUnknown, totalSize, configErr } - return backendUnknown, nil + return backend, totalSize, nil } type huggingFaceRepoBackendResolver struct { @@ -81,12 +87,12 @@ func newHuggingFaceRepoBackendResolver() *huggingFaceRepoBackendResolver { } } -func (r *huggingFaceRepoBackendResolver) Resolve(ctx context.Context, target string) (string, error) { +func (r *huggingFaceRepoBackendResolver) Resolve(ctx context.Context, target string) (string, int64, error) { repoFiles, err := r.listFiles(ctx, target, "main") if err != nil { - return backendUnknown, err + return backendUnknown, 0, err } - return backendFromRepoFiles(repoFiles), nil + return backendFromRepoFiles(repoFiles), distributionhf.TotalSize(repoFiles), nil } func backendFromFormat(format disttypes.Format) string { @@ -152,7 +158,7 @@ func resolveSearchResultBackends( ctx context.Context, results []SearchResult, resolveConcurrency int, - resolve func(context.Context, SearchResult) (string, error), + resolve func(context.Context, SearchResult) (string, int64, error), ) []SearchResult { if len(results) == 0 { return results @@ -168,12 +174,13 @@ func resolveSearchResultBackends( for i := range resolved { group.Go(func() error { - backend, err := resolve(workerCtx, resolved[i]) + backend, size, err := resolve(workerCtx, resolved[i]) if err != nil || backend == "" { resolved[i].Backend = backendUnknown return nil } resolved[i].Backend = backend + resolved[i].Size = size return nil }) } diff --git a/cmd/cli/search/backend_resolution_test.go b/cmd/cli/search/backend_resolution_test.go index 9f0f6c032..4b8b6ba9e 100644 --- a/cmd/cli/search/backend_resolution_test.go +++ b/cmd/cli/search/backend_resolution_test.go @@ -180,13 +180,13 @@ func TestResolveSearchResultBackendsConcurrent(t *testing.T) { } } - resolve := func(_ context.Context, result SearchResult) (string, error) { + resolve := func(_ context.Context, result SearchResult) (string, int64, error) { for i, r := range results { if r.Name == result.Name { - return wantBackends[i], nil + return wantBackends[i], 0, nil } } - return backendUnknown, nil + return backendUnknown, 0, nil } resolved := resolveSearchResultBackends(t.Context(), results, numResults, resolve) diff --git a/cmd/cli/search/dockerhub.go b/cmd/cli/search/dockerhub.go index e9ca14112..ff178cc87 100644 --- a/cmd/cli/search/dockerhub.go +++ b/cmd/cli/search/dockerhub.go @@ -156,9 +156,9 @@ func (c *DockerHubClient) Search(ctx context.Context, opts SearchOptions) ([]Sea nextURL = response.Next } - return resolveSearchResultBackends(ctx, results, c.resolveConcurrency, func(ctx context.Context, result SearchResult) (string, error) { + return resolveSearchResultBackends(ctx, results, c.resolveConcurrency, func(ctx context.Context, result SearchResult) (string, int64, error) { if c.backendResolver == nil { - return backendUnknown, nil + return backendUnknown, 0, nil } return c.backendResolver.Resolve(ctx, result.Name) }), nil diff --git a/cmd/cli/search/dockerhub_test.go b/cmd/cli/search/dockerhub_test.go index 462548b26..238879b85 100644 --- a/cmd/cli/search/dockerhub_test.go +++ b/cmd/cli/search/dockerhub_test.go @@ -14,14 +14,14 @@ type fakeBackendResolver struct { errs map[string]error } -func (f fakeBackendResolver) Resolve(_ context.Context, target string) (string, error) { +func (f fakeBackendResolver) Resolve(_ context.Context, target string) (string, int64, error) { if err, ok := f.errs[target]; ok { - return backendUnknown, err + return backendUnknown, 0, err } if backend, ok := f.backends[target]; ok { - return backend, nil + return backend, 0, nil } - return backendUnknown, nil + return backendUnknown, 0, nil } func TestDockerHubSearchUsesVerifiedBackend(t *testing.T) { diff --git a/cmd/cli/search/huggingface.go b/cmd/cli/search/huggingface.go index a1a44d47f..7d5d5cd2d 100644 --- a/cmd/cli/search/huggingface.go +++ b/cmd/cli/search/huggingface.go @@ -125,9 +125,9 @@ func (c *HuggingFaceClient) Search(ctx context.Context, opts SearchOptions) ([]S }) } - return resolveSearchResultBackends(ctx, results, c.resolveConcurrency, func(ctx context.Context, result SearchResult) (string, error) { + return resolveSearchResultBackends(ctx, results, c.resolveConcurrency, func(ctx context.Context, result SearchResult) (string, int64, error) { if c.backendResolver == nil { - return backendUnknown, nil + return backendUnknown, 0, nil } return c.backendResolver.Resolve(ctx, result.Name) }), nil diff --git a/cmd/cli/search/types.go b/cmd/cli/search/types.go index 4c7e6fa06..2267ed812 100644 --- a/cmd/cli/search/types.go +++ b/cmd/cli/search/types.go @@ -18,6 +18,7 @@ type SearchResult struct { Official bool // Whether this is an official model UpdatedAt string // Last update timestamp Backend string // Backend type: "llama.cpp", "vllm", "diffusers", "unknown", or a comma-separated combination + Size int64 // Total size in bytes (0 if unknown) } // SearchOptions configures the search behavior