From abca4ff6d48d1e6e27433c58d1d6b3f70fa541fe Mon Sep 17 00:00:00 2001 From: Juniper Alanna <201364921+juniper-shopify@users.noreply.github.com> Date: Tue, 10 Mar 2026 20:36:23 -0400 Subject: [PATCH] Add multiple prompt support to agent cog --- lib/roast/cogs/agent.rb | 2 +- lib/roast/cogs/agent/input.rb | 42 ++-- lib/roast/cogs/agent/providers/claude.rb | 11 +- .../providers/claude/claude_invocation.rb | 8 +- lib/roast/cogs/agent/providers/pi.rb | 17 +- .../cogs/agent/providers/pi/pi_invocation.rb | 12 +- .../functional/roast_examples_test.rb | 4 + test/roast/cogs/agent/input_test.rb | 83 ++++--- .../claude/claude_invocation_test.rb | 21 +- .../roast/cogs/agent/providers/claude_test.rb | 170 ++++++++++++++ .../agent/providers/pi/pi_invocation_test.rb | 21 +- test/roast/cogs/agent/providers/pi_test.rb | 214 ++++++++++++++++++ 12 files changed, 513 insertions(+), 92 deletions(-) create mode 100644 test/roast/cogs/agent/providers/claude_test.rb create mode 100644 test/roast/cogs/agent/providers/pi_test.rb diff --git a/lib/roast/cogs/agent.rb b/lib/roast/cogs/agent.rb index 0e43f5a2..05e57614 100644 --- a/lib/roast/cogs/agent.rb +++ b/lib/roast/cogs/agent.rb @@ -47,7 +47,7 @@ class MissingPromptError < AgentCogError; end # #: (Input) -> Output def execute(input) - puts "[USER PROMPT] #{input.valid_prompt!}" if config.show_prompt? + puts "[USER PROMPT] #{input.prompts.first}" if config.show_prompt? output = provider.invoke(input) # NOTE: If progress is displayed, the agent's response will always be the last progress message, # so showing it again is duplicative. diff --git a/lib/roast/cogs/agent/input.rb b/lib/roast/cogs/agent/input.rb index b947f62e..4edb82cd 100644 --- a/lib/roast/cogs/agent/input.rb +++ b/lib/roast/cogs/agent/input.rb @@ -9,10 +9,14 @@ class Agent < Cog # The agent cog requires a prompt that will be sent to the agent for processing. # Optionally, a session identifier can be provided to maintain context across multiple invocations. class Input < Cog::Input - # The prompt to send to the agent for processing + # The prompts to send to the agent for processing # - #: String? - attr_accessor :prompt + # When multiple prompts are specified, each subsequent prompt is passed to the agent as soon as it completes + # the previous one, in the same session throughout. This can be useful for helping to ensure the agent produces + # final outputs in the form you desire after performing a long and complex task. + # + #: Array[String] + attr_accessor :prompts # Optional session identifier for maintaining conversation context # @@ -28,7 +32,7 @@ class Input < Cog::Input #: () -> void def initialize super - @prompt = nil #: String? + @prompts = [] #: Array[String] end # Validate that the input has all required parameters @@ -37,41 +41,35 @@ def initialize # # #### See Also # - `coerce` - # - `valid_prompt!` # #: () -> void def validate! - valid_prompt! + raise Cog::Input::InvalidInputError, "At least one prompt is required" unless prompts.present? + raise Cog::Input::InvalidInputError, "Blank prompts are not allowed" if prompts.any?(&:blank?) end # Coerce the input from the return value of the input block # # If the input block returns a String, it will be used as the prompt value. + # If the input block returns an Array of Strings, the first will be used as the prompt and the + # rest will be used as finalizers. # # #### See Also # - `validate!` # #: (untyped) -> void def coerce(input_return_value) - if input_return_value.is_a?(String) - self.prompt ||= input_return_value + case input_return_value + when String + self.prompts = [input_return_value] + when Array + self.prompts = input_return_value.map(&:to_s) end end - # Get the validated prompt value - # - # Returns the prompt if it is present, otherwise raises an `InvalidInputError`. - # - # #### See Also - # - `prompt` - # - `validate!` - # - #: () -> String - def valid_prompt! - valid_prompt = @prompt - raise Cog::Input::InvalidInputError, "'prompt' is required" unless valid_prompt.present? - - valid_prompt + #: (String) -> void + def prompt=(prompt) + @prompts = [prompt] end end end diff --git a/lib/roast/cogs/agent/providers/claude.rb b/lib/roast/cogs/agent/providers/claude.rb index adfe045f..3ad52ceb 100644 --- a/lib/roast/cogs/agent/providers/claude.rb +++ b/lib/roast/cogs/agent/providers/claude.rb @@ -18,9 +18,14 @@ def initialize(invocation_result) #: (Agent::Input) -> Agent::Output def invoke(input) - invocation = ClaudeInvocation.new(@config, input) - invocation.run! - Output.new(invocation.result) + invocations = [] #: Array[ClaudeInvocation] + input.prompts.each do |prompt| + invocation = ClaudeInvocation.new(@config, prompt, invocations.last&.result&.session || input.session) + invocation.run! + invocations << invocation + break unless invocation.result.success + end + Output.new(invocations.last.not_nil!.result) end end end diff --git a/lib/roast/cogs/agent/providers/claude/claude_invocation.rb b/lib/roast/cogs/agent/providers/claude/claude_invocation.rb index 391d67bd..d7080289 100644 --- a/lib/roast/cogs/agent/providers/claude/claude_invocation.rb +++ b/lib/roast/cogs/agent/providers/claude/claude_invocation.rb @@ -53,20 +53,20 @@ def initialize end end - #: (Agent::Config, Agent::Input) -> void - def initialize(config, input) + #: (Agent::Config, String, String?) -> void + def initialize(config, prompt, session) @base_command = config.valid_command #: (String | Array[String])? @model = config.valid_model #: String? @append_system_prompt = config.valid_append_system_prompt #: String? @replace_system_prompt = config.valid_replace_system_prompt #: String? @apply_permissions = config.apply_permissions? #: bool @working_directory = config.valid_working_directory #: Pathname? - @prompt = input.valid_prompt! #: String - @session = input.session #: String? @context = Context.new #: Context @result = Result.new #: Result @raw_dump_file = config.valid_dump_raw_agent_messages_to_path #: Pathname? @show_progress = config.show_progress? #: bool + @prompt = prompt + @session = session end #: () -> void diff --git a/lib/roast/cogs/agent/providers/pi.rb b/lib/roast/cogs/agent/providers/pi.rb index e79aced5..c297b4ac 100644 --- a/lib/roast/cogs/agent/providers/pi.rb +++ b/lib/roast/cogs/agent/providers/pi.rb @@ -18,9 +18,20 @@ def initialize(invocation_result) #: (Agent::Input) -> Agent::Output def invoke(input) - invocation = PiInvocation.new(@config, input) - invocation.run! - Output.new(invocation.result) + invocations = [] #: Array[PiInvocation] + input.prompts.each do |prompt| + previous_session = invocations.last&.result&.session + invocation = PiInvocation.new( + @config, + prompt, + previous_session || input.session, + ) + invocation.run! + invocations << invocation + break unless invocation.result.success + end + final_result = invocations.last.not_nil!.result + Output.new(final_result) end end end diff --git a/lib/roast/cogs/agent/providers/pi/pi_invocation.rb b/lib/roast/cogs/agent/providers/pi/pi_invocation.rb index cce898c9..8fe915b5 100644 --- a/lib/roast/cogs/agent/providers/pi/pi_invocation.rb +++ b/lib/roast/cogs/agent/providers/pi/pi_invocation.rb @@ -53,19 +53,21 @@ def initialize end end - #: (Agent::Config, Agent::Input) -> void - def initialize(config, input) + #: (Agent::Config, String, String?) -> void + def initialize(config, prompt, session) @base_command = config.valid_command #: (String | Array[String])? @model = config.valid_model #: String? @append_system_prompt = config.valid_append_system_prompt #: String? @replace_system_prompt = config.valid_replace_system_prompt #: String? @working_directory = config.valid_working_directory #: Pathname? - @prompt = input.valid_prompt! #: String - @session = input.session #: String? + @prompt = prompt #: String + @session = session #: String? @context = Context.new #: Context @result = Result.new #: Result @raw_dump_file = config.valid_dump_raw_agent_messages_to_path #: Pathname? + @show_prompt = config.show_prompt? #: bool @show_progress = config.show_progress? #: bool + @show_response = config.show_response? #: bool @num_turns = 0 #: Integer @total_cost = 0.0 #: Float @model_usage_accumulator = {} #: Hash[String, Hash[Symbol, Numeric]] @@ -78,6 +80,7 @@ def run! raise PiAlreadyStartedError if started? @started = true + puts "[USER PROMPT] #{@prompt}" if @show_prompt @start_time_ms = (Process.clock_gettime(Process::CLOCK_MONOTONIC) * 1000).to_i _stdout, stderr, status = CommandRunner.execute( command_line, @@ -91,6 +94,7 @@ def run! @completed = true @result.success = true finalize_stats! + puts "[AGENT RESPONSE] #{@result.response}" if @show_response else @failed = true @result.success = false diff --git a/test/examples/functional/roast_examples_test.rb b/test/examples/functional/roast_examples_test.rb index 558fe004..8f081a3f 100644 --- a/test/examples/functional/roast_examples_test.rb +++ b/test/examples/functional/roast_examples_test.rb @@ -555,10 +555,14 @@ class RoastExamplesTest < FunctionalTest # When show_progress is enabled (the default), text blocks are accumulated and printed # as a single unit, and [AGENT RESPONSE] is suppressed to avoid duplication expected_stdout = <<~STDOUT + [USER PROMPT] What is the world's largest lake? [USER PROMPT] What is the world's largest lake? Caspian spreads wide— Ancient waters vast and deep, World's largest lake gleams. + [AGENT RESPONSE] Caspian spreads wide— + Ancient waters vast and deep, + World's largest lake gleams. [AGENT STATS] Turns: 1 Duration: 0 seconds Cost (USD): $0.024634 diff --git a/test/roast/cogs/agent/input_test.rb b/test/roast/cogs/agent/input_test.rb index 861fc133..c2b89c60 100644 --- a/test/roast/cogs/agent/input_test.rb +++ b/test/roast/cogs/agent/input_test.rb @@ -10,14 +10,20 @@ def setup @input = Input.new end - test "initialize sets prompt to nil" do - assert_nil @input.prompt + test "initialize sets prompts to empty array" do + assert_equal [], @input.prompts end - test "prompt can be set" do + test "prompt= sets prompts to single-element array" do @input.prompt = "What is 2+2?" - assert_equal "What is 2+2?", @input.prompt + assert_equal ["What is 2+2?"], @input.prompts + end + + test "prompts can be set directly" do + @input.prompts = ["First", "Second"] + + assert_equal ["First", "Second"], @input.prompts end test "session can be set" do @@ -26,79 +32,94 @@ def setup assert_equal "session-123", @input.session end - test "validate! raises error when prompt is nil" do + test "validate! raises error when prompts is empty" do error = assert_raises(Cog::Input::InvalidInputError) do @input.validate! end - assert_equal "'prompt' is required", error.message + assert_equal "At least one prompt is required", error.message end - test "validate! raises error when prompt is empty string" do - @input.prompt = "" + test "validate! raises error when any prompt is blank" do + @input.prompts = ["Valid prompt", " ", "Another"] error = assert_raises(Cog::Input::InvalidInputError) do @input.validate! end - assert_equal "'prompt' is required", error.message + assert_equal "Blank prompts are not allowed", error.message end - test "validate! raises error when prompt is whitespace only" do - @input.prompt = " " + test "validate! succeeds when prompts has at least one element" do + @input.prompt = "What is 2+2?" - error = assert_raises(Cog::Input::InvalidInputError) do + assert_nothing_raised do @input.validate! end - - assert_equal "'prompt' is required", error.message end - test "validate! succeeds when prompt is present" do - @input.prompt = "What is 2+2?" + test "validate! succeeds with multiple prompts" do + @input.prompts = ["First", "Second"] assert_nothing_raised do @input.validate! end end - test "coerce sets prompt from string" do + test "coerce sets prompts from string" do @input.coerce("What is the meaning of life?") - assert_equal "What is the meaning of life?", @input.prompt + assert_equal ["What is the meaning of life?"], @input.prompts end - test "coerce does not override existing prompt" do + test "coerce overrides existing prompts" do @input.prompt = "Original prompt" @input.coerce("New prompt") - assert_equal "Original prompt", @input.prompt + assert_equal ["New prompt"], @input.prompts end - test "coerce does nothing for non-string values" do + test "coerce does nothing for non-string non-array values" do @input.coerce(42) - assert_nil @input.prompt + assert_equal [], @input.prompts end test "coerce does nothing for nil" do @input.coerce(nil) - assert_nil @input.prompt + assert_equal [], @input.prompts end - test "valid_prompt! returns prompt when present" do - @input.prompt = "Test prompt" + test "coerce with array sets all prompts" do + @input.coerce(["Main prompt", "Finalizer 1", "Finalizer 2"]) - assert_equal "Test prompt", @input.valid_prompt! + assert_equal ["Main prompt", "Finalizer 1", "Finalizer 2"], @input.prompts end - test "valid_prompt! raises error when prompt is nil" do - error = assert_raises(Cog::Input::InvalidInputError) do - @input.valid_prompt! - end + test "coerce with single-element array" do + @input.coerce(["Only prompt"]) + + assert_equal ["Only prompt"], @input.prompts + end + + test "coerce with array converts elements to strings" do + @input.coerce(["Main prompt", 42, :symbol]) + + assert_equal ["Main prompt", "42", "symbol"], @input.prompts + end + + test "coerce with empty array sets prompts to empty" do + @input.coerce([]) + + assert_equal [], @input.prompts + end + + test "prompt= overrides all existing prompts" do + @input.prompts = ["First", "Second", "Third"] + @input.prompt = "Only" - assert_equal "'prompt' is required", error.message + assert_equal ["Only"], @input.prompts end end end diff --git a/test/roast/cogs/agent/providers/claude/claude_invocation_test.rb b/test/roast/cogs/agent/providers/claude/claude_invocation_test.rb index e63f3917..a5e4e3ef 100644 --- a/test/roast/cogs/agent/providers/claude/claude_invocation_test.rb +++ b/test/roast/cogs/agent/providers/claude/claude_invocation_test.rb @@ -11,9 +11,7 @@ class ClaudeInvocationTest < ActiveSupport::TestCase def setup @config = Agent::Config.new @config.no_show_progress! - @input = Agent::Input.new - @input.prompt = "Test prompt" - @invocation = ClaudeInvocation.new(@config, @input) + @invocation = ClaudeInvocation.new(@config, "Test prompt", nil) end def success_status @@ -184,7 +182,7 @@ def failure_status test "command_line uses custom command when configured as string" do @config.command("custom-claude --flag") - invocation = ClaudeInvocation.new(@config, @input) + invocation = ClaudeInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) @@ -194,7 +192,7 @@ def failure_status test "command_line uses custom command when configured as array" do @config.command(["my-claude", "--opt"]) - invocation = ClaudeInvocation.new(@config, @input) + invocation = ClaudeInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) @@ -204,7 +202,7 @@ def failure_status test "command_line includes model when configured" do @config.model("claude-opus-4-5-20251101") - invocation = ClaudeInvocation.new(@config, @input) + invocation = ClaudeInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) @@ -215,7 +213,7 @@ def failure_status test "command_line includes replace_system_prompt when configured" do @config.replace_system_prompt("Custom system prompt") - invocation = ClaudeInvocation.new(@config, @input) + invocation = ClaudeInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) @@ -226,7 +224,7 @@ def failure_status test "command_line includes append_system_prompt when configured" do @config.append_system_prompt("Additional instructions") - invocation = ClaudeInvocation.new(@config, @input) + invocation = ClaudeInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) @@ -236,8 +234,7 @@ def failure_status end test "command_line includes session flags when session is set" do - @input.session = "session_123" - invocation = ClaudeInvocation.new(@config, @input) + invocation = ClaudeInvocation.new(@config, "Test prompt", "session_123") command = invocation.send(:command_line) @@ -256,7 +253,7 @@ def failure_status test "command_line omits dangerously-skip-permissions when permissions applied" do @config.apply_permissions! - invocation = ClaudeInvocation.new(@config, @input) + invocation = ClaudeInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) @@ -264,7 +261,7 @@ def failure_status end test "command_line omits dangerously-skip-permissions by default" do - invocation = ClaudeInvocation.new(@config, @input) + invocation = ClaudeInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) diff --git a/test/roast/cogs/agent/providers/claude_test.rb b/test/roast/cogs/agent/providers/claude_test.rb new file mode 100644 index 00000000..6e84702c --- /dev/null +++ b/test/roast/cogs/agent/providers/claude_test.rb @@ -0,0 +1,170 @@ +# frozen_string_literal: true + +require "test_helper" + +module Roast + module Cogs + class Agent < Cog + module Providers + class ClaudeTest < ActiveSupport::TestCase + def setup + @config = Agent::Config.new + @config.no_show_progress! + @provider = Claude.new(@config) + end + + def mock_status(success:) + status = stub("process_status") + status.stubs(success?: success) + status + end + + test "invoke with single prompt runs a single invocation" do + input = Agent::Input.new + input.prompt = "Do something" + + CommandRunner.stubs(:execute).returns(["", "", mock_status(success: true)]) + + output = @provider.invoke(input) + + assert_kind_of Agent::Output, output + end + + test "invoke passes prompt as stdin to the invocation" do + input = Agent::Input.new + input.prompt = "Do something" + + stdin_received = nil + CommandRunner.stubs(:execute).with do |_args, **kwargs| + stdin_received = kwargs[:stdin_content] + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal "Do something", stdin_received + end + + test "invoke passes session from input to first invocation" do + input = Agent::Input.new + input.prompt = "Do something" + input.session = "existing_session" + + args_received = nil + CommandRunner.stubs(:execute).with do |args, **_kwargs| + args_received = args + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_includes args_received, "--resume" + resume_index = args_received.index("--resume") + assert_equal "existing_session", args_received[resume_index + 1] + end + + test "invoke runs all prompts in order" do + input = Agent::Input.new + input.prompts = ["Main task", "Summarize", "Format as JSON"] + + prompts_received = [] + CommandRunner.stubs(:execute).with do |_args, **kwargs| + prompts_received << kwargs[:stdin_content] + result_json = { type: "result", subtype: "success", result: "ok" }.to_json + kwargs[:stdout_handler]&.call(result_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal ["Main task", "Summarize", "Format as JSON"], prompts_received + end + + test "invoke chains session from previous invocation result to next" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalizer"] + + sessions_seen = [] + call_count = 0 + CommandRunner.stubs(:execute).with do |args, **kwargs| + call_count += 1 + if call_count == 1 + result_json = { type: "result", subtype: "success", result: "done", session_id: "session_from_first" }.to_json + else + if args.include?("--resume") + resume_index = args.index("--resume") + sessions_seen << args[resume_index + 1] + end + result_json = { type: "result", subtype: "success", result: "finalized" }.to_json + end + kwargs[:stdout_handler]&.call(result_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal ["session_from_first"], sessions_seen + end + + test "invoke stops on first failed invocation" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalizer 1", "Finalizer 2"] + + call_count = 0 + CommandRunner.stubs(:execute).with do |_args, **_kwargs| + call_count += 1 + true + end.returns(["", "Error occurred", mock_status(success: false)]) + + assert_raises(Claude::ClaudeInvocation::ClaudeFailedError) do + @provider.invoke(input) + end + + assert_equal 1, call_count + end + + test "invoke returns output from last invocation" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalize it"] + + call_count = 0 + CommandRunner.stubs(:execute).with do |_args, **kwargs| + call_count += 1 + result_text = call_count == 1 ? "intermediate" : "final result" + result_json = { type: "result", subtype: "success", result: result_text }.to_json + kwargs[:stdout_handler]&.call(result_json) + true + end.returns(["", "", mock_status(success: true)]) + + output = @provider.invoke(input) + + assert_equal "final result", output.response + end + + test "invoke uses input session when no previous invocation session exists" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalize"] + input.session = "input_session" + + sessions_for_calls = [] + CommandRunner.stubs(:execute).with do |args, **kwargs| + if args.include?("--resume") + resume_index = args.index("--resume") + sessions_for_calls << args[resume_index + 1] + else + sessions_for_calls << nil + end + result_json = { type: "result", subtype: "success", result: "done" }.to_json + kwargs[:stdout_handler]&.call(result_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal ["input_session", "input_session"], sessions_for_calls + end + end + end + end + end +end diff --git a/test/roast/cogs/agent/providers/pi/pi_invocation_test.rb b/test/roast/cogs/agent/providers/pi/pi_invocation_test.rb index 9deb3903..8a95bb74 100644 --- a/test/roast/cogs/agent/providers/pi/pi_invocation_test.rb +++ b/test/roast/cogs/agent/providers/pi/pi_invocation_test.rb @@ -11,10 +11,8 @@ class PiInvocationTest < ActiveSupport::TestCase def setup @config = Agent::Config.new @config.provider(:pi) - @config.no_show_progress! - @input = Agent::Input.new - @input.prompt = "Test prompt" - @invocation = PiInvocation.new(@config, @input) + @config.no_display! + @invocation = PiInvocation.new(@config, "Test prompt", nil) end def success_status @@ -177,7 +175,7 @@ def failure_status test "command_line uses custom command when configured as string" do @config.command("custom-pi --flag") - invocation = PiInvocation.new(@config, @input) + invocation = PiInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) assert_equal "custom-pi", command.first @@ -186,7 +184,7 @@ def failure_status test "command_line uses custom command when configured as array" do @config.command(["my-pi", "--opt"]) - invocation = PiInvocation.new(@config, @input) + invocation = PiInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) assert_equal "my-pi", command.first @@ -195,7 +193,7 @@ def failure_status test "command_line includes model when configured" do @config.model("anthropic/claude-sonnet-4-20250514") - invocation = PiInvocation.new(@config, @input) + invocation = PiInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) model_index = command.index("--model") @@ -205,7 +203,7 @@ def failure_status test "command_line includes replace_system_prompt when configured" do @config.replace_system_prompt("Custom system prompt") - invocation = PiInvocation.new(@config, @input) + invocation = PiInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) prompt_index = command.index("--system-prompt") @@ -215,7 +213,7 @@ def failure_status test "command_line includes append_system_prompt when configured" do @config.append_system_prompt("Additional instructions") - invocation = PiInvocation.new(@config, @input) + invocation = PiInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) prompt_index = command.index("--append-system-prompt") @@ -224,8 +222,7 @@ def failure_status end test "command_line includes fork flag when session is set" do - @input.session = "93b0c56b-b6a9-4b33-8dff-ce0fabceae6d" - invocation = PiInvocation.new(@config, @input) + invocation = PiInvocation.new(@config, "Test prompt", "93b0c56b-b6a9-4b33-8dff-ce0fabceae6d") command = invocation.send(:command_line) assert_includes command, "--fork" @@ -235,7 +232,7 @@ def failure_status end test "command_line includes --no-session when no session is set" do - invocation = PiInvocation.new(@config, @input) + invocation = PiInvocation.new(@config, "Test prompt", nil) command = invocation.send(:command_line) assert_includes command, "--no-session" diff --git a/test/roast/cogs/agent/providers/pi_test.rb b/test/roast/cogs/agent/providers/pi_test.rb new file mode 100644 index 00000000..89e43f7d --- /dev/null +++ b/test/roast/cogs/agent/providers/pi_test.rb @@ -0,0 +1,214 @@ +# frozen_string_literal: true + +require "test_helper" + +module Roast + module Cogs + class Agent < Cog + module Providers + class PiTest < ActiveSupport::TestCase + def setup + @config = Agent::Config.new + @config.provider(:pi) + @config.no_display! + @provider = Pi.new(@config) + end + + def mock_status(success:) + status = stub("process_status") + status.stubs(success?: success) + status + end + + test "invoke with single prompt runs a single invocation" do + input = Agent::Input.new + input.prompt = "Do something" + + CommandRunner.stubs(:execute).returns(["", "", mock_status(success: true)]) + + output = @provider.invoke(input) + + assert_kind_of Agent::Output, output + end + + test "invoke passes prompt as stdin to the invocation" do + input = Agent::Input.new + input.prompt = "Do something" + + stdin_received = nil + CommandRunner.stubs(:execute).with do |_args, **kwargs| + stdin_received = kwargs[:stdin_content] + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal "Do something", stdin_received + end + + test "invoke passes session from input to first invocation" do + input = Agent::Input.new + input.prompt = "Do something" + input.session = "existing_session" + + args_received = nil + CommandRunner.stubs(:execute).with do |args, **_kwargs| + args_received = args + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_includes args_received, "--fork" + fork_index = args_received.index("--fork") + assert_equal "existing_session", args_received[fork_index + 1] + end + + test "invoke runs all prompts in order" do + input = Agent::Input.new + input.prompts = ["Main task", "Summarize", "Format as JSON"] + + prompts_received = [] + CommandRunner.stubs(:execute).with do |_args, **kwargs| + prompts_received << kwargs[:stdin_content] + session_json = { type: "session", id: "session_#{prompts_received.size}" }.to_json + kwargs[:stdout_handler]&.call(session_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal ["Main task", "Summarize", "Format as JSON"], prompts_received + end + + test "invoke chains session from previous invocation result to next" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalizer"] + + sessions_seen = [] + call_count = 0 + CommandRunner.stubs(:execute).with do |args, **kwargs| + call_count += 1 + if call_count > 1 && args.include?("--fork") + fork_index = args.index("--fork") + sessions_seen << args[fork_index + 1] + end + session_id = call_count == 1 ? "session_from_first" : "session_from_second" + session_json = { type: "session", id: session_id }.to_json + kwargs[:stdout_handler]&.call(session_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal ["session_from_first"], sessions_seen + end + + test "invoke always uses --fork for session chaining" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalizer"] + + fork_flags = [] + CommandRunner.stubs(:execute).with do |args, **kwargs| + fork_flags << args.include?("--fork") + session_json = { type: "session", id: "session_1" }.to_json + kwargs[:stdout_handler]&.call(session_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + # First invocation has no session, so --no-session (no --fork). + # Second invocation forks from the first session. + assert_equal [false, true], fork_flags + end + + test "invoke uses --fork for first invocation when input has session" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalizer"] + input.session = "external_session" + + fork_flags = [] + CommandRunner.stubs(:execute).with do |args, **kwargs| + fork_flags << args.include?("--fork") + session_json = { type: "session", id: "new_session" }.to_json + kwargs[:stdout_handler]&.call(session_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal [true, true], fork_flags + end + + test "invoke stops on first failed invocation" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalizer 1", "Finalizer 2"] + + call_count = 0 + CommandRunner.stubs(:execute).with do |_args, **_kwargs| + call_count += 1 + true + end.returns(["", "Error occurred", mock_status(success: false)]) + + assert_raises(Pi::PiInvocation::PiFailedError) do + @provider.invoke(input) + end + + assert_equal 1, call_count + end + + test "invoke returns output from last invocation" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalize it"] + + call_count = 0 + CommandRunner.stubs(:execute).with do |_args, **kwargs| + call_count += 1 + result_text = call_count == 1 ? "intermediate" : "final result" + # Pi uses agent_end to extract the final response + agent_end_json = { + type: "agent_end", + messages: [ + { role: "assistant", content: [{ type: "text", text: result_text }] }, + ], + }.to_json + kwargs[:stdout_handler]&.call(agent_end_json) + # Also emit a session so the chain can continue + session_json = { type: "session", id: "session_#{call_count}" }.to_json + kwargs[:stdout_handler]&.call(session_json) + true + end.returns(["", "", mock_status(success: true)]) + + output = @provider.invoke(input) + + assert_equal "final result", output.response + end + + test "invoke uses input session when no previous invocation session exists" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalize"] + input.session = "input_session" + + sessions_for_calls = [] + CommandRunner.stubs(:execute).with do |args, **_kwargs| + if args.include?("--fork") + fork_index = args.index("--fork") + sessions_for_calls << args[fork_index + 1] + else + sessions_for_calls << nil + end + # Don't emit a session event, so no session chains forward + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + # Both calls should use the input session since no session was returned + assert_equal ["input_session", "input_session"], sessions_for_calls + end + end + end + end + end +end