Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 302 additions & 48 deletions crates/coco-tui/src/components/chat.rs

Large diffs are not rendered by default.

364 changes: 198 additions & 166 deletions src/agent.rs

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions src/agent/bash_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ struct ParsedCommand {
name_range: Range<usize>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedCommandSummary {
pub name: String,
pub args: Vec<String>,
}

#[derive(Debug)]
enum ParseError {
Empty,
Expand All @@ -64,6 +70,20 @@ pub fn should_bypass_permission(input: &BashInput) -> bool {
is_safe_command(&input.command)
}

pub fn parse_primary_command(command: &str) -> Result<ParsedCommandSummary, String> {
let commands = parse_commands(command).map_err(|err| err.to_reason().to_string())?;
if commands.len() != 1 {
return Err("multiple commands".to_string());
}
let command = commands
.first()
.ok_or_else(|| "command is empty".to_string())?;
Ok(ParsedCommandSummary {
name: command.name.clone(),
args: command.args.iter().map(|arg| arg.text.clone()).collect(),
})
}

pub fn bash_unsafe_ranges(command: &str) -> Vec<(Range<usize>, String)> {
let trimmed = command.trim();
if trimmed.is_empty() {
Expand Down Expand Up @@ -171,6 +191,19 @@ pub fn bash_unsafe_ranges(command: &str) -> Vec<(Range<usize>, String)> {
ranges
}

impl ParseError {
fn to_reason(&self) -> &'static str {
match self {
ParseError::Empty => "command is empty",
ParseError::MultipleStatements => "multiple statements",
ParseError::MissingCommandName => "missing command name",
ParseError::ParseFailed => "command parse failed",
ParseError::SyntaxError => "syntax error",
ParseError::UnsupportedNode => "unsupported shell syntax",
}
}
}

pub fn bash_unsafe_reason(command: &str) -> Result<(), String> {
let details = bash_unsafe_ranges(command);
if details.is_empty() {
Expand Down
50 changes: 24 additions & 26 deletions src/cmd/prompt.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::collections::HashMap;

use serde_json::Value;
use snafu::prelude::*;
use tokio::io::AsyncReadExt;
use tracing::info;

use crate::{PromptPayload, PromptSchema, SessionSocketClient, error::Result};
use crate::{PromptPayload, PromptSchema, ReplyPayload, SessionSocketClient, error::Result};

pub async fn handle_ask(prompt: String, schemas: Vec<String>) -> Result<()> {
let prompt = resolve_prompt(prompt).await?;
Expand Down Expand Up @@ -68,35 +66,35 @@ pub async fn handle_tell(prompt: String) -> Result<()> {
/// Fields are provided as --field=value format.
/// Validation is done by the parent process (TUI) which knows the required schemas.
pub async fn handle_reply(fields: Vec<String>) -> Result<()> {
let parsed_fields = parse_reply_fields(&fields)?;
let Some(client) = SessionSocketClient::from_env()
.await
.whatever_context("failed to new from env COCO_SESSION_SOCK")?
else {
whatever!("env COCO_SESSION_SOCK is not set");
};

let validation = client
.send_reply_wait_validation(ReplyPayload { fields })
.await
.whatever_context("failed to send reply to session socket")?;

// Output the fields as JSON for bash result parsing
let output = serde_json::to_string(&parsed_fields)
.whatever_context("failed to serialize reply fields")?;
println!("{output}");
if !validation.success {
let error = validation
.error
.unwrap_or_else(|| "reply validation failed".to_string());
whatever!("{error}");
}

let Some(response) = validation.response else {
whatever!("reply validation succeeded without response");
};

println!("{response}");

info!("reply output generated");
Ok(())
}

fn parse_reply_fields(fields: &[String]) -> Result<HashMap<String, String>> {
let mut parsed = HashMap::new();
for field in fields {
// Handle --field=value format
let field = field.strip_prefix("--").unwrap_or(field);
let Some((key, value)) = field.split_once('=') else {
whatever!("invalid field format {field:?}, expected --field=value");
};
let key = key.trim();
ensure_whatever!(!key.is_empty(), "field key cannot be empty");
if parsed.contains_key(key) {
whatever!("duplicate field key: {key}");
}
parsed.insert(key.to_string(), value.to_string());
}
Ok(parsed)
}

async fn resolve_prompt(prompt: String) -> Result<String> {
if !prompt.trim().is_empty() {
return Ok(prompt);
Expand Down
10 changes: 6 additions & 4 deletions src/combo/session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, path::Path, path::PathBuf, sync::Arc};
use std::{path::Path, path::PathBuf, sync::Arc};

use serde::{Deserialize, Serialize};
use snafu::prelude::*;
Expand Down Expand Up @@ -110,11 +110,11 @@ pub struct ThinkingConfig {
}

/// Payload for combo reply via bash command offload.
/// Contains the field values extracted by the LLM.
/// Contains raw `--field=value` args for server-side parsing.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ReplyPayload {
/// Field name to value mapping
pub fields: HashMap<String, String>,
/// Raw field args, e.g. "--message=hello"
pub fields: Vec<String>,
}

/// Server response to validate reply fields against required schemas.
Expand All @@ -123,6 +123,8 @@ pub struct ReplyValidation {
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<String>,
}

#[derive(Debug, Snafu)]
Expand Down
Loading