From c9e5c4915a107afbe91ab6eb036fc47ec6f40cfd Mon Sep 17 00:00:00 2001 From: Mikhail Batukhtin <6481198+remdev@users.noreply.github.com> Date: Mon, 15 Jun 2026 16:24:19 +0300 Subject: [PATCH 1/6] feat(reviewbackend): add Cursor and chat-completions review backends Introduce internal/reviewbackend so main review work runs through a shared Backend interface. ChatCompletionsBackend preserves the existing OpenAI and Anthropic tool-call loop; CursorAgentBackend runs local Cursor agents with OCR custom tools. Wire both paths through review and llm commands. --- cmd/opencodereview/llm_cmd.go | 41 ++- cmd/opencodereview/review_cmd.go | 19 +- go.mod | 1 + go.sum | 2 + internal/agent/agent.go | 180 +++++------- internal/agent/backend_test.go | 84 ++++++ internal/reviewbackend/backend.go | 84 ++++++ internal/reviewbackend/chat_completions.go | 215 ++++++++++++++ .../reviewbackend/chat_completions_test.go | 135 +++++++++ internal/reviewbackend/cursor_agent.go | 264 ++++++++++++++++++ internal/reviewbackend/cursor_agent_test.go | 107 +++++++ internal/reviewbackend/cursor_usage.go | 84 ++++++ internal/reviewbackend/cursor_usage_test.go | 71 +++++ internal/reviewbackend/factory.go | 62 ++++ internal/reviewbackend/resolver.go | 99 +++++++ internal/reviewbackend/resolver_test.go | 125 +++++++++ 16 files changed, 1449 insertions(+), 124 deletions(-) create mode 100644 internal/agent/backend_test.go create mode 100644 internal/reviewbackend/backend.go create mode 100644 internal/reviewbackend/chat_completions.go create mode 100644 internal/reviewbackend/chat_completions_test.go create mode 100644 internal/reviewbackend/cursor_agent.go create mode 100644 internal/reviewbackend/cursor_agent_test.go create mode 100644 internal/reviewbackend/cursor_usage.go create mode 100644 internal/reviewbackend/cursor_usage_test.go create mode 100644 internal/reviewbackend/factory.go create mode 100644 internal/reviewbackend/resolver.go create mode 100644 internal/reviewbackend/resolver_test.go diff --git a/cmd/opencodereview/llm_cmd.go b/cmd/opencodereview/llm_cmd.go index 95cf8618..294ff9c2 100644 --- a/cmd/opencodereview/llm_cmd.go +++ b/cmd/opencodereview/llm_cmd.go @@ -3,10 +3,12 @@ package main import ( "context" "fmt" + "os" "time" "github.com/open-code-review/open-code-review/internal/config/testconnection" "github.com/open-code-review/open-code-review/internal/llm" + "github.com/open-code-review/open-code-review/internal/reviewbackend" ) func runLLM(args []string) error { @@ -37,9 +39,28 @@ func runLLMTest() error { return fmt.Errorf("load config: %w", err) } - ep, err := llm.ResolveEndpoint(cfgPath) + resolved, err := reviewbackend.ResolveBackend(cfgPath) if err != nil { - return fmt.Errorf("resolve LLM endpoint: %w", err) + return fmt.Errorf("resolve review backend: %w", err) + } + + repoDir, err := os.Getwd() + if err != nil { + return err + } + + backend, err := reviewbackend.New(context.Background(), resolved, repoDir) + if err != nil { + return fmt.Errorf("create review backend: %w", err) + } + + llmClient := reviewbackend.TextClient(backend) + + model := resolved.Endpoint.Model + source := resolved.Endpoint.Source + if resolved.Kind == reviewbackend.KindCursorAgent { + model = resolved.Cursor.Model + source = resolved.Cursor.Source } task, err := testconnection.LoadDefault() @@ -55,8 +76,6 @@ func runLLMTest() error { timeout = time.Duration(task.Timeout) * time.Second } - llmClient := llm.NewLLMClient(ep) - messages := make([]llm.Message, 0, len(task.Messages)) for _, m := range task.Messages { messages = append(messages, llm.Message{Role: m.Role, Content: m.Content}) @@ -66,7 +85,7 @@ func runLLMTest() error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return llmClient.CompletionsWithCtx(ctx, llm.ChatRequest{ - Model: ep.Model, + Model: model, Messages: messages, MaxTokens: 256, }) @@ -75,13 +94,15 @@ func runLLMTest() error { return fmt.Errorf("llm request failed: %w", err) } - model := ep.Model + outModel := model if resp.Model != "" { - model = resp.Model + outModel = resp.Model + } + fmt.Printf("Source: %s\n", source) + if resolved.Kind == reviewbackend.KindChatCompletions { + fmt.Printf("URL: %s\n", resolved.Endpoint.URL) } - fmt.Printf("Source: %s\n", ep.Source) - fmt.Printf("URL: %s\n", ep.URL) - fmt.Printf("Model: %s\n", model) + fmt.Printf("Model: %s\n", outModel) fmt.Printf("%s\n", resp.Content()) return nil } diff --git a/cmd/opencodereview/review_cmd.go b/cmd/opencodereview/review_cmd.go index bf9f3c5a..00ac043a 100644 --- a/cmd/opencodereview/review_cmd.go +++ b/cmd/opencodereview/review_cmd.go @@ -14,7 +14,7 @@ import ( "github.com/open-code-review/open-code-review/internal/config/toolsconfig" "github.com/open-code-review/open-code-review/internal/diff" "github.com/open-code-review/open-code-review/internal/gitcmd" - "github.com/open-code-review/open-code-review/internal/llm" + "github.com/open-code-review/open-code-review/internal/reviewbackend" "github.com/open-code-review/open-code-review/internal/stdout" "github.com/open-code-review/open-code-review/internal/telemetry" "github.com/open-code-review/open-code-review/internal/tool" @@ -88,13 +88,21 @@ func runReview(args []string) error { tpl.ApplyLanguage(appCfg.Language) } - ep, err := llm.ResolveEndpoint(cfgPath) + resolved, err := reviewbackend.ResolveBackend(cfgPath) if err != nil { - return fmt.Errorf("resolve LLM endpoint: %w", err) + return fmt.Errorf("resolve review backend: %w", err) } - llmClient := llm.NewLLMClient(ep) - model := ep.Model + backend, err := reviewbackend.New(context.Background(), resolved, repoDir) + if err != nil { + return fmt.Errorf("create review backend: %w", err) + } + + llmClient := reviewbackend.TextClient(backend) + model := resolved.Endpoint.Model + if resolved.Kind == reviewbackend.KindCursorAgent { + model = resolved.Cursor.Model + } gitRunner := gitcmd.New(opts.maxGitProcs) @@ -118,6 +126,7 @@ func runReview(args []string) error { SystemRule: resolver, FileFilter: fileFilter, LLMClient: llmClient, + Backend: backend, Tools: tools, PlanToolDefs: planToolDefs, MainToolDefs: mainToolDefs, diff --git a/go.mod b/go.mod index c6dcda13..1a2fab16 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/bmatcuk/doublestar/v4 v4.10.0 github.com/openai/openai-go/v3 v3.39.0 github.com/pkoukk/tiktoken-go v0.1.8 + github.com/remdev/cursor-go-sdk v0.0.0-20260614191545-6e5b35765d3b go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 diff --git a/go.sum b/go.sum index 291ab0b8..25fefa6a 100644 --- a/go.sum +++ b/go.sum @@ -73,6 +73,8 @@ github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/Q github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remdev/cursor-go-sdk v0.0.0-20260614191545-6e5b35765d3b h1:Ajn05OTxZ9VM7oVX2QAg4ADt+yOmy1iDifTtRiw4Q4Q= +github.com/remdev/cursor-go-sdk v0.0.0-20260614191545-6e5b35765d3b/go.mod h1:8+NZymOCljHsqHoor5ZU9o4kuwy573NuY7U8n6jajAY= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 h1:uOfcYT+3QungH6tIGSVCR/Y3KJmgJiHcojJbMTPDZAI= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index a71bc072..9d45e3ac 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -18,6 +18,7 @@ import ( "github.com/open-code-review/open-code-review/internal/gitcmd" "github.com/open-code-review/open-code-review/internal/llm" "github.com/open-code-review/open-code-review/internal/model" + "github.com/open-code-review/open-code-review/internal/reviewbackend" "github.com/open-code-review/open-code-review/internal/session" "github.com/open-code-review/open-code-review/internal/stdout" "github.com/open-code-review/open-code-review/internal/telemetry" @@ -71,9 +72,13 @@ type Args struct { // When nil, only the default extension and path filters apply. FileFilter *rules.FileFilter - // LLM client for model inference. + // LLM client for text completions (relocation, compression). When Backend is set, + // callers should also set LLMClient to reviewbackend.TextClient(Backend). LLMClient llm.LLMClient + // Backend executes plan/main review model work (chat completions or Cursor agent). + Backend reviewbackend.Backend + // Tool registry mapping tool aliases to implementations. Tools *tool.Registry @@ -893,119 +898,76 @@ func formatToolDefs(toolDefs []llm.ToolDef) string { return sb.String() } -// performLlmCodeReview drives the main LLM conversation loop for a single file. -// It sends messages with tool definitions, handles tool calls returned by the model, -// and collects review comments until task_done is called or limits are reached. +// performLlmCodeReview drives the main review conversation for a single file via the configured backend. func (a *Agent) performLlmCodeReview(ctx context.Context, messages []llm.Message, newPath string) error { - toolReqCount := a.args.Template.MaxToolRequestTimes - const maxConsecutiveEmptyRounds = 3 - consecutiveEmptyRounds := 0 - - for toolReqCount > 0 { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - toolReqCount-- - - fs := a.session.GetOrCreateFileSession(newPath) - rec := fs.AppendTaskRecord(session.MainTask, append([]llm.Message(nil), messages...)) - startTime := time.Now() - - resp, err := a.args.LLMClient.CompletionsWithCtx(ctx, llm.ChatRequest{ - Model: a.args.Model, - Messages: messages, - Tools: a.args.MainToolDefs, - MaxTokens: a.args.Template.MaxTokens, - }) - duration := time.Since(startTime) - if err != nil { - rec.SetError(err, duration) - telemetry.RecordLLMRequest(ctx, a.args.Model, duration, 0, "error") - return fmt.Errorf("LLM completion error: %w", err) - } - rec.SetResponse(resp, duration) - // Record LLM metrics with token info from API response usage field. - totalTokens := int64(0) - if resp.Usage != nil { - totalTokens = resp.Usage.TotalTokens - atomic.AddInt64(&a.totalInputTokens, resp.Usage.PromptTokens) - atomic.AddInt64(&a.totalOutputTokens, resp.Usage.CompletionTokens) - atomic.AddInt64(&a.totalCacheReadTokens, resp.Usage.CacheReadTokens) - atomic.AddInt64(&a.totalCacheWriteTokens, resp.Usage.CacheWriteTokens) - } - telemetry.RecordLLMRequest(ctx, a.args.Model, duration, totalTokens, "ok") - - content := resp.Content() - calls := resp.ToolCalls() + if a.args.Backend == nil { + return fmt.Errorf("review backend is not configured") + } - if len(calls) == 0 { - // No tool calls - remind the model - fmt.Fprintf(stdout.Writer(), "[ocr] No tool calls parsed for %s, retrying...\n", newPath) - messages = append(messages, llm.NewTextMessage("user", "You did not successfully call any tools. Please try again or use task_done if finished.")) - if content != "" { - messages = append(messages[:len(messages)-1], llm.NewTextMessage("assistant", content), messages[len(messages)-1]) - } - continue + var mainRec *session.TaskRecord + executor := func(execCtx context.Context, call reviewbackend.ToolCallInput) reviewbackend.ToolCallOutput { + if execCtx == nil { + execCtx = ctx } - - var results []tool.ToolCallResult - taskCompleted := false - hasValidResult := false - - for _, call := range calls { - cp := a.executeToolCall(ctx, newPath, call, rec) - if cp.Completed { - results = append(results, tool.ToolCallResult{ - ToolCallID: call.ID, - Name: call.Function.Name, - Result: "Task completed successfully.", - }) - taskCompleted = true - } else if cp.Data != "" { - results = append(results, tool.ToolCallResult{ - ToolCallID: call.ID, - Name: call.Function.Name, - Result: cp.Data, - }) - hasValidResult = true - } else { - results = append(results, tool.ToolCallResult{ - ToolCallID: call.ID, - Name: call.Function.Name, - Result: "Error: Tool execution returned no result.", - }) + cp := a.executeToolCall(execCtx, newPath, llm.ToolCall{ + ID: call.ID, + Type: "function", + Function: llm.FunctionCall{ + Name: call.Name, + Arguments: call.Arguments, + }, + }, mainRec) + return reviewbackend.ToolCallOutput{Result: cp.Data, Completed: cp.Completed} + } + + hooks := &reviewbackend.ReviewHooks{ + AppendTaskRecord: func(taskType session.TaskType, msgs []llm.Message) *session.TaskRecord { + fs := a.session.GetOrCreateFileSession(newPath) + mainRec = fs.AppendTaskRecord(taskType, msgs) + return mainRec + }, + SetResponse: func(rec *session.TaskRecord, resp *llm.ChatResponse, durationMs int64) { + rec.SetResponse(resp, time.Duration(durationMs)*time.Millisecond) + }, + SetError: func(rec *session.TaskRecord, err error, durationMs int64) { + rec.SetError(err, time.Duration(durationMs)*time.Millisecond) + }, + RecordUsage: func(usage *llm.UsageInfo) { + if usage == nil { + return } - } - - if taskCompleted { - break - } - if !hasValidResult { - consecutiveEmptyRounds++ - if consecutiveEmptyRounds >= maxConsecutiveEmptyRounds { - fmt.Fprintf(stdout.Writer(), "[ocr] Too many empty retries for %s, stopping.\n", newPath) - break + atomic.AddInt64(&a.totalInputTokens, usage.PromptTokens) + atomic.AddInt64(&a.totalOutputTokens, usage.CompletionTokens) + atomic.AddInt64(&a.totalCacheReadTokens, usage.CacheReadTokens) + atomic.AddInt64(&a.totalCacheWriteTokens, usage.CacheWriteTokens) + }, + RecordLLMRequest: func(durationMs int64, totalTokens int64, status string) { + telemetry.RecordLLMRequest(ctx, a.args.Model, time.Duration(durationMs)*time.Millisecond, totalTokens, status) + }, + AppendRound: func(assistantContent string, calls []llm.ToolCall, results []reviewbackend.ToolRoundResult, msgs *[]llm.Message) bool { + toolResults := make([]tool.ToolCallResult, len(results)) + for i, r := range results { + toolResults[i] = tool.ToolCallResult{ + ToolCallID: r.ToolCallID, + Name: r.Name, + Result: r.Result, + } } - fmt.Fprintf(stdout.Writer(), "[ocr] No valid tool results for %s, retrying...\n", newPath) - } else { - consecutiveEmptyRounds = 0 - } - - succeed := a.addNextMessage(ctx, content, calls, results, &messages, newPath) - if !succeed { - fmt.Fprintf(stdout.Writer(), "[ocr] Context compression exceeded threshold for %s, stopping.\n", newPath) - break - } - } - - if toolReqCount <= 0 { - fmt.Fprintf(stdout.Writer(), "[ocr] Max tool requests reached for %s.\n", newPath) - } - - return nil + return a.addNextMessage(ctx, assistantContent, calls, toolResults, msgs, newPath) + }, + Logf: func(format string, args ...any) { + fmt.Fprintf(stdout.Writer(), format, args...) + }, + } + + return a.args.Backend.ReviewFile(ctx, reviewbackend.ReviewFileRequest{ + Model: a.args.Model, + Messages: messages, + Tools: a.args.MainToolDefs, + MaxTokens: a.args.Template.MaxTokens, + MaxToolRounds: a.args.Template.MaxToolRequestTimes, + FilePath: newPath, + }, executor, hooks) } // executeToolCall executes a single tool call from the LLM response and records diff --git a/internal/agent/backend_test.go b/internal/agent/backend_test.go new file mode 100644 index 00000000..92197a1d --- /dev/null +++ b/internal/agent/backend_test.go @@ -0,0 +1,84 @@ +package agent + +import ( + "context" + "testing" + + "github.com/open-code-review/open-code-review/internal/config/template" + "github.com/open-code-review/open-code-review/internal/llm" + "github.com/open-code-review/open-code-review/internal/reviewbackend" + "github.com/open-code-review/open-code-review/internal/session" + "github.com/open-code-review/open-code-review/internal/tool" +) + +type recordingBackend struct { + reviewCalled bool + lastReq reviewbackend.ReviewFileRequest + executor reviewbackend.ToolExecutor +} + +func (b *recordingBackend) Kind() reviewbackend.Kind { return reviewbackend.KindCursorAgent } +func (b *recordingBackend) Model() string { return "test-model" } +func (b *recordingBackend) Source() string { return "test" } +func (b *recordingBackend) Complete(context.Context, reviewbackend.CompleteRequest) (*reviewbackend.CompleteResponse, error) { + return &reviewbackend.CompleteResponse{}, nil +} +func (b *recordingBackend) ReviewFile(ctx context.Context, req reviewbackend.ReviewFileRequest, exec reviewbackend.ToolExecutor, _ *reviewbackend.ReviewHooks) error { + b.reviewCalled = true + b.lastReq = req + b.executor = exec + if exec != nil { + exec(ctx, reviewbackend.ToolCallInput{ + ID: "tc-1", + Name: "code_comment", + Arguments: `{"body":"note"}`, + }) + } + return nil +} + +func TestPerformLlmCodeReview_UsesBackend(t *testing.T) { + backend := &recordingBackend{} + reg := tool.NewRegistry() + a := &Agent{ + args: Args{ + Model: "test-model", + Backend: backend, + Tools: reg, + MainToolDefs: []llm.ToolDef{{ + Type: "function", + Function: llm.FunctionDef{Name: "code_comment"}, + }}, + Template: templateWithMaxRounds(2), + }, + session: session.New(t.TempDir(), "main", "test-model", session.SessionOptions{}), + } + + err := a.performLlmCodeReview(context.Background(), []llm.Message{ + llm.NewTextMessage("user", "review file"), + }, "src/main.go") + if err != nil { + t.Fatalf("performLlmCodeReview: %v", err) + } + if !backend.reviewCalled { + t.Fatal("backend.ReviewFile was not called") + } + if backend.lastReq.FilePath != "src/main.go" { + t.Errorf("FilePath = %q", backend.lastReq.FilePath) + } + if backend.lastReq.Model != "test-model" { + t.Errorf("Model = %q", backend.lastReq.Model) + } +} + +func TestPerformLlmCodeReview_RequiresBackend(t *testing.T) { + a := &Agent{args: Args{}} + err := a.performLlmCodeReview(context.Background(), nil, "x.go") + if err == nil { + t.Fatal("expected error when backend is nil") + } +} + +func templateWithMaxRounds(n int) template.Template { + return template.Template{MaxToolRequestTimes: n, MaxTokens: 1024} +} diff --git a/internal/reviewbackend/backend.go b/internal/reviewbackend/backend.go new file mode 100644 index 00000000..b358919b --- /dev/null +++ b/internal/reviewbackend/backend.go @@ -0,0 +1,84 @@ +package reviewbackend + +import ( + "context" + + "github.com/open-code-review/open-code-review/internal/llm" + "github.com/open-code-review/open-code-review/internal/session" +) + +// Kind identifies how review tasks are executed. +type Kind string + +const ( + KindChatCompletions Kind = "chat_completions" + KindCursorAgent Kind = "cursor_agent" +) + +// CompleteRequest is a text-only completion (plan, filter, compression, llm test). +type CompleteRequest struct { + Model string + Messages []llm.Message + MaxTokens int +} + +// CompleteResponse is the result of a text-only completion. +type CompleteResponse struct { + Content string + Model string + Usage *llm.UsageInfo + Raw *llm.ChatResponse +} + +// ReviewFileRequest drives the main per-file review task with tools. +type ReviewFileRequest struct { + Model string + Messages []llm.Message + Tools []llm.ToolDef + MaxTokens int + MaxToolRounds int + FilePath string +} + +// ToolCallInput is passed to the tool executor from any backend. +type ToolCallInput struct { + ID string + Name string + Arguments string +} + +// ToolCallOutput is returned by the tool executor. +type ToolCallOutput struct { + Result string + Completed bool +} + +// ToolExecutor runs a single tool call on behalf of the backend. +type ToolExecutor func(ctx context.Context, call ToolCallInput) ToolCallOutput + +// ReviewHooks wires agent-level session, telemetry, and compression into a backend loop. +type ReviewHooks struct { + AppendTaskRecord func(taskType session.TaskType, messages []llm.Message) *session.TaskRecord + SetResponse func(rec *session.TaskRecord, resp *llm.ChatResponse, durationMs int64) + SetError func(rec *session.TaskRecord, err error, durationMs int64) + RecordUsage func(usage *llm.UsageInfo) + RecordLLMRequest func(durationMs int64, totalTokens int64, status string) + AppendRound func(assistantContent string, calls []llm.ToolCall, results []ToolRoundResult, messages *[]llm.Message) bool + Logf func(format string, args ...any) +} + +// ToolRoundResult is a single tool result within a review round. +type ToolRoundResult struct { + ToolCallID string + Name string + Result string +} + +// Backend executes review-related model work. +type Backend interface { + Kind() Kind + Model() string + Source() string + Complete(ctx context.Context, req CompleteRequest) (*CompleteResponse, error) + ReviewFile(ctx context.Context, req ReviewFileRequest, exec ToolExecutor, hooks *ReviewHooks) error +} diff --git a/internal/reviewbackend/chat_completions.go b/internal/reviewbackend/chat_completions.go new file mode 100644 index 00000000..cd7dbd70 --- /dev/null +++ b/internal/reviewbackend/chat_completions.go @@ -0,0 +1,215 @@ +package reviewbackend + +import ( + "context" + "fmt" + "time" + + "github.com/open-code-review/open-code-review/internal/llm" + "github.com/open-code-review/open-code-review/internal/session" +) + +// ChatCompletionsBackend runs review via OpenAI/Anthropic chat completions and tool_calls. +type ChatCompletionsBackend struct { + client llm.LLMClient + ep llm.ResolvedEndpoint +} + +// NewChatCompletionsBackend creates a backend from a resolved LLM endpoint. +func NewChatCompletionsBackend(ep llm.ResolvedEndpoint) *ChatCompletionsBackend { + return &ChatCompletionsBackend{ + client: llm.NewLLMClient(ep), + ep: ep, + } +} + +func (b *ChatCompletionsBackend) Kind() Kind { return KindChatCompletions } + +func (b *ChatCompletionsBackend) Model() string { return b.ep.Model } + +func (b *ChatCompletionsBackend) Source() string { return b.ep.Source } + +func (b *ChatCompletionsBackend) Complete(ctx context.Context, req CompleteRequest) (*CompleteResponse, error) { + model := req.Model + if model == "" { + model = b.ep.Model + } + start := time.Now() + resp, err := b.client.CompletionsWithCtx(ctx, llm.ChatRequest{ + Model: model, + Messages: req.Messages, + MaxTokens: req.MaxTokens, + }) + _ = start + if err != nil { + return nil, err + } + outModel := model + if resp.Model != "" { + outModel = resp.Model + } + return &CompleteResponse{ + Content: resp.Content(), + Model: outModel, + Usage: resp.Usage, + Raw: resp, + }, nil +} + +// ReviewFile runs the chat-completions tool loop until task_done or limits are hit. +func (b *ChatCompletionsBackend) ReviewFile(ctx context.Context, req ReviewFileRequest, exec ToolExecutor, hooks *ReviewHooks) error { + if hooks == nil { + hooks = &ReviewHooks{} + } + logf := hooks.Logf + if logf == nil { + logf = func(string, ...any) {} + } + + model := req.Model + if model == "" { + model = b.ep.Model + } + + messages := append([]llm.Message(nil), req.Messages...) + toolReqCount := req.MaxToolRounds + if toolReqCount <= 0 { + toolReqCount = 1 + } + + const maxConsecutiveEmptyRounds = 3 + consecutiveEmptyRounds := 0 + + for toolReqCount > 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + toolReqCount-- + + var rec *session.TaskRecord + if hooks.AppendTaskRecord != nil { + rec = hooks.AppendTaskRecord(session.MainTask, append([]llm.Message(nil), messages...)) + } + + start := time.Now() + resp, err := b.client.CompletionsWithCtx(ctx, llm.ChatRequest{ + Model: model, + Messages: messages, + Tools: req.Tools, + MaxTokens: req.MaxTokens, + }) + duration := time.Since(start) + + if err != nil { + if hooks.SetError != nil && rec != nil { + hooks.SetError(rec, err, duration.Milliseconds()) + } + if hooks.RecordLLMRequest != nil { + hooks.RecordLLMRequest(duration.Milliseconds(), 0, "error") + } + return fmt.Errorf("LLM completion error: %w", err) + } + + if hooks.SetResponse != nil && rec != nil { + hooks.SetResponse(rec, resp, duration.Milliseconds()) + } + totalTokens := int64(0) + if resp.Usage != nil { + totalTokens = resp.Usage.TotalTokens + if hooks.RecordUsage != nil { + hooks.RecordUsage(resp.Usage) + } + } + if hooks.RecordLLMRequest != nil { + hooks.RecordLLMRequest(duration.Milliseconds(), totalTokens, "ok") + } + + content := resp.Content() + calls := resp.ToolCalls() + + if len(calls) == 0 { + logf("[ocr] No tool calls parsed for %s, retrying...\n", req.FilePath) + messages = append(messages, llm.NewTextMessage("user", "You did not successfully call any tools. Please try again or use task_done if finished.")) + if content != "" { + messages = append(messages[:len(messages)-1], llm.NewTextMessage("assistant", content), messages[len(messages)-1]) + } + continue + } + + var results []ToolRoundResult + taskCompleted := false + hasValidResult := false + + for _, call := range calls { + out := exec(ctx, ToolCallInput{ + ID: call.ID, + Name: call.Function.Name, + Arguments: call.Function.Arguments, + }) + if out.Completed { + results = append(results, ToolRoundResult{ + ToolCallID: call.ID, + Name: call.Function.Name, + Result: "Task completed successfully.", + }) + taskCompleted = true + } else if out.Result != "" { + results = append(results, ToolRoundResult{ + ToolCallID: call.ID, + Name: call.Function.Name, + Result: out.Result, + }) + hasValidResult = true + } else { + results = append(results, ToolRoundResult{ + ToolCallID: call.ID, + Name: call.Function.Name, + Result: "Error: Tool execution returned no result.", + }) + } + } + + if taskCompleted { + break + } + if !hasValidResult { + consecutiveEmptyRounds++ + if consecutiveEmptyRounds >= maxConsecutiveEmptyRounds { + logf("[ocr] Too many empty retries for %s, stopping.\n", req.FilePath) + break + } + logf("[ocr] No valid tool results for %s, retrying...\n", req.FilePath) + } else { + consecutiveEmptyRounds = 0 + } + + if hooks.AppendRound != nil { + if !hooks.AppendRound(content, calls, results, &messages) { + logf("[ocr] Context compression exceeded threshold for %s, stopping.\n", req.FilePath) + break + } + } else { + appendRoundMessages(content, calls, results, &messages) + } + } + + if toolReqCount <= 0 { + logf("[ocr] Max tool requests reached for %s.\n", req.FilePath) + } + + return nil +} + +func appendRoundMessages(assistantContent string, toolCalls []llm.ToolCall, results []ToolRoundResult, messages *[]llm.Message) { + if len(toolCalls) > 0 { + *messages = append(*messages, llm.NewToolCallMessage(assistantContent, toolCalls)) + } else if assistantContent != "" { + *messages = append(*messages, llm.NewTextMessage("assistant", assistantContent)) + } + for _, r := range results { + *messages = append(*messages, llm.NewToolResultMessage(r.ToolCallID, r.Result)) + } +} diff --git a/internal/reviewbackend/chat_completions_test.go b/internal/reviewbackend/chat_completions_test.go new file mode 100644 index 00000000..cc97120b --- /dev/null +++ b/internal/reviewbackend/chat_completions_test.go @@ -0,0 +1,135 @@ +package reviewbackend + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/open-code-review/open-code-review/internal/llm" +) + +type mockLLMClient struct { + responses []*llm.ChatResponse + err error + calls int +} + +func (m *mockLLMClient) CompletionsWithCtx(_ context.Context, _ llm.ChatRequest) (*llm.ChatResponse, error) { + if m.err != nil { + return nil, m.err + } + if m.calls >= len(m.responses) { + return &llm.ChatResponse{}, nil + } + resp := m.responses[m.calls] + m.calls++ + return resp, nil +} + +func toolCallResponse(name, id string) *llm.ChatResponse { + content := "working" + return &llm.ChatResponse{ + Choices: []llm.Choice{{ + Message: llm.ResponseMessage{ + Role: "assistant", + Content: &content, + ToolCalls: []llm.ToolCall{{ + ID: id, + Type: "function", + Function: llm.FunctionCall{ + Name: name, + Arguments: `{}`, + }, + }}, + }, + }}, + } +} + +func TestChatCompletionsBackend_ReviewFile_TaskDone(t *testing.T) { + mock := &mockLLMClient{ + responses: []*llm.ChatResponse{ + toolCallResponse("task_done", "call-1"), + }, + } + backend := &ChatCompletionsBackend{ + client: mock, + ep: llm.ResolvedEndpoint{Model: "test-model"}, + } + + var executed bool + err := backend.ReviewFile(context.Background(), ReviewFileRequest{ + Model: "test-model", + Messages: []llm.Message{llm.NewTextMessage("user", "review")}, + MaxToolRounds: 3, + FilePath: "foo.go", + }, func(_ context.Context, call ToolCallInput) ToolCallOutput { + executed = true + if call.Name != "task_done" { + t.Errorf("tool name = %q", call.Name) + } + return ToolCallOutput{Completed: true} + }, nil) + if err != nil { + t.Fatalf("ReviewFile: %v", err) + } + if !executed { + t.Fatal("executor was not called") + } + if mock.calls != 1 { + t.Errorf("LLM calls = %d, want 1", mock.calls) + } +} + +func TestChatCompletionsBackend_ReviewFile_LLMError(t *testing.T) { + mock := &mockLLMClient{err: errors.New("network down")} + backend := &ChatCompletionsBackend{ + client: mock, + ep: llm.ResolvedEndpoint{Model: "test-model"}, + } + + err := backend.ReviewFile(context.Background(), ReviewFileRequest{ + Messages: []llm.Message{llm.NewTextMessage("user", "review")}, + MaxToolRounds: 1, + FilePath: "foo.go", + }, func(context.Context, ToolCallInput) ToolCallOutput { + t.Fatal("executor should not run on LLM error") + return ToolCallOutput{} + }, nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "network down") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestTextClient_CompleteAdapter(t *testing.T) { + backend := &fakeTextBackend{content: "hello", model: "m1"} + client := TextClient(backend) + resp, err := client.CompletionsWithCtx(context.Background(), llm.ChatRequest{ + Messages: []llm.Message{llm.NewTextMessage("user", "hi")}, + }) + if err != nil { + t.Fatalf("CompletionsWithCtx: %v", err) + } + if resp.Content() != "hello" { + t.Errorf("content = %q", resp.Content()) + } +} + +type fakeTextBackend struct { + content string + model string +} + +func (f *fakeTextBackend) Kind() Kind { return KindCursorAgent } +func (f *fakeTextBackend) Model() string { return f.model } +func (f *fakeTextBackend) Source() string { return "test" } +func (f *fakeTextBackend) Complete(_ context.Context, _ CompleteRequest) (*CompleteResponse, error) { + return &CompleteResponse{Content: f.content, Model: f.model}, nil +} +func (f *fakeTextBackend) ReviewFile(context.Context, ReviewFileRequest, ToolExecutor, *ReviewHooks) error { + return nil +} diff --git a/internal/reviewbackend/cursor_agent.go b/internal/reviewbackend/cursor_agent.go new file mode 100644 index 00000000..349173e1 --- /dev/null +++ b/internal/reviewbackend/cursor_agent.go @@ -0,0 +1,264 @@ +package reviewbackend + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/open-code-review/open-code-review/internal/llm" + "github.com/open-code-review/open-code-review/internal/session" + "github.com/remdev/cursor-go-sdk/cursor" +) + +const bridgeSetupHint = "install Cursor bridge: go run github.com/remdev/cursor-go-sdk/cmd/setup@latest " + + "or npm install -g @cursor-go-sdk/cursor-sdk-bridge@0.0.2" + +// CursorAgentBackend runs review via Cursor Agent SDK local runtime and custom tools. +type CursorAgentBackend struct { + cfg CursorConfig + repoDir string +} + +// NewCursorAgentBackend validates prerequisites and returns a Cursor backend. +func NewCursorAgentBackend(ctx context.Context, cfg CursorConfig, repoDir string) (*CursorAgentBackend, error) { + if err := cursor.EnsureBridgeInstalled(ctx); err != nil { + return nil, fmt.Errorf("cursor bridge not ready: %w; %s", err, bridgeSetupHint) + } + if cfg.APIKey == "" { + return nil, errors.New("cursor api key is required (set providers.cursor.api_key or CURSOR_API_KEY)") + } + if cfg.Model == "" { + return nil, errors.New("cursor model is required") + } + if repoDir == "" { + return nil, errors.New("cursor local agent requires repo directory") + } + return &CursorAgentBackend{cfg: cfg, repoDir: repoDir}, nil +} + +func (b *CursorAgentBackend) Kind() Kind { return KindCursorAgent } + +func (b *CursorAgentBackend) Model() string { return b.cfg.Model } + +func (b *CursorAgentBackend) Source() string { return b.cfg.Source } + +func (b *CursorAgentBackend) Complete(ctx context.Context, req CompleteRequest) (*CompleteResponse, error) { + model := req.Model + if model == "" { + model = b.cfg.Model + } + + agent, err := cursor.CreateAgent(ctx, b.agentOptions(model)) + if err != nil { + return nil, wrapCursorError(err) + } + defer func() { + closeCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + _ = agent.Close(closeCtx) + }() + + usageAcc := &cursorUsageAccumulator{} + run, err := agent.Send(ctx, messagesToPrompt(req.Messages), cursor.SendOptions{ + OnDelta: usageAcc.callback(), + }) + if err != nil { + return nil, wrapCursorError(err) + } + + result, err := run.Wait(ctx) + if err != nil { + return nil, wrapCursorError(err) + } + if result.Status == cursor.RunStatusError { + return nil, fmt.Errorf("cursor prompt failed: %s", result.Result) + } + + return &CompleteResponse{ + Content: result.Result, + Model: model, + Usage: usageAcc.usage(), + }, nil +} + +func (b *CursorAgentBackend) ReviewFile(ctx context.Context, req ReviewFileRequest, exec ToolExecutor, hooks *ReviewHooks) error { + if hooks == nil { + hooks = &ReviewHooks{} + } + + model := req.Model + if model == "" { + model = b.cfg.Model + } + + customTools := toolDefsToCustomTools(ctx, req.Tools, exec) + + agent, err := cursor.CreateAgent(ctx, b.agentOptions(model)) + if err != nil { + return wrapCursorError(err) + } + defer func() { + closeCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + _ = agent.Close(closeCtx) + }() + + var rec *session.TaskRecord + if hooks.AppendTaskRecord != nil { + rec = hooks.AppendTaskRecord(session.MainTask, append([]llm.Message(nil), req.Messages...)) + } + + start := time.Now() + usageAcc := &cursorUsageAccumulator{} + run, err := agent.Send(ctx, messagesToPrompt(req.Messages), cursor.SendOptions{ + Local: &cursor.LocalSendOptions{ + CustomTools: customTools, + }, + OnDelta: usageAcc.callback(), + }) + if err != nil { + if hooks.SetError != nil && rec != nil { + hooks.SetError(rec, err, time.Since(start).Milliseconds()) + } + return wrapCursorError(err) + } + + result, err := run.Wait(ctx) + duration := time.Since(start) + if err != nil { + if hooks.SetError != nil && rec != nil { + hooks.SetError(rec, err, duration.Milliseconds()) + } + return wrapCursorError(err) + } + + if hooks.SetResponse != nil && rec != nil { + content := result.Result + hooks.SetResponse(rec, &llm.ChatResponse{ + Model: model, + Choices: []llm.Choice{{ + Message: llm.ResponseMessage{Role: "assistant", Content: &content}, + FinishReason: string(result.Status), + }}, + Usage: usageAcc.usage(), + }, duration.Milliseconds()) + } + usage := usageAcc.usage() + if usage != nil && hooks.RecordUsage != nil { + hooks.RecordUsage(usage) + } + totalTokens := int64(0) + if usage != nil { + totalTokens = usage.TotalTokens + } + if hooks.RecordLLMRequest != nil { + hooks.RecordLLMRequest(duration.Milliseconds(), totalTokens, string(result.Status)) + } + + switch result.Status { + case cursor.RunStatusFinished: + return nil + case cursor.RunStatusCancelled: + return fmt.Errorf("cursor review cancelled") + case cursor.RunStatusExpired: + return fmt.Errorf("cursor review expired") + case cursor.RunStatusError: + if result.Result != "" { + return fmt.Errorf("cursor review failed: %s", result.Result) + } + return errors.New("cursor review failed") + default: + return nil + } +} + +func (b *CursorAgentBackend) agentOptions(model string) cursor.AgentOptions { + sandboxEnabled := true + return cursor.AgentOptions{ + Model: model, + APIKey: b.cfg.APIKey, + Local: &cursor.LocalAgentOptions{ + CWD: []string{b.repoDir}, + SettingSources: nil, + SandboxOptions: &cursor.SandboxOptions{Enabled: &sandboxEnabled}, + CustomTools: nil, + }, + } +} + +func toolDefsToCustomTools(ctx context.Context, defs []llm.ToolDef, exec ToolExecutor) map[string]cursor.CustomTool { + if len(defs) == 0 { + return nil + } + out := make(map[string]cursor.CustomTool, len(defs)) + for _, def := range defs { + name := def.Function.Name + fn := def.Function + out[name] = cursor.CustomTool{ + Description: fn.Description, + InputSchema: fn.Parameters, + Execute: func(args map[string]any, tctx cursor.CustomToolContext) (any, error) { + raw, err := json.Marshal(args) + if err != nil { + return nil, err + } + result := exec(ctx, ToolCallInput{ + ID: tctx.ToolCallID, + Name: name, + Arguments: string(raw), + }) + if result.Completed { + return "Task completed successfully.", nil + } + if result.Result == "" { + return "Error: Tool execution returned no result.", nil + } + return result.Result, nil + }, + } + } + return out +} + +func messagesToPrompt(msgs []llm.Message) string { + var sb strings.Builder + for _, m := range msgs { + role := m.Role + switch role { + case "system": + role = "System" + case "assistant": + role = "Assistant" + case "user": + role = "User" + case "tool": + role = "Tool" + default: + if role == "" { + role = "User" + } else { + role = strings.ToUpper(role[:1]) + role[1:] + } + } + sb.WriteString(role) + sb.WriteString(":\n") + sb.WriteString(m.ExtractText()) + sb.WriteString("\n\n") + } + return strings.TrimSpace(sb.String()) +} + +func wrapCursorError(err error) error { + var ae *cursor.AgentError + if errors.As(err, &ae) { + return fmt.Errorf("cursor agent error (%s): %s; %s", ae.Code, ae.Message, bridgeSetupHint) + } + if strings.Contains(strings.ToLower(err.Error()), "sandbox") || + strings.Contains(strings.ToLower(err.Error()), "configuration") { + return fmt.Errorf("%w; cursor sandbox may be unavailable on this host — %s", err, bridgeSetupHint) + } + return err +} diff --git a/internal/reviewbackend/cursor_agent_test.go b/internal/reviewbackend/cursor_agent_test.go new file mode 100644 index 00000000..84844c87 --- /dev/null +++ b/internal/reviewbackend/cursor_agent_test.go @@ -0,0 +1,107 @@ +package reviewbackend + +import ( + "context" + "strings" + "testing" + + "github.com/open-code-review/open-code-review/internal/llm" + "github.com/remdev/cursor-go-sdk/cursor" +) + +func TestMessagesToPrompt(t *testing.T) { + prompt := messagesToPrompt([]llm.Message{ + llm.NewTextMessage("system", "You are a reviewer."), + llm.NewTextMessage("user", "Review this file."), + llm.NewTextMessage("assistant", "Checking..."), + }) + + if !strings.Contains(prompt, "System:\nYou are a reviewer.") { + t.Errorf("missing system block: %q", prompt) + } + if !strings.Contains(prompt, "User:\nReview this file.") { + t.Errorf("missing user block: %q", prompt) + } + if !strings.Contains(prompt, "Assistant:\nChecking...") { + t.Errorf("missing assistant block: %q", prompt) + } +} + +func TestToolDefsToCustomTools_Executor(t *testing.T) { + var gotName, gotArgs string + tools := toolDefsToCustomTools(context.Background(), []llm.ToolDef{{ + Type: "function", + Function: llm.FunctionDef{ + Name: "code_comment", + Description: "Leave a comment", + Parameters: map[string]any{"type": "object"}, + }, + }}, func(_ context.Context, call ToolCallInput) ToolCallOutput { + gotName = call.Name + gotArgs = call.Arguments + return ToolCallOutput{Result: `{"ok":true}`} + }) + + tool, ok := tools["code_comment"] + if !ok { + t.Fatal("code_comment tool not mapped") + } + if tool.Description != "Leave a comment" { + t.Errorf("Description = %q", tool.Description) + } + + out, err := tool.Execute(map[string]any{"line": float64(10)}, cursor.CustomToolContext{ToolCallID: "tc-1"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if gotName != "code_comment" { + t.Errorf("executor name = %q", gotName) + } + if !strings.Contains(gotArgs, `"line"`) { + t.Errorf("executor args = %q", gotArgs) + } + if out != `{"ok":true}` { + t.Errorf("result = %v, want json payload", out) + } +} + +func TestToolDefsToCustomTools_TaskDone(t *testing.T) { + tools := toolDefsToCustomTools(context.Background(), []llm.ToolDef{{ + Type: "function", + Function: llm.FunctionDef{ + Name: "task_done", + Description: "Finish review", + }, + }}, func(_ context.Context, _ ToolCallInput) ToolCallOutput { + return ToolCallOutput{Completed: true} + }) + + out, err := tools["task_done"].Execute(nil, cursor.CustomToolContext{ToolCallID: "done-1"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if out != "Task completed successfully." { + t.Errorf("result = %v", out) + } +} + +func TestWrapCursorError_AgentError(t *testing.T) { + err := wrapCursorError(&cursor.AgentError{Code: "auth", Message: "invalid key"}) + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if !strings.Contains(msg, "auth") || !strings.Contains(msg, "invalid key") { + t.Errorf("unexpected message: %s", msg) + } + if !strings.Contains(msg, bridgeSetupHint) { + t.Errorf("missing bridge hint: %s", msg) + } +} + +func TestNewCursorAgentBackend_MissingAPIKey(t *testing.T) { + _, err := NewCursorAgentBackend(context.Background(), CursorConfig{Model: "auto"}, t.TempDir()) + if err == nil { + t.Fatal("expected error for missing api key") + } +} diff --git a/internal/reviewbackend/cursor_usage.go b/internal/reviewbackend/cursor_usage.go new file mode 100644 index 00000000..1c05b68b --- /dev/null +++ b/internal/reviewbackend/cursor_usage.go @@ -0,0 +1,84 @@ +package reviewbackend + +import ( + "sync" + + "github.com/open-code-review/open-code-review/internal/llm" + "github.com/remdev/cursor-go-sdk/cursor" +) + +// cursorUsageAccumulator collects token usage from Cursor stream deltas. +type cursorUsageAccumulator struct { + mu sync.Mutex + prompt int64 + completion int64 + cacheRead int64 + cacheWrite int64 + deltaCompletion int64 + hasTurnUsage bool +} + +func (a *cursorUsageAccumulator) observe(update cursor.InteractionUpdate) { + switch update.Type { + case "turn-ended": + if ui := llm.UsageFromMap(update.Usage); ui != nil { + a.mu.Lock() + if !a.hasTurnUsage { + a.prompt = 0 + a.completion = 0 + a.cacheRead = 0 + a.cacheWrite = 0 + a.hasTurnUsage = true + } + a.mu.Unlock() + a.merge(ui) + } + case "token-delta": + if update.Tokens > 0 { + a.mu.Lock() + if !a.hasTurnUsage { + a.deltaCompletion += int64(update.Tokens) + } + a.mu.Unlock() + } + } +} + +func (a *cursorUsageAccumulator) merge(ui *llm.UsageInfo) { + a.mu.Lock() + defer a.mu.Unlock() + a.prompt += ui.PromptTokens + a.completion += ui.CompletionTokens + a.cacheRead += ui.CacheReadTokens + a.cacheWrite += ui.CacheWriteTokens +} + +func (a *cursorUsageAccumulator) usage() *llm.UsageInfo { + a.mu.Lock() + defer a.mu.Unlock() + + prompt := a.prompt + completion := a.completion + cacheRead := a.cacheRead + cacheWrite := a.cacheWrite + + if !a.hasTurnUsage && a.deltaCompletion > 0 { + completion = a.deltaCompletion + } + + if prompt == 0 && completion == 0 && cacheRead == 0 && cacheWrite == 0 { + return nil + } + total := prompt + completion + cacheRead + cacheWrite + return &llm.UsageInfo{ + TotalTokens: total, + PromptTokens: prompt, + CompletionTokens: completion, + CacheReadTokens: cacheRead, + CacheWriteTokens: cacheWrite, + } +} + +func (a *cursorUsageAccumulator) callback() func(cursor.InteractionUpdate) { + return a.observe +} diff --git a/internal/reviewbackend/cursor_usage_test.go b/internal/reviewbackend/cursor_usage_test.go new file mode 100644 index 00000000..5bf4544e --- /dev/null +++ b/internal/reviewbackend/cursor_usage_test.go @@ -0,0 +1,71 @@ +package reviewbackend + +import ( + "testing" + + "github.com/remdev/cursor-go-sdk/cursor" +) + +func TestCursorUsageAccumulator_TurnEnded(t *testing.T) { + acc := &cursorUsageAccumulator{} + acc.observe(cursor.InteractionUpdate{ + Type: "turn-ended", + Usage: map[string]any{ + "inputTokens": 1200, + "outputTokens": 340, + }, + }) + acc.observe(cursor.InteractionUpdate{ + Type: "turn-ended", + Usage: map[string]any{ + "prompt_tokens": 800, + "completion_tokens": 120, + }, + }) + + usage := acc.usage() + if usage == nil { + t.Fatal("expected usage") + } + if usage.PromptTokens != 2000 { + t.Errorf("PromptTokens = %d, want 2000", usage.PromptTokens) + } + if usage.CompletionTokens != 460 { + t.Errorf("CompletionTokens = %d, want 460", usage.CompletionTokens) + } + if usage.TotalTokens != 2460 { + t.Errorf("TotalTokens = %d, want 2460", usage.TotalTokens) + } +} + +func TestCursorUsageAccumulator_TokenDeltaFallback(t *testing.T) { + acc := &cursorUsageAccumulator{} + acc.observe(cursor.InteractionUpdate{Type: "token-delta", Tokens: 50}) + acc.observe(cursor.InteractionUpdate{Type: "token-delta", Tokens: 25}) + + usage := acc.usage() + if usage == nil { + t.Fatal("expected usage") + } + if usage.CompletionTokens != 75 { + t.Errorf("CompletionTokens = %d, want 75", usage.CompletionTokens) + } +} + +func TestCursorUsageAccumulator_TurnEndedDisablesTokenDelta(t *testing.T) { + acc := &cursorUsageAccumulator{} + acc.observe(cursor.InteractionUpdate{Type: "token-delta", Tokens: 100}) + acc.observe(cursor.InteractionUpdate{ + Type: "turn-ended", + Usage: map[string]any{ + "inputTokens": 10, + "outputTokens": 5, + }, + }) + acc.observe(cursor.InteractionUpdate{Type: "token-delta", Tokens: 999}) + + usage := acc.usage() + if usage.PromptTokens != 10 || usage.CompletionTokens != 5 { + t.Errorf("usage = %+v, want prompt=10 completion=5", usage) + } +} diff --git a/internal/reviewbackend/factory.go b/internal/reviewbackend/factory.go new file mode 100644 index 00000000..53e92e77 --- /dev/null +++ b/internal/reviewbackend/factory.go @@ -0,0 +1,62 @@ +package reviewbackend + +import ( + "context" + "fmt" + + "github.com/open-code-review/open-code-review/internal/llm" +) + +// New creates a Backend from resolved configuration. +func New(ctx context.Context, resolved ResolvedBackend, repoDir string) (Backend, error) { + switch resolved.Kind { + case KindCursorAgent: + return NewCursorAgentBackend(ctx, resolved.Cursor, repoDir) + case KindChatCompletions: + return NewChatCompletionsBackend(resolved.Endpoint), nil + default: + return nil, fmt.Errorf("unsupported backend kind %q", resolved.Kind) + } +} + +// TextClient returns an llm.LLMClient that delegates text completions to the backend. +func TextClient(b Backend) llm.LLMClient { + if cc, ok := b.(*ChatCompletionsBackend); ok { + return cc.client + } + return &completeAdapter{backend: b} +} + +type completeAdapter struct { + backend Backend +} + +func (a *completeAdapter) CompletionsWithCtx(ctx context.Context, req llm.ChatRequest) (*llm.ChatResponse, error) { + model := req.Model + if model == "" { + model = a.backend.Model() + } + resp, err := a.backend.Complete(ctx, CompleteRequest{ + Model: model, + Messages: req.Messages, + MaxTokens: req.MaxTokens, + }) + if err != nil { + return nil, err + } + if resp.Raw != nil { + return resp.Raw, nil + } + content := resp.Content + return &llm.ChatResponse{ + Model: resp.Model, + Choices: []llm.Choice{{ + Message: llm.ResponseMessage{ + Role: "assistant", + Content: &content, + }, + FinishReason: "stop", + }}, + Usage: resp.Usage, + }, nil +} diff --git a/internal/reviewbackend/resolver.go b/internal/reviewbackend/resolver.go new file mode 100644 index 00000000..7b6ec697 --- /dev/null +++ b/internal/reviewbackend/resolver.go @@ -0,0 +1,99 @@ +package reviewbackend + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/open-code-review/open-code-review/internal/llm" +) + +// CursorConfig holds resolved Cursor Agent SDK settings. +type CursorConfig struct { + APIKey string + Model string + Source string +} + +// ResolvedBackend is the outcome of configuration resolution. +type ResolvedBackend struct { + Kind Kind + Endpoint llm.ResolvedEndpoint + Cursor CursorConfig +} + +// ResolveBackend reads OCR config and returns the appropriate backend kind. +func ResolveBackend(configPath string) (ResolvedBackend, error) { + cfg, err := readConfigFile(configPath) + if err != nil { + return ResolvedBackend{}, err + } + if cfg != nil && strings.EqualFold(cfg.Provider, "cursor") { + return resolveCursorProvider(cfg) + } + + ep, err := llm.ResolveEndpoint(configPath) + if err != nil { + return ResolvedBackend{}, err + } + return ResolvedBackend{Kind: KindChatCompletions, Endpoint: ep}, nil +} + +type providerEntry struct { + APIKey string `json:"api_key,omitempty"` + Model string `json:"model,omitempty"` +} + +type configFile struct { + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Providers map[string]providerEntry `json:"providers,omitempty"` +} + +func readConfigFile(path string) (*configFile, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + var cfg configFile + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + return &cfg, nil +} + +func resolveCursorProvider(cfg *configFile) (ResolvedBackend, error) { + entry, ok := cfg.Providers["cursor"] + if !ok { + return ResolvedBackend{}, fmt.Errorf("provider %q is set but not configured in providers section", cfg.Provider) + } + + apiKey := entry.APIKey + if apiKey == "" { + apiKey = os.Getenv("CURSOR_API_KEY") + } + if apiKey == "" { + return ResolvedBackend{}, fmt.Errorf("provider %q has no api_key configured and CURSOR_API_KEY is not set", cfg.Provider) + } + + model := cfg.Model + if entry.Model != "" { + model = entry.Model + } + if model == "" { + return ResolvedBackend{}, fmt.Errorf("provider %q has no model configured; run 'ocr config model' to select one", cfg.Provider) + } + + return ResolvedBackend{ + Kind: KindCursorAgent, + Cursor: CursorConfig{ + APIKey: apiKey, + Model: model, + Source: "provider:cursor", + }, + }, nil +} diff --git a/internal/reviewbackend/resolver_test.go b/internal/reviewbackend/resolver_test.go new file mode 100644 index 00000000..79fca8dd --- /dev/null +++ b/internal/reviewbackend/resolver_test.go @@ -0,0 +1,125 @@ +package reviewbackend + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func writeConfig(t *testing.T, dir string, cfg any) string { + t.Helper() + path := filepath.Join(dir, "config.json") + data, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal config: %v", err) + } + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + return path +} + +func TestResolveBackend_CursorProvider(t *testing.T) { + t.Setenv("CURSOR_API_KEY", "") + cfgPath := writeConfig(t, t.TempDir(), map[string]any{ + "provider": "cursor", + "providers": map[string]any{ + "cursor": map[string]any{ + "api_key": "cursor-test-key", + "model": "composer-2.5", + }, + }, + }) + + resolved, err := ResolveBackend(cfgPath) + if err != nil { + t.Fatalf("ResolveBackend: %v", err) + } + if resolved.Kind != KindCursorAgent { + t.Fatalf("Kind = %q, want %q", resolved.Kind, KindCursorAgent) + } + if resolved.Cursor.APIKey != "cursor-test-key" { + t.Errorf("APIKey = %q, want cursor-test-key", resolved.Cursor.APIKey) + } + if resolved.Cursor.Model != "composer-2.5" { + t.Errorf("Model = %q, want composer-2.5", resolved.Cursor.Model) + } + if resolved.Cursor.Source != "provider:cursor" { + t.Errorf("Source = %q, want provider:cursor", resolved.Cursor.Source) + } + if resolved.Endpoint.URL != "" { + t.Errorf("Endpoint.URL = %q, want empty for cursor", resolved.Endpoint.URL) + } +} + +func TestResolveBackend_CursorEnvAPIKeyFallback(t *testing.T) { + t.Setenv("CURSOR_API_KEY", "env-cursor-key") + cfgPath := writeConfig(t, t.TempDir(), map[string]any{ + "provider": "cursor", + "providers": map[string]any{ + "cursor": map[string]any{ + "model": "auto", + }, + }, + }) + + resolved, err := ResolveBackend(cfgPath) + if err != nil { + t.Fatalf("ResolveBackend: %v", err) + } + if resolved.Cursor.APIKey != "env-cursor-key" { + t.Errorf("APIKey = %q, want env-cursor-key", resolved.Cursor.APIKey) + } +} + +func TestResolveBackend_CursorMissingAPIKey(t *testing.T) { + t.Setenv("CURSOR_API_KEY", "") + cfgPath := writeConfig(t, t.TempDir(), map[string]any{ + "provider": "cursor", + "providers": map[string]any{ + "cursor": map[string]any{ + "model": "auto", + }, + }, + }) + + _, err := ResolveBackend(cfgPath) + if err == nil { + t.Fatal("expected error for missing api key") + } +} + +func TestResolveBackend_CursorMissingModel(t *testing.T) { + cfgPath := writeConfig(t, t.TempDir(), map[string]any{ + "provider": "cursor", + "providers": map[string]any{ + "cursor": map[string]any{ + "api_key": "key", + }, + }, + }) + + _, err := ResolveBackend(cfgPath) + if err == nil { + t.Fatal("expected error for missing model") + } +} + +func TestResolveBackend_ChatCompletionsUnchanged(t *testing.T) { + t.Setenv("OCR_LLM_URL", "https://api.example.com/v1/chat/completions") + t.Setenv("OCR_LLM_TOKEN", "test-token") + t.Setenv("OCR_LLM_MODEL", "gpt-4o") + + cfgPath := filepath.Join(t.TempDir(), "missing.json") + resolved, err := ResolveBackend(cfgPath) + if err != nil { + t.Fatalf("ResolveBackend: %v", err) + } + if resolved.Kind != KindChatCompletions { + t.Fatalf("Kind = %q, want %q", resolved.Kind, KindChatCompletions) + } + if resolved.Endpoint.Model != "gpt-4o" { + t.Errorf("Model = %q, want gpt-4o", resolved.Endpoint.Model) + } +} From a448264e66493f898cf0db5e9f7f30aaf411edc8 Mon Sep 17 00:00:00 2001 From: Mikhail Batukhtin <6481198+remdev@users.noreply.github.com> Date: Mon, 15 Jun 2026 16:24:20 +0300 Subject: [PATCH 2/6] feat(providers): add Cursor preset and token usage parsing Register cursor as a built-in provider and parse usage maps from Cursor stream events so review summaries report token counts. --- internal/llm/providers.go | 18 +++++++++++++- internal/llm/providers_cursor_test.go | 25 +++++++++++++++++++ internal/llm/usage_from_map.go | 22 ++++++++++++++++ internal/llm/usage_from_map_test.go | 36 +++++++++++++++++++++++++++ internal/llm/usage_resolver.go | 6 +++++ 5 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 internal/llm/providers_cursor_test.go create mode 100644 internal/llm/usage_from_map.go create mode 100644 internal/llm/usage_from_map_test.go diff --git a/internal/llm/providers.go b/internal/llm/providers.go index 6f9bc77c..f22d66e1 100644 --- a/internal/llm/providers.go +++ b/internal/llm/providers.go @@ -9,13 +9,18 @@ import ( type Provider struct { Name string DisplayName string - Protocol string // "anthropic" or "openai" + Protocol string // "anthropic", "openai", or "cursor" BaseURL string AuthHeader string // Anthropic-only; empty for OpenAI-compatible EnvVar string // environment variable name for API key fallback Models []string } +// IsCursorAgent reports whether the provider uses the Cursor Agent SDK backend. +func (p Provider) IsCursorAgent() bool { + return p.Protocol == "cursor" +} + var registry = []Provider{ { Name: "anthropic", @@ -102,6 +107,17 @@ var registry = []Provider{ "glm-4.7", }, }, + { + Name: "cursor", + DisplayName: "Cursor Agent SDK", + Protocol: "cursor", + EnvVar: "CURSOR_API_KEY", + Models: []string{ + "auto", + "composer-2.5", + "composer-2", + }, + }, { Name: "mimo", DisplayName: "Xiaomi MiMo API", diff --git a/internal/llm/providers_cursor_test.go b/internal/llm/providers_cursor_test.go new file mode 100644 index 00000000..c085b6ad --- /dev/null +++ b/internal/llm/providers_cursor_test.go @@ -0,0 +1,25 @@ +package llm + +import "testing" + +func TestLookupProvider_CursorDetails(t *testing.T) { + p, ok := LookupProvider("cursor") + if !ok { + t.Fatal("cursor not found in registry") + } + if p.Protocol != "cursor" { + t.Errorf("Protocol = %q, want cursor", p.Protocol) + } + if p.EnvVar != "CURSOR_API_KEY" { + t.Errorf("EnvVar = %q, want CURSOR_API_KEY", p.EnvVar) + } + if !p.IsCursorAgent() { + t.Error("IsCursorAgent() = false, want true") + } + if p.BaseURL != "" { + t.Errorf("BaseURL = %q, want empty for cursor agent backend", p.BaseURL) + } + if len(p.Models) == 0 { + t.Fatal("expected cursor models") + } +} diff --git a/internal/llm/usage_from_map.go b/internal/llm/usage_from_map.go new file mode 100644 index 00000000..a33dccee --- /dev/null +++ b/internal/llm/usage_from_map.go @@ -0,0 +1,22 @@ +package llm + +import "encoding/json" + +// UsageFromMap extracts token usage from a loosely-typed map (e.g. Cursor turn-ended usage). +func UsageFromMap(m map[string]any) *UsageInfo { + if len(m) == 0 { + return nil + } + raw, err := json.Marshal(m) + if err != nil { + return nil + } + if ui := resolveUsage(raw); ui != nil { + return ui + } + raw, err = json.Marshal(map[string]any{"usage": m}) + if err != nil { + return nil + } + return resolveUsage(raw) +} diff --git a/internal/llm/usage_from_map_test.go b/internal/llm/usage_from_map_test.go new file mode 100644 index 00000000..8d19162a --- /dev/null +++ b/internal/llm/usage_from_map_test.go @@ -0,0 +1,36 @@ +package llm + +import "testing" + +func TestUsageFromMap_CursorCamelCase(t *testing.T) { + ui := UsageFromMap(map[string]any{ + "inputTokens": 100, + "outputTokens": 42, + }) + if ui == nil { + t.Fatal("expected usage") + } + if ui.PromptTokens != 100 { + t.Errorf("PromptTokens = %d", ui.PromptTokens) + } + if ui.CompletionTokens != 42 { + t.Errorf("CompletionTokens = %d", ui.CompletionTokens) + } + if ui.TotalTokens != 142 { + t.Errorf("TotalTokens = %d, want 142", ui.TotalTokens) + } +} + +func TestUsageFromMap_OpenAISnakeCase(t *testing.T) { + ui := UsageFromMap(map[string]any{ + "prompt_tokens": 50, + "completion_tokens": 10, + "total_tokens": 60, + }) + if ui == nil { + t.Fatal("expected usage") + } + if ui.TotalTokens != 60 { + t.Errorf("TotalTokens = %d", ui.TotalTokens) + } +} diff --git a/internal/llm/usage_resolver.go b/internal/llm/usage_resolver.go index 13454207..3bd2cae6 100644 --- a/internal/llm/usage_resolver.go +++ b/internal/llm/usage_resolver.go @@ -18,12 +18,16 @@ var promptTokensPaths = []string{ "usage.prompt_tokens", // OpenAI standard "prompt_tokens", // flat at root "data.usage.prompt_tokens", // wrapped in data layer + "usage.inputTokens", // Cursor / camelCase + "inputTokens", } var completionTokensPaths = []string{ "usage.completion_tokens", // OpenAI standard "completion_tokens", // flat at root "data.usage.completion_tokens", // wrapped in data layer + "usage.outputTokens", // Cursor / camelCase + "outputTokens", } var cacheReadTokensPaths = []string{ @@ -45,6 +49,8 @@ var totalTokensPaths = []string{ "usage.total_tokens", // OpenAI standard "total_tokens", // flat at root "data.usage.total_tokens", // wrapped in data layer (some proxy APIs) + "usage.totalTokens", // Cursor / camelCase + "totalTokens", } // resolveUsage parses raw JSON bytes into a map and extracts token usage From 9e8824102badafb4896796e4891ed72019e53b57 Mon Sep 17 00:00:00 2001 From: Mikhail Batukhtin <6481198+remdev@users.noreply.github.com> Date: Mon, 15 Jun 2026 16:24:20 +0300 Subject: [PATCH 3/6] feat(config): add Cursor setup guidance to provider TUI Show Cursor in the official preset list with SDK bridge hints during API key setup and after saving provider configuration. --- cmd/opencodereview/config_cmd_test.go | 28 ++++++++++++++ cmd/opencodereview/provider_cmd.go | 5 +++ cmd/opencodereview/provider_tui.go | 22 +++++++++++ cmd/opencodereview/provider_tui_test.go | 51 +++++++++++++++++++++++++ 4 files changed, 106 insertions(+) diff --git a/cmd/opencodereview/config_cmd_test.go b/cmd/opencodereview/config_cmd_test.go index 34b42344..e1bad354 100644 --- a/cmd/opencodereview/config_cmd_test.go +++ b/cmd/opencodereview/config_cmd_test.go @@ -126,6 +126,34 @@ func TestSetConfigValueProviderEntryExtraBody(t *testing.T) { } } +func TestSetConfigValueProviderCursor(t *testing.T) { + cfg := &Config{} + + if err := setConfigValue(cfg, "provider", "cursor"); err != nil { + t.Fatalf("setConfigValue provider: %v", err) + } + if cfg.Provider != "cursor" { + t.Errorf("Provider = %q, want cursor", cfg.Provider) + } + if cfg.Providers["cursor"].APIKey != "" { + t.Error("expected empty cursor provider entry after provider switch") + } + + if err := setConfigValue(cfg, "providers.cursor.api_key", "cursor-key"); err != nil { + t.Fatalf("setConfigValue api_key: %v", err) + } + if cfg.Providers["cursor"].APIKey != "cursor-key" { + t.Errorf("api_key = %q, want cursor-key", cfg.Providers["cursor"].APIKey) + } + + if err := setConfigValue(cfg, "providers.cursor.model", "composer-2.5"); err != nil { + t.Fatalf("setConfigValue model: %v", err) + } + if cfg.Providers["cursor"].Model != "composer-2.5" { + t.Errorf("model = %q, want composer-2.5", cfg.Providers["cursor"].Model) + } +} + func TestSetConfigValueModelWithCustomProvider(t *testing.T) { cfg := &Config{ Provider: "my-gateway", diff --git a/cmd/opencodereview/provider_cmd.go b/cmd/opencodereview/provider_cmd.go index e62a6b75..c204ddc6 100644 --- a/cmd/opencodereview/provider_cmd.go +++ b/cmd/opencodereview/provider_cmd.go @@ -171,6 +171,11 @@ func applyOfficialProviderConfig(configPath string, cfg *Config, result provider fmt.Printf("\nProvider set to: %s\n", result.provider) fmt.Printf("Model: %s\n", result.model) + if isPreset && preset.IsCursorAgent() { + fmt.Println("\nCursor backend requires the SDK bridge before the first review.") + fmt.Printf("Run: %s\n", cursorBridgeSetupHint) + } + fmt.Println("\nTesting connection...") if err := runLLMTest(); err != nil { fmt.Fprintf(os.Stderr, "Connection test failed: %v\n", err) diff --git a/cmd/opencodereview/provider_tui.go b/cmd/opencodereview/provider_tui.go index 4075360c..7777dbe6 100644 --- a/cmd/opencodereview/provider_tui.go +++ b/cmd/opencodereview/provider_tui.go @@ -903,6 +903,9 @@ func (m providerTUIModel) viewOfficialTab(s *strings.Builder) { s.WriteString(cursor + tuiItemStyle.Render(name)) } s.WriteString("\n") + if subtitle := providerOfficialSubtitle(p); subtitle != "" { + s.WriteString(" " + tuiDimStyle.Render(subtitle) + "\n") + } } } @@ -1109,6 +1112,13 @@ func (m providerTUIModel) viewAPIKey(s *strings.Builder) { if m.activeTab == tabOfficial { provider := m.currentProvider() + if provider.IsCursorAgent() { + s.WriteString("\n") + s.WriteString(tuiDimStyle.Render(" Requires Cursor SDK bridge (Node.js >= 18, npm).")) + s.WriteString("\n") + s.WriteString(tuiDimStyle.Render(" One-time setup: " + cursorBridgeSetupHint)) + s.WriteString("\n") + } if envKey := os.Getenv(provider.EnvVar); envKey != "" { s.WriteString("\n") s.WriteString(tuiDimStyle.Render(fmt.Sprintf(" $%s is set", provider.EnvVar))) @@ -1127,8 +1137,20 @@ func (m providerTUIModel) viewAPIKey(s *strings.Builder) { // --- Styles --- +const cursorBridgeSetupHint = "go run github.com/remdev/cursor-go-sdk/cmd/setup@latest" + const tuiCursor = "▸" +func providerOfficialSubtitle(p llm.Provider) string { + if p.IsCursorAgent() { + return "local agent via Cursor SDK (no HTTP endpoint)" + } + if p.BaseURL != "" { + return p.BaseURL + } + return "" +} + var ( tuiTitleStyle = lipgloss.NewStyle(). Bold(true). diff --git a/cmd/opencodereview/provider_tui_test.go b/cmd/opencodereview/provider_tui_test.go index 2dc6ab4a..96507fef 100644 --- a/cmd/opencodereview/provider_tui_test.go +++ b/cmd/opencodereview/provider_tui_test.go @@ -136,6 +136,57 @@ func TestProviderTUI_TabSwitchOnlyOnStepProvider(t *testing.T) { // --- Official tab tests (updated from original) --- +func TestProviderTUI_OfficialTabIncludesCursor(t *testing.T) { + m := newProviderTUI(&Config{}) + + found := false + for _, p := range m.providers { + if p.Name == "cursor" { + found = true + if !p.IsCursorAgent() { + t.Error("cursor preset should use Cursor agent backend") + } + if subtitle := providerOfficialSubtitle(p); subtitle == "" { + t.Error("expected subtitle for cursor preset") + } + break + } + } + if !found { + t.Fatal("cursor preset missing from official provider list") + } +} + +func TestProviderTUI_CursorProviderSelectsModels(t *testing.T) { + m := newProviderTUI(&Config{}) + idx := -1 + for i, p := range m.providers { + if p.Name == "cursor" { + idx = i + break + } + } + if idx < 0 { + t.Fatal("cursor preset not found") + } + + m.officialIdx = idx + result, _ := m.Update(enterKey()) + m2 := result.(providerTUIModel) + if m2.step != stepModel { + t.Fatalf("step = %d, want stepModel", m2.step) + } + + models := m2.models() + want := map[string]bool{"auto": true, "composer-2.5": true, "composer-2": true} + for _, model := range models { + delete(want, model) + } + if len(want) > 0 { + t.Errorf("missing cursor models: %v", want) + } +} + func TestProviderTUI_EscFromModelGoesBackToProvider(t *testing.T) { m := newProviderTUI(&Config{}) From 579782c69722b3a1d6a4e520818c2cb8fd383216 Mon Sep 17 00:00:00 2001 From: Mikhail Batukhtin <6481198+remdev@users.noreply.github.com> Date: Mon, 15 Jun 2026 16:24:20 +0300 Subject: [PATCH 4/6] chore: bump cursor-go-sdk to v0.0.2 Update provider list test to include cursor in the sorted registry. --- go.mod | 2 +- go.sum | 4 ++-- internal/llm/providers_test.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 1a2fab16..ef306195 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/bmatcuk/doublestar/v4 v4.10.0 github.com/openai/openai-go/v3 v3.39.0 github.com/pkoukk/tiktoken-go v0.1.8 - github.com/remdev/cursor-go-sdk v0.0.0-20260614191545-6e5b35765d3b + github.com/remdev/cursor-go-sdk v0.0.2 go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 diff --git a/go.sum b/go.sum index 25fefa6a..5e459449 100644 --- a/go.sum +++ b/go.sum @@ -73,8 +73,8 @@ github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/Q github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/remdev/cursor-go-sdk v0.0.0-20260614191545-6e5b35765d3b h1:Ajn05OTxZ9VM7oVX2QAg4ADt+yOmy1iDifTtRiw4Q4Q= -github.com/remdev/cursor-go-sdk v0.0.0-20260614191545-6e5b35765d3b/go.mod h1:8+NZymOCljHsqHoor5ZU9o4kuwy573NuY7U8n6jajAY= +github.com/remdev/cursor-go-sdk v0.0.2 h1:2bB/CKYuDWvqImaMvnlv6IfrMBkSouPegVdf+cnVjeY= +github.com/remdev/cursor-go-sdk v0.0.2/go.mod h1:8+NZymOCljHsqHoor5ZU9o4kuwy573NuY7U8n6jajAY= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 h1:uOfcYT+3QungH6tIGSVCR/Y3KJmgJiHcojJbMTPDZAI= diff --git a/internal/llm/providers_test.go b/internal/llm/providers_test.go index d10d4e10..1ad3afc6 100644 --- a/internal/llm/providers_test.go +++ b/internal/llm/providers_test.go @@ -40,7 +40,7 @@ func TestListProviders_Order(t *testing.T) { if len(providers) < 3 { t.Fatalf("expected at least 3 providers, got %d", len(providers)) } - expected := []string{"anthropic", "dashscope", "deepseek", "kimi", "mimo", "minimax", "openai", "volcengine", "z-ai"} + expected := []string{"anthropic", "cursor", "dashscope", "deepseek", "kimi", "mimo", "minimax", "openai", "volcengine", "z-ai"} if len(providers) != len(expected) { t.Fatalf("expected %d providers, got %d", len(expected), len(providers)) } From c3ead9530d0dffc1d4490fa5247e49d18f573a81 Mon Sep 17 00:00:00 2001 From: Mikhail Batukhtin <6481198+remdev@users.noreply.github.com> Date: Mon, 15 Jun 2026 20:44:03 +0300 Subject: [PATCH 5/6] fix(cursor): harden review backend loop and address PR review feedback Improve Cursor agent tool replay, usage telemetry, and comment parsing; apply PR #140 suggestions for backend resolution and error paths. --- cmd/opencodereview/llm_cmd.go | 8 +- cmd/opencodereview/review_cmd.go | 5 +- go.mod | 2 +- go.sum | 4 +- internal/agent/agent.go | 6 + internal/llm/resolver.go | 17 +- internal/reviewbackend/backend.go | 4 + internal/reviewbackend/chat_completions.go | 10 +- internal/reviewbackend/cursor_agent.go | 224 +++++++++++--- internal/reviewbackend/cursor_agent_test.go | 4 +- internal/reviewbackend/cursor_prompt.go | 84 ++++++ internal/reviewbackend/cursor_prompt_test.go | 37 +++ internal/reviewbackend/cursor_text_tools.go | 276 ++++++++++++++++++ .../cursor_text_tools_integration_test.go | 52 ++++ .../reviewbackend/cursor_text_tools_test.go | 60 ++++ internal/reviewbackend/cursor_usage.go | 14 +- internal/reviewbackend/factory.go | 3 + internal/reviewbackend/resolver.go | 22 +- internal/tool/code_comment.go | 98 ++++++- internal/tool/code_comment_test.go | 28 ++ 20 files changed, 875 insertions(+), 83 deletions(-) create mode 100644 internal/reviewbackend/cursor_prompt.go create mode 100644 internal/reviewbackend/cursor_prompt_test.go create mode 100644 internal/reviewbackend/cursor_text_tools.go create mode 100644 internal/reviewbackend/cursor_text_tools_integration_test.go create mode 100644 internal/reviewbackend/cursor_text_tools_test.go create mode 100644 internal/tool/code_comment_test.go diff --git a/cmd/opencodereview/llm_cmd.go b/cmd/opencodereview/llm_cmd.go index 294ff9c2..3216311a 100644 --- a/cmd/opencodereview/llm_cmd.go +++ b/cmd/opencodereview/llm_cmd.go @@ -56,12 +56,8 @@ func runLLMTest() error { llmClient := reviewbackend.TextClient(backend) - model := resolved.Endpoint.Model - source := resolved.Endpoint.Source - if resolved.Kind == reviewbackend.KindCursorAgent { - model = resolved.Cursor.Model - source = resolved.Cursor.Source - } + model := backend.Model() + source := backend.Source() task, err := testconnection.LoadDefault() if err != nil { diff --git a/cmd/opencodereview/review_cmd.go b/cmd/opencodereview/review_cmd.go index 00ac043a..566a3316 100644 --- a/cmd/opencodereview/review_cmd.go +++ b/cmd/opencodereview/review_cmd.go @@ -99,10 +99,7 @@ func runReview(args []string) error { } llmClient := reviewbackend.TextClient(backend) - model := resolved.Endpoint.Model - if resolved.Kind == reviewbackend.KindCursorAgent { - model = resolved.Cursor.Model - } + model := backend.Model() gitRunner := gitcmd.New(opts.maxGitProcs) diff --git a/go.mod b/go.mod index ef306195..46b3f523 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/bmatcuk/doublestar/v4 v4.10.0 github.com/openai/openai-go/v3 v3.39.0 github.com/pkoukk/tiktoken-go v0.1.8 - github.com/remdev/cursor-go-sdk v0.0.2 + github.com/remdev/cursor-go-sdk v0.0.3 go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 diff --git a/go.sum b/go.sum index 5e459449..c8bb6327 100644 --- a/go.sum +++ b/go.sum @@ -73,8 +73,8 @@ github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/Q github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/remdev/cursor-go-sdk v0.0.2 h1:2bB/CKYuDWvqImaMvnlv6IfrMBkSouPegVdf+cnVjeY= -github.com/remdev/cursor-go-sdk v0.0.2/go.mod h1:8+NZymOCljHsqHoor5ZU9o4kuwy573NuY7U8n6jajAY= +github.com/remdev/cursor-go-sdk v0.0.3 h1:iDkAip8KdmXg5OID1F9wNbufmInlIcGEE74+8BantFE= +github.com/remdev/cursor-go-sdk v0.0.3/go.mod h1:8+NZymOCljHsqHoor5ZU9o4kuwy573NuY7U8n6jajAY= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 h1:uOfcYT+3QungH6tIGSVCR/Y3KJmgJiHcojJbMTPDZAI= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 9d45e3ac..80022fb0 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -960,10 +960,16 @@ func (a *Agent) performLlmCodeReview(ctx context.Context, messages []llm.Message }, } + toolsPrompt := "" + if a.args.Backend.Kind() == reviewbackend.KindCursorAgent { + toolsPrompt = reviewbackend.FormatCursorToolDefs(a.args.MainToolDefs) + } + return a.args.Backend.ReviewFile(ctx, reviewbackend.ReviewFileRequest{ Model: a.args.Model, Messages: messages, Tools: a.args.MainToolDefs, + ToolsPrompt: toolsPrompt, MaxTokens: a.args.Template.MaxTokens, MaxToolRounds: a.args.Template.MaxToolRequestTimes, FilePath: newPath, diff --git a/internal/llm/resolver.go b/internal/llm/resolver.go index 6bb1d05d..4f3eb1f6 100644 --- a/internal/llm/resolver.go +++ b/internal/llm/resolver.go @@ -139,17 +139,28 @@ func tryOCRConfig(path string) (ResolvedEndpoint, bool, error) { } return ResolvedEndpoint{}, false, err } + return TryOCRConfigBytes(data) +} +// TryOCRConfigBytes resolves an endpoint from already-read OCR config JSON. +func TryOCRConfigBytes(data []byte) (ResolvedEndpoint, bool, error) { var cfg configFile if err := json.Unmarshal(data, &cfg); err != nil { return ResolvedEndpoint{}, false, fmt.Errorf("parse config: %w", err) } + var ep ResolvedEndpoint + var ok bool + var err error if cfg.Provider != "" { - return tryProviderConfig(cfg) + ep, ok, err = tryProviderConfig(cfg) + } else { + ep, ok, err = tryLegacyLlmConfig(cfg) } - - return tryLegacyLlmConfig(cfg) + if ok { + ep.Model = stripModelSuffix(ep.Model) + } + return ep, ok, err } // tryProviderConfig resolves an endpoint from the provider-based configuration. diff --git a/internal/reviewbackend/backend.go b/internal/reviewbackend/backend.go index b358919b..09265a4e 100644 --- a/internal/reviewbackend/backend.go +++ b/internal/reviewbackend/backend.go @@ -38,6 +38,8 @@ type ReviewFileRequest struct { MaxTokens int MaxToolRounds int FilePath string + // ToolsPrompt is human-readable tool guidance for Cursor MCP custom tools. + ToolsPrompt string } // ToolCallInput is passed to the tool executor from any backend. @@ -58,6 +60,8 @@ type ToolExecutor func(ctx context.Context, call ToolCallInput) ToolCallOutput // ReviewHooks wires agent-level session, telemetry, and compression into a backend loop. type ReviewHooks struct { + // AppendTaskRecord must be called at the start of each review round, before exec + // invokes tools, so tool results are recorded on the active task record. AppendTaskRecord func(taskType session.TaskType, messages []llm.Message) *session.TaskRecord SetResponse func(rec *session.TaskRecord, resp *llm.ChatResponse, durationMs int64) SetError func(rec *session.TaskRecord, err error, durationMs int64) diff --git a/internal/reviewbackend/chat_completions.go b/internal/reviewbackend/chat_completions.go index cd7dbd70..c38712be 100644 --- a/internal/reviewbackend/chat_completions.go +++ b/internal/reviewbackend/chat_completions.go @@ -34,13 +34,11 @@ func (b *ChatCompletionsBackend) Complete(ctx context.Context, req CompleteReque if model == "" { model = b.ep.Model } - start := time.Now() resp, err := b.client.CompletionsWithCtx(ctx, llm.ChatRequest{ Model: model, Messages: req.Messages, MaxTokens: req.MaxTokens, }) - _ = start if err != nil { return nil, err } @@ -132,9 +130,13 @@ func (b *ChatCompletionsBackend) ReviewFile(ctx context.Context, req ReviewFileR if len(calls) == 0 { logf("[ocr] No tool calls parsed for %s, retrying...\n", req.FilePath) - messages = append(messages, llm.NewTextMessage("user", "You did not successfully call any tools. Please try again or use task_done if finished.")) if content != "" { - messages = append(messages[:len(messages)-1], llm.NewTextMessage("assistant", content), messages[len(messages)-1]) + messages = append(messages, + llm.NewTextMessage("assistant", content), + llm.NewTextMessage("user", "You did not successfully call any tools. Please try again or use task_done if finished."), + ) + } else { + messages = append(messages, llm.NewTextMessage("user", "You did not successfully call any tools. Please try again or use task_done if finished.")) } continue } diff --git a/internal/reviewbackend/cursor_agent.go b/internal/reviewbackend/cursor_agent.go index 349173e1..a6132a66 100644 --- a/internal/reviewbackend/cursor_agent.go +++ b/internal/reviewbackend/cursor_agent.go @@ -14,7 +14,7 @@ import ( ) const bridgeSetupHint = "install Cursor bridge: go run github.com/remdev/cursor-go-sdk/cmd/setup@latest " + - "or npm install -g @cursor-go-sdk/cursor-sdk-bridge@0.0.2" + "or npm install -g @cursor-go-sdk/cursor-sdk-bridge@0.0.3" // CursorAgentBackend runs review via Cursor Agent SDK local runtime and custom tools. type CursorAgentBackend struct { @@ -51,7 +51,7 @@ func (b *CursorAgentBackend) Complete(ctx context.Context, req CompleteRequest) model = b.cfg.Model } - agent, err := cursor.CreateAgent(ctx, b.agentOptions(model)) + agent, err := cursor.CreateAgent(ctx, b.agentOptions(model, nil)) if err != nil { return nil, wrapCursorError(err) } @@ -88,15 +88,40 @@ func (b *CursorAgentBackend) ReviewFile(ctx context.Context, req ReviewFileReque if hooks == nil { hooks = &ReviewHooks{} } + logf := hooks.Logf + if logf == nil { + logf = func(string, ...any) {} + } model := req.Model if model == "" { model = b.cfg.Model } - customTools := toolDefsToCustomTools(ctx, req.Tools, exec) + tracker := &cursorReviewTracker{} + mcpExec := func(callCtx context.Context, call ToolCallInput) ToolCallOutput { + tracker.markTool(call.Name) + if call.Name == "code_comment" { + tracker.mcpCodeComment = true + } + out := exec(callCtx, call) + if out.Completed { + tracker.taskDone = true + } + return out + } + replayExec := func(callCtx context.Context, call ToolCallInput) ToolCallOutput { + tracker.markTool(call.Name) + out := exec(callCtx, call) + if out.Completed { + tracker.taskDone = true + } + return out + } + + customTools := toolDefsToCustomTools(ctx, req.Tools, req.FilePath, mcpExec) - agent, err := cursor.CreateAgent(ctx, b.agentOptions(model)) + agent, err := cursor.CreateAgent(ctx, b.agentOptions(model, customTools)) if err != nil { return wrapCursorError(err) } @@ -106,58 +131,163 @@ func (b *CursorAgentBackend) ReviewFile(ctx context.Context, req ReviewFileReque _ = agent.Close(closeCtx) }() - var rec *session.TaskRecord - if hooks.AppendTaskRecord != nil { - rec = hooks.AppendTaskRecord(session.MainTask, append([]llm.Message(nil), req.Messages...)) + maxRounds := req.MaxToolRounds + if maxRounds <= 0 { + maxRounds = 1 } - start := time.Now() - usageAcc := &cursorUsageAccumulator{} - run, err := agent.Send(ctx, messagesToPrompt(req.Messages), cursor.SendOptions{ - Local: &cursor.LocalSendOptions{ - CustomTools: customTools, - }, - OnDelta: usageAcc.callback(), - }) - if err != nil { - if hooks.SetError != nil && rec != nil { - hooks.SetError(rec, err, time.Since(start).Milliseconds()) + prompt := buildCursorReviewPrompt(req.Messages, req.ToolsPrompt) + basePrompt := prompt + consecutiveNoTools := 0 + + for round := 0; round < maxRounds; round++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: } - return wrapCursorError(err) - } - result, err := run.Wait(ctx) - duration := time.Since(start) - if err != nil { - if hooks.SetError != nil && rec != nil { - hooks.SetError(rec, err, duration.Milliseconds()) + usageAcc := &cursorUsageAccumulator{} + sendOpts := cursor.SendOptions{ + Mode: cursor.AgentModeAgent, + OnDelta: usageAcc.callback(), } - return wrapCursorError(err) + + var rec *session.TaskRecord + if hooks.AppendTaskRecord != nil { + msgs := req.Messages + if round > 0 { + msgs = []llm.Message{llm.NewTextMessage("user", prompt)} + } + rec = hooks.AppendTaskRecord(session.MainTask, append([]llm.Message(nil), msgs...)) + } + + start := time.Now() + run, err := agent.Send(ctx, prompt, sendOpts) + if err != nil { + durationMs := time.Since(start).Milliseconds() + if hooks.SetError != nil && rec != nil { + hooks.SetError(rec, err, durationMs) + } + recordCursorRoundMetrics(hooks, usageAcc, durationMs, "error") + return wrapCursorError(err) + } + + result, err := run.Wait(ctx) + durationMs := time.Since(start).Milliseconds() + if err != nil { + if hooks.SetError != nil && rec != nil { + hooks.SetError(rec, err, durationMs) + } + recordCursorRoundMetrics(hooks, usageAcc, durationMs, "error") + return wrapCursorError(err) + } + + if hooks.SetResponse != nil && rec != nil { + content := result.Result + hooks.SetResponse(rec, &llm.ChatResponse{ + Model: model, + Choices: []llm.Choice{{ + Message: llm.ResponseMessage{Role: "assistant", Content: &content}, + FinishReason: string(result.Status), + }}, + Usage: usageAcc.usage(), + }, durationMs) + } + + if err := cursorRunStatusError(result); err != nil { + if hooks.SetError != nil && rec != nil { + hooks.SetError(rec, err, durationMs) + } + recordCursorRoundMetrics(hooks, usageAcc, durationMs, "error") + return err + } + recordCursorRoundMetrics(hooks, usageAcc, durationMs, string(result.Status)) + + if result.Result != "" { + replayCursorTextToolCalls(ctx, result.Result, req.FilePath, replayExec, logf, tracker.mcpCodeComment) + } + + if tracker.taskDone { + return nil + } + + invokedTools := tracker.roundToolInvoked + tracker.roundToolInvoked = false + tracker.mcpCodeComment = false + + if !invokedTools { + consecutiveNoTools++ + if consecutiveNoTools >= 3 { + logf("[ocr] Cursor agent did not call tools for %s after %d attempts, stopping.\n", req.FilePath, consecutiveNoTools) + break + } + logf("[ocr] Cursor agent replied without tools for %s, retrying...\n", req.FilePath) + prompt = basePrompt + "\n\n" + cursorReviewNudge(tracker, true) + continue + } + + consecutiveNoTools = 0 + if tracker.commentCalls > 0 { + logf("[ocr] Cursor agent left comments for %s without task_done, retrying...\n", req.FilePath) + } else { + logf("[ocr] Cursor agent called tools for %s without code_comment, retrying...\n", req.FilePath) + } + prompt = basePrompt + "\n\n" + cursorReviewNudge(tracker, false) } - if hooks.SetResponse != nil && rec != nil { - content := result.Result - hooks.SetResponse(rec, &llm.ChatResponse{ - Model: model, - Choices: []llm.Choice{{ - Message: llm.ResponseMessage{Role: "assistant", Content: &content}, - FinishReason: string(result.Status), - }}, - Usage: usageAcc.usage(), - }, duration.Milliseconds()) + if tracker.taskDone { + return nil + } + if tracker.commentCalls > 0 { + logf("[ocr] Cursor review for %s finished with %d comment(s) without task_done.\n", req.FilePath, tracker.commentCalls) + return nil } - usage := usageAcc.usage() - if usage != nil && hooks.RecordUsage != nil { + logf("[ocr] Max tool requests reached for %s without review comments.\n", req.FilePath) + return fmt.Errorf("cursor review incomplete for %s", req.FilePath) +} + +type cursorReviewTracker struct { + taskDone bool + commentCalls int + roundToolInvoked bool + mcpCodeComment bool +} + +func (t *cursorReviewTracker) markTool(name string) { + t.roundToolInvoked = true + if name == "code_comment" { + t.commentCalls++ + } +} + +func recordCursorRoundMetrics(hooks *ReviewHooks, usageAcc *cursorUsageAccumulator, durationMs int64, status string) { + if hooks == nil { + return + } + if usage := usageAcc.usage(); usage != nil && hooks.RecordUsage != nil { hooks.RecordUsage(usage) } totalTokens := int64(0) - if usage != nil { + if usage := usageAcc.usage(); usage != nil { totalTokens = usage.TotalTokens } if hooks.RecordLLMRequest != nil { - hooks.RecordLLMRequest(duration.Milliseconds(), totalTokens, string(result.Status)) + hooks.RecordLLMRequest(durationMs, totalTokens, status) } +} +func cursorReviewNudge(tracker *cursorReviewTracker, noTools bool) string { + if noTools { + return "You must not reply with a markdown review. For each confirmed issue in the diff, call the code_comment tool with structured comments (path, start_line, end_line, content). When finished, call task_done." + } + if tracker.commentCalls > 0 { + return "Comments received. Call task_done if the review is complete, or call code_comment for any remaining issues." + } + return "Call code_comment for each confirmed issue, then call task_done when finished." +} + +func cursorRunStatusError(result cursor.RunResult) error { switch result.Status { case cursor.RunStatusFinished: return nil @@ -175,21 +305,24 @@ func (b *CursorAgentBackend) ReviewFile(ctx context.Context, req ReviewFileReque } } -func (b *CursorAgentBackend) agentOptions(model string) cursor.AgentOptions { - sandboxEnabled := true +func (b *CursorAgentBackend) agentOptions(model string, customTools map[string]cursor.CustomTool) cursor.AgentOptions { + // Sandbox is disabled so MCP custom-user-tools can run. Operators should treat + // the local Cursor agent as trusted code with full repo access (see Cursor SDK docs). + sandboxEnabled := false return cursor.AgentOptions{ Model: model, APIKey: b.cfg.APIKey, + Mode: cursor.AgentModeAgent, Local: &cursor.LocalAgentOptions{ CWD: []string{b.repoDir}, SettingSources: nil, SandboxOptions: &cursor.SandboxOptions{Enabled: &sandboxEnabled}, - CustomTools: nil, + CustomTools: customTools, }, } } -func toolDefsToCustomTools(ctx context.Context, defs []llm.ToolDef, exec ToolExecutor) map[string]cursor.CustomTool { +func toolDefsToCustomTools(ctx context.Context, defs []llm.ToolDef, filePath string, exec ToolExecutor) map[string]cursor.CustomTool { if len(defs) == 0 { return nil } @@ -201,6 +334,9 @@ func toolDefsToCustomTools(ctx context.Context, defs []llm.ToolDef, exec ToolExe Description: fn.Description, InputSchema: fn.Parameters, Execute: func(args map[string]any, tctx cursor.CustomToolContext) (any, error) { + if name == "code_comment" { + args = normalizeCodeCommentArgs(args, filePath) + } raw, err := json.Marshal(args) if err != nil { return nil, err diff --git a/internal/reviewbackend/cursor_agent_test.go b/internal/reviewbackend/cursor_agent_test.go index 84844c87..5b3e66a1 100644 --- a/internal/reviewbackend/cursor_agent_test.go +++ b/internal/reviewbackend/cursor_agent_test.go @@ -36,7 +36,7 @@ func TestToolDefsToCustomTools_Executor(t *testing.T) { Description: "Leave a comment", Parameters: map[string]any{"type": "object"}, }, - }}, func(_ context.Context, call ToolCallInput) ToolCallOutput { + }}, "foo.go", func(_ context.Context, call ToolCallInput) ToolCallOutput { gotName = call.Name gotArgs = call.Arguments return ToolCallOutput{Result: `{"ok":true}`} @@ -72,7 +72,7 @@ func TestToolDefsToCustomTools_TaskDone(t *testing.T) { Name: "task_done", Description: "Finish review", }, - }}, func(_ context.Context, _ ToolCallInput) ToolCallOutput { + }}, "foo.go", func(_ context.Context, _ ToolCallInput) ToolCallOutput { return ToolCallOutput{Completed: true} }) diff --git a/internal/reviewbackend/cursor_prompt.go b/internal/reviewbackend/cursor_prompt.go new file mode 100644 index 00000000..37d761ac --- /dev/null +++ b/internal/reviewbackend/cursor_prompt.go @@ -0,0 +1,84 @@ +package reviewbackend + +import ( + "fmt" + "sort" + "strings" + + "github.com/open-code-review/open-code-review/internal/llm" +) + +// buildCursorReviewPrompt formats OCR messages for Cursor agent.send. +// Custom tools are registered as MCP server "custom-user-tools"; the prompt must +// tell the model to call them instead of emitting a markdown review. +func buildCursorReviewPrompt(msgs []llm.Message, toolsPrompt string) string { + var sb strings.Builder + sb.WriteString(messagesToPrompt(msgs)) + sb.WriteString("\n\n## OCR review tools (MCP: custom-user-tools)\n") + sb.WriteString("Do not write a markdown code review. Use the tools below:\n") + sb.WriteString("- Call `code_comment` for each confirmed issue (use the comments[] schema with existing_code from the diff).\n") + sb.WriteString("- Call `task_done` when finished.\n") + sb.WriteString("- Use context tools only when needed to confirm an issue in the current diff.\n") + if toolsPrompt != "" { + sb.WriteString("\n") + sb.WriteString(toolsPrompt) + } + return strings.TrimSpace(sb.String()) +} + +// normalizeCursorToolName strips MCP server prefixes (e.g. custom-user-tools/code_comment). +func normalizeCursorToolName(name string) string { + if i := strings.LastIndex(name, "/"); i >= 0 { + name = name[i+1:] + } + if i := strings.LastIndex(name, "__"); i >= 0 { + name = name[i+2:] + } + return name +} + +// FormatCursorToolDefs renders main-task tool definitions for Cursor MCP prompts. +func FormatCursorToolDefs(defs []llm.ToolDef) string { + return formatCursorToolDefs(defs) +} + +func formatCursorToolDefs(defs []llm.ToolDef) string { + if len(defs) == 0 { + return "" + } + var sb strings.Builder + for _, td := range defs { + fn := &td.Function + sb.WriteString(fmt.Sprintf("- **%s**: %s\n", fn.Name, fn.Description)) + if params, ok := fn.Parameters["properties"].(map[string]any); ok && len(params) > 0 { + sb.WriteString(" Parameters:\n") + required := make(map[string]bool) + if reqList, ok := fn.Parameters["required"].([]any); ok { + for _, r := range reqList { + if s, ok := r.(string); ok { + required[s] = true + } + } + } + names := make([]string, 0, len(params)) + for name := range params { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + p := params[name] + suffix := "" + if required[name] { + suffix = " (required)" + } + if pm, ok := p.(map[string]any); ok { + desc, _ := pm["description"].(string) + sb.WriteString(fmt.Sprintf(" - %s: %s%s\n", name, desc, suffix)) + } else { + sb.WriteString(fmt.Sprintf(" - %s%s\n", name, suffix)) + } + } + } + } + return sb.String() +} diff --git a/internal/reviewbackend/cursor_prompt_test.go b/internal/reviewbackend/cursor_prompt_test.go new file mode 100644 index 00000000..5bb6458d --- /dev/null +++ b/internal/reviewbackend/cursor_prompt_test.go @@ -0,0 +1,37 @@ +package reviewbackend + +import ( + "strings" + "testing" + + "github.com/open-code-review/open-code-review/internal/llm" +) + +func TestBuildCursorReviewPrompt_IncludesTools(t *testing.T) { + prompt := buildCursorReviewPrompt([]llm.Message{ + llm.NewTextMessage("user", "Review this diff."), + }, "- **code_comment**: leave feedback\n") + + if !strings.Contains(prompt, "custom-user-tools") { + t.Fatalf("missing MCP hint: %q", prompt) + } + if !strings.Contains(prompt, "code_comment") { + t.Fatalf("missing tool guidance: %q", prompt) + } + if !strings.Contains(prompt, "Review this diff.") { + t.Fatalf("missing user content: %q", prompt) + } +} + +func TestNormalizeCursorToolName(t *testing.T) { + cases := map[string]string{ + "code_comment": "code_comment", + "custom-user-tools/code_comment": "code_comment", + "custom-user-tools__code_comment": "code_comment", + } + for in, want := range cases { + if got := normalizeCursorToolName(in); got != want { + t.Fatalf("%q => %q, want %q", in, got, want) + } + } +} diff --git a/internal/reviewbackend/cursor_text_tools.go b/internal/reviewbackend/cursor_text_tools.go new file mode 100644 index 00000000..5cc92375 --- /dev/null +++ b/internal/reviewbackend/cursor_text_tools.go @@ -0,0 +1,276 @@ +package reviewbackend + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" +) + +const maxCursorTextReplayCalls = 20 + +var cursorReplayAllowedTools = map[string]bool{ + "code_comment": true, + "task_done": true, +} + +func replayCursorTextToolCalls(ctx context.Context, text, filePath string, exec ToolExecutor, logf func(string, ...any), skipMCPCodeComment bool) { + objects := extractJSONObjectStrings(text) + codeCommentCalls := 0 + var deferredDone *ToolCallInput + + for i, obj := range objects { + select { + case <-ctx.Done(): + return + default: + } + + var payload map[string]any + if err := json.Unmarshal([]byte(obj), &payload); err != nil { + logf("[ocr] cursor text replay: skip invalid JSON: %v\n", err) + continue + } + name := toolNameFromPayload(payload) + if name == "" || !cursorReplayAllowedTools[name] { + continue + } + + args := toolArgsFromPayload(payload) + if name == "code_comment" { + if skipMCPCodeComment { + continue + } + if codeCommentCalls >= maxCursorTextReplayCalls { + logf("[ocr] cursor text replay: skipping code_comment beyond limit %d\n", maxCursorTextReplayCalls) + continue + } + args = normalizeCodeCommentArgs(args, filePath) + if _, ok := args["comments"].([]any); !ok { + continue + } + codeCommentCalls++ + } + + raw, err := json.Marshal(args) + if err != nil { + logf("[ocr] cursor text replay: marshal %s args: %v\n", name, err) + continue + } + + call := ToolCallInput{ + ID: fmt.Sprintf("cursor-text-replay-%d", i), + Name: name, + Arguments: string(raw), + } + if name == "task_done" { + deferred := call + deferredDone = &deferred + continue + } + + exec(ctx, call) + } + + if deferredDone != nil { + out := exec(ctx, *deferredDone) + if out.Completed { + return + } + } +} + +func toolNameFromPayload(payload map[string]any) string { + if name, ok := payload["tool"].(string); ok && name != "" { + return normalizeCursorToolName(name) + } + if name, ok := payload["name"].(string); ok && name != "" { + return normalizeCursorToolName(name) + } + return "" +} + +func toolArgsFromPayload(payload map[string]any) map[string]any { + if args, ok := payload["arguments"].(map[string]any); ok { + return args + } + if raw, ok := payload["arguments"].(string); ok && raw != "" { + var args map[string]any + if err := json.Unmarshal([]byte(raw), &args); err == nil { + return args + } + } + out := cloneMap(payload) + delete(out, "tool") + delete(out, "name") + delete(out, "state") + return out +} + +func normalizeCodeCommentArgs(args map[string]any, defaultPath string) map[string]any { + args = cloneMap(args) + if comments, ok := args["comments"].([]any); ok && len(comments) > 0 { + scopeCommentPaths(args, defaultPath) + return args + } + + path, _ := args["path"].(string) + if path == "" { + path = defaultPath + } + content, _ := args["content"].(string) + if content == "" { + return args + } + + comment := map[string]any{"content": content} + if existing, ok := args["existing_code"].(string); ok && existing != "" { + comment["existing_code"] = existing + } else if startLine, ok := args["start_line"]; ok { + comment["existing_code"] = lineAnchor(startLine) + } else { + comment["existing_code"] = content + } + if suggestion, ok := args["suggestion_code"].(string); ok && suggestion != "" { + comment["suggestion_code"] = suggestion + } + if start, ok := intFromAny(args["start_line"]); ok { + comment["start_line"] = start + comment["end_line"] = start + } + if end, ok := intFromAny(args["end_line"]); ok { + if start, hasStart := intFromAny(args["start_line"]); hasStart { + comment["start_line"] = start + } + comment["end_line"] = end + } + + out := map[string]any{ + "path": path, + "comments": []any{comment}, + } + return out +} + +func scopeCommentPaths(args map[string]any, defaultPath string) { + if defaultPath == "" { + return + } + if p, ok := args["path"].(string); !ok || p == "" { + args["path"] = defaultPath + } + comments, ok := args["comments"].([]any) + if !ok { + return + } + cloned := make([]any, len(comments)) + for i, raw := range comments { + item, ok := raw.(map[string]any) + if !ok { + cloned[i] = raw + continue + } + copyItem := cloneMap(item) + if p, ok := copyItem["path"].(string); !ok || p == "" { + copyItem["path"] = defaultPath + } + cloned[i] = copyItem + } + args["comments"] = cloned +} + +func cloneMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func lineAnchor(line any) string { + if n, ok := intFromAny(line); ok { + return "+// line " + strconv.Itoa(n) + } + return "+// review anchor" +} + +func intFromAny(v any) (int, bool) { + switch t := v.(type) { + case int: + if t <= 0 { + return 0, false + } + return t, true + case int64: + if t <= 0 { + return 0, false + } + return int(t), true + case float64: + if t <= 0 { + return 0, false + } + return int(t), true + case json.Number: + i, err := t.Int64() + if err != nil || i <= 0 { + return 0, false + } + return int(i), true + case string: + i, err := strconv.Atoi(strings.TrimSpace(t)) + if err != nil || i <= 0 { + return 0, false + } + return i, true + default: + return 0, false + } +} + +func extractJSONObjectStrings(text string) []string { + var out []string + depth := 0 + start := -1 + inString := false + escaped := false + for i, ch := range text { + if inString { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + inString = false + } + continue + } + switch ch { + case '"': + inString = true + case '{': + if depth == 0 { + start = i + } + depth++ + case '}': + if depth == 0 { + continue + } + depth-- + if depth == 0 && start >= 0 { + out = append(out, text[start:i+1]) + start = -1 + } + } + } + return out +} diff --git a/internal/reviewbackend/cursor_text_tools_integration_test.go b/internal/reviewbackend/cursor_text_tools_integration_test.go new file mode 100644 index 00000000..f4977ae5 --- /dev/null +++ b/internal/reviewbackend/cursor_text_tools_integration_test.go @@ -0,0 +1,52 @@ +package reviewbackend + +import ( + "context" + "encoding/json" + "testing" + + "github.com/open-code-review/open-code-review/internal/tool" +) + +func TestReplayCursorTextToolCalls_CollectsComments(t *testing.T) { + collector := tool.NewCommentCollector() + provider := &tool.CodeCommentProvider{Collector: collector} + tracker := &cursorReviewTracker{} + baseExec := func(_ context.Context, call ToolCallInput) ToolCallOutput { + var args map[string]any + if err := json.Unmarshal([]byte(call.Arguments), &args); err != nil { + t.Fatalf("unmarshal args: %v", err) + } + result, err := provider.Execute(context.Background(), args) + if err != nil { + t.Fatalf("execute: %v", err) + } + if call.Name == "task_done" { + return ToolCallOutput{Completed: true, Result: result} + } + return ToolCallOutput{Result: result} + } + exec := func(ctx context.Context, call ToolCallInput) ToolCallOutput { + tracker.markTool(call.Name) + out := baseExec(ctx, call) + if out.Completed { + tracker.taskDone = true + } + return out + } + + text := `{"tool":"code_comment","arguments":{"path":"internal/reviewbackend/review_probe.go","start_line":12,"end_line":18,"content":"Hardcoded API key fallback"}} +{"name":"task_done","arguments":{"state":"DONE"}}` + + replayCursorTextToolCalls(context.Background(), text, "internal/reviewbackend/review_probe.go", exec, func(string, ...any) {}, false) + + if len(collector.Comments()) == 0 { + t.Fatal("expected comments in collector") + } + if collector.Comments()[0].Content != "Hardcoded API key fallback" { + t.Fatalf("unexpected content: %q", collector.Comments()[0].Content) + } + if !tracker.taskDone { + t.Fatal("expected task_done via replay") + } +} diff --git a/internal/reviewbackend/cursor_text_tools_test.go b/internal/reviewbackend/cursor_text_tools_test.go new file mode 100644 index 00000000..36c69518 --- /dev/null +++ b/internal/reviewbackend/cursor_text_tools_test.go @@ -0,0 +1,60 @@ +package reviewbackend + +import ( + "context" + "strings" + "testing" +) + +func TestExtractJSONObjectStrings(t *testing.T) { + text := `prefix {"tool":"code_comment","arguments":{"content":"issue"}} suffix {"name":"task_done","arguments":{"state":"DONE"}}` + objs := extractJSONObjectStrings(text) + if len(objs) != 2 { + t.Fatalf("expected 2 objects, got %d", len(objs)) + } + if !strings.Contains(objs[0], "code_comment") { + t.Fatalf("unexpected first object: %s", objs[0]) + } +} + +func TestNormalizeCodeCommentArgs_FlatFormat(t *testing.T) { + args := normalizeCodeCommentArgs(map[string]any{ + "path": "foo.go", + "start_line": float64(12), + "content": "SQL injection risk", + }, "default.go") + + if args["path"] != "foo.go" { + t.Fatalf("path = %v", args["path"]) + } + comments, ok := args["comments"].([]any) + if !ok || len(comments) != 1 { + t.Fatalf("comments = %T %#v", args["comments"], args["comments"]) + } + item := comments[0].(map[string]any) + if item["content"] != "SQL injection risk" { + t.Fatalf("content = %v", item["content"]) + } +} + +func TestReplayCursorTextToolCalls(t *testing.T) { + var calls []string + text := `{"tool":"code_comment","arguments":{"path":"foo.go","start_line":1,"content":"bad"}}` + tracker := &cursorReviewTracker{} + exec := func(_ context.Context, call ToolCallInput) ToolCallOutput { + calls = append(calls, call.Name) + tracker.markTool(call.Name) + if call.Name == "task_done" { + return ToolCallOutput{Completed: true} + } + return ToolCallOutput{Result: "ok"} + } + replayCursorTextToolCalls(context.Background(), text, "foo.go", exec, func(string, ...any) {}, false) + + if len(calls) != 1 || calls[0] != "code_comment" { + t.Fatalf("calls = %v", calls) + } + if tracker.commentCalls != 1 { + t.Fatalf("commentCalls = %d", tracker.commentCalls) + } +} diff --git a/internal/reviewbackend/cursor_usage.go b/internal/reviewbackend/cursor_usage.go index 1c05b68b..b39925b9 100644 --- a/internal/reviewbackend/cursor_usage.go +++ b/internal/reviewbackend/cursor_usage.go @@ -30,8 +30,11 @@ func (a *cursorUsageAccumulator) observe(update cursor.InteractionUpdate) { a.cacheWrite = 0 a.hasTurnUsage = true } + a.prompt += ui.PromptTokens + a.completion += ui.CompletionTokens + a.cacheRead += ui.CacheReadTokens + a.cacheWrite += ui.CacheWriteTokens a.mu.Unlock() - a.merge(ui) } case "token-delta": if update.Tokens > 0 { @@ -44,15 +47,6 @@ func (a *cursorUsageAccumulator) observe(update cursor.InteractionUpdate) { } } -func (a *cursorUsageAccumulator) merge(ui *llm.UsageInfo) { - a.mu.Lock() - defer a.mu.Unlock() - a.prompt += ui.PromptTokens - a.completion += ui.CompletionTokens - a.cacheRead += ui.CacheReadTokens - a.cacheWrite += ui.CacheWriteTokens -} - func (a *cursorUsageAccumulator) usage() *llm.UsageInfo { a.mu.Lock() defer a.mu.Unlock() diff --git a/internal/reviewbackend/factory.go b/internal/reviewbackend/factory.go index 53e92e77..91348562 100644 --- a/internal/reviewbackend/factory.go +++ b/internal/reviewbackend/factory.go @@ -32,6 +32,9 @@ type completeAdapter struct { } func (a *completeAdapter) CompletionsWithCtx(ctx context.Context, req llm.ChatRequest) (*llm.ChatResponse, error) { + if len(req.Tools) > 0 { + return nil, fmt.Errorf("completeAdapter does not support tools; use Backend.ReviewFile instead") + } model := req.Model if model == "" { model = a.backend.Model() diff --git a/internal/reviewbackend/resolver.go b/internal/reviewbackend/resolver.go index 7b6ec697..40b51bee 100644 --- a/internal/reviewbackend/resolver.go +++ b/internal/reviewbackend/resolver.go @@ -25,7 +25,7 @@ type ResolvedBackend struct { // ResolveBackend reads OCR config and returns the appropriate backend kind. func ResolveBackend(configPath string) (ResolvedBackend, error) { - cfg, err := readConfigFile(configPath) + data, cfg, err := readConfigBytes(configPath) if err != nil { return ResolvedBackend{}, err } @@ -33,6 +33,16 @@ func ResolveBackend(configPath string) (ResolvedBackend, error) { return resolveCursorProvider(cfg) } + if len(data) > 0 { + ep, ok, err := llm.TryOCRConfigBytes(data) + if err != nil { + return ResolvedBackend{}, err + } + if ok { + return ResolvedBackend{Kind: KindChatCompletions, Endpoint: ep}, nil + } + } + ep, err := llm.ResolveEndpoint(configPath) if err != nil { return ResolvedBackend{}, err @@ -51,19 +61,19 @@ type configFile struct { Providers map[string]providerEntry `json:"providers,omitempty"` } -func readConfigFile(path string) (*configFile, error) { +func readConfigBytes(path string) ([]byte, *configFile, error) { data, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { - return nil, nil + return nil, nil, nil } - return nil, err + return nil, nil, err } var cfg configFile if err := json.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("parse config: %w", err) + return nil, nil, fmt.Errorf("parse config: %w", err) } - return &cfg, nil + return data, &cfg, nil } func resolveCursorProvider(cfg *configFile) (ResolvedBackend, error) { diff --git a/internal/tool/code_comment.go b/internal/tool/code_comment.go index 5fda7f6f..f495645c 100644 --- a/internal/tool/code_comment.go +++ b/internal/tool/code_comment.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "strconv" + "strings" "github.com/open-code-review/open-code-review/internal/model" ) @@ -43,10 +45,47 @@ func ParseComments(args map[string]any) ([]model.LlmComment, string) { } } if len(rawComments) == 0 { + if content, ok := args["content"].(string); ok && content != "" { + cm := model.LlmComment{Content: content} + if path, ok := args["path"].(string); ok { + cm.Path = path + } + if start, ok := toInt(args["start_line"]); ok { + cm.StartLine = start + cm.EndLine = start + } + if end, ok := toInt(args["end_line"]); ok { + if cm.StartLine > 0 { + cm.EndLine = end + } else { + cm.EndLine = end + cm.StartLine = end + } + } + if existing, ok := args["existing_code"].(string); ok { + cm.ExistingCode = existing + } + if suggestion, ok := args["suggestion_code"].(string); ok { + cm.SuggestionCode = suggestion + } + if thinking, ok := args["thinking"].(string); ok { + cm.Thinking = thinking + } + normalizeCommentLines(&cm) + if cm.Content != "" && cm.Path != "" { + return []model.LlmComment{cm}, "" + } + if content != "" && cm.Path == "" { + return nil, "Error: flat code_comment requires path" + } + } raw, _ := json.Marshal(args) return nil, fmt.Sprintf("Error: 'comments' array is required. Got args: %s", string(raw)) } + topPath, _ := args["path"].(string) + topStart, hasTopStart := toInt(args["start_line"]) + topEnd, hasTopEnd := toInt(args["end_line"]) var comments []model.LlmComment for _, raw := range rawComments { obj, ok := raw.(map[string]any) @@ -68,15 +107,72 @@ func ParseComments(args map[string]any) ([]model.LlmComment, string) { if thinking, ok := obj["thinking"].(string); ok { cm.Thinking = thinking } - if path, ok := args["path"].(string); ok { + if path, ok := obj["path"].(string); ok && path != "" { cm.Path = path + } else if topPath != "" { + cm.Path = topPath + } + if start, ok := toInt(obj["start_line"]); ok { + cm.StartLine = start + if cm.EndLine == 0 { + cm.EndLine = start + } + } else if hasTopStart { + cm.StartLine = topStart + cm.EndLine = topStart + } + if end, ok := toInt(obj["end_line"]); ok { + if cm.StartLine > 0 { + cm.EndLine = end + } else { + cm.StartLine = end + cm.EndLine = end + } + } else if hasTopEnd && cm.StartLine > 0 { + cm.EndLine = topEnd } + normalizeCommentLines(&cm) + if cm.Path == "" || cm.Content == "" { continue } comments = append(comments, cm) } + if len(rawComments) > 0 && len(comments) == 0 { + return nil, "Error: no valid comments parsed from comments[] array" + } return comments, "" } + +func normalizeCommentLines(cm *model.LlmComment) { + if cm.StartLine > 0 && cm.EndLine > 0 && cm.EndLine < cm.StartLine { + cm.EndLine = cm.StartLine + } +} + +func toInt(v any) (int, bool) { + var n int + var ok bool + switch t := v.(type) { + case int: + n, ok = t, true + case int64: + n, ok = int(t), true + case float64: + n, ok = int(t), true + case json.Number: + i, err := t.Int64() + n, ok = int(i), err == nil + case string: + i, err := strconv.Atoi(strings.TrimSpace(t)) + n, ok = i, err == nil + default: + return 0, false + } + if !ok || n <= 0 { + return 0, false + } + return n, true +} diff --git a/internal/tool/code_comment_test.go b/internal/tool/code_comment_test.go new file mode 100644 index 00000000..7a19d891 --- /dev/null +++ b/internal/tool/code_comment_test.go @@ -0,0 +1,28 @@ +package tool + +import "testing" + +func TestParseComments_FlatTopLevelFormat(t *testing.T) { + comments, errMsg := ParseComments(map[string]any{ + "path": "internal/reviewbackend/review_probe.go", + "start_line": float64(12), + "end_line": float64(18), + "content": "Hardcoded API key fallback", + }) + if errMsg != "" { + t.Fatalf("unexpected error: %s", errMsg) + } + if len(comments) != 1 { + t.Fatalf("expected 1 comment, got %d", len(comments)) + } + cm := comments[0] + if cm.Path != "internal/reviewbackend/review_probe.go" { + t.Fatalf("path = %q", cm.Path) + } + if cm.StartLine != 12 || cm.EndLine != 18 { + t.Fatalf("lines = %d-%d", cm.StartLine, cm.EndLine) + } + if cm.Content != "Hardcoded API key fallback" { + t.Fatalf("content = %q", cm.Content) + } +} From b84cceb44ef7c8c04cc9fcbcccdf424ff6454381 Mon Sep 17 00:00:00 2001 From: Mikhail Batukhtin <6481198+remdev@users.noreply.github.com> Date: Mon, 15 Jun 2026 21:20:37 +0300 Subject: [PATCH 6/6] refactor(cursor): drop text replay and simplify MCP-only review loop Remove aggressive fallbacks that inflated comment volume; keep MCP tools as the sole collection path with a single no-tools nudge per round. --- cmd/opencodereview/llm_cmd.go | 37 ++-- internal/agent/agent.go | 2 + internal/reviewbackend/backend.go | 2 + internal/reviewbackend/cursor_agent.go | 68 ++------ internal/reviewbackend/cursor_prompt.go | 4 +- internal/reviewbackend/cursor_prompt_test.go | 2 + internal/reviewbackend/cursor_text_tools.go | 158 +----------------- .../cursor_text_tools_integration_test.go | 52 ------ .../reviewbackend/cursor_text_tools_test.go | 42 +---- internal/reviewbackend/cursor_usage.go | 15 +- internal/reviewbackend/resolver.go | 17 +- internal/reviewbackend/resolver_test.go | 20 +++ internal/tool/code_comment.go | 4 + 13 files changed, 107 insertions(+), 316 deletions(-) delete mode 100644 internal/reviewbackend/cursor_text_tools_integration_test.go diff --git a/cmd/opencodereview/llm_cmd.go b/cmd/opencodereview/llm_cmd.go index 3216311a..933a90de 100644 --- a/cmd/opencodereview/llm_cmd.go +++ b/cmd/opencodereview/llm_cmd.go @@ -49,16 +49,6 @@ func runLLMTest() error { return err } - backend, err := reviewbackend.New(context.Background(), resolved, repoDir) - if err != nil { - return fmt.Errorf("create review backend: %w", err) - } - - llmClient := reviewbackend.TextClient(backend) - - model := backend.Model() - source := backend.Source() - task, err := testconnection.LoadDefault() if err != nil { return fmt.Errorf("load test task config: %w", err) @@ -72,20 +62,29 @@ func runLLMTest() error { timeout = time.Duration(task.Timeout) * time.Second } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + backend, err := reviewbackend.New(ctx, resolved, repoDir) + if err != nil { + return fmt.Errorf("create review backend: %w", err) + } + + llmClient := reviewbackend.TextClient(backend) + + model := backend.Model() + source := backend.Source() + messages := make([]llm.Message, 0, len(task.Messages)) for _, m := range task.Messages { messages = append(messages, llm.Message{Role: m.Role, Content: m.Content}) } - resp, err := func() (*llm.ChatResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return llmClient.CompletionsWithCtx(ctx, llm.ChatRequest{ - Model: model, - Messages: messages, - MaxTokens: 256, - }) - }() + resp, err := llmClient.CompletionsWithCtx(ctx, llm.ChatRequest{ + Model: model, + Messages: messages, + MaxTokens: 256, + }) if err != nil { return fmt.Errorf("llm request failed: %w", err) } diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 80022fb0..fd39d143 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -905,6 +905,8 @@ func (a *Agent) performLlmCodeReview(ctx context.Context, messages []llm.Message } var mainRec *session.TaskRecord + // mainRec is set by AppendTaskRecord at the start of each backend round. + // Backends must invoke tools sequentially after recording the round task. executor := func(execCtx context.Context, call reviewbackend.ToolCallInput) reviewbackend.ToolCallOutput { if execCtx == nil { execCtx = ctx diff --git a/internal/reviewbackend/backend.go b/internal/reviewbackend/backend.go index 09265a4e..e15b941f 100644 --- a/internal/reviewbackend/backend.go +++ b/internal/reviewbackend/backend.go @@ -50,6 +50,8 @@ type ToolCallInput struct { } // ToolCallOutput is returned by the tool executor. +// Backends treat non-empty Result strings as model-visible tool output; errors +// are conventionally prefixed with "Error:" by tool providers. type ToolCallOutput struct { Result string Completed bool diff --git a/internal/reviewbackend/cursor_agent.go b/internal/reviewbackend/cursor_agent.go index a6132a66..3f539e2a 100644 --- a/internal/reviewbackend/cursor_agent.go +++ b/internal/reviewbackend/cursor_agent.go @@ -73,8 +73,8 @@ func (b *CursorAgentBackend) Complete(ctx context.Context, req CompleteRequest) if err != nil { return nil, wrapCursorError(err) } - if result.Status == cursor.RunStatusError { - return nil, fmt.Errorf("cursor prompt failed: %s", result.Result) + if err := cursorRunStatusError(result); err != nil { + return nil, err } return &CompleteResponse{ @@ -100,17 +100,6 @@ func (b *CursorAgentBackend) ReviewFile(ctx context.Context, req ReviewFileReque tracker := &cursorReviewTracker{} mcpExec := func(callCtx context.Context, call ToolCallInput) ToolCallOutput { - tracker.markTool(call.Name) - if call.Name == "code_comment" { - tracker.mcpCodeComment = true - } - out := exec(callCtx, call) - if out.Completed { - tracker.taskDone = true - } - return out - } - replayExec := func(callCtx context.Context, call ToolCallInput) ToolCallOutput { tracker.markTool(call.Name) out := exec(callCtx, call) if out.Completed { @@ -138,9 +127,10 @@ func (b *CursorAgentBackend) ReviewFile(ctx context.Context, req ReviewFileReque prompt := buildCursorReviewPrompt(req.Messages, req.ToolsPrompt) basePrompt := prompt - consecutiveNoTools := 0 for round := 0; round < maxRounds; round++ { + tracker.roundToolInvoked = false + select { case <-ctx.Done(): return ctx.Err() @@ -204,46 +194,23 @@ func (b *CursorAgentBackend) ReviewFile(ctx context.Context, req ReviewFileReque } recordCursorRoundMetrics(hooks, usageAcc, durationMs, string(result.Status)) - if result.Result != "" { - replayCursorTextToolCalls(ctx, result.Result, req.FilePath, replayExec, logf, tracker.mcpCodeComment) - } - if tracker.taskDone { return nil } - - invokedTools := tracker.roundToolInvoked - tracker.roundToolInvoked = false - tracker.mcpCodeComment = false - - if !invokedTools { - consecutiveNoTools++ - if consecutiveNoTools >= 3 { - logf("[ocr] Cursor agent did not call tools for %s after %d attempts, stopping.\n", req.FilePath, consecutiveNoTools) - break - } - logf("[ocr] Cursor agent replied without tools for %s, retrying...\n", req.FilePath) - prompt = basePrompt + "\n\n" + cursorReviewNudge(tracker, true) - continue + if tracker.roundToolInvoked { + break } - - consecutiveNoTools = 0 - if tracker.commentCalls > 0 { - logf("[ocr] Cursor agent left comments for %s without task_done, retrying...\n", req.FilePath) - } else { - logf("[ocr] Cursor agent called tools for %s without code_comment, retrying...\n", req.FilePath) + if round+1 >= maxRounds { + break } - prompt = basePrompt + "\n\n" + cursorReviewNudge(tracker, false) + logf("[ocr] Cursor agent replied without MCP tools for %s, retrying...\n", req.FilePath) + prompt = basePrompt + "\n\n" + cursorReviewNudgeNoTools() } - if tracker.taskDone { - return nil - } - if tracker.commentCalls > 0 { - logf("[ocr] Cursor review for %s finished with %d comment(s) without task_done.\n", req.FilePath, tracker.commentCalls) + if tracker.taskDone || tracker.commentCalls > 0 { return nil } - logf("[ocr] Max tool requests reached for %s without review comments.\n", req.FilePath) + logf("[ocr] Cursor review for %s produced no comments.\n", req.FilePath) return fmt.Errorf("cursor review incomplete for %s", req.FilePath) } @@ -251,7 +218,6 @@ type cursorReviewTracker struct { taskDone bool commentCalls int roundToolInvoked bool - mcpCodeComment bool } func (t *cursorReviewTracker) markTool(name string) { @@ -277,14 +243,8 @@ func recordCursorRoundMetrics(hooks *ReviewHooks, usageAcc *cursorUsageAccumulat } } -func cursorReviewNudge(tracker *cursorReviewTracker, noTools bool) string { - if noTools { - return "You must not reply with a markdown review. For each confirmed issue in the diff, call the code_comment tool with structured comments (path, start_line, end_line, content). When finished, call task_done." - } - if tracker.commentCalls > 0 { - return "Comments received. Call task_done if the review is complete, or call code_comment for any remaining issues." - } - return "Call code_comment for each confirmed issue, then call task_done when finished." +func cursorReviewNudgeNoTools() string { + return "You must not reply with a markdown review. For each confirmed issue in the diff, call the code_comment tool with structured comments (path, start_line, end_line, content). When finished, call task_done." } func cursorRunStatusError(result cursor.RunResult) error { diff --git a/internal/reviewbackend/cursor_prompt.go b/internal/reviewbackend/cursor_prompt.go index 37d761ac..c22a88bd 100644 --- a/internal/reviewbackend/cursor_prompt.go +++ b/internal/reviewbackend/cursor_prompt.go @@ -31,8 +31,8 @@ func normalizeCursorToolName(name string) string { if i := strings.LastIndex(name, "/"); i >= 0 { name = name[i+1:] } - if i := strings.LastIndex(name, "__"); i >= 0 { - name = name[i+2:] + if strings.HasPrefix(name, "custom-user-tools__") { + return strings.TrimPrefix(name, "custom-user-tools__") } return name } diff --git a/internal/reviewbackend/cursor_prompt_test.go b/internal/reviewbackend/cursor_prompt_test.go index 5bb6458d..fca8a108 100644 --- a/internal/reviewbackend/cursor_prompt_test.go +++ b/internal/reviewbackend/cursor_prompt_test.go @@ -28,6 +28,8 @@ func TestNormalizeCursorToolName(t *testing.T) { "code_comment": "code_comment", "custom-user-tools/code_comment": "code_comment", "custom-user-tools__code_comment": "code_comment", + "other__code_comment": "other__code_comment", + "other__task_done": "other__task_done", } for in, want := range cases { if got := normalizeCursorToolName(in); got != want { diff --git a/internal/reviewbackend/cursor_text_tools.go b/internal/reviewbackend/cursor_text_tools.go index 5cc92375..46403596 100644 --- a/internal/reviewbackend/cursor_text_tools.go +++ b/internal/reviewbackend/cursor_text_tools.go @@ -1,113 +1,13 @@ package reviewbackend import ( - "context" "encoding/json" - "fmt" + "math" "strconv" "strings" ) -const maxCursorTextReplayCalls = 20 - -var cursorReplayAllowedTools = map[string]bool{ - "code_comment": true, - "task_done": true, -} - -func replayCursorTextToolCalls(ctx context.Context, text, filePath string, exec ToolExecutor, logf func(string, ...any), skipMCPCodeComment bool) { - objects := extractJSONObjectStrings(text) - codeCommentCalls := 0 - var deferredDone *ToolCallInput - - for i, obj := range objects { - select { - case <-ctx.Done(): - return - default: - } - - var payload map[string]any - if err := json.Unmarshal([]byte(obj), &payload); err != nil { - logf("[ocr] cursor text replay: skip invalid JSON: %v\n", err) - continue - } - name := toolNameFromPayload(payload) - if name == "" || !cursorReplayAllowedTools[name] { - continue - } - - args := toolArgsFromPayload(payload) - if name == "code_comment" { - if skipMCPCodeComment { - continue - } - if codeCommentCalls >= maxCursorTextReplayCalls { - logf("[ocr] cursor text replay: skipping code_comment beyond limit %d\n", maxCursorTextReplayCalls) - continue - } - args = normalizeCodeCommentArgs(args, filePath) - if _, ok := args["comments"].([]any); !ok { - continue - } - codeCommentCalls++ - } - - raw, err := json.Marshal(args) - if err != nil { - logf("[ocr] cursor text replay: marshal %s args: %v\n", name, err) - continue - } - - call := ToolCallInput{ - ID: fmt.Sprintf("cursor-text-replay-%d", i), - Name: name, - Arguments: string(raw), - } - if name == "task_done" { - deferred := call - deferredDone = &deferred - continue - } - - exec(ctx, call) - } - - if deferredDone != nil { - out := exec(ctx, *deferredDone) - if out.Completed { - return - } - } -} - -func toolNameFromPayload(payload map[string]any) string { - if name, ok := payload["tool"].(string); ok && name != "" { - return normalizeCursorToolName(name) - } - if name, ok := payload["name"].(string); ok && name != "" { - return normalizeCursorToolName(name) - } - return "" -} - -func toolArgsFromPayload(payload map[string]any) map[string]any { - if args, ok := payload["arguments"].(map[string]any); ok { - return args - } - if raw, ok := payload["arguments"].(string); ok && raw != "" { - var args map[string]any - if err := json.Unmarshal([]byte(raw), &args); err == nil { - return args - } - } - out := cloneMap(payload) - delete(out, "tool") - delete(out, "name") - delete(out, "state") - return out -} - +// normalizeCodeCommentArgs adapts flat or batch code_comment payloads for ParseComments. func normalizeCodeCommentArgs(args map[string]any, defaultPath string) map[string]any { args = cloneMap(args) if comments, ok := args["comments"].([]any); ok && len(comments) > 0 { @@ -128,9 +28,9 @@ func normalizeCodeCommentArgs(args map[string]any, defaultPath string) map[strin if existing, ok := args["existing_code"].(string); ok && existing != "" { comment["existing_code"] = existing } else if startLine, ok := args["start_line"]; ok { - comment["existing_code"] = lineAnchor(startLine) - } else { - comment["existing_code"] = content + if _, hasLine := intFromAny(startLine); hasLine { + comment["existing_code"] = lineAnchor(startLine) + } } if suggestion, ok := args["suggestion_code"].(string); ok && suggestion != "" { comment["suggestion_code"] = suggestion @@ -146,11 +46,10 @@ func normalizeCodeCommentArgs(args map[string]any, defaultPath string) map[strin comment["end_line"] = end } - out := map[string]any{ + return map[string]any{ "path": path, "comments": []any{comment}, } - return out } func scopeCommentPaths(args map[string]any, defaultPath string) { @@ -211,7 +110,7 @@ func intFromAny(v any) (int, bool) { } return int(t), true case float64: - if t <= 0 { + if t <= 0 || t != math.Trunc(t) || t > float64(math.MaxInt) { return 0, false } return int(t), true @@ -231,46 +130,3 @@ func intFromAny(v any) (int, bool) { return 0, false } } - -func extractJSONObjectStrings(text string) []string { - var out []string - depth := 0 - start := -1 - inString := false - escaped := false - for i, ch := range text { - if inString { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == '"' { - inString = false - } - continue - } - switch ch { - case '"': - inString = true - case '{': - if depth == 0 { - start = i - } - depth++ - case '}': - if depth == 0 { - continue - } - depth-- - if depth == 0 && start >= 0 { - out = append(out, text[start:i+1]) - start = -1 - } - } - } - return out -} diff --git a/internal/reviewbackend/cursor_text_tools_integration_test.go b/internal/reviewbackend/cursor_text_tools_integration_test.go deleted file mode 100644 index f4977ae5..00000000 --- a/internal/reviewbackend/cursor_text_tools_integration_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package reviewbackend - -import ( - "context" - "encoding/json" - "testing" - - "github.com/open-code-review/open-code-review/internal/tool" -) - -func TestReplayCursorTextToolCalls_CollectsComments(t *testing.T) { - collector := tool.NewCommentCollector() - provider := &tool.CodeCommentProvider{Collector: collector} - tracker := &cursorReviewTracker{} - baseExec := func(_ context.Context, call ToolCallInput) ToolCallOutput { - var args map[string]any - if err := json.Unmarshal([]byte(call.Arguments), &args); err != nil { - t.Fatalf("unmarshal args: %v", err) - } - result, err := provider.Execute(context.Background(), args) - if err != nil { - t.Fatalf("execute: %v", err) - } - if call.Name == "task_done" { - return ToolCallOutput{Completed: true, Result: result} - } - return ToolCallOutput{Result: result} - } - exec := func(ctx context.Context, call ToolCallInput) ToolCallOutput { - tracker.markTool(call.Name) - out := baseExec(ctx, call) - if out.Completed { - tracker.taskDone = true - } - return out - } - - text := `{"tool":"code_comment","arguments":{"path":"internal/reviewbackend/review_probe.go","start_line":12,"end_line":18,"content":"Hardcoded API key fallback"}} -{"name":"task_done","arguments":{"state":"DONE"}}` - - replayCursorTextToolCalls(context.Background(), text, "internal/reviewbackend/review_probe.go", exec, func(string, ...any) {}, false) - - if len(collector.Comments()) == 0 { - t.Fatal("expected comments in collector") - } - if collector.Comments()[0].Content != "Hardcoded API key fallback" { - t.Fatalf("unexpected content: %q", collector.Comments()[0].Content) - } - if !tracker.taskDone { - t.Fatal("expected task_done via replay") - } -} diff --git a/internal/reviewbackend/cursor_text_tools_test.go b/internal/reviewbackend/cursor_text_tools_test.go index 36c69518..5a150c52 100644 --- a/internal/reviewbackend/cursor_text_tools_test.go +++ b/internal/reviewbackend/cursor_text_tools_test.go @@ -1,21 +1,6 @@ package reviewbackend -import ( - "context" - "strings" - "testing" -) - -func TestExtractJSONObjectStrings(t *testing.T) { - text := `prefix {"tool":"code_comment","arguments":{"content":"issue"}} suffix {"name":"task_done","arguments":{"state":"DONE"}}` - objs := extractJSONObjectStrings(text) - if len(objs) != 2 { - t.Fatalf("expected 2 objects, got %d", len(objs)) - } - if !strings.Contains(objs[0], "code_comment") { - t.Fatalf("unexpected first object: %s", objs[0]) - } -} +import "testing" func TestNormalizeCodeCommentArgs_FlatFormat(t *testing.T) { args := normalizeCodeCommentArgs(map[string]any{ @@ -35,26 +20,13 @@ func TestNormalizeCodeCommentArgs_FlatFormat(t *testing.T) { if item["content"] != "SQL injection risk" { t.Fatalf("content = %v", item["content"]) } -} - -func TestReplayCursorTextToolCalls(t *testing.T) { - var calls []string - text := `{"tool":"code_comment","arguments":{"path":"foo.go","start_line":1,"content":"bad"}}` - tracker := &cursorReviewTracker{} - exec := func(_ context.Context, call ToolCallInput) ToolCallOutput { - calls = append(calls, call.Name) - tracker.markTool(call.Name) - if call.Name == "task_done" { - return ToolCallOutput{Completed: true} - } - return ToolCallOutput{Result: "ok"} + if item["start_line"] != 12 { + t.Fatalf("start_line = %v", item["start_line"]) } - replayCursorTextToolCalls(context.Background(), text, "foo.go", exec, func(string, ...any) {}, false) +} - if len(calls) != 1 || calls[0] != "code_comment" { - t.Fatalf("calls = %v", calls) - } - if tracker.commentCalls != 1 { - t.Fatalf("commentCalls = %d", tracker.commentCalls) +func TestIntFromAny_RejectsFractionalFloat(t *testing.T) { + if _, ok := intFromAny(12.5); ok { + t.Fatal("expected fractional float to be rejected") } } diff --git a/internal/reviewbackend/cursor_usage.go b/internal/reviewbackend/cursor_usage.go index b39925b9..e81da4f5 100644 --- a/internal/reviewbackend/cursor_usage.go +++ b/internal/reviewbackend/cursor_usage.go @@ -16,6 +16,7 @@ type cursorUsageAccumulator struct { cacheWrite int64 deltaCompletion int64 hasTurnUsage bool + reportedTotal int64 } func (a *cursorUsageAccumulator) observe(update cursor.InteractionUpdate) { @@ -34,6 +35,9 @@ func (a *cursorUsageAccumulator) observe(update cursor.InteractionUpdate) { a.completion += ui.CompletionTokens a.cacheRead += ui.CacheReadTokens a.cacheWrite += ui.CacheWriteTokens + if ui.TotalTokens > 0 { + a.reportedTotal += ui.TotalTokens + } a.mu.Unlock() } case "token-delta": @@ -55,15 +59,22 @@ func (a *cursorUsageAccumulator) usage() *llm.UsageInfo { completion := a.completion cacheRead := a.cacheRead cacheWrite := a.cacheWrite + reportedTotal := a.reportedTotal if !a.hasTurnUsage && a.deltaCompletion > 0 { completion = a.deltaCompletion } - if prompt == 0 && completion == 0 && cacheRead == 0 && cacheWrite == 0 { + componentTotal := prompt + completion + cacheRead + cacheWrite + if componentTotal == 0 && reportedTotal == 0 { return nil } - total := prompt + completion + cacheRead + cacheWrite + total := componentTotal + if total == 0 { + total = reportedTotal + } else if reportedTotal > total { + total = reportedTotal + } return &llm.UsageInfo{ TotalTokens: total, PromptTokens: prompt, diff --git a/internal/reviewbackend/resolver.go b/internal/reviewbackend/resolver.go index 40b51bee..29a29a27 100644 --- a/internal/reviewbackend/resolver.go +++ b/internal/reviewbackend/resolver.go @@ -77,7 +77,7 @@ func readConfigBytes(path string) ([]byte, *configFile, error) { } func resolveCursorProvider(cfg *configFile) (ResolvedBackend, error) { - entry, ok := cfg.Providers["cursor"] + entry, ok := providerEntryCI(cfg.Providers, "cursor") if !ok { return ResolvedBackend{}, fmt.Errorf("provider %q is set but not configured in providers section", cfg.Provider) } @@ -107,3 +107,18 @@ func resolveCursorProvider(cfg *configFile) (ResolvedBackend, error) { }, }, nil } + +func providerEntryCI(m map[string]providerEntry, key string) (providerEntry, bool) { + if m == nil { + return providerEntry{}, false + } + if entry, ok := m[key]; ok { + return entry, true + } + for k, entry := range m { + if strings.EqualFold(k, key) { + return entry, true + } + } + return providerEntry{}, false +} diff --git a/internal/reviewbackend/resolver_test.go b/internal/reviewbackend/resolver_test.go index 79fca8dd..200dcc74 100644 --- a/internal/reviewbackend/resolver_test.go +++ b/internal/reviewbackend/resolver_test.go @@ -53,6 +53,26 @@ func TestResolveBackend_CursorProvider(t *testing.T) { } } +func TestResolveBackend_CursorProviderKeyCaseInsensitive(t *testing.T) { + cfgPath := writeConfig(t, t.TempDir(), map[string]any{ + "provider": "cursor", + "providers": map[string]any{ + "Cursor": map[string]any{ + "api_key": "cursor-test-key", + "model": "composer-2.5", + }, + }, + }) + + resolved, err := ResolveBackend(cfgPath) + if err != nil { + t.Fatalf("ResolveBackend: %v", err) + } + if resolved.Cursor.APIKey != "cursor-test-key" { + t.Errorf("APIKey = %q, want cursor-test-key", resolved.Cursor.APIKey) + } +} + func TestResolveBackend_CursorEnvAPIKeyFallback(t *testing.T) { t.Setenv("CURSOR_API_KEY", "env-cursor-key") cfgPath := writeConfig(t, t.TempDir(), map[string]any{ diff --git a/internal/tool/code_comment.go b/internal/tool/code_comment.go index f495645c..0bd90aae 100644 --- a/internal/tool/code_comment.go +++ b/internal/tool/code_comment.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "math" "strconv" "strings" @@ -161,6 +162,9 @@ func toInt(v any) (int, bool) { case int64: n, ok = int(t), true case float64: + if t <= 0 || t != math.Trunc(t) || t > float64(math.MaxInt) { + return 0, false + } n, ok = int(t), true case json.Number: i, err := t.Int64()