diff --git a/README.md b/README.md index 226ed35..bf2668a 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ Crust auto-detects the provider from the model name and passes through your auth ```bash crust status # Check if running crust logs -f # Follow logs +crust doctor # Diagnose provider endpoints crust stop # Stop crust ``` diff --git a/docs/cli.md b/docs/cli.md index ebec481..121785f 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -23,6 +23,10 @@ crust lint-rules [FILE] # Validate rule syntax # ACP Proxy crust acp-wrap [flags] -- # ACP stdio proxy with security rules +# Diagnostics +crust doctor [--timeout 5s] [--retries N] # Check provider endpoint connectivity +crust doctor --report # Generate sanitized report for GitHub issues + # Other crust version [--json] # Show version crust completion [--install] # Install shell completion (bash/zsh/fish) @@ -52,6 +56,15 @@ crust uninstall # Complete removal |------|-------------| | `--api-addr HOST:PORT` | Connect to a remote daemon (e.g. Docker) over TCP instead of the local Unix socket | +## Doctor Flags + +| Flag | Description | +|------|-------------| +| `--timeout DURATION` | Timeout per provider check (default `5s`) | +| `--retries N` | Retries for connection errors (default `1`, use `0` to disable) | +| `--report` | Generate a sanitized markdown report for GitHub issues | +| `--config PATH` | Path to configuration file | + ## ACP Wrap Flags | Flag | Description | @@ -98,6 +111,11 @@ crust list-rules --json crust status --live --api-addr localhost:9090 crust list-rules --api-addr localhost:9090 +# Diagnostics — check all provider endpoints (no daemon needed) +crust doctor +crust doctor --timeout 3s --retries 0 +crust doctor --report # sanitized report for GitHub issues + # ACP proxy: wrap Codex for JetBrains/Zed crust acp-wrap -- codex acp crust acp-wrap --log-level debug -- goose acp diff --git a/internal/completion/completion.go b/internal/completion/completion.go index 9980d5a..ae4977e 100644 --- a/internal/completion/completion.go +++ b/internal/completion/completion.go @@ -47,6 +47,7 @@ var command = &complete.Command{ "list-rules": {Flags: map[string]complete.Predictor{"json": predict.Nothing, "api-addr": predict.Nothing}}, "reload-rules": {}, "lint-rules": {Flags: map[string]complete.Predictor{"info": predict.Nothing}, Args: predict.Files("*.yaml")}, + "doctor": {Flags: map[string]complete.Predictor{"config": predict.Files("*.yaml"), "timeout": predict.Nothing, "retries": predict.Nothing, "report": predict.Nothing}}, "uninstall": {}, "help": {}, "completion": {Flags: map[string]complete.Predictor{"install": predict.Nothing, "uninstall": predict.Nothing}}, diff --git a/internal/proxy/doctor.go b/internal/proxy/doctor.go new file mode 100644 index 0000000..aeef043 --- /dev/null +++ b/internal/proxy/doctor.go @@ -0,0 +1,299 @@ +package proxy + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "path" + "sort" + "strings" + "sync" + "time" + + "github.com/BakeLens/crust/internal/config" +) + +// DoctorStatus represents the result of a provider endpoint check. +type DoctorStatus int + +const ( + StatusOK DoctorStatus = iota // 200: endpoint and key valid + StatusAuthError // 401/403: endpoint OK, key issue + StatusPathError // 404: wrong URL path + StatusConnError // connection failed + StatusOtherError // unexpected status code +) + +// String returns a short label for the status (OK, AUTH, PATH, CONN, ERR). +func (s DoctorStatus) String() string { + switch s { + case StatusOK: + return "OK" + case StatusAuthError: + return "AUTH" + case StatusPathError: + return "PATH" + case StatusConnError: + return "CONN" + default: + return "ERR" + } +} + +// DoctorResult holds the outcome of checking a single provider. +type DoctorResult struct { + Name string + URL string + Diagnosis string + Status DoctorStatus + StatusCode int + Duration time.Duration + HasAPIKey bool + IsUser bool +} + +// DoctorOptions configures the doctor check. +type DoctorOptions struct { + Timeout time.Duration + Retries int // number of retries for CONN errors (default 1) + UserProviders map[string]config.ProviderConfig +} + +// providerEntry is an internal representation of a provider for checking. +type providerEntry struct { + name string + config config.ProviderConfig + isUser bool +} + +// RunDoctor checks all providers (builtin + user) concurrently and returns +// results sorted by provider name. CONN errors are retried up to +// opts.Retries times with a short backoff. +func RunDoctor(opts DoctorOptions) []DoctorResult { + providers := mergeProviders(opts.UserProviders) + retries := opts.Retries + + client := &http.Client{ + Timeout: opts.Timeout, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + TLSHandshakeTimeout: opts.Timeout, + DialContext: (&net.Dialer{Timeout: opts.Timeout}).DialContext, + }, + } + defer client.CloseIdleConnections() + + results := make([]DoctorResult, len(providers)) + var wg sync.WaitGroup + for i, entry := range providers { + wg.Add(1) + go func(i int, entry providerEntry) { + defer wg.Done() + r := checkProvider(client, entry) + for attempt := range retries { + if r.Status != StatusConnError { + break + } + time.Sleep(time.Duration(attempt+1) * 500 * time.Millisecond) + r = checkProvider(client, entry) + } + results[i] = r + }(i, entry) + } + wg.Wait() + return results +} + +// isAnthropicProvider reports whether a provider URL uses the Anthropic +// Messages API protocol. Reuses detectAPIType for URLs whose path contains +// "/anthropic" or "/v1/messages", and additionally checks the host for +// api.anthropic.com (which has no path marker). +func isAnthropicProvider(providerURL string) bool { + u, err := url.Parse(providerURL) + if err != nil { + return false + } + if u.Host == "api.anthropic.com" { + return true + } + return detectAPIType(u.Path).IsAnthropic() +} + +// buildTestURL constructs a lightweight test endpoint URL for a provider, +// using the same version-handling logic as buildUpstreamURL. +// For Anthropic-protocol providers it targets /v1/messages (the only +// guaranteed endpoint); for OpenAI-protocol providers it targets /v1/models. +func buildTestURL(providerURL string) (string, error) { + u, err := url.Parse(providerURL) + if err != nil { + return "", fmt.Errorf("invalid provider URL %q: %w", providerURL, err) + } + + // Pick the right test path: Anthropic providers have no /models endpoint. + testPath := "/v1/models" + if isAnthropicProvider(providerURL) { + testPath = "/v1/messages" + } + + // Same logic as buildUpstreamURL auto mode: + // strip client /v1 when provider path already has a version segment. + if pathHasVersion(u.Path) { + testPath = stripLeadingVersion(testPath) + } + + u.Path = path.Join(u.Path, testPath) + return u.String(), nil +} + +// checkProvider sends a lightweight request to verify a provider endpoint. +// It reuses detectAPIType and injectAuth from the proxy to ensure the same +// protocol detection and auth logic used for real requests. +func checkProvider(client *http.Client, entry providerEntry) DoctorResult { + result := DoctorResult{ + Name: entry.name, + HasAPIKey: entry.config.APIKey != "", + IsUser: entry.isUser, + } + + testURL, err := buildTestURL(entry.config.URL) + if err != nil { + result.URL = entry.config.URL + result.Status = StatusConnError + result.Diagnosis = fmt.Sprintf("invalid URL: %v", err) + return result + } + result.URL = testURL + + // Use isAnthropicProvider to decide HTTP method: + // Anthropic endpoints only support POST /v1/messages, not GET /v1/models. + isAnthropic := isAnthropicProvider(entry.config.URL) + method := http.MethodGet + if isAnthropic { + method = http.MethodPost + } + + // Anthropic POST needs a body; empty POST may cause 500 on some proxies. + var body io.Reader + if isAnthropic { + body = bytes.NewReader([]byte(`{}`)) + } + req, err := http.NewRequestWithContext(context.Background(), method, testURL, body) + if err != nil { + result.Status = StatusConnError + result.Diagnosis = fmt.Sprintf("bad request: %v", err) + return result + } + + // Reuse injectAuth from proxy — same auth header logic for real requests. + if entry.config.APIKey != "" { + injectAuth(req.Header, entry.config.APIKey, "", isAnthropic) + } + + start := time.Now() + resp, err := client.Do(req) //nolint:gosec // doctor checks known provider URLs, not user-tainted input + result.Duration = time.Since(start) + + if err != nil || resp == nil { + result.Status = StatusConnError + result.Diagnosis = classifyConnError(err) + return result + } + defer resp.Body.Close() + + result.StatusCode = resp.StatusCode + switch resp.StatusCode { + case http.StatusOK: + result.Status = StatusOK + result.Diagnosis = "endpoint OK, key valid" + if !result.HasAPIKey { + result.Diagnosis = "endpoint OK, no API key configured" + } + case http.StatusUnauthorized, http.StatusForbidden: + result.Status = StatusAuthError + if result.HasAPIKey { + result.Diagnosis = "endpoint OK, key invalid or expired" + } else { + result.Diagnosis = "endpoint OK, no API key configured" + } + case http.StatusNotFound: + result.Status = StatusPathError + result.Diagnosis = "endpoint NOT found (path may be wrong)" + case http.StatusMethodNotAllowed: + // 405 means the path exists but doesn't accept the method — path is correct + result.Status = StatusOK + result.Diagnosis = "endpoint exists (method not allowed, path OK)" + case http.StatusBadRequest: + // 400 = endpoint alive but rejected the probe (e.g. Anthropic empty body, + // Gemini without API key). Path is correct; treat as OK. + result.Status = StatusOK + result.Diagnosis = "endpoint OK (bad request, path OK)" + if !result.HasAPIKey { + result.Diagnosis = "endpoint OK, no API key configured" + } + default: + result.Status = StatusOtherError + result.Diagnosis = fmt.Sprintf("unexpected status %d", resp.StatusCode) + } + return result +} + +// mergeProviders combines builtin and user providers, deduped by normalized URL. +// User providers with the same key override builtins. Sorted by name. +func mergeProviders(userProviders map[string]config.ProviderConfig) []providerEntry { + seen := make(map[string]bool) // normalized URL → already added + var entries []providerEntry + + // User providers first (higher priority) + for name, prov := range userProviders { + norm := normalizeProviderURL(prov.URL) + if seen[norm] { + continue + } + seen[norm] = true + entries = append(entries, providerEntry{name: name, config: prov, isUser: true}) + } + + // Builtins (skip if URL already covered) + for name, prov := range builtinProviders { + norm := normalizeProviderURL(prov.URL) + if seen[norm] { + continue + } + seen[norm] = true + entries = append(entries, providerEntry{name: name, config: prov}) + } + + sort.Slice(entries, func(i, j int) bool { + return entries[i].name < entries[j].name + }) + return entries +} + +// normalizeProviderURL strips trailing slash and lowercases scheme+host for dedup. +func normalizeProviderURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + return strings.ToLower(u.Scheme+"://"+u.Host) + strings.TrimSuffix(u.Path, "/") +} + +// classifyConnError returns a human-readable diagnosis for a connection error. +func classifyConnError(err error) string { + var netErr net.Error + if ok := errors.As(err, &netErr); ok && netErr.Timeout() { + return "connection timed out" + } + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return "DNS lookup failed: " + dnsErr.Name + } + return fmt.Sprintf("connection error: %v", err) +} diff --git a/internal/proxy/doctor_test.go b/internal/proxy/doctor_test.go new file mode 100644 index 0000000..df82a52 --- /dev/null +++ b/internal/proxy/doctor_test.go @@ -0,0 +1,208 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/BakeLens/crust/internal/config" +) + +func TestBuildTestURL(t *testing.T) { + tests := []struct { + name string + providerURL string + wantURL string + }{ + // OpenAI-protocol providers → /v1/models + {"OpenAI (no path)", "https://api.openai.com", "https://api.openai.com/v1/models"}, + {"DeepSeek (no path)", "https://api.deepseek.com", "https://api.deepseek.com/v1/models"}, + {"Mistral (no path)", "https://api.mistral.ai", "https://api.mistral.ai/v1/models"}, + {"Moonshot (no path)", "https://api.moonshot.ai", "https://api.moonshot.ai/v1/models"}, + {"GLM versioned /v4", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/models"}, + {"Gemini v1beta/openai", "https://generativelanguage.googleapis.com/v1beta/openai", "https://generativelanguage.googleapis.com/v1beta/openai/models"}, + {"Groq with /openai", "https://api.groq.com/openai", "https://api.groq.com/openai/v1/models"}, + {"Qwen compatible-mode", "https://dashscope.aliyuncs.com/compatible-mode", "https://dashscope.aliyuncs.com/compatible-mode/v1/models"}, + {"Codex backend", "https://chatgpt.com/backend-api/codex", "https://chatgpt.com/backend-api/codex/v1/models"}, + // Anthropic-protocol providers → /v1/messages (no /models endpoint) + {"Anthropic (no path)", "https://api.anthropic.com", "https://api.anthropic.com/v1/messages"}, + {"MiniMax /anthropic", "https://api.minimax.io/anthropic", "https://api.minimax.io/anthropic/v1/messages"}, + {"HF synthetic /anthropic", "https://api.synthetic.new/anthropic", "https://api.synthetic.new/anthropic/v1/messages"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := buildTestURL(tt.providerURL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.wantURL { + t.Errorf("buildTestURL(%q) = %q, want %q", tt.providerURL, got, tt.wantURL) + } + }) + } +} + +func TestCheckProvider_StatusCodes(t *testing.T) { + tests := []struct { + name string + statusCode int + wantStatus DoctorStatus + }{ + {"200 OK", http.StatusOK, StatusOK}, + {"401 Unauthorized", http.StatusUnauthorized, StatusAuthError}, + {"403 Forbidden", http.StatusForbidden, StatusAuthError}, + {"404 Not Found", http.StatusNotFound, StatusPathError}, + {"405 Method Not Allowed", http.StatusMethodNotAllowed, StatusOK}, + {"500 Internal Server Error", http.StatusInternalServerError, StatusOtherError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tt.statusCode) + })) + defer srv.Close() + + result := checkProvider(srv.Client(), providerEntry{ + name: "test", + config: config.ProviderConfig{URL: srv.URL}, + }) + if result.Status != tt.wantStatus { + t.Errorf("status = %v, want %v (diagnosis: %s)", result.Status, tt.wantStatus, result.Diagnosis) + } + if result.StatusCode != tt.statusCode { + t.Errorf("statusCode = %d, want %d", result.StatusCode, tt.statusCode) + } + }) + } +} + +func TestCheckProvider_ConnError(t *testing.T) { + result := checkProvider( + &http.Client{Timeout: 200 * time.Millisecond}, + providerEntry{ + name: "unreachable", + config: config.ProviderConfig{URL: "http://192.0.2.1:1"}, // TEST-NET + }, + ) + if result.Status != StatusConnError { + t.Errorf("status = %v, want StatusConnError (diagnosis: %s)", result.Status, result.Diagnosis) + } +} + +func TestCheckProvider_AuthHeader(t *testing.T) { + tests := []struct { + name string + provider string + // urlSuffix is appended to the test server URL to trigger protocol detection. + // "/anthropic" → Anthropic protocol (X-Api-Key); empty → OpenAI (Bearer). + urlSuffix string + apiKey string + wantAuth string + wantXKey string + }{ + {"bearer auth", "gpt", "", "sk-123", "Bearer sk-123", ""}, + {"anthropic x-api-key", "claude", "/anthropic", "sk-ant-123", "", "sk-ant-123"}, + {"no key", "gpt", "", "", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotAuth, gotXKey string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotXKey = r.Header.Get("X-Api-Key") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + checkProvider(srv.Client(), providerEntry{ + name: tt.provider, + config: config.ProviderConfig{URL: srv.URL + tt.urlSuffix, APIKey: tt.apiKey}, + }) + if gotAuth != tt.wantAuth { + t.Errorf("Authorization = %q, want %q", gotAuth, tt.wantAuth) + } + if gotXKey != tt.wantXKey { + t.Errorf("X-Api-Key = %q, want %q", gotXKey, tt.wantXKey) + } + }) + } +} + +func TestCheckProvider_AnthropicProtocol(t *testing.T) { + // Anthropic-protocol providers should use POST and treat 400 as OK + // (empty body rejected = endpoint alive). + var gotMethod string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + w.WriteHeader(http.StatusBadRequest) // empty POST body → 400 + })) + defer srv.Close() + + result := checkProvider(srv.Client(), providerEntry{ + name: "minimax", + config: config.ProviderConfig{URL: srv.URL + "/anthropic"}, + }) + if gotMethod != http.MethodPost { + t.Errorf("method = %q, want POST for Anthropic provider", gotMethod) + } + if result.Status != StatusOK { + t.Errorf("status = %v, want StatusOK for Anthropic 400 (diagnosis: %s)", result.Status, result.Diagnosis) + } +} + +func TestMergeProviders_Dedup(t *testing.T) { + user := map[string]config.ProviderConfig{ + "my-gpt": {URL: "https://api.openai.com", APIKey: "sk-user"}, + } + entries := mergeProviders(user) + + // "my-gpt" should be present (user), and builtin "gpt"/"openai"/"o1" etc. + // should be deduped because they share the same URL. + var openaiCount int + for _, e := range entries { + u, _ := url.Parse(e.config.URL) + if u != nil && u.Host == "api.openai.com" { + openaiCount++ + } + } + if openaiCount != 1 { + t.Errorf("expected 1 openai entry after dedup, got %d", openaiCount) + } + + // Verify user entry is the one that survived + for _, e := range entries { + if e.name == "my-gpt" { + if !e.isUser { + t.Error("expected my-gpt to be marked as user provider") + } + if e.config.APIKey != "sk-user" { + t.Error("expected user API key to be preserved") + } + return + } + } + t.Error("my-gpt entry not found in merged providers") +} + +func TestBuiltinProviders_Accessor(t *testing.T) { + providers := BuiltinProviders() + if len(providers) == 0 { + t.Fatal("BuiltinProviders() returned empty map") + } + + // Verify it's a copy — modifying it shouldn't affect the original + providers["test-mutation"] = config.ProviderConfig{URL: "http://mutated"} + fresh := BuiltinProviders() + if _, ok := fresh["test-mutation"]; ok { + t.Error("BuiltinProviders() returned a reference, not a copy") + } + + // Spot-check known providers + for _, key := range []string{"gpt", "claude", "glm", "gemini", "deepseek"} { + if _, ok := fresh[key]; !ok { + t.Errorf("expected builtin provider %q not found", key) + } + } +} diff --git a/internal/proxy/providers.go b/internal/proxy/providers.go index 50426c7..1187cad 100644 --- a/internal/proxy/providers.go +++ b/internal/proxy/providers.go @@ -1,6 +1,7 @@ package proxy import ( + "maps" "slices" "strings" @@ -23,12 +24,26 @@ var builtinProviders = map[string]config.ProviderConfig{ "qwen": {URL: "https://dashscope.aliyuncs.com/compatible-mode"}, "moonshot": {URL: "https://api.moonshot.ai"}, "kimi": {URL: "https://api.moonshot.ai"}, - "gemini": {URL: "https://generativelanguage.googleapis.com"}, - "mistral": {URL: "https://api.mistral.ai"}, - "groq": {URL: "https://api.groq.com/openai"}, - "llama": {URL: "https://api.groq.com/openai"}, - "minimax": {URL: "https://api.minimax.io/anthropic"}, - "hf:": {URL: "https://api.synthetic.new/anthropic"}, // HuggingFace + // Gemini's OpenAI-compatible endpoint is at /v1beta/openai/, not /v1/. + // With the default URL below, clients sending /v1/chat/completions get + // /v1/chat/completions which returns 404. Users must override this in + // config with the full base URL including the path prefix: + // gemini: https://generativelanguage.googleapis.com/v1beta/openai + "gemini": {URL: "https://generativelanguage.googleapis.com/v1beta/openai"}, + "glm": {URL: "https://open.bigmodel.cn/api/paas/v4"}, + "mistral": {URL: "https://api.mistral.ai"}, + "groq": {URL: "https://api.groq.com/openai"}, + "llama": {URL: "https://api.groq.com/openai"}, + "minimax": {URL: "https://api.minimax.io/anthropic"}, + "hf:": {URL: "https://api.synthetic.new/anthropic"}, // HuggingFace +} + +// BuiltinProviders returns a copy of the builtin provider map. +// Used by diagnostic tools (e.g., crust doctor) to enumerate all known providers. +func BuiltinProviders() map[string]config.ProviderConfig { + result := make(map[string]config.ProviderConfig, len(builtinProviders)) + maps.Copy(result, builtinProviders) + return result } // ResolveProvider resolves a model name to a provider config (URL + optional API key). diff --git a/internal/proxy/providers_test.go b/internal/proxy/providers_test.go index 7882379..cb7541e 100644 --- a/internal/proxy/providers_test.go +++ b/internal/proxy/providers_test.go @@ -45,7 +45,7 @@ func TestResolveProvider_BuiltinMatch(t *testing.T) { {"o1-preview", "https://api.openai.com"}, {"o3-mini", "https://api.openai.com"}, {"o4-mini", "https://api.openai.com"}, - {"gemini-pro", "https://generativelanguage.googleapis.com"}, + {"gemini-pro", "https://generativelanguage.googleapis.com/v1beta/openai"}, {"llama-3.3-70b-versatile", "https://api.groq.com/openai"}, {"mistral-large", "https://api.mistral.ai"}, {"moonshot-v1-8k", "https://api.moonshot.ai"}, diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 1f39e84..2565d00 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -812,8 +812,13 @@ func (p *Proxy) buildUpstreamURL(reqPath, model string) (url.URL, string, error) } u.Scheme = resolvedURL.Scheme u.Host = resolvedURL.Host - // path.Join concatenates provider path + request path and - // cleans the result (collapses double slashes, resolves dots). + // If the provider URL already contains an API version segment + // (e.g. /v4 in "open.bigmodel.cn/api/paas/v4"), strip the + // client's version prefix to avoid path duplication like + // /api/paas/v4/v1/chat/completions. + if pathHasVersion(resolvedURL.Path) { + reqPath = stripLeadingVersion(reqPath) + } u.Path = path.Join(resolvedURL.Path, reqPath) return u, result.APIKey, nil } @@ -831,11 +836,52 @@ func (p *Proxy) buildUpstreamURL(reqPath, model string) (url.URL, string, error) if basePath != "" && strings.HasPrefix(reqPath, basePath+"/") { u.Path = reqPath } else { + if pathHasVersion(basePath) { + reqPath = stripLeadingVersion(reqPath) + } u.Path = path.Join(u.Path, reqPath) } return u, "", nil } +// pathHasVersion reports whether any segment of the URL path starts with +// an API version prefix — "v" followed by at least one digit (e.g. "v1", +// "v4", "v1beta", "v2alpha1"). This detects provider URLs such as +// "open.bigmodel.cn/api/paas/v4" and "generativelanguage.googleapis.com/v1beta/openai". +// When detected, the client's redundant /vN prefix is stripped by +// stripLeadingVersion to avoid path duplication. +func pathHasVersion(p string) bool { + for seg := range strings.SplitSeq(p, "/") { + if len(seg) >= 2 && seg[0] == 'v' && seg[1] >= '0' && seg[1] <= '9' { + return true + } + } + return false +} + +// stripLeadingVersion removes a leading /vN segment from a request path. +// e.g. "/v1/chat/completions" → "/chat/completions". +// Returns the path unchanged if it does not start with a version segment. +func stripLeadingVersion(p string) string { + if len(p) < 3 || p[0] != '/' || p[1] != 'v' { + return p + } + i := 2 + for i < len(p) && p[i] >= '0' && p[i] <= '9' { + i++ + } + if i == 2 { + return p // no digits after 'v' + } + if i >= len(p) { + return "/" // entire path was just "/vN" + } + if p[i] == '/' { + return p[i:] + } + return p // not a pure version segment (e.g. "/v1beta/...") +} + // extractUsageAndBody extracts token usage and body from response func extractUsageAndBody(resp *http.Response, apiType types.APIType) (inputTokens, outputTokens int64, bodyBytes []byte) { if resp == nil { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 45370a8..541d21a 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -185,7 +185,7 @@ func TestBuildUpstreamURL_EndpointMode(t *testing.T) { wantPath: "/v1/chat/completions", }, { - name: "base path preserved for non-matching prefix /v1 vs /v1beta", + name: "base /v1 with client /v1beta — no dedup, path appended", upstream: "http://localhost:11434/v1", reqPath: "/v1beta/completions", model: "x", @@ -224,6 +224,22 @@ func TestBuildUpstreamURL_EndpointMode(t *testing.T) { wantHost: "openrouter.ai", wantPath: "/api/v1/chat/completions", }, + { + name: "versioned provider path strips client /v1", + upstream: "https://open.bigmodel.cn/api/paas/v4", + reqPath: "/v1/chat/completions", + model: "glm-4-plus", + wantHost: "open.bigmodel.cn", + wantPath: "/api/paas/v4/chat/completions", + }, + { + name: "gemini v1beta/openai strips client /v1", + upstream: "https://generativelanguage.googleapis.com/v1beta/openai", + reqPath: "/v1/chat/completions", + model: "gemini-2.0-flash", + wantHost: "generativelanguage.googleapis.com", + wantPath: "/v1beta/openai/chat/completions", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -309,6 +325,30 @@ func TestBuildUpstreamURL_AutoMode(t *testing.T) { wantHost: "api.anthropic.com", wantPath: "/v1/messages", }, + { + name: "glm strips client /v1 for versioned provider path", + upstream: "http://fallback:8080", + reqPath: "/v1/chat/completions", + model: "glm-4-plus", + wantHost: "open.bigmodel.cn", + wantPath: "/api/paas/v4/chat/completions", + }, + { + name: "gemini routes to v1beta/openai, strips client /v1", + upstream: "http://fallback:8080", + reqPath: "/v1/chat/completions", + model: "gemini-2.0-flash", + wantHost: "generativelanguage.googleapis.com", + wantPath: "/v1beta/openai/chat/completions", + }, + { + name: "gemini messages endpoint", + upstream: "http://fallback:8080", + reqPath: "/v1/messages", + model: "gemini-pro", + wantHost: "generativelanguage.googleapis.com", + wantPath: "/v1beta/openai/messages", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -329,6 +369,59 @@ func TestBuildUpstreamURL_AutoMode(t *testing.T) { } } +func TestPathHasVersion(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"", false}, + {"/", false}, + {"/api/paas/v4", true}, + {"/v1", true}, + {"/v1/", true}, + {"/compatible-mode", false}, + {"/backend-api/codex", false}, + {"/openai", false}, + {"/v1beta", true}, // v + digit prefix (Gemini-style) + {"/v1beta/openai", true}, // Gemini OpenAI-compat path + {"/api/v1beta2", true}, // v + digit prefix + {"/api", false}, + {"/anthropic", false}, + {"/vendor", false}, // v but no digit after + {"/vpc/subnet", false}, // v but no digit after + } + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + if got := pathHasVersion(tt.path); got != tt.want { + t.Errorf("pathHasVersion(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestStripLeadingVersion(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"/v1/chat/completions", "/chat/completions"}, + {"/v4/chat/completions", "/chat/completions"}, + {"/v1/messages", "/messages"}, + {"/v1", "/"}, + {"/v1beta/completions", "/v1beta/completions"}, // not pure version + {"/chat/completions", "/chat/completions"}, // no version + {"/responses", "/responses"}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + if got := stripLeadingVersion(tt.input); got != tt.want { + t.Errorf("stripLeadingVersion(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + // setupTestProxy creates a proxy pointed at the given upstream, with security enabled. func setupTestProxy(t *testing.T, upstream *httptest.Server) *Proxy { t.Helper() diff --git a/internal/tui/dashboard/common.go b/internal/tui/dashboard/common.go index 2652adb..51e765c 100644 --- a/internal/tui/dashboard/common.go +++ b/internal/tui/dashboard/common.go @@ -156,10 +156,8 @@ func RenderPlain(data StatusData) string { sb.WriteString("[crust] Security: disabled\n") } fmt.Fprintf(&sb, "[crust] Rules: %d loaded\n", data.RuleCount) - if data.Stats.TotalToolCalls > 0 { - pct := float64(data.Stats.BlockedCalls) / float64(data.Stats.TotalToolCalls) * 100 - fmt.Fprintf(&sb, "[crust] Calls: %d total, %d blocked (%.1f%%)\n", - data.Stats.TotalToolCalls, data.Stats.BlockedCalls, pct) + if data.Stats.BlockedCalls > 0 { + fmt.Fprintf(&sb, "[crust] Blocked: %d tool calls\n", data.Stats.BlockedCalls) } sb.WriteString("[crust] Logs: " + data.LogFile) } else { diff --git a/internal/tui/dashboard/common_test.go b/internal/tui/dashboard/common_test.go index 1aa7e55..cf113d3 100644 --- a/internal/tui/dashboard/common_test.go +++ b/internal/tui/dashboard/common_test.go @@ -152,7 +152,7 @@ func TestRenderPlain(t *testing.T) { RuleCount: 14, LogFile: "/tmp/crust.log", Stats: SecurityStats{TotalToolCalls: 100, BlockedCalls: 10, AllowedCalls: 90}, }, - []string{"PID 1234", "healthy", "enabled", "14 loaded", "100 total", "10 blocked", "/tmp/crust.log"}, + []string{"PID 1234", "healthy", "enabled", "14 loaded", "10 tool calls", "/tmp/crust.log"}, }, { "not running", diff --git a/internal/tui/dashboard/dashboard.go b/internal/tui/dashboard/dashboard.go index 611ed28..83209f5 100644 --- a/internal/tui/dashboard/dashboard.go +++ b/internal/tui/dashboard/dashboard.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/charmbracelet/bubbles/progress" "github.com/charmbracelet/bubbles/spinner" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -52,8 +51,7 @@ type model struct { apiBase string proxyBaseURL string - blockBar progress.Model - spinner spinner.Model + spinner spinner.Model // shimmer triggers when blocked count increases shimmer tui.ShimmerState @@ -74,9 +72,6 @@ type model struct { } func newModel(mgmtClient *http.Client, apiBase string, proxyBaseURL string, pid int) model { - blockBar := progress.New(progress.WithGradient("#F5A623", "#E05A3A"), progress.WithoutPercentage(), progress.WithWidth(20)) - blockBar.EmptyColor = "#3D3228" - s := spinner.New() s.Spinner = spinner.Dot s.Style = lipgloss.NewStyle().Foreground(tui.ColorSuccess) @@ -89,7 +84,6 @@ func newModel(mgmtClient *http.Client, apiBase string, proxyBaseURL string, pid apiBase: apiBase, proxyBaseURL: proxyBaseURL, data: StatusData{Running: true, PID: pid}, - blockBar: blockBar, spinner: s, shimmer: tui.NewShimmer(shimCfg), width: 60, @@ -237,11 +231,6 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.spinner, cmd = m.spinner.Update(msg) return m, cmd - case progress.FrameMsg: - pm, cmd := m.blockBar.Update(msg) - m.blockBar = pm.(progress.Model) - return m, cmd - case tea.WindowSizeMsg: m.width = msg.Width m.height = msg.Height @@ -374,17 +363,11 @@ func (m model) renderOverview() string { // Metrics — aggregate across all sessions since startup metricsTitle := tui.Separator("Security Metrics (all sessions, since startup)") - total := d.Stats.TotalToolCalls blocked := d.Stats.BlockedCalls - var blockPct float64 - if total > 0 { - blockPct = float64(blocked) / float64(total) * 100 - } - toolCalls := fmt.Sprintf(" %s %s", tui.Faint("Tool Calls"), formatCount(total)) var blockedStr string if m.shimmer.Active { - label := fmt.Sprintf(" %s %s (%.1f%%)", tui.Faint("Blocked"), formatCount(blocked), blockPct) + label := fmt.Sprintf(" %s %s", tui.Faint("Blocked"), formatCount(blocked)) runes := []rune(label) var bb strings.Builder for i, r := range runes { @@ -394,23 +377,15 @@ func (m model) renderOverview() string { } blockedStr = bb.String() } else { - blockedStr = fmt.Sprintf(" %s %s (%.1f%%)", tui.Faint("Blocked"), formatCount(blocked), blockPct) - } - - var barPct float64 - if total > 0 { - barPct = float64(blocked) / float64(total) + blockedStr = fmt.Sprintf(" %s %s", tui.Faint("Blocked"), formatCount(blocked)) } - blockBarView := fmt.Sprintf(" %s %s %.1f%%", tui.Faint("Block Rate"), m.blockBar.ViewAs(barPct), blockPct) logStr := fmt.Sprintf(" %s %s", tui.Faint("Logs"), tui.Hyperlink("file://"+d.LogFile, d.LogFile)) var sb strings.Builder sb.WriteString(info + "\n\n") sb.WriteString(metricsTitle + "\n\n") - sb.WriteString(toolCalls + "\n") sb.WriteString(blockedStr + "\n\n") - sb.WriteString(blockBarView + "\n\n") sb.WriteString(logStr) return sb.String() } @@ -614,12 +589,9 @@ func RenderStatic(data StatusData) string { } fmt.Fprintf(&sb, " %s %d loaded\n", tui.Faint("Rules"), data.RuleCount) - if data.Stats.TotalToolCalls > 0 { - blocked := data.Stats.BlockedCalls - total := data.Stats.TotalToolCalls - pct := float64(blocked) / float64(total) * 100 - fmt.Fprintf(&sb, " %s %s total, %s blocked (%.1f%%)\n", - tui.Faint("Calls"), formatCount(total), formatCount(blocked), pct) + if data.Stats.BlockedCalls > 0 { + fmt.Fprintf(&sb, " %s %s blocked\n", + tui.Faint("Calls"), formatCount(data.Stats.BlockedCalls)) } fmt.Fprintf(&sb, " %s %s", diff --git a/main.go b/main.go index 14753a4..4e72f79 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "os/signal" + "runtime" "slices" "strconv" "strings" @@ -89,6 +90,9 @@ func main() { case "lint-rules": runLintRules(os.Args[2:]) return + case "doctor": + runDoctor(os.Args[2:]) + return case "acp-wrap": runAcpWrap(os.Args[2:]) return @@ -727,6 +731,7 @@ func printUsage() { fmt.Println(tui.Separator("Other")) fmt.Print(tui.AlignColumns([][2]string{ + {"crust doctor [--timeout 5s] [--report]", "Check provider endpoint connectivity"}, {"crust acp-wrap [flags] -- ", "ACP stdio proxy with security rules"}, {"crust completion [--install]", "Install shell completion (bash/zsh/fish)"}, {"crust uninstall", "Uninstall crust completely"}, @@ -1082,6 +1087,141 @@ func runLintRules(args []string) { } } +// runDoctor handles the doctor subcommand — checks provider endpoint connectivity. +func runDoctor(args []string) { + tui.WindowTitle("crust doctor") + doctorFlags := flag.NewFlagSet("doctor", flag.ExitOnError) + configPath := doctorFlags.String("config", config.DefaultConfigPath(), "Path to configuration file") + timeout := doctorFlags.Duration("timeout", 5*time.Second, "Timeout per provider check") + retries := doctorFlags.Int("retries", 1, "Retries for connection errors") + report := doctorFlags.Bool("report", false, "Generate a sanitized report for GitHub issues") + _ = doctorFlags.Parse(args) + + // Load config for user-defined providers (no daemon needed) + cfg, err := config.Load(*configPath) + if err != nil { + cfg = config.DefaultConfig() + } + + fmt.Println() + fmt.Println(tui.Separator("Provider Diagnostics")) + fmt.Println() + + results := proxy.RunDoctor(proxy.DoctorOptions{ + Timeout: *timeout, + Retries: *retries, + UserProviders: cfg.Upstream.Providers, + }) + + // Print each result + var okCount, warnCount, errCount int + for _, r := range results { + printDoctorResult(r) + switch r.Status { + case proxy.StatusOK: + okCount++ + case proxy.StatusAuthError: + warnCount++ + default: + errCount++ + } + } + + // Summary + fmt.Println() + switch { + case errCount > 0: + tui.PrintError(fmt.Sprintf("%d error(s), %d warning(s), %d ok", errCount, warnCount, okCount)) + case warnCount > 0: + tui.PrintWarning(fmt.Sprintf("%d warning(s), %d ok", warnCount, okCount)) + default: + tui.PrintSuccess(fmt.Sprintf("All %d providers ok", okCount)) + } + + if *report { + fmt.Println() + fmt.Println(buildDoctorReport(results, okCount, warnCount, errCount)) + } +} + +// printDoctorResult prints a single provider check result. +func printDoctorResult(r proxy.DoctorResult) { + tag := r.Status.String() + latency := fmt.Sprintf("(%s)", r.Duration.Round(time.Millisecond)) + + if tui.IsPlainMode() { + name := r.Name + if r.IsUser { + name += " *" + } + fmt.Printf(" [%-4s] %-14s %s\n", tag, name, r.URL) + fmt.Printf(" %s %s\n", r.Diagnosis, latency) + return + } + + // Styled output + var icon string + var style lipgloss.Style + switch r.Status { + case proxy.StatusOK: + icon = tui.IconCheck + style = tui.StyleSuccess + case proxy.StatusAuthError: + icon = tui.IconWarning + style = tui.StyleWarning + default: + icon = tui.IconCross + style = tui.StyleError + } + + // Pad raw name before styling so column width counts visible chars, not ANSI codes. + name := r.Name + if r.IsUser { + name += " *" + } + paddedName := fmt.Sprintf("%-14s", name) + fmt.Printf(" %s %s %s\n", style.Render(icon), tui.StyleBold.Render(paddedName), tui.Faint(r.URL)) + fmt.Printf(" %s %s %s\n", style.Render(tag), r.Diagnosis, tui.Faint(latency)) + fmt.Println() +} + +// buildDoctorReport generates a sanitized markdown report for GitHub issues. +// Privacy: user-defined provider URLs are masked to host-only; API keys are +// never included (DoctorResult doesn't carry them). +func buildDoctorReport(results []proxy.DoctorResult, okCount, warnCount, errCount int) string { + var sb strings.Builder + sb.WriteString("## Crust Doctor Report\n\n") + sb.WriteString("```\n") + fmt.Fprintf(&sb, "Version: %s\n", Version) + fmt.Fprintf(&sb, "OS: %s/%s\n", runtime.GOOS, runtime.GOARCH) + fmt.Fprintf(&sb, "Go: %s\n", runtime.Version()) + fmt.Fprintf(&sb, "Summary: %d ok, %d auth, %d error\n\n", okCount, warnCount, errCount) + + fmt.Fprintf(&sb, "%-14s %-5s %4s %-8s %s\n", "PROVIDER", "STAT", "CODE", "LATENCY", "DIAGNOSIS") + fmt.Fprintf(&sb, "%-14s %-5s %4s %-8s %s\n", "--------", "----", "----", "-------", "---------") + for _, r := range results { + code := "-" + if r.StatusCode > 0 { + code = strconv.Itoa(r.StatusCode) + } + name := r.Name + diagnosis := r.Diagnosis + if r.IsUser { + name += " *" + // Sanitize diagnosis for user-defined providers: + // connection errors may embed the full URL. + diagnosis = r.Status.String() + } + fmt.Fprintf(&sb, "%-14s %-5s %4s %-8s %s\n", + name, r.Status, code, + r.Duration.Round(time.Millisecond).String(), diagnosis, + ) + } + sb.WriteString("```\n") + sb.WriteString("\nPaste the block above into a GitHub issue at https://github.com/BakeLens/crust/issues/new\n") + return sb.String() +} + // runCompletion handles the completion subcommand func runCompletion(args []string) { compFlags := flag.NewFlagSet("completion", flag.ExitOnError)