Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion cmd/cli/commands/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"

"github.com/docker/go-units"
"github.com/docker/model-runner/cmd/cli/commands/formatter"
"github.com/docker/model-runner/cmd/cli/search"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -94,7 +95,7 @@ Examples:
func prettyPrintSearchResults(results []search.SearchResult) string {
var buf bytes.Buffer
table := newTable(&buf)
table.Header([]string{"NAME", "DESCRIPTION", "BACKEND", "DOWNLOADS", "STARS", "SOURCE"})
table.Header([]string{"NAME", "DESCRIPTION", "BACKEND", "SIZE", "DOWNLOADS", "STARS", "SOURCE"})

for _, r := range results {
name := r.Name
Expand All @@ -105,6 +106,7 @@ func prettyPrintSearchResults(results []search.SearchResult) string {
name,
r.Description,
r.Backend,
formatSize(r.Size),
formatCount(r.Downloads),
formatCount(r.Stars),
r.Source,
Expand All @@ -125,3 +127,11 @@ func formatCount(n int64) string {
}
return fmt.Sprintf("%d", n)
}

// formatSize formats a byte count as a human-readable size string
func formatSize(n int64) string {
if n <= 0 {
return "n/a"
}
return units.CustomSize("%.2f%s", float64(n), 1000.0, []string{"B", "kB", "MB", "GB", "TB"})
}
51 changes: 51 additions & 0 deletions cmd/cli/commands/search_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package commands

import (
"testing"
)

func TestFormatSize(t *testing.T) {
tests := []struct {
name string
input int64
want string
}{
{name: "zero returns n/a", input: 0, want: "n/a"},
{name: "negative returns n/a", input: -1, want: "n/a"},
{name: "bytes", input: 500, want: "500.00B"},
{name: "kilobytes", input: 1500, want: "1.50kB"},
{name: "megabytes", input: 2_500_000, want: "2.50MB"},
{name: "gigabytes", input: 4_300_000_000, want: "4.30GB"},
{name: "terabytes", input: 1_200_000_000_000, want: "1.20TB"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := formatSize(tt.input); got != tt.want {
t.Errorf("formatSize(%d) = %q, want %q", tt.input, got, tt.want)
}
})
}
}

func TestFormatCount(t *testing.T) {
tests := []struct {
name string
input int64
want string
}{
{name: "zero", input: 0, want: "0"},
{name: "hundreds", input: 999, want: "999"},
{name: "thousands", input: 1_000, want: "1.0K"},
{name: "thousands with decimal", input: 45_600, want: "45.6K"},
{name: "millions", input: 1_200_000, want: "1.2M"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := formatCount(tt.input); got != tt.want {
t.Errorf("formatCount(%d) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
43 changes: 25 additions & 18 deletions cmd/cli/search/backend_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
)

type backendResolver interface {
Resolve(ctx context.Context, target string) (string, error)
Resolve(ctx context.Context, target string) (backend string, size int64, err error)
}

type registryBackendResolver struct {
Expand All @@ -38,36 +38,42 @@ func newRegistryBackendResolver() *registryBackendResolver {
}
}

func (r *registryBackendResolver) Resolve(ctx context.Context, target string) (string, error) {
func (r *registryBackendResolver) Resolve(ctx context.Context, target string) (string, int64, error) {
model, err := r.lookup(ctx, withDefaultTag(target))
if err != nil {
return backendUnknown, err
return backendUnknown, 0, err
}

backend := backendUnknown
config, configErr := model.Config()
if configErr == nil {
if backend := backendFromFormat(config.GetFormat()); backend != backendUnknown {
return backend, nil
}
backend = backendFromFormat(config.GetFormat())
}

manifest, manifestErr := model.Manifest()
if manifestErr != nil {
if configErr != nil {
return backendUnknown, errors.Join(configErr, manifestErr)
return backendUnknown, 0, errors.Join(configErr, manifestErr)
}
return backendUnknown, manifestErr
return backend, 0, manifestErr
}

if backend == backendUnknown {
backend = backendFromManifestLayers(manifest)
}

if backend := backendFromManifestLayers(manifest); backend != backendUnknown {
return backend, nil
var totalSize int64
if manifest != nil {
for _, layer := range manifest.Layers {
totalSize += layer.Size
}
}
Comment thread
KeeTraxx marked this conversation as resolved.

if configErr != nil {
return backendUnknown, configErr
if backend == backendUnknown && configErr != nil {
return backendUnknown, totalSize, configErr
}

return backendUnknown, nil
return backend, totalSize, nil
}

type huggingFaceRepoBackendResolver struct {
Expand All @@ -81,12 +87,12 @@ func newHuggingFaceRepoBackendResolver() *huggingFaceRepoBackendResolver {
}
}

func (r *huggingFaceRepoBackendResolver) Resolve(ctx context.Context, target string) (string, error) {
func (r *huggingFaceRepoBackendResolver) Resolve(ctx context.Context, target string) (string, int64, error) {
repoFiles, err := r.listFiles(ctx, target, "main")
if err != nil {
return backendUnknown, err
return backendUnknown, 0, err
}
return backendFromRepoFiles(repoFiles), nil
return backendFromRepoFiles(repoFiles), distributionhf.TotalSize(repoFiles), nil
}

func backendFromFormat(format disttypes.Format) string {
Expand Down Expand Up @@ -152,7 +158,7 @@ func resolveSearchResultBackends(
ctx context.Context,
results []SearchResult,
resolveConcurrency int,
resolve func(context.Context, SearchResult) (string, error),
resolve func(context.Context, SearchResult) (string, int64, error),
) []SearchResult {
if len(results) == 0 {
return results
Expand All @@ -168,12 +174,13 @@ func resolveSearchResultBackends(

for i := range resolved {
group.Go(func() error {
backend, err := resolve(workerCtx, resolved[i])
backend, size, err := resolve(workerCtx, resolved[i])
if err != nil || backend == "" {
resolved[i].Backend = backendUnknown
return nil
}
resolved[i].Backend = backend
resolved[i].Size = size
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
return nil
})
}
Expand Down
6 changes: 3 additions & 3 deletions cmd/cli/search/backend_resolution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ func TestResolveSearchResultBackendsConcurrent(t *testing.T) {
}
}

resolve := func(_ context.Context, result SearchResult) (string, error) {
resolve := func(_ context.Context, result SearchResult) (string, int64, error) {
for i, r := range results {
if r.Name == result.Name {
return wantBackends[i], nil
return wantBackends[i], 0, nil
}
}
return backendUnknown, nil
return backendUnknown, 0, nil
}

resolved := resolveSearchResultBackends(t.Context(), results, numResults, resolve)
Expand Down
4 changes: 2 additions & 2 deletions cmd/cli/search/dockerhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ func (c *DockerHubClient) Search(ctx context.Context, opts SearchOptions) ([]Sea
nextURL = response.Next
}

return resolveSearchResultBackends(ctx, results, c.resolveConcurrency, func(ctx context.Context, result SearchResult) (string, error) {
return resolveSearchResultBackends(ctx, results, c.resolveConcurrency, func(ctx context.Context, result SearchResult) (string, int64, error) {
if c.backendResolver == nil {
return backendUnknown, nil
return backendUnknown, 0, nil
}
return c.backendResolver.Resolve(ctx, result.Name)
}), nil
Expand Down
8 changes: 4 additions & 4 deletions cmd/cli/search/dockerhub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ type fakeBackendResolver struct {
errs map[string]error
}

func (f fakeBackendResolver) Resolve(_ context.Context, target string) (string, error) {
func (f fakeBackendResolver) Resolve(_ context.Context, target string) (string, int64, error) {
if err, ok := f.errs[target]; ok {
return backendUnknown, err
return backendUnknown, 0, err
}
if backend, ok := f.backends[target]; ok {
return backend, nil
return backend, 0, nil
}
return backendUnknown, nil
return backendUnknown, 0, nil
}

func TestDockerHubSearchUsesVerifiedBackend(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions cmd/cli/search/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ func (c *HuggingFaceClient) Search(ctx context.Context, opts SearchOptions) ([]S
})
}

return resolveSearchResultBackends(ctx, results, c.resolveConcurrency, func(ctx context.Context, result SearchResult) (string, error) {
return resolveSearchResultBackends(ctx, results, c.resolveConcurrency, func(ctx context.Context, result SearchResult) (string, int64, error) {
if c.backendResolver == nil {
return backendUnknown, nil
return backendUnknown, 0, nil
}
return c.backendResolver.Resolve(ctx, result.Name)
}), nil
Expand Down
1 change: 1 addition & 0 deletions cmd/cli/search/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type SearchResult struct {
Official bool // Whether this is an official model
UpdatedAt string // Last update timestamp
Backend string // Backend type: "llama.cpp", "vllm", "diffusers", "unknown", or a comma-separated combination
Size int64 // Total size in bytes (0 if unknown)
}

// SearchOptions configures the search behavior
Expand Down