diff --git a/lib/roast/cogs/agent/providers/claude.rb b/lib/roast/cogs/agent/providers/claude.rb index 2db0c619..4a0a8f90 100644 --- a/lib/roast/cogs/agent/providers/claude.rb +++ b/lib/roast/cogs/agent/providers/claude.rb @@ -20,7 +20,13 @@ def initialize(invocation_result) def invoke(input) invocations = [] #: Array[ClaudeInvocation] input.prompts.each do |prompt| - invocation = ClaudeInvocation.new(@config, prompt, invocations.last&.result&.session || input.session) + previous_session = invocations.last&.result&.session + invocation = ClaudeInvocation.new( + @config, + prompt, + previous_session || input.session, + fork_session: previous_session.nil?, + ) invocation.run! invocations << invocation break unless invocation.result.success diff --git a/lib/roast/cogs/agent/providers/claude/claude_invocation.rb b/lib/roast/cogs/agent/providers/claude/claude_invocation.rb index f2555877..1c6ad3d7 100644 --- a/lib/roast/cogs/agent/providers/claude/claude_invocation.rb +++ b/lib/roast/cogs/agent/providers/claude/claude_invocation.rb @@ -53,8 +53,8 @@ def initialize end end - #: (Agent::Config, String, String?) -> void - def initialize(config, prompt, session) + #: (Agent::Config, String, String?, ?fork_session: bool) -> void + def initialize(config, prompt, session, fork_session: true) @base_command = config.valid_command #: (String | Array[String])? @model = config.valid_model #: String? @append_system_prompt = config.valid_append_system_prompt #: String? @@ -69,6 +69,7 @@ def initialize(config, prompt, session) @show_response = config.show_response? #: bool @prompt = prompt @session = session + @fork_session = fork_session #: bool end #: () -> void @@ -184,7 +185,10 @@ def command_line command.push("--model", @model) if @model command.push("--system-prompt", @replace_system_prompt) if @replace_system_prompt command.push("--append-system-prompt", @append_system_prompt) if @append_system_prompt - command.push("--fork-session", "--resume", @session) if @session.present? + if @session.present? + command.push("--fork-session") if @fork_session + command.push("--resume", @session) + end command << "--dangerously-skip-permissions" unless @apply_permissions command end diff --git a/test/roast/cogs/agent/providers/claude/claude_invocation_test.rb b/test/roast/cogs/agent/providers/claude/claude_invocation_test.rb index 2314ede2..bbb2ef6c 100644 --- a/test/roast/cogs/agent/providers/claude/claude_invocation_test.rb +++ b/test/roast/cogs/agent/providers/claude/claude_invocation_test.rb @@ -233,7 +233,7 @@ def failure_status assert_equal "Additional instructions", command[prompt_index + 1] end - test "command_line includes session flags when session is set" do + test "command_line includes fork-session and resume when session is set" do invocation = ClaudeInvocation.new(@config, "Test prompt", "session_123") command = invocation.send(:command_line) @@ -244,6 +244,26 @@ def failure_status assert_equal "session_123", command[resume_index + 1] end + test "command_line includes resume without fork-session when fork_session is false" do + invocation = ClaudeInvocation.new(@config, "Test prompt", "session_123", fork_session: false) + + command = invocation.send(:command_line) + + refute_includes command, "--fork-session" + assert_includes command, "--resume" + resume_index = command.index("--resume") + assert_equal "session_123", command[resume_index + 1] + end + + test "command_line omits fork-session when no session is given even if fork_session is true" do + invocation = ClaudeInvocation.new(@config, "Test prompt", nil, fork_session: true) + + command = invocation.send(:command_line) + + refute_includes command, "--fork-session" + refute_includes command, "--resume" + end + test "command_line includes dangerously-skip-permissions when permissions skipped" do @config.skip_permissions! command = @invocation.send(:command_line) diff --git a/test/roast/cogs/agent/providers/claude_test.rb b/test/roast/cogs/agent/providers/claude_test.rb index 9237bece..9a441dd3 100644 --- a/test/roast/cogs/agent/providers/claude_test.rb +++ b/test/roast/cogs/agent/providers/claude_test.rb @@ -106,6 +106,43 @@ def mock_status(success:) assert_equal ["session_from_first"], sessions_seen end + test "invoke does not fork session for subsequent invocations" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalizer"] + + fork_flags = [] + call_count = 0 + CommandRunner.stubs(:execute).with do |args, **kwargs| + call_count += 1 + fork_flags << args.include?("--fork-session") + result_json = { type: "result", subtype: "success", result: "done", session_id: "session_1" }.to_json + kwargs[:stdout_handler]&.call(result_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal [false, false], fork_flags + end + + test "invoke forks session for first invocation when input has session" do + input = Agent::Input.new + input.prompts = ["Main task", "Finalizer"] + input.session = "external_session" + + fork_flags = [] + CommandRunner.stubs(:execute).with do |args, **kwargs| + fork_flags << args.include?("--fork-session") + result_json = { type: "result", subtype: "success", result: "done", session_id: "new_session" }.to_json + kwargs[:stdout_handler]&.call(result_json) + true + end.returns(["", "", mock_status(success: true)]) + + @provider.invoke(input) + + assert_equal [true, false], fork_flags + end + test "invoke stops on first failed invocation" do input = Agent::Input.new input.prompts = ["Main task", "Finalizer 1", "Finalizer 2"]