From 602aa8f647389ddc56292ef4fad836976bff541d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Fri, 16 Jan 2026 20:40:08 +0800 Subject: [PATCH 1/3] feat: add retry mechanism for combo reply tool calls --- crates/coco-tui/src/components/chat.rs | 192 +++++++++++-- src/agent.rs | 360 +++++++++++++------------ src/config.rs | 3 + src/config/presets.toml | 7 + src/config/provider.rs | 6 + 5 files changed, 379 insertions(+), 189 deletions(-) diff --git a/crates/coco-tui/src/components/chat.rs b/crates/coco-tui/src/components/chat.rs index 6e528b4..2af28b3 100644 --- a/crates/coco-tui/src/components/chat.rs +++ b/crates/coco-tui/src/components/chat.rs @@ -19,6 +19,7 @@ use ratatui::{ use serde::{Deserialize, Serialize}; use serde_json::Value; +use snafu::prelude::*; use std::{ path::PathBuf, sync::{ @@ -1562,6 +1563,58 @@ Do not output any other text or explanation. Only call the bash tool with the co ) } +fn build_offload_reply_retry_directive(schemas: &[code_combo::PromptSchema]) -> String { + let directive = build_offload_reply_directive(schemas); + format!("The previous response did not produce a valid coco reply. Retry.\n\n{directive}") +} + +#[derive(Debug, Snafu)] +enum ComboReplyError { + #[snafu(display("prompt reply cancelled"))] + Cancelled, + #[snafu(display("chat failed: {message}"))] + ChatFailed { message: String }, + #[snafu(display("LLM did not return a bash tool call for coco reply"))] + MissingBashToolUse, + #[snafu(display("expected coco reply command, got: {command}"))] + UnexpectedCommand { command: String }, + #[snafu(display("bash execution failed: {message}"))] + BashExecutionFailed { message: String }, + #[snafu(display("bash command did not produce output"))] + BashOutputMissing, + #[snafu(display("bash command failed: {output}"))] + BashCommandFailed { output: String }, + #[snafu(display("coco reply failed with exit code {exit_code}: {stderr}"))] + ReplyCommandFailed { exit_code: u8, stderr: String }, + #[snafu(display("coco reply output is not valid JSON: {message}"))] + ReplyOutputInvalidJson { message: String }, + #[snafu(display("coco reply output must be a JSON object"))] + ReplyOutputNotObject, + #[snafu(display("missing required fields: {fields}"))] + ReplyOutputMissingFields { fields: String }, + #[snafu(display("unexpected message output: {message}"))] + UnexpectedMessageOutput { message: String }, +} + +impl ComboReplyError { + fn should_retry(&self) -> bool { + matches!( + self, + ComboReplyError::MissingBashToolUse + | ComboReplyError::UnexpectedCommand { .. } + | ComboReplyError::ReplyCommandFailed { .. } + | ComboReplyError::ReplyOutputInvalidJson { .. } + | ComboReplyError::ReplyOutputNotObject + | ComboReplyError::ReplyOutputMissingFields { .. } + | ComboReplyError::UnexpectedMessageOutput { .. } + ) + } +} + +fn should_retry_offload_combo_reply(error: &ComboReplyError) -> bool { + error.should_retry() +} + /// Check if a bash tool_use input contains a `coco reply` command. fn is_coco_reply_command(input: &Value) -> bool { let Some(command) = input.get("command").and_then(|v| v.as_str()) else { @@ -1579,7 +1632,7 @@ fn is_coco_reply_command(input: &Value) -> bool { fn parse_coco_reply_output( output: &Final, schemas: &[code_combo::PromptSchema], -) -> Result { +) -> Result { match output { Final::Json(value) => { // BashOutput structure: { stdout, stderr, exit_code, timed_out } @@ -1589,9 +1642,10 @@ fn parse_coco_reply_output( .unwrap_or(255) as u8; if exit_code != 0 { let stderr = value.get("stderr").and_then(|v| v.as_str()).unwrap_or(""); - return Err(format!( - "coco reply failed with exit code {exit_code}: {stderr}" - )); + return Err(ComboReplyError::ReplyCommandFailed { + exit_code, + stderr: stderr.to_string(), + }); } let stdout = value .get("stdout") @@ -1599,47 +1653,95 @@ fn parse_coco_reply_output( .unwrap_or("") .trim(); // Validate it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(stdout) - .map_err(|e| format!("coco reply output is not valid JSON: {e}"))?; + let parsed: serde_json::Value = serde_json::from_str(stdout).map_err(|e| { + ComboReplyError::ReplyOutputInvalidJson { + message: e.to_string(), + } + })?; // Validate all required fields are present let obj = parsed .as_object() - .ok_or_else(|| "coco reply output must be a JSON object".to_string())?; + .ok_or(ComboReplyError::ReplyOutputNotObject)?; let missing: Vec<&str> = schemas .iter() .filter(|s| !obj.contains_key(&s.name)) .map(|s| s.name.as_str()) .collect(); if !missing.is_empty() { - return Err(format!("missing required fields: {}", missing.join(", "))); + return Err(ComboReplyError::ReplyOutputMissingFields { + fields: missing.join(", "), + }); } Ok(stdout.to_string()) } - Final::Message(msg) => Err(format!("unexpected message output: {msg}")), + Final::Message(msg) => Err(ComboReplyError::UnexpectedMessageOutput { + message: msg.to_string(), + }), + } +} + +/// Handle combo reply via offload to bash `coco reply` command. +/// Returns Ok(json_response) on success, Err(error) on failure. +async fn handle_offload_combo_reply_with_retry( + agent: &mut Agent, + schemas: &[code_combo::PromptSchema], + combo_name: &str, + cancel_token: CancellationToken, + tx: tokio::sync::mpsc::UnboundedSender, +) -> Result { + let max_retries = agent.combo_reply_retries(); + let mut attempt = 0usize; + loop { + if cancel_token.is_cancelled() { + return Err(ComboReplyError::Cancelled); + } + let directive = if attempt == 0 { + build_offload_reply_directive(schemas) + } else { + build_offload_reply_retry_directive(schemas) + }; + let response = handle_offload_combo_reply( + agent, + schemas, + combo_name, + cancel_token.clone(), + tx.clone(), + &directive, + ) + .await; + match response { + Ok(result) => return Ok(result), + Err(err) => { + if attempt >= max_retries || !should_retry_offload_combo_reply(&err) { + return Err(err); + } + attempt += 1; + } + } } } /// Handle combo reply via offload to bash `coco reply` command. -/// Returns Ok(json_response) on success, Err(error_message) on failure. +/// Returns Ok(json_response) on success, Err(error) on failure. async fn handle_offload_combo_reply( agent: &mut Agent, schemas: &[code_combo::PromptSchema], combo_name: &str, cancel_token: CancellationToken, tx: tokio::sync::mpsc::UnboundedSender, -) -> Result { + directive: &str, +) -> Result { use code_combo::tools::BASH_TOOL_NAME; if cancel_token.is_cancelled() { - return Err("prompt reply cancelled".to_string()); + return Err(ComboReplyError::Cancelled); } // Build and append the directive prompt - let directive = build_offload_reply_directive(schemas); agent - .append_message(ChatMessage::user(ChatContent::Text(directive))) + .append_message(ChatMessage::user(ChatContent::Text(directive.to_string()))) .await; // Call chat to get LLM response with streaming for thinking updates @@ -1666,7 +1768,9 @@ async fn handle_offload_combo_reply( .ok(); }) .await - .map_err(|e| format!("chat failed: {e}"))?; + .map_err(|e| ComboReplyError::ChatFailed { + message: e.to_string(), + })?; // Extract Bash tool_use from response let blocks = match &chat_response.message.content { @@ -1684,7 +1788,7 @@ async fn handle_offload_combo_reply( } None }) - .ok_or_else(|| "LLM did not return a bash tool call for coco reply".to_string())?; + .ok_or(ComboReplyError::MissingBashToolUse)?; // Validate it's a coco reply command if !is_coco_reply_command(&bash_tool_use.input) { @@ -1693,7 +1797,9 @@ async fn handle_offload_combo_reply( .get("command") .and_then(|v| v.as_str()) .unwrap_or(""); - return Err(format!("expected coco reply command, got: {command}")); + return Err(ComboReplyError::UnexpectedCommand { + command: command.to_string(), + }); } // Auto-grant and execute the bash command @@ -1729,19 +1835,19 @@ async fn handle_offload_combo_reply( }, ) .await - .map_err(|e| format!("bash execution failed: {e}"))?; + .map_err(|e| ComboReplyError::BashExecutionFailed { + message: e.to_string(), + })?; if cancel_token.is_cancelled() { - return Err("prompt reply cancelled".to_string()); + return Err(ComboReplyError::Cancelled); } // Parse the output and send result event let (output, is_error) = match final_output { Some(Output::Success(output)) => (output, false), Some(Output::Failure(output)) => (output, true), - _ => { - return Err("bash command did not produce output".to_string()); - } + _ => return Err(ComboReplyError::BashOutputMissing), }; agent @@ -1762,7 +1868,9 @@ async fn handle_offload_combo_reply( .ok(); if is_error { - return Err(format!("bash command failed: {:?}", output)); + return Err(ComboReplyError::BashCommandFailed { + output: format!("{output:?}"), + }); } parse_coco_reply_output(&output, schemas) @@ -1925,7 +2033,7 @@ async fn task_combo_execute( // Check if offload_combo_reply is enabled for the current provider let response = if agent.offload_combo_reply() { // Offload path: use bash tool to call `coco reply` - handle_offload_combo_reply( + handle_offload_combo_reply_with_retry( &mut agent, &schemas, &name, @@ -1933,6 +2041,7 @@ async fn task_combo_execute( tx.clone(), ) .await + .map_err(|err| err.to_string()) } else { // Original path: use combo_reply tool let disable_stream = agent.disable_stream_for_current_model(); @@ -2648,4 +2757,39 @@ mod tests { Some(true) ); } + + #[test] + fn should_retry_offload_combo_reply_matches_model_errors() { + let retryable = [ + ComboReplyError::MissingBashToolUse, + ComboReplyError::UnexpectedCommand { + command: "ls".to_string(), + }, + ComboReplyError::ReplyCommandFailed { + exit_code: 1, + stderr: "bad args".to_string(), + }, + ComboReplyError::ReplyOutputInvalidJson { + message: "invalid".to_string(), + }, + ComboReplyError::ReplyOutputNotObject, + ComboReplyError::ReplyOutputMissingFields { + fields: "foo, bar".to_string(), + }, + ComboReplyError::UnexpectedMessageOutput { + message: "hi".to_string(), + }, + ]; + for case in retryable { + assert!(should_retry_offload_combo_reply(&case), "case: {case}"); + } + assert!(!should_retry_offload_combo_reply( + &ComboReplyError::Cancelled + )); + assert!(!should_retry_offload_combo_reply( + &ComboReplyError::ChatFailed { + message: "network".to_string() + } + )); + } } diff --git a/src/agent.rs b/src/agent.rs index b620089..47cd29f 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -38,6 +38,7 @@ pub use config::{ }; const PROMPT_REPLY_TOOL_NAME: &str = "combo_reply"; +const REPLY_TOOL_MISSING_ERROR: &str = "reply tool use not found in response"; #[derive(Clone)] pub struct Agent { @@ -569,17 +570,7 @@ impl Agent { !request_options.disable_tools, "reply tool disabled by request options" ); - let reply_tool = build_reply_tool(&schemas)?; let (_, client) = self.pick_provider()?; - let messages = { - let mut history = self.messages.lock().await; - if use_tool_choice_fallback { - let new_message = build_reply_prompt_message(&schemas); - history.push(new_message); - } - history.clone() - }; - let messages = self.prepare_messages_for_request(messages, &request_options); let tool_choice = ToolChoice::tool(PROMPT_REPLY_TOOL_NAME, None); let system_prompt = system_prompt.trim(); let system_prompt = if system_prompt.is_empty() { @@ -588,73 +579,92 @@ impl Agent { Some(system_prompt) }; let thinking = self.thinking_payload_with_override(thinking.as_ref()); - let response = if request_options.disable_tool_choice { - ensure_whatever!( - request_options.tool_choice_fallback, - "tool_choice disabled without prompt fallback" - ); - client - .messages( - system_prompt, - messages, - vec![reply_tool], - thinking, - &request_options, - ) - .await - .whatever_context_display("failed to request prompt reply")? - } else { - client - .messages_with_tool_choice( - system_prompt, - messages, - vec![reply_tool], - tool_choice, - thinking, - &request_options, - ) - .await - .whatever_context_display("failed to request prompt reply")? - }; - let stop_reason = response.stop_reason.clone(); - if !response.content.is_empty() { - let mut history = self.messages.lock().await; - history.push(Message::assistant(Content::Multiple( - response.content.clone(), - ))); - } - let mut thinking = Vec::new(); - let mut reply_tool = None; - for block in response.content.into_iter() { - match block { - Block::Thinking { thinking: text, .. } => { - thinking.push(text); + let mut attempt = 0usize; + loop { + let reply_tool = build_reply_tool(&schemas)?; + let messages = { + let mut history = self.messages.lock().await; + if attempt == 0 && use_tool_choice_fallback { + let new_message = build_reply_prompt_message(&schemas); + history.push(new_message); + } else if attempt > 0 { + history.push(build_reply_retry_message(&schemas)); } - Block::ToolUse(tool_use) if tool_use.name == PROMPT_REPLY_TOOL_NAME => { - reply_tool = Some(tool_use); + history.clone() + }; + let messages = self.prepare_messages_for_request(messages, &request_options); + let response = if request_options.disable_tool_choice { + ensure_whatever!( + request_options.tool_choice_fallback, + "tool_choice disabled without prompt fallback" + ); + client + .messages( + system_prompt, + messages, + vec![reply_tool], + thinking.clone(), + &request_options, + ) + .await + .whatever_context_display("failed to request prompt reply")? + } else { + client + .messages_with_tool_choice( + system_prompt, + messages, + vec![reply_tool], + tool_choice.clone(), + thinking.clone(), + &request_options, + ) + .await + .whatever_context_display("failed to request prompt reply")? + }; + let stop_reason = response.stop_reason.clone(); + if !response.content.is_empty() { + let mut history = self.messages.lock().await; + history.push(Message::assistant(Content::Multiple( + response.content.clone(), + ))); + } + let mut thinking = Vec::new(); + let mut reply_tool_use = None; + for block in response.content.into_iter() { + match block { + Block::Thinking { thinking: text, .. } => { + thinking.push(text); + } + Block::ToolUse(tool_use) if tool_use.name == PROMPT_REPLY_TOOL_NAME => { + reply_tool_use = Some(tool_use); + } + _ => (), } - _ => (), } + let Some(tool_use) = reply_tool_use else { + if attempt >= request_options.combo_reply_retries { + whatever!("{}", REPLY_TOOL_MISSING_ERROR); + } + attempt += 1; + continue; + }; + { + let mut history = self.messages.lock().await; + history.push(Message::user(Content::Multiple(vec![Block::tool_result( + &tool_use.id, + None, + Content::Text("ok".to_string()), + )]))); + } + let response = serde_json::to_string(&tool_use.input) + .whatever_context("failed to serialize reply tool input")?; + self.mark_thinking_cleanup_pending(stop_reason.as_ref()); + return Ok(PromptReply { + tool_use, + response, + thinking, + }); } - let Some(tool_use) = reply_tool else { - whatever!("reply tool use not found in response"); - }; - { - let mut history = self.messages.lock().await; - history.push(Message::user(Content::Multiple(vec![Block::tool_result( - &tool_use.id, - None, - Content::Text("ok".to_string()), - )]))); - } - let response = serde_json::to_string(&tool_use.input) - .whatever_context("failed to serialize reply tool input")?; - self.mark_thinking_cleanup_pending(stop_reason.as_ref()); - Ok(PromptReply { - tool_use, - response, - thinking, - }) } pub async fn reply_prompt_stream_with_thinking( @@ -675,17 +685,7 @@ impl Agent { !request_options.disable_tools, "reply tool disabled by request options" ); - let reply_tool = build_reply_tool(&schemas)?; let (_, client) = self.pick_provider()?; - let messages = { - let mut history = self.messages.lock().await; - if use_tool_choice_fallback { - let new_message = build_reply_prompt_message(&schemas); - history.push(new_message); - } - history.clone() - }; - let messages = self.prepare_messages_for_request(messages, &request_options); let tool_choice = ToolChoice::tool(PROMPT_REPLY_TOOL_NAME, None); let system_prompt = system_prompt.trim(); let system_prompt = if system_prompt.is_empty() { @@ -694,94 +694,113 @@ impl Agent { Some(system_prompt) }; let thinking = self.thinking_payload_with_override(thinking.as_ref()); - let mut stream = if request_options.disable_tool_choice { - ensure_whatever!( - request_options.tool_choice_fallback, - "tool_choice disabled without prompt fallback" - ); - client - .messages_stream( - system_prompt, - messages, - vec![reply_tool], - thinking, - &request_options, - ) - .await - .inspect_err(|err| { - warn!("send prompt reply stream error: {err:?}"); - }) - .whatever_context_display("failed to request prompt reply stream")? - } else { - client - .messages_stream_with_tool_choice( - system_prompt, - messages, - vec![reply_tool], - tool_choice, - thinking, - &request_options, - ) - .await - .inspect_err(|err| { - warn!("send prompt reply stream error: {err:?}"); - }) - .whatever_context_display("failed to request prompt reply stream")? - }; - - let mut accumulator = StreamAccumulator::new(); - while let Some(event) = tokio::select! { - _ = cancel_token.cancelled() => { - whatever!("prompt reply stream cancelled"); - } - event = stream.next() => event, - } { - let event = event.whatever_context_display("read prompt reply stream error")?; - let action = accumulator - .handle_event(event, &mut on_update) - .whatever_context("parse prompt reply stream error")?; - if matches!(action, StreamAction::Stop) { - break; + let mut attempt = 0usize; + loop { + let reply_tool = build_reply_tool(&schemas)?; + let messages = { + let mut history = self.messages.lock().await; + if attempt == 0 && use_tool_choice_fallback { + let new_message = build_reply_prompt_message(&schemas); + history.push(new_message); + } else if attempt > 0 { + history.push(build_reply_retry_message(&schemas)); + } + history.clone() + }; + let messages = self.prepare_messages_for_request(messages, &request_options); + let mut stream = if request_options.disable_tool_choice { + ensure_whatever!( + request_options.tool_choice_fallback, + "tool_choice disabled without prompt fallback" + ); + client + .messages_stream( + system_prompt, + messages, + vec![reply_tool], + thinking.clone(), + &request_options, + ) + .await + .inspect_err(|err| { + warn!("send prompt reply stream error: {err:?}"); + }) + .whatever_context_display("failed to request prompt reply stream")? + } else { + client + .messages_stream_with_tool_choice( + system_prompt, + messages, + vec![reply_tool], + tool_choice.clone(), + thinking.clone(), + &request_options, + ) + .await + .inspect_err(|err| { + warn!("send prompt reply stream error: {err:?}"); + }) + .whatever_context_display("failed to request prompt reply stream")? + }; + + let mut accumulator = StreamAccumulator::new(); + while let Some(event) = tokio::select! { + _ = cancel_token.cancelled() => { + whatever!("prompt reply stream cancelled"); + } + event = stream.next() => event, + } { + let event = event.whatever_context_display("read prompt reply stream error")?; + let action = accumulator + .handle_event(event, &mut on_update) + .whatever_context("parse prompt reply stream error")?; + if matches!(action, StreamAction::Stop) { + break; + } } - } - let (blocks, stop_reason) = accumulator.finish(); - if !blocks.is_empty() { - let msg = Message::assistant(Content::Multiple(blocks.clone())); - self.messages.lock().await.push(msg); - } - let mut thinking = Vec::new(); - let mut reply_tool = None; - for block in &blocks { - match block { - Block::Thinking { thinking: text, .. } => { - thinking.push(text.clone()); + let (blocks, stop_reason) = accumulator.finish(); + if !blocks.is_empty() { + let msg = Message::assistant(Content::Multiple(blocks.clone())); + self.messages.lock().await.push(msg); + } + let mut thinking = Vec::new(); + let mut reply_tool_use = None; + for block in &blocks { + match block { + Block::Thinking { thinking: text, .. } => { + thinking.push(text.clone()); + } + Block::ToolUse(tool_use) if tool_use.name == PROMPT_REPLY_TOOL_NAME => { + reply_tool_use = Some(tool_use.clone()); + } + _ => (), } - Block::ToolUse(tool_use) if tool_use.name == PROMPT_REPLY_TOOL_NAME => { - reply_tool = Some(tool_use.clone()); + } + let Some(tool_use) = reply_tool_use else { + if attempt >= request_options.combo_reply_retries { + whatever!("{}", REPLY_TOOL_MISSING_ERROR); } - _ => (), + attempt += 1; + continue; + }; + { + let mut history = self.messages.lock().await; + history.push(Message::user(Content::Multiple(vec![Block::tool_result( + &tool_use.id, + None, + Content::Text("ok".to_string()), + )]))); } + let response = serde_json::to_string(&tool_use.input) + .whatever_context("failed to serialize reply tool input")?; + self.mark_thinking_cleanup_pending(stop_reason.as_ref()); + return Ok(PromptReply { + tool_use, + response, + thinking, + }); } - let Some(tool_use) = reply_tool else { - whatever!("reply tool use not found in response"); - }; - { - let mut history = self.messages.lock().await; - history.push(Message::user(Content::Multiple(vec![Block::tool_result( - &tool_use.id, - None, - Content::Text("ok".to_string()), - )]))); - } - let response = serde_json::to_string(&tool_use.input) - .whatever_context("failed to serialize reply tool input")?; - self.mark_thinking_cleanup_pending(stop_reason.as_ref()); - Ok(PromptReply { - tool_use, - response, - thinking, - }) } pub fn grant_once(&mut self, id: &str, name: &str) { @@ -813,6 +832,10 @@ impl Agent { self.request_options_for_current_model().disable_stream } + pub fn combo_reply_retries(&self) -> usize { + self.request_options_for_current_model().combo_reply_retries + } + pub fn current_model(&self) -> String { let selected_model = self.selected_model(); match Self::select_provider_index(selected_model.as_deref(), &self.config.providers) { @@ -1062,6 +1085,13 @@ fn build_reply_prompt_message(schemas: &[PromptSchema]) -> Message { Message::user(Content::Text(build_reply_tool_directive(schemas))) } +fn build_reply_retry_message(schemas: &[PromptSchema]) -> Message { + let directive = build_reply_tool_directive(schemas); + Message::user(Content::Text(format!( + "The previous response did not call the required tool. {directive}" + ))) +} + fn build_reply_tool(schemas: &[PromptSchema]) -> Result { let mut properties = JsonMap::new(); let mut required = Vec::new(); diff --git a/src/config.rs b/src/config.rs index d21d467..a8b7cf6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -641,6 +641,7 @@ safe_commands_mode = \"override\"\n\ let options = config.request_options_for_model("kimi-k2-thinking"); assert!(options.include_reasoning_content); assert_eq!(options.offload_combo_reply, Some(true)); + assert_eq!(options.combo_reply_retries, 1); assert_eq!(options.temperature, Some(1.0)); assert_eq!(options.max_tokens, Some(16000)); } @@ -652,12 +653,14 @@ safe_commands_mode = \"override\"\n\ model: "deepseek-reasoner".to_string(), disable_tool_choice: Some(false), tool_choice_fallback: Some(false), + combo_reply_retries: Some(0), max_tokens: Some(2048), ..ModelRequestConfig::default() }); let options = config.request_options_for_model("deepseek-reasoner"); assert!(!options.disable_tool_choice); assert!(!options.tool_choice_fallback); + assert_eq!(options.combo_reply_retries, 0); assert_eq!(options.max_tokens, Some(2048)); } } diff --git a/src/config/presets.toml b/src/config/presets.toml index d057f3d..777faa4 100644 --- a/src/config/presets.toml +++ b/src/config/presets.toml @@ -2,6 +2,7 @@ model = "deepseek-chat" disable_tool_choice = true tool_choice_fallback = true +combo_reply_retries = 1 context_window = 128000 can_reason = false max_tokens = 4000 @@ -10,6 +11,7 @@ max_tokens = 4000 model = "deepseek-reasoner" disable_tool_choice = true tool_choice_fallback = true +combo_reply_retries = 1 context_window = 128000 can_reason = true max_tokens = 32000 @@ -18,6 +20,7 @@ max_tokens = 32000 model = "kimi-k2-thinking" include_reasoning_content = true offload_combo_reply = true +combo_reply_retries = 1 temperature = 1.0 max_tokens = 16000 @@ -28,6 +31,7 @@ context_window = 204800 can_reason = true include_reasoning_content = true offload_combo_reply = true +combo_reply_retries = 1 [[model_presets]] model = "glm-4.6" @@ -35,6 +39,7 @@ max_tokens = 131072 context_window = 204800 can_reason = true offload_combo_reply = true +combo_reply_retries = 1 [[model_presets]] model = "glm-4.5" @@ -42,6 +47,7 @@ max_tokens = 98304 context_window = 131072 can_reason = true offload_combo_reply = true +combo_reply_retries = 1 [[model_presets]] model = "glm-4.5-air" @@ -49,3 +55,4 @@ max_tokens = 98304 context_window = 131072 can_reason = true offload_combo_reply = true +combo_reply_retries = 1 diff --git a/src/config/provider.rs b/src/config/provider.rs index 94c624d..37d585e 100644 --- a/src/config/provider.rs +++ b/src/config/provider.rs @@ -26,6 +26,8 @@ pub struct ModelRequestConfig { #[serde(default, skip_serializing_if = "Option::is_none")] pub offload_combo_reply: Option, #[serde(default, skip_serializing_if = "Option::is_none")] + pub combo_reply_retries: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub context_window: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub can_reason: Option, @@ -43,6 +45,7 @@ pub struct RequestOptions { pub include_reasoning_content: bool, pub disable_stream: bool, pub offload_combo_reply: Option, + pub combo_reply_retries: usize, pub context_window: Option, pub can_reason: Option, pub temperature: Option, @@ -69,6 +72,9 @@ impl RequestOptions { if let Some(value) = override_config.offload_combo_reply { self.offload_combo_reply = Some(value); } + if let Some(value) = override_config.combo_reply_retries { + self.combo_reply_retries = value; + } if let Some(value) = override_config.context_window { self.context_window = Some(value); } From 65aa0e7104ed20bb18f8448aa43a247ca584791c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Fri, 16 Jan 2026 21:34:52 +0800 Subject: [PATCH 2/3] refactor: improve session server connection handling and reply validation --- crates/coco-tui/src/components/chat.rs | 95 +++- src/cmd/prompt.rs | 50 +- src/combo/session.rs | 10 +- src/combo/starter.rs | 672 +++++++++++++++++-------- 4 files changed, 555 insertions(+), 272 deletions(-) diff --git a/crates/coco-tui/src/components/chat.rs b/crates/coco-tui/src/components/chat.rs index 2af28b3..70c27b5 100644 --- a/crates/coco-tui/src/components/chat.rs +++ b/crates/coco-tui/src/components/chat.rs @@ -1,5 +1,5 @@ use coco_macro::ComponentExt; -use code_combo::tools::Final; +use code_combo::tools::{BashInput, Final}; use code_combo::{ Agent, Block as ChatBlock, ChatResponse, ChatStreamUpdate, Config, Content as ChatContent, Message as ChatMessage, Output, RuntimeOverrides, SessionEnv, Starter, StarterCommand, @@ -21,7 +21,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use snafu::prelude::*; use std::{ - path::PathBuf, + path::{Path, PathBuf}, sync::{ Arc, atomic::{AtomicBool, Ordering}, @@ -1576,6 +1576,8 @@ enum ComboReplyError { ChatFailed { message: String }, #[snafu(display("LLM did not return a bash tool call for coco reply"))] MissingBashToolUse, + #[snafu(display("failed to parse bash tool input: {message}"))] + InvalidBashInput { message: String }, #[snafu(display("expected coco reply command, got: {command}"))] UnexpectedCommand { command: String }, #[snafu(display("bash execution failed: {message}"))] @@ -1601,6 +1603,7 @@ impl ComboReplyError { matches!( self, ComboReplyError::MissingBashToolUse + | ComboReplyError::InvalidBashInput { .. } | ComboReplyError::UnexpectedCommand { .. } | ComboReplyError::ReplyCommandFailed { .. } | ComboReplyError::ReplyOutputInvalidJson { .. } @@ -1615,17 +1618,50 @@ fn should_retry_offload_combo_reply(error: &ComboReplyError) -> bool { error.should_retry() } -/// Check if a bash tool_use input contains a `coco reply` command. -fn is_coco_reply_command(input: &Value) -> bool { - let Some(command) = input.get("command").and_then(|v| v.as_str()) else { +fn is_env_assignment(token: &str) -> bool { + let Some((key, _value)) = token.split_once('=') else { return false; }; - let trimmed = command.trim(); - // Check if command starts with "coco reply" (with possible path prefix) - trimmed == "coco reply" - || trimmed.starts_with("coco reply ") - || trimmed.ends_with("/coco reply") - || trimmed.contains("/coco reply ") + if key.is_empty() { + return false; + } + if key.starts_with('-') { + return false; + } + key.chars() + .all(|ch| ch == '_' || ch.is_ascii_alphanumeric()) +} + +/// Check if a command string starts with `coco reply` (allowing env assignments). +fn is_coco_reply_command(command: &str) -> bool { + let mut parts = command.split_whitespace(); + let mut token = match parts.next() { + Some(token) => token, + None => return false, + }; + + loop { + if token == "env" { + token = match parts.next() { + Some(next) => next, + None => return false, + }; + continue; + } + if is_env_assignment(token) { + token = match parts.next() { + Some(next) => next, + None => return false, + }; + continue; + } + break; + } + + let command_token = token; + let reply_token = parts.next(); + let is_coco = command_token == "coco" || command_token.ends_with("/coco"); + is_coco && matches!(reply_token, Some("reply")) } /// Parse coco reply stdout output as JSON fields and validate required schemas. @@ -1688,6 +1724,7 @@ async fn handle_offload_combo_reply_with_retry( agent: &mut Agent, schemas: &[code_combo::PromptSchema], combo_name: &str, + session_socket_path: &Path, cancel_token: CancellationToken, tx: tokio::sync::mpsc::UnboundedSender, ) -> Result { @@ -1706,6 +1743,7 @@ async fn handle_offload_combo_reply_with_retry( agent, schemas, combo_name, + session_socket_path, cancel_token.clone(), tx.clone(), &directive, @@ -1729,6 +1767,7 @@ async fn handle_offload_combo_reply( agent: &mut Agent, schemas: &[code_combo::PromptSchema], combo_name: &str, + session_socket_path: &Path, cancel_token: CancellationToken, tx: tokio::sync::mpsc::UnboundedSender, directive: &str, @@ -1790,18 +1829,31 @@ async fn handle_offload_combo_reply( }) .ok_or(ComboReplyError::MissingBashToolUse)?; + let mut bash_input: BashInput = + serde_json::from_value(bash_tool_use.input.clone()).map_err(|err| { + ComboReplyError::InvalidBashInput { + message: err.to_string(), + } + })?; + // Validate it's a coco reply command - if !is_coco_reply_command(&bash_tool_use.input) { - let command = bash_tool_use - .input - .get("command") - .and_then(|v| v.as_str()) - .unwrap_or(""); + if !is_coco_reply_command(&bash_input.command) { return Err(ComboReplyError::UnexpectedCommand { - command: command.to_string(), + command: bash_input.command.clone(), }); } + if !bash_input.command.contains("COCO_SESSION_SOCK=") { + let socket_path = session_socket_path.to_string_lossy(); + let escaped_path = shell_escape(socket_path.as_ref()); + let command = bash_input.command.trim_start(); + bash_input.command = format!("COCO_SESSION_SOCK={} {}", escaped_path, command); + } + let bash_input_value = + serde_json::to_value(&bash_input).map_err(|err| ComboReplyError::InvalidBashInput { + message: err.to_string(), + })?; + // Auto-grant and execute the bash command agent.grant_once(&bash_tool_use.id, BASH_TOOL_NAME); @@ -1824,7 +1876,7 @@ async fn handle_offload_combo_reply( .execute_with_output( &bash_tool_use.id, BASH_TOOL_NAME, - code_combo::Input::Starter(bash_tool_use.input.clone()), + code_combo::Input::Starter(bash_input_value), cancel_token.clone(), |out| { if let Output::Success(output) = out { @@ -1915,6 +1967,7 @@ async fn task_combo_execute( let session_env = SessionEnv::builder() .build() .expect("failed to build session"); + let session_socket_path = session_env.socket_path().to_path_buf(); let mcp_envs = match code_combo::tools::prepare_mcp_envs().await { Ok(envs) => envs, Err(err) => { @@ -2037,6 +2090,7 @@ async fn task_combo_execute( &mut agent, &schemas, &name, + &session_socket_path, cancel_token.clone(), tx.clone(), ) @@ -2762,6 +2816,9 @@ mod tests { fn should_retry_offload_combo_reply_matches_model_errors() { let retryable = [ ComboReplyError::MissingBashToolUse, + ComboReplyError::InvalidBashInput { + message: "bad input".to_string(), + }, ComboReplyError::UnexpectedCommand { command: "ls".to_string(), }, diff --git a/src/cmd/prompt.rs b/src/cmd/prompt.rs index 3e16aab..61b73c4 100644 --- a/src/cmd/prompt.rs +++ b/src/cmd/prompt.rs @@ -1,11 +1,9 @@ -use std::collections::HashMap; - use serde_json::Value; use snafu::prelude::*; use tokio::io::AsyncReadExt; use tracing::info; -use crate::{PromptPayload, PromptSchema, SessionSocketClient, error::Result}; +use crate::{PromptPayload, PromptSchema, ReplyPayload, SessionSocketClient, error::Result}; pub async fn handle_ask(prompt: String, schemas: Vec) -> Result<()> { let prompt = resolve_prompt(prompt).await?; @@ -68,35 +66,35 @@ pub async fn handle_tell(prompt: String) -> Result<()> { /// Fields are provided as --field=value format. /// Validation is done by the parent process (TUI) which knows the required schemas. pub async fn handle_reply(fields: Vec) -> Result<()> { - let parsed_fields = parse_reply_fields(&fields)?; + let Some(client) = SessionSocketClient::from_env() + .await + .whatever_context("failed to new from env COCO_SESSION_SOCK")? + else { + whatever!("env COCO_SESSION_SOCK is not set"); + }; + + let validation = client + .send_reply_wait_validation(ReplyPayload { fields }) + .await + .whatever_context("failed to send reply to session socket")?; - // Output the fields as JSON for bash result parsing - let output = serde_json::to_string(&parsed_fields) - .whatever_context("failed to serialize reply fields")?; - println!("{output}"); + if !validation.success { + let error = validation + .error + .unwrap_or_else(|| "reply validation failed".to_string()); + whatever!("{error}"); + } + + let Some(response) = validation.response else { + whatever!("reply validation succeeded without response"); + }; + + println!("{response}"); info!("reply output generated"); Ok(()) } -fn parse_reply_fields(fields: &[String]) -> Result> { - let mut parsed = HashMap::new(); - for field in fields { - // Handle --field=value format - let field = field.strip_prefix("--").unwrap_or(field); - let Some((key, value)) = field.split_once('=') else { - whatever!("invalid field format {field:?}, expected --field=value"); - }; - let key = key.trim(); - ensure_whatever!(!key.is_empty(), "field key cannot be empty"); - if parsed.contains_key(key) { - whatever!("duplicate field key: {key}"); - } - parsed.insert(key.to_string(), value.to_string()); - } - Ok(parsed) -} - async fn resolve_prompt(prompt: String) -> Result { if !prompt.trim().is_empty() { return Ok(prompt); diff --git a/src/combo/session.rs b/src/combo/session.rs index e72939b..a6ee92f 100644 --- a/src/combo/session.rs +++ b/src/combo/session.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, path::Path, path::PathBuf, sync::Arc}; +use std::{path::Path, path::PathBuf, sync::Arc}; use serde::{Deserialize, Serialize}; use snafu::prelude::*; @@ -110,11 +110,11 @@ pub struct ThinkingConfig { } /// Payload for combo reply via bash command offload. -/// Contains the field values extracted by the LLM. +/// Contains raw `--field=value` args for server-side parsing. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct ReplyPayload { - /// Field name to value mapping - pub fields: HashMap, + /// Raw field args, e.g. "--message=hello" + pub fields: Vec, } /// Server response to validate reply fields against required schemas. @@ -123,6 +123,8 @@ pub struct ReplyValidation { pub success: bool, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, } #[derive(Debug, Snafu)] diff --git a/src/combo/starter.rs b/src/combo/starter.rs index a2e596b..eafe273 100644 --- a/src/combo/starter.rs +++ b/src/combo/starter.rs @@ -1,4 +1,5 @@ use std::{ + collections::{HashMap, HashSet}, io::ErrorKind, path::{Path, PathBuf}, pin::Pin, @@ -12,7 +13,7 @@ use futures_util::StreamExt; use snafu::prelude::*; use time::OffsetDateTime; use tokio::{ - sync::{mpsc, oneshot}, + sync::{Mutex as AsyncMutex, mpsc, oneshot}, task::{self, JoinHandle}, }; use tokio_util::sync::CancellationToken; @@ -20,9 +21,9 @@ use tracing::{debug, info, warn}; use crate::tools::{BASH_TOOL_NAME, BashInput, BashOutput, Final}; use crate::{ - ClientMessage, Combo, ComboMetadata, ControlAction, MetadataPayload, MetadataResponse, - PromptPayload, PromptSchema, RecordControl, RecordEndPayload, ServerMessage, SessionEnv, - SessionSocketServer, StreamKind, ThinkingConfig, ToolUse, + ClientMessage, Combo, ComboMetadata, MetadataPayload, MetadataResponse, PromptPayload, + PromptSchema, RecordEndPayload, ReplyValidation, ServerConnection, ServerMessage, SessionEnv, + SessionServerError, SessionSocketServer, StreamKind, ThinkingConfig, ToolUse, exec::{ChunkConfig, ExecCommand, OutputChunk, ProcessEvent}, }; use serde_json::json; @@ -114,9 +115,23 @@ impl std::fmt::Debug for PromptResponseSender { } } -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone)] struct SessionState { metadata: Option, + pending_reply_schemas: Option>, + event_index: usize, +} + +impl SessionState { + fn next_event_index(&mut self) -> usize { + let index = self.event_index; + self.event_index = self.event_index.saturating_add(1); + index + } + + fn bump_event_index(&mut self) { + self.event_index = self.event_index.saturating_add(1); + } } #[derive(Debug)] @@ -285,6 +300,45 @@ fn resolve_prompt_thinking( }) } +fn parse_reply_fields( + fields: &[String], + schemas: &[PromptSchema], +) -> Result, String> { + let mut parsed = HashMap::new(); + let schema_names: HashSet<&str> = schemas.iter().map(|schema| schema.name.as_str()).collect(); + + for field in fields { + let field = field.strip_prefix("--").unwrap_or(field); + let Some((key, value)) = field.split_once('=') else { + return Err(format!( + "invalid field format {field:?}, expected --field=value" + )); + }; + let key = key.trim(); + if key.is_empty() { + return Err("field key cannot be empty".to_string()); + } + if !schema_names.contains(key) { + return Err(format!("unexpected field key: {key}")); + } + if parsed.contains_key(key) { + return Err(format!("duplicate field key: {key}")); + } + parsed.insert(key.to_string(), value.to_string()); + } + + let missing: Vec<&str> = schemas + .iter() + .filter(|schema| !parsed.contains_key(&schema.name)) + .map(|schema| schema.name.as_str()) + .collect(); + if !missing.is_empty() { + return Err(format!("missing required fields: {}", missing.join(", "))); + } + + Ok(parsed) +} + fn record_output(record: &RecordedCommand) -> BashOutput { let stdout = if record.stdout.is_empty() { String::new() @@ -643,70 +697,90 @@ async fn spawn_session_server( }) } -async fn run_session_server( - server: SessionSocketServer, +async fn clear_pending_reply(state: &Arc>) { + let mut guard = state.lock().await; + guard.pending_reply_schemas = None; +} + +async fn handle_session_connection( + mut conn: ServerConnection, discovery: bool, event_tx: mpsc::Sender, - mut shutdown_rx: oneshot::Receiver<()>, -) -> Result { - let mut state = SessionState::default(); - let mut metadata_seen = false; + state: Arc>, + cancel_token: CancellationToken, +) -> Result<(), StarterError> { let mut first_message = true; - let mut event_index: usize = 0; + let mut current_record: Option = None; loop { - let accept = tokio::select! { - _ = &mut shutdown_rx => break, - accept = server.accept() => accept, + let message = tokio::select! { + _ = cancel_token.cancelled() => break, + message = conn.read_client_message() => message, }; - let mut conn = accept.context(AcceptSessionConnectionSnafu)?; - let mut current_record: Option = None; + let message = match message { + Ok(message) => message, + Err(err) => { + if matches!(&err, SessionServerError::Receive { source, .. } if source.kind() == ErrorKind::UnexpectedEof) + { + break; + } + return Err(InvalidSnafu { + reason: format!("failed to read session message: {err}"), + } + .build()); + } + }; - loop { - let message = tokio::select! { - _ = &mut shutdown_rx => { - if !metadata_seen { - return Err(InvalidSnafu { - reason: "metadata not received from session".to_string(), - } - .build()); + match message { + ClientMessage::Metadata(payload) => { + if !first_message { + let _ = conn.interrupt().await; + return Err(InvalidSnafu { + reason: "metadata must be the first and only metadata message".to_string(), } - return Ok(state); - }, - message = conn.read_client_message() => message, - }; - - match message { - Ok(ClientMessage::Metadata(payload)) => { - if !first_message || metadata_seen { - let _ = conn - .send_server_message(&ServerMessage::RecordControl(RecordControl { - action: ControlAction::Interrupt, - })) - .await; + .build()); + } + { + let mut guard = state.lock().await; + if guard.metadata.is_some() { + let _ = conn.interrupt().await; return Err(InvalidSnafu { reason: "metadata must be the first and only metadata message" .to_string(), } .build()); } - metadata_seen = true; - state.metadata = Some(payload); - let _ = conn - .send_server_message(&ServerMessage::Metadata(MetadataResponse { - discovery, - })) - .await; - first_message = false; + guard.metadata = Some(payload); } - Ok(ClientMessage::RecordStart(payload)) => { - if discovery || !metadata_seen { - let _ = conn - .send_server_message(&ServerMessage::RecordControl(RecordControl { - action: ControlAction::Interrupt, - })) - .await; + conn.send_server_message(&ServerMessage::Metadata(MetadataResponse { discovery })) + .await + .map_err(|err| { + InvalidSnafu { + reason: format!("failed to send metadata response: {err}"), + } + .build() + })?; + first_message = false; + } + ClientMessage::RecordStart(payload) => { + if discovery { + let _ = conn.interrupt().await; + return Err(InvalidSnafu { + reason: "record commands are not allowed in discovery or before metadata" + .to_string(), + } + .build()); + } + let (record_index, name) = { + let mut guard = state.lock().await; + let metadata_name = guard + .metadata + .as_ref() + .map(|metadata| metadata.name.clone()); + if metadata_name.is_none() { + drop(guard); + let _ = conn.interrupt().await; return Err(InvalidSnafu { reason: "record commands are not allowed in discovery or before metadata" @@ -714,216 +788,329 @@ async fn run_session_server( } .build()); } - let record_index = event_index; - event_index = event_index.saturating_add(1); - let name = state - .metadata - .as_ref() - .map(|metadata| metadata.name.as_str()) - .unwrap_or("combo"); - let tool_use_id = format!("combo_record_{name}_{record_index}"); - let input = BashInput::new(payload.command.join(" ")); - let tool_use = ToolUse { - id: tool_use_id.clone(), - name: BASH_TOOL_NAME.to_string(), - input: serde_json::to_value(&input).unwrap_or_else(|_| { - json!({ - "command": input.command, - "timeout": input.timeout, - }) - }), - }; - current_record = Some(RecordedCommand { - tool_use_id: tool_use_id.clone(), - stdout: Vec::new(), - stderr: Vec::new(), - exit_code: None, - }); - let _ = conn - .send_server_message(&ServerMessage::RecordControl(RecordControl { - action: ControlAction::Allow, - })) - .await; - event_tx - .send(StarterEvent::RecordStart { tool_use }) - .await - .ok(); - first_message = false; + let record_index = guard.next_event_index(); + (record_index, metadata_name.unwrap()) + }; + let tool_use_id = format!("combo_record_{name}_{record_index}"); + let input = BashInput::new(payload.command.join(" ")); + let tool_use = ToolUse { + id: tool_use_id.clone(), + name: BASH_TOOL_NAME.to_string(), + input: serde_json::to_value(&input).unwrap_or_else(|_| { + json!({ + "command": input.command, + "timeout": input.timeout, + }) + }), + }; + current_record = Some(RecordedCommand { + tool_use_id: tool_use_id.clone(), + stdout: Vec::new(), + stderr: Vec::new(), + exit_code: None, + }); + let _ = conn.allow().await; + event_tx + .send(StarterEvent::RecordStart { tool_use }) + .await + .ok(); + first_message = false; + } + ClientMessage::RecordChunk(chunk) => { + if discovery { + let _ = conn.interrupt().await; + return Err(InvalidSnafu { + reason: "record chunk is not allowed during discovery".to_string(), + } + .build()); } - Ok(ClientMessage::RecordChunk(chunk)) => { - if discovery { - let _ = conn - .send_server_message(&ServerMessage::RecordControl(RecordControl { - action: ControlAction::Interrupt, - })) - .await; - return Err(InvalidSnafu { - reason: "record chunk is not allowed during discovery".to_string(), - } - .build()); + let stream = chunk.stream; + let lines = chunk.lines; + let Some(record) = current_record.as_mut() else { + continue; + }; + let tool_use_id = record.tool_use_id.clone(); + + match stream { + StreamKind::Stdout => record.stdout.extend(lines.clone()), + StreamKind::Stderr => record.stderr.extend(lines.clone()), + } + event_tx + .send(StarterEvent::RecordOutput { + tool_use_id, + chunk: OutputChunk { + timestamp: OffsetDateTime::now_utc().unix_timestamp(), + stream, + lines, + }, + }) + .await + .ok(); + } + ClientMessage::RecordEnd(RecordEndPayload { + exit_code, + stdout, + stderr, + .. + }) => { + if discovery { + let _ = conn.interrupt().await; + return Err(InvalidSnafu { + reason: "record end is not allowed during discovery".to_string(), } - let stream = chunk.stream; - let lines = chunk.lines; - let Some(record) = current_record.as_mut() else { - continue; - }; - let tool_use_id = record.tool_use_id.clone(); + .build()); + } + let Some(mut record) = current_record.take() else { + continue; + }; - match stream { - StreamKind::Stdout => record.stdout.extend(lines.clone()), - StreamKind::Stderr => record.stderr.extend(lines.clone()), + if let Some(stdout) = stdout { + record.stdout.push(stdout); + } + if let Some(stderr) = stderr { + record.stderr.push(stderr); + } + record.exit_code = exit_code; + let tool_use_id = record.tool_use_id.clone(); + let output = record_output(&record); + let is_error = output.exit_code != 0; + let output_value = + serde_json::to_value(&output).expect("failed to encode record output"); + event_tx + .send(StarterEvent::RecordEnd { + tool_use_id, + is_error, + output: Final::from(output_value), + }) + .await + .ok(); + } + ClientMessage::Prompt(payload) => { + let metadata = { state.lock().await.metadata.clone() }; + if metadata.is_none() { + return Err(InvalidSnafu { + reason: "prompt is not allowed before metadata".to_string(), } - event_tx - .send(StarterEvent::RecordOutput { - tool_use_id, - chunk: OutputChunk { - timestamp: OffsetDateTime::now_utc().unix_timestamp(), - stream, - lines, - }, - }) - .await - .ok(); + .build()); } - Ok(ClientMessage::RecordEnd(RecordEndPayload { - exit_code, - stdout, - stderr, - .. - })) => { + if payload.reply { if discovery { - let _ = conn - .send_server_message(&ServerMessage::RecordControl(RecordControl { - action: ControlAction::Interrupt, - })) - .await; return Err(InvalidSnafu { - reason: "record end is not allowed during discovery".to_string(), + reason: "prompt reply is not allowed during discovery".to_string(), } .build()); } - let Some(mut record) = current_record.take() else { - continue; - }; - - if let Some(stdout) = stdout { - record.stdout.push(stdout); + if payload.schemas.is_empty() { + return Err(InvalidSnafu { + reason: "prompt reply requires schemas".to_string(), + } + .build()); } - if let Some(stderr) = stderr { - record.stderr.push(stderr); + { + let mut guard = state.lock().await; + if guard.pending_reply_schemas.is_some() { + return Err(InvalidSnafu { + reason: "prompt reply already in progress".to_string(), + } + .build()); + } + guard.pending_reply_schemas = Some(payload.schemas.clone()); } - record.exit_code = exit_code; - let tool_use_id = record.tool_use_id.clone(); - let output = record_output(&record); - let is_error = output.exit_code != 0; - let output_value = - serde_json::to_value(&output).expect("failed to encode record output"); - event_tx - .send(StarterEvent::RecordEnd { - tool_use_id, - is_error, - output: Final::from(output_value), + let (response_tx, response_rx) = oneshot::channel(); + let responder = PromptResponseSender::new(response_tx); + let thinking = resolve_prompt_thinking(metadata.as_ref(), &payload); + if event_tx + .send(StarterEvent::PromptRequest { + prompt: payload.prompt, + schemas: payload.schemas, + thinking, + responder, }) .await - .ok(); - } - Ok(ClientMessage::Prompt(payload)) => { - if !metadata_seen { + .is_err() + { + clear_pending_reply(&state).await; return Err(InvalidSnafu { - reason: "prompt is not allowed before metadata".to_string(), + reason: "prompt responder is not available".to_string(), } .build()); } - if payload.reply { - if discovery { + { + let mut guard = state.lock().await; + guard.bump_event_index(); + } + let response = match response_rx.await { + Ok(response) => response, + Err(_) => { + clear_pending_reply(&state).await; return Err(InvalidSnafu { - reason: "prompt reply is not allowed during discovery".to_string(), + reason: "prompt responder dropped response".to_string(), } .build()); } - if payload.schemas.is_empty() { + }; + let response = match response { + Ok(response) => response, + Err(err) => { + clear_pending_reply(&state).await; return Err(InvalidSnafu { - reason: "prompt reply requires schemas".to_string(), + reason: format!("prompt responder failed: {err}"), } .build()); } - let (response_tx, response_rx) = oneshot::channel(); - let responder = PromptResponseSender::new(response_tx); - let thinking = resolve_prompt_thinking(state.metadata.as_ref(), &payload); - event_tx - .send(StarterEvent::PromptRequest { - prompt: payload.prompt, - schemas: payload.schemas, - thinking, - responder, - }) - .await - .map_err(|_| { - InvalidSnafu { - reason: "prompt responder is not available".to_string(), - } - .build() - })?; - event_index = event_index.saturating_add(1); - let response = response_rx.await.map_err(|_| { + }; + clear_pending_reply(&state).await; + conn.send_server_message(&ServerMessage::PromptResponse(response)) + .await + .map_err(|err| { InvalidSnafu { - reason: "prompt responder dropped response".to_string(), + reason: format!("failed to send prompt response: {err}"), } .build() })?; - let response = response.map_err(|err| { - InvalidSnafu { - reason: format!("prompt responder failed: {err}"), - } - .build() - })?; - conn.send_server_message(&ServerMessage::PromptResponse(response)) - .await - .map_err(|err| { - InvalidSnafu { - reason: format!("failed to send prompt response: {err}"), - } - .build() - })?; - } else if !discovery { - event_tx - .send(StarterEvent::Prompt { - prompt: payload.prompt, - }) - .await - .ok(); - event_index = event_index.saturating_add(1); + } else if !discovery { + event_tx + .send(StarterEvent::Prompt { + prompt: payload.prompt, + }) + .await + .ok(); + { + let mut guard = state.lock().await; + guard.bump_event_index(); } - first_message = false; } - Ok(ClientMessage::Reply(_)) => { - // Reply messages are handled locally by `coco reply` command, - // they should not be sent to the session server. - return Err(InvalidSnafu { - reason: "reply is not expected in combo session".to_string(), - } - .build()); + first_message = false; + } + ClientMessage::Reply(payload) => { + if discovery { + let _ = conn + .send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { + success: false, + error: Some("reply is not allowed during discovery".to_string()), + response: None, + })) + .await; + continue; } - Ok(ClientMessage::Mcp(_)) => { - return Err(InvalidSnafu { - reason: "mcp request is not allowed in combo session".to_string(), + let schemas = { state.lock().await.pending_reply_schemas.clone() }; + let Some(schemas) = schemas else { + let _ = conn + .send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { + success: false, + error: Some("reply is not expected in combo session".to_string()), + response: None, + })) + .await; + continue; + }; + let parsed = match parse_reply_fields(&payload.fields, &schemas) { + Ok(parsed) => parsed, + Err(err) => { + let _ = conn + .send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { + success: false, + error: Some(err), + response: None, + })) + .await; + continue; } - .build()); - } - Err(err) => { - if !metadata_seen { - return Err(InvalidSnafu { - reason: format!("failed before receiving metadata: {err}"), - } - .build()); + }; + let response = match serde_json::to_string(&parsed) { + Ok(value) => value, + Err(err) => { + let _ = conn + .send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { + success: false, + error: Some(format!("failed to serialize reply output: {err}")), + response: None, + })) + .await; + continue; } - break; + }; + conn.send_server_message(&ServerMessage::ReplyValidation(ReplyValidation { + success: true, + error: None, + response: Some(response), + })) + .await + .map_err(|err| { + InvalidSnafu { + reason: format!("failed to send reply validation: {err}"), + } + .build() + })?; + first_message = false; + } + ClientMessage::Mcp(_) => { + return Err(InvalidSnafu { + reason: "mcp request is not allowed in combo session".to_string(), } + .build()); + } + } + } + + Ok(()) +} + +async fn run_session_server( + server: SessionSocketServer, + discovery: bool, + event_tx: mpsc::Sender, + mut shutdown_rx: oneshot::Receiver<()>, +) -> Result { + let state = Arc::new(AsyncMutex::new(SessionState::default())); + let cancel_token = CancellationToken::new(); + let (error_tx, mut error_rx) = mpsc::unbounded_channel(); + let mut handles: Vec> = Vec::new(); + let mut fatal_error: Option = None; + + loop { + tokio::select! { + _ = &mut shutdown_rx => break, + Some(err) = error_rx.recv() => { + fatal_error = Some(err); + break; } + accept = server.accept() => { + let conn = accept.context(AcceptSessionConnectionSnafu)?; + let task_state = state.clone(); + let task_event_tx = event_tx.clone(); + let task_cancel = cancel_token.clone(); + let task_error_tx = error_tx.clone(); + handles.push(tokio::spawn(async move { + if let Err(err) = handle_session_connection( + conn, + discovery, + task_event_tx, + task_state, + task_cancel, + ) + .await + { + let _ = task_error_tx.send(err); + } + })); + } + } + } + + cancel_token.cancel(); + for handle in handles { + if let Err(err) = handle.await { + warn!(?err, "session connection task failed"); } } - if !metadata_seen { + if let Some(err) = fatal_error { + return Err(err); + } + + let state = state.lock().await.clone(); + if state.metadata.is_none() { return Err(InvalidSnafu { reason: "metadata not received from session".to_string(), } @@ -1305,6 +1492,45 @@ mod tests { Ok(()) } + #[test] + fn parse_reply_fields_accepts_required_fields() { + let schemas = vec![ + PromptSchema { + name: "message".to_string(), + description: "reply message".to_string(), + }, + PromptSchema { + name: "status".to_string(), + description: "status".to_string(), + }, + ]; + let fields = vec!["--message=hello".to_string(), "--status=ok".to_string()]; + + let parsed = parse_reply_fields(&fields, &schemas).expect("expected parse success"); + assert_eq!(parsed.get("message"), Some(&"hello".to_string())); + assert_eq!(parsed.get("status"), Some(&"ok".to_string())); + } + + #[test] + fn parse_reply_fields_rejects_unknown_or_missing_fields() { + let schemas = vec![PromptSchema { + name: "message".to_string(), + description: "reply message".to_string(), + }]; + let fields = vec!["--extra=value".to_string()]; + + let err = parse_reply_fields(&fields, &schemas).expect_err("expected parse failure"); + assert!(err.contains("unexpected field key")); + + let fields = vec!["--message=hello".to_string(), "--extra=value".to_string()]; + let err = parse_reply_fields(&fields, &schemas).expect_err("expected parse failure"); + assert!(err.contains("unexpected field key")); + + let fields = Vec::new(); + let err = parse_reply_fields(&fields, &schemas).expect_err("expected parse failure"); + assert!(err.contains("missing required fields")); + } + #[tokio::test] async fn discovery_server_interrupts_record() -> Result<(), Box> { let session_env = session_env_with_coco(); From b5c0fa6845ddd967689f122574a54d39e2729078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Fri, 16 Jan 2026 22:51:02 +0800 Subject: [PATCH 3/3] feat: add command classification and guidance for combo reply offload --- crates/coco-tui/src/components/chat.rs | 161 ++++++++++++++++--------- src/agent.rs | 4 +- src/agent/bash_executor.rs | 33 +++++ 3 files changed, 143 insertions(+), 55 deletions(-) diff --git a/crates/coco-tui/src/components/chat.rs b/crates/coco-tui/src/components/chat.rs index 70c27b5..1e26f8b 100644 --- a/crates/coco-tui/src/components/chat.rs +++ b/crates/coco-tui/src/components/chat.rs @@ -3,8 +3,8 @@ use code_combo::tools::{BashInput, Final}; use code_combo::{ Agent, Block as ChatBlock, ChatResponse, ChatStreamUpdate, Config, Content as ChatContent, Message as ChatMessage, Output, RuntimeOverrides, SessionEnv, Starter, StarterCommand, - StarterError, StarterEvent, StopReason, TextEdit, ToolUse, discover_starters, - load_runtime_overrides, save_runtime_overrides, + StarterError, StarterEvent, StopReason, TextEdit, ToolUse, bash_unsafe_ranges, + discover_starters, load_runtime_overrides, parse_primary_command, save_runtime_overrides, }; use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; use futures::StreamExt; @@ -1568,6 +1568,48 @@ fn build_offload_reply_retry_directive(schemas: &[code_combo::PromptSchema]) -> format!("The previous response did not produce a valid coco reply. Retry.\n\n{directive}") } +enum OffloadCommandKind { + Coco, + Safe, + Unsafe, +} + +fn classify_offload_command(command: &str) -> OffloadCommandKind { + let is_coco_reply = is_coco_reply_command(command); + if is_coco_reply { + return OffloadCommandKind::Coco; + } + if is_safe_command(command) { + return OffloadCommandKind::Safe; + } + OffloadCommandKind::Unsafe +} + +fn build_offload_reply_guidance( + schemas: &[code_combo::PromptSchema], + command: &str, + executed: bool, +) -> String { + let field_args: Vec = schemas + .iter() + .map(|schema| format!("--{}=...", schema.name)) + .collect(); + let field_descriptions: Vec = schemas + .iter() + .map(|schema| format!("- {}: {}", schema.name, schema.description)) + .collect(); + let status = if executed { "executed" } else { "blocked" }; + format!( + "The previous tool call was {status} but did not use `coco reply` (command: {command}).\n\ +You must call the bash tool with `coco reply` and only that command.\n\ +Required fields:\n{field_list}\n\n\ +Example:\n\ +coco reply {field_args}", + field_list = field_descriptions.join("\n"), + field_args = field_args.join(" "), + ) +} + #[derive(Debug, Snafu)] enum ComboReplyError { #[snafu(display("prompt reply cancelled"))] @@ -1618,50 +1660,24 @@ fn should_retry_offload_combo_reply(error: &ComboReplyError) -> bool { error.should_retry() } -fn is_env_assignment(token: &str) -> bool { - let Some((key, _value)) = token.split_once('=') else { - return false; - }; - if key.is_empty() { - return false; - } - if key.starts_with('-') { - return false; - } - key.chars() - .all(|ch| ch == '_' || ch.is_ascii_alphanumeric()) +fn is_coco_command_name(name: &str) -> bool { + name == "coco" || name.ends_with("/coco") } -/// Check if a command string starts with `coco reply` (allowing env assignments). fn is_coco_reply_command(command: &str) -> bool { - let mut parts = command.split_whitespace(); - let mut token = match parts.next() { - Some(token) => token, - None => return false, + let summary = match parse_primary_command(command) { + Ok(summary) => summary, + Err(_) => return false, }; - - loop { - if token == "env" { - token = match parts.next() { - Some(next) => next, - None => return false, - }; - continue; - } - if is_env_assignment(token) { - token = match parts.next() { - Some(next) => next, - None => return false, - }; - continue; - } - break; + if !is_coco_command_name(&summary.name) { + return false; } + matches!(summary.args.first(), Some(arg) if arg == "reply") +} - let command_token = token; - let reply_token = parts.next(); - let is_coco = command_token == "coco" || command_token.ends_with("/coco"); - is_coco && matches!(reply_token, Some("reply")) +fn is_safe_command(command: &str) -> bool { + let trimmed = command.trim(); + !trimmed.is_empty() && bash_unsafe_ranges(command).is_empty() } /// Parse coco reply stdout output as JSON fields and validate required schemas. @@ -1836,26 +1852,17 @@ async fn handle_offload_combo_reply( } })?; - // Validate it's a coco reply command - if !is_coco_reply_command(&bash_input.command) { - return Err(ComboReplyError::UnexpectedCommand { - command: bash_input.command.clone(), - }); - } + let original_command = bash_input.command.clone(); + let command_kind = classify_offload_command(&bash_input.command); - if !bash_input.command.contains("COCO_SESSION_SOCK=") { + if matches!(command_kind, OffloadCommandKind::Coco) + && !bash_input.command.contains("COCO_SESSION_SOCK=") + { let socket_path = session_socket_path.to_string_lossy(); let escaped_path = shell_escape(socket_path.as_ref()); let command = bash_input.command.trim_start(); bash_input.command = format!("COCO_SESSION_SOCK={} {}", escaped_path, command); } - let bash_input_value = - serde_json::to_value(&bash_input).map_err(|err| ComboReplyError::InvalidBashInput { - message: err.to_string(), - })?; - - // Auto-grant and execute the bash command - agent.grant_once(&bash_tool_use.id, BASH_TOOL_NAME); // Send tool use event for UI feedback (thinking already streamed via PromptStream) tx.send( @@ -1869,6 +1876,42 @@ async fn handle_offload_combo_reply( ) .ok(); + if matches!(command_kind, OffloadCommandKind::Unsafe) { + let reason = match code_combo::bash_unsafe_reason(&original_command) { + Ok(_) => "command not allowlisted".to_string(), + Err(reason) => reason, + }; + let output = Final::Message(format!("command rejected: {reason}; expected coco reply")); + agent + .append_message(build_tool_result_message(&bash_tool_use.id, true, &output)) + .await; + tx.send( + ComboEvent::ReplyToolResult { + name: combo_name.to_string(), + tool_use_id: bash_tool_use.id.clone(), + is_error: true, + output: output.clone(), + } + .into(), + ) + .ok(); + let prompt = build_offload_reply_guidance(schemas, &original_command, false); + agent + .append_message(ChatMessage::user(ChatContent::Text(prompt))) + .await; + return Err(ComboReplyError::UnexpectedCommand { + command: original_command, + }); + } + + let bash_input_value = + serde_json::to_value(&bash_input).map_err(|err| ComboReplyError::InvalidBashInput { + message: err.to_string(), + })?; + + // Auto-grant and execute the bash command + agent.grant_once(&bash_tool_use.id, BASH_TOOL_NAME); + // Execute the bash command let mut final_output: Option = None; let tool_use_id = bash_tool_use.id.clone(); @@ -1919,6 +1962,16 @@ async fn handle_offload_combo_reply( ) .ok(); + if matches!(command_kind, OffloadCommandKind::Safe) { + let prompt = build_offload_reply_guidance(schemas, &original_command, true); + agent + .append_message(ChatMessage::user(ChatContent::Text(prompt))) + .await; + return Err(ComboReplyError::UnexpectedCommand { + command: original_command, + }); + } + if is_error { return Err(ComboReplyError::BashCommandFailed { output: format!("{output:?}"), diff --git a/src/agent.rs b/src/agent.rs index 47cd29f..bd959f0 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -28,7 +28,9 @@ mod executor; mod prompt; pub use crate::provider::{Content, Message, StopReason, ToolUse}; -pub use bash_executor::{bash_unsafe_ranges, bash_unsafe_reason}; +pub use bash_executor::{ + ParsedCommandSummary, bash_unsafe_ranges, bash_unsafe_reason, parse_primary_command, +}; pub use executor::{ExecuteStatus, Executor, Input, Output}; const DEFAULT_THINKING_BUDGET_TOKENS: usize = 1024; diff --git a/src/agent/bash_executor.rs b/src/agent/bash_executor.rs index 5c30204..c674f27 100644 --- a/src/agent/bash_executor.rs +++ b/src/agent/bash_executor.rs @@ -38,6 +38,12 @@ struct ParsedCommand { name_range: Range, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParsedCommandSummary { + pub name: String, + pub args: Vec, +} + #[derive(Debug)] enum ParseError { Empty, @@ -64,6 +70,20 @@ pub fn should_bypass_permission(input: &BashInput) -> bool { is_safe_command(&input.command) } +pub fn parse_primary_command(command: &str) -> Result { + let commands = parse_commands(command).map_err(|err| err.to_reason().to_string())?; + if commands.len() != 1 { + return Err("multiple commands".to_string()); + } + let command = commands + .first() + .ok_or_else(|| "command is empty".to_string())?; + Ok(ParsedCommandSummary { + name: command.name.clone(), + args: command.args.iter().map(|arg| arg.text.clone()).collect(), + }) +} + pub fn bash_unsafe_ranges(command: &str) -> Vec<(Range, String)> { let trimmed = command.trim(); if trimmed.is_empty() { @@ -171,6 +191,19 @@ pub fn bash_unsafe_ranges(command: &str) -> Vec<(Range, String)> { ranges } +impl ParseError { + fn to_reason(&self) -> &'static str { + match self { + ParseError::Empty => "command is empty", + ParseError::MultipleStatements => "multiple statements", + ParseError::MissingCommandName => "missing command name", + ParseError::ParseFailed => "command parse failed", + ParseError::SyntaxError => "syntax error", + ParseError::UnsupportedNode => "unsupported shell syntax", + } + } +} + pub fn bash_unsafe_reason(command: &str) -> Result<(), String> { let details = bash_unsafe_ranges(command); if details.is_empty() {