diff --git a/crates/jp_llm/Cargo.toml b/crates/jp_llm/Cargo.toml index f5a828e5..e50561b8 100644 --- a/crates/jp_llm/Cargo.toml +++ b/crates/jp_llm/Cargo.toml @@ -15,23 +15,18 @@ version.workspace = true [dependencies] jp_config = { workspace = true } jp_conversation = { workspace = true } -jp_inquire = { workspace = true } jp_mcp = { workspace = true } jp_openrouter = { workspace = true } -jp_printer = { workspace = true } jp_tool = { workspace = true } async-anthropic = { workspace = true } async-stream = { workspace = true } async-trait = { workspace = true } camino = { workspace = true } -crossterm = { workspace = true } -duct = { workspace = true } -duct_sh = { workspace = true } +chrono = { workspace = true } futures = { workspace = true } gemini_client_rs = { workspace = true } indexmap = { workspace = true } -jsonschema = { workspace = true } minijinja = { workspace = true, features = [ "builtins", "json", @@ -41,18 +36,15 @@ minijinja = { workspace = true, features = [ ] } ollama-rs = { workspace = true, features = ["rustls", "stream"] } open-editor = { workspace = true } -openai = { workspace = true, features = ["rustls", "reqwest"] } openai_responses = { workspace = true, features = ["stream"] } quick-xml = { workspace = true, features = ["serialize"] } reqwest = { workspace = true } reqwest-eventsource = { workspace = true } -schemars = { workspace = true } serde = { workspace = true } serde_json = { workspace = true, features = ["preserve_order"] } thiserror = { workspace = true } -chrono = { workspace = true } tokio = { workspace = true } -tokio-stream = { workspace = true } +tokio-util = { workspace = true } tracing = { workspace = true } url = { workspace = true } diff --git a/crates/jp_llm/src/error.rs b/crates/jp_llm/src/error.rs index 4a9d552d..e1d0b653 100644 --- a/crates/jp_llm/src/error.rs +++ b/crates/jp_llm/src/error.rs @@ -1,9 +1,262 @@ +use std::{ + fmt, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use async_anthropic::errors::AnthropicError; +use reqwest::header::{HeaderMap, RETRY_AFTER}; +use serde_json::Value; + use crate::stream::aggregator::tool_call_request::AggregationError; pub(crate) type Result = std::result::Result; +/// A provider-agnostic streaming error. +#[derive(Debug)] +pub struct StreamError { + /// The kind of streaming error. + pub kind: StreamErrorKind, + + /// Whether and when the request can be retried. + /// + /// If `Some`, the request can be retried after the specified duration. + /// If `None`, the caller should use exponential backoff or not retry. + pub retry_after: Option, + + /// Human-readable error message. + message: String, + + /// The underlying source of the error. + /// + /// This is kept for logging and display purposes, but callers should + /// make decisions based on `kind` and `retry_after`, not the source. + source: Option>, +} + +impl StreamError { + /// Create a new stream error. + #[must_use] + pub fn new(kind: StreamErrorKind, message: impl Into) -> Self { + Self { + kind, + message: message.into(), + retry_after: None, + source: None, + } + } + + /// Create a timeout error. + #[must_use] + pub fn timeout(message: impl Into) -> Self { + Self::new(StreamErrorKind::Timeout, message) + } + + /// Create a connection error. + #[must_use] + pub fn connect(message: impl Into) -> Self { + Self::new(StreamErrorKind::Connect, message) + } + + /// Create a rate limit error. + #[must_use] + pub fn rate_limit(retry_after: Option) -> Self { + Self { + kind: StreamErrorKind::RateLimit, + retry_after, + message: "Rate limited".into(), + source: None, + } + } + + /// Create a transient error. + #[must_use] + pub fn transient(message: impl Into) -> Self { + Self::new(StreamErrorKind::Transient, message) + } + + /// Create an error from another error type. + #[must_use] + pub fn other(message: impl Into) -> Self { + Self::new(StreamErrorKind::Other, message) + } + + /// Set the `retry_after` duration. + #[must_use] + pub fn with_retry_after(mut self, duration: Duration) -> Self { + self.retry_after = Some(duration); + self + } + + /// Set the source error. + #[must_use] + pub fn with_source( + mut self, + source: impl Into>, + ) -> Self { + self.source = Some(source.into()); + self + } + + /// Returns whether this error is likely retryable. + #[must_use] + pub fn is_retryable(&self) -> bool { + matches!( + self.kind, + StreamErrorKind::Timeout + | StreamErrorKind::Connect + | StreamErrorKind::RateLimit + | StreamErrorKind::Transient + ) || self.retry_after.is_some() + } +} + +impl fmt::Display for StreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message)?; + if let Some(ref source) = self.source { + write!(f, ": {source}")?; + } + + Ok(()) + } +} + +impl std::error::Error for StreamError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source + .as_ref() + .map(|e| e.as_ref() as &(dyn std::error::Error + 'static)) + } +} + +/// Canonical classifier for [`reqwest::Error`]. +/// +/// Provider-specific `From` impls should delegate to this when they encounter +/// an inner [`reqwest::Error`], rather than re-implementing the classification +/// logic. +impl From for StreamError { + fn from(err: reqwest::Error) -> Self { + if err.is_timeout() { + StreamError::timeout(err.to_string()).with_source(err) + } else if err.is_connect() { + StreamError::connect(err.to_string()).with_source(err) + } else if err.status().is_some_and(|s| s == 429) { + StreamError::rate_limit(None).with_source(err) + } else if err + .status() + .is_some_and(|s| matches!(s.as_u16(), 408 | 409 | _ if s.as_u16() >= 500)) + || err.is_body() + || err.is_decode() + { + StreamError::transient(err.to_string()).with_source(err) + } else { + StreamError::other(err.to_string()).with_source(err) + } + } +} + +/// Canonical classifier for [`reqwest_eventsource::Error`]. +/// +/// Delegates the [`Transport`] variant to the [`reqwest::Error`] classifier. +/// Classifies [`InvalidStatusCode`] by its HTTP status, extracts `Retry-After` +/// headers, and honours the non-standard `x-should-retry` header used by +/// OpenAI. +/// +/// Provider-specific `From` impls should delegate to this when they encounter +/// an inner [`reqwest_eventsource::Error`], rather than re-implementing the +/// classification logic. +/// +/// [`Transport`]: reqwest_eventsource::Error::Transport +/// [`InvalidStatusCode`]: reqwest_eventsource::Error::InvalidStatusCode +impl From for StreamError { + fn from(err: reqwest_eventsource::Error) -> Self { + use reqwest_eventsource::Error; + + match err { + Error::Transport(error) => Self::from(error), + Error::InvalidStatusCode(status, response) => { + let headers = response.headers(); + let retry_after = extract_retry_after(headers); + let code = status.as_u16(); + + // Non-standard `x-should-retry` header overrides + // status-code heuristics. + let retryable = match header_str(headers, "x-should-retry") { + Some("true") => true, + Some("false") => false, + _ => matches!(code, 408 | 409 | 429 | _ if code >= 500), + }; + + if !retryable { + StreamError::other(format!("HTTP {status}")) + } else if code == 429 { + StreamError::rate_limit(retry_after) + } else { + let err = StreamError::transient(format!("HTTP {status}")); + match retry_after { + Some(d) => err.with_retry_after(d), + None => err, + } + } + } + error @ (Error::Utf8(_) + | Error::Parser(_) + | Error::InvalidContentType(_, _) + | Error::InvalidLastEventId(_) + | Error::StreamEnded) => StreamError::other(error.to_string()).with_source(error), + } + } +} + +/// The kind of streaming error. +/// +/// This abstraction allows the resilience layer to make retry decisions without +/// knowing the specific provider implementation details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamErrorKind { + /// Request timed out. + Timeout, + + /// Failed to establish connection. + Connect, + + /// Rate limited by the provider. + RateLimit, + + /// Transient error (server error, temporary failure). + /// These are typically safe to retry. + Transient, + + /// The API key's quota has been exhausted. + /// This is not retryable — the user needs to top up or change plans. + InsufficientQuota, + + /// Other errors that are not categorized. + /// These may or may not be retryable depending on the specific error. + Other, +} + +impl StreamErrorKind { + /// Returns the error kind as a string. + #[must_use] + pub fn as_str(&self) -> &'static str { + match self { + Self::Timeout => "Timeout", + Self::Connect => "Connection error", + Self::RateLimit => "Rate limited", + Self::Transient => "Server error", + Self::InsufficientQuota => "Insufficient API quota", + Self::Other => "Stream Error", + } + } +} + #[derive(Debug, thiserror::Error)] pub enum Error { + /// Streaming error with provider-agnostic classification. + #[error(transparent)] + Stream(#[from] StreamError), + #[error("OpenRouter error: {0}")] OpenRouter(#[from] jp_openrouter::Error), @@ -43,9 +296,6 @@ pub enum Error { #[error("Ollama error: {0}")] Ollama(#[from] ollama_rs::error::OllamaError), - #[error("Missing structured data in response")] - MissingStructuredData, - #[error("Unknown model: {0}")] UnknownModel(String), @@ -56,7 +306,7 @@ pub enum Error { Request(#[from] reqwest::Error), #[error("Anthropic error: {0}")] - Anthropic(#[from] async_anthropic::errors::AnthropicError), + Anthropic(#[from] AnthropicError), #[error("Anthropic request builder error: {0}")] AnthropicRequestBuilder(#[from] async_anthropic::types::CreateMessagesRequestBuilderError), @@ -79,37 +329,6 @@ pub enum Error { ToolCallRequestAggregator(#[from] AggregationError), } -impl From for Error { - fn from(error: gemini_client_rs::GeminiError) -> Self { - use gemini_client_rs::GeminiError; - - match &error { - GeminiError::Api(api) if api.get("status").is_some_and(|v| v.as_u64() == Some(404)) => { - if let Some(model) = api.pointer("/message/error/message").and_then(|v| { - v.as_str().and_then(|s| { - s.contains("Call ListModels").then(|| { - s.split('/') - .nth(1) - .and_then(|v| v.split(' ').next()) - .unwrap_or("unknown") - }) - }) - }) { - return Self::UnknownModel(model.to_owned()); - } - Self::Gemini(error) - } - _ => Self::Gemini(error), - } - } -} - -impl From for Error { - fn from(error: openai_responses::types::response::Error) -> Self { - Self::OpenaiResponse(error) - } -} - #[cfg(test)] impl PartialEq for Error { fn eq(&self, other: &Self) -> bool { @@ -147,7 +366,7 @@ pub enum ToolError { #[error("Failed to serialize tool arguments")] SerializeArgumentsError { - arguments: serde_json::Value, + arguments: Value, #[source] error: serde_json::Error, }, @@ -155,16 +374,23 @@ pub enum ToolError { #[error("Tool call failed: {0}")] ToolCallFailed(String), + #[error("Failed to spawn command: {command}")] + SpawnError { + command: String, + #[source] + error: std::io::Error, + }, + #[error("Failed to open editor to edit tool call")] OpenEditorError { - arguments: serde_json::Value, + arguments: Value, #[source] error: open_editor::errors::OpenEditorError, }, #[error("Failed to edit tool call")] EditArgumentsError { - arguments: serde_json::Value, + arguments: Value, #[source] error: serde_json::Error, }, @@ -179,7 +405,7 @@ pub enum ToolError { #[error("Invalid `type` property for {key}, got {value:?}, expected one of {need:?}")] InvalidType { key: String, - value: serde_json::Value, + value: Value, need: Vec<&'static str>, }, @@ -219,3 +445,241 @@ impl PartialEq for ToolError { format!("{self:?}") == format!("{other:?}") } } + +/// Heuristic check for quota/billing exhaustion based on error text. +/// +/// This catches the common patterns across providers: +/// - OpenAI: `"insufficient_quota"` +/// - Anthropic: `"billing_error"`, `"Your credit balance is too low"` +/// - Google: `"RESOURCE_EXHAUSTED"`, `"Quota exceeded"` +/// - OpenRouter: `"insufficient credits"`, `"out of credits"` +pub(crate) fn looks_like_quota_error(text: &str) -> bool { + let lower = text.to_ascii_lowercase(); + lower.contains("insufficient_quota") + || lower.contains("insufficient quota") + || lower.contains("insufficient credits") + || lower.contains("out of credits") + || lower.contains("billing_error") + || lower.contains("credit balance is too low") + || lower.contains("quota exceeded") + || lower.contains("resource_exhausted") +} + +/// Extracts a retry-after duration from an error message body. +/// +/// Last-resort fallback when response headers don't carry retry timing. +/// Matches common natural-language patterns found in API error responses: +/// +/// - `"retry after 30 seconds"` +/// - `"retry-after: 30"` +/// - `"wait 30 seconds"` +/// - `"try again in 5s"` / `"try again in 5.5s"` +/// - `"retryDelay": "30s"` (Google Gemini JSON body) +pub(crate) fn extract_retry_from_text(text: &str) -> Option { + let lower = text.to_ascii_lowercase(); + + // Scan for patterns and extract the first match. + for window in lower.split_whitespace().collect::>().windows(4) { + // "retry after N second(s)" + if window[0] == "retry" + && window[1] == "after" + && let Some(secs) = parse_secs_token(window[2]) + { + return Some(Duration::from_secs(secs)); + } + // "wait N second(s)" + if window[0] == "wait" + && let Some(secs) = parse_secs_token(window[1]) + { + return Some(Duration::from_secs(secs)); + } + // "try again in Ns" / "try again in N.Ns" + if window[0] == "try" + && window[1] == "again" + && window[2] == "in" + && let Some(secs) = parse_secs_token(window[3]) + { + return Some(Duration::from_secs(secs)); + } + } + + // "retry-after: N" / "retry-after:N" + if let Some(pos) = lower.find("retry-after:") { + let after = lower[pos + "retry-after:".len()..].trim_start(); + if let Some(secs) = after + .split(|c: char| !c.is_ascii_digit()) + .next() + .and_then(|s| s.parse::().ok()) + .filter(|s| *s > 0) + { + return Some(Duration::from_secs(secs)); + } + } + + // "retryDelay": "30s" (Gemini JSON body) + if let Some(pos) = lower.find("retrydelay") { + let after = &lower[pos..]; + if let Some(d) = after + .split('"') + .find(|s| s.ends_with('s') && s[..s.len() - 1].chars().all(|c| c.is_ascii_digit())) + .and_then(parse_human_duration) + { + return Some(Duration::from_secs(d)); + } + } + + None +} + +/// Extracts a retry-after duration from common rate-limit response +/// headers. +/// +/// Headers are checked in decreasing order of authority: +/// +/// 1. `retry-after-ms` — Non-standard (OpenAI). Millisecond precision. +/// 2. `Retry-After` — RFC 7231. Integer or float seconds (the spec +/// mandates integers, but floats are accepted). HTTP-date values are +/// not supported. +/// 3. `RateLimit` — IETF draft `t=` parameter (delta-seconds). +/// See: +/// 4. `x-ratelimit-reset-requests` / `x-ratelimit-reset-tokens` — +/// OpenAI-style Human-duration values (e.g. `6m0s`). Takes the longer +/// of the two if both are present. +/// 5. `x-ratelimit-reset` — Unix timestamp, converted relative to now. +fn extract_retry_after(headers: &HeaderMap) -> Option { + if let Some(d) = header_positive_f64(headers, "retry-after-ms") + .map(|ms| Duration::from_secs_f64(ms / 1000.0)) + { + return Some(d); + } + + if let Some(d) = header_positive_f64(headers, RETRY_AFTER).map(Duration::from_secs_f64) { + return Some(d); + } + + // IETF draft: `RateLimit: remaining=0; t=` + if let Some(secs) = headers + .get("ratelimit") + .and_then(|v| v.to_str().ok()) + .and_then(|v| { + v.split(';') + .map(str::trim) + .find_map(|p| p.strip_prefix("t=")) + }) + .and_then(|v| v.trim().parse::().ok()) + .filter(|v| *v > 0) + { + return Some(Duration::from_secs(secs)); + } + + // 4. OpenAI: `x-ratelimit-reset-requests` / `x-ratelimit-reset-tokens` + // Both use human-duration format (e.g. "1s", "6m0s"). Take the max. + let requests = header_str(headers, "x-ratelimit-reset-requests").and_then(parse_human_duration); + let tokens = header_str(headers, "x-ratelimit-reset-tokens").and_then(parse_human_duration); + + if let Some(secs) = requests.into_iter().chain(tokens).max() { + return Some(Duration::from_secs(secs)); + } + + if let Some(reset_ts) = header_u64(headers, "x-ratelimit-reset") { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if reset_ts > now { + return Some(Duration::from_secs(reset_ts - now)); + } + } + + None +} + +/// Read a header value as `&str`. +fn header_str(headers: &HeaderMap, name: impl reqwest::header::AsHeaderName) -> Option<&str> { + headers.get(name).and_then(|v| v.to_str().ok()) +} + +/// Read a header value as `u64`. +fn header_u64(headers: &HeaderMap, name: impl reqwest::header::AsHeaderName) -> Option { + header_str(headers, name).and_then(|s| s.parse().ok()) +} + +/// Read a header value as a positive, finite `f64`. +fn header_positive_f64( + headers: &HeaderMap, + name: impl reqwest::header::AsHeaderName, +) -> Option { + header_str(headers, name) + .and_then(|s| s.parse::().ok()) + .filter(|v| *v > 0.0 && v.is_finite()) +} + +/// Parses a human-style duration string into whole seconds. +/// +/// Supported units: `h` (hours), `m` (minutes), `s` (seconds), `ms` +/// (milliseconds — rounded up to 1s if non-zero and total is 0). +/// +/// Examples: `"1s"` → 1, `"6m0s"` → 360, `"1h30m"` → 5400, +/// `"200ms"` → 1. +/// +/// Returns `None` for empty, zero, or unparseable values. +fn parse_human_duration(s: &str) -> Option { + let mut total: u64 = 0; + let mut has_sub_second = false; + let mut remaining = s.trim(); + + while !remaining.is_empty() { + let num_end = remaining + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(remaining.len()); + + if num_end == 0 { + return None; + } + + let n: u64 = remaining[..num_end].parse().ok()?; + remaining = &remaining[num_end..]; + + if remaining.starts_with("ms") { + has_sub_second = n > 0; + remaining = &remaining[2..]; + } else if remaining.starts_with('h') { + total += n * 3600; + remaining = &remaining[1..]; + } else if remaining.starts_with('m') { + total += n * 60; + remaining = &remaining[1..]; + } else if remaining.starts_with('s') { + total += n; + remaining = &remaining[1..]; + } else { + return None; + } + } + + // Round sub-second durations up to 1s so we don't return 0 + // when the server asked us to wait. + if total == 0 && has_sub_second { + total = 1; + } + + if total > 0 { Some(total) } else { None } +} + +/// Parse a token like `"30"`, `"30s"`, `"5.5s"` into whole seconds. +#[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +fn parse_secs_token(s: &str) -> Option { + let s = s + .trim_end_matches('s') + .trim_end_matches("second") + .trim_end_matches(','); + s.parse::() + .ok() + .filter(|v| *v > 0.0 && v.is_finite()) + .map(|v| v.ceil() as u64) +} + +#[cfg(test)] +#[path = "error_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/error_tests.rs b/crates/jp_llm/src/error_tests.rs new file mode 100644 index 00000000..4b94d2e1 --- /dev/null +++ b/crates/jp_llm/src/error_tests.rs @@ -0,0 +1,228 @@ +use super::*; + +#[test] +fn extract_retry_after_from_retry_after_ms() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("retry-after-ms", "1500".parse().unwrap()); + + assert_eq!( + extract_retry_after(&headers), + Some(Duration::from_millis(1500)) + ); +} + +#[test] +fn extract_retry_after_ms_takes_priority_over_retry_after() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("retry-after-ms", "500".parse().unwrap()); + headers.insert(reqwest::header::RETRY_AFTER, "30".parse().unwrap()); + + // retry-after-ms is more precise, should be preferred. + assert_eq!( + extract_retry_after(&headers), + Some(Duration::from_millis(500)) + ); +} + +#[test] +fn extract_retry_after_from_standard_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(reqwest::header::RETRY_AFTER, "30".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_secs(30))); +} + +#[test] +fn extract_retry_after_accepts_float_seconds() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(reqwest::header::RETRY_AFTER, "1.5".parse().unwrap()); + + assert_eq!( + extract_retry_after(&headers), + Some(Duration::from_millis(1500)) + ); +} + +#[test] +fn extract_retry_after_ignores_http_date() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::RETRY_AFTER, + "Wed, 21 Oct 2025 07:28:00 GMT".parse().unwrap(), + ); + + // HTTP-date is not supported, should return None. + assert_eq!(extract_retry_after(&headers), None); +} + +#[test] +fn extract_retry_after_from_ietf_ratelimit_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("ratelimit", "remaining=0; t=45".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_secs(45))); +} + +#[test] +fn extract_retry_after_from_openai_reset_requests() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset-requests", "6m0s".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_mins(6))); +} + +#[test] +fn extract_retry_after_from_openai_reset_tokens() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset-tokens", "1s".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_secs(1))); +} + +#[test] +fn extract_retry_after_openai_takes_max_of_both() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset-requests", "2s".parse().unwrap()); + headers.insert("x-ratelimit-reset-tokens", "6m0s".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_mins(6))); +} + +#[test] +fn extract_retry_after_from_ratelimit_reset() { + let future_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + + 45; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset", future_ts.to_string().parse().unwrap()); + + let result = extract_retry_after(&headers).unwrap(); + // Allow 1s tolerance for test execution time. + assert!(result.as_secs() >= 44 && result.as_secs() <= 46); +} + +#[test] +fn extract_retry_after_prefers_standard_header() { + let future_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + + 120; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(reqwest::header::RETRY_AFTER, "10".parse().unwrap()); + headers.insert("x-ratelimit-reset", future_ts.to_string().parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_secs(10))); +} + +#[test] +fn extract_retry_after_past_reset_returns_none() { + let past_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + - 60; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset", past_ts.to_string().parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), None); +} + +#[test] +fn extract_retry_after_empty_headers() { + let headers = reqwest::header::HeaderMap::new(); + assert_eq!(extract_retry_after(&headers), None); +} + +#[test] +fn human_duration_seconds() { + assert_eq!(parse_human_duration("1s"), Some(1)); + assert_eq!(parse_human_duration("30s"), Some(30)); +} + +#[test] +fn human_duration_minutes_and_seconds() { + assert_eq!(parse_human_duration("6m0s"), Some(360)); + assert_eq!(parse_human_duration("1m30s"), Some(90)); +} + +#[test] +fn human_duration_hours() { + assert_eq!(parse_human_duration("1h30m0s"), Some(5400)); + assert_eq!(parse_human_duration("2h"), Some(7200)); +} + +#[test] +fn human_duration_milliseconds_rounds_up() { + assert_eq!(parse_human_duration("200ms"), Some(1)); + assert_eq!(parse_human_duration("0ms"), None); +} + +#[test] +fn human_duration_mixed_with_ms() { + // 1 second + 500ms = 1s (ms doesn't add to whole seconds). + assert_eq!(parse_human_duration("1s500ms"), Some(1)); +} + +#[test] +fn human_duration_zero_returns_none() { + assert_eq!(parse_human_duration("0s"), None); + assert_eq!(parse_human_duration("0m0s"), None); +} + +#[test] +fn human_duration_invalid() { + assert_eq!(parse_human_duration(""), None); + assert_eq!(parse_human_duration("abc"), None); + assert_eq!(parse_human_duration("5x"), None); +} + +#[test] +fn text_retry_after_n_seconds() { + let text = "Rate limit exceeded. Please retry after 30 seconds."; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_secs(30))); +} + +#[test] +fn text_wait_n_seconds() { + let text = "Too many requests. Please wait 60 seconds before trying again."; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_mins(1))); +} + +#[test] +fn text_try_again_in_ns() { + let text = "Service busy, try again in 5s"; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_secs(5))); +} + +#[test] +fn text_try_again_in_float() { + let text = "Overloaded, try again in 5.5s please"; + assert_eq!( + extract_retry_from_text(text), + Some(Duration::from_secs(6)) // ceil(5.5) + ); +} + +#[test] +fn text_retry_after_colon() { + let text = "Error: retry-after: 15"; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_secs(15))); +} + +#[test] +fn text_gemini_retry_delay() { + let text = r#"{"error":{"details":[{"retryDelay":"30s"}]}}"#; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_secs(30))); +} + +#[test] +fn text_no_pattern_returns_none() { + assert_eq!(extract_retry_from_text("Something went wrong"), None); + assert_eq!(extract_retry_from_text(""), None); +} diff --git a/crates/jp_llm/src/event.rs b/crates/jp_llm/src/event.rs index e19fc475..8cd2656e 100644 --- a/crates/jp_llm/src/event.rs +++ b/crates/jp_llm/src/event.rs @@ -14,8 +14,8 @@ use serde_json::Value; /// will have different `index` values corresponding to their relative order. /// /// Only after [`Event::Flush`] is produced with the given `index` value, should -/// all previous parts be merged into a single [`ConversationEvent`], using -/// `EventAggregator`. +/// all previous parts be merged into a single [`ConversationEvent`] (e.g. via +/// [`EventBuilder`](jp_conversation::event_builder::EventBuilder)). #[derive(Debug, Clone, PartialEq)] pub enum Event { /// A part of a completed event. diff --git a/crates/jp_llm/src/lib.rs b/crates/jp_llm/src/lib.rs index 35d9ca18..3cdc6b47 100644 --- a/crates/jp_llm/src/lib.rs +++ b/crates/jp_llm/src/lib.rs @@ -1,15 +1,18 @@ -mod error; +pub mod error; pub mod event; pub mod model; pub mod provider; pub mod query; +pub mod retry; mod stream; -pub mod structured; +pub mod title; pub mod tool; #[cfg(test)] pub(crate) mod test; -pub use error::{Error, ToolError}; +pub use error::{Error, StreamError, StreamErrorKind, ToolError}; pub use provider::Provider; -pub use stream::{aggregator::tool_call_request::AggregationError, chain::EventChain}; +pub use retry::exponential_backoff; +pub use stream::{EventStream, aggregator::tool_call_request::AggregationError, chain::EventChain}; +pub use tool::{CommandResult, ExecutionOutcome, run_tool_command}; diff --git a/crates/jp_llm/src/model.rs b/crates/jp_llm/src/model.rs index 137e3249..310b0402 100644 --- a/crates/jp_llm/src/model.rs +++ b/crates/jp_llm/src/model.rs @@ -49,6 +49,11 @@ impl ModelDetails { } } + #[must_use] + pub fn name(&self) -> &str { + self.display_name.as_deref().unwrap_or(&self.id.name) + } + #[must_use] pub fn custom_reasoning_config( &self, @@ -134,6 +139,21 @@ impl ModelDetails { // Custom configuration, so use it. Some(ReasoningConfig::Custom(custom)) => Some(custom), }, + + // Adaptive + Some(ReasoningDetails::Adaptive { .. }) => match config { + // Off, so disabled. + Some(ReasoningConfig::Off) => None, + + // Unconfigured or auto, so use high effort (the API default). + None | Some(ReasoningConfig::Auto) => Some(CustomReasoningConfig { + effort: ReasoningEffort::High, + exclude: false, + }), + + // Custom configuration, so use it. + Some(ReasoningConfig::Custom(custom)) => Some(custom), + }, } } } @@ -207,6 +227,18 @@ pub enum ReasoningDetails { /// Whether the model supports extremely high effort reasoning. xhigh: bool, }, + + /// Adaptive reasoning support. + /// + /// The model dynamically decides when and how much to think based on + /// task complexity. Uses effort levels (low/medium/high/max) instead of + /// token budgets. + /// + /// Currently only supported by Claude Opus 4.6+. + Adaptive { + /// Whether the model supports `max` effort level. + max: bool, + }, } impl ReasoningDetails { @@ -230,6 +262,11 @@ impl ReasoningDetails { } } + #[must_use] + pub fn adaptive(max: bool) -> Self { + Self::Adaptive { max } + } + #[must_use] pub fn unsupported() -> Self { Self::Unsupported @@ -239,7 +276,7 @@ impl ReasoningDetails { pub fn min_tokens(&self) -> u32 { match self { Self::Budgetted { min_tokens, .. } => *min_tokens, - Self::Leveled { .. } | Self::Unsupported => 0, + Self::Leveled { .. } | Self::Adaptive { .. } | Self::Unsupported => 0, } } @@ -247,7 +284,42 @@ impl ReasoningDetails { pub fn max_tokens(&self) -> Option { match self { Self::Budgetted { max_tokens, .. } => *max_tokens, - Self::Leveled { .. } | Self::Unsupported => None, + Self::Leveled { .. } | Self::Adaptive { .. } | Self::Unsupported => None, + } + } + + /// Returns the lowest reasoning effort level supported by this model, if + /// known. + /// + /// `Leveled` models return their lowest supported level. Other variants + /// return `Option::None` — callers should decide how to handle "disable + /// reasoning" for their provider (e.g. token budget 0, effort `minimal`, + /// `thinking: disabled`, etc.). + #[must_use] + pub fn lowest_effort(&self) -> Option { + match self { + Self::Leveled { + xlow, + low, + medium, + high, + xhigh, + } => { + if *xlow { + Some(ReasoningEffort::Xlow) + } else if *low { + Some(ReasoningEffort::Low) + } else if *medium { + Some(ReasoningEffort::Medium) + } else if *high { + Some(ReasoningEffort::High) + } else if *xhigh { + Some(ReasoningEffort::XHigh) + } else { + None + } + } + _ => None, } } @@ -265,4 +337,9 @@ impl ReasoningDetails { pub fn is_leveled(&self) -> bool { matches!(self, Self::Leveled { .. }) } + + #[must_use] + pub fn is_adaptive(&self) -> bool { + matches!(self, Self::Adaptive { .. }) + } } diff --git a/crates/jp_llm/src/provider.rs b/crates/jp_llm/src/provider.rs index 21e4b6bd..afa4d6ad 100644 --- a/crates/jp_llm/src/provider.rs +++ b/crates/jp_llm/src/provider.rs @@ -3,35 +3,26 @@ pub mod google; // pub mod xai; pub mod anthropic; pub mod llamacpp; +pub mod mock; pub mod ollama; pub mod openai; pub mod openrouter; use anthropic::Anthropic; use async_trait::async_trait; -use futures::{TryStreamExt as _, stream}; use google::Google; use jp_config::{ - assistant::instructions::InstructionsConfig, model::id::{Name, ProviderId}, providers::llm::LlmProviderConfig, }; -use jp_conversation::event::ConversationEvent; use llamacpp::Llamacpp; use ollama::Ollama; use openai::Openai; use openrouter::Openrouter; -use serde_json::Value; -use tracing::warn; use crate::{ - Error, - error::Result, - event::Event, - model::ModelDetails, - query::{ChatQuery, StructuredQuery}, - stream::{EventStream, aggregator::chunk::EventAggregator}, - structured::SCHEMA_TOOL_NAME, + error::Result, model::ModelDetails, provider::mock::MockProvider, query::ChatQuery, + stream::EventStream, }; #[async_trait] @@ -48,369 +39,27 @@ pub trait Provider: std::fmt::Debug + Send + Sync { model: &ModelDetails, query: ChatQuery, ) -> Result; - - /// Perform a non-streaming chat completion. - /// - /// Default implementation collects results from the streaming version. - async fn chat_completion(&self, model: &ModelDetails, query: ChatQuery) -> Result> { - let mut aggregator = EventAggregator::new(); - self.chat_completion_stream(model, query) - .await? - .map_ok(|event| stream::iter(aggregator.ingest(event).into_iter().map(Ok))) - .try_flatten() - .try_collect() - .await - } - - /// Perform a structured completion. - /// - /// Default implementation uses a specialized tool-call to get structured - /// results. - /// - /// Providers that have a dedicated structured response endpoint should - /// override this method. - async fn structured_completion( - &self, - model: &ModelDetails, - query: StructuredQuery, - ) -> Result { - let mut chat_query = ChatQuery { - thread: query.thread.clone(), - tools: vec![query.tool_definition()?], - tool_choice: query.tool_choice()?, - tool_call_strict_mode: true, - }; - - let max_retries = 3; - for i in 1..=max_retries { - let result = self.chat_completion(model, chat_query.clone()).await; - let events = match result { - Ok(events) => events, - Err(error) if i >= max_retries => return Err(error), - Err(error) => { - warn!(%error, "Error while getting structured data. Retrying in non-strict mode."); - chat_query.tool_call_strict_mode = false; - continue; - } - }; - - let data = events - .into_iter() - .filter_map(Event::into_conversation_event) - .filter_map(ConversationEvent::into_tool_call_request) - .find(|call| call.name == SCHEMA_TOOL_NAME) - .map(|call| Value::Object(call.arguments)); - - let result = data - .ok_or("Did not receive any structured data".to_owned()) - .and_then(|data| query.validate(&data).map(|()| data)); - - match result { - Ok(data) => return Ok(query.map(data)), - Err(error) => { - warn!(error, "Failed to fetch structured data. Retrying."); - - chat_query.thread.instructions.push( - InstructionsConfig::default() - .with_title("Structured Data Validation Error") - .with_description( - "The following error occurred while validating the structured \ - data. Please try again.", - ) - .with_item(error), - ); - } - } - } - - Err(Error::MissingStructuredData) - } } /// Get a provider by ID. -/// -/// # Panics -/// -/// Panics if the provider is `ProviderId::TEST`, which is reserved for testing -/// only. pub fn get_provider(id: ProviderId, config: &LlmProviderConfig) -> Result> { let provider: Box = match id { ProviderId::Anthropic => Box::new(Anthropic::try_from(&config.anthropic)?), - ProviderId::Deepseek => todo!(), ProviderId::Google => Box::new(Google::try_from(&config.google)?), ProviderId::Llamacpp => Box::new(Llamacpp::try_from(&config.llamacpp)?), ProviderId::Ollama => Box::new(Ollama::try_from(&config.ollama)?), ProviderId::Openai => Box::new(Openai::try_from(&config.openai)?), ProviderId::Openrouter => Box::new(Openrouter::try_from(&config.openrouter)?), + + ProviderId::Deepseek => todo!(), ProviderId::Xai => todo!(), + + ProviderId::Test => Box::new(MockProvider::new(vec![])), }; Ok(provider) } #[cfg(test)] -mod tests { - use std::sync::Arc; - - use jp_config::{ - assistant::tool_choice::ToolChoice, - conversation::tool::{OneOrManyTypes, ToolParameterConfig, item::ToolParameterItemConfig}, - }; - use jp_conversation::event::ChatRequest; - use jp_test::{Result, function_name}; - - use super::*; - use crate::{ - structured, - test::{TestRequest, run_test, test_model_details}, - }; - - macro_rules! test_all_providers { - ($($fn:ident),* $(,)?) => { - mod anthropic { use super::*; $(test_all_providers!(func; $fn, ProviderId::Anthropic);)* } - mod google { use super::*; $(test_all_providers!(func; $fn, ProviderId::Google);)* } - mod openai { use super::*; $(test_all_providers!(func; $fn, ProviderId::Openai);)* } - mod openrouter{ use super::*; $(test_all_providers!(func; $fn, ProviderId::Openrouter);)* } - mod ollama { use super::*; $(test_all_providers!(func; $fn, ProviderId::Ollama);)* } - mod llamacpp { use super::*; $(test_all_providers!(func; $fn, ProviderId::Llamacpp);)* } - }; - (func; $fn:ident, $provider:ty) => { - paste::paste! { - #[test_log::test(tokio::test)] - async fn [< test_ $fn >]() -> Result { - $fn($provider, function_name!()).await - } - } - }; - } - - async fn chat_completion_nostream(provider: ProviderId, test_name: &str) -> Result { - let request = TestRequest::chat(provider) - .stream(false) - .enable_reasoning() - .event(ChatRequest::from("Test message")); - - run_test(provider, test_name, Some(request)).await - } - - async fn chat_completion_stream(provider: ProviderId, test_name: &str) -> Result { - let request = TestRequest::chat(provider) - .stream(true) - .enable_reasoning() - .event(ChatRequest::from("Test message")); - - run_test(provider, test_name, Some(request)).await - } - - fn tool_call_base(provider: ProviderId) -> TestRequest { - TestRequest::chat(provider) - .event(ChatRequest::from( - "Please run the tool, providing whatever arguments you want.", - )) - .tool("run_me", vec![ - ("foo", ToolParameterConfig { - kind: OneOrManyTypes::One("string".into()), - default: Some("foo".into()), - description: None, - required: false, - enumeration: vec![], - items: None, - }), - ("bar", ToolParameterConfig { - kind: OneOrManyTypes::Many(vec!["string".into(), "array".into()]), - default: None, - description: None, - required: true, - enumeration: vec!["foo".into(), vec!["foo", "bar"].into()], - items: Some(ToolParameterItemConfig { - kind: OneOrManyTypes::One("string".into()), - default: None, - description: None, - enumeration: vec![], - }), - }), - ]) - } - - async fn tool_call_nostream(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_stream(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).stream(true), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_strict(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).tool_call_strict_mode(true), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - /// Without reasoning, "forced" tool calls should work as expected. - async fn tool_call_required_no_reasoning(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).tool_choice(ToolChoice::Required), - TestRequest::tool_call_response(Ok("working!"), true), - ]; - - run_test(provider, test_name, requests).await - } - - /// With reasoning, some models do not support "forced" tool calls, so - /// provider implementations should fall back to trying to instruct the - /// model to use the tool through regular textual instructions. - async fn tool_call_required_reasoning(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider) - .tool_choice(ToolChoice::Required) - .enable_reasoning(), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_auto(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).tool_choice(ToolChoice::Auto), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_function(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).tool_choice_fn("run_me"), - TestRequest::tool_call_response(Ok("working!"), true), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_reasoning(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).enable_reasoning(), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn structured_completion_success(provider: ProviderId, test_name: &str) -> Result { - let request = - TestRequest::chat(provider).chat_request("I am testing the structured completion API."); - let history = request.as_thread().unwrap().events.clone(); - let request = TestRequest::Structured { - query: structured::titles::titles(3, history, &[]).unwrap(), - model: match request { - TestRequest::Chat { model, .. } => model, - _ => unreachable!(), - }, - assert: Arc::new(|_| {}), - }; - - run_test(provider, test_name, Some(request)).await - } - - async fn structured_completion_error(provider: ProviderId, test_name: &str) -> Result { - let request = - TestRequest::chat(provider).chat_request("I am testing the structured completion API."); - let thread = request.as_thread().cloned().unwrap(); - let query = StructuredQuery::new( - schemars::json_schema!({ - "type": "object", - "description": "1 + 1 = ?", - "required": ["answer"], - "additionalProperties": false, - "properties": { "answer": { "type": "integer" } }, - }), - thread, - ) - .with_validator(move |value| { - value - .get("answer") - .ok_or("Missing `answer` field.".to_owned())? - .as_u64() - .ok_or("Answer must be an integer".to_owned()) - .and_then(|v| Err(format!("You thought 1 + 1 = {v}? Think again!"))) - }); - - let request = TestRequest::Structured { - query, - model: match request { - TestRequest::Chat { model, .. } => model, - _ => unreachable!(), - }, - assert: Arc::new(|results| { - results.iter().all(std::result::Result::is_err); - }), - }; - - run_test(provider, test_name, Some(request)).await - } - - async fn model_details(provider: ProviderId, test_name: &str) -> Result { - let request = TestRequest::ModelDetails { - name: test_model_details(provider).id.name.to_string(), - assert: Arc::new(|_| {}), - }; - - run_test(provider, test_name, Some(request)).await - } - - async fn models(provider: ProviderId, test_name: &str) -> Result { - let request = TestRequest::Models { - assert: Arc::new(|_| {}), - }; - - run_test(provider, test_name, Some(request)).await - } - - async fn multi_turn_conversation(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - TestRequest::chat(provider).chat_request("Test message"), - TestRequest::chat(provider) - .enable_reasoning() - .chat_request("Repeat my previous message"), - tool_call_base(provider).tool_choice_fn("run_me"), - TestRequest::tool_call_response(Ok("The secret code is: 42"), true), - TestRequest::chat(provider) - .enable_reasoning() - .chat_request("What was the result of the previous tool call?"), - ]; - - run_test(provider, test_name, requests).await - } - - test_all_providers![ - chat_completion_nostream, - chat_completion_stream, - tool_call_auto, - tool_call_function, - tool_call_reasoning, - tool_call_nostream, - tool_call_required_no_reasoning, - tool_call_required_reasoning, - tool_call_stream, - tool_call_strict, - structured_completion_success, - structured_completion_error, - model_details, - models, - multi_turn_conversation, - ]; -} +#[path = "provider_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/anthropic.rs b/crates/jp_llm/src/provider/anthropic.rs index a5fabe6a..7c917ffc 100644 --- a/crates/jp_llm/src/provider/anthropic.rs +++ b/crates/jp_llm/src/provider/anthropic.rs @@ -1,27 +1,30 @@ -use std::{env, time::Duration}; +use std::{env, future, time::Duration}; use async_anthropic::{ Client, errors::AnthropicError, messages::DEFAULT_MAX_TOKENS, types::{ - self, ListModelsResponse, System, Thinking, ToolBash, ToolCodeExecution, ToolComputerUse, - ToolTextEditor, ToolWebSearch, + self, Effort, JsonOutputFormat, ListModelsResponse, OutputConfig, System, Thinking, + ToolBash, ToolCodeExecution, ToolComputerUse, ToolTextEditor, ToolWebSearch, }, }; use async_stream::try_stream; use async_trait::async_trait; use chrono::NaiveDate; -use futures::{StreamExt as _, TryStreamExt as _, pin_mut}; +use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _, pin_mut, stream}; use indexmap::IndexMap; use jp_config::{ assistant::tool_choice::ToolChoice, - model::id::{Name, ProviderId}, + model::{ + id::{Name, ProviderId}, + parameters::ReasoningEffort, + }, providers::llm::anthropic::AnthropicConfig, }; use jp_conversation::{ ConversationStream, - event::{ChatResponse, ConversationEvent, EventKind}, + event::{ChatResponse, ConversationEvent, EventKind, ToolCallRequest}, thread::{Document, Documents, Thread}, }; use serde_json::{Map, Value, json}; @@ -29,7 +32,10 @@ use tracing::{debug, info, trace, warn}; use super::Provider; use crate::{ - error::{Error, Result}, + error::{ + Error, Result, StreamError, StreamErrorKind, extract_retry_from_text, + looks_like_quota_error, + }, event::{Event, FinishReason}, model::{ModelDeprecation, ModelDetails, ReasoningDetails}, query::ChatQuery, @@ -51,6 +57,27 @@ const MAX_CACHE_CONTROL_COUNT: usize = 4; const THINKING_SIGNATURE_KEY: &str = "anthropic_thinking_signature"; const REDACTED_THINKING_KEY: &str = "anthropic_redacted_thinking"; +/// Known Anthropic error types that are safe to retry. +/// +/// See: +/// See: +const RETRYABLE_ANTHROPIC_ERROR_TYPES: &[&str] = + &["rate_limit_error", "overloaded_error", "api_error"]; + +/// Supported string `format` values for Anthropic structured output. +const SUPPORTED_STRING_FORMATS: &[&str] = &[ + "date-time", + "time", + "date", + "duration", + "email", + "hostname", + "uri", + "ipv4", + "ipv6", + "uuid", +]; + #[derive(Debug, Clone)] pub struct Anthropic { client: Client, @@ -105,18 +132,25 @@ impl Provider for Anthropic { query: ChatQuery, ) -> Result { let client = self.client.clone(); - let chain_on_max_tokens = query + let max_tokens_config = query .thread .events .config()? .assistant .model .parameters - .max_tokens - .is_none() - && self.chain_on_max_tokens; + .max_tokens; - let request = create_request(model, query, true, &self.beta)?; + let (request, is_structured) = create_request(model, query, true, &self.beta)?; + + // Chaining is disabled for structured output — the provider guarantees + // schema compliance so the response won't hit max_tokens for a + // well-constrained schema. + // + // It is also disabled when the user has explicitly configured a max + // tokens value, or when chaining is disabled in the provider config. + let chain_on_max_tokens = + !is_structured && max_tokens_config.is_none() && self.chain_on_max_tokens; debug!(stream = true, "Anthropic chat completion stream request."); trace!( @@ -124,7 +158,7 @@ impl Provider for Anthropic { "Request payload." ); - Ok(call(client, request, chain_on_max_tokens)) + Ok(call(client, request, chain_on_max_tokens, is_structured)) } } @@ -139,6 +173,7 @@ fn call( client: Client, request: types::CreateMessagesRequest, chain_on_max_tokens: bool, + is_structured: bool, ) -> EventStream { Box::pin(try_stream!({ let mut tool_call_aggregator = ToolCallRequestAggregator::new(); @@ -153,87 +188,57 @@ fn call( .messages() .create_stream(request.clone()) .await - .map_err(|e| match e { - AnthropicError::RateLimit { retry_after } => Error::RateLimit { - retry_after: retry_after.map(Duration::from_secs), - }, - - // Anthropic's API is notoriously unreliable, so we - // special-case a few common errors that most of the times - // resolve themselves when retried. - // - // See: - // See: - AnthropicError::StreamError(e) - if ["rate_limit_error", "overloaded_error", "api_error"] - .contains(&e.error_type.as_str()) - || e.error.as_ref().is_some_and(|v| { - v.get("message").and_then(Value::as_str) == Some("Overloaded") - }) => - { - Error::RateLimit { - retry_after: Some(Duration::from_secs(3)), - } - } - - _ => Error::from(e), - }) + .map_err(StreamError::from) + .map_ok(|v| stream::iter(map_event(v, &mut tool_call_aggregator, is_structured))) + .try_flatten() + .chain(future::ready(Ok(Event::Finished(FinishReason::Completed))).into_stream()) .peekable(); pin_mut!(stream); while let Some(event) = stream.next().await.transpose()? { - if let Some(event) = map_event(event, &mut tool_call_aggregator)? { - match event { - // If the assistant has reached the maximum number of - // tokens, and we are in a state in which we can request - // more tokens, we do so by sending a new request and - // chaining those events onto the previous ones, keeping the - // existing stream of events alive. - // - // TODO: generalize this for any provider. - event if should_chain(&event, tool_calls_requested, chain_on_max_tokens) => { - debug!("Max tokens reached, auto-requesting more tokens."); - - for await event in chain(client.clone(), request.clone(), events) { - yield event?; - } - return; - } - done @ Event::Finished(_) => { - yield done; - return; + match event { + // If the assistant has reached the maximum number of + // tokens, and we are in a state in which we can request + // more tokens, we do so by sending a new request and + // chaining those events onto the previous ones, keeping the + // existing stream of events alive. + // + // TODO: generalize this for any provider. + event if should_chain(&event, tool_calls_requested, chain_on_max_tokens) => { + debug!("Max tokens reached, auto-requesting more tokens."); + + for await event in chain(client.clone(), request.clone(), events, is_structured) + { + yield event?; } - Event::Part { event, index } => { - if event.is_tool_call_request() { - tool_calls_requested = true; - } else if chain_on_max_tokens { - events.push(event.clone()); - } - - yield Event::Part { event, index }; + return; + } + done @ Event::Finished(_) => { + yield done; + return; + } + Event::Part { event, index } => { + if event.is_tool_call_request() { + tool_calls_requested = true; + } else if chain_on_max_tokens { + events.push(event.clone()); } - flush @ Event::Flush { .. } => { - let next_delta = stream.as_mut().peek().await.and_then(|e| { - e.as_ref().ok().and_then(|e| match e { - types::MessagesStreamEvent::MessageDelta { delta, .. } => { - Some(delta) - } - _ => None, - }) - }); - - // If we try to flush, but we're about to continue with - // a chained request, we ignore the flush event, to - // allow more event parts to be generated by the next - // response. - if let Some(delta) = next_delta.cloned().and_then(map_message_delta) - && should_chain(&delta, tool_calls_requested, chain_on_max_tokens) - { - continue; - } - - yield flush; + + yield Event::Part { event, index }; + } + flush @ Event::Flush { .. } => { + let next_event = stream.as_mut().peek().await.and_then(|e| e.as_ref().ok()); + + // If we try to flush, but we're about to continue with a + // chained request, we ignore the flush event, to allow more + // event parts to be generated by the next response. + if let Some(event) = next_event + && should_chain(event, tool_calls_requested, chain_on_max_tokens) + { + continue; } + + yield flush; } } } @@ -255,13 +260,18 @@ fn chain( client: Client, mut request: types::CreateMessagesRequest, events: Vec, + is_structured: bool, ) -> EventStream { debug_assert!(!events.iter().any(ConversationEvent::is_tool_call_request)); let mut should_merge = true; let previous_content = events .last() - .and_then(|e| e.as_chat_response().map(ChatResponse::content)) + .and_then(|e| match e.as_chat_response() { + Some(ChatResponse::Message { message }) => Some(message.as_str()), + Some(ChatResponse::Reasoning { reasoning }) => Some(reasoning.as_str()), + _ => None, + }) .unwrap_or_default() .to_owned(); @@ -292,7 +302,7 @@ fn chain( }); Box::pin(try_stream!({ - for await event in call(client, request, true) { + for await event in call(client, request, true, is_structured) { let mut event = event?; // When chaining new events, the reasoning content is irrelevant, as @@ -316,10 +326,12 @@ fn chain( // overlap. Sometimes the assistant will start a chaining response // with a small amount of content that was already seen in the // previous response, and we want to avoid duplicating that. - if let Some(content) = event + if let Some( + ChatResponse::Message { message: content } + | ChatResponse::Reasoning { reasoning: content }, + ) = event .as_conversation_event_mut() .and_then(ConversationEvent::as_chat_response_mut) - .map(ChatResponse::content_mut) { if should_merge { let merge_point = find_merge_point(&previous_content, content, 500); @@ -401,12 +413,11 @@ fn create_request( query: ChatQuery, stream: bool, beta: &BetaFeatures, -) -> Result { +) -> Result<(types::CreateMessagesRequest, bool)> { let ChatQuery { thread, tools, mut tool_choice, - tool_call_strict_mode, } = query; let mut builder = types::CreateMessagesRequestBuilder::default(); @@ -415,11 +426,22 @@ fn create_request( let Thread { system_prompt, - instructions, + sections, attachments, events, } = thread; + // Request a structured response if the very last event is a ChatRequest + // with a schema attached. The schema is transformed to strip unsupported + // properties (moving them into `description` fields as hints). + let format = events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()) + .map(|schema| JsonOutputFormat::JsonSchema { + schema: transform_schema(schema), + }); + let mut cache_control_count = MAX_CACHE_CONTROL_COUNT; let config = events.config()?; @@ -427,13 +449,8 @@ fn create_request( .model(model.id.name.clone()) .messages(AnthropicMessages::build(events, &mut cache_control_count).0); - let tools = convert_tools( - tools, - tool_call_strict_mode - && model.features.contains(&"structured-outputs") - && beta.structured_outputs(), - &mut cache_control_count, - ); + let strict_tools = model.features.contains(&"structured-outputs") && beta.structured_outputs(); + let tools = convert_tools(tools, strict_tools, &mut cache_control_count); let mut system_content = vec![]; @@ -447,20 +464,29 @@ fn create_request( })); } - if !instructions.is_empty() { - let text = instructions - .into_iter() - .map(|instruction| instruction.try_to_xml().map_err(Into::into)) - .collect::>>()? - .join("\n\n"); - - system_content.push(types::SystemContent::Text(types::Text { - text, - cache_control: (cache_control_count > 0).then_some({ - cache_control_count = cache_control_count.saturating_sub(1); - types::CacheControl::default() - }), - })); + // FIXME: Somehow the system_prompt is being duplicated. It has to do with + // `impl PartialConfigDelta`? + // dbg!(&system_content); + + if !sections.is_empty() { + // Each section gets its own system content block. Cache control + // is placed on the last section, as section content is unlikely + // to change between requests. + let mut sections = sections.iter().peekable(); + while let Some(section) = sections.next() { + system_content.push(types::SystemContent::Text(types::Text { + text: section.render(), + cache_control: sections.peek().map_or_else( + || { + (cache_control_count > 0).then(|| { + cache_control_count = cache_control_count.saturating_sub(1); + types::CacheControl::default() + }) + }, + |_| None, + ), + })); + } } if !attachments.is_empty() { @@ -546,36 +572,76 @@ fn create_request( builder.system(System::Content(system_content)); } + // Track the effort from reasoning config so we can merge it with the + // structured output format into a single OutputConfig at the end. + let mut effort = None; + let supports_thinking = model.reasoning.is_some_and(|r| !r.is_unsupported()); + if let Some(config) = reasoning_config { - let (min_budget, mut max_budget) = match model.reasoning { + match model.reasoning { + // Adaptive thinking for Opus 4.6+ + Some(ReasoningDetails::Adaptive { max: supports_max }) => { + builder.thinking(types::ExtendedThinking::Adaptive); + + effort = match config + .effort + .abs_to_rel(model.max_output_tokens) + .unwrap_or(ReasoningEffort::Auto) + { + ReasoningEffort::Max if supports_max => Some(Effort::Max), + ReasoningEffort::Max + | ReasoningEffort::XHigh + | ReasoningEffort::High + | ReasoningEffort::Absolute(_) => Some(Effort::High), + ReasoningEffort::Medium => Some(Effort::Medium), + ReasoningEffort::Low | ReasoningEffort::Xlow | ReasoningEffort::None => { + Some(Effort::Low) + } + ReasoningEffort::Auto => None, + }; + } + + // Budget-based thinking for older models Some(ReasoningDetails::Budgetted { min_tokens, - max_tokens, - }) => (min_tokens, max_tokens.unwrap_or(u32::MAX)), - _ => (0, u32::MAX), - }; + max_tokens: reasoning_max_tokens, + }) => { + let mut max_budget = reasoning_max_tokens.unwrap_or(u32::MAX); - // With interleaved thinking, the `budget_tokens` can exceed the - // `max_tokens` parameter, as it represents the total budget across all - // thinking blocks within one assistant turn. - // - // See: - // - // This is only enabled if the model supports it, otherwise an error is - // returned if the `max_tokens` parameter is larger than the model's - // supported range. - if beta.interleaved_thinking() && model.features.contains(&"interleaved-thinking") { - max_budget = model.context_window.unwrap_or(max_budget); + // With interleaved thinking, the `budget_tokens` can exceed the + // `max_tokens` parameter, as it represents the total budget across all + // thinking blocks within one assistant turn. + // + // See: + // + // This is only enabled if the model supports it, otherwise an error is + // returned if the `max_tokens` parameter is larger than the model's + // supported range. + if beta.interleaved_thinking() && model.features.contains(&"interleaved-thinking") { + max_budget = model.context_window.unwrap_or(max_budget); + } + + builder.thinking(types::ExtendedThinking::Enabled { + budget_tokens: config + .effort + .to_tokens(max_tokens) + .max(min_tokens) + .min(max_budget), + }); + } + + // Other reasoning details (Leveled, Unsupported) - no thinking config + _ => {} } + } else if supports_thinking { + // Reasoning is off but the model supports it — explicitly disable + // to prevent the model from thinking by default. + builder.thinking(types::ExtendedThinking::Disabled); + } - builder.thinking(types::ExtendedThinking { - kind: "enabled".to_string(), - budget_tokens: config - .effort - .to_tokens(max_tokens) - .max(min_budget) - .min(max_budget), - }); + let is_structured = format.is_some(); + if effort.is_some() || is_structured { + builder.output_config(OutputConfig { effort, format }); } if let Some(temperature) = parameters.temperature { @@ -608,12 +674,53 @@ fn create_request( builder.context_management(strategy); } - builder.build().map_err(Into::into) + builder + .build() + .map(|req| (req, is_structured)) + .map_err(Into::into) } #[expect(clippy::too_many_lines)] fn map_model(model: types::Model, beta: &BetaFeatures) -> Result { let details = match model.id.as_str() { + "claude-sonnet-4-6" => ModelDetails { + id: (PROVIDER, model.id).try_into()?, + display_name: Some(model.display_name), + context_window: if beta.context_1m() { + Some(1_000_000) + } else { + Some(200_000) + }, + max_output_tokens: Some(64_000), + reasoning: Some(ReasoningDetails::adaptive(true)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 8, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![ + "interleaved-thinking", + "context-editing", + "structured-outputs", + "adaptive-thinking", + ], + }, + "claude-opus-4-6" | "claude-opus-4-6-20260205" => ModelDetails { + id: (PROVIDER, model.id).try_into()?, + display_name: Some(model.display_name), + context_window: if beta.context_1m() { + Some(1_000_000) + } else { + Some(200_000) + }, + max_output_tokens: Some(128_000), + reasoning: Some(ReasoningDetails::adaptive(true)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 5, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![ + "interleaved-thinking", + "context-editing", + "structured-outputs", + "adaptive-thinking", + ], + }, "claude-opus-4-5" | "claude-opus-4-5-20251101" => ModelDetails { id: (PROVIDER, model.id).try_into()?, display_name: Some(model.display_name), @@ -694,39 +801,6 @@ fn map_model(model: types::Model, beta: &BetaFeatures) -> Result { deprecated: Some(ModelDeprecation::Active), features: vec!["interleaved-thinking", "context-editing"], }, - "claude-3-7-sonnet-latest" | "claude-3-7-sonnet-20250219" => ModelDetails { - id: (PROVIDER, model.id).try_into()?, - display_name: Some(model.display_name), - context_window: Some(200_000), - max_output_tokens: Some(64_000), - reasoning: Some(ReasoningDetails::budgetted(1024, None)), - knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 11, 1).unwrap()), - deprecated: Some(ModelDeprecation::Active), - features: vec![], - }, - "claude-3-5-haiku-latest" | "claude-3-5-haiku-20241022" => ModelDetails { - id: (PROVIDER, model.id).try_into()?, - display_name: Some(model.display_name), - context_window: Some(200_000), - max_output_tokens: Some(8_192), - reasoning: Some(ReasoningDetails::unsupported()), - knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 7, 1).unwrap()), - deprecated: Some(ModelDeprecation::Active), - features: vec![], - }, - "claude-3-opus-latest" | "claude-3-opus-20240229" => ModelDetails { - id: (PROVIDER, model.id).try_into()?, - display_name: Some(model.display_name), - context_window: Some(200_000), - max_output_tokens: Some(4_096), - reasoning: Some(ReasoningDetails::unsupported()), - knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2023, 8, 1).unwrap()), - deprecated: Some(ModelDeprecation::deprecated( - &"recommended replacement: claude-opus-4-1-20250805", - Some(NaiveDate::from_ymd_opt(2026, 1, 5).unwrap()), - )), - features: vec![], - }, "claude-3-haiku-20240307" => ModelDetails { id: (PROVIDER, model.id).try_into()?, display_name: Some(model.display_name), @@ -734,7 +808,10 @@ fn map_model(model: types::Model, beta: &BetaFeatures) -> Result { max_output_tokens: Some(4_096), reasoning: Some(ReasoningDetails::unsupported()), knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 8, 1).unwrap()), - deprecated: Some(ModelDeprecation::Active), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: claude-haiku-4-5-20251001", + Some(NaiveDate::from_ymd_opt(2026, 4, 20).unwrap()), + )), features: vec![], }, id => { @@ -751,7 +828,8 @@ fn map_model(model: types::Model, beta: &BetaFeatures) -> Result { fn map_event( event: types::MessagesStreamEvent, agg: &mut ToolCallRequestAggregator, -) -> Result> { + is_structured: bool, +) -> Vec> { use types::MessagesStreamEvent::*; trace!( @@ -763,11 +841,17 @@ fn map_event( ContentBlockStart { content_block, index, - } => Ok(map_content_start(content_block, index, agg)), - ContentBlockDelta { delta, index } => Ok(map_content_delta(delta, index, agg)), + } => map_content_start(content_block, index, agg, is_structured) + .into_iter() + .map(Ok) + .collect(), + ContentBlockDelta { delta, index } => map_content_delta(delta, index, agg, is_structured) + .into_iter() + .map(Ok) + .collect(), ContentBlockStop { index } => map_content_stop(index, agg), - MessageDelta { delta, .. } => Ok(map_message_delta(delta)), - _ => Ok(None), + MessageDelta { delta, .. } => map_message_delta(&delta).into_iter().map(Ok).collect(), + _ => vec![], } } @@ -798,6 +882,139 @@ impl TryFrom<&AnthropicConfig> for Anthropic { } } +/// Transform a JSON schema to conform to Anthropic's structured output +/// constraints. +/// +/// Anthropic's structured output supports a subset of JSON Schema. Unsupported +/// properties are stripped and appended to the `description` field so the model +/// can still see them as soft hints. +/// +/// Mirrors the logic from Anthropic's Python SDK `transform_schema`. +/// +/// See: +fn transform_schema(mut src: Map) -> Map { + if let Some(r) = src.remove("$ref") { + return Map::from_iter([("$ref".into(), r)]); + } + + let mut out = Map::new(); + + // Helper macro to move a field from src to out. + macro_rules! move_field { + ($key:literal) => { + if let Some(v) = src.remove($key) { + out.insert($key.into(), v); + } + }; + } + + // Extract common fields + move_field!("title"); + move_field!("description"); + + // Recursive Transformation Helpers + let transform_val = |v: Value| match v { + Value::Object(o) => Value::Object(transform_schema(o)), + other => other, + }; + + let transform_map = |m: Map| -> Map { + m.into_iter().map(|(k, v)| (k, transform_val(v))).collect() + }; + + let transform_vec = + |v: Vec| -> Vec { v.into_iter().map(transform_val).collect() }; + + // Handle Recursive Dictionaries + for key in ["$defs", "definitions"] { + if let Some(Value::Object(defs)) = src.remove(key) { + out.insert(key.into(), Value::Object(transform_map(defs))); + } + } + + // Handle Combinators + if let Some(Value::Array(variants)) = src.remove("anyOf") { + out.insert("anyOf".into(), Value::Array(transform_vec(variants))); + } else if let Some(Value::Array(variants)) = src.remove("oneOf") { + // Remap oneOf -> anyOf + out.insert("anyOf".into(), Value::Array(transform_vec(variants))); + } else if let Some(Value::Array(variants)) = src.remove("allOf") { + out.insert("allOf".into(), Value::Array(transform_vec(variants))); + } + + // Handle Type-Specific Logic + // + // We remove "type" now so it doesn't get caught in the "leftovers" logic + // later. + let type_val = src.remove("type"); + match type_val.as_ref().and_then(Value::as_str) { + Some("object") => { + if let Some(Value::Object(props)) = src.remove("properties") { + out.insert("properties".into(), Value::Object(transform_map(props))); + } + + move_field!("required"); + + // Force strictness + src.remove("additionalProperties"); + out.insert("additionalProperties".into(), Value::Bool(false)); + } + Some("array") => { + if let Some(items) = src.remove("items") { + out.insert("items".into(), transform_val(items)); + } + + // Enforce minItems logic + if let Some(min) = src.remove("minItems") { + if min.as_u64().is_some_and(|n| n <= 1) { + out.insert("minItems".into(), min); + } else { + // Put it back into src so it falls into description + src.insert("minItems".into(), min); + } + } + } + Some("string") => { + if let Some(format) = src.remove("format") { + let is_supported = format + .as_str() + .is_some_and(|f| SUPPORTED_STRING_FORMATS.contains(&f)); + + if is_supported { + out.insert("format".into(), format); + } else { + src.insert("format".into(), format); + } + } + } + _ => {} + } + + // Re-insert type. + if let Some(t) = type_val { + out.insert("type".into(), t); + } + + // 7. Handle "Leftovers" (Unsupported fields -> Description) + if !src.is_empty() { + let extra_info = src + .iter() + .map(|(k, v)| format!("{k}: {v}")) + .collect::>() + .join(", "); + + out.entry("description") + .and_modify(|v| { + if let Some(s) = v.as_str() { + *v = Value::from(format!("{s}\n\n{{{extra_info}}}")); + } + }) + .or_insert_with(|| Value::from(format!("{{{extra_info}}}"))); + } + + out +} + fn convert_tool_choice(choice: ToolChoice) -> types::ToolChoice { match choice { ToolChoice::None => types::ToolChoice::none(), @@ -979,7 +1196,11 @@ fn convert_event( && signature.is_some() { types::MessageContent::Thinking(Thinking { - thinking: response.into_content(), + thinking: match response { + ChatResponse::Reasoning { reasoning } => reasoning, + ChatResponse::Message { message } => message, + ChatResponse::Structured { data } => data.to_string(), + }, signature, }) } else if is_anthropic @@ -990,13 +1211,19 @@ fn convert_event( .map(str::to_owned) { types::MessageContent::RedactedThinking { data } - } else if response.is_reasoning() { - // Reasoning from other providers - wrap in XML tags - types::MessageContent::Text( - format!("\n{}\n\n\n", response.content()).into(), - ) } else { - types::MessageContent::Text(response.into_content().into()) + match response { + // Reasoning from other providers - wrap in tags. + ChatResponse::Reasoning { reasoning } => types::MessageContent::Text( + format!("\n{reasoning}\n\n\n").into(), + ), + ChatResponse::Message { message } => { + types::MessageContent::Text(message.into()) + } + ChatResponse::Structured { data } => { + types::MessageContent::Text(data.to_string().into()) + } + } }; Some((types::MessageRole::Assistant, content)) @@ -1028,7 +1255,8 @@ fn convert_event( } EventKind::ChatRequest(_) | EventKind::InquiryRequest(_) - | EventKind::InquiryResponse(_) => None, + | EventKind::InquiryResponse(_) + | EventKind::TurnStart(_) => None, } } @@ -1036,16 +1264,29 @@ fn map_content_start( item: types::MessageContent, index: usize, agg: &mut ToolCallRequestAggregator, + is_structured: bool, ) -> Option { use types::MessageContent::*; let mut metadata = IndexMap::new(); let kind: EventKind = match item { + // Initial part indicating a tool call request has started. The eventual + // fully-aggregated arguments will be sent in a separate Part. ToolUse(types::ToolUse { id, name, .. }) => { - *agg = ToolCallRequestAggregator::new(); - agg.add_chunk(index, Some(id), Some(name), None); - return None; + *agg = ToolCallRequestAggregator::default(); + agg.add_chunk(index, Some(id.clone()), Some(name.clone()), None); + let request = ToolCallRequest { + id, + name, + arguments: Map::new(), + }; + + return Some(Event::Part { + index, + event: request.into(), + }); } + Text(text) if is_structured => ChatResponse::structured(Value::String(text.text)).into(), Text(text) if !text.text.is_empty() => ChatResponse::message(text.text).into(), Text(_) => return None, Thinking(types::Thinking { @@ -1077,9 +1318,13 @@ fn map_content_delta( delta: types::ContentBlockDelta, index: usize, agg: &mut ToolCallRequestAggregator, + is_structured: bool, ) -> Option { let mut metadata = IndexMap::new(); let kind: EventKind = match delta { + types::ContentBlockDelta::TextDelta { text } if is_structured => { + ChatResponse::structured(Value::String(text)).into() + } types::ContentBlockDelta::TextDelta { text } => ChatResponse::message(text).into(), types::ContentBlockDelta::ThinkingDelta { thinking } => { ChatResponse::reasoning(thinking).into() @@ -1105,163 +1350,90 @@ fn map_content_delta( }) } -fn map_content_stop(index: usize, agg: &mut ToolCallRequestAggregator) -> Result> { +fn map_content_stop( + index: usize, + agg: &mut ToolCallRequestAggregator, +) -> Vec> { + let mut events = vec![]; + // Check if we're buffering a tool call request match agg.finalize(index) { - Ok(tool_call) => { - return Ok(Some(Event::Part { - event: ConversationEvent::now(tool_call), - index, - })); - } + Ok(tool_call) => events.push(Ok(Event::Part { + event: ConversationEvent::now(tool_call), + index, + })), Err(AggregationError::UnknownIndex) => {} - Err(error) => return Err(error.into()), + Err(error) => { + events.push(Err(StreamError::other(error.to_string()))); + return events; + } } - Ok(Some(Event::flush(index))) + events.push(Ok(Event::flush(index))); + events } -fn map_message_delta(delta: types::MessageDelta) -> Option { - match delta.stop_reason?.as_str() { +fn map_message_delta(delta: &types::MessageDelta) -> Option { + match delta.stop_reason.as_deref()? { "max_tokens" => Some(Event::Finished(FinishReason::MaxTokens)), _ => None, } } -#[cfg(test)] -mod tests { - use indexmap::IndexMap; - use jp_config::model::parameters::{ - PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort, - }; - use jp_test::{Result, function_name}; - use test_log::test; - - use super::*; - use crate::test::{TestRequest, run_test}; - - const MAGIC_STRING: &str = "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB"; - - #[test(tokio::test)] - async fn test_redacted_thinking() -> Result { - let requests = vec![ - TestRequest::chat(PROVIDER) - .enable_reasoning() - .chat_request(MAGIC_STRING), - TestRequest::chat(PROVIDER) - .chat_request("Do you have access to your redacted thinking content?"), - ]; - - run_test(PROVIDER, function_name!(), requests).await - } +impl From for StreamError { + fn from(error: AnthropicError) -> Self { + use AnthropicError as E; - #[test(tokio::test)] - async fn test_request_chaining() -> Result { - let mut request = TestRequest::chat(PROVIDER) - .stream(true) - .reasoning(Some(PartialReasoningConfig::Custom( - PartialCustomReasoningConfig { - effort: Some(ReasoningEffort::Absolute(1024.into())), - exclude: Some(false), - }, - ))) - .chat_request("Give me a 2000 word explainer about Kirigami-inspired parachutes"); - - if let Some(details) = request.as_model_details_mut() { - details.max_output_tokens = Some(1152); - } + match error { + E::Network(error) => Self::from(error), + E::StreamTransport(_) => StreamError::transient(error.to_string()).with_source(error), + E::RateLimit { retry_after } => { + Self::rate_limit(retry_after.map(Duration::from_secs)).with_source(error) + } - run_test(PROVIDER, function_name!(), Some(request)).await - } + // Anthropic's API is notoriously unreliable, so we special-case a + // few common errors that most of the times resolve themselves when + // retried. + // + // See: + // See: + E::Api(ref api_error) + if RETRYABLE_ANTHROPIC_ERROR_TYPES.contains(&api_error.error_type.as_str()) => + { + let retry_after = api_error + .message + .as_deref() + .and_then(extract_retry_from_text) + .unwrap_or(Duration::from_secs(3)); + + StreamError::transient(error.to_string()) + .with_retry_after(retry_after) + .with_source(error) + } - #[test] - fn test_find_merge_point_edge_cases() { - struct TestCase { - left: &'static str, - right: &'static str, - expected: &'static str, - max_search: usize, - } + // Detect billing/quota errors before falling through to generic. + E::Api(ref api_error) + if looks_like_quota_error(&api_error.error_type) + || api_error + .message + .as_deref() + .is_some_and(looks_like_quota_error) => + { + StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your plan and billing details \ + at https://console.anthropic.com/settings/billing. ({error})" + ), + ) + .with_source(error) + } - let cases = IndexMap::from([ - ("no overlap", TestCase { - left: "Hello", - right: " world", - expected: "Hello world", - max_search: 500, - }), - ("single word overlap", TestCase { - left: "The quick brown", - right: "brown fox", - expected: "The quick brown fox", - max_search: 500, - }), - ("minimal overlap (5 chars)", TestCase { - expected: "abcdefghij", - left: "abcdefgh", - right: "defghij", - max_search: 500, - }), - ( - "below minimum overlap (4 chars) - should not merge", - TestCase { - left: "abcd", - right: "abcd", - expected: "abcdabcd", - max_search: 500, - }, - ), - ("complete overlap", TestCase { - left: "Hello world", - right: "world", - expected: "Hello world", - max_search: 500, - }), - ("overlap with punctuation", TestCase { - left: "Hello, how are", - right: "how are you?", - expected: "Hello, how are you?", - max_search: 500, - }), - ("overlap with whitespace", TestCase { - left: "Hello ", - right: " world", - expected: "Hello world", - max_search: 500, - }), - ("unicode overlap", TestCase { - left: "Hello 世界", - right: "世界 friend", - expected: "Hello 世界 friend", - max_search: 500, - }), - ("long overlap", TestCase { - left: "The quick brown fox jumps", - right: "fox jumps over the lazy dog", - expected: "The quick brown fox jumpsfox jumps over the lazy dog", - max_search: 8, - }), - ("empty right", TestCase { - left: "Hello", - right: "", - expected: "Hello", - max_search: 500, - }), - ]); - - for ( - name, - TestCase { - left, - right, - expected, - max_search, - }, - ) in cases - { - let pos = find_merge_point(left, right, max_search); - let result = format!("{left}{}", &right[pos..]); - assert_eq!(result, expected, "Failed test case: {name}"); + error => StreamError::other(error.to_string()).with_source(error), } } } + +#[cfg(test)] +#[path = "anthropic_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/anthropic_tests.rs b/crates/jp_llm/src/provider/anthropic_tests.rs new file mode 100644 index 00000000..b1da6420 --- /dev/null +++ b/crates/jp_llm/src/provider/anthropic_tests.rs @@ -0,0 +1,746 @@ +use indexmap::IndexMap; +use jp_config::model::parameters::{ + PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort, +}; +use jp_test::{Result, function_name}; +use test_log::test; + +use super::*; +use crate::test::{TestRequest, run_test}; + +const MAGIC_STRING: &str = "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB"; + +#[test(tokio::test)] +async fn test_redacted_thinking() -> Result { + let requests = vec![ + TestRequest::chat(PROVIDER) + .enable_reasoning() + .chat_request(MAGIC_STRING), + TestRequest::chat(PROVIDER) + .chat_request("Do you have access to your redacted thinking content?"), + ]; + + run_test(PROVIDER, function_name!(), requests).await +} + +#[test(tokio::test)] +async fn test_request_chaining() -> Result { + let mut request = TestRequest::chat(PROVIDER) + .reasoning(Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::Absolute(1024.into())), + exclude: Some(false), + }, + ))) + .chat_request("Give me a 2000 word explainer about Kirigami-inspired parachutes"); + + if let Some(details) = request.as_model_details_mut() { + details.max_output_tokens = Some(1152); + } + + run_test(PROVIDER, function_name!(), Some(request)).await +} + +/// Test that Opus 4.6 uses adaptive thinking mode with the effort parameter. +#[test(tokio::test)] +async fn test_opus_4_6_adaptive_thinking() -> Result { + let mut request = TestRequest::chat(PROVIDER) + .reasoning(Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::High), + exclude: Some(false), + }, + ))) + .model("anthropic/claude-opus-4-6".parse().unwrap()) + .chat_request("What is 2 + 2?"); + + // Configure model to use adaptive thinking (Opus 4.6 feature) + if let Some(details) = request.as_model_details_mut() { + details.reasoning = Some(ReasoningDetails::adaptive(true)); + details.features = vec!["adaptive-thinking"]; + } + + run_test(PROVIDER, function_name!(), Some(request)).await +} + +/// Test Opus 4.6 with `Max` effort level (only supported on Opus 4.6). +#[test(tokio::test)] +async fn test_opus_4_6_max_effort() -> Result { + let mut request = TestRequest::chat(PROVIDER) + .reasoning(Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::Max), + exclude: Some(false), + }, + ))) + .model("anthropic/claude-opus-4-6".parse().unwrap()) + .chat_request("What is 2 + 2?"); + + // Configure model to use adaptive thinking with max effort support (Opus 4.6 feature) + if let Some(details) = request.as_model_details_mut() { + details.reasoning = Some(ReasoningDetails::adaptive(true)); + details.features = vec!["adaptive-thinking"]; + } + + run_test(PROVIDER, function_name!(), Some(request)).await +} + +/// Unit test: Verify Opus 4.6 generates adaptive thinking request. +#[test] +fn test_opus_4_6_request_uses_adaptive_thinking() { + use jp_conversation::{ConversationStream, thread::Thread}; + + let model = ModelDetails { + id: (PROVIDER, "claude-opus-4-6").try_into().unwrap(), + display_name: Some("Claude Opus 4.6".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(128_000), + reasoning: Some(ReasoningDetails::adaptive(true)), + knowledge_cutoff: None, + deprecated: None, + features: vec!["adaptive-thinking"], + }; + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events: ConversationStream::new_test().with_chat_request("test"), + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, is_structured) = create_request(&model, query, true, &beta).unwrap(); + assert!(!is_structured); + + // Verify adaptive thinking is used + assert_eq!(request.thinking, Some(types::ExtendedThinking::Adaptive)); + + // Verify output_config has effort set (defaults to High) + assert!(request.output_config.is_some()); + let output_config = request.output_config.unwrap(); + assert_eq!(output_config.effort, Some(Effort::High)); + assert_eq!(output_config.format, None); +} + +/// Unit test: Verify Max effort maps to `Effort::Max` for Opus 4.6. +#[test] +fn test_opus_4_6_max_effort_mapping() { + use jp_conversation::{ConversationStream, thread::Thread}; + + let model = ModelDetails { + id: (PROVIDER, "claude-opus-4-6").try_into().unwrap(), + display_name: Some("Claude Opus 4.6".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(128_000), + reasoning: Some(ReasoningDetails::adaptive(true)), // supports max + knowledge_cutoff: None, + deprecated: None, + features: vec!["adaptive-thinking"], + }; + + let mut events = ConversationStream::new_test().with_chat_request("test"); + let mut delta = jp_config::PartialAppConfig::empty(); + delta.assistant.model.parameters.reasoning = Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::Max), + exclude: Some(false), + }, + )); + events.add_config_delta(delta); + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events, + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, _) = create_request(&model, query, true, &beta).unwrap(); + + // Verify Max effort is used + assert!(request.output_config.is_some()); + let output_config = request.output_config.unwrap(); + assert_eq!(output_config.effort, Some(Effort::Max)); +} + +/// Unit test: Verify budget-based model (Opus 4.5) still uses Enabled thinking. +#[test] +fn test_opus_4_5_uses_budgetted_thinking() { + use jp_conversation::{ConversationStream, thread::Thread}; + + let model = ModelDetails { + id: (PROVIDER, "claude-opus-4-5").try_into().unwrap(), + display_name: Some("Claude Opus 4.5".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(64_000), + reasoning: Some(ReasoningDetails::budgetted(1024, None)), + knowledge_cutoff: None, + deprecated: None, + features: vec!["interleaved-thinking"], + }; + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events: ConversationStream::new_test().with_chat_request("test"), + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, _) = create_request(&model, query, true, &beta).unwrap(); + + // Verify budget-based thinking is used (not adaptive) + assert!(matches!( + request.thinking, + Some(types::ExtendedThinking::Enabled { .. }) + )); + + // Verify output_config is NOT set for budget-based models + assert!(request.output_config.is_none()); +} + +/// Verify structured output sets `output_config.format` when the last event +/// is a `ChatRequest` with a schema. +#[test] +fn test_structured_output_sets_format() { + use jp_conversation::{ConversationStream, event::ChatRequest, thread::Thread}; + use serde_json::json; + + let model = ModelDetails { + id: (PROVIDER, "claude-sonnet-4-5").try_into().unwrap(), + display_name: Some("Claude Sonnet 4.5".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(64_000), + reasoning: Some(ReasoningDetails::budgetted(1024, None)), + knowledge_cutoff: None, + deprecated: None, + features: vec!["structured-outputs"], + }; + + let schema = serde_json::Map::from_iter([ + ("type".into(), json!("object")), + ("properties".into(), json!({"name": {"type": "string"}})), + ]); + + let events = ConversationStream::new_test().with_chat_request(ChatRequest { + content: "Extract contacts".into(), + schema: Some(schema), + }); + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events, + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, is_structured) = create_request(&model, query, true, &beta).unwrap(); + + assert!(is_structured); + assert!(request.output_config.is_some()); + let output_config = request.output_config.unwrap(); + // No adaptive thinking, so effort should be None. + assert_eq!(output_config.effort, None); + // transform_schema adds additionalProperties: false for objects. + let expected_schema = serde_json::Map::from_iter([ + ("type".into(), json!("object")), + ("properties".into(), json!({"name": {"type": "string"}})), + ("additionalProperties".into(), json!(false)), + ]); + assert_eq!( + output_config.format, + Some(JsonOutputFormat::JsonSchema { + schema: expected_schema + }) + ); +} + +/// When the last event is NOT a `ChatRequest` (e.g. a `ChatResponse`), no +/// structured output should be configured even if a prior `ChatRequest` had a +/// schema. +#[test] +fn test_schema_ignored_when_last_event_is_not_chat_request() { + use jp_conversation::{ + ConversationStream, + event::{ChatRequest, ChatResponse}, + thread::Thread, + }; + use serde_json::json; + + let model = ModelDetails { + id: (PROVIDER, "claude-sonnet-4-5").try_into().unwrap(), + display_name: None, + context_window: Some(200_000), + max_output_tokens: Some(64_000), + reasoning: None, + knowledge_cutoff: None, + deprecated: None, + features: vec![], + }; + + let mut events = ConversationStream::new_test(); + // First turn: structured request + events.add_chat_request(ChatRequest { + content: "Extract contacts".into(), + schema: Some(serde_json::Map::from_iter([( + "type".into(), + json!("object"), + )])), + }); + // Then a response (now the last event is not a ChatRequest) + events.add_chat_response(ChatResponse::structured(json!({"name": "Alice"}))); + // Follow-up without schema + events.add_chat_request(ChatRequest { + content: "Explain what you found".into(), + schema: None, + }); + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events, + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (_, is_structured) = create_request(&model, query, true, &beta).unwrap(); + assert!(!is_structured); +} + +/// Adaptive thinking + structured output should coexist on `OutputConfig`. +#[test] +fn test_adaptive_thinking_with_structured_output() { + use jp_conversation::{ConversationStream, event::ChatRequest, thread::Thread}; + use serde_json::json; + + let model = ModelDetails { + id: (PROVIDER, "claude-opus-4-6").try_into().unwrap(), + display_name: Some("Claude Opus 4.6".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(128_000), + reasoning: Some(ReasoningDetails::adaptive(true)), + knowledge_cutoff: None, + deprecated: None, + features: vec!["adaptive-thinking", "structured-outputs"], + }; + + let schema = serde_json::Map::from_iter([("type".into(), json!("object"))]); + + let events = ConversationStream::new_test().with_chat_request(ChatRequest { + content: "Extract data".into(), + schema: Some(schema), + }); + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events, + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, is_structured) = create_request(&model, query, true, &beta).unwrap(); + + assert!(is_structured); + assert_eq!(request.thinking, Some(types::ExtendedThinking::Adaptive)); + + let output_config = request.output_config.unwrap(); + // Both effort and format should be present. + assert_eq!(output_config.effort, Some(Effort::High)); + let expected_schema = serde_json::Map::from_iter([ + ("type".into(), json!("object")), + ("additionalProperties".into(), json!(false)), + ]); + assert_eq!( + output_config.format, + Some(JsonOutputFormat::JsonSchema { + schema: expected_schema + }) + ); +} + +#[test] +fn test_find_merge_point_edge_cases() { + struct TestCase { + left: &'static str, + right: &'static str, + expected: &'static str, + max_search: usize, + } + + let cases = IndexMap::from([ + ("no overlap", TestCase { + left: "Hello", + right: " world", + expected: "Hello world", + max_search: 500, + }), + ("single word overlap", TestCase { + left: "The quick brown", + right: "brown fox", + expected: "The quick brown fox", + max_search: 500, + }), + ("minimal overlap (5 chars)", TestCase { + expected: "abcdefghij", + left: "abcdefgh", + right: "defghij", + max_search: 500, + }), + ( + "below minimum overlap (4 chars) - should not merge", + TestCase { + left: "abcd", + right: "abcd", + expected: "abcdabcd", + max_search: 500, + }, + ), + ("complete overlap", TestCase { + left: "Hello world", + right: "world", + expected: "Hello world", + max_search: 500, + }), + ("overlap with punctuation", TestCase { + left: "Hello, how are", + right: "how are you?", + expected: "Hello, how are you?", + max_search: 500, + }), + ("overlap with whitespace", TestCase { + left: "Hello ", + right: " world", + expected: "Hello world", + max_search: 500, + }), + ("unicode overlap", TestCase { + left: "Hello 世界", + right: "世界 friend", + expected: "Hello 世界 friend", + max_search: 500, + }), + ("long overlap", TestCase { + left: "The quick brown fox jumps", + right: "fox jumps over the lazy dog", + expected: "The quick brown fox jumpsfox jumps over the lazy dog", + max_search: 8, + }), + ("empty right", TestCase { + left: "Hello", + right: "", + expected: "Hello", + max_search: 500, + }), + ]); + + for ( + name, + TestCase { + left, + right, + expected, + max_search, + }, + ) in cases + { + let pos = find_merge_point(left, right, max_search); + let result = format!("{left}{}", &right[pos..]); + assert_eq!(result, expected, "Failed test case: {name}"); + } +} + +mod transform_schema { + use serde_json::{Map, Value, json}; + + use super::transform_schema; + + #[expect(clippy::needless_pass_by_value)] + fn schema(v: Value) -> Map { + v.as_object().unwrap().clone() + } + + #[test] + fn object_forces_additional_properties_false() { + let input = schema(json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + })); + + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["name"])); + assert_eq!(out["properties"]["name"]["type"], "string"); + } + + #[test] + fn object_drops_existing_additional_properties() { + let input = schema(json!({ + "type": "object", + "properties": {}, + "additionalProperties": true + })); + + let out = transform_schema(input); + assert_eq!(out["additionalProperties"], json!(false)); + } + + #[test] + fn array_keeps_min_items_0_and_1() { + for n in [0, 1] { + let input = schema(json!({ + "type": "array", + "items": { "type": "string" }, + "minItems": n + })); + + let out = transform_schema(input); + assert_eq!(out["minItems"], json!(n), "minItems {n} should be kept"); + } + } + + #[test] + fn array_moves_large_min_items_to_description() { + let input = schema(json!({ + "type": "array", + "items": { "type": "string" }, + "minItems": 3 + })); + + let out = transform_schema(input); + assert!(out.get("minItems").is_none()); + let desc = out["description"].as_str().unwrap(); + assert!( + desc.contains("minItems"), + "description should mention minItems: {desc}" + ); + } + + #[test] + fn array_moves_max_items_to_description() { + let input = schema(json!({ + "type": "array", + "items": { "type": "string" }, + "maxItems": 5 + })); + + let out = transform_schema(input); + assert!(out.get("maxItems").is_none()); + let desc = out["description"].as_str().unwrap(); + assert!( + desc.contains("maxItems"), + "description should mention maxItems: {desc}" + ); + } + + #[test] + fn string_keeps_supported_format() { + let input = schema(json!({ + "type": "string", + "format": "date-time" + })); + + let out = transform_schema(input); + assert_eq!(out["format"], "date-time"); + assert!(out.get("description").is_none()); + } + + #[test] + fn string_moves_unsupported_format_to_description() { + let input = schema(json!({ + "type": "string", + "format": "phone-number" + })); + + let out = transform_schema(input); + assert!(out.get("format").is_none()); + let desc = out["description"].as_str().unwrap(); + assert!( + desc.contains("phone-number"), + "description should contain the format: {desc}" + ); + } + + #[test] + fn numeric_constraints_moved_to_description() { + let input = schema(json!({ + "type": "integer", + "minimum": 1, + "maximum": 10, + "description": "A number" + })); + + let out = transform_schema(input); + assert!(out.get("minimum").is_none()); + assert!(out.get("maximum").is_none()); + let desc = out["description"].as_str().unwrap(); + assert!( + desc.starts_with("A number"), + "should preserve original description" + ); + assert!(desc.contains("minimum"), "should contain minimum: {desc}"); + assert!(desc.contains("maximum"), "should contain maximum: {desc}"); + } + + #[test] + fn ref_passes_through() { + let input = schema(json!({ + "$ref": "#/$defs/Address" + })); + + let out = transform_schema(input); + assert_eq!(out["$ref"], "#/$defs/Address"); + assert_eq!(out.len(), 1); + } + + #[test] + fn defs_recursively_transformed() { + let input = schema(json!({ + "type": "object", + "properties": { + "addr": { "$ref": "#/$defs/Address" } + }, + "$defs": { + "Address": { + "type": "object", + "properties": { "city": { "type": "string" } }, + "additionalProperties": true + } + } + })); + + let out = transform_schema(input); + let addr_def = out["$defs"]["Address"].as_object().unwrap(); + assert_eq!(addr_def["additionalProperties"], json!(false)); + } + + #[test] + fn one_of_converted_to_any_of() { + let input = schema(json!({ + "oneOf": [ + { "type": "string" }, + { "type": "integer" } + ] + })); + + let out = transform_schema(input); + assert!(out.get("oneOf").is_none()); + let any_of = out["anyOf"].as_array().unwrap(); + assert_eq!(any_of.len(), 2); + assert_eq!(any_of[0]["type"], "string"); + assert_eq!(any_of[1]["type"], "integer"); + } + + #[test] + fn nested_properties_recursively_transformed() { + let input = schema(json!({ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { "id": { "type": "integer", "minimum": 0 } }, + "additionalProperties": true + }, + "maxItems": 10 + } + } + })); + + let out = transform_schema(input); + + // Top-level object + assert_eq!(out["additionalProperties"], json!(false)); + + // The array property + let items_prop = out["properties"]["items"].as_object().unwrap(); + assert!(items_prop.get("maxItems").is_none()); + + // The nested object inside the array + let nested = items_prop["items"].as_object().unwrap(); + assert_eq!(nested["additionalProperties"], json!(false)); + + // The integer property's minimum should be in description + let id_prop = nested["properties"]["id"].as_object().unwrap(); + assert!(id_prop.get("minimum").is_none()); + let desc = id_prop["description"].as_str().unwrap(); + assert!( + desc.contains("minimum"), + "nested constraint in description: {desc}" + ); + } + + /// Mirrors the example from Anthropic's Python SDK docstring. + #[test] + fn sdk_docstring_example() { + let input = schema(json!({ + "type": "integer", + "minimum": 1, + "maximum": 10, + "description": "A number" + })); + + let out = transform_schema(input); + assert_eq!(out["type"], "integer"); + let desc = out["description"].as_str().unwrap(); + assert!(desc.starts_with("A number")); + assert!(desc.contains("minimum: 1")); + assert!(desc.contains("maximum: 10")); + } + + /// The `title_schema` used by the title generator should survive + /// transformation. + #[test] + fn title_schema_transforms_cleanly() { + let input = crate::title::title_schema(3); + let out = transform_schema(input); + + assert_eq!(out["type"], "object"); + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["titles"])); + + let titles = out["properties"]["titles"].as_object().unwrap(); + assert_eq!(titles["type"], "array"); + // minItems 1 is kept but > 1 is moved to description + assert!(titles.get("minItems").is_none()); + assert!(titles.get("maxItems").is_none()); + let desc = titles["description"].as_str().unwrap(); + assert!( + desc.contains("minItems"), + "should contain minItems hint: {desc}" + ); + assert!( + desc.contains("maxItems"), + "should contain maxItems hint: {desc}" + ); + } +} diff --git a/crates/jp_llm/src/provider/google.rs b/crates/jp_llm/src/provider/google.rs index 6d4a83ec..5ae80d5e 100644 --- a/crates/jp_llm/src/provider/google.rs +++ b/crates/jp_llm/src/provider/google.rs @@ -2,8 +2,9 @@ use std::{collections::HashMap, env}; use async_stream::stream; use async_trait::async_trait; +use chrono::NaiveDate; use futures::{StreamExt as _, TryStreamExt as _}; -use gemini_client_rs::{GeminiClient, types}; +use gemini_client_rs::{GeminiClient, GeminiError, types}; use indexmap::IndexMap; use jp_config::{ assistant::tool_choice::ToolChoice, @@ -18,14 +19,15 @@ use jp_conversation::{ event::{ChatResponse, ConversationEvent, EventKind, ToolCallRequest}, thread::{Document, Documents, Thread}, }; -use serde_json::Value; +use serde_json::{Map, Value}; use tracing::{debug, trace}; use super::{EventStream, Provider}; use crate::{ - error::{Error, Result}, + StreamErrorKind, + error::{Error, Result, StreamError, looks_like_quota_error}, event::{Event, FinishReason}, - model::{ModelDetails, ReasoningDetails}, + model::{ModelDeprecation, ModelDetails, ReasoningDetails}, query::ChatQuery, tool::ToolDefinition, }; @@ -69,7 +71,7 @@ impl Provider for Google { query: ChatQuery, ) -> Result { let client = self.client.clone(); - let request = create_request(model, query)?; + let (request, structured) = create_request(model, query)?; let slug = model.id.name.clone(); debug!(stream = true, "Google chat completion stream request."); @@ -78,7 +80,7 @@ impl Provider for Google { "Request payload." ); - Ok(call(client, request, slug, 0)) + Ok(call(client, request, slug, 0, structured)) } } @@ -87,17 +89,19 @@ fn call( request: types::GenerateContentRequest, model: Name, tries: usize, + is_structured: bool, ) -> EventStream { Box::pin(stream! { let mut state = IndexMap::new(); let stream = client .stream_content(&model, &request) - .await? - .map_err(Error::from); + .await + .map_err(|e| StreamError::other(e.to_string()))? + .map_err(StreamError::from); tokio::pin!(stream); while let Some(event) = stream.next().await { - for event in map_response(event?, &mut state)? { + for event in map_response(event?, &mut state, is_structured).map_err(|e| StreamError::other(e.to_string()))? { // Sometimes the API returns an "unexpected tool call" error, if // a previous turn had tools available but those were made // unavailable in follow-up turns. This is a known issue: @@ -117,7 +121,7 @@ fn call( let should_retry = matches!(&event, Event::Finished(FinishReason::Other(Value::String(s))) if s == "UNEXPECTED_TOOL_CALL"); if should_retry && tries < 3 { - let mut next_stream = call(client.clone(), request.clone(), model.clone(), tries + 1); + let mut next_stream = call(client.clone(), request.clone(), model.clone(), tries + 1, is_structured); while let Some(item) = next_stream.next().await { yield item; } @@ -131,21 +135,30 @@ fn call( } #[expect(clippy::too_many_lines)] -fn create_request(model: &ModelDetails, query: ChatQuery) -> Result { +fn create_request( + model: &ModelDetails, + query: ChatQuery, +) -> Result<(types::GenerateContentRequest, bool)> { let ChatQuery { thread, tools, tool_choice, - tool_call_strict_mode, } = query; let Thread { system_prompt, - instructions, + sections, attachments, events, } = thread; + // Only use the schema if the very last event is a ChatRequest with one. + let structured_schema = events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()); + let is_structured = structured_schema.is_some(); + let config = events.config()?; let parameters = &config.assistant.model.parameters; @@ -170,51 +183,81 @@ fn create_request(model: &ModelDetails, query: ChatQuery) -> Result 0) || reasoning.is_some()) - .map(|details| types::ThinkingConfig { - include_thoughts: reasoning.is_some_and(|v| !v.exclude), - thinking_budget: match details { - ReasoningDetails::Leveled { .. } => None, - _ => reasoning.map(|v| { + let supports_thinking = model.reasoning.is_some_and(|r| !r.is_unsupported()); + + // Add thinking config if the model supports it. + let thinking_config = if let Some(details) = model.reasoning.filter(|_| supports_thinking) { + if let Some(config) = reasoning { + // Reasoning is enabled — configure thinking accordingly. + Some(types::ThinkingConfig { + include_thoughts: !config.exclude, + thinking_budget: if details.is_leveled() { + None + } else { // TODO: Once the `gemini` crate supports `-1` for "auto" // thinking, use that here if `effort` is `Auto`. // // See: #[expect(clippy::cast_sign_loss)] - v.effort + let tokens = config + .effort .to_tokens(max_output_tokens.unwrap_or(32_000) as u32) .min(details.max_tokens().unwrap_or(u32::MAX)) - .max(details.min_tokens()) - }), - }, - thinking_level: match details { - ReasoningDetails::Leveled { low, high, .. } => { - let level = reasoning.map(|v| { - v.effort + .max(details.min_tokens()); + Some(tokens) + }, + thinking_level: match details { + ReasoningDetails::Leveled { + xlow, + low, + medium, + high, + xhigh: _, + } => { + let level = config + .effort .abs_to_rel(max_output_tokens.map(i32::cast_unsigned)) - }); - - match level { - Some(ReasoningEffort::Low) if low => Some(types::ThinkingLevel::Low), - Some(ReasoningEffort::High) if high => Some(types::ThinkingLevel::High), - // Any other level is unsupported and treated as - // high (since the documentation specifies this is - // the default). - _ => Some(types::ThinkingLevel::High), + .unwrap_or(ReasoningEffort::Auto); + + match level { + ReasoningEffort::None | ReasoningEffort::Xlow if xlow => { + Some(types::ThinkingLevel::Minimal) + } + ReasoningEffort::Low if low => Some(types::ThinkingLevel::Low), + ReasoningEffort::Medium if medium => Some(types::ThinkingLevel::Medium), + ReasoningEffort::High if high => Some(types::ThinkingLevel::High), + + // Any other level is unsupported and treated as + // high (since the documentation specifies this is + // the default). + _ => Some(types::ThinkingLevel::High), + } } - } - _ => None, - }, - }); + _ => None, + }, + }) + } else if details.min_tokens() > 0 { + // Model requires a minimum thinking budget — can't fully disable. + Some(types::ThinkingConfig { + include_thoughts: false, + thinking_budget: Some(details.min_tokens()), + thinking_level: None, + }) + } else { + // Reasoning is off — explicitly disable thinking. + Some(types::ThinkingConfig { + include_thoughts: false, + thinking_budget: Some(0), + thinking_level: None, + }) + } + } else { + None + }; let parts = { let mut parts = vec![]; @@ -222,14 +265,8 @@ fn create_request(model: &ModelDetails, query: ChatQuery) -> Result>>()? - .join("\n\n"); - - parts.push(types::ContentData::Text(text)); + for section in §ions { + parts.push(types::ContentData::Text(section.render())); } if !attachments.is_empty() { @@ -255,58 +292,186 @@ fn create_request(model: &ModelDetails, query: ChatQuery) -> Result>() }; - Ok(types::GenerateContentRequest { - system_instruction: if parts.is_empty() { - None - } else { - Some(types::Content { parts, role: None }) + // Set structured output config on GenerationConfig when a schema is present. + // We use `_responseJsonSchema` (the JSON Schema field) rather than + // `responseSchema` (the OpenAPI Schema field) so that standard JSON + // Schema properties like `additionalProperties` are accepted. + // + // The schema is transformed to rewrite unsupported properties (e.g. + // `const` → `enum`) so constraints aren't silently dropped. + let (response_mime_type, response_json_schema) = match structured_schema { + Some(schema) => ( + Some("application/json".to_owned()), + Some(Value::Object(transform_schema(schema))), + ), + None => (None, None), + }; + + Ok(( + types::GenerateContentRequest { + system_instruction: if parts.is_empty() { + None + } else { + Some(types::Content { parts, role: None }) + }, + contents: convert_events(events), + tools, + tool_config: Some(tool_config), + generation_config: Some(types::GenerationConfig { + max_output_tokens, + #[expect(clippy::cast_lossless)] + temperature: parameters.temperature.map(|v| v as f64), + #[expect(clippy::cast_lossless)] + top_p: parameters.top_p.map(|v| v as f64), + #[expect(clippy::cast_possible_wrap)] + top_k: parameters.top_k.map(|v| v as i32), + thinking_config, + response_mime_type, + response_json_schema, + ..Default::default() + }), }, - contents: convert_events(events), - tools, - tool_config: Some(tool_config), - generation_config: Some(types::GenerationConfig { - max_output_tokens, - #[expect(clippy::cast_lossless)] - temperature: parameters.temperature.map(|v| v as f64), - #[expect(clippy::cast_lossless)] - top_p: parameters.top_p.map(|v| v as f64), - #[expect(clippy::cast_possible_wrap)] - top_k: parameters.top_k.map(|v| v as i32), - thinking_config, - ..Default::default() - }), - }) + is_structured, + )) } +/// Map a Gemini model to a `ModelDetails`. +/// +/// See: +/// See: +#[expect(clippy::too_many_lines)] fn map_model(model: types::Model) -> ModelDetails { - ModelDetails { - id: (PROVIDER, model.base_model_id.as_str()).try_into().unwrap(), - display_name: Some(model.display_name), - context_window: Some(model.input_token_limit), - max_output_tokens: Some(model.output_token_limit), - reasoning: model - .base_model_id - .starts_with("gemini-2.5-pro") - .then_some(ReasoningDetails::budgetted(128, Some(32768))) - .or_else(|| { - model - .base_model_id - .starts_with("gemini-2.5-flash") - .then_some(ReasoningDetails::budgetted(0, Some(24576))) - }) - .or_else(|| { - (model.base_model_id.starts_with("gemini-3-flash") - || model.base_model_id == "gemini-flash-latest") - .then_some(ReasoningDetails::leveled(true, true, true, true, false)) - }) - .or_else(|| { - (model.base_model_id.starts_with("gemini-3-pro") - || model.base_model_id == "gemini-pro-latest") - .then_some(ReasoningDetails::leveled(false, true, false, true, false)) - }), - knowledge_cutoff: None, - deprecated: None, - features: vec![], + let name = model.base_model_id.as_str(); + let display_name = Some(model.display_name); + let context_window = Some(model.input_token_limit); + let max_output_tokens = Some(model.output_token_limit); + let Ok(id) = (PROVIDER, model.base_model_id.as_str()).try_into() else { + return ModelDetails::empty((PROVIDER, "unknown").try_into().unwrap()); + }; + + match name { + "gemini-pro-latest" | "gemini-3.1-pro-preview" | "gemini-3.1-pro-preview-customtools" => { + ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::leveled(false, true, true, true, false)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![], + } + } + "gemini-3-pro-preview" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::leveled(false, true, false, true, false)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![], + }, + "gemini-flash-latest" | "gemini-3-flash-preview" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::leveled(true, true, true, true, false)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![], + }, + "gemini-2.5-flash" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::budgetted(0, Some(24576))), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: gemini-3-flash-preview", + Some(NaiveDate::from_ymd_opt(2026, 6, 17).unwrap()), + )), + features: vec![], + }, + "gemini-flash-lite-latest" + | "gemini-2.5-flash-lite" + | "gemini-2.5-flash-lite-preview-09-2025" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::budgetted(512, Some(24576))), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: unknown", + Some(NaiveDate::from_ymd_opt(2026, 7, 22).unwrap()), + )), + features: vec![], + }, + "gemini-2.5-pro" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::budgetted(512, Some(24576))), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: gemini-3-pro-preview", + Some(NaiveDate::from_ymd_opt(2026, 6, 17).unwrap()), + )), + features: vec![], + }, + "gemini-2.0-flash" | "gemini-2.0-flash-001" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::budgetted(0, Some(24576))), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 8, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: gemini-2.5-flash", + Some(NaiveDate::from_ymd_opt(2026, 6, 1).unwrap()), + )), + features: vec![], + }, + "gemini-2.0-flash-lite" | "gemini-2.0-flash-lite-001" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::unsupported()), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 8, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: gemini-2.5-flash-lite", + Some(NaiveDate::from_ymd_opt(2026, 6, 1).unwrap()), + )), + features: vec![], + }, + id => { + trace!( + name, + display_name = display_name + .clone() + .unwrap_or_else(|| "".to_owned()), + id, + "Missing model details. Falling back to generic model details." + ); + + ModelDetails { + id: (PROVIDER, model.base_model_id.as_str()) + .try_into() + .unwrap_or((PROVIDER, "unknown").try_into().unwrap()), + display_name, + context_window, + max_output_tokens, + reasoning: None, + knowledge_cutoff: None, + deprecated: None, + features: vec![], + } + } } } @@ -340,6 +505,7 @@ impl CandidateState { fn map_response( response: types::GenerateContentResponse, state: &mut IndexMap, + is_structured: bool, ) -> Result> { debug!("Received response from Google API."); trace!( @@ -350,7 +516,7 @@ fn map_response( response .candidates .into_iter() - .flat_map(|v| map_candidate(v, state)) + .flat_map(|v| map_candidate(v, state, is_structured)) .try_fold(vec![], |mut acc, events| { acc.extend(events); Ok(acc) @@ -360,6 +526,7 @@ fn map_response( fn map_candidate( candidate: types::Candidate, states: &mut IndexMap, + is_structured: bool, ) -> Result> { let types::Candidate { content, @@ -382,6 +549,13 @@ fn map_candidate( .. } = part; + // Google sometimes sends empty text parts (e.g. the final chunk of a + // thinking+tool_use response with finishReason: STOP). Skip them before + // mode transition logic to avoid spurious flushes. + if matches!(&data, types::ContentData::Text(text) if text.is_empty()) { + continue; + } + // Determine what "mode" the content is in. let mode = if matches!(data, types::ContentData::FunctionCall(_)) { ContentMode::FunctionCall @@ -414,6 +588,10 @@ fn map_candidate( event: ConversationEvent::now(ChatResponse::reasoning(text)), index, }, + types::ContentData::Text(text) if is_structured => Event::Part { + event: ConversationEvent::now(ChatResponse::structured(Value::String(text))), + index, + }, types::ContentData::Text(text) => Event::Part { event: ConversationEvent::now(ChatResponse::message(text)), index, @@ -488,11 +666,140 @@ impl TryFrom<&GoogleConfig> for Google { } } -fn convert_tool_choice(choice: ToolChoice, strict: bool) -> types::ToolConfig { +/// Transform a JSON schema to conform to Google's structured output constraints. +/// +/// Google's Gemini API supports a subset of JSON Schema. This transformation: +/// - Inlines `$ref` references by replacing them with the referenced `$defs` +/// - Removes `$defs`/`definitions` from the output after inlining +/// - Rewrites `const` to `enum` with a single value (Google ignores `const`) +/// - Adds `propertyOrdering` to objects with multiple properties +/// - Recursively processes `anyOf`, object properties, `additionalProperties`, +/// array `items`, and `prefixItems` +/// +/// Mirrors the logic from Google's Python SDK `process_schema`. +/// +/// See: +fn transform_schema(mut src: Map) -> Map { + // Extract $defs from the root. They are inlined wherever $ref appears + // and discarded from the output. + let defs = src + .remove("$defs") + .or_else(|| src.remove("definitions")) + .and_then(|v| match v { + Value::Object(m) => Some(m), + _ => None, + }) + .unwrap_or_default(); + + process_schema(src, &defs) +} + +/// Core recursive processor for a single schema node. +fn process_schema(mut src: Map, defs: &Map) -> Map { + // Resolve $ref by inlining the referenced definition. + if let Some(Value::String(ref_path)) = src.remove("$ref") + && let Some(resolved) = resolve_ref(&ref_path, defs) + { + let mut merged = resolved; + // Preserve sibling fields (e.g. description, nullable) from the + // referring schema — the definition's own fields take precedence. + for (k, v) in src { + merged.entry(k).or_insert(v); + } + return process_schema(merged, defs); + } + + // Rewrite `const` to `enum` with a single-element array. Google supports + // `enum` for strings and numbers but not `const`. + if let Some(val) = src.remove("const") { + src.insert("enum".into(), Value::Array(vec![val])); + } + + // Handle anyOf: recurse into each variant, then return early (matching the + // Python SDK's behavior). + if let Some(Value::Array(variants)) = src.remove("anyOf") { + src.insert( + "anyOf".into(), + Value::Array( + variants + .into_iter() + .map(|v| resolve_and_process(v, defs)) + .collect(), + ), + ); + return src; + } + + // Type-specific processing. + match src.get("type").and_then(Value::as_str) { + Some("object") => { + if let Some(Value::Object(props)) = src.remove("properties") { + let keys: Vec = props.keys().cloned().collect(); + let processed: Map = props + .into_iter() + .map(|(k, v)| (k, resolve_and_process(v, defs))) + .collect(); + + // Deterministic output ordering for objects with >1 property. + if keys.len() > 1 && !src.contains_key("propertyOrdering") { + src.insert( + "propertyOrdering".into(), + Value::Array(keys.into_iter().map(Value::String).collect()), + ); + } + + src.insert("properties".into(), Value::Object(processed)); + } + + // Process additionalProperties when it's a schema (not a boolean). + if let Some(additional) = src.remove("additionalProperties") { + src.insert("additionalProperties".into(), match additional { + Value::Object(schema) => Value::Object(process_schema(schema, defs)), + other => other, + }); + } + } + Some("array") => { + if let Some(items) = src.remove("items") { + src.insert("items".into(), resolve_and_process(items, defs)); + } + + if let Some(Value::Array(prefixes)) = src.remove("prefixItems") { + src.insert( + "prefixItems".into(), + Value::Array( + prefixes + .into_iter() + .map(|v| resolve_and_process(v, defs)) + .collect(), + ), + ); + } + } + _ => {} + } + + src +} + +/// Resolve any `$ref` inside a value, then recursively process it. +fn resolve_and_process(value: Value, defs: &Map) -> Value { + match value { + Value::Object(map) => Value::Object(process_schema(map, defs)), + other => other, + } +} + +/// Look up a `$ref` path (e.g. `#/$defs/MyType`) in the definitions map. +fn resolve_ref(ref_path: &str, defs: &Map) -> Option> { + let name = ref_path.rsplit("defs/").next().unwrap_or(ref_path); + defs.get(name).and_then(Value::as_object).cloned() +} + +fn convert_tool_choice(choice: ToolChoice) -> types::ToolConfig { let (mode, allowed_function_names) = match choice { ToolChoice::None => (types::FunctionCallingMode::None, vec![]), - ToolChoice::Auto if strict => (types::FunctionCallingMode::Validated, vec![]), - ToolChoice::Auto => (types::FunctionCallingMode::Auto, vec![]), + ToolChoice::Auto => (types::FunctionCallingMode::Validated, vec![]), ToolChoice::Required => (types::FunctionCallingMode::Any, vec![]), ToolChoice::Function(name) => (types::FunctionCallingMode::Any, vec![name]), }; @@ -546,12 +853,20 @@ fn convert_events(events: ConversationStream) -> Vec { types::Role::User, types::ContentData::Text(request.content).into(), ), - EventKind::ChatResponse(response) => (types::Role::Model, types::ContentPart { - thought: response.is_reasoning(), - data: types::ContentData::Text(response.into_content()), - metadata: None, - thought_signature: None, - }), + EventKind::ChatResponse(response) => { + let thought = response.is_reasoning(); + let text = match response { + ChatResponse::Message { message } => message, + ChatResponse::Reasoning { reasoning } => reasoning, + ChatResponse::Structured { data } => data.to_string(), + }; + (types::Role::Model, types::ContentPart { + thought, + data: types::ContentData::Text(text), + metadata: None, + thought_signature: None, + }) + } EventKind::ToolCallRequest(request) => (types::Role::Model, types::ContentPart { data: types::ContentData::FunctionCall(types::FunctionCall { name: { @@ -611,37 +926,67 @@ fn convert_events(events: ConversationStream) -> Vec { }) } -#[cfg(test)] -mod tests { - use jp_config::model::parameters::{ - PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort, - }; - use jp_conversation::event::ChatRequest; - use jp_test::function_name; - use test_log::test; +impl From for StreamError { + fn from(err: GeminiError) -> Self { + match err { + GeminiError::Http(error) => Self::from(error), + GeminiError::EventSource(error) => Self::from(error), + GeminiError::Api(ref value) => { + let msg = err.to_string(); + + // Check for quota/billing exhaustion first. + if looks_like_quota_error(&msg) { + return StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your plan and billing details \ + at https://console.cloud.google.com/billing. ({msg})" + ), + ) + .with_source(err); + } - use super::*; - use crate::test::{TestRequest, run_test}; + // Classify by HTTP status code if present in the API error. + let status = value + .get("status") + .or_else(|| value.pointer("/error/code")) + .and_then(serde_json::Value::as_u64); - // TODO: Test specific conditions as detailed in - // : - // - // - parallel function calls - // - dummy thought signatures - // - multi-turn conversations - #[test(tokio::test)] - async fn test_gemini_3_reasoning() -> std::result::Result<(), Box> { - let request = TestRequest::chat(PROVIDER) - .stream(true) - .reasoning(Some(PartialReasoningConfig::Custom( - PartialCustomReasoningConfig { - effort: Some(ReasoningEffort::Low), - exclude: Some(false), - }, - ))) - .model("google/gemini-3-pro-preview".parse().unwrap()) - .event(ChatRequest::from("Test message")); + match status { + Some(429) => StreamError::rate_limit(None).with_source(err), + Some(500 | 502 | 503 | 504) => StreamError::transient(msg).with_source(err), + _ => StreamError::other(msg).with_source(err), + } + } + GeminiError::Json { data, error } => StreamError::other(data).with_source(error), + GeminiError::FunctionExecution(msg) => StreamError::other(msg), + } + } +} - run_test(PROVIDER, function_name!(), Some(request)).await +impl From for Error { + fn from(error: GeminiError) -> Self { + match &error { + GeminiError::Api(api) if api.get("status").is_some_and(|v| v.as_u64() == Some(404)) => { + if let Some(model) = api.pointer("/message/error/message").and_then(|v| { + v.as_str().and_then(|s| { + s.contains("Call ListModels").then(|| { + s.split('/') + .nth(1) + .and_then(|v| v.split(' ').next()) + .unwrap_or("unknown") + }) + }) + }) { + return Self::UnknownModel(model.to_owned()); + } + Self::Gemini(error) + } + _ => Self::Gemini(error), + } } } + +#[cfg(test)] +#[path = "google_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/google_tests.rs b/crates/jp_llm/src/provider/google_tests.rs new file mode 100644 index 00000000..c6fd1087 --- /dev/null +++ b/crates/jp_llm/src/provider/google_tests.rs @@ -0,0 +1,501 @@ +use jp_config::model::parameters::{ + PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort, +}; +use jp_conversation::event::ChatRequest; +use jp_test::function_name; +use test_log::test; + +use super::*; +use crate::test::{TestRequest, run_test}; + +// TODO: Test specific conditions as detailed in +// : +// +// - parallel function calls +// - dummy thought signatures +// - multi-turn conversations +#[test(tokio::test)] +async fn test_gemini_3_reasoning() -> std::result::Result<(), Box> { + let request = TestRequest::chat(PROVIDER) + .reasoning(Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::Low), + exclude: Some(false), + }, + ))) + .model("google/gemini-3-pro-preview".parse().unwrap()) + .event(ChatRequest::from("Test message")); + + run_test(PROVIDER, function_name!(), Some(request)).await +} + +mod transform_schema { + use serde_json::{Map, Value, json}; + + use super::transform_schema; + + #[expect(clippy::needless_pass_by_value)] + fn schema(v: Value) -> Map { + v.as_object().unwrap().clone() + } + + #[test] + fn const_rewritten_to_enum() { + let input = schema(json!({ + "type": "string", + "const": "tool_call.my_tool.call_123" + })); + + let out = transform_schema(input); + + assert_eq!(out.get("const"), None); + assert_eq!(out["enum"], json!(["tool_call.my_tool.call_123"])); + assert_eq!(out["type"], "string"); + } + + #[test] + fn const_rewritten_for_non_string_values() { + let input = schema(json!({ + "type": "integer", + "const": 42 + })); + + let out = transform_schema(input); + + assert_eq!(out.get("const"), None); + assert_eq!(out["enum"], json!([42])); + } + + #[test] + fn nested_const_in_properties_rewritten() { + let input = schema(json!({ + "type": "object", + "properties": { + "inquiry_id": { + "type": "string", + "const": "tool_call.fs_modify_file.call_abc" + }, + "answer": { + "type": "boolean" + } + }, + "required": ["inquiry_id", "answer"] + })); + + let out = transform_schema(input); + + let inquiry_id = out["properties"]["inquiry_id"].as_object().unwrap(); + assert_eq!(inquiry_id.get("const"), None); + assert_eq!( + inquiry_id["enum"], + json!(["tool_call.fs_modify_file.call_abc"]) + ); + assert_eq!(inquiry_id["type"], "string"); + assert_eq!(out["properties"]["answer"]["type"], "boolean"); + } + + #[test] + fn deeply_nested_const_rewritten() { + let input = schema(json!({ + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": { + "inner": { + "type": "string", + "const": "fixed" + } + } + } + } + })); + + let out = transform_schema(input); + + let inner = &out["properties"]["outer"]["properties"]["inner"]; + assert_eq!(inner.get("const"), None); + assert_eq!(inner["enum"], json!(["fixed"])); + } + + #[test] + fn const_in_array_items_rewritten() { + let input = schema(json!({ + "type": "array", + "items": { + "type": "string", + "const": "only_value" + } + })); + + let out = transform_schema(input); + + let items = out["items"].as_object().unwrap(); + assert_eq!(items.get("const"), None); + assert_eq!(items["enum"], json!(["only_value"])); + } + + #[test] + fn ref_inlined_from_defs() { + let input = schema(json!({ + "type": "array", + "items": { "$ref": "#/$defs/CountryInfo" }, + "$defs": { + "CountryInfo": { + "type": "object", + "properties": { + "continent": { "type": "string" }, + "gdp": { "type": "integer" } + }, + "required": ["continent", "gdp"] + } + } + })); + + let out = transform_schema(input); + + // $defs should be removed from the output. + assert!(out.get("$defs").is_none()); + + // items should be the inlined definition. + let items = out["items"].as_object().unwrap(); + assert_eq!(items["type"], "object"); + assert_eq!(items["properties"]["continent"]["type"], "string"); + assert_eq!(items["properties"]["gdp"]["type"], "integer"); + assert_eq!(items["required"], json!(["continent", "gdp"])); + } + + #[test] + fn ref_with_sibling_fields_preserved() { + let input = schema(json!({ + "type": "object", + "properties": { + "person": { + "$ref": "#/$defs/Person", + "description": "The main person" + } + }, + "$defs": { + "Person": { + "type": "object", + "properties": { + "name": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + // Sibling "description" should be preserved alongside inlined def. + let person = out["properties"]["person"].as_object().unwrap(); + assert_eq!(person["type"], "object"); + assert_eq!(person["description"], "The main person"); + assert_eq!(person["properties"]["name"]["type"], "string"); + } + + #[test] + fn ref_in_nested_property() { + let input = schema(json!({ + "type": "object", + "properties": { + "addr": { "$ref": "#/$defs/Address" } + }, + "$defs": { + "Address": { + "type": "object", + "properties": { + "city": { "type": "string" }, + "zip": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + let addr = out["properties"]["addr"].as_object().unwrap(); + assert_eq!(addr["type"], "object"); + assert_eq!(addr["properties"]["city"]["type"], "string"); + assert_eq!(addr["properties"]["zip"]["type"], "string"); + // Inlined object gets propertyOrdering. + assert_eq!(addr["propertyOrdering"], json!(["city", "zip"])); + } + + #[test] + fn definitions_also_removed() { + let input = schema(json!({ + "type": "object", + "properties": { + "x": { "$ref": "#/$defs/X" } + }, + "definitions": { + "X": { "type": "string" } + } + })); + + let out = transform_schema(input); + + assert!(out.get("definitions").is_none()); + assert_eq!(out["properties"]["x"]["type"], "string"); + } + + #[test] + fn property_ordering_added_for_multiple_properties() { + let input = schema(json!({ + "type": "object", + "properties": { + "first": { "type": "string" }, + "second": { "type": "integer" }, + "third": { "type": "boolean" } + } + })); + + let out = transform_schema(input); + + assert_eq!(out["propertyOrdering"], json!(["first", "second", "third"])); + } + + #[test] + fn property_ordering_not_added_for_single_property() { + let input = schema(json!({ + "type": "object", + "properties": { + "only": { "type": "string" } + } + })); + + let out = transform_schema(input); + + assert!(out.get("propertyOrdering").is_none()); + } + + #[test] + fn property_ordering_preserved_if_already_set() { + let input = schema(json!({ + "type": "object", + "properties": { + "a": { "type": "string" }, + "b": { "type": "string" } + }, + "propertyOrdering": ["b", "a"] + })); + + let out = transform_schema(input); + + // Existing ordering should not be overwritten. + assert_eq!(out["propertyOrdering"], json!(["b", "a"])); + } + + #[test] + fn anyof_variants_processed() { + let input = schema(json!({ + "anyOf": [ + { "type": "string", "const": "fixed" }, + { "type": "integer" } + ] + })); + + let out = transform_schema(input); + + let variants = out["anyOf"].as_array().unwrap(); + assert_eq!(variants.len(), 2); + // const should be rewritten inside the variant. + assert_eq!(variants[0]["enum"], json!(["fixed"])); + assert!(variants[0].get("const").is_none()); + assert_eq!(variants[1]["type"], "integer"); + } + + #[test] + fn anyof_with_ref_resolved() { + let input = schema(json!({ + "anyOf": [ + { "$ref": "#/$defs/Str" }, + { "type": "integer" } + ], + "$defs": { + "Str": { "type": "string" } + } + })); + + let out = transform_schema(input); + + let variants = out["anyOf"].as_array().unwrap(); + assert_eq!(variants[0]["type"], "string"); + assert_eq!(variants[1]["type"], "integer"); + } + + #[test] + fn additional_properties_bool_preserved() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "additionalProperties": false + })); + + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + } + + #[test] + fn additional_properties_schema_processed() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "additionalProperties": { + "type": "string", + "const": "extra" + } + })); + + let out = transform_schema(input); + + let additional = out["additionalProperties"].as_object().unwrap(); + assert_eq!(additional.get("const"), None); + assert_eq!(additional["enum"], json!(["extra"])); + } + + #[test] + fn prefix_items_processed() { + let input = schema(json!({ + "type": "array", + "prefixItems": [ + { "type": "string", "const": "header" }, + { "type": "integer" } + ] + })); + + let out = transform_schema(input); + + let prefixes = out["prefixItems"].as_array().unwrap(); + assert_eq!(prefixes[0]["enum"], json!(["header"])); + assert!(prefixes[0].get("const").is_none()); + assert_eq!(prefixes[1]["type"], "integer"); + } + + #[test] + fn enum_preserved_unchanged() { + let input = schema(json!({ + "type": "string", + "enum": ["A", "B", "C"] + })); + + let out = transform_schema(input); + + assert_eq!(out["enum"], json!(["A", "B", "C"])); + } + + #[test] + fn supported_properties_preserved() { + let input = schema(json!({ + "type": "integer", + "minimum": 1, + "maximum": 10, + "description": "A number" + })); + + let out = transform_schema(input); + + assert_eq!(out["type"], "integer"); + assert_eq!(out["minimum"], 1); + assert_eq!(out["maximum"], 10); + assert_eq!(out["description"], "A number"); + } + + /// The actual inquiry schema should transform correctly for Google. + #[test] + fn inquiry_schema_transforms_correctly() { + let input = schema(json!({ + "type": "object", + "properties": { + "inquiry_id": { + "type": "string", + "const": "tool_call.fs_modify_file.call_a3b7c9d1" + }, + "answer": { + "type": "boolean" + } + }, + "required": ["inquiry_id", "answer"], + "additionalProperties": false + })); + + let out = transform_schema(input); + + assert_eq!( + Value::Object(out), + json!({ + "type": "object", + "required": ["inquiry_id", "answer"], + "additionalProperties": false, + "propertyOrdering": ["inquiry_id", "answer"], + "properties": { + "inquiry_id": { + "type": "string", + "enum": ["tool_call.fs_modify_file.call_a3b7c9d1"] + }, + "answer": { + "type": "boolean" + } + } + }) + ); + } + + /// The `title_schema` should pass through mostly unchanged. + /// It has a single property so no `propertyOrdering` is added. + #[test] + fn title_schema_passes_through() { + let input = crate::title::title_schema(3); + let out = transform_schema(input.clone()); + + assert_eq!(out, input); + } + + /// Matches the example from the Python SDK docstring. + #[test] + fn sdk_docstring_example() { + let input = schema(json!({ + "items": { "$ref": "#/$defs/CountryInfo" }, + "title": "Placeholder", + "type": "array", + "$defs": { + "CountryInfo": { + "properties": { + "continent": { "title": "Continent", "type": "string" }, + "gdp": { "title": "Gdp", "type": "integer" } + }, + "required": ["continent", "gdp"], + "title": "CountryInfo", + "type": "object" + } + } + })); + + let out = transform_schema(input); + + // $defs removed, $ref inlined, propertyOrdering added. + assert_eq!( + Value::Object(out), + json!({ + "title": "Placeholder", + "type": "array", + "items": { + "properties": { + "continent": { "title": "Continent", "type": "string" }, + "gdp": { "title": "Gdp", "type": "integer" } + }, + "required": ["continent", "gdp"], + "title": "CountryInfo", + "type": "object", + "propertyOrdering": ["continent", "gdp"] + } + }) + ); + } +} diff --git a/crates/jp_llm/src/provider/llamacpp.rs b/crates/jp_llm/src/provider/llamacpp.rs index e4d1f68b..fdde4a4e 100644 --- a/crates/jp_llm/src/provider/llamacpp.rs +++ b/crates/jp_llm/src/provider/llamacpp.rs @@ -1,7 +1,7 @@ use std::mem; use async_trait::async_trait; -use futures::{FutureExt as _, StreamExt as _, future, stream}; +use futures::{StreamExt as _, future, stream}; use jp_config::{ assistant::tool_choice::ToolChoice, model::id::{ModelIdConfig, Name, ProviderId}, @@ -11,26 +11,19 @@ use jp_conversation::{ ConversationEvent, ConversationStream, event::{ChatResponse, EventKind, ToolCallResponse}, }; -use openai::{ - Credentials, - chat::{ - self, ChatCompletionBuilder, ChatCompletionChoiceDelta, ChatCompletionDelta, - ChatCompletionMessage, ChatCompletionMessageDelta, ChatCompletionMessageRole, - ToolCallFunction, structured_output::ToolCallFunctionDefinition, - }, -}; -use serde_json::Value; -use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, trace}; +use reqwest_eventsource::{Event as SseEvent, EventSource}; +use serde::Deserialize; +use serde_json::{Value, json}; +use tracing::{debug, trace, warn}; use super::{ EventStream, ModelDetails, - openai::{ModelListResponse, ModelResponse}, + openai::{ModelListResponse, ModelResponse, parameters_with_strict_mode}, }; use crate::{ - error::{Error, Result}, + error::{Error, StreamError}, event::{Event, FinishReason}, - provider::{Provider, openai::parameters_with_strict_mode}, + provider::Provider, query::ChatQuery, stream::aggregator::{ reasoning::ReasoningExtractor, tool_call_request::ToolCallRequestAggregator, @@ -43,46 +36,12 @@ static PROVIDER: ProviderId = ProviderId::Llamacpp; #[derive(Debug, Clone)] pub struct Llamacpp { reqwest_client: reqwest::Client, - credentials: Credentials, base_url: String, } -impl Llamacpp { - /// Build request for Llama.cpp API. - fn build_request( - &self, - model: &ModelDetails, - query: ChatQuery, - ) -> Result { - let slug = model.id.name.to_string(); - let ChatQuery { - thread, - tools, - tool_choice, - tool_call_strict_mode, - } = query; - - let messages = thread.into_messages(to_system_messages, convert_events)?; - let tools = convert_tools(tools, tool_call_strict_mode, &tool_choice); - let tool_choice = convert_tool_choice(&tool_choice); - - trace!( - slug, - messages_size = messages.len(), - tools_size = tools.len(), - "Built Llamacpp request." - ); - - Ok(ChatCompletionDelta::builder(&slug, messages) - .credentials(self.credentials.clone()) - .tools(tools) - .tool_choice(tool_choice)) - } -} - #[async_trait] impl Provider for Llamacpp { - async fn model_details(&self, name: &Name) -> Result { + async fn model_details(&self, name: &Name) -> Result { let id: ModelIdConfig = (PROVIDER, name.as_ref()).try_into()?; Ok(self @@ -93,7 +52,7 @@ impl Provider for Llamacpp { .unwrap_or(ModelDetails::empty(id))) } - async fn models(&self) -> Result> { + async fn models(&self) -> Result, Error> { self.reqwest_client .get(format!("{}/v1/models", self.base_url)) .send() @@ -104,113 +63,417 @@ impl Provider for Llamacpp { .data .iter() .map(map_model) - .collect::>() + .collect::>() } async fn chat_completion_stream( &self, model: &ModelDetails, query: ChatQuery, - ) -> Result { + ) -> Result { debug!( model = %model.id.name, "Starting Llamacpp chat completion stream." ); - let mut extractor = ReasoningExtractor::default(); - let mut agg = ToolCallRequestAggregator::default(); - let request = self.build_request(model, query)?; + let (body, is_structured) = build_request(model, query)?; - trace!(?request, "Sending request to Llamacpp."); + trace!( + body = serde_json::to_string(&body).unwrap_or_default(), + "Sending request to Llamacpp." + ); + + let request = self + .reqwest_client + .post(format!("{}/v1/chat/completions", self.base_url)) + .header("content-type", "application/json") + .json(&body); + + let es = EventSource::new(request).map_err(|e| Error::InvalidResponse(e.to_string()))?; + + let mut state = StreamState { + extractor: ReasoningExtractor::default(), + tool_calls: ToolCallRequestAggregator::default(), + reasoning_flushed: false, + finish_reason: None, + is_structured, + }; - Ok(request - .create_stream() - .await - .map(ReceiverStream::new) - .expect("Should not fail to clone") - .flat_map(|v| stream::iter(v.choices)) - .flat_map(move |v| stream::iter(map_event(v, &mut extractor, &mut agg))) - .chain(future::ok(Event::Finished(FinishReason::Completed)).into_stream()) + Ok(es + // EventSource yields Err on close; stop the stream. + .take_while(|event| future::ready(event.is_ok())) + .flat_map(move |event| stream::iter(handle_sse_event(event, &mut state))) .boxed()) } } -fn map_event( - event: ChatCompletionChoiceDelta, - extractor: &mut ReasoningExtractor, - agg: &mut ToolCallRequestAggregator, -) -> Vec> { - let ChatCompletionChoiceDelta { - index, - finish_reason, - delta: - ChatCompletionMessageDelta { - content, - tool_calls, - .. - }, - } = event; - - #[allow(clippy::cast_possible_truncation)] - let index = index as usize; - let mut events = vec![]; +/// Mutable state carried across SSE events in a single stream. +struct StreamState { + extractor: ReasoningExtractor, + tool_calls: ToolCallRequestAggregator, + reasoning_flushed: bool, + /// Captured from `finish_reason` in the last choice delta. Emitted as + /// `Event::Finished` when the `[DONE]` sentinel arrives. + finish_reason: Option, + is_structured: bool, +} - for chat::ToolCallDelta { id, function, .. } in tool_calls.into_iter().flatten() { - let (name, arguments) = match function { - Some(chat::ToolCallFunction { name, arguments }) => (Some(name), Some(arguments)), - None => (None, None), - }; +/// Process a single SSE event into zero or more provider-agnostic events. +#[expect(clippy::too_many_lines)] +fn handle_sse_event( + event: Result, + state: &mut StreamState, +) -> Vec> { + match event { + Ok(SseEvent::Open) => vec![], + Ok(SseEvent::Message(msg)) => { + if msg.data == "[DONE]" { + // Finalize the reasoning extractor on stream end. + state.extractor.finalize(); + let mut events: Vec> = + drain_extractor(&mut state.extractor, state.is_structured) + .into_iter() + .map(Ok) + .collect(); + + // Flush reasoning if we never did. + if !state.reasoning_flushed { + events.push(Ok(Event::flush(0))); + state.reasoning_flushed = true; + } + + // Flush message content. + events.push(Ok(Event::flush(1))); + + events.push(Ok(Event::Finished( + state + .finish_reason + .take() + .unwrap_or(FinishReason::Completed), + ))); + return events; + } - agg.add_chunk(index, id, name, arguments.as_deref()); - } + let chunk: StreamChunk = match serde_json::from_str(&msg.data) { + Ok(c) => c, + Err(error) => { + warn!( + error = error.to_string(), + data = &msg.data, + "Failed to parse Llamacpp chunk." + ); + + return vec![]; + } + }; + + let mut events = Vec::new(); + + for choice in &chunk.choices { + let delta = &choice.delta; + + // Reasoning via `reasoning_content` (deepseek / deepseek-legacy formats) + if let Some(reasoning) = &delta.reasoning_content + && !reasoning.is_empty() + { + events.push(Ok(Event::Part { + index: 0, + event: ConversationEvent::now(ChatResponse::reasoning(reasoning.clone())), + })); + } + + // Content + // + // If reasoning_content was present, the server already + // separated reasoning from content (deepseek / + // deepseek-legacy). Otherwise, content may contain tags + // (none format) and needs the extractor. + if let Some(content) = &delta.content + && !content.is_empty() + { + // Server separated reasoning; content is pure text. + if delta.reasoning_content.is_some() { + flush_reasoning_if_needed(&mut events, &mut state.reasoning_flushed); + + let response = if state.is_structured { + ChatResponse::structured(Value::String(content.clone())) + } else { + ChatResponse::message(content.clone()) + }; + events.push(Ok(Event::Part { + index: 1, + event: ConversationEvent::now(response), + })); + } else { + // Might contain tags — feed through extractor. + state.extractor.handle(content); + events.extend( + drain_extractor(&mut state.extractor, state.is_structured) + .into_iter() + .map(Ok), + ); + } + } + + // Tool calls + if let Some(tool_calls) = &delta.tool_calls { + flush_reasoning_if_needed(&mut events, &mut state.reasoning_flushed); + + for tc in tool_calls { + let index = tc.index as usize + 2; + let name = tc.function.as_ref().and_then(|f| f.name.clone()); + let arguments = tc.function.as_ref().and_then(|f| f.arguments.as_deref()); + state + .tool_calls + .add_chunk(index, tc.id.clone(), name, arguments); + } + } + + // Finish reason + if let Some(reason) = &choice.finish_reason { + state.extractor.finalize(); + events.extend( + drain_extractor(&mut state.extractor, state.is_structured) + .into_iter() + .map(Ok), + ); + + if matches!(reason.as_str(), "tool_calls" | "stop") { + events.extend(state.tool_calls.finalize_all().into_iter().flat_map( + |(index, result)| { + vec![ + result + .map(|call| Event::Part { + index, + event: ConversationEvent::now(call), + }) + .map_err(|e| StreamError::other(e.to_string())), + Ok(Event::flush(index)), + ] + }, + )); + } + + if !state.reasoning_flushed { + events.push(Ok(Event::flush(0))); + state.reasoning_flushed = true; + } + events.push(Ok(Event::flush(1))); + + // Per the OpenAI spec. + match reason.as_str() { + "length" => state.finish_reason = Some(FinishReason::MaxTokens), + "stop" => state.finish_reason = Some(FinishReason::Completed), + _ => {} + } + } + } - if ["function_call", "tool_calls"].contains(&finish_reason.as_deref().unwrap_or_default()) { - match agg.finalize(index) { - Ok(request) => events.extend(vec![ - Ok(Event::Part { - index, - event: ConversationEvent::now(request), - }), - Ok(Event::flush(index)), - ]), - Err(error) => events.push(Err(error.into())), + events } + Err(e) => vec![Err(StreamError::from(e))], } +} - if let Some(content) = content { - extractor.handle(&content); - } - - if finish_reason.is_some() { - extractor.finalize(); +/// Push a reasoning flush event if we haven't already. +fn flush_reasoning_if_needed(events: &mut Vec>, flushed: &mut bool) { + if !*flushed { + events.push(Ok(Event::flush(0))); + *flushed = true; } - - events.extend(fetch_content(extractor, index).into_iter().map(Ok)); - events } -fn fetch_content(extractor: &mut ReasoningExtractor, index: usize) -> Vec { +/// Drain accumulated content from the `ReasoningExtractor` into events. +/// +/// Index convention matches Ollama: 0 = reasoning, 1 = message content. +fn drain_extractor(extractor: &mut ReasoningExtractor, is_structured: bool) -> Vec { let mut events = Vec::new(); + if !extractor.reasoning.is_empty() { let reasoning = mem::take(&mut extractor.reasoning); events.push(Event::Part { - index, + index: 0, event: ConversationEvent::now(ChatResponse::reasoning(reasoning)), }); } if !extractor.other.is_empty() { let content = mem::take(&mut extractor.other); + let response = if is_structured { + ChatResponse::structured(Value::String(content)) + } else { + ChatResponse::message(content) + }; events.push(Event::Part { - index, - event: ConversationEvent::now(ChatResponse::message(content)), + index: 1, + event: ConversationEvent::now(response), }); } events } -fn map_model(model: &ModelResponse) -> Result { +/// Build the JSON request body for the llama.cpp `/v1/chat/completions` +/// endpoint. +/// +/// Returns `(body, is_structured)`. +fn build_request(model: &ModelDetails, query: ChatQuery) -> Result<(Value, bool), Error> { + let ChatQuery { + thread, + tools, + tool_choice, + } = query; + + let structured_schema = thread + .events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()); + + let is_structured = structured_schema.is_some(); + let slug = model.id.name.to_string(); + + let messages = thread.into_messages(to_system_messages, convert_events)?; + let converted_tools = convert_tools(tools, &tool_choice); + let tool_choice_val = convert_tool_choice(&tool_choice); + + trace!( + slug, + messages_size = messages.len(), + tools_size = converted_tools.len(), + "Built Llamacpp request." + ); + + let mut body = json!({ + "model": slug, + "messages": messages, + "stream": true, + }); + + if !converted_tools.is_empty() { + body["tools"] = json!(converted_tools); + body["tool_choice"] = json!(tool_choice_val); + } + + if let Some(schema) = structured_schema { + body["response_format"] = json!({ + "type": "json_schema", + "json_schema": { + "name": "structured_output", + "schema": schema, + "strict": true, + }, + }); + } + + Ok((body, is_structured)) +} + +/// Convert system prompt parts into a list of JSON message values. +fn to_system_messages(parts: Vec) -> impl Iterator { + parts + .into_iter() + .map(|content| json!({ "role": "system", "content": content })) +} + +/// Convert a conversation event stream into a list of JSON message values. +fn convert_events(events: ConversationStream) -> Vec { + events + .into_iter() + .filter_map(|event| match event.into_kind() { + EventKind::ChatRequest(request) => { + Some(json!({ "role": "user", "content": request.content })) + } + EventKind::ChatResponse(response) => match response { + ChatResponse::Message { message } => { + Some(json!({ "role": "assistant", "content": message })) + } + ChatResponse::Reasoning { reasoning } => { + // Wrap reasoning in tags so the model can pick up + // its own chain-of-thought on the next turn. + Some(json!({ + "role": "assistant", + "content": format!("\n{reasoning}\n"), + })) + } + ChatResponse::Structured { data } => { + Some(json!({ "role": "assistant", "content": data.to_string() })) + } + }, + EventKind::ToolCallRequest(request) => Some(json!({ + "role": "assistant", + "tool_calls": [{ + "id": request.id, + "type": "function", + "function": { + "name": request.name, + "arguments": Value::Object(request.arguments).to_string(), + }, + }], + })), + EventKind::ToolCallResponse(ToolCallResponse { id, result }) => Some(json!({ + "role": "tool", + "tool_call_id": id, + "content": match result { + Ok(content) | Err(content) => content, + }, + })), + _ => None, + }) + .fold(vec![], |mut messages: Vec, message| { + // Merge consecutive assistant messages that carry tool_calls + // (same folding logic as the old openai-crate implementation). + if message.get("tool_calls").is_some() + && let Some(last) = messages.last_mut() + && last.get("tool_calls").is_some() + && let (Some(existing), Some(new)) = ( + last["tool_calls"].as_array_mut(), + message["tool_calls"].as_array(), + ) + { + existing.extend(new.iter().cloned()); + return messages; + } + messages.push(message); + messages + }) +} + +/// Convert tool definitions to the OpenAI-compatible JSON format. +/// +/// If [`ToolChoice::Function`] is set, only include the named tool. llama.cpp +/// doesn't support calling a specific tool by name, but it supports `required` +/// mode, so we limit the tool list instead. +fn convert_tools(tools: Vec, tool_choice: &ToolChoice) -> Vec { + tools + .into_iter() + .map(|tool| { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description.unwrap_or_default(), + "parameters": parameters_with_strict_mode(tool.parameters, true), + "strict": true, + }, + }) + }) + .filter(|tool| match tool_choice { + ToolChoice::Function(req) => tool["function"]["name"].as_str() == Some(req.as_str()), + _ => true, + }) + .collect() +} + +fn convert_tool_choice(choice: &ToolChoice) -> &str { + match choice { + ToolChoice::Auto => "auto", + ToolChoice::None => "none", + ToolChoice::Required | ToolChoice::Function(_) => "required", + } +} + +fn map_model(model: &ModelResponse) -> Result { Ok(ModelDetails { id: ( PROVIDER, @@ -233,189 +496,65 @@ fn map_model(model: &ModelResponse) -> Result { impl TryFrom<&LlamacppConfig> for Llamacpp { type Error = Error; - fn try_from(config: &LlamacppConfig) -> Result { + fn try_from(config: &LlamacppConfig) -> Result { let reqwest_client = reqwest::Client::builder().build()?; let base_url = config.base_url.clone(); - let credentials = Credentials::new("", &base_url); Ok(Llamacpp { reqwest_client, - credentials, base_url, }) } } -/// Convert a list of [`jp_mcp::Tool`] to a list of [`chat::ChatCompletionTool`]. -/// -/// Additionally, if [`ToolChoice::Function`] is provided, only return the -/// tool(s) that matches the expected name. This is done because Llama.cpp does -/// not support calling a specific tool by name, but it *does* support -/// "required"/forced tool calling, which we can turn into a request to call a -/// specific tool, by limiting the list of tools to the ones that we want to be -/// called. -fn convert_tools( - tools: Vec, - strict: bool, - tool_choice: &ToolChoice, -) -> Vec { - tools - .into_iter() - .map(|tool| chat::ChatCompletionTool::Function { - function: ToolCallFunctionDefinition { - parameters: Some(Value::Object(parameters_with_strict_mode( - tool.parameters, - strict, - ))), - name: tool.name, - description: tool.description.or(Some(String::new())), - strict: Some(strict), - }, - }) - .filter(|tool| match tool_choice { - ToolChoice::Function(req) => matches!( - tool, - chat::ChatCompletionTool::Function { - function: ToolCallFunctionDefinition { name, .. } - } if name == req - ), - _ => true, - }) - .collect::>() +// These mirror the llama.cpp server's `common_chat_msg_diff_to_json_oaicompat` +// output. The critical addition over the `openai` crate is the +// `reasoning_content` field, which carries extracted reasoning for the +// `--reasoning-format deepseek` (default) and `deepseek-legacy` modes. + +#[derive(Debug, Deserialize)] +struct StreamChunk { + #[serde(default)] + choices: Vec, } -fn convert_tool_choice(choice: &ToolChoice) -> chat::ToolChoice { - match choice { - ToolChoice::Auto => chat::ToolChoice::mode(chat::ToolChoiceMode::Auto), - ToolChoice::None => chat::ToolChoice::mode(chat::ToolChoiceMode::None), - ToolChoice::Required | ToolChoice::Function(_) => { - chat::ToolChoice::mode(chat::ToolChoiceMode::Required) - } - } +#[derive(Debug, Deserialize)] +struct StreamChoice { + delta: StreamDelta, + #[serde(default)] + finish_reason: Option, } -/// Convert a list of content into system messages. -fn to_system_messages(parts: Vec) -> impl Iterator { - parts.into_iter().map(|content| ChatCompletionMessage { - role: ChatCompletionMessageRole::System, - content: Some(content), - ..Default::default() - }) +#[derive(Debug, Deserialize, Default)] +struct StreamDelta { + #[serde(default)] + content: Option, + /// Reasoning content extracted by the server (deepseek / deepseek-legacy). + /// This is a non-standard `DeepSeek` extension that llama.cpp also uses. + #[serde(default)] + reasoning_content: Option, + #[serde(default)] + tool_calls: Option>, } -fn convert_events(events: ConversationStream) -> Vec { - events - .into_iter() - .filter_map(|event| match event.into_kind() { - EventKind::ChatRequest(request) => Some(ChatCompletionMessage { - role: ChatCompletionMessageRole::User, - content: Some(request.content), - ..Default::default() - }), - EventKind::ChatResponse(response) => Some(ChatCompletionMessage { - role: ChatCompletionMessageRole::Assistant, - content: Some(response.into_content()), - ..Default::default() - }), - EventKind::ToolCallRequest(request) => Some(ChatCompletionMessage { - role: ChatCompletionMessageRole::Assistant, - tool_calls: Some(vec![chat::ToolCall { - id: request.id.clone(), - r#type: chat::FunctionType::Function, - function: ToolCallFunction { - name: request.name.clone(), - arguments: Value::Object(request.arguments.clone()).to_string(), - }, - }]), - ..Default::default() - }), - EventKind::ToolCallResponse(ToolCallResponse { id, result }) => { - Some(ChatCompletionMessage { - role: ChatCompletionMessageRole::Tool, - tool_call_id: Some(id), - content: Some(match result { - Ok(content) | Err(content) => content, - }), - ..Default::default() - }) - } - _ => None, - }) - .fold(vec![], |mut messages, message| match messages.last_mut() { - Some(last) if message.tool_calls.is_some() && last.tool_calls.is_some() => { - last.tool_calls - .get_or_insert_default() - .extend(message.tool_calls.unwrap_or_default()); - messages - } - _ => { - messages.push(message); - messages - } - }) +#[derive(Debug, Deserialize)] +struct ToolCallDelta { + #[serde(default)] + index: u32, + #[serde(default)] + id: Option, + #[serde(default)] + function: Option, +} + +#[derive(Debug, Deserialize)] +struct FunctionDelta { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, } -// #[cfg(test)] -// mod tests { -// use jp_config::providers::llm::LlmProviderConfig; -// use jp_test::{Result, fn_name, mock::Vcr}; -// use test_log::test; -// -// use super::*; -// -// fn vcr() -> Vcr { -// Vcr::new("http://127.0.0.1:8080", env!("CARGO_MANIFEST_DIR")) -// } -// -// #[test(tokio::test)] -// async fn test_llamacpp_models() -> Result { -// let mut config = LlmProviderConfig::default().llamacpp; -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// Llamacpp::try_from(&config).unwrap().models().await -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_llamacpp_chat_completion() -> Result { -// let mut config = LlmProviderConfig::default().llamacpp; -// let model_id = "llamacpp/llama3:latest".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let query = ChatQuery { -// thread: Thread { -// events: ConversationStream::default().with_chat_request("Test message"), -// ..Default::default() -// }, -// ..Default::default() -// }; -// -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// -// Llamacpp::try_from(&config) -// .unwrap() -// .chat_completion(&model, query) -// .await -// }, -// ) -// .await -// } -// } +#[cfg(test)] +#[path = "llamacpp_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/llamacpp_tests.rs b/crates/jp_llm/src/provider/llamacpp_tests.rs new file mode 100644 index 00000000..97518ac6 --- /dev/null +++ b/crates/jp_llm/src/provider/llamacpp_tests.rs @@ -0,0 +1,181 @@ +use super::*; + +#[test] +fn parse_deepseek_format_reasoning_in_dedicated_field() { + // The default `--reasoning-format deepseek`: reasoning arrives in + // `reasoning_content`, regular content in `content`. + let json = r#"{ + "choices": [{ + "delta": { + "reasoning_content": "Let me think step by step...", + "content": null + }, + "index": 0, + "finish_reason": null + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert_eq!(chunk.choices.len(), 1); + + let delta = &chunk.choices[0].delta; + assert_eq!( + delta.reasoning_content.as_deref(), + Some("Let me think step by step...") + ); + assert!(delta.content.is_none()); +} + +#[test] +fn parse_deepseek_format_content_after_reasoning() { + let json = r#"{ + "choices": [{ + "delta": { + "reasoning_content": null, + "content": "The answer is 42." + }, + "index": 0, + "finish_reason": null + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + let delta = &chunk.choices[0].delta; + assert!(delta.reasoning_content.is_none()); + assert_eq!(delta.content.as_deref(), Some("The answer is 42.")); +} + +#[test] +fn parse_none_format_think_tags_in_content() { + // `--reasoning-format none`: everything in `content`, with tags. + // No `reasoning_content` field at all. + let json = r#"{ + "choices": [{ + "delta": { + "content": "\nLet me reason...\n\nThe answer." + }, + "index": 0, + "finish_reason": null + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + let delta = &chunk.choices[0].delta; + assert!(delta.reasoning_content.is_none()); + assert!(delta.content.as_ref().unwrap().contains("")); +} + +#[test] +fn parse_finish_reason() { + let json = r#"{ + "choices": [{ + "delta": {}, + "index": 0, + "finish_reason": "stop" + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("stop")); +} + +#[test] +fn parse_tool_call_delta() { + let json = r#"{ + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "id": "call_abc123", + "function": { + "name": "get_weather", + "arguments": "{\"city\":" + } + }] + }, + "index": 0, + "finish_reason": null + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id.as_deref(), Some("call_abc123")); + let func = tool_calls[0].function.as_ref().unwrap(); + assert_eq!(func.name.as_deref(), Some("get_weather")); + assert_eq!(func.arguments.as_deref(), Some("{\"city\":")); +} + +#[test] +fn parse_empty_choices() { + // Some servers send empty choices arrays (e.g. usage-only chunks). + let json = r#"{"choices": []}"#; + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert!(chunk.choices.is_empty()); +} + +#[test] +fn parse_missing_optional_fields() { + // Minimal delta with only content. + let json = r#"{"choices": [{"delta": {"content": "hi"}}]}"#; + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + let delta = &chunk.choices[0].delta; + assert_eq!(delta.content.as_deref(), Some("hi")); + assert!(delta.reasoning_content.is_none()); + assert!(delta.tool_calls.is_none()); + assert!(chunk.choices[0].finish_reason.is_none()); +} + +#[test] +fn convert_events_merges_consecutive_tool_calls() { + use jp_conversation::event::ToolCallRequest; + + let mut events = ConversationStream::new_test(); + events.push(ConversationEvent::now(ToolCallRequest { + id: "call_1".into(), + name: "tool_a".into(), + arguments: serde_json::Map::new(), + })); + events.push(ConversationEvent::now(ToolCallRequest { + id: "call_2".into(), + name: "tool_b".into(), + arguments: serde_json::Map::new(), + })); + + let messages = convert_events(events); + + // Should be merged into a single assistant message with 2 tool_calls. + assert_eq!(messages.len(), 1); + let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0]["function"]["name"], "tool_a"); + assert_eq!(tool_calls[1]["function"]["name"], "tool_b"); +} + +#[test] +fn convert_events_wraps_reasoning_in_think_tags() { + let mut events = ConversationStream::new_test(); + events.push(ConversationEvent::now(ChatResponse::reasoning( + "step 1: think hard", + ))); + + let messages = convert_events(events); + + assert_eq!(messages.len(), 1); + let content = messages[0]["content"].as_str().unwrap(); + assert!(content.starts_with("")); + assert!(content.contains("step 1: think hard")); + assert!(content.ends_with("")); +} + +#[test] +fn convert_tool_choice_values() { + assert_eq!(convert_tool_choice(&ToolChoice::Auto), "auto"); + assert_eq!(convert_tool_choice(&ToolChoice::None), "none"); + assert_eq!(convert_tool_choice(&ToolChoice::Required), "required"); + assert_eq!( + convert_tool_choice(&ToolChoice::Function("my_fn".into())), + "required" + ); +} diff --git a/crates/jp_llm/src/provider/mock.rs b/crates/jp_llm/src/provider/mock.rs new file mode 100644 index 00000000..bfa36b9b --- /dev/null +++ b/crates/jp_llm/src/provider/mock.rs @@ -0,0 +1,203 @@ +//! Mock provider for testing LLM interactions without real API calls. +//! +//! This module provides a configurable mock implementation of the [`Provider`] +//! trait, useful for: +//! +//! - Integration tests that need to simulate LLM responses +//! - Testing interrupt/signal handling during streaming +//! - Verifying persistence logic without network calls +//! +//! # Example +//! +//! ```ignore +//! use jp_llm::provider::mock::MockProvider; +//! use jp_llm::event::{Event, FinishReason}; +//! use jp_conversation::event::{ConversationEvent, ChatResponse}; +//! +//! // Create a provider that returns a simple message +//! let provider = MockProvider::with_message("Hello, world!"); +//! +//! // Or create one with custom events for more complex scenarios +//! let provider = MockProvider::new(vec![ +//! Event::Part { +//! index: 0, +//! event: ConversationEvent::now(ChatResponse::message("Hello")), +//! }, +//! Event::flush(0), +//! Event::Finished(FinishReason::Completed), +//! ]); +//! ``` + +use async_trait::async_trait; +use futures::stream; +use jp_config::model::id::{ModelIdConfig, Name, ProviderId}; +use jp_conversation::event::{ChatResponse, ToolCallRequest}; +use serde_json::{Map, Value}; + +use super::Provider; +use crate::{ + error::Result, + event::{Event, FinishReason}, + model::ModelDetails, + query::ChatQuery, + stream::EventStream, +}; + +/// A mock LLM provider for testing. +/// +/// Returns predetermined events from [`chat_completion_stream`], allowing tests +/// to simulate various LLM behaviors without making real API calls. +/// +/// [`chat_completion_stream`]: Provider::chat_completion_stream +#[derive(Debug, Clone)] +pub struct MockProvider { + /// Events to return from the stream. + events: Vec, + + /// Model details to return. + model: ModelDetails, +} + +impl MockProvider { + /// Create a new mock provider with the given events. + /// + /// The events will be returned in order from [`chat_completion_stream`]. + /// + /// [`chat_completion_stream`]: Provider::chat_completion_stream + #[must_use] + pub fn new(events: Vec) -> Self { + Self { + events, + model: Self::default_model(), + } + } + + /// Create a mock provider that streams a simple message response. + /// + /// Useful for basic tests that just need some content to be streamed. + #[must_use] + pub fn with_message(content: &str) -> Self { + Self::new(vec![ + Event::Part { + index: 0, + event: ChatResponse::message(content).into(), + }, + Event::flush(0), + Event::Finished(FinishReason::Completed), + ]) + } + + /// Create a mock provider that streams reasoning followed by a message. + #[must_use] + pub fn with_reasoning_and_message(reasoning: &str, message: &str) -> Self { + Self::new(vec![ + Event::Part { + index: 0, + event: ChatResponse::reasoning(reasoning).into(), + }, + Event::flush(0), + Event::Part { + index: 1, + event: ChatResponse::message(message).into(), + }, + Event::flush(1), + Event::Finished(FinishReason::Completed), + ]) + } + + /// Create a mock provider that streams content in multiple chunks. + /// + /// Useful for testing streaming behavior and partial content handling. + #[must_use] + pub fn with_chunked_message(chunks: &[&str]) -> Self { + let mut events = Vec::with_capacity(chunks.len() + 2); + + for &chunk in chunks { + events.push(Event::Part { + index: 0, + event: ChatResponse::message(chunk).into(), + }); + } + + events.push(Event::flush(0)); + events.push(Event::Finished(FinishReason::Completed)); + + Self::new(events) + } + + /// Create a mock provider that requests a tool call. + #[must_use] + pub fn with_tool_call( + tool_id: impl Into, + tool_name: impl Into, + arguments: Map, + ) -> Self { + Self::new(vec![ + Event::Part { + index: 0, + event: ToolCallRequest { + id: tool_id.into(), + name: tool_name.into(), + arguments, + } + .into(), + }, + Event::flush(0), + Event::Finished(FinishReason::Completed), + ]) + } + + /// Set custom model details for this provider. + #[must_use] + pub fn with_model(mut self, model: ModelDetails) -> Self { + self.model = model; + self + } + + /// Set the model name. + #[must_use] + pub fn with_model_name(mut self, name: impl Into) -> Self { + self.model.id = Self::make_model_id(name); + self + } + + fn default_model() -> ModelDetails { + ModelDetails::empty(Self::make_model_id("mock-model")) + } + + fn make_model_id(name: impl Into) -> ModelIdConfig { + ModelIdConfig { + provider: ProviderId::Test, + name: name.into().parse().expect("valid model name"), + } + } +} + +#[async_trait] +impl Provider for MockProvider { + async fn model_details(&self, name: &Name) -> Result { + let mut model = self.model.clone(); + model.id = ModelIdConfig { + provider: ProviderId::Test, + name: name.clone(), + }; + Ok(model) + } + + async fn models(&self) -> Result> { + Ok(vec![self.model.clone()]) + } + + async fn chat_completion_stream( + &self, + _model: &ModelDetails, + _query: ChatQuery, + ) -> Result { + let events = self.events.clone(); + Ok(Box::pin(stream::iter(events.into_iter().map(Ok)))) + } +} + +#[cfg(test)] +#[path = "mock_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/mock_tests.rs b/crates/jp_llm/src/provider/mock_tests.rs new file mode 100644 index 00000000..176b70b9 --- /dev/null +++ b/crates/jp_llm/src/provider/mock_tests.rs @@ -0,0 +1,124 @@ +use futures::StreamExt; +use jp_conversation::{ConversationStream, thread::Thread}; + +use super::*; + +fn empty_query() -> ChatQuery { + ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events: ConversationStream::new_test(), + }, + tools: vec![], + tool_choice: jp_config::assistant::tool_choice::ToolChoice::Auto, + } +} + +fn test_name(s: &str) -> Name { + s.parse().expect("valid test name") +} + +#[tokio::test] +async fn test_with_message() { + let provider = MockProvider::with_message("Hello, world!"); + let model = provider.model_details(&test_name("test")).await.unwrap(); + + let mut stream = provider + .chat_completion_stream(&model, empty_query()) + .await + .unwrap(); + + // First event: Part with message + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Part { index: 0, .. })); + + // Second event: Flush + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Flush { index: 0, .. })); + + // Third event: Finished + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Finished(FinishReason::Completed))); + + // Stream should be exhausted + assert!(stream.next().await.is_none()); +} + +#[tokio::test] +async fn test_with_chunked_message() { + let provider = MockProvider::with_chunked_message(&["Hello, ", "world", "!"]); + let model = provider.model_details(&test_name("test")).await.unwrap(); + + let mut stream = provider + .chat_completion_stream(&model, empty_query()) + .await + .unwrap(); + + // Three Part events + for _ in 0..3 { + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Part { index: 0, .. })); + } + + // Flush + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Flush { index: 0, .. })); + + // Finished + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Finished(FinishReason::Completed))); +} + +#[tokio::test] +async fn test_with_reasoning_and_message() { + let provider = MockProvider::with_reasoning_and_message("thinking...", "done"); + let model = provider.model_details(&test_name("test")).await.unwrap(); + + let mut stream = provider + .chat_completion_stream(&model, empty_query()) + .await + .unwrap(); + + // Reasoning part at index 0 + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Part { index: 0, .. })); + + // Flush index 0 + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Flush { index: 0, .. })); + + // Message part at index 1 + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Part { index: 1, .. })); + + // Flush index 1 + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Flush { index: 1, .. })); + + // Finished + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Finished(FinishReason::Completed))); +} + +#[tokio::test] +async fn test_model_details() { + let provider = MockProvider::with_message("test"); + let model = provider + .model_details(&test_name("custom-name")) + .await + .unwrap(); + + assert_eq!(model.id.name.as_ref(), "custom-name"); + assert_eq!(model.id.provider, ProviderId::Test); +} + +#[tokio::test] +async fn test_models_list() { + let provider = MockProvider::with_message("test").with_model_name("my-model"); + let models = provider.models().await.unwrap(); + + assert_eq!(models.len(), 1); + assert_eq!(models[0].id.name.as_ref(), "my-model"); +} diff --git a/crates/jp_llm/src/provider/ollama.rs b/crates/jp_llm/src/provider/ollama.rs index dfc7583c..d458e70b 100644 --- a/crates/jp_llm/src/provider/ollama.rs +++ b/crates/jp_llm/src/provider/ollama.rs @@ -1,7 +1,7 @@ -use std::{mem, str::FromStr as _}; +use std::str::FromStr as _; use async_trait::async_trait; -use futures::{FutureExt as _, StreamExt as _, future, stream}; +use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _, future, stream}; use jp_config::{ assistant::tool_choice::ToolChoice, model::{ @@ -16,9 +16,10 @@ use jp_conversation::{ }; use ollama_rs::{ Ollama as Client, + error::OllamaError, generation::{ chat::{ChatMessage, ChatMessageResponse, MessageRole, request::ChatMessageRequest}, - parameters::{KeepAlive, TimeUnit}, + parameters::{FormatType, JsonStructure, KeepAlive, TimeUnit}, tools::{ToolCall, ToolCallFunction, ToolFunctionInfo, ToolInfo, ToolType}, }, models::{LocalModel, ModelOptions}, @@ -29,10 +30,9 @@ use url::Url; use super::{EventStream, ModelDetails, Provider}; use crate::{ - error::{Error, Result}, + error::{Error, Result, StreamError}, event::{Event, FinishReason}, query::ChatQuery, - stream::aggregator::reasoning::ReasoningExtractor, tool::ToolDefinition, }; @@ -72,8 +72,7 @@ impl Provider for Ollama { "Starting Ollama chat completion stream." ); - let mut extractor = ReasoningExtractor::default(); - let request = create_request(model, query)?; + let (request, is_structured) = create_request(model, query)?; trace!( request = serde_json::to_string(&request).unwrap_or_default(), @@ -84,11 +83,19 @@ impl Provider for Ollama { .client .send_chat_messages_stream(request) .await? - .filter_map(|v| future::ready(v.ok())) - .map(move |v| stream::iter(map_event(v, &mut extractor))) - .flatten() - .chain(future::ready(Event::Finished(FinishReason::Completed)).into_stream()) - .map(Ok) + .map(|v| v.map_err(|()| StreamError::other("Ollama stream error"))) + .map_ok({ + let mut reasoning_flushed = false; + move |v| { + stream::iter( + map_event(v, is_structured, &mut reasoning_flushed) + .into_iter() + .map(Ok), + ) + } + }) + .try_flatten() + .chain(future::ready(Ok(Event::Finished(FinishReason::Completed))).into_stream()) .boxed()) } } @@ -106,74 +113,111 @@ fn map_model(model: LocalModel) -> Result { }) } -fn map_event(event: ChatMessageResponse, extractor: &mut ReasoningExtractor) -> Vec { +/// Map an Ollama streaming chunk into provider-agnostic events. +/// +/// Index convention: 0 = reasoning, 1 = message content, 2+ = tool calls. +/// +/// Ollama guarantees that thinking tokens arrive before content/tool call +/// tokens. We exploit this by flushing the reasoning stream (index 0) as soon +/// as the first content or tool call chunk appears, ensuring the reasoning +/// event precedes content and tool call events in the history. +fn map_event( + event: ChatMessageResponse, + is_structured: bool, + reasoning_flushed: &mut bool, +) -> Vec { let ChatMessageResponse { message, done, .. } = event; - let mut events = fetch_content(extractor, done); + trace!( + content = message.content, + thinking = message.thinking, + tool_calls = message.tool_calls.len(), + done, + "Ollama stream chunk." + ); - for ( - index, - ToolCall { - function: ToolCallFunction { name, arguments }, - }, - ) in message.tool_calls.into_iter().enumerate() + let mut events = Vec::new(); + + if let Some(thinking) = message.thinking + && !thinking.is_empty() { - events.extend(vec![ - Event::Part { - // These events don't have any index assigned, but we use `0` - // and `1` for regular chat messages and reasoning, and `2` and - // up for tool calls. - index: index + 2, - event: ConversationEvent::now(ToolCallRequest { - id: String::new(), - name, - arguments: match arguments { - Value::Object(map) => map, - v => Map::from_iter([("input".into(), v)]), - }, - }), - }, - Event::flush(0), - ]); + events.push(Event::Part { + index: 0, + event: ConversationEvent::now(ChatResponse::reasoning(thinking)), + }); } - events -} + let has_content = !message.content.is_empty(); + let has_tool_calls = !message.tool_calls.is_empty(); -fn fetch_content(extractor: &mut ReasoningExtractor, done: bool) -> Vec { - let mut events = Vec::new(); + // Flush reasoning before emitting content or tool calls so the reasoning + // event always precedes them in the conversation history. + if !*reasoning_flushed && (has_content || has_tool_calls) { + events.push(Event::flush(0)); + *reasoning_flushed = true; + } - if !extractor.reasoning.is_empty() { - let reasoning = mem::take(&mut extractor.reasoning); + if has_content { + let response = if is_structured { + ChatResponse::structured(Value::String(message.content)) + } else { + ChatResponse::message(message.content) + }; events.push(Event::Part { - index: 0, - event: ConversationEvent::now(ChatResponse::reasoning(reasoning)), + index: 1, + event: ConversationEvent::now(response), }); } - if !extractor.other.is_empty() { - let content = mem::take(&mut extractor.other); + for ( + index, + ToolCall { + function: ToolCallFunction { name, arguments }, + }, + ) in message.tool_calls.into_iter().enumerate() + { + let index = index + 2; events.push(Event::Part { - index: 1, - event: ConversationEvent::now(ChatResponse::message(content)), + index, + event: ConversationEvent::now(ToolCallRequest { + id: String::new(), + name, + arguments: match arguments { + Value::Object(map) => map, + v => Map::from_iter([("input".into(), v)]), + }, + }), }); + events.push(Event::flush(index)); } if done { - events.extend(vec![Event::flush(0), Event::flush(1)]); + if !*reasoning_flushed { + events.push(Event::flush(0)); + } + events.push(Event::flush(1)); } events } -fn create_request(model: &ModelDetails, query: ChatQuery) -> Result { +fn create_request(model: &ModelDetails, query: ChatQuery) -> Result<(ChatMessageRequest, bool)> { let ChatQuery { thread, tools, tool_choice, - tool_call_strict_mode, } = query; + let structured_schema = thread + .events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()) + .map(|schema| JsonStructure::new_for_schema(schema.into())) + .map(|schema| FormatType::StructuredJson(Box::new(schema))); + + let is_structured = structured_schema.is_some(); + let config = thread.events.config()?; let parameters = &config.assistant.model.parameters; @@ -185,7 +229,7 @@ fn create_request(model: &ModelDetails, query: ChatQuery) -> Result Result for Ollama { @@ -265,7 +313,7 @@ impl TryFrom<&OllamaConfig> for Ollama { } } -fn convert_tools(tools: Vec, _strict: bool) -> Result> { +fn convert_tools(tools: Vec) -> Result> { tools .into_iter() .map(|tool| { @@ -334,6 +382,7 @@ fn convert_events(events: ConversationStream) -> Vec { images: None, thinking: Some(reasoning), }), + ChatResponse::Structured { data } => Some(ChatMessage::assistant(data.to_string())), }, EventKind::ToolCallRequest(request) => Some(ChatMessage { role: MessageRole::Assistant, @@ -371,204 +420,20 @@ fn convert_events(events: ConversationStream) -> Vec { }) } -// #[cfg(test)] -// mod tests { -// use jp_config::providers::llm::LlmProviderConfig; -// use jp_conversation::event::ChatResponse; -// use jp_test::{Result, fn_name, mock::Vcr}; -// use test_log::test; -// -// use super::*; -// use crate::structured; -// -// fn vcr(url: &str) -> Vcr { -// Vcr::new(url, env!("CARGO_MANIFEST_DIR")) -// } -// -// #[test(tokio::test)] -// async fn test_ollama_models() -> Result { -// let mut config = LlmProviderConfig::default().ollama; -// let vcr = vcr(&config.base_url); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// -// Ollama::try_from(&config) -// .unwrap() -// .models() -// .await -// .map(|mut v| { -// v.truncate(2); -// v -// }) -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_ollama_chat_completion() -> Result { -// let mut config = LlmProviderConfig::default().ollama; -// let model_id = "ollama/llama3:latest".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let query = ChatQuery { -// thread: Thread { -// events: ConversationStream::default().with_chat_request("Test message"), -// ..Default::default() -// }, -// ..Default::default() -// }; -// -// let vcr = vcr(&config.base_url); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// -// Ollama::try_from(&config) -// .unwrap() -// .chat_completion(&model, query) -// .await -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_ollama_chat_completion_stream() -> Result { -// let mut config = LlmProviderConfig::default().ollama; -// let model_id = "ollama/llama3:latest".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let query = ChatQuery { -// thread: Thread { -// events: ConversationStream::default().with_chat_request("Test message"), -// ..Default::default() -// }, -// ..Default::default() -// }; -// -// let vcr = vcr(&config.base_url); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// -// Ollama::try_from(&config) -// .unwrap() -// .chat_completion_stream(&model, query) -// .await -// .unwrap() -// .collect::>() -// .await -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_ollama_structured_completion() -> Result { -// let mut config = LlmProviderConfig::default().ollama; -// let model_id = "ollama/llama3.1:8b".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let history = ConversationStream::default() -// .with_chat_request("Test message") -// .with_chat_response(ChatResponse::reasoning("Test response")); -// -// let vcr = vcr(&config.base_url); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// let query = structured::titles::titles(3, history, &[]).unwrap(); -// -// Ollama::try_from(&config) -// .unwrap() -// .structured_completion(&model, query) -// .await -// }, -// ) -// .await -// } -// -// mod chunk_parser { -// use test_log::test; -// -// use super::*; -// -// #[test] -// fn test_no_think_tag_at_all() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("some other text"); -// parser.finalize(); -// assert_eq!(parser.other, "some other text"); -// assert_eq!(parser.reasoning, ""); -// } -// -// #[test] -// fn test_standard_case_with_newline() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("prefix\n\nthoughts\n\nsuffix"); -// parser.finalize(); -// assert_eq!(parser.reasoning, "thoughts\n"); -// assert_eq!(parser.other, "prefix\nsuffix"); -// } -// -// #[test] -// fn test_suffix_only() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("\nthoughts\n\n\nsuffix text here"); -// parser.finalize(); -// assert_eq!(parser.reasoning, "thoughts\n"); -// assert_eq!(parser.other, "\nsuffix text here"); -// } -// -// #[test] -// fn test_ends_with_closing_tag_no_newline() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("\nfinal thoughts\n"); -// parser.handle(""); -// parser.finalize(); -// assert_eq!(parser.reasoning, "final thoughts\n"); -// assert_eq!(parser.other, ""); -// } -// -// #[test] -// fn test_less_than_symbol_in_reasoning_content_is_not_stripped() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("\na < b is a true statement\n"); -// parser.finalize(); -// // The last '<' is part of "", so "a < b is a true statement" is kept. -// assert_eq!(parser.reasoning, "a < b is a true statement\n"); -// } -// -// #[test] -// fn test_less_than_symbol_not_part_of_tag_is_kept() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("\nhere is a random < symbol"); -// parser.finalize(); -// // The final '<' is not a prefix of '', so it's kept. -// assert_eq!(parser.reasoning, "here is a random < symbol"); -// } -// } -// } +impl From for StreamError { + fn from(err: OllamaError) -> Self { + use ollama_rs::error::InternalOllamaError; + + match err { + OllamaError::ReqwestError(error) => Self::from(error), + OllamaError::ToolCallError(error) => { + StreamError::other("tool-call error").with_source(error) + } + OllamaError::JsonError(error) => StreamError::other("json error").with_source(error), + OllamaError::InternalError(InternalOllamaError { message }) => { + StreamError::other(message) + } + OllamaError::Other(message) => StreamError::other(message), + } + } +} diff --git a/crates/jp_llm/src/provider/openai.rs b/crates/jp_llm/src/provider/openai.rs index 5b242af8..3ac2eb06 100644 --- a/crates/jp_llm/src/provider/openai.rs +++ b/crates/jp_llm/src/provider/openai.rs @@ -6,7 +6,7 @@ use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _, future, stream} use indexmap::{IndexMap, IndexSet}; use jp_config::{ assistant::tool_choice::ToolChoice, - conversation::tool::{OneOrManyTypes, ToolParameterConfig, item::ToolParameterItemConfig}, + conversation::tool::{OneOrManyTypes, ToolParameterConfig}, model::{ id::{Name, ProviderId}, parameters::{CustomReasoningConfig, ReasoningEffort}, @@ -18,7 +18,7 @@ use jp_conversation::{ event::{ChatResponse, ConversationEvent, EventKind, ToolCallRequest, ToolCallResponse}, }; use openai_responses::{ - Client, CreateError, StreamError, + Client, CreateError, StreamError as OpenaiStreamError, types::{self, Include, Request, SummaryConfig}, }; use reqwest::header::{self, HeaderMap, HeaderValue}; @@ -28,7 +28,7 @@ use tracing::{trace, warn}; use super::{EventStream, ModelDetails, Provider}; use crate::{ - error::{Error, Result}, + error::{Error, Result, StreamError, StreamErrorKind}, event::{Event, FinishReason}, model::{ModelDeprecation, ReasoningDetails}, query::ChatQuery, @@ -80,15 +80,15 @@ impl Provider for Openai { model: &ModelDetails, query: ChatQuery, ) -> Result { - let request = create_request(model, query)?; + let (request, is_structured, reasoning_enabled) = create_request(model, query)?; Ok(self .client .stream(request) .or_else(map_error) - .map_ok(|v| stream::iter(map_event(v))) + .map_ok(move |v| stream::iter(map_event(v, is_structured, reasoning_enabled))) .try_flatten() - .chain(future::ok(Event::Finished(FinishReason::Completed)).into_stream()) + .chain(future::ready(Ok(Event::Finished(FinishReason::Completed))).into_stream()) .boxed()) } } @@ -112,41 +112,84 @@ pub(crate) struct ModelResponse { } /// Create a request for the given model and query details. -fn create_request(model: &ModelDetails, query: ChatQuery) -> Result { +/// +/// Returns `(request, is_structured, reasoning_enabled)`. +fn create_request(model: &ModelDetails, query: ChatQuery) -> Result<(Request, bool, bool)> { let ChatQuery { thread, tools, tool_choice, - tool_call_strict_mode, } = query; + // Only use the schema if the very last event is a ChatRequest with one. + // Transform the schema for OpenAI's strict structured output mode + // before passing it to the request. + let text = thread + .events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()) + .map(|schema| types::TextConfig { + format: types::TextFormat::JsonSchema { + schema: Value::Object(transform_schema(schema)), + description: "Structured output".to_owned(), + name: "structured_output".to_owned(), + strict: Some(true), + }, + }); + + let is_structured = text.is_some(); let parameters = thread.events.config()?.assistant.model.parameters; - let reasoning = model - .custom_reasoning_config(parameters.reasoning) - .map(|r| convert_reasoning(r, model.max_output_tokens)); let supports_reasoning = model .reasoning .is_some_and(|v| !matches!(v, ReasoningDetails::Unsupported)); + let reasoning = match model.custom_reasoning_config(parameters.reasoning) { + Some(r) => Some(convert_reasoning(r, model.max_output_tokens)), + // Explicitly disable reasoning for models that support it when the + // user has turned it off. Sending `null` lets the model use its + // default (which may include reasoning). + // + // For leveled models, use their lowest supported effort. For all + // others (budgetted), fall back to `minimal` which is universally + // supported across OpenAI reasoning models. + None if supports_reasoning => { + let effort = model + .reasoning + .and_then(|r| r.lowest_effort()) + .unwrap_or(ReasoningEffort::Xlow); + Some(convert_reasoning( + CustomReasoningConfig { + effort, + exclude: true, + }, + model.max_output_tokens, + )) + } + None => None, + }; + let reasoning_enabled = model + .custom_reasoning_config(parameters.reasoning) + .is_some(); let messages = thread.into_messages(to_system_messages, convert_events(supports_reasoning))?; - let request = Request { model: types::Model::Other(model.id.name.to_string()), input: types::Input::List(messages), - include: supports_reasoning.then_some(vec![Include::ReasoningEncryptedContent]), + include: reasoning_enabled.then_some(vec![Include::ReasoningEncryptedContent]), store: Some(false), tool_choice: Some(convert_tool_choice(tool_choice)), - tools: Some(convert_tools(tools, tool_call_strict_mode)), + tools: Some(convert_tools(tools)), temperature: parameters.temperature, reasoning, max_output_tokens: parameters.max_tokens.map(Into::into), truncation: Some(types::Truncation::Auto), top_p: parameters.top_p, + text, ..Default::default() }; trace!(?request, "Sending request to OpenAI."); - Ok(request) + Ok((request, is_structured, reasoning_enabled)) } #[expect(clippy::too_many_lines)] @@ -447,27 +490,26 @@ fn map_model(model: ModelResponse) -> Result { Ok(details) } -/// Convert a [`StreamError`] into an [`Error`]. +/// Convert an OpenAI [`OpenaiStreamError`] into a [`StreamError`]. /// -/// This needs an async function because we want to get the response text from -/// the body as contextual information. -async fn map_error(error: StreamError) -> Result { +/// This needs an async function because we read the response body for context. +/// Headers are extracted *before* consuming the body. +async fn map_error(error: OpenaiStreamError) -> std::result::Result { Err(match error { - StreamError::Parsing(error) => error.into(), - StreamError::Stream(error) => match error { - reqwest_eventsource::Error::InvalidStatusCode(status_code, response) => { - Error::OpenaiStatusCode { - status_code, - response: response.text().await.unwrap_or_default(), - } - } - _ => Error::OpenaiEvent(Box::new(error)), - }, + OpenaiStreamError::Stream(error) => StreamError::from(error), + OpenaiStreamError::Parsing(error) => { + StreamError::other(error.to_string()).with_source(error) + } }) } /// Map an Openai [`types::Event`] into one or more [`Event`]s. -fn map_event(event: types::Event) -> Vec> { +#[expect(clippy::too_many_lines)] +fn map_event( + event: types::Event, + is_structured: bool, + reasoning_enabled: bool, +) -> Vec> { use types::Event::*; #[expect(clippy::cast_possible_truncation)] @@ -481,11 +523,26 @@ fn map_event(event: types::Event) -> Vec> { output_index, item: types::OutputItem::Message(_), } => vec![Ok(Event::Part { - event: ConversationEvent::now(ChatResponse::message(String::new())), + event: ConversationEvent::now(if is_structured { + ChatResponse::structured(Value::String(String::new())) + } else { + ChatResponse::message(String::new()) + }), index: output_index as usize, })], - // See the previous `OutputItemAdded` case for details. + // Skip all reasoning events when reasoning is disabled. The model + // may still return minimal reasoning output at `effort: "minimal"`. + OutputItemAdded { + item: types::OutputItem::Reasoning(_), + .. + } + | ReasoningSummaryTextDelta { .. } + | OutputItemDone { + item: types::OutputItem::Reasoning(_), + .. + } if !reasoning_enabled => vec![], + OutputItemAdded { output_index, item: types::OutputItem::Reasoning(_), @@ -498,10 +555,17 @@ fn map_event(event: types::Event) -> Vec> { delta, output_index, .. - } => vec![Ok(Event::Part { - event: ConversationEvent::now(ChatResponse::message(delta)), - index: output_index as usize, - })], + } => { + let response = if is_structured { + ChatResponse::structured(Value::String(delta)) + } else { + ChatResponse::message(delta) + }; + vec![Ok(Event::Part { + event: ConversationEvent::now(response), + index: output_index as usize, + })] + } ReasoningSummaryTextDelta { delta, @@ -513,6 +577,8 @@ fn map_event(event: types::Event) -> Vec> { })], OutputItemDone { item, output_index } => { + let index = output_index as usize; + let mut events = vec![]; let metadata = match &item { types::OutputItem::FunctionCall(_) => IndexMap::new(), types::OutputItem::Message(v) => { @@ -536,14 +602,15 @@ fn map_event(event: types::Event) -> Vec> { | types::OutputItem::ComputerToolCall(_) => return vec![], }; - match item { - types::OutputItem::FunctionCall(types::FunctionCall { - name, - arguments, - call_id, - .. - }) => vec![Ok(Event::Part { - index: output_index as usize, + if let types::OutputItem::FunctionCall(types::FunctionCall { + name, + arguments, + call_id, + .. + }) = item + { + events.push(Ok(Event::Part { + index, event: ConversationEvent::now(ToolCallRequest { id: call_id, name, @@ -555,28 +622,42 @@ fn map_event(event: types::Event) -> Vec> { map }, }), - })], - _ => vec![Ok(Event::flush_with_metadata( - output_index as usize, - metadata, - ))], + })); } + + events.push(Ok(Event::flush_with_metadata(index, metadata))); + events } - Error { - code, - message, - param, - } => vec![Err(types::Error { - r#type: "stream_error".to_owned(), - code, - message, - param, - } - .into())], + Error { error } => vec![Err(classify_stream_error(error))], _ => vec![], } } +/// Classify an OpenAI streaming error event into a [`StreamError`]. +/// +/// Maps well-known error types (quota, rate-limit, auth, server errors) +/// to the appropriate [`StreamErrorKind`] so the retry and display layers +/// can handle them correctly. +fn classify_stream_error(error: types::response::Error) -> StreamError { + match error.r#type.as_str() { + "insufficient_quota" => StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your plan and billing details \ + at https://platform.openai.com/settings/organization/billing. \ + ({})", + error.message + ), + ), + "rate_limit_exceeded" => StreamError::rate_limit(None), + "server_error" | "api_error" => StreamError::transient(error.message), + _ => StreamError::other(format!( + "OpenAI error: type={}, code={:?}, message={}, param={:?}", + error.r#type, error.code, error.message, error.param + )), + } +} + impl TryFrom<&OpenaiConfig> for Openai { type Error = Error; @@ -605,6 +686,145 @@ impl TryFrom<&OpenaiConfig> for Openai { } } +/// Transform a JSON schema for OpenAI's strict structured output mode. +/// +/// OpenAI's structured outputs require: +/// - `additionalProperties: false` on all objects +/// - All properties listed in `required` +/// - `allOf` is not supported and must be flattened +/// +/// Additionally handles: +/// - Unraveling `$ref` that has sibling properties (OpenAI doesn't support +/// `$ref` alongside other keys) +/// - Recursively processing `$defs`/`definitions`, `properties`, `items`, and +/// `anyOf` +/// - Stripping `null` defaults +/// +/// Unlike Google, OpenAI supports `$ref`/`$defs` and `const` natively, so those +/// are left in place when standalone. +/// +/// See: +fn transform_schema(src: Map) -> Map { + let root = Value::Object(src.clone()); + process_schema(src, &root) +} + +/// Core recursive processor for a single schema node. +fn process_schema(mut src: Map, root: &Value) -> Map { + // Recursively process $defs/definitions in place. + for key in ["$defs", "definitions"] { + if let Some(Value::Object(defs)) = src.remove(key) { + let processed: Map = defs + .into_iter() + .map(|(k, v)| (k, resolve_and_process(v, root))) + .collect(); + src.insert(key.into(), Value::Object(processed)); + } + } + + // Force `additionalProperties: false` on all objects. + // The docs require this for strict mode. + if src.get("type").and_then(Value::as_str) == Some("object") { + src.insert("additionalProperties".into(), Value::Bool(false)); + } + + // Force all properties into `required` (strict mode requirement). + if let Some(Value::Object(props)) = src.get("properties") { + let keys: Vec = props.keys().map(|k| Value::String(k.clone())).collect(); + src.insert("required".into(), Value::Array(keys)); + } + + // Recursively process object properties. + if let Some(Value::Object(props)) = src.remove("properties") { + let processed: Map = props + .into_iter() + .map(|(k, v)| (k, resolve_and_process(v, root))) + .collect(); + src.insert("properties".into(), Value::Object(processed)); + } + + // Recursively process array items. + if let Some(items) = src.remove("items") { + src.insert("items".into(), resolve_and_process(items, root)); + } + + // Recursively process anyOf variants. + if let Some(Value::Array(variants)) = src.remove("anyOf") { + src.insert( + "anyOf".into(), + Value::Array( + variants + .into_iter() + .map(|v| resolve_and_process(v, root)) + .collect(), + ), + ); + } + + // Flatten `allOf` — not supported by OpenAI. + // Merge all entries into the parent schema; later entries yield to + // earlier ones (and to keys already present on the parent). + if let Some(Value::Array(entries)) = src.remove("allOf") { + for entry in entries { + if let Value::Object(entry_map) = resolve_and_process(entry, root) { + for (k, v) in entry_map { + src.entry(k).or_insert(v); + } + } + } + } + + // Strip `null` defaults (no meaningful distinction for strict mode). + if src.get("default") == Some(&Value::Null) { + src.remove("default"); + } + + // Unravel `$ref` when it has sibling properties. + // OpenAI supports standalone `$ref` but not alongside other keys. + if src.contains_key("$ref") + && src.len() > 1 + && let Some(Value::String(ref_path)) = src.remove("$ref") + { + if let Some(resolved) = resolve_ref(&ref_path, root) { + // Current schema properties take priority over the + // resolved definition's. + let mut merged = resolved; + for (k, v) in src { + merged.insert(k, v); + } + return process_schema(merged, root); + } + // Failed to resolve — put it back. + src.insert("$ref".into(), Value::String(ref_path)); + } + + src +} + +/// Recursively process a value that may be a schema object. +fn resolve_and_process(value: Value, root: &Value) -> Value { + match value { + Value::Object(map) => Value::Object(process_schema(map, root)), + other => other, + } +} + +/// Resolve a JSON pointer against the root schema. +/// +/// Handles paths like `#/$defs/MyType` and `#` (root self-reference). +fn resolve_ref(ref_path: &str, root: &Value) -> Option> { + if ref_path == "#" { + return root.as_object().cloned(); + } + + let path = ref_path.strip_prefix("#/")?; + let mut current = root; + for segment in path.split('/') { + current = current.get(segment)?; + } + current.as_object().cloned() +} + fn convert_tool_choice(choice: ToolChoice) -> types::ToolChoice { match choice { ToolChoice::Auto => types::ToolChoice::Auto, @@ -706,9 +926,7 @@ fn make_config_nullable(cfg: &mut ToolParameterConfig) { /// moving array-based enums into the 'items' configuration. fn sanitize_parameter(config: &mut ToolParameterConfig) { if let Some(items) = &mut config.items { - let mut item_config = items.clone().into(); - sanitize_parameter(&mut item_config); - *items = item_config.into(); + sanitize_parameter(items); } let allows_array = match &config.kind { @@ -758,26 +976,31 @@ fn sanitize_parameter(config: &mut ToolParameterConfig) { OneOrManyTypes::Many(inferred_types.into_iter().collect()) }; - ToolParameterItemConfig { + Box::new(ToolParameterConfig { kind, default: None, + required: false, + summary: None, description: None, + examples: None, enumeration: vec![], - } + items: None, + properties: IndexMap::default(), + }) }); // Append the flattened values to the items enum items_config.enumeration.extend(items); } -fn convert_tools(tools: Vec, strict: bool) -> Vec { +fn convert_tools(tools: Vec) -> Vec { tools .into_iter() .map(|tool| types::Tool::Function { name: tool.name, - strict, + strict: true, description: tool.description, - parameters: parameters_with_strict_mode(tool.parameters, strict).into(), + parameters: parameters_with_strict_mode(tool.parameters, true).into(), }) .collect() } @@ -792,11 +1015,17 @@ fn convert_reasoning( } else { Some(SummaryConfig::Auto) }, - effort: match reasoning.effort.abs_to_rel(max_tokens) { - ReasoningEffort::XHigh => Some(types::ReasoningEffort::XHigh), + effort: match reasoning + .effort + .abs_to_rel(max_tokens) + .unwrap_or(ReasoningEffort::Auto) + { + ReasoningEffort::None => Some(types::ReasoningEffort::None), + ReasoningEffort::Max | ReasoningEffort::XHigh => Some(types::ReasoningEffort::XHigh), ReasoningEffort::High => Some(types::ReasoningEffort::High), ReasoningEffort::Auto | ReasoningEffort::Medium => Some(types::ReasoningEffort::Medium), - ReasoningEffort::Low | ReasoningEffort::Xlow => Some(types::ReasoningEffort::Low), + ReasoningEffort::Low => Some(types::ReasoningEffort::Low), + ReasoningEffort::Xlow => Some(types::ReasoningEffort::Minimal), ReasoningEffort::Absolute(_) => { debug_assert!(false, "Reasoning effort must be relative."); None @@ -898,6 +1127,12 @@ fn convert_events( })] } } + ChatResponse::Structured { data } => { + vec![types::InputListItem::Message(types::InputMessage { + role: types::Role::Assistant, + content: types::ContentInput::Text(data.to_string()), + })] + } } } EventKind::ToolCallRequest(request) => vec![types::InputListItem::Item( @@ -928,232 +1163,12 @@ fn convert_events( } } -// /// Converts a single event into `OpenAI` input items. -// /// -// /// Note: `OpenAI` requires separate items for different content types. -// fn convert_events( -// events: ConversationStream, -// supports_reasoning: bool, -// ) -> Vec { -// events -// .into_iter() -// .flat_map(|event| { -// let ConversationEvent { -// kind, mut metadata, .. -// } = event.event; -// -// match kind { -// EventKind::ChatRequest(request) => { -// vec![types::InputListItem::Message(types::InputMessage { -// role: types::Role::User, -// content: types::ContentInput::Text(request.content), -// })] -// } -// EventKind::ChatResponse(response) => { -// if let Some(item) = metadata.remove(ENCODED_PAYLOAD_KEY).and_then(|s| { -// Some(if response.is_reasoning() { -// if !supports_reasoning { -// return None; -// } -// -// types::InputItem::Reasoning( -// serde_json::from_value::(s).ok()?, -// ) -// } else { -// types::InputItem::OutputMessage( -// serde_json::from_value::(s).ok()?, -// ) -// }) -// }) { -// vec![types::InputListItem::Item(item)] -// } else if response.is_reasoning() { -// // Unsupported reasoning content - wrap in XML tags -// vec![types::InputListItem::Message(types::InputMessage { -// role: types::Role::Assistant, -// content: types::ContentInput::Text(format!( -// "\n{}\n\n\n", -// response.content() -// )), -// })] -// } else { -// vec![types::InputListItem::Message(types::InputMessage { -// role: types::Role::Assistant, -// content: types::ContentInput::Text(response.into_content()), -// })] -// } -// } -// EventKind::ToolCallRequest(request) => { -// let call = metadata -// .remove(ENCODED_PAYLOAD_KEY) -// .and_then(|s| serde_json::from_value::(s).ok()) -// .unwrap_or_else(|| types::FunctionCall { -// call_id: String::new(), -// name: request.name, -// arguments: Value::Object(request.arguments).to_string(), -// status: None, -// id: (!request.id.is_empty()).then_some(request.id), -// }); -// -// vec![types::InputListItem::Item(types::InputItem::FunctionCall( -// call, -// ))] -// } -// EventKind::ToolCallResponse(ToolCallResponse { id, result }) => { -// vec![types::InputListItem::Item( -// types::InputItem::FunctionCallOutput(types::FunctionCallOutput { -// call_id: id, -// output: match result { -// Ok(content) | Err(content) => content, -// }, -// id: None, -// status: None, -// }), -// )] -// } -// _ => vec![], -// } -// }) -// .collect() -// } - -// impl From for Delta { -// fn from(item: types::OutputItem) -> Self { -// match item { -// types::OutputItem::Message(message) => Delta::content( -// message -// .content -// .into_iter() -// .filter_map(|item| match item { -// types::OutputContent::Text { text, .. } => Some(text), -// types::OutputContent::Refusal { .. } => None, -// }) -// .collect::>() -// .join("\n\n"), -// ), -// types::OutputItem::Reasoning(reasoning) => Delta::reasoning( -// reasoning -// .summary -// .into_iter() -// .map(|item| match item { -// types::ReasoningSummary::Text { text, .. } => text, -// }) -// .collect::>() -// .join("\n\n"), -// ), -// types::OutputItem::FunctionCall(call) => { -// Delta::tool_call(call.call_id, call.name, call.arguments).finished() -// } -// _ => Delta::default(), -// } -// } -// } - -// #[cfg(test)] -// mod tests { -// use jp_config::providers::llm::LlmProviderConfig; -// use jp_test::{Result, fn_name, mock::Vcr}; -// use test_log::test; -// -// use super::*; -// -// fn vcr() -> Vcr { -// Vcr::new("https://api.openai.com", env!("CARGO_MANIFEST_DIR")) -// } -// -// #[test(tokio::test)] -// async fn test_openai_model_details() -> Result { -// let mut config = LlmProviderConfig::default().openai; -// let name: Name = "o4-mini".parse().unwrap(); -// -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |recording, url| async move { -// config.base_url = url; -// if !recording { -// // dummy api key value when replaying a cassette -// config.api_key_env = "USER".to_owned(); -// } -// -// Openai::try_from(&config) -// .unwrap() -// .model_details(&name) -// .await -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_openai_models() -> Result { -// let mut config = LlmProviderConfig::default().openai; -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |recording, url| async move { -// config.base_url = url; -// if !recording { -// // dummy api key value when replaying a cassette -// config.api_key_env = "USER".to_owned(); -// } -// -// Openai::try_from(&config) -// .unwrap() -// .models() -// .await -// .map(|mut v| { -// v.truncate(10); -// v -// }) -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_openai_chat_completion() -> Result { -// let mut config = LlmProviderConfig::default().openai; -// let model_id = "openai/o4-mini".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let query = ChatQuery { -// thread: Thread { -// events: ConversationStream::default().with_chat_request("Test message"), -// ..Default::default() -// }, -// ..Default::default() -// }; -// -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |recording, url| async move { -// config.base_url = url; -// if !recording { -// // dummy api key value when replaying a cassette -// config.api_key_env = "USER".to_owned(); -// } -// -// Openai::try_from(&config) -// .unwrap() -// .chat_completion(&model, query) -// .await -// }, -// ) -// .await -// } -// } +impl From for Error { + fn from(error: types::response::Error) -> Self { + Self::OpenaiResponse(error) + } +} + +#[cfg(test)] +#[path = "openai_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/openai_tests.rs b/crates/jp_llm/src/provider/openai_tests.rs new file mode 100644 index 00000000..1d648ba1 --- /dev/null +++ b/crates/jp_llm/src/provider/openai_tests.rs @@ -0,0 +1,378 @@ +mod transform_schema { + use serde_json::{Map, Value, json}; + + use super::super::transform_schema; + + #[expect(clippy::needless_pass_by_value)] + fn schema(v: Value) -> Map { + v.as_object().unwrap().clone() + } + + #[test] + fn additional_properties_false_forced_on_objects() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + } + })); + + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + } + + #[test] + fn additional_properties_true_overridden_to_false() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "additionalProperties": true + })); + + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + } + + #[test] + fn all_properties_forced_into_required() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" } + } + })); + + let out = transform_schema(input); + + assert_eq!(out["required"], json!(["name", "age"])); + } + + #[test] + fn existing_required_overwritten_to_all_properties() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" } + }, + "required": ["name"] + })); + + let out = transform_schema(input); + + // Both properties must be required in strict mode. + assert_eq!(out["required"], json!(["name", "age"])); + } + + #[test] + fn nested_objects_get_strict_treatment() { + let input = schema(json!({ + "type": "object", + "properties": { + "inner": { + "type": "object", + "properties": { + "x": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + let inner = out["properties"]["inner"].as_object().unwrap(); + assert_eq!(inner["additionalProperties"], json!(false)); + assert_eq!(inner["required"], json!(["x"])); + } + + #[test] + fn defs_recursively_processed() { + let input = schema(json!({ + "type": "object", + "properties": { + "step": { "$ref": "#/$defs/Step" } + }, + "$defs": { + "Step": { + "type": "object", + "properties": { + "explanation": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + // $defs is kept (OpenAI supports it natively). + let step_def = out["$defs"]["Step"].as_object().unwrap(); + assert_eq!(step_def["additionalProperties"], json!(false)); + assert_eq!(step_def["required"], json!(["explanation"])); + } + + #[test] + fn standalone_ref_kept_as_is() { + let input = schema(json!({ + "type": "object", + "properties": { + "step": { "$ref": "#/$defs/Step" } + }, + "$defs": { + "Step": { + "type": "object", + "properties": { + "name": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + // Standalone $ref should stay. + assert_eq!(out["properties"]["step"]["$ref"], "#/$defs/Step"); + } + + #[test] + fn ref_with_siblings_unraveled() { + let input = schema(json!({ + "type": "object", + "properties": { + "person": { + "$ref": "#/$defs/Person", + "description": "The main person" + } + }, + "$defs": { + "Person": { + "type": "object", + "properties": { + "name": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + let person = out["properties"]["person"].as_object().unwrap(); + // $ref should be removed, definition inlined. + assert!(person.get("$ref").is_none()); + assert_eq!(person["type"], "object"); + assert_eq!(person["description"], "The main person"); + assert_eq!(person["properties"]["name"]["type"], "string"); + // Inlined object should also get strict treatment. + assert_eq!(person["additionalProperties"], json!(false)); + } + + #[test] + fn anyof_variants_processed() { + let input = schema(json!({ + "type": "object", + "properties": { + "item": { + "anyOf": [ + { + "type": "object", + "properties": { + "name": { "type": "string" } + } + }, + { "type": "string" } + ] + } + } + })); + + let out = transform_schema(input); + + let variants = out["properties"]["item"]["anyOf"].as_array().unwrap(); + let obj_variant = variants[0].as_object().unwrap(); + assert_eq!(obj_variant["additionalProperties"], json!(false)); + assert_eq!(obj_variant["required"], json!(["name"])); + } + + #[test] + fn allof_single_element_merged() { + let input = schema(json!({ + "allOf": [{ + "type": "object", + "properties": { + "name": { "type": "string" } + } + }] + })); + + let out = transform_schema(input); + + assert!(out.get("allOf").is_none()); + assert_eq!(out["type"], "object"); + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["name"])); + } + + #[test] + fn allof_multiple_elements_merged() { + let input = schema(json!({ + "allOf": [ + { + "type": "object", + "properties": { + "name": { "type": "string" } + } + }, + { + "description": "Extra info" + } + ] + })); + + let out = transform_schema(input); + + assert!(out.get("allOf").is_none()); + assert_eq!(out["type"], "object"); + assert_eq!(out["description"], "Extra info"); + } + + #[test] + fn null_default_stripped() { + let input = schema(json!({ + "type": "string", + "default": null + })); + + let out = transform_schema(input); + + assert!(out.get("default").is_none()); + } + + #[test] + fn non_null_default_preserved() { + let input = schema(json!({ + "type": "string", + "default": "hello" + })); + + let out = transform_schema(input); + + assert_eq!(out["default"], "hello"); + } + + #[test] + fn const_preserved_unchanged() { + let input = schema(json!({ + "type": "string", + "const": "tool_call.my_tool.call_123" + })); + + let out = transform_schema(input); + + assert_eq!(out["const"], "tool_call.my_tool.call_123"); + } + + #[test] + fn array_items_recursively_processed() { + let input = schema(json!({ + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { "type": "integer" } + } + } + })); + + let out = transform_schema(input); + + let items = out["items"].as_object().unwrap(); + assert_eq!(items["additionalProperties"], json!(false)); + assert_eq!(items["required"], json!(["id"])); + } + + /// The inquiry schema should get strict treatment. + #[test] + fn inquiry_schema_transforms_correctly() { + let input = schema(json!({ + "type": "object", + "properties": { + "inquiry_id": { + "type": "string", + "const": "tool_call.fs_modify_file.call_a3b7c9d1" + }, + "answer": { + "type": "boolean" + } + }, + "required": ["inquiry_id", "answer"], + "additionalProperties": false + })); + + let out = transform_schema(input); + + // const should be preserved (OpenAI supports it). + assert_eq!( + out["properties"]["inquiry_id"]["const"], + "tool_call.fs_modify_file.call_a3b7c9d1" + ); + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["inquiry_id", "answer"])); + } + + /// The `title_schema` should get strict treatment applied. + #[test] + fn title_schema_gets_strict_treatment() { + let input = crate::title::title_schema(3); + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["titles"])); + + let titles = out["properties"]["titles"].as_object().unwrap(); + let items = titles["items"].as_object().unwrap(); + assert_eq!(items["type"], "string"); + } + + /// Docs example: definitions with $ref. + #[test] + fn definitions_example_from_docs() { + let input = schema(json!({ + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { "$ref": "#/$defs/step" } + }, + "final_answer": { "type": "string" } + }, + "$defs": { + "step": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + })); + + let out = transform_schema(input); + + // Root object: all properties required. + assert_eq!(out["required"], json!(["steps", "final_answer"])); + + // $defs preserved, step def gets required added. + let step_def = out["$defs"]["step"].as_object().unwrap(); + assert_eq!(step_def["required"], json!(["explanation", "output"])); + assert_eq!(step_def["additionalProperties"], json!(false)); + + // $ref stays as-is (standalone, no siblings). + assert_eq!(out["properties"]["steps"]["items"]["$ref"], "#/$defs/step"); + } +} diff --git a/crates/jp_llm/src/provider/openrouter.rs b/crates/jp_llm/src/provider/openrouter.rs index 2cd8b03c..a1572d53 100644 --- a/crates/jp_llm/src/provider/openrouter.rs +++ b/crates/jp_llm/src/provider/openrouter.rs @@ -21,7 +21,7 @@ use jp_openrouter::{ types::{ self, chat::{CacheControl, Content, Message}, - request::{self, RequestMessage}, + request::{self, JsonSchemaFormat, RequestMessage, ResponseFormat}, response::{ self, ChatCompletion as OpenRouterChunk, FinishReason, ReasoningDetails, ReasoningDetailsFormat, ReasoningDetailsKind, @@ -35,8 +35,8 @@ use tracing::{debug, trace, warn}; use super::{EventStream, ModelDetails}; use crate::{ - Error, - error::Result, + Error, StreamError, + error::{Result, StreamErrorKind, looks_like_quota_error}, event::{self, Event}, provider::{Provider, openai::parameters_with_strict_mode}, query::ChatQuery, @@ -108,18 +108,18 @@ impl Provider for Openrouter { "Starting OpenRouter chat completion stream." ); + let (request, is_structured) = build_request(query, model)?; let mut state = AggregationState { tool_calls: ToolCallRequestAggregator::default(), aggregating_reasoning: false, aggregating_message: false, + is_structured, }; - let request = build_request(query, model)?; - Ok(self .client .chat_completion_stream(request) - .map_err(Error::from) + .map_err(StreamError::from) .map_ok(move |v| stream::iter(map_completion(v, &mut state))) .try_flatten() .boxed()) @@ -136,6 +136,9 @@ struct AggregationState { /// Did the stream of events have any message content? aggregating_message: bool, + + /// Whether the current request uses structured (JSON schema) output. + is_structured: bool, } /// Metadata stored in the conversation stream, based on Openrouter @@ -267,7 +270,10 @@ impl From for IndexMap { } } -fn map_completion(v: OpenRouterChunk, state: &mut AggregationState) -> Vec> { +fn map_completion( + v: OpenRouterChunk, + state: &mut AggregationState, +) -> Vec> { v.choices .into_iter() .flat_map(|v| map_event(v, state)) @@ -275,7 +281,10 @@ fn map_completion(v: OpenRouterChunk, state: &mut AggregationState) -> Vec Vec> { +fn map_event( + choice: types::response::Choice, + state: &mut AggregationState, +) -> Vec> { let types::response::Choice::Streaming(types::response::StreamingChoice { finish_reason, delta: @@ -306,7 +315,17 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V let reasoning_details = MultiProviderMetadata::from_details(reasoning_details); if let Some(error) = error { - return vec![Err(Error::from(error))]; + if looks_like_quota_error(&error.message) { + return vec![Err(StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your credits \ + at https://openrouter.ai/settings/credits. ({})", + error.message + ), + ))]; + } + return vec![Err(StreamError::other(error.message))]; } let mut events = vec![]; @@ -329,9 +348,14 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V { state.aggregating_message = true; + let response = if state.is_structured { + ChatResponse::structured(Value::String(content)) + } else { + ChatResponse::message(content) + }; events.push(Ok(Event::Part { index: 1, - event: ConversationEvent::now(ChatResponse::message(content)), + event: ConversationEvent::now(response), })); } @@ -375,7 +399,7 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V index, event: ConversationEvent::now(call), }) - .map_err(Error::from), + .map_err(|e| StreamError::other(e.to_string())), Ok(Event::flush(index)), ] }), @@ -389,10 +413,9 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V Some(FinishReason::Stop) => { events.push(Ok(Event::Finished(event::FinishReason::Completed))); } - Some(FinishReason::Error) => events.push(Err(jp_openrouter::Error::Stream( - "unknown stream error".into(), - ) - .into())), + Some(FinishReason::Error) => { + events.push(Err(StreamError::other("unknown stream error"))); + } Some(reason) => events.push(Ok(Event::Finished(event::FinishReason::Other( reason.as_str().into(), )))), @@ -403,17 +426,37 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V } /// Build request for Openrouter API. -fn build_request(query: ChatQuery, model: &ModelDetails) -> Result { +/// +/// Returns the request and whether structured output is active. +fn build_request( + query: ChatQuery, + model: &ModelDetails, +) -> Result<(request::ChatCompletion, bool)> { let ChatQuery { thread, tools, tool_choice, - tool_call_strict_mode, } = query; let config = thread.events.config()?; let parameters = &config.assistant.model.parameters; + // Only use the schema if the very last event is a ChatRequest with one. + let response_format = thread + .events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()) + .map(|schema| ResponseFormat::JsonSchema { + json_schema: JsonSchemaFormat { + name: "structured_output".to_owned(), + schema: Value::Object(schema), + strict: Some(true), + }, + }); + + let is_structured = response_format.is_some(); + let slug = model.id.name.to_string(); let reasoning = model.custom_reasoning_config(parameters.reasoning); @@ -422,10 +465,10 @@ fn build_request(query: ChatQuery, model: &ModelDetails) -> Result>(); @@ -447,26 +490,40 @@ fn build_request(query: ChatQuery, model: &ModelDetails) -> Result request::ReasoningEffort::XHigh, - ReasoningEffort::High => request::ReasoningEffort::High, - ReasoningEffort::Auto | ReasoningEffort::Medium => request::ReasoningEffort::Medium, - ReasoningEffort::Low | ReasoningEffort::Xlow => request::ReasoningEffort::Low, - ReasoningEffort::Absolute(_) => { - debug_assert!(false, "Reasoning effort must be relative."); - request::ReasoningEffort::Medium - } - }, - }), - tools, - tool_choice, - ..Default::default() - }) + Ok(( + request::ChatCompletion { + model: slug, + messages: messages.0, + reasoning: reasoning.map(|r| request::Reasoning { + exclude: r.exclude, + effort: match r + .effort + .abs_to_rel(model.max_output_tokens) + .unwrap_or(ReasoningEffort::Auto) + { + ReasoningEffort::Max | ReasoningEffort::XHigh => { + request::ReasoningEffort::XHigh + } + ReasoningEffort::High => request::ReasoningEffort::High, + ReasoningEffort::Auto | ReasoningEffort::Medium => { + request::ReasoningEffort::Medium + } + ReasoningEffort::None => request::ReasoningEffort::None, + ReasoningEffort::Xlow => request::ReasoningEffort::Minimal, + ReasoningEffort::Low => request::ReasoningEffort::Low, + ReasoningEffort::Absolute(_) => { + debug_assert!(false, "Reasoning effort must be relative."); + request::ReasoningEffort::Medium + } + }, + }), + tools, + tool_choice, + response_format, + ..Default::default() + }, + is_structured, + )) } // TODO: Manually add a bunch of often-used models. @@ -483,21 +540,6 @@ fn map_model(model: response::Model) -> Result { }) } -// impl From for Delta { -// fn from(delta: StreamingDelta) -> Self { -// let tool_call = delta.tool_calls.into_iter().next(); -// -// Self { -// content: delta.content, -// reasoning: delta.reasoning, -// tool_call_id: tool_call.as_ref().and_then(ToolCall::id), -// tool_call_name: tool_call.as_ref().and_then(ToolCall::name), -// tool_call_arguments: tool_call.as_ref().and_then(ToolCall::arguments), -// tool_call_finished: false, -// } -// } -// } - impl From for Error { fn from(error: types::response::ErrorResponse) -> Self { Self::OpenRouter(jp_openrouter::Error::Api { @@ -534,7 +576,7 @@ impl TryFrom<(&ModelIdConfig, Thread)> for RequestMessages { fn try_from((model_id, thread): (&ModelIdConfig, Thread)) -> Result { let Thread { system_prompt, - instructions, + sections, attachments, events, } = thread; @@ -554,7 +596,7 @@ impl TryFrom<(&ModelIdConfig, Thread)> for RequestMessages { }); } - if !instructions.is_empty() { + if !sections.is_empty() { content.push(Content::Text { text: "Before we continue, here are some contextual details that will help you \ generate a better response." @@ -562,15 +604,15 @@ impl TryFrom<(&ModelIdConfig, Thread)> for RequestMessages { cache_control: None, }); - // Then instructions in XML tags. + // Then sections as rendered text. // - // Cached (3/4), (for the last instruction), as it's not expected to + // Cached (3/4), (for the last section), as it's not expected to // change. - let mut instructions = instructions.iter().peekable(); - while let Some(instruction) = instructions.next() { + let mut sections = sections.iter().peekable(); + while let Some(section) = sections.next() { content.push(Content::Text { - text: instruction.try_to_xml()?, - cache_control: instructions + text: section.render(), + cache_control: sections .peek() .map_or(Some(CacheControl::Ephemeral), |_| None), }); @@ -631,6 +673,9 @@ fn convert_events(events: ConversationStream) -> Vec { ChatResponse::Reasoning { reasoning, .. } => { vec![Message::default().with_reasoning(reasoning).assistant()] } + ChatResponse::Structured { data } => { + vec![Message::default().with_text(data.to_string()).assistant()] + } }, EventKind::ToolCallRequest(request) => { let message = Message { @@ -663,62 +708,47 @@ fn convert_events(events: ConversationStream) -> Vec { name: None, })] } - EventKind::InquiryRequest(_) | EventKind::InquiryResponse(_) => vec![], + EventKind::InquiryRequest(_) + | EventKind::InquiryResponse(_) + | EventKind::TurnStart(_) => vec![], }) .collect() } -#[cfg(test)] -mod tests { - use jp_config::providers::llm::LlmProviderConfig; - use jp_test::{Result, function_name}; - - use super::*; - use crate::test::TestRequest; - - macro_rules! test_all_models { - ($($fn:ident),* $(,)?) => { - mod anthropic { use super::*; $(test_all_models!(func; $fn, "openrouter/anthropic/claude-haiku-4.5");)* } - mod google { use super::*; $(test_all_models!(func; $fn, "openrouter/google/gemini-2.5-flash");)* } - mod xai { use super::*; $(test_all_models!(func; $fn, "openrouter/x-ai/grok-code-fast-1");)* } - mod minimax { use super::*; $(test_all_models!(func; $fn, "openrouter/minimax/minimax-m2");)* } - }; - (func; $fn:ident, $model:literal) => { - paste::paste! { - #[test_log::test(tokio::test)] - async fn [< test_ $fn >]() -> Result { - $fn($model, &format!("{}_{}", $model.split('/').nth(1).unwrap(), function_name!())).await - } +impl From for StreamError { + fn from(err: jp_openrouter::Error) -> Self { + use jp_openrouter::Error as E; + + match err { + E::Request(error) => Self::from(error), + E::Api { code: 429, .. } => StreamError::rate_limit(None).with_source(err), + // 402 Payment Required — OpenRouter returns this for insufficient credits. + E::Api { code: 402, .. } => StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your credits \ + at https://openrouter.ai/settings/credits. ({err})" + ), + ) + .with_source(err), + E::Api { ref message, .. } if looks_like_quota_error(message) => StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your credits \ + at https://openrouter.ai/settings/credits. ({err})" + ), + ) + .with_source(err), + E::Api { + code: 408 | 500 | 502 | 503 | 504, + .. } - }; - } - - test_all_models![sub_provider_event_metadata]; - - async fn sub_provider_event_metadata(model: &str, test_name: &str) -> Result { - let requests = vec![ - TestRequest::chat(ProviderId::Openrouter) - .model(model.parse().unwrap()) - .enable_reasoning() - .chat_request("Test message"), - ]; - - run_test(test_name, requests).await?; - - Ok(()) - } - - async fn run_test( - test_name: impl AsRef, - requests: impl IntoIterator, - ) -> Result { - crate::test::run_chat_completion( - test_name, - env!("CARGO_MANIFEST_DIR"), - ProviderId::Openrouter, - LlmProviderConfig::default(), - requests.into_iter().collect(), - ) - .await + | E::Stream(_) => StreamError::transient(err.to_string()).with_source(err), + _ => StreamError::other(err.to_string()).with_source(err), + } } } + +#[cfg(test)] +#[path = "openrouter_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/openrouter_tests.rs b/crates/jp_llm/src/provider/openrouter_tests.rs new file mode 100644 index 00000000..c7c63ee6 --- /dev/null +++ b/crates/jp_llm/src/provider/openrouter_tests.rs @@ -0,0 +1,51 @@ +use jp_config::providers::llm::LlmProviderConfig; +use jp_test::{Result, function_name}; + +use super::*; +use crate::test::TestRequest; + +macro_rules! test_all_models { + ($($fn:ident),* $(,)?) => { + mod anthropic { use super::*; $(test_all_models!(func; $fn, "openrouter/anthropic/claude-haiku-4.5");)* } + mod google { use super::*; $(test_all_models!(func; $fn, "openrouter/google/gemini-2.5-flash");)* } + mod xai { use super::*; $(test_all_models!(func; $fn, "openrouter/x-ai/grok-code-fast-1");)* } + mod minimax { use super::*; $(test_all_models!(func; $fn, "openrouter/minimax/minimax-m2");)* } + }; + (func; $fn:ident, $model:literal) => { + paste::paste! { + #[test_log::test(tokio::test)] + async fn [< test_ $fn >]() -> Result { + $fn($model, &format!("{}_{}", $model.split('/').nth(1).unwrap(), function_name!())).await + } + } + }; + } + +test_all_models![sub_provider_event_metadata]; + +async fn sub_provider_event_metadata(model: &str, test_name: &str) -> Result { + let requests = vec![ + TestRequest::chat(ProviderId::Openrouter) + .model(model.parse().unwrap()) + .enable_reasoning() + .chat_request("Test message"), + ]; + + run_test(test_name, requests).await?; + + Ok(()) +} + +async fn run_test( + test_name: impl AsRef, + requests: impl IntoIterator, +) -> Result { + crate::test::run_chat_completion( + test_name, + env!("CARGO_MANIFEST_DIR"), + ProviderId::Openrouter, + LlmProviderConfig::default(), + requests.into_iter().collect(), + ) + .await +} diff --git a/crates/jp_llm/src/provider_tests.rs b/crates/jp_llm/src/provider_tests.rs new file mode 100644 index 00000000..7f9c1880 --- /dev/null +++ b/crates/jp_llm/src/provider_tests.rs @@ -0,0 +1,198 @@ +use std::sync::Arc; + +use indexmap::IndexMap; +use jp_config::{ + assistant::tool_choice::ToolChoice, + conversation::tool::{OneOrManyTypes, ToolParameterConfig}, +}; +use jp_conversation::event::ChatRequest; +use jp_test::{Result, function_name}; + +use super::*; +use crate::test::{TestRequest, run_test, test_model_details}; + +macro_rules! test_all_providers { + ($($fn:ident),* $(,)?) => { + mod anthropic { use super::*; $(test_all_providers!(func; $fn, ProviderId::Anthropic);)* } + mod google { use super::*; $(test_all_providers!(func; $fn, ProviderId::Google);)* } + mod openai { use super::*; $(test_all_providers!(func; $fn, ProviderId::Openai);)* } + mod openrouter{ use super::*; $(test_all_providers!(func; $fn, ProviderId::Openrouter);)* } + mod ollama { use super::*; $(test_all_providers!(func; $fn, ProviderId::Ollama);)* } + mod llamacpp { use super::*; $(test_all_providers!(func; $fn, ProviderId::Llamacpp);)* } + }; + (func; $fn:ident, $provider:ty) => { + paste::paste! { + #[test_log::test(tokio::test)] + async fn [< test_ $fn >]() -> Result { + $fn($provider, function_name!()).await + } + } + }; + } + +async fn chat_completion_stream(provider: ProviderId, test_name: &str) -> Result { + let request = TestRequest::chat(provider) + .enable_reasoning() + .event(ChatRequest::from("Test message")); + + run_test(provider, test_name, Some(request)).await +} + +fn tool_call_base(provider: ProviderId) -> TestRequest { + TestRequest::chat(provider) + .event(ChatRequest::from( + "Please run the tool, providing whatever arguments you want.", + )) + .tool("run_me", vec![ + ("foo", ToolParameterConfig { + kind: OneOrManyTypes::One("string".into()), + default: Some("foo".into()), + required: false, + summary: None, + description: None, + examples: None, + enumeration: vec![], + items: None, + properties: IndexMap::default(), + }), + ("bar", ToolParameterConfig { + kind: OneOrManyTypes::Many(vec!["string".into(), "array".into()]), + default: None, + required: true, + summary: None, + description: None, + examples: None, + enumeration: vec!["foo".into(), vec!["foo", "bar"].into()], + items: Some(Box::new(ToolParameterConfig { + kind: OneOrManyTypes::One("string".into()), + default: None, + required: false, + summary: None, + description: None, + examples: None, + enumeration: vec![], + items: None, + properties: IndexMap::default(), + })), + properties: IndexMap::default(), + }), + ]) +} + +async fn tool_call_stream(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider), + TestRequest::tool_call_response(Ok("working!"), false), + ]; + + run_test(provider, test_name, requests).await +} + +/// Without reasoning, "forced" tool calls should work as expected. +async fn tool_call_required_no_reasoning(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider).tool_choice(ToolChoice::Required), + TestRequest::tool_call_response(Ok("working!"), true), + ]; + + run_test(provider, test_name, requests).await +} + +/// With reasoning, some models do not support "forced" tool calls, so +/// provider implementations should fall back to trying to instruct the +/// model to use the tool through regular textual instructions. +async fn tool_call_required_reasoning(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider) + .tool_choice(ToolChoice::Required) + .enable_reasoning(), + TestRequest::tool_call_response(Ok("working!"), false), + ]; + + run_test(provider, test_name, requests).await +} + +async fn tool_call_auto(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider).tool_choice(ToolChoice::Auto), + TestRequest::tool_call_response(Ok("working!"), false), + ]; + + run_test(provider, test_name, requests).await +} + +async fn tool_call_function(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider).tool_choice_fn("run_me"), + TestRequest::tool_call_response(Ok("working!"), true), + ]; + + run_test(provider, test_name, requests).await +} + +async fn tool_call_reasoning(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider).enable_reasoning(), + TestRequest::tool_call_response(Ok("working!"), false), + ]; + + run_test(provider, test_name, requests).await +} + +async fn model_details(provider: ProviderId, test_name: &str) -> Result { + let request = TestRequest::ModelDetails { + name: test_model_details(provider).id.name.to_string(), + assert: Arc::new(|_| {}), + }; + + run_test(provider, test_name, Some(request)).await +} + +async fn models(provider: ProviderId, test_name: &str) -> Result { + let request = TestRequest::Models { + assert: Arc::new(|_| {}), + }; + + run_test(provider, test_name, Some(request)).await +} + +async fn structured_output(provider: ProviderId, test_name: &str) -> Result { + let schema = crate::title::title_schema(1); + + let request = TestRequest::chat(provider).chat_request(ChatRequest { + content: "Generate a title for this conversation.".into(), + schema: Some(schema), + }); + + run_test(provider, test_name, Some(request)).await +} + +async fn multi_turn_conversation(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + TestRequest::chat(provider).chat_request("Test message"), + TestRequest::chat(provider) + .enable_reasoning() + .chat_request("Repeat my previous message"), + tool_call_base(provider).tool_choice_fn("run_me"), + TestRequest::tool_call_response(Ok("The secret code is: 42"), true), + TestRequest::chat(provider) + .enable_reasoning() + .chat_request("What was the result of the previous tool call?"), + ]; + + run_test(provider, test_name, requests).await +} + +test_all_providers![ + chat_completion_stream, + tool_call_auto, + tool_call_function, + tool_call_reasoning, + tool_call_required_no_reasoning, + tool_call_required_reasoning, + tool_call_stream, + model_details, + models, + multi_turn_conversation, + structured_output, +]; diff --git a/crates/jp_llm/src/query.rs b/crates/jp_llm/src/query.rs index 3da03670..b0645770 100644 --- a/crates/jp_llm/src/query.rs +++ b/crates/jp_llm/src/query.rs @@ -1,44 +1,31 @@ -pub mod chat; -pub mod structured; - -pub use chat::ChatQuery; +use jp_config::assistant::tool_choice::ToolChoice; use jp_conversation::thread::Thread; -pub use structured::StructuredQuery; -#[derive(Debug)] -pub enum Query { - Chat(ChatQuery), - Structured(StructuredQuery), -} +use crate::tool::ToolDefinition; -impl Query { - /// Get the [`Thread`] for the query. - #[must_use] - pub fn thread(&self) -> &Thread { - match self { - Self::Chat(v) => &v.thread, - Self::Structured(v) => &v.thread, - } - } +#[derive(Debug, Clone)] +pub struct ChatQuery { + pub thread: Thread, + // TODO: Should this be taken from `thread.events`, if not, document why? + // + // I think it should, because the tools that are available to the LLM are + // always represented by the configuration in the conversation stream. If a + // user adds a new tool to a config file, that tool is not automatically + // available in existing conversations (it will be in new ones), but will + // only become available when `--tool` or `--cfg` is used. + pub tools: Vec, + // TODO: Should this instead be a delta config on `thread.events`? + // + // Same logic applies here, I think? + pub tool_choice: ToolChoice, +} - /// Get a mutable reference to the [`Thread`] for the query. - #[must_use] - pub fn thread_mut(&mut self) -> &mut Thread { - match self { - Self::Chat(v) => &mut v.thread, - Self::Structured(v) => &mut v.thread, +impl From for ChatQuery { + fn from(thread: Thread) -> Self { + Self { + thread, + tools: vec![], + tool_choice: ToolChoice::default(), } } - - /// Returns `true` if the query is a chat query. - #[must_use] - pub fn is_chat(&self) -> bool { - matches!(self, Self::Chat(_)) - } - - /// Returns `true` if the query is a structured query. - #[must_use] - pub fn is_structured(&self) -> bool { - matches!(self, Self::Structured(_)) - } } diff --git a/crates/jp_llm/src/query/chat.rs b/crates/jp_llm/src/query/chat.rs deleted file mode 100644 index 9e950af6..00000000 --- a/crates/jp_llm/src/query/chat.rs +++ /dev/null @@ -1,33 +0,0 @@ -use jp_config::assistant::tool_choice::ToolChoice; -use jp_conversation::thread::Thread; - -use crate::tool::ToolDefinition; - -#[derive(Debug, Clone)] -pub struct ChatQuery { - pub thread: Thread, - // TODO: Should this be taken from `thread.events`, if not, document why? - // - // I think it should, because the tools that are available to the LLM are - // always represented by the configuration in the conversation stream. If a - // user adds a new tool to a config file, that tool is not automatically - // available in existing conversations (it will be in new ones), but will - // only become available when `--tool` or `--cfg` is used. - pub tools: Vec, - // TODO: Should this instead be a delta config on `thread.events`? - // - // Same logic applies here, I think? - pub tool_choice: ToolChoice, - pub tool_call_strict_mode: bool, -} - -impl From for ChatQuery { - fn from(thread: Thread) -> Self { - Self { - thread, - tools: vec![], - tool_choice: ToolChoice::default(), - tool_call_strict_mode: false, - } - } -} diff --git a/crates/jp_llm/src/query/structured.rs b/crates/jp_llm/src/query/structured.rs deleted file mode 100644 index 0383b0e6..00000000 --- a/crates/jp_llm/src/query/structured.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::fmt; - -use jp_config::{ - assistant::tool_choice::ToolChoice, - conversation::tool::{OneOrManyTypes, ToolParameterConfig}, -}; -use jp_conversation::thread::Thread; -use schemars::Schema; -use serde_json::Value; - -use crate::{Error, structured::SCHEMA_TOOL_NAME, tool::ToolDefinition}; - -type Mapping = Box Option + Send>; -type Validate = Box Result<(), String> + Send>; - -/// A structured query for LLMs. -pub struct StructuredQuery { - /// The thread to use for the query. - pub thread: Thread, - - /// The JSON schema to enforce the shape of the response. - schema: Schema, - - /// An optional mapping function to mutate the response object into a - /// different shape. - mapping: Option, - - /// Validators to run on the response. If a validator fails, its error is - /// sent back to the assistant, so that it can be fixed/retried. - /// - /// TODO: Add support for JSON Schema validation. - validators: Vec, -} - -impl fmt::Debug for StructuredQuery { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("StructuredQuery") - .field("thread", &self.thread) - .field("schema", &self.schema) - .field("mapping", &"") - .field("validators", &"") - .finish() - } -} - -impl StructuredQuery { - /// Create a new structured query. - #[must_use] - pub fn new(schema: Schema, thread: Thread) -> Self { - Self { - thread, - schema, - mapping: None, - validators: vec![], - } - } - - #[must_use] - pub fn with_mapping( - mut self, - mapping: impl Fn(&mut Value) -> Option + Send + Sync + 'static, - ) -> Self { - self.mapping = Some(Box::new(mapping)); - self - } - - #[must_use] - pub fn with_validator( - mut self, - validator: impl Fn(&Value) -> Result<(), String> + Send + Sync + 'static, - ) -> Self { - self.validators.push(Box::new(validator)); - self - } - - #[must_use] - pub fn with_schema_validator(self, schema: Schema) -> Self { - let validate = move |value: &Value| { - jsonschema::validate(schema.as_value(), value).map_err(|e| e.to_string()) - }; - - self.with_validator(validate) - } - - #[must_use] - pub fn map(&self, mut value: Value) -> Value { - self.mapping - .as_ref() - .and_then(|f| f(&mut value)) - .unwrap_or(value) - } - - pub fn validate(&self, value: &Value) -> Result<(), String> { - for validator in &self.validators { - validator(value)?; - } - - Ok(()) - } - - pub fn tool_definition(&self) -> Result { - let mut description = - "This tool can be used to deliver structured data to the caller. It is NOT intended \ - to GENERATE the requested data, but instead as a structured delivery mechanism. The \ - tool is a no-op implementation in that it allows the assistant to deliver structured \ - data to the user, but the tool will never report back with a result, instead the \ - user can take the structured data from the tool arguments provided." - .to_owned(); - if let Some(desc) = self.schema.get("description").and_then(|v| v.as_str()) { - description.push_str(&format!( - " Here is the description for the requested structured data:\n\n{desc}" - )); - } - - let required = self - .schema - .get("required") - .and_then(|v| v.as_array()) - .into_iter() - .flatten() - .filter_map(|v| v.as_str()) - .collect::>(); - - let parameters = self - .schema - .get("properties") - .and_then(|v| v.as_object()) - .into_iter() - .flatten() - .map(|(k, v)| { - let kind = v - .get("type") - .and_then(|v| match v.clone() { - Value::String(v) => Some(v.into()), - Value::Array(v) => Some( - v.into_iter() - .filter_map(|v| match v { - Value::String(v) => Some(v), - _ => None, - }) - .collect::>() - .into(), - ), - _ => None, - }) - .unwrap_or_else(|| OneOrManyTypes::One("object".to_owned())); - - let parameter = ToolParameterConfig { - kind, - required: required.contains(&k.as_str()), - description: v - .get("description") - .and_then(|v| v.as_str()) - .map(str::to_owned), - default: v.get("default").cloned(), - enumeration: v - .get("enum") - .and_then(|v| v.as_array().cloned()) - .unwrap_or_default(), - items: v - .get("items") - .map(|v| serde_json::from_value(v.clone())) - .transpose()?, - }; - - Ok((k.to_owned(), parameter)) - }) - .collect::>()?; - - Ok(ToolDefinition { - name: SCHEMA_TOOL_NAME.to_owned(), - description: Some(description), - parameters, - include_tool_answers_parameter: false, - }) - } - - pub fn tool_choice(&self) -> Result { - Ok(ToolChoice::Function(self.tool_definition()?.name)) - } -} diff --git a/crates/jp_llm/src/retry.rs b/crates/jp_llm/src/retry.rs new file mode 100644 index 00000000..180c8c64 --- /dev/null +++ b/crates/jp_llm/src/retry.rs @@ -0,0 +1,116 @@ +//! Retry utilities for resilient LLM request handling. + +use std::time::Duration; + +use futures::TryStreamExt as _; +use tracing::{debug, warn}; + +use crate::{Provider, error::Result, event::Event, model::ModelDetails, query::ChatQuery}; + +/// Configuration for resilient stream retries. +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// Maximum number of retry attempts. + pub max_retries: u32, + + /// Base backoff delay in milliseconds. + pub base_backoff_ms: u64, + + /// Maximum backoff delay in seconds. + pub max_backoff_secs: u64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + base_backoff_ms: 1000, + max_backoff_secs: 30, + } + } +} + +/// Execute `chat_completion_stream` with automatic retries on transient errors. +/// +/// Collects the full event stream into a `Vec`. On retryable stream +/// errors, backs off and retries the entire request up to `config.max_retries` +/// times. +/// +/// Non-retryable errors and errors from `chat_completion_stream` itself (before +/// streaming starts) are propagated immediately. +pub async fn collect_with_retry( + provider: &dyn Provider, + model: &ModelDetails, + query: ChatQuery, + config: &RetryConfig, +) -> Result> { + let mut attempt = 0u32; + + loop { + let stream = provider + .chat_completion_stream(model, query.clone()) + .await?; + + match stream.try_collect::>().await { + Ok(events) => return Ok(events), + Err(error) => { + attempt += 1; + + if !error.is_retryable() || attempt > config.max_retries { + warn!( + attempt, + max = config.max_retries, + error = error.to_string(), + "Stream error (exhausted retries)." + ); + return Err(error.into()); + } + + let delay = match error.retry_after { + Some(d) => d.min(Duration::from_secs(config.max_backoff_secs)), + None => exponential_backoff( + attempt, + config.base_backoff_ms, + config.max_backoff_secs, + ), + }; + + debug!( + attempt, + max = config.max_retries, + delay_ms = delay.as_millis(), + error = error.to_string(), + "Retryable stream error, backing off." + ); + + tokio::time::sleep(delay).await; + } + } + } +} + +/// Calculate exponential backoff delay. +/// +/// Formula: `min(base * 2^attempt, max_backoff)` +/// +/// # Arguments +/// +/// * `attempt` - Current attempt number (1-based). The delay doubles with +/// each attempt. +/// * `base_backoff_ms` - Base delay in milliseconds for the first attempt. +/// * `max_backoff_secs` - Maximum delay cap in seconds. +#[must_use] +pub fn exponential_backoff(attempt: u32, base_backoff_ms: u64, max_backoff_secs: u64) -> Duration { + let max_ms = max_backoff_secs * 1000; + + // Cap the exponent to avoid overflow. + let capped_attempt = attempt.saturating_sub(1).min(20); + let base_delay = base_backoff_ms.saturating_mul(1u64 << capped_attempt); + let total_ms = base_delay.min(max_ms); + + Duration::from_millis(total_ms) +} + +#[cfg(test)] +#[path = "retry_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/retry_tests.rs b/crates/jp_llm/src/retry_tests.rs new file mode 100644 index 00000000..eb413905 --- /dev/null +++ b/crates/jp_llm/src/retry_tests.rs @@ -0,0 +1,61 @@ +use super::*; +use crate::error::StreamError; + +/// Default base backoff for tests. +const TEST_BASE_BACKOFF_MS: u64 = 1000; + +/// Default max backoff for tests. +const TEST_MAX_BACKOFF_SECS: u64 = 60; + +#[test] +fn backoff_increases() { + let d1 = exponential_backoff(1, TEST_BASE_BACKOFF_MS, TEST_MAX_BACKOFF_SECS); + let d2 = exponential_backoff(2, TEST_BASE_BACKOFF_MS, TEST_MAX_BACKOFF_SECS); + let d3 = exponential_backoff(3, TEST_BASE_BACKOFF_MS, TEST_MAX_BACKOFF_SECS); + + // Base delays should roughly double + // attempt 1: ~1000ms, attempt 2: ~2000ms, attempt 3: ~4000ms + assert!(d1 < d2); + assert!(d2 < d3); +} + +#[test] +fn backoff_capped() { + let d_high = exponential_backoff(100, TEST_BASE_BACKOFF_MS, TEST_MAX_BACKOFF_SECS); + + // Should be capped at max_backoff_secs + assert!(d_high <= Duration::from_secs(TEST_MAX_BACKOFF_SECS + 1)); +} + +#[test] +fn backoff_respects_config() { + // Custom base and max + let d1 = exponential_backoff(1, 500, 10); + let d2 = exponential_backoff(1, 2000, 10); + + // Higher base should give higher delay + assert!(d1 < d2); + + // Should respect max cap + let d_capped = exponential_backoff(100, 1000, 5); + assert!(d_capped <= Duration::from_secs(5)); +} + +#[test] +fn stream_error_is_retryable() { + // Retryable error kinds + assert!(StreamError::timeout("test").is_retryable()); + assert!(StreamError::connect("test").is_retryable()); + assert!(StreamError::rate_limit(None).is_retryable()); + assert!(StreamError::transient("test").is_retryable()); + + // Non-retryable + assert!(!StreamError::other("test").is_retryable()); +} + +#[test] +fn stream_error_with_retry_after() { + let err = StreamError::rate_limit(Some(Duration::from_secs(30))); + assert_eq!(err.retry_after, Some(Duration::from_secs(30))); + assert!(err.is_retryable()); +} diff --git a/crates/jp_llm/src/stream.rs b/crates/jp_llm/src/stream.rs index dbbb4289..b90c0035 100644 --- a/crates/jp_llm/src/stream.rs +++ b/crates/jp_llm/src/stream.rs @@ -2,9 +2,13 @@ use std::pin::Pin; use futures::Stream; -use crate::{Error, event::Event}; +use crate::{error::StreamError, event::Event}; pub(super) mod aggregator; pub(super) mod chain; -pub type EventStream = Pin> + Send>>; +/// A stream of events from an LLM provider. +/// +/// Errors are represented as `StreamError` to provide provider-agnostic error +/// classification for retry logic. +pub type EventStream = Pin> + Send>>; diff --git a/crates/jp_llm/src/stream/aggregator.rs b/crates/jp_llm/src/stream/aggregator.rs index a1e8adc9..f3c917b5 100644 --- a/crates/jp_llm/src/stream/aggregator.rs +++ b/crates/jp_llm/src/stream/aggregator.rs @@ -1,3 +1,2 @@ -pub(crate) mod chunk; pub(crate) mod reasoning; pub(crate) mod tool_call_request; diff --git a/crates/jp_llm/src/stream/aggregator/chunk.rs b/crates/jp_llm/src/stream/aggregator/chunk.rs deleted file mode 100644 index a0d730b4..00000000 --- a/crates/jp_llm/src/stream/aggregator/chunk.rs +++ /dev/null @@ -1,195 +0,0 @@ -use indexmap::{IndexMap, map::Entry}; -use jp_conversation::{ - ConversationEvent, EventKind, - event::{ChatResponse, ToolCallRequest}, -}; -use serde_json::Value; -use tracing::warn; - -use crate::event::Event; - -/// A buffering state machine that consumes multiplexed streaming events and -/// produces coalesced events keyed by index. -pub struct EventAggregator { - /// The currently accumulating events, keyed by stream index. - pending: IndexMap, -} - -impl EventAggregator { - /// Create a new, empty aggregator. - pub fn new() -> Self { - Self { - pending: IndexMap::new(), - } - } - - /// Consumes a single streaming [`Event`] and returns a vector of zero or - /// more "completed" `Event`s. - /// - /// For consistency, we always send a `Flush` event after flushing a merged - /// `Part` event. While this is not strictly necessary, it makes the API - /// more consistent to use, regardless of whether the event aggregator is - /// used or not. - pub fn ingest(&mut self, event: Event) -> Vec { - match event { - Event::Part { index, event } => match self.pending.entry(index) { - // Nothing buffered for this index, start buffering. - Entry::Vacant(e) => { - e.insert(event); - vec![] - } - Entry::Occupied(mut e) => match try_merge_events(e.get_mut(), event) { - // Merge succeeded. Continue buffering. - Ok(()) => vec![], - // Merge failed (types were different). Force flush the OLD - // event, replace it with the NEW event. - Err(unmerged) => vec![ - Event::Part { - index, - event: e.insert(unmerged), - }, - Event::flush(index), - ], - }, - }, - - Event::Flush { index, metadata } => { - if let Some(event) = self.pending.shift_remove(&index) { - vec![ - Event::Part { - index, - event: event.with_metadata(metadata), - }, - Event::flush(index), - ] - } else { - if !metadata.is_empty() { - warn!( - index, - metadata = ?metadata, - "Received Flush with metadata for empty index." - ); - } - - vec![] - } - } - - Event::Finished(reason) => self - .pending - .drain(..) - .flat_map(|(index, event)| vec![Event::Part { index, event }, Event::flush(index)]) - .chain(std::iter::once(Event::Finished(reason))) - .collect(), - } - } -} - -/// Attempts to merge `incoming` into `target`. Returns `Ok(())` if successful, -/// or `Err(incoming)` if the events were incompatible (e.g., different types), -/// passing ownership of the incoming event back to the caller. -fn try_merge_events( - target: &mut ConversationEvent, - incoming: ConversationEvent, -) -> Result<(), ConversationEvent> { - let ConversationEvent { - kind, - metadata, - timestamp, - } = incoming; - - match (&mut target.kind, kind) { - (EventKind::ChatResponse(t_resp), EventKind::ChatResponse(i_resp)) => { - match merge_chat_responses(t_resp, i_resp) { - Ok(()) => { - // Merge successful. - // - // Now merge the remaining fields from the destructured - // event. - target.metadata.extend(metadata); - target.timestamp = timestamp; - - Ok(()) - } - Err(returned_resp) => { - // Merge failed (variant mismatch). - // - // Reconstruct the event using the returned response and - // original fields. - Err(ConversationEvent { - kind: EventKind::ChatResponse(returned_resp), - metadata, - timestamp, - }) - } - } - } - (EventKind::ToolCallRequest(t_tool), EventKind::ToolCallRequest(i_tool)) => { - merge_tool_calls(t_tool, i_tool); - target.metadata.extend(metadata); - target.timestamp = timestamp; - - Ok(()) - } - // Mismatch in high-level types (e.g. ChatResponse vs ToolCallRequest). - // Reconstruct the event and return it as an error. - (_, other_kind) => Err(ConversationEvent { - kind: other_kind, - metadata, - timestamp, - }), - } -} - -/// Merges two `ChatResponse` items. -/// -/// Returns `Err(incoming)` if they are different variants (Message vs -/// Reasoning). -fn merge_chat_responses( - target: &mut ChatResponse, - incoming: ChatResponse, -) -> Result<(), ChatResponse> { - match (target, incoming) { - (ChatResponse::Message { message: t_msg }, ChatResponse::Message { message: i_msg }) => { - t_msg.push_str(&i_msg); - Ok(()) - } - ( - ChatResponse::Reasoning { reasoning: t_reas }, - ChatResponse::Reasoning { reasoning: i_reas }, - ) => { - t_reas.push_str(&i_reas); - Ok(()) - } - - // Variants didn't match. - (_, incoming) => Err(incoming), - } -} - -/// Merges two `ToolCallRequest` items. -fn merge_tool_calls(target: &mut ToolCallRequest, incoming: ToolCallRequest) { - if target.id.is_empty() && !incoming.id.is_empty() { - target.id = incoming.id; - } - - if target.name.is_empty() && !incoming.name.is_empty() { - target.name = incoming.name; - } - - for (key, val) in incoming.arguments { - match target.arguments.get_mut(&key) { - Some(existing_val) => { - if let (Value::String(s1), Value::String(s2)) = (existing_val, &val) { - s1.push_str(s2); - } else { - // Overwrite non-string values - *target.arguments.entry(key).or_insert(Value::Null) = val; - } - } - None => { - target.arguments.insert(key, val); - } - } - } -} diff --git a/crates/jp_llm/src/stream/aggregator/reasoning.rs b/crates/jp_llm/src/stream/aggregator/reasoning.rs index 3e8aa68c..34b9beb5 100644 --- a/crates/jp_llm/src/stream/aggregator/reasoning.rs +++ b/crates/jp_llm/src/stream/aggregator/reasoning.rs @@ -1,3 +1,6 @@ +//! An extractor that segments a stream of text into 'reasoning' and 'other' +//! buckets. + #[derive(Default, Debug)] /// A parser that segments a stream of text into 'reasoning' and 'other' /// buckets. It handles streams with or without a `` block. diff --git a/crates/jp_llm/src/stream/chain.rs b/crates/jp_llm/src/stream/chain.rs index 5a9fe19d..8bce4cb1 100644 --- a/crates/jp_llm/src/stream/chain.rs +++ b/crates/jp_llm/src/stream/chain.rs @@ -293,7 +293,11 @@ fn event_text_len(event: &Event) -> usize { event .as_conversation_event() .and_then(ConversationEvent::as_chat_response) - .map_or(0, |v| v.content().len()) + .map_or(0, |v| match v { + ChatResponse::Message { message } => message.len(), + ChatResponse::Reasoning { reasoning } => reasoning.len(), + ChatResponse::Structured { .. } => 0, + }) } /// Reconstruct text from a deque of events. @@ -305,10 +309,16 @@ fn reconstruct_text(events: &VecDeque) -> (String, Vec<(usize, usize)>) { let mut map = Vec::new(); for (i, event) in events.iter().enumerate() { - if let Some(content) = event + let content = event .as_conversation_event() .and_then(ConversationEvent::as_chat_response) - .map(ChatResponse::content) + .and_then(|v| match v { + ChatResponse::Message { message } => Some(message.as_str()), + ChatResponse::Reasoning { reasoning } => Some(reasoning.as_str()), + ChatResponse::Structured { .. } => None, + }); + + if let Some(content) = content && !content.is_empty() { s.push_str(content); @@ -324,7 +334,11 @@ fn trim_event_start(event: &mut Event, count: usize) { let Some(content) = event .as_conversation_event_mut() .and_then(ConversationEvent::as_chat_response_mut) - .map(ChatResponse::content_mut) + .and_then(|v| match v { + ChatResponse::Message { message } => Some(message), + ChatResponse::Reasoning { reasoning } => Some(reasoning), + ChatResponse::Structured { .. } => None, + }) else { return; }; diff --git a/crates/jp_llm/src/structured.rs b/crates/jp_llm/src/structured.rs deleted file mode 100644 index 1645c1f0..00000000 --- a/crates/jp_llm/src/structured.rs +++ /dev/null @@ -1,31 +0,0 @@ -//! Tools for requesting structured data from LLMs using tool calls. - -pub mod titles; - -use jp_config::model::id::ModelIdConfig; -use serde::de::DeserializeOwned; - -use crate::{error::Result, provider::Provider, query::StructuredQuery}; - -// Name of the schema enforcement tool -pub(crate) const SCHEMA_TOOL_NAME: &str = "generate_structured_data"; - -/// Request structured data from the LLM for any type `T` that implements -/// [`DeserializeOwned`]. -/// -/// It assumes a [`StructuredQuery`] that has a schema to enforce the correct -/// sturcute for `T`. -/// -/// If a LLM model enforces a JSON object as the response, but you want (e.g.) a -/// list of items, you can use [`StructuredQuery::with_mapping`] to map the -/// response object into the final shape of `T`. -pub async fn completion( - provider: &dyn Provider, - model_id: &ModelIdConfig, - query: StructuredQuery, -) -> Result { - let model = provider.model_details(&model_id.name).await?; - let value = provider.structured_completion(&model, query).await?; - - serde_json::from_value(value).map_err(Into::into) -} diff --git a/crates/jp_llm/src/structured/titles.rs b/crates/jp_llm/src/structured/titles.rs deleted file mode 100644 index fd2bad28..00000000 --- a/crates/jp_llm/src/structured/titles.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::error::Error; - -use jp_config::assistant::instructions::InstructionsConfig; -use jp_conversation::{ConversationStream, thread::ThreadBuilder}; -use serde_json::Value; - -use crate::query::StructuredQuery; - -pub fn titles( - count: usize, - events: ConversationStream, - rejected: &[String], -) -> Result> { - let schema = schemars::json_schema!({ - "type": "object", - "description": format!("Provide {count} concise, descriptive factual titles for this conversation."), - "required": ["titles"], - "additionalProperties": false, - "properties": { - "titles": { - "type": "array", - "items": { - "type": "string", - "description": "A concise, descriptive title for the conversation" - }, - }, - }, - }); - - // The validator schema is more strict than the schema we use to generate - // the titles, because not all providers support the full JSON schema - // feature-set. - let validator = schemars::json_schema!({ - "type": "object", - "required": ["titles"], - "additionalProperties": false, - "properties": { - "titles": { - "type": "array", - "items": { - "type": "string", - "minLength": 1, - "maxLength": 50, - }, - "minItems": count, - "maxItems": count, - }, - }, - }); - - let mut instructions = vec![ - InstructionsConfig::default() - .with_title("Title Generation") - .with_description("Generate titles to summarize the active conversation") - .with_item(format!("Generate exactly {count} titles")) - .with_item("Concise, descriptive, factual") - .with_item("Short and to the point, no more than 50 characters") - .with_item("Deliver as a JSON array of strings") - .with_item("DO NOT mention this request to generate titles"), - ]; - - if !rejected.is_empty() { - let mut instruction = InstructionsConfig::default() - .with_title("Rejected Titles") - .with_description("These listed titles were rejected by the user and must be avoided"); - - for title in rejected { - instruction = instruction.with_item(title); - } - - instructions.push(instruction); - } - - let mapping = |value: &mut Value| value.get_mut("titles").map(Value::take); - let thread = ThreadBuilder::default() - .with_events(events) - .with_instructions(instructions) - .build()?; - - Ok(StructuredQuery::new(schema, thread) - .with_mapping(mapping) - .with_schema_validator(validator)) -} diff --git a/crates/jp_llm/src/test.rs b/crates/jp_llm/src/test.rs index 3d812646..04f26b8a 100644 --- a/crates/jp_llm/src/test.rs +++ b/crates/jp_llm/src/test.rs @@ -17,39 +17,30 @@ use jp_config::{ use jp_conversation::{ ConversationEvent, ConversationStream, event::{ChatRequest, ToolCallResponse}, + event_builder::EventBuilder, stream::ConversationEventWithConfig, thread::{Thread, ThreadBuilder}, }; use jp_test::mock::{Snap, Vcr}; -use schemars::Schema; use crate::{ event::Event, model::{ModelDetails, ReasoningDetails}, provider::get_provider, - query::{ChatQuery, StructuredQuery}, - stream::aggregator::chunk::EventAggregator, + query::ChatQuery, tool::ToolDefinition, }; +#[allow(clippy::large_enum_variant)] pub enum TestRequest { /// A chat completion request. Chat { - stream: bool, model: ModelDetails, query: ChatQuery, #[expect(clippy::type_complexity)] assert: Arc])>, }, - /// A structured completion request. - Structured { - model: ModelDetails, - query: StructuredQuery, - #[expect(clippy::type_complexity)] - assert: Arc])>, - }, - /// List all models. Models { #[expect(clippy::type_complexity)] @@ -83,7 +74,6 @@ impl TestRequest { pub fn chat(provider: ProviderId) -> Self { Self::Chat { - stream: false, model: test_model_details(provider), query: ChatQuery { thread: ThreadBuilder::new() @@ -101,35 +91,11 @@ impl TestRequest { .unwrap(), tools: vec![], tool_choice: ToolChoice::default(), - tool_call_strict_mode: false, }, assert: Arc::new(|_| {}), } } - #[expect(dead_code)] - pub fn structured(provider: ProviderId) -> Self { - Self::Structured { - model: test_model_details(provider), - query: StructuredQuery::new( - true.into(), - ThreadBuilder::new() - .with_events({ - let mut config = AppConfig::new_test(); - config.assistant.model.id = ModelIdOrAliasConfig::Id(ModelIdConfig { - provider, - name: "test".parse().unwrap(), - }); - ConversationStream::new(config.into()) - .with_created_at(Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap()) - }) - .build() - .unwrap(), - ), - assert: Arc::new(|_| {}), - } - } - pub fn tool_call_response(result: Result<&str, &str>, panic_on_missing_request: bool) -> Self { Self::ToolCallResponse { result: result.map(Into::into).map_err(Into::into), @@ -144,14 +110,6 @@ impl TestRequest { } } - pub fn stream(mut self, stream: bool) -> Self { - if let Self::Chat { stream: s, .. } = &mut self { - *s = stream; - } - - self - } - pub fn model(mut self, model: ModelIdConfig) -> Self { let Some(thread) = self.as_thread_mut() else { return self; @@ -161,9 +119,8 @@ impl TestRequest { delta.assistant.model.id = PartialModelIdOrAliasConfig::Id(model.to_partial()); thread.events.add_config_delta(delta); - match &mut self { - Self::Chat { model: m, .. } | Self::Structured { model: m, .. } => m.id = model, - _ => {} + if let Self::Chat { model: m, .. } = &mut self { + m.id = model; } self @@ -211,22 +168,6 @@ impl TestRequest { self } - #[expect(dead_code)] - pub fn schema(self, schema: impl Into) -> Self { - match self { - Self::Structured { - model, - query, - assert, - } => Self::Structured { - model, - query: StructuredQuery::new(schema.into(), query.thread), - assert, - }, - _ => self, - } - } - pub fn tool_choice_fn(self, name: impl Into) -> Self { self.tool_choice(ToolChoice::Function(name.into())) } @@ -244,21 +185,12 @@ impl TestRequest { .into_iter() .map(|(k, v)| (k.to_owned(), v)) .collect(), - include_tool_answers_parameter: false, }); } self } - pub fn tool_call_strict_mode(mut self, strict: bool) -> Self { - if let Self::Chat { query, .. } = &mut self { - query.tool_call_strict_mode = strict; - } - - self - } - #[expect(dead_code)] pub fn assert_chat(mut self, assert: impl Fn(&[Vec]) + 'static) -> Self { if let Self::Chat { assert: a, .. } = &mut self { @@ -268,18 +200,6 @@ impl TestRequest { self } - #[expect(dead_code)] - pub fn assert_structured( - mut self, - assert: impl Fn(&[Result]) + 'static, - ) -> Self { - if let Self::Structured { assert: a, .. } = &mut self { - *a = Arc::new(assert); - } - - self - } - #[expect(dead_code)] pub fn assert_models(mut self, assert: impl Fn(&[ModelDetails]) + 'static) -> Self { if let Self::Models { assert: a, .. } = &mut self { @@ -292,7 +212,6 @@ impl TestRequest { pub fn as_thread(&self) -> Option<&Thread> { match self { Self::Chat { query, .. } => Some(&query.thread), - Self::Structured { query, .. } => Some(&query.thread), _ => None, } } @@ -300,14 +219,13 @@ impl TestRequest { pub fn as_thread_mut(&mut self) -> Option<&mut Thread> { match self { Self::Chat { query, .. } => Some(&mut query.thread), - Self::Structured { query, .. } => Some(&mut query.thread), _ => None, } } pub fn as_model_details_mut(&mut self) -> Option<&mut ModelDetails> { match self { - Self::Chat { model, .. } | Self::Structured { model, .. } => Some(model), + Self::Chat { model, .. } => Some(model), _ => None, } } @@ -316,19 +234,8 @@ impl TestRequest { impl std::fmt::Debug for TestRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Chat { - stream, - model, - query, - .. - } => f + Self::Chat { model, query, .. } => f .debug_struct("Chat") - .field("stream", stream) - .field("model", model) - .field("query", query) - .finish(), - Self::Structured { model, query, .. } => f - .debug_struct("Structured") .field("model", model) .field("query", query) .finish(), @@ -414,9 +321,6 @@ pub async fn run_chat_completion( } let provider = get_provider(provider_id, &config).unwrap(); - let has_structured_request = requests - .iter() - .any(|v| matches!(v, TestRequest::Structured { .. })); let has_chat_request = requests .iter() .any(|v| matches!(v, TestRequest::Chat { .. })); @@ -433,7 +337,6 @@ pub async fn run_chat_completion( let mut all_events = vec![]; let mut history = vec![]; - let mut structured_history = vec![]; let mut model_details = vec![]; let mut models = vec![]; @@ -491,9 +394,6 @@ pub async fn run_chat_completion( TestRequest::Chat { query, .. } => { query.thread.events.config().unwrap().to_partial() } - TestRequest::Structured { query, .. } => { - query.thread.events.config().unwrap().to_partial() - } TestRequest::Models { .. } | TestRequest::ModelDetails { .. } => { PartialAppConfig::empty() } @@ -529,57 +429,75 @@ pub async fn run_chat_completion( // 3. Then we run the query, and collect the new events. match request { TestRequest::Chat { - stream, model, query, assert, } => { - let mut agg = EventAggregator::new(); - let events = if stream { - provider - .chat_completion_stream(&model, query) - .await - .unwrap() - .try_collect() - .await - .unwrap() - } else { - provider.chat_completion(&model, query).await.unwrap() - }; - - for mut event in events { - if let Event::Part { event, .. } = &mut event { - event.timestamp = - Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap(); - } - - all_events[index].push(event.clone()); - - for event in agg.ingest(event) { - if let Event::Part { event, .. } = event { - if let Some(stream) = conversation_stream.as_mut() { - stream.push(event.clone()); + let events: Vec = provider + .chat_completion_stream(&model, query) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let stream = conversation_stream + .as_mut() + .expect("Chat request always sets conversation_stream"); + let mut builder = EventBuilder::new(); + + for llm_event in events { + match llm_event { + Event::Part { index: idx, event } => { + builder.handle_part(idx, event); + } + Event::Flush { + index: idx, + metadata, + } => { + if let Some(mut event) = builder.handle_flush(idx, metadata) { + // Normalize timestamp for deterministic snapshots. + event.timestamp = + Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap(); + + all_events[index].push(Event::Part { + index: idx, + event: event.clone(), + }); + all_events[index].push(Event::flush(idx)); + + history.push(ConversationEventWithConfig { + event: event.clone(), + config: config.clone(), + }); + stream.push(event); + } + } + Event::Finished(reason) => { + for mut event in builder.drain() { + event.timestamp = + Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap(); + + all_events[index].push(Event::Part { + index: 0, + event: event.clone(), + }); + all_events[index].push(Event::flush(0)); + + history.push(ConversationEventWithConfig { + event: event.clone(), + config: config.clone(), + }); + stream.push(event); } - history.push(ConversationEventWithConfig { - event, - config: config.clone(), - }); + all_events[index].push(Event::Finished(reason)); } } } assert(&all_events); } - TestRequest::Structured { - model, - query, - assert, - } => { - let value = provider.structured_completion(&model, query).await; - structured_history.push(value); - assert(&structured_history); - } TestRequest::Models { assert } => { let value = provider.models().await.unwrap(); models.extend(value); @@ -605,18 +523,6 @@ pub async fn run_chat_completion( ]); } - if has_structured_request { - let out = structured_history - .into_iter() - .map(|v| match v { - Ok(value) => value, - Err(error) => format!("Error::{error:?}").into(), - }) - .collect::>(); - - outputs.push(("structured_outputs", Snap::json(out))); - } - if has_model_details_request { outputs.push(("model_details", Snap::debug(model_details))); } @@ -650,7 +556,7 @@ pub(crate) fn test_model_details(id: ProviderId) -> ModelDetails { display_name: None, context_window: Some(200_000), max_output_tokens: Some(64_000), - reasoning: Some(ReasoningDetails::budgetted(128, Some(24576))), + reasoning: Some(ReasoningDetails::budgetted(512, Some(24576))), knowledge_cutoff: None, deprecated: None, features: vec![], @@ -695,6 +601,7 @@ pub(crate) fn test_model_details(id: ProviderId) -> ModelDetails { deprecated: None, features: vec![], }, + ProviderId::Test => ModelDetails::empty("test/mock-model".parse().unwrap()), ProviderId::Xai => unimplemented!(), ProviderId::Deepseek => unimplemented!(), } diff --git a/crates/jp_llm/src/title.rs b/crates/jp_llm/src/title.rs new file mode 100644 index 00000000..6eced3f8 --- /dev/null +++ b/crates/jp_llm/src/title.rs @@ -0,0 +1,97 @@ +//! Shared title generation helpers. +//! +//! Provides the JSON schema and instructions used by background callers (e.g. +//! `TitleGeneratorTask`, `conversation edit --title`) to request conversation +//! titles from an LLM via structured output. + +use jp_config::assistant::{instructions::InstructionsConfig, sections::SectionConfig}; +use serde_json::{Map, Value, json}; + +/// JSON schema for the title generation structured output. +/// +/// Returns a schema requiring an object with a `titles` array of exactly +/// `count` string elements. +#[must_use] +#[allow(clippy::missing_panics_doc)] +pub fn title_schema(count: usize) -> Map { + let schema = json!({ + "type": "object", + "required": ["titles"], + "additionalProperties": false, + "properties": { + "titles": { + "type": "array", + "items": { + "type": "string", + "description": "A concise, descriptive title for the conversation" + }, + "minItems": count, + "maxItems": count, + }, + }, + }); + + schema + .as_object() + .expect("schema is always an object") + .clone() +} + +/// Build instruction sections for title generation. +/// +/// Returns one or two sections: the main generation instructions, and +/// optionally a "rejected titles" section if `rejected` is non-empty. +#[must_use] +pub fn title_instructions(count: usize, rejected: &[String]) -> Vec { + let mut sections = vec![ + InstructionsConfig::default() + .with_title("Title Generation") + .with_description("Generate titles to summarize the active conversation") + .with_item(format!("Generate exactly {count} titles")) + .with_item("Concise, descriptive, factual") + .with_item("Short and to the point, no more than 50 characters") + .with_item("Deliver as a JSON object with a \"titles\" array of strings") + .with_item("DO NOT mention this request to generate titles") + .to_section(), + ]; + + if !rejected.is_empty() { + let mut rejected_instruction = InstructionsConfig::default() + .with_title("Rejected Titles") + .with_description("These listed titles were rejected by the user and must be avoided"); + + for title in rejected { + rejected_instruction = rejected_instruction.with_item(title); + } + + sections.push(rejected_instruction.to_section()); + } + + sections +} + +/// Extract title strings from a structured JSON response. +/// +/// Expects a JSON object with a `titles` array of strings, e.g.: +/// +/// ```json +/// {"titles": ["My Title", "Another Title"]} +/// ``` +/// +/// Returns an empty vec if the structure doesn't match. +#[must_use] +pub fn extract_titles(data: &Value) -> Vec { + data.get("titles") + .and_then(Value::as_array) + .map(|arr| { + arr.iter() + .filter_map(Value::as_str) + .map(str::to_owned) + .collect() + }) + .unwrap_or_default() +} + +#[cfg(test)] +#[path = "title_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/title_tests.rs b/crates/jp_llm/src/title_tests.rs new file mode 100644 index 00000000..034709d1 --- /dev/null +++ b/crates/jp_llm/src/title_tests.rs @@ -0,0 +1,62 @@ +use super::*; + +#[test] +fn title_schema_has_correct_structure() { + let schema = title_schema(3); + + assert_eq!(schema["type"], "object"); + assert_eq!(schema["properties"]["titles"]["type"], "array"); + assert_eq!(schema["properties"]["titles"]["minItems"], 3); + assert_eq!(schema["properties"]["titles"]["maxItems"], 3); + assert!( + schema["required"] + .as_array() + .unwrap() + .contains(&json!("titles")) + ); + assert_eq!(schema["additionalProperties"], false); +} + +#[test] +fn title_schema_single_title() { + let schema = title_schema(1); + assert_eq!(schema["properties"]["titles"]["minItems"], 1); + assert_eq!(schema["properties"]["titles"]["maxItems"], 1); +} + +#[test] +fn title_instructions_without_rejected() { + let sections = title_instructions(3, &[]); + assert_eq!(sections.len(), 1); +} + +#[test] +fn title_instructions_with_rejected() { + let rejected = vec!["Bad Title".to_owned(), "Worse Title".to_owned()]; + let sections = title_instructions(3, &rejected); + assert_eq!(sections.len(), 2); +} + +#[test] +fn extract_titles_valid() { + let data = json!({"titles": ["Title A", "Title B"]}); + assert_eq!(extract_titles(&data), vec!["Title A", "Title B"]); +} + +#[test] +fn extract_titles_missing_key() { + let data = json!({"other": "value"}); + assert!(extract_titles(&data).is_empty()); +} + +#[test] +fn extract_titles_wrong_type() { + let data = json!({"titles": "not an array"}); + assert!(extract_titles(&data).is_empty()); +} + +#[test] +fn extract_titles_mixed_types_filters_non_strings() { + let data = json!({"titles": ["Valid", 42, null, "Also Valid"]}); + assert_eq!(extract_titles(&data), vec!["Valid", "Also Valid"]); +} diff --git a/crates/jp_llm/src/tool.rs b/crates/jp_llm/src/tool.rs index 8b98bb16..9e541542 100644 --- a/crates/jp_llm/src/tool.rs +++ b/crates/jp_llm/src/tool.rs @@ -1,26 +1,419 @@ -use std::{fmt::Write, sync::Arc}; +//! Tool call utilities. +pub mod builtin; +pub mod executor; + +use std::{ffi::OsStr, process::Stdio, sync::Arc}; + +pub use builtin::BuiltinTool; use camino::Utf8Path; -use crossterm::style::Stylize as _; -use indexmap::{IndexMap, IndexSet}; +use indexmap::IndexMap; use jp_config::conversation::tool::{ - OneOrManyTypes, QuestionConfig, QuestionTarget, ResultMode, RunMode, ToolCommandConfig, - ToolConfigWithDefaults, ToolParameterConfig, ToolSource, item::ToolParameterItemConfig, + OneOrManyTypes, ToolCommandConfig, ToolConfigWithDefaults, ToolParameterConfig, ToolSource, }; use jp_conversation::event::ToolCallResponse; -use jp_inquire::{InlineOption, InlineSelect}; use jp_mcp::{ RawContent, ResourceContents, id::{McpServerId, McpToolId}, }; -use jp_printer::PrinterWriter; -use jp_tool::{Action, Outcome}; +use jp_tool::{Action, Outcome, Question}; use minijinja::Environment; use serde_json::{Map, Value, json}; -use tracing::{error, info, trace}; +use tokio::process::Command; +use tokio_util::sync::CancellationToken; +use tracing::{info, trace}; use crate::error::ToolError; +/// Documentation for a single tool parameter. +#[derive(Debug, Clone)] +pub struct ParameterDocs { + pub summary: Option, + pub description: Option, + pub examples: Option, +} + +impl ParameterDocs { + #[must_use] + pub fn is_empty(&self) -> bool { + self.description.is_none() && self.examples.is_none() + } +} + +/// Documentation for a single tool. +#[derive(Debug, Clone)] +pub struct ToolDocs { + pub summary: Option, + pub description: Option, + pub examples: Option, + pub parameters: IndexMap, +} + +impl ToolDocs { + #[must_use] + pub fn is_empty(&self) -> bool { + self.description.is_none() + && self.examples.is_none() + && self.parameters.values().all(ParameterDocs::is_empty) + } +} + +/// The outcome of a tool execution. +/// +/// This type represents the possible results of executing a tool's underlying +/// command or MCP call, without any interactive prompts. The caller is +/// responsible for: +/// +/// 1. Handling permission prompts **before** calling +/// [`ToolDefinition::execute()`]. +/// 2. Handling [`ExecutionOutcome::NeedsInput`] by prompting the user or +/// assistant. +/// 3. Handling result editing **after** receiving the outcome. +/// +/// # Example Flow +/// +/// ```text +/// ToolExecutor (jp_cli) ToolDefinition (jp_llm) +/// ───────────────────── ────────────────────── +/// │ +/// ├── [AwaitingPermission] +/// │ prompt_permission() +/// │ +/// ├── [Running] +/// │ ────────────────────────────► execute() +/// │ │ +/// │ ◄──────────────────────────── ExecutionOutcome +/// ├── [AwaitingInput] (if NeedsInput) +/// │ prompt_question() +/// │ ────────────────────────────► execute() (with answer) +/// │ │ +/// │ ◄──────────────────────────── ExecutionOutcome +/// ├── [AwaitingResultEdit] +/// │ prompt_result_edit() +/// │ +/// └── [Completed] +/// ``` +#[derive(Debug)] +pub enum ExecutionOutcome { + /// Tool executed and produced a result. + Completed { + /// The tool call ID (for correlation with the request). + id: String, + + /// The execution result. + /// + /// If an error occurred, it means the tool ran, but reported an error. + result: Result, + }, + + /// Tool needs additional input before it can complete. + /// + /// The caller should: + /// 1. Present the question to the user (or delegate to the assistant) + /// 2. Collect the answer + /// 3. Call [`ToolDefinition::execute()`] again with the answer in `answers` + NeedsInput { + /// The tool call ID. + id: String, + + /// The question to ask. + question: Question, + }, + + /// Tool execution was cancelled via the cancellation token. + /// + /// This occurs when the user interrupts tool execution (e.g., Ctrl+C during + /// a long-running command). + Cancelled { + /// The tool call ID. + id: String, + }, +} + +impl ExecutionOutcome { + /// Convert the outcome to a [`ToolCallResponse`]. + /// + /// This is useful for building the final response to send to the LLM after + /// any post-processing (e.g., result editing) is complete. + /// + /// # Note + /// + /// For [`ExecutionOutcome::NeedsInput`], this returns a placeholder + /// response. The caller should typically handle `NeedsInput` specially + /// rather than converting it directly to a response. + #[must_use] + pub fn into_response(self) -> ToolCallResponse { + match self { + Self::Completed { id, result } => ToolCallResponse { id, result }, + Self::NeedsInput { id, question } => ToolCallResponse { + id, + result: Ok(format!("Tool requires additional input: {}", question.text)), + }, + Self::Cancelled { id } => ToolCallResponse { + id, + result: Ok("Tool execution cancelled by user.".to_string()), + }, + } + } + + /// Returns the tool call ID. + #[must_use] + pub fn id(&self) -> &str { + match self { + Self::Completed { id, .. } | Self::NeedsInput { id, .. } | Self::Cancelled { id } => id, + } + } + + /// Returns `true` if this is a `NeedsInput` outcome. + #[must_use] + pub fn needs_input(&self) -> bool { + matches!(self, Self::NeedsInput { .. }) + } + + /// Returns `true` if this is a `Cancelled` outcome. + #[must_use] + pub fn is_cancelled(&self) -> bool { + matches!(self, Self::Cancelled { .. }) + } + + /// Returns `true` if this is a `Completed` outcome with a successful result. + #[must_use] + pub fn is_success(&self) -> bool { + matches!(self, Self::Completed { result: Ok(_), .. }) + } +} + +/// Result of running a tool command. +/// +/// This is the single parsing point for all tool command output. Both tool +/// execution and argument formatting go through this type, ensuring consistent +/// handling of `Outcome` variants (including error traces). +#[derive(Debug)] +pub enum CommandResult { + /// Tool produced content. + Success(String), + + /// Tool reported a transient error (can be retried). + TransientError { + /// The error message. + message: String, + + /// The error trace (source chain from the tool process). + trace: Vec, + }, + + /// Tool reported a fatal error. + FatalError(String), + + /// Tool needs additional input before it can continue. + NeedsInput(Question), + + /// Tool was cancelled via the cancellation token. + Cancelled, + + /// stdout wasn't valid `Outcome` JSON. + /// + /// Falls back to treating stdout as plain text. The `success` flag + /// indicates the process exit status. + RawOutput { + /// Raw stdout content. + stdout: String, + + /// Raw stderr content. + stderr: String, + + /// Whether the process exited successfully. + success: bool, + }, +} + +impl CommandResult { + /// Format a transient error message including trace details. + /// + /// If the trace is empty, returns just the message. Otherwise appends + /// the trace entries so the LLM (or user) can see the root cause. + #[must_use] + pub fn format_error(message: &str, trace: &[String]) -> String { + if trace.is_empty() { + message.to_owned() + } else { + format!("{message}\n\nTrace:\n{}", trace.join("\n")) + } + } + + /// Convert to a `Result` suitable for tool call responses. + /// + /// - `Success` → `Ok(content)` + /// - `TransientError` → `Err(json with message + trace)` + /// - `FatalError` → `Err(raw json)` + /// - `NeedsInput` → handled separately by callers (this panics) + /// - `Cancelled` → `Ok(cancellation message)` + /// - `RawOutput` → `Ok(stdout)` if success, `Err(json)` if failure + pub fn into_tool_result(self, name: &str) -> Result { + match self { + Self::Success(content) => Ok(content), + Self::TransientError { message, trace } => Err(json!({ + "message": message, + "trace": trace, + }) + .to_string()), + Self::FatalError(raw) => Err(raw), + Self::Cancelled => Ok("Tool execution cancelled by user.".to_string()), + Self::RawOutput { + stdout, + stderr, + success, + } => { + if success { + Ok(stdout) + } else { + Err(json!({ + "message": format!("Tool '{name}' execution failed."), + "stderr": stderr, + "stdout": stdout, + }) + .to_string()) + } + } + Self::NeedsInput(_) => { + unreachable!("NeedsInput should be handled by the caller") + } + } + } +} + +/// Run a tool command asynchronously with cancellation support. +/// +/// This is the **single entry point** for running tool commands (both execution +/// and argument formatting). It handles: +/// +/// 1. Template rendering via [`minijinja`] +/// 2. Process spawning via Tokio's [`Command`] +/// 3. Cancellation via [`CancellationToken`] +/// 4. Parsing stdout as [`jp_tool::Outcome`] +pub async fn run_tool_command( + command: ToolCommandConfig, + ctx: Value, + root: &Utf8Path, + cancellation_token: CancellationToken, +) -> Result { + let ToolCommandConfig { + program, + args, + shell, + } = command; + + let tmpl = Arc::new(Environment::new()); + + let program = tmpl + .render_str(&program, &ctx) + .map_err(|error| ToolError::TemplateError { + data: program.clone(), + error, + })?; + + let args = args + .iter() + .map(|s| tmpl.render_str(s, &ctx)) + .collect::, _>>() + .map_err(|error| ToolError::TemplateError { + data: args.join(" "), + error, + })?; + + let mut cmd = if shell { + let shell_cmd = std::iter::once(program.clone()) + .chain(args.iter().cloned()) + .collect::>() + .join(" "); + + let mut cmd = Command::new("sh"); + cmd.arg("-c").arg(&shell_cmd); + cmd + } else { + let mut cmd = Command::new(&program); + cmd.args(&args); + cmd + }; + + let child = cmd + .current_dir(root.as_std_path()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(|error| ToolError::SpawnError { + command: format!( + "{} {}", + cmd.as_std().get_program().to_string_lossy(), + cmd.as_std() + .get_args() + .filter_map(OsStr::to_str) + .collect::>() + .join(" ") + ), + error, + })?; + + let wait_handle = tokio::spawn(async move { child.wait_with_output().await }); + let abort_handle = wait_handle.abort_handle(); + + tokio::select! { + biased; + () = cancellation_token.cancelled() => { + abort_handle.abort(); + Ok(CommandResult::Cancelled) + } + result = wait_handle => { + match result { + Ok(Ok(output)) => Ok(parse_command_output( + &output.stdout, + &output.stderr, + output.status.success(), + )), + Ok(Err(error)) => Ok(CommandResult::RawOutput { + stdout: String::new(), + stderr: error.to_string(), + success: false, + }), + Err(join_error) => Ok(CommandResult::RawOutput { + stdout: String::new(), + stderr: format!("Task panicked: {join_error}"), + success: false, + }), + } + } + } +} + +/// Parse raw command output into a [`CommandResult`]. +/// +/// Tries to deserialize stdout as [`jp_tool::Outcome`]. If that fails, +/// falls back to [`CommandResult::RawOutput`]. +fn parse_command_output(stdout: &[u8], stderr: &[u8], success: bool) -> CommandResult { + let stdout_str = String::from_utf8_lossy(stdout); + + match serde_json::from_str::(&stdout_str) { + Ok(Outcome::Success { content }) => CommandResult::Success(content), + Ok(Outcome::Error { + transient, + message, + trace, + }) => { + if transient { + CommandResult::TransientError { message, trace } + } else { + CommandResult::FatalError(stdout_str.into_owned()) + } + } + Ok(Outcome::NeedsInput { question }) => CommandResult::NeedsInput(question), + Err(_) => CommandResult::RawOutput { + stdout: stdout_str.into_owned(), + stderr: String::from_utf8_lossy(stderr).into_owned(), + success, + }, + } +} + /// The definition of a tool. /// /// The definition source is either a [`ToolConfig`] for `local` tools, or a @@ -33,13 +426,6 @@ pub struct ToolDefinition { pub name: String, pub description: Option, pub parameters: IndexMap, - - /// Whether the tool should include the `tool_answers` parameter. - /// - /// This is `true` for any tool that has its questions configured such that - /// at least one question has to be answered by the assistant instead of the - /// user. - pub include_tool_answers_parameter: bool, } impl ToolDefinition { @@ -48,15 +434,13 @@ impl ToolDefinition { source: &ToolSource, description: Option, parameters: IndexMap, - questions: &IndexMap, mcp_client: &jp_mcp::Client, ) -> Result { match &source { - ToolSource::Local { .. } => Ok(local_tool_definition( + ToolSource::Local { .. } | ToolSource::Builtin { .. } => Ok(local_tool_definition( name.to_owned(), description, parameters, - questions, )), ToolSource::Mcp { server, tool } => { mcp_tool_definition( @@ -69,126 +453,139 @@ impl ToolDefinition { ) .await } - ToolSource::Builtin { .. } => todo!(), } } - pub fn format_args( - &self, - name: Option<&str>, - cmd: &ToolCommandConfig, - arguments: &Map, - root: &Utf8Path, - ) -> Result, ToolError> { - let name = name.unwrap_or(&self.name); - if arguments.is_empty() { - return Ok(Ok(String::new())); - } - - let ctx = json!({ - "tool": { - "name": self.name, - "arguments": arguments, - }, - "context": { - "action": Action::FormatArguments, - "root": root.as_str(), - }, - }); - - run_cmd_with_ctx(name, cmd, &ctx, root) - } - - pub async fn call( + /// Execute the tool without any interactive prompts. + /// + /// This is a pure execution method that runs the tool's underlying command + /// or MCP call and returns an [`ExecutionOutcome`]. All interactive + /// decisions (permission prompts, result editing, question handling) are + /// the caller's responsibility. + /// + /// # Arguments + /// + /// * `id` - The tool call ID for correlation with the request + /// * `arguments` - The tool arguments (caller is responsible for any pre-processing) + /// * `answers` - Pre-provided answers to tool questions (from previous `NeedsInput`) + /// * `config` - Tool configuration + /// * `mcp_client` - MCP client for MCP tool execution + /// * `root` - Working directory for local tool execution + /// * `cancellation_token` - Token to cancel long-running execution + /// * `builtin_executors` - Registry of builtin tools + /// + /// # Returns + /// + /// - [`ExecutionOutcome::Completed`] - Tool finished (check inner `Result` for success/error) + /// - [`ExecutionOutcome::NeedsInput`] - Tool needs user input to continue + /// - [`ExecutionOutcome::Cancelled`] - Execution was cancelled via the token + /// + /// # Errors + /// + /// Returns [`ToolError`] for infrastructure errors (spawn failure, missing + /// command, etc.). Tool-level errors (command returned non-zero) are + /// returned as `Ok(ExecutionOutcome::Completed { result: Err(...) })`. + /// + /// # Example + /// + /// ```ignore + /// loop { + /// match definition.execute(id, &args, &answers, ...).await? { + /// ExecutionOutcome::Completed { result, .. } => { + /// // Handle success or tool error + /// break result; + /// } + /// ExecutionOutcome::NeedsInput { question, .. } => { + /// // Prompt user for input + /// let answer = prompt_user(&question)?; + /// answers.insert(question.id, answer); + /// // Loop to retry with answer + /// } + /// ExecutionOutcome::Cancelled { .. } => { + /// break Ok("Cancelled".into()); + /// } + /// } + /// } + /// ``` + pub async fn execute( &self, id: String, - mut arguments: Value, + arguments: Value, answers: &IndexMap, - pending_questions: &IndexSet, + config: &ToolConfigWithDefaults, mcp_client: &jp_mcp::Client, - config: ToolConfigWithDefaults, root: &Utf8Path, - editor: Option<&Utf8Path>, - writer: PrinterWriter<'_>, - ) -> Result { - info!(tool = %self.name, arguments = ?arguments, "Calling tool."); - - // If the tool call has answers to provide to the tool, it means the - // tool already ran once, and we should not ask for confirmation again. - let run_mode = if pending_questions.is_empty() { - config.run() - } else { - RunMode::Unattended - }; + cancellation_token: CancellationToken, + builtin_executors: &builtin::BuiltinExecutors, + ) -> Result { + info!(tool = %self.name, arguments = ?arguments, "Executing tool."); - let mut result_mode = config.result(); - if answers.is_empty() { - self.prepare_run( - run_mode, - &mut result_mode, - &mut arguments, - config.source(), - mcp_client, - editor, - writer, - ) - .await?; - } - - let result = match config.source() { + match config.source() { ToolSource::Local { tool } => { - self.call_local(id, &arguments, answers, &config, tool.as_deref(), root)? + self.execute_local( + id, + arguments, + answers, + config, + tool.as_deref(), + root, + cancellation_token, + ) + .await } ToolSource::Mcp { server, tool } => { - self.call_mcp( + self.execute_mcp( id, - &arguments, + arguments, mcp_client, server.as_deref(), tool.as_deref(), + cancellation_token, ) - .await? + .await } - ToolSource::Builtin { .. } => todo!(), - }; - - trace!(result = ?result, "Tool call completed."); - self.prepare_result(result, result_mode, editor, writer) + ToolSource::Builtin { .. } => { + self.execute_builtin(id, &arguments, answers, builtin_executors) + .await + } + } } - fn call_local( + /// Execute a local tool and return the outcome. + /// + /// This is the pure execution path for local tools. It validates arguments, + /// runs the command, and converts the result to an `ExecutionOutcome`. + async fn execute_local( &self, id: String, - arguments: &Value, + mut arguments: Value, answers: &IndexMap, config: &ToolConfigWithDefaults, tool: Option<&str>, root: &Utf8Path, - ) -> Result { + cancellation_token: CancellationToken, + ) -> Result { let name = tool.unwrap_or(&self.name); - // TODO: Should we enforce at a type-level this for all tool calls, even - // MCP? - if let Some(args) = arguments.as_object() - && let Err(error) = validate_tool_arguments( - args, - &config - .parameters() - .iter() - .map(|(k, v)| (k.to_owned(), v.required)) - .collect(), - ) - { - return Ok(ToolCallResponse { - id, - result: Err(format!("Invalid arguments: {error}")), - }); + // Apply configured defaults for missing parameters, then validate. + if let Some(args) = arguments.as_object_mut() { + apply_parameter_defaults(args, config.parameters()); + + if let Err(error) = validate_tool_arguments(args, config.parameters()) { + return Ok(ExecutionOutcome::Completed { + id, + result: Err(format!( + "Invalid arguments: {error}\n\nYou can call `describe_tools(tools: \ + [\"{name}\"])` to learn more about how to use the tool correctly." + )), + }); + } } let ctx = json!({ "tool": { "name": name, - "arguments": arguments, + "arguments": &arguments, "answers": answers, }, "context": { @@ -201,73 +598,123 @@ impl ToolDefinition { return Err(ToolError::MissingCommand); }; - Ok(ToolCallResponse { - id, - result: run_cmd_with_ctx(name, &command, &ctx, root)?, - }) + match run_tool_command(command, ctx, root, cancellation_token).await? { + CommandResult::Success(content) => Ok(ExecutionOutcome::Completed { + id, + result: Ok(content), + }), + CommandResult::NeedsInput(question) => { + Ok(ExecutionOutcome::NeedsInput { id, question }) + } + CommandResult::Cancelled => Ok(ExecutionOutcome::Cancelled { id }), + other => Ok(ExecutionOutcome::Completed { + id, + result: other.into_tool_result(name), + }), + } } - async fn call_mcp( + /// Execute an MCP tool and return the outcome. + /// + /// This is the pure execution path for MCP tools. It calls the MCP server + /// and converts the result to an `ExecutionOutcome`. + async fn execute_mcp( &self, id: String, - arguments: &Value, + arguments: Value, mcp_client: &jp_mcp::Client, server: Option<&str>, tool: Option<&str>, - ) -> Result { + cancellation_token: CancellationToken, + ) -> Result { let name = tool.unwrap_or(&self.name); - let result = mcp_client - .call_tool(name, server, arguments) - .await - .map_err(ToolError::McpRunToolError)?; + let call_future = mcp_client.call_tool(name, server, &arguments); - let content = result - .content - .into_iter() - .filter_map(|v| match v.raw { - RawContent::Text(v) => Some(v.text), - RawContent::Resource(v) => match v.resource { - ResourceContents::TextResourceContents { text, .. } => Some(text), - ResourceContents::BlobResourceContents { blob, .. } => Some(blob), - }, - RawContent::Image(_) | RawContent::Audio(_) => None, - }) - .collect::>() - .join("\n\n"); + tokio::select! { + biased; + () = cancellation_token.cancelled() => { + info!(tool = %self.name, "MCP tool call cancelled"); + Ok(ExecutionOutcome::Cancelled { id }) + } + result = call_future => { + let result = result.map_err(ToolError::McpRunToolError)?; + + let content = result + .content + .into_iter() + .filter_map(|v| match v.raw { + RawContent::Text(v) => Some(v.text), + RawContent::Resource(v) => match v.resource { + ResourceContents::TextResourceContents { text, .. } => Some(text), + ResourceContents::BlobResourceContents { blob, .. } => Some(blob), + }, + RawContent::Image(_) | RawContent::Audio(_) => None, + }) + .collect::>() + .join("\n\n"); - Ok(ToolCallResponse { - id, - result: if result.is_error.unwrap_or_default() { - Err(content) - } else { - Ok(content) + let result = if result.is_error.unwrap_or_default() { + Err(content) + } else { + Ok(content) + }; + + Ok(ExecutionOutcome::Completed { id, result }) + } + } + } + + /// Execute a builtin tool and return the outcome. + async fn execute_builtin( + &self, + id: String, + arguments: &Value, + answers: &IndexMap, + builtin_executors: &builtin::BuiltinExecutors, + ) -> Result { + let executor = builtin_executors + .get(&self.name) + .ok_or_else(|| ToolError::NotFound { + name: self.name.clone(), + })?; + + let outcome = executor.execute(arguments, answers).await; + + Ok(match outcome { + jp_tool::Outcome::Success { content } => ExecutionOutcome::Completed { + id, + result: Ok(content), }, + jp_tool::Outcome::Error { + message, + trace, + transient: _, + } => { + let error_msg = if trace.is_empty() { + message + } else { + format!("{message}\n\nTrace:\n{}", trace.join("\n")) + }; + ExecutionOutcome::Completed { + id, + result: Err(error_msg), + } + } + jp_tool::Outcome::NeedsInput { question } => { + ExecutionOutcome::NeedsInput { id, question } + } }) } /// Return a map of parameter names to JSON schemas. #[must_use] pub fn to_parameters_map(&self) -> Map { - let mut map = self - .parameters + self.parameters .clone() .into_iter() .map(|(k, v)| (k, v.to_json_schema())) - .collect::>(); - - if self.include_tool_answers_parameter { - map.insert( - "tool_answers".to_owned(), - json!({ - "type": ["object", "null"], - "additionalProperties": true, - "description": "Answers to the tool's questions. This should only be used if explicitly requested by the user.", - }), - ); - } - - map + .collect() } /// Return a JSON schema for the parameters of the tool. @@ -287,421 +734,117 @@ impl ToolDefinition { "required": required, }) } +} - #[expect(clippy::too_many_lines)] - async fn prepare_run( - &self, - run_mode: RunMode, - result_mode: &mut ResultMode, - arguments: &mut Value, - source: &ToolSource, - mcp_client: &jp_mcp::Client, - editor: Option<&Utf8Path>, - mut writer: PrinterWriter<'_>, - ) -> Result<(), ToolError> { - match run_mode { - RunMode::Ask => match InlineSelect::new( - { - let mut question = format!( - "Run {} {} tool", - match source { - ToolSource::Builtin { .. } => "built-in", - ToolSource::Local { .. } => "local", - ToolSource::Mcp { .. } => "mcp", - }, - self.name.as_str().bold().yellow(), - ); - - if let ToolSource::Mcp { server, tool } = source { - let tool = McpToolId::new(tool.as_ref().unwrap_or(&self.name)); - let server = server.as_ref().map(|s| McpServerId::new(s.clone())); - let server_id = mcp_client - .get_tool_server_id(&tool, server.as_ref()) - .await - .map_err(ToolError::McpGetToolError)?; - - question = format!( - "{} from {} server?", - question, - server_id.as_str().bold().blue() - ); - } +/// Split a description string into a short summary and remaining detail. +/// +/// If the text is short (single line, ≤120 chars), it is returned as the +/// summary with no remaining description. +/// +/// Otherwise, the first sentence is extracted as the summary. A sentence +/// ends at `. ` or `.\n`. The remainder becomes the description. +pub(crate) fn split_description(text: &str) -> (String, Option) { + let text = text.trim(); + + // Find the first sentence boundary. + // Look for ". " or ".\n" — a period followed by whitespace. + for (i, _) in text.match_indices('.') { + let after = i + 1; + if after >= text.len() { + // Period at end of string — the whole text is one sentence. + break; + } - question - }, - vec![ - InlineOption::new('y', "Run tool"), - InlineOption::new('n', "Skip running tool"), - InlineOption::new( - 'r', - format!( - "Change run mode (current: {})", - run_mode.to_string().italic().yellow() - ), - ), - InlineOption::new( - 'x', - format!( - "Change result mode (current: {})", - result_mode.to_string().italic().yellow() - ), - ), - InlineOption::new('p', "Print raw tool arguments"), - ], - ) - .prompt(&mut writer) - .unwrap_or('n') + let next_byte = text.as_bytes()[after]; + if next_byte == b'\n' { + // Period followed by newline is always a sentence boundary. + } else if next_byte == b' ' { + // Period followed by space: only split if the next non-space + // character is uppercase (heuristic to skip abbreviations + // like "e.g. foo"). + let rest_after_space = text[after..].trim_start(); + if rest_after_space.is_empty() + || !rest_after_space + .chars() + .next() + .is_some_and(char::is_uppercase) { - 'y' => return Ok(()), - 'n' => return Err(ToolError::Skipped { reason: None }), - 'r' => { - let new_run_mode = match InlineSelect::new("Run Mode", { - let mut options = vec![ - InlineOption::new('a', "Ask"), - InlineOption::new('u', "Unattended (Run Tool Without Changes)"), - InlineOption::new('e', "Edit Arguments"), - InlineOption::new('s', "Skip Call"), - ]; - - if editor.is_some() { - options.push(InlineOption::new('S', "Skip Call, with reasoning")); - } - - options.push(InlineOption::new('c', "Keep Current Run Mode")); - options - }) - .prompt(&mut writer) - .unwrap_or('c') - { - 'a' => RunMode::Ask, - 'u' => RunMode::Unattended, - 'e' => RunMode::Edit, - 's' => RunMode::Skip, - 'S' => match editor { - None => RunMode::Skip, - Some(editor) => { - return Err(ToolError::Skipped { - reason: Some( - open_editor::EditorCallBuilder::new() - .with_editor(open_editor::Editor::from_bin_path( - editor.into(), - )) - .edit_string( - "_Provide reasoning for skipping tool execution_", - ) - .map_err(|error| ToolError::OpenEditorError { - arguments: arguments.clone(), - error, - })?, - ), - }); - } - }, - 'c' => run_mode, - _ => unimplemented!(), - }; - - return Box::pin(self.prepare_run( - new_run_mode, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - 'x' => { - match InlineSelect::new("Result Mode", vec![ - InlineOption::new('a', "Ask"), - InlineOption::new('u', "Unattended (Delver Payload As Is)"), - InlineOption::new('e', "Edit Result Payload"), - InlineOption::new('s', "Skip (Don't Deliver Payload)"), - InlineOption::new('c', "Keep Current Result Mode"), - ]) - .prompt(&mut writer) - .unwrap_or('c') - { - 'a' => *result_mode = ResultMode::Ask, - 'u' => *result_mode = ResultMode::Unattended, - 'e' => *result_mode = ResultMode::Edit, - 's' => *result_mode = ResultMode::Skip, - 'c' => {} - _ => unimplemented!(), - } - - return Box::pin(self.prepare_run( - run_mode, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - 'p' => { - if let Err(error) = - writeln!(writer, "{}\n", serde_json::to_string_pretty(&arguments)?) - { - error!(%error, "Failed to write arguments"); - } - - return Box::pin(self.prepare_run( - RunMode::Ask, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - _ => unreachable!(), - }, - RunMode::Unattended => return Ok(()), - RunMode::Skip => return Err(ToolError::Skipped { reason: None }), - RunMode::Edit => { - let mut args = serde_json::to_string_pretty(&arguments).map_err(|error| { - ToolError::SerializeArgumentsError { - arguments: arguments.clone(), - error, - } - })?; - - *arguments = { - if let Some(editor) = editor { - open_editor::EditorCallBuilder::new() - .with_editor(open_editor::Editor::from_bin_path(editor.into())) - .edit_string_mut(&mut args) - .map_err(|error| ToolError::OpenEditorError { - arguments: arguments.clone(), - error, - })?; - } - - // If the user removed all data from the arguments, we consider the - // edit a no-op, and ask the user if they want to run the tool. - if args.trim().is_empty() { - return Box::pin(self.prepare_run( - RunMode::Ask, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - - match serde_json::from_str::(&args) { - Ok(value) => value, - - // If we can't parse the arguments as valid JSON, we consider - // the input invalid, and ask the user if they want to re-open - // the editor. - Err(error) => { - if let Err(error) = writeln!(writer, "JSON parsing error: {error}") { - error!(%error, "Failed to write error"); - } - - let retry = InlineSelect::new("Re-open editor?", vec![ - InlineOption::new('y', "Open editor to edit arguments"), - InlineOption::new('n', "Skip editing, failing with error"), - ]) - .with_default('y') - .prompt(&mut writer) - .unwrap_or('n'); - - if retry == 'n' { - return Err(ToolError::EditArgumentsError { - arguments: arguments.clone(), - error, - }); - } - - return Box::pin(self.prepare_run( - RunMode::Edit, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - } - }; + continue; } + } else { + continue; } - Ok(()) - } + { + let summary = text[..=i].trim().to_owned(); + let rest = text[after..].trim(); - fn prepare_result( - &self, - mut result: ToolCallResponse, - result_mode: ResultMode, - editor: Option<&Utf8Path>, - mut writer: PrinterWriter<'_>, - ) -> Result { - match result_mode { - ResultMode::Ask => match InlineSelect::new( - format!( - "Deliver the results of the {} tool call?", - self.name.as_str().bold().yellow(), - ), - vec![ - InlineOption::new('y', "Deliver results"), - InlineOption::new('n', "Do not deliver results"), - InlineOption::new('e', "Edit results manually"), - ], - ) - .with_default('y') - .prompt(&mut writer) - .unwrap_or('n') - { - 'y' => return Ok(result), - 'n' => { - return Ok(ToolCallResponse { - id: result.id, - result: Ok("Tool call result omitted by user.".into()), - }); - } - 'e' => {} - _ => unreachable!(), - }, - ResultMode::Unattended => return Ok(result), - ResultMode::Skip => { - return Ok(ToolCallResponse { - id: result.id, - result: Ok("Tool ran successfully.".into()), - }); + if rest.is_empty() { + return (summary, None); } - ResultMode::Edit => {} + + return (summary, Some(rest.to_owned())); } + } - if let Some(editor) = editor { - let content = open_editor::EditorCallBuilder::new() - .with_editor(open_editor::Editor::from_bin_path(editor.into())) - .edit_string(result.content()) - .map_err(|error| ToolError::OpenEditorError { - arguments: Value::Null, - error, - })?; - - // If the user removed all data from the result, we consider the edit a - // no-op, and ask the user if they want to deliver the tool results. - if content.trim().is_empty() { - return self.prepare_result(result, ResultMode::Ask, Some(editor), writer); - } + // No sentence boundary found — take the first line. + if let Some(nl) = text.find('\n') { + let summary = text[..nl].trim().to_owned(); + let rest = text[nl..].trim(); - result.result = Ok(content); + if rest.is_empty() { + return (summary, None); } - Ok(result) + return (summary, Some(rest.to_owned())); } -} - -fn run_cmd_with_ctx( - name: &str, - command: &ToolCommandConfig, - ctx: &Value, - root: &Utf8Path, -) -> Result, ToolError> { - let command = { - let tmpl = Arc::new(Environment::new()); - - let program = - tmpl.render_str(&command.program, ctx) - .map_err(|error| ToolError::TemplateError { - data: command.program.clone(), - error, - })?; - - let args = command - .args - .iter() - .map(|s| tmpl.render_str(s, ctx)) - .collect::, _>>() - .map_err(|error| ToolError::TemplateError { - data: command.args.join(" ").clone(), - error, - })?; - - let expression = if command.shell { - let cmd = std::iter::once(program.clone()) - .chain(args.iter().cloned()) - .collect::>() - .join(" "); - duct_sh::sh_dangerous(cmd) - } else { - duct::cmd(program.clone(), args) - }; + // Single long line, no period — return as-is. + (text.to_owned(), None) +} - expression - .dir(root) - .unchecked() - .stdout_capture() - .stderr_capture() - }; +/// Fill in configured default values for missing parameters. +/// +/// LLMs commonly omit parameters that have a `default` in the JSON schema, +/// even when those parameters are marked `required`. This function patches +/// the arguments map before validation so that such omissions don't cause +/// spurious "missing argument" errors and unnecessary LLM retries. +fn apply_parameter_defaults( + arguments: &mut Map, + parameters: &IndexMap, +) { + for (name, cfg) in parameters { + if !arguments.contains_key(name) { + if let Some(default) = &cfg.default { + arguments.insert(name.clone(), default.clone()); + } + continue; + } - match command.run() { - Ok(output) => { - let stdout = String::from_utf8_lossy(&output.stdout); - let content = match serde_json::from_str::(&stdout) { - Err(_) => stdout.to_string(), - Ok(Outcome::Error { - transient, - message, - trace, - }) => { - if transient { - return Ok(Err(json!({ - "message": message, - "trace": trace, - }) - .to_string())); - } + // Recurse into object fields. + if let Some(obj) = arguments.get_mut(name).and_then(Value::as_object_mut) + && !cfg.properties.is_empty() + { + apply_parameter_defaults(obj, &cfg.properties); + } - return Err(ToolError::ToolCallFailed(stdout.to_string())); - } - Ok(Outcome::Success { content }) => content, - Ok(Outcome::NeedsInput { question }) => { - return Err(ToolError::NeedsInput { question }); + // Recurse into array elements. + if let Some(items) = &cfg.items + && !items.properties.is_empty() + && let Some(arr) = arguments.get_mut(name).and_then(Value::as_array_mut) + { + for elem in arr.iter_mut() { + if let Some(obj) = elem.as_object_mut() { + apply_parameter_defaults(obj, &items.properties); } - }; - - if output.status.success() { - Ok(Ok(content)) - } else { - let stderr = String::from_utf8_lossy(&output.stderr); - Ok(Err(json!({ - "message": format!("Tool '{name}' execution failed."), - "stderr": stderr, - "stdout": content, - }) - .to_string())) } } - Err(error) => Ok(Err(json!({ - "message": format!( - "Failed to execute command '{command:?}': {error}", - ), - }) - .to_string())), } } fn validate_tool_arguments( arguments: &Map, - parameters: &IndexMap, + parameters: &IndexMap, ) -> Result<(), ToolError> { let unknown = arguments .keys() @@ -710,8 +853,8 @@ fn validate_tool_arguments( .collect::>(); let mut missing = vec![]; - for (name, required) in parameters { - if *required && !arguments.contains_key(name) { + for (name, cfg) in parameters { + if cfg.required && !arguments.contains_key(name) { missing.push(name.to_owned()); } } @@ -720,51 +863,241 @@ fn validate_tool_arguments( return Err(ToolError::Arguments { missing, unknown }); } + // Recurse into nested structures. + for (name, cfg) in parameters { + let Some(value) = arguments.get(name) else { + continue; + }; + + // Object parameters with properties: validate the object fields. + if let Some(obj) = value.as_object() + && !cfg.properties.is_empty() + { + validate_tool_arguments(obj, &cfg.properties)?; + } + + // Array parameters with items that have properties: validate each + // element. + if let Some(items) = &cfg.items + && !items.properties.is_empty() + && let Some(arr) = value.as_array() + { + for element in arr { + if let Some(obj) = element.as_object() { + validate_tool_arguments(obj, &items.properties)?; + } + } + } + } + Ok(()) } +/// Resolved tool definitions and their on-demand documentation. +pub struct ResolvedTools { + /// Tool definitions sent to the LLM provider. + pub definitions: Vec, + + /// Per-tool documentation for `describe_tools`, keyed by tool name. + pub docs: IndexMap, +} + pub async fn tool_definitions( configs: impl Iterator, mcp_client: &jp_mcp::Client, -) -> Result, ToolError> { +) -> Result { let mut definitions = vec![]; + let mut docs = IndexMap::new(); + for (name, config) in configs { // Skip disabled tools. if !config.enable() { continue; } - definitions.push( - ToolDefinition::new( + let (definition, tool_docs) = resolve_tool(name, &config, mcp_client).await?; + + if !tool_docs.is_empty() { + docs.insert(name.to_owned(), tool_docs); + } + + definitions.push(definition); + } + + Ok(ResolvedTools { definitions, docs }) +} + +/// Resolve a single tool definition and its documentation. +async fn resolve_tool( + name: &str, + config: &ToolConfigWithDefaults, + mcp_client: &jp_mcp::Client, +) -> Result<(ToolDefinition, ToolDocs), ToolError> { + match config.source() { + ToolSource::Local { .. } | ToolSource::Builtin { .. } => { + // For local/builtin tools, docs come from config fields. + let definition = ToolDefinition::new( name, config.source(), - config.description().map(str::to_owned), + config.summary().map(str::to_owned), config.parameters().clone(), - config.questions(), mcp_client, ) - .await?, - ); + .await?; + + let tool_docs = docs_from_config(config); + Ok((definition, tool_docs)) + } + ToolSource::Mcp { .. } => { + // For MCP tools, resolve against the server, then build docs + // from config overrides + auto-split of MCP descriptions. + resolve_mcp_tool(name, config, mcp_client).await + } + } +} + +/// Build `ToolDocs` from config fields (local/builtin tools). +fn docs_from_config(config: &ToolConfigWithDefaults) -> ToolDocs { + let summary = config.summary().map(str::to_owned); + let description = config.description().map(str::to_owned); + let examples = config.examples().map(str::to_owned); + + let parameters = config + .parameters() + .iter() + .filter_map(|(param_name, param_cfg)| { + let summary = param_cfg + .summary + .as_deref() + .or(param_cfg.description.as_deref()) + .map(str::to_owned); + let desc = param_cfg.description.as_deref().map(str::to_owned); + let ex = param_cfg.examples.as_deref().map(str::to_owned); + + if summary.is_none() && desc.is_none() && ex.is_none() { + return None; + } + + Some((param_name.to_owned(), ParameterDocs { + summary, + description: desc, + examples: ex, + })) + }) + .collect(); + + ToolDocs { + summary, + description, + examples, + parameters, + } +} + +/// Resolve an MCP tool: build definition + docs with auto-split heuristic. +async fn resolve_mcp_tool( + name: &str, + config: &ToolConfigWithDefaults, + mcp_client: &jp_mcp::Client, +) -> Result<(ToolDefinition, ToolDocs), ToolError> { + let has_user_summary = config.summary().is_some(); + + let definition = ToolDefinition::new( + name, + config.source(), + config.summary().map(str::to_owned), + config.parameters().clone(), + mcp_client, + ) + .await?; + + // Build docs. If the user provided summary/description/examples in + // config, use those. Otherwise, auto-split the resolved MCP description. + let (summary, description) = if has_user_summary { + // User provided explicit summary — use config fields as-is. + ( + config.summary().map(str::to_owned), + config.description().map(str::to_owned), + ) + } else if let Some(resolved) = &definition.description { + // No user summary — auto-split the MCP description. + let (s, d) = split_description(resolved); + (Some(s), d) + } else { + (None, None) + }; + + let examples = config.examples().map(str::to_owned); + + // Build parameter docs. For each parameter, check if the user + // provided an override. If not, auto-split the MCP description. + let parameters = definition + .parameters + .iter() + .filter_map(|(param_name, param_cfg)| { + let user_override = config.parameters().get(param_name); + let has_user_param_summary = user_override.and_then(|o| o.summary.as_ref()).is_some(); + + let (summary, desc) = if has_user_param_summary { + // User provided explicit summary for this parameter. + let summary = user_override + .and_then(|o| o.summary.as_deref()) + .or(user_override.and_then(|o| o.description.as_deref())) + .map(str::to_owned); + let desc = user_override + .and_then(|o| o.description.as_deref()) + .map(str::to_owned); + (summary, desc) + } else if let Some(resolved) = ¶m_cfg.description { + // Auto-split the resolved (possibly MCP) description. + let (s, d) = split_description(resolved); + (Some(s), d) + } else { + (None, None) + }; + + let ex = user_override + .and_then(|o| o.examples.as_deref()) + .map(str::to_owned); + + if summary.is_none() && desc.is_none() && ex.is_none() { + return None; + } + + Some((param_name.to_owned(), ParameterDocs { + summary, + description: desc, + examples: ex, + })) + }) + .collect(); + + // The definition's description should be the summary (short) for the + // provider API. Replace it if we auto-split. + let mut definition = definition; + if !has_user_summary && let Some(ref s) = summary { + definition.description = Some(s.clone()); } - Ok(definitions) + let tool_docs = ToolDocs { + summary, + description, + examples, + parameters, + }; + + Ok((definition, tool_docs)) } fn local_tool_definition( name: String, description: Option, parameters: IndexMap, - questions: &IndexMap, ) -> ToolDefinition { - let include_tool_answers_parameter = questions - .iter() - .any(|(_, v)| v.target == QuestionTarget::Assistant); - ToolDefinition { name, description, parameters, - include_tool_answers_parameter, } } @@ -901,11 +1234,13 @@ async fn mcp_tool_definition( params.insert(name.to_owned(), ToolParameterConfig { kind, default, - description, required, + summary: None, + description, + examples: None, enumeration, items: opts.get("items").and_then(|v| v.as_object()).and_then(|v| { - Some(ToolParameterItemConfig { + Some(Box::new(ToolParameterConfig { kind: match v.get("type")? { Value::String(v) => OneOrManyTypes::One(v.to_owned()), Value::Array(v) => OneOrManyTypes::Many( @@ -917,10 +1252,16 @@ async fn mcp_tool_definition( _ => return None, }, default: None, + required: false, + summary: None, description: None, + examples: None, enumeration: vec![], - }) + items: None, + properties: IndexMap::default(), + })) }), + properties: IndexMap::default(), }); } @@ -928,65 +1269,9 @@ async fn mcp_tool_definition( name: name.to_owned(), description, parameters: params, - include_tool_answers_parameter: false, }) } #[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_validate_tool_arguments() { - struct TestCase { - arguments: Map, - parameters: IndexMap, - want: Result<(), ToolError>, - } - - let cases = vec![ - ("empty", TestCase { - arguments: Map::new(), - parameters: IndexMap::new(), - want: Ok(()), - }), - ("correct", TestCase { - arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), - parameters: IndexMap::from_iter([ - ("foo".to_owned(), true), - ("bar".to_owned(), false), - ]), - want: Ok(()), - }), - ("missing", TestCase { - arguments: Map::new(), - parameters: IndexMap::from_iter([("foo".to_owned(), true)]), - want: Err(ToolError::Arguments { - missing: vec!["foo".to_owned()], - unknown: vec![], - }), - }), - ("unknown", TestCase { - arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), - parameters: IndexMap::from_iter([("bar".to_owned(), false)]), - want: Err(ToolError::Arguments { - missing: vec![], - unknown: vec!["foo".to_owned()], - }), - }), - ("both", TestCase { - arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), - parameters: IndexMap::from_iter([("bar".to_owned(), true)]), - want: Err(ToolError::Arguments { - missing: vec!["bar".to_owned()], - unknown: vec!["foo".to_owned()], - }), - }), - ]; - - for (name, test_case) in cases { - let result = validate_tool_arguments(&test_case.arguments, &test_case.parameters); - assert_eq!(result, test_case.want, "failed case: {name}"); - } - } -} +#[path = "tool_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/tool/builtin.rs b/crates/jp_llm/src/tool/builtin.rs new file mode 100644 index 00000000..b247c43c --- /dev/null +++ b/crates/jp_llm/src/tool/builtin.rs @@ -0,0 +1,45 @@ +//! Builtin tool trait and executor registry. +//! +//! Maps tool names to their Rust implementations. + +pub mod describe_tools; + +use std::{collections::HashMap, sync::Arc}; + +use async_trait::async_trait; +use indexmap::IndexMap; +use jp_tool::Outcome; +use serde_json::Value; + +/// A built-in tool that executes Rust code instead of shelling out. +#[async_trait] +pub trait BuiltinTool: Send + Sync { + /// Execute the tool with the given arguments and accumulated answers. + async fn execute(&self, arguments: &Value, answers: &IndexMap) -> Outcome; +} + +/// Registry mapping builtin tool names to their executors. +#[derive(Clone, Default)] +pub struct BuiltinExecutors { + executors: HashMap>, +} + +impl BuiltinExecutors { + #[must_use] + pub fn new() -> Self { + Self { + executors: HashMap::new(), + } + } + + #[must_use] + pub fn register(mut self, name: impl Into, tool: impl BuiltinTool + 'static) -> Self { + self.executors.insert(name.into(), Arc::new(tool)); + self + } + + #[must_use] + pub fn get(&self, name: &str) -> Option> { + self.executors.get(name).cloned() + } +} diff --git a/crates/jp_llm/src/tool/builtin/describe_tools.rs b/crates/jp_llm/src/tool/builtin/describe_tools.rs new file mode 100644 index 00000000..2d4ab907 --- /dev/null +++ b/crates/jp_llm/src/tool/builtin/describe_tools.rs @@ -0,0 +1,126 @@ +//! The `describe_tools` builtin implementation. + +use async_trait::async_trait; +use indexmap::IndexMap; +use jp_tool::Outcome; +use serde_json::Value; + +use crate::tool::{BuiltinTool, ToolDocs}; + +pub struct DescribeTools { + docs: IndexMap, +} + +impl DescribeTools { + #[must_use] + pub fn new(docs: IndexMap) -> Self { + Self { docs } + } + + fn format_tool_docs(name: &str, docs: &ToolDocs) -> String { + let mut out = format!("## {name}\n"); + + if let Some(summary) = &docs.summary { + out.push('\n'); + out.push_str(summary); + out.push('\n'); + } + + if let Some(desc) = &docs.description { + out.push('\n'); + out.push_str(desc); + out.push('\n'); + } + + if let Some(examples) = &docs.examples { + out.push_str("\n### Examples\n\n"); + out.push_str(examples); + out.push('\n'); + } + + let has_param_docs = docs.parameters.values().any(|p| !p.is_empty()); + if has_param_docs { + out.push_str("\n### Parameters\n"); + + for (param_name, param_docs) in &docs.parameters { + if param_docs.is_empty() { + continue; + } + + out.push_str(&format!("\n#### `{param_name}`\n")); + + if let Some(summary) = ¶m_docs.summary { + out.push('\n'); + out.push_str(summary); + out.push('\n'); + } + + if let Some(desc) = ¶m_docs.description { + out.push('\n'); + out.push_str(desc); + out.push('\n'); + } + + if let Some(examples) = ¶m_docs.examples { + out.push('\n'); + out.push_str(examples); + out.push('\n'); + } + } + } + + out + } +} + +#[async_trait] +impl BuiltinTool for DescribeTools { + async fn execute(&self, arguments: &Value, _answers: &IndexMap) -> Outcome { + let tool_names = match arguments.get("tools").and_then(Value::as_array) { + Some(arr) => arr.iter().filter_map(Value::as_str).collect::>(), + None => { + return Outcome::Error { + message: "Missing or invalid `tools` parameter.".to_owned(), + trace: vec![], + transient: false, + }; + } + }; + + if tool_names.is_empty() { + return Outcome::Error { + message: "The `tools` array must not be empty.".to_owned(), + trace: vec![], + transient: false, + }; + } + + let mut sections = Vec::new(); + let mut not_found = Vec::new(); + + for name in &tool_names { + match self.docs.get(*name) { + Some(docs) => sections.push(Self::format_tool_docs(name, docs)), + None => not_found.push(*name), + } + } + + let mut output = sections.join("\n---\n\n"); + + if !not_found.is_empty() { + if !output.is_empty() { + output.push_str("\n---\n\n"); + } + output.push_str(&format!( + "No additional documentation available for: {}", + not_found.join(", ") + )); + } + + Outcome::Success { content: output } + } +} + +#[cfg(test)] +#[path = "describe_tools_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/tool/builtin/describe_tools_tests.rs b/crates/jp_llm/src/tool/builtin/describe_tools_tests.rs new file mode 100644 index 00000000..e250ef25 --- /dev/null +++ b/crates/jp_llm/src/tool/builtin/describe_tools_tests.rs @@ -0,0 +1,343 @@ +use indexmap::IndexMap; +use jp_tool::Outcome; +use serde_json::{Value, json}; + +use super::*; +use crate::tool::{ParameterDocs, ToolDocs}; + +fn empty_tool_docs() -> ToolDocs { + ToolDocs { + summary: None, + description: None, + examples: None, + parameters: IndexMap::new(), + } +} + +fn param_docs( + summary: Option<&str>, + description: Option<&str>, + examples: Option<&str>, +) -> ParameterDocs { + ParameterDocs { + summary: summary.map(str::to_owned), + description: description.map(str::to_owned), + examples: examples.map(str::to_owned), + } +} + +fn no_answers() -> IndexMap { + IndexMap::new() +} + +#[test] +fn test_format_empty_docs() { + let out = DescribeTools::format_tool_docs("my_tool", &empty_tool_docs()); + assert_eq!(out, "## my_tool\n"); +} + +#[test] +fn test_format_summary_only() { + let docs = ToolDocs { + summary: Some("A brief summary.".to_owned()), + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!(out, "## my_tool\n\nA brief summary.\n"); +} + +#[test] +fn test_format_description_only() { + let docs = ToolDocs { + description: Some("Detailed description.".to_owned()), + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!(out, "## my_tool\n\nDetailed description.\n"); +} + +#[test] +fn test_format_summary_and_description() { + let docs = ToolDocs { + summary: Some("Summary.".to_owned()), + description: Some("Description.".to_owned()), + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!(out, "## my_tool\n\nSummary.\n\nDescription.\n"); +} + +#[test] +fn test_format_all_tool_fields() { + let docs = ToolDocs { + summary: Some("Summary.".to_owned()), + description: Some("Description.".to_owned()), + examples: Some("my_tool()".to_owned()), + parameters: IndexMap::new(), + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!( + out, + "## my_tool\n\nSummary.\n\nDescription.\n\n### Examples\n\nmy_tool()\n" + ); +} + +#[test] +fn test_format_parameter_with_description() { + let mut parameters = IndexMap::new(); + parameters.insert( + "input".to_owned(), + param_docs(Some("Short summary."), Some("Long description."), None), + ); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!( + out, + "## my_tool\n\n### Parameters\n\n#### `input`\n\nShort summary.\n\nLong description.\n" + ); +} + +#[test] +fn test_format_parameter_with_examples() { + let mut parameters = IndexMap::new(); + parameters.insert( + "query".to_owned(), + param_docs(None, Some("The search query."), Some("\"hello world\"")), + ); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!( + out, + "## my_tool\n\n### Parameters\n\n#### `query`\n\nThe search query.\n\n\"hello world\"\n" + ); +} + +#[test] +fn test_format_parameter_all_fields() { + let mut parameters = IndexMap::new(); + parameters.insert( + "path".to_owned(), + param_docs( + Some("File path."), + Some("Absolute or relative path to the target file."), + Some("\"src/main.rs\""), + ), + ); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!( + out, + "## my_tool\n\n### Parameters\n\n#### `path`\n\nFile path.\n\nAbsolute or relative path \ + to the target file.\n\n\"src/main.rs\"\n" + ); +} + +#[test] +fn test_format_skips_empty_parameters() { + // is_empty() checks description and examples only, so a param with only + // summary is still "empty" and is excluded from the output. + let mut parameters = IndexMap::new(); + parameters.insert( + "documented".to_owned(), + param_docs(None, Some("Has description."), None), + ); + parameters.insert("undocumented".to_owned(), param_docs(None, None, None)); + parameters.insert( + "summary_only".to_owned(), + param_docs(Some("Only a summary."), None, None), + ); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert!( + out.contains("#### `documented`"), + "documented param should appear" + ); + assert!( + !out.contains("#### `undocumented`"), + "undocumented param should be skipped" + ); + assert!( + !out.contains("#### `summary_only`"), + "summary-only param should be skipped" + ); +} + +#[test] +fn test_format_no_parameters_section_when_all_params_empty() { + // If every parameter is empty, the "### Parameters" section is omitted. + let mut parameters = IndexMap::new(); + parameters.insert("a".to_owned(), param_docs(None, None, None)); + parameters.insert("b".to_owned(), param_docs(Some("summary only"), None, None)); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert!(!out.contains("### Parameters")); + assert_eq!(out, "## my_tool\n"); +} + +#[tokio::test] +async fn test_execute_missing_tools_argument() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool.execute(&json!({}), &no_answers()).await; + let Outcome::Error { + message, transient, .. + } = result + else { + panic!("expected Outcome::Error"); + }; + assert!(message.contains("`tools`")); + assert!(!transient); +} + +#[tokio::test] +async fn test_execute_tools_not_an_array() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool + .execute(&json!({"tools": "my_tool"}), &no_answers()) + .await; + assert!( + matches!(result, Outcome::Error { .. }), + "non-array `tools` should be an error" + ); +} + +#[tokio::test] +async fn test_execute_empty_tools_array() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool.execute(&json!({"tools": []}), &no_answers()).await; + let Outcome::Error { message, .. } = result else { + panic!("expected Outcome::Error"); + }; + assert!(message.contains("must not be empty")); +} + +#[tokio::test] +async fn test_execute_single_known_tool() { + let mut docs = IndexMap::new(); + docs.insert("my_tool".to_owned(), ToolDocs { + summary: Some("Tool summary.".to_owned()), + ..empty_tool_docs() + }); + + let tool = DescribeTools::new(docs); + let result = tool + .execute(&json!({"tools": ["my_tool"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!(content, "## my_tool\n\nTool summary.\n"); +} + +#[tokio::test] +async fn test_execute_known_tool_with_empty_docs() { + let mut docs = IndexMap::new(); + docs.insert("bare_tool".to_owned(), empty_tool_docs()); + + let tool = DescribeTools::new(docs); + let result = tool + .execute(&json!({"tools": ["bare_tool"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!(content, "## bare_tool\n"); +} + +#[tokio::test] +async fn test_execute_multiple_known_tools_separated_by_divider() { + let mut docs = IndexMap::new(); + docs.insert("tool_a".to_owned(), ToolDocs { + summary: Some("A.".to_owned()), + ..empty_tool_docs() + }); + docs.insert("tool_b".to_owned(), ToolDocs { + summary: Some("B.".to_owned()), + ..empty_tool_docs() + }); + + let tool = DescribeTools::new(docs); + let result = tool + .execute(&json!({"tools": ["tool_a", "tool_b"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!(content, "## tool_a\n\nA.\n\n---\n\n## tool_b\n\nB.\n"); +} + +#[tokio::test] +async fn test_execute_single_unknown_tool() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool + .execute(&json!({"tools": ["unknown_tool"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!( + content, + "No additional documentation available for: unknown_tool" + ); +} + +#[tokio::test] +async fn test_execute_multiple_unknown_tools() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool + .execute(&json!({"tools": ["foo", "bar"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!( + content, + "No additional documentation available for: foo, bar" + ); +} + +#[tokio::test] +async fn test_execute_mixed_known_and_unknown_tools() { + let mut docs = IndexMap::new(); + docs.insert("known".to_owned(), ToolDocs { + summary: Some("Summary.".to_owned()), + ..empty_tool_docs() + }); + + let tool = DescribeTools::new(docs); + let result = tool + .execute(&json!({"tools": ["known", "unknown"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!( + content, + "## known\n\nSummary.\n\n---\n\nNo additional documentation available for: unknown" + ); +} diff --git a/crates/jp_llm/src/tool/executor.rs b/crates/jp_llm/src/tool/executor.rs new file mode 100644 index 00000000..dbde0514 --- /dev/null +++ b/crates/jp_llm/src/tool/executor.rs @@ -0,0 +1,346 @@ +use std::sync::Mutex; + +use async_trait::async_trait; +use camino::Utf8Path; +use indexmap::IndexMap; +use jp_config::conversation::tool::{RunMode, ToolSource, ToolsConfig}; +use jp_conversation::event::{ToolCallRequest, ToolCallResponse}; +use jp_mcp::Client; +use jp_tool::Question; +use serde_json::{Map, Value}; +use tokio_util::sync::CancellationToken; + +use crate::ToolError; + +/// Trait for tool execution, enabling mock implementations for testing. +/// +/// This trait abstracts the execution of a single tool call, allowing the +/// `ToolCoordinator` to work with both real and mock executors. +/// +/// # Design +/// +/// The executor is intentionally simple - it just executes tools with given +/// answers. All decision-making about question targets, static answers, and how +/// to handle `NeedsInput` is done by the coordinator, which has access to the +/// tool configuration. +#[async_trait] +pub trait Executor: Send + Sync { + /// Returns the tool call ID. + fn tool_id(&self) -> &str; + + /// Returns the tool name. + fn tool_name(&self) -> &str; + + /// Returns the tool call arguments. + /// + /// This is separate from [`permission_info()`](Self::permission_info) + /// because arguments are always available, while permission info is only + /// present for tools that require a permission prompt. + fn arguments(&self) -> &Map; + + /// Returns information needed for permission prompting. + /// + /// Returns `None` if the tool doesn't need a permission prompt (e.g., + /// `RunMode::Unattended` or `RunMode::Skip`). + fn permission_info(&self) -> Option; + + /// Updates the arguments to use for execution. + /// + /// This is called after permission prompting if the user edited the + /// arguments (via `RunMode::Edit`). The new arguments replace the original + /// arguments from the tool call request. + fn set_arguments(&mut self, args: Value); + + /// Executes the tool once with the given answers. + /// + /// This method performs a single execution pass. If the tool needs + /// additional input, it returns `ExecutorResult::NeedsInput` and the + /// coordinator handles prompting and retrying. + /// + /// The executor doesn't know how questions should be answered - it just + /// reports that input is needed. The coordinator looks up the tool + /// configuration to determine whether to prompt the user or ask the LLM. + /// + /// # Arguments + /// + /// * `answers` - Accumulated answers from previous `NeedsInput` responses + /// * `mcp_client` - MCP client for remote tool execution + /// * `root` - Project root directory + /// * `cancellation_token` - Token to cancel execution + async fn execute( + &self, + answers: &IndexMap, + mcp_client: &Client, + root: &Utf8Path, + cancellation_token: CancellationToken, + ) -> ExecutorResult; +} + +/// Abstraction over how executors are created for tool calls. +/// +/// This trait enables dependency injection of executor creation, allowing tests +/// to use mock executors without executing real shell commands. +#[async_trait] +pub trait ExecutorSource: Send + Sync { + /// Creates an executor for the given tool call request. + /// + /// # Arguments + /// + /// * `request` - The tool call request from the LLM + /// * `tools_config` - Configuration for all tools + /// * `mcp_client` - MCP client for remote tool execution + /// + /// # Errors + /// + /// Returns an error if the tool is not found or cannot be initialized. + async fn create( + &self, + request: ToolCallRequest, + tools_config: &ToolsConfig, + mcp_client: &Client, + ) -> Result, ToolError>; +} + +/// Result of a tool execution attempt. +/// +/// Tools may need multiple rounds of execution if they require additional +/// input. This enum allows the executor to return control to the coordinator, +/// which decides how to handle the `NeedsInput` case by looking up the question +/// configuration. +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] // NeedsInput variant is larger but rarely used +pub enum ExecutorResult { + /// Tool completed (success or error). + Completed(ToolCallResponse), + + /// Tool needs additional input before it can continue. + /// + /// The executor doesn't know who should answer - it just reports that input + /// is needed. The coordinator looks up the question configuration to + /// determine the target: + /// + /// - `User`: Prompt the user interactively, then restart the tool + /// - `Assistant`: Format a response asking the LLM to re-run with answers + NeedsInput { + /// Tool call ID. + tool_id: String, + + /// Tool name (for persisting answers). + tool_name: String, + + /// The question that needs to be answered. + question: Question, + + /// Accumulated answers so far (for retry). + accumulated_answers: IndexMap, + }, +} + +/// A mock executor for testing that returns pre-configured results. +/// +/// This executor doesn't execute any real commands - it simply returns whatever +/// result is configured, making it ideal for testing tool coordination flows +/// without side effects. +/// +/// # Example +/// +/// ```ignore +/// let executor = MockExecutor::completed("call_1", "my_tool", "success output"); +/// let result = executor.execute(&answers, &client, &root, token).await; +/// assert!(result.is_completed()); +/// ``` +pub struct MockExecutor { + tool_id: String, + tool_name: String, + arguments: Map, + permission_info: Option, + result: Mutex>, +} + +impl MockExecutor { + /// Creates a mock executor that returns a successful completion. + #[must_use] + pub fn completed(tool_id: &str, tool_name: &str, output: &str) -> Self { + Self { + tool_id: tool_id.to_string(), + tool_name: tool_name.to_string(), + arguments: Map::new(), + permission_info: None, + result: Mutex::new(Some(ExecutorResult::Completed(ToolCallResponse { + id: tool_id.to_string(), + result: Ok(output.to_string()), + }))), + } + } + + /// Creates a mock executor that returns an error. + #[must_use] + pub fn error(tool_id: &str, tool_name: &str, error: &str) -> Self { + Self { + tool_id: tool_id.to_string(), + tool_name: tool_name.to_string(), + arguments: Map::new(), + permission_info: None, + result: Mutex::new(Some(ExecutorResult::Completed(ToolCallResponse { + id: tool_id.to_string(), + result: Err(error.to_string()), + }))), + } + } + + /// Sets the arguments for this executor. + #[must_use] + pub fn with_arguments(mut self, args: Map) -> Self { + self.arguments = args; + self + } + + /// Sets the permission info for this executor. + /// + /// If set, the executor will require permission prompting based on the + /// configured `RunMode`. + #[must_use] + pub fn with_permission_info(mut self, info: PermissionInfo) -> Self { + self.permission_info = Some(info); + self + } + + /// Sets a custom result for this executor. + #[must_use] + pub fn with_result(mut self, result: ExecutorResult) -> Self { + self.result = Mutex::new(Some(result)); + self + } +} + +#[async_trait] +impl Executor for MockExecutor { + fn tool_id(&self) -> &str { + &self.tool_id + } + + fn tool_name(&self) -> &str { + &self.tool_name + } + + fn arguments(&self) -> &Map { + &self.arguments + } + + fn permission_info(&self) -> Option { + self.permission_info.clone() + } + + fn set_arguments(&mut self, _args: Value) { + // No-op for mock executor - arguments don't affect the pre-configured + // result + } + + async fn execute( + &self, + _answers: &IndexMap, + _mcp_client: &Client, + _root: &Utf8Path, + _cancellation_token: CancellationToken, + ) -> ExecutorResult { + self.result.lock().unwrap().take().unwrap_or_else(|| { + ExecutorResult::Completed(ToolCallResponse { + id: self.tool_id.clone(), + result: Err("MockExecutor: result already consumed".to_string()), + }) + }) + } +} + +/// An executor source for testing that returns pre-registered mock executors. +/// +/// This allows tests to inject mock executors for specific tool names without +/// executing any real shell commands. +/// +/// # Example +/// +/// ```ignore +/// let source = TestExecutorSource::new() +/// .with_executor("my_tool", |req| { +/// Box::new(MockExecutor::completed(&req.id, &req.name, "mock output")) +/// }); +/// +/// let coordinator = ToolCoordinator::new(tools_config, Arc::new(source)); +/// ``` +pub struct TestExecutorSource { + #[allow(clippy::type_complexity)] + factories: std::collections::HashMap< + String, + Box Box + Send + Sync>, + >, +} + +impl TestExecutorSource { + /// Creates a new empty test executor source. + #[must_use] + pub fn new() -> Self { + Self { + factories: std::collections::HashMap::new(), + } + } + + /// Registers a factory function for a tool name. + /// + /// When `create()` is called for this tool name, the factory will be + /// invoked to create the executor. + #[must_use] + pub fn with_executor(mut self, tool_name: &str, factory: F) -> Self + where + F: Fn(ToolCallRequest) -> Box + Send + Sync + 'static, + { + self.factories + .insert(tool_name.to_string(), Box::new(factory)); + self + } +} + +impl Default for TestExecutorSource { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ExecutorSource for TestExecutorSource { + async fn create( + &self, + request: ToolCallRequest, + _tools_config: &ToolsConfig, + _mcp_client: &Client, + ) -> Result, ToolError> { + if let Some(factory) = self.factories.get(&request.name) { + Ok(factory(request)) + } else { + Err(ToolError::NotFound { + name: request.name.clone(), + }) + } + } +} + +/// Information needed to prompt for tool execution permission. +/// +/// This struct contains all the data the `ToolPrompter` needs to show a +/// permission prompt to the user. +#[derive(Debug, Clone)] +pub struct PermissionInfo { + /// The tool call ID. + pub tool_id: String, + + /// The tool name. + pub tool_name: String, + + /// The tool source (builtin, local, MCP). + pub tool_source: ToolSource, + + /// The configured run mode. + pub run_mode: RunMode, + + /// The arguments to pass to the tool. + pub arguments: Value, +} diff --git a/crates/jp_llm/src/tool_tests.rs b/crates/jp_llm/src/tool_tests.rs new file mode 100644 index 00000000..d7e19e5b --- /dev/null +++ b/crates/jp_llm/src/tool_tests.rs @@ -0,0 +1,567 @@ +use jp_tool::AnswerType; + +use super::*; + +#[test] +fn test_execution_outcome_completed_success_into_response() { + let outcome = ExecutionOutcome::Completed { + id: "call_123".to_string(), + result: Ok("Tool output".to_string()), + }; + + let response = outcome.into_response(); + assert_eq!(response.id, "call_123"); + assert_eq!(response.result, Ok("Tool output".to_string())); +} + +#[test] +fn test_execution_outcome_completed_error_into_response() { + let outcome = ExecutionOutcome::Completed { + id: "call_456".to_string(), + result: Err("Tool failed".to_string()), + }; + + let response = outcome.into_response(); + assert_eq!(response.id, "call_456"); + assert_eq!(response.result, Err("Tool failed".to_string())); +} + +#[test] +fn test_execution_outcome_needs_input_into_response() { + let question = Question { + id: "q1".to_string(), + text: "What is your name?".to_string(), + answer_type: AnswerType::Text, + default: None, + }; + + let outcome = ExecutionOutcome::NeedsInput { + id: "call_789".to_string(), + question, + }; + + let response = outcome.into_response(); + assert_eq!(response.id, "call_789"); + assert!(response.result.is_ok()); + assert!( + response + .result + .unwrap() + .contains("requires additional input") + ); +} + +#[test] +fn test_execution_outcome_cancelled_into_response() { + let outcome = ExecutionOutcome::Cancelled { + id: "call_abc".to_string(), + }; + + let response = outcome.into_response(); + assert_eq!(response.id, "call_abc"); + assert!(response.result.is_ok()); + assert!(response.result.unwrap().contains("cancelled")); +} + +#[test] +fn test_execution_outcome_id() { + let completed = ExecutionOutcome::Completed { + id: "id1".to_string(), + result: Ok(String::new()), + }; + assert_eq!(completed.id(), "id1"); + + let needs_input = ExecutionOutcome::NeedsInput { + id: "id2".to_string(), + question: Question { + id: "q".to_string(), + text: "?".to_string(), + answer_type: AnswerType::Text, + default: None, + }, + }; + assert_eq!(needs_input.id(), "id2"); + + let cancelled = ExecutionOutcome::Cancelled { + id: "id3".to_string(), + }; + assert_eq!(cancelled.id(), "id3"); +} + +#[test] +fn test_execution_outcome_helper_methods() { + let success = ExecutionOutcome::Completed { + id: "1".to_string(), + result: Ok("output".to_string()), + }; + assert!(success.is_success()); + assert!(!success.needs_input()); + assert!(!success.is_cancelled()); + + let failure = ExecutionOutcome::Completed { + id: "2".to_string(), + result: Err("error".to_string()), + }; + assert!(!failure.is_success()); + assert!(!failure.needs_input()); + assert!(!failure.is_cancelled()); + + let needs_input = ExecutionOutcome::NeedsInput { + id: "3".to_string(), + question: Question { + id: "q".to_string(), + text: "?".to_string(), + answer_type: AnswerType::Boolean, + default: None, + }, + }; + assert!(!needs_input.is_success()); + assert!(needs_input.needs_input()); + assert!(!needs_input.is_cancelled()); + + let cancelled = ExecutionOutcome::Cancelled { + id: "4".to_string(), + }; + assert!(!cancelled.is_success()); + assert!(!cancelled.needs_input()); + assert!(cancelled.is_cancelled()); +} + +/// Build a minimal `ToolParameterConfig` for use in validation tests. +fn param(kind: &str, required: bool) -> ToolParameterConfig { + ToolParameterConfig { + kind: kind.to_owned().into(), + required, + default: None, + summary: None, + description: None, + examples: None, + enumeration: vec![], + items: None, + properties: IndexMap::default(), + } +} + +#[test] +fn test_validate_tool_arguments() { + struct TestCase { + arguments: Map, + parameters: IndexMap, + want: Result<(), ToolError>, + } + + let cases = vec![ + ("empty", TestCase { + arguments: Map::new(), + parameters: IndexMap::new(), + want: Ok(()), + }), + ("correct", TestCase { + arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), + parameters: IndexMap::from_iter([ + ("foo".to_owned(), param("string", true)), + ("bar".to_owned(), param("string", false)), + ]), + want: Ok(()), + }), + ("missing", TestCase { + arguments: Map::new(), + parameters: IndexMap::from_iter([("foo".to_owned(), param("string", true))]), + want: Err(ToolError::Arguments { + missing: vec!["foo".to_owned()], + unknown: vec![], + }), + }), + ("unknown", TestCase { + arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), + parameters: IndexMap::from_iter([("bar".to_owned(), param("string", false))]), + want: Err(ToolError::Arguments { + missing: vec![], + unknown: vec!["foo".to_owned()], + }), + }), + ("both", TestCase { + arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), + parameters: IndexMap::from_iter([("bar".to_owned(), param("string", true))]), + want: Err(ToolError::Arguments { + missing: vec!["bar".to_owned()], + unknown: vec!["foo".to_owned()], + }), + }), + ]; + + for (name, test_case) in cases { + let result = validate_tool_arguments(&test_case.arguments, &test_case.parameters); + assert_eq!(result, test_case.want, "failed case: {name}"); + } +} + +#[test] +fn test_validate_nested_array_item_properties() { + // Mirrors the fs_modify_file schema: + // patterns: array of { old: string (required), new: string (required) } + let parameters = IndexMap::from_iter([ + ("path".to_owned(), param("string", true)), + ("patterns".to_owned(), ToolParameterConfig { + kind: "array".to_owned().into(), + required: true, + items: Some(Box::new(ToolParameterConfig { + kind: "object".to_owned().into(), + required: false, + properties: IndexMap::from_iter([ + ("old".to_owned(), param("string", true)), + ("new".to_owned(), param("string", true)), + ]), + ..param("object", false) + })), + ..param("array", true) + }), + ]); + + // Valid: correct inner fields. + let args = json!({ + "path": "src/lib.rs", + "patterns": [{"old": "foo", "new": "bar"}] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Valid: multiple items. + let args = json!({ + "path": "src/lib.rs", + "patterns": [ + {"old": "a", "new": "b"}, + {"old": "c", "new": "d"} + ] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Invalid: unknown inner field. + let args = json!({ + "path": "src/lib.rs", + "patterns": [{"old": "foo", "new": "bar", "extra": true}] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Err(ToolError::Arguments { + missing: vec![], + unknown: vec!["extra".to_owned()], + }) + ); + + // Invalid: missing required inner field. + let args = json!({ + "path": "src/lib.rs", + "patterns": [{"old": "foo"}] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Err(ToolError::Arguments { + missing: vec!["new".to_owned()], + unknown: vec![], + }) + ); + + // Invalid: wrong inner field names (the LLM hallucinated names). + let args = json!({ + "path": "src/lib.rs", + "patterns": [{"string_to_replace": "foo", "new_string": "bar"}] + }); + let err = validate_tool_arguments(args.as_object().unwrap(), ¶meters); + assert!(err.is_err()); + let ToolError::Arguments { missing, unknown } = err.unwrap_err() else { + panic!("expected Arguments error"); + }; + assert_eq!(missing, vec!["old".to_owned(), "new".to_owned()]); + // preserve_order: keys iterate in insertion order from json! macro + assert_eq!(unknown, vec![ + "string_to_replace".to_owned(), + "new_string".to_owned() + ]); + + // Valid: non-object array items are skipped (no crash). + let args = json!({ + "path": "src/lib.rs", + "patterns": ["not an object"] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Valid: parameter is not an array (type mismatch, but not our job to check types). + let args = json!({ + "path": "src/lib.rs", + "patterns": "not an array" + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); +} + +#[test] +fn test_validate_nested_object_properties() { + let parameters = IndexMap::from_iter([ + ("name".to_owned(), param("string", true)), + ("config".to_owned(), ToolParameterConfig { + kind: "object".to_owned().into(), + required: false, + properties: IndexMap::from_iter([ + ("verbose".to_owned(), param("boolean", false)), + ("output".to_owned(), param("string", true)), + ]), + ..param("object", false) + }), + ]); + + // Valid. + let args = json!({ "name": "test", "config": { "verbose": true, "output": "out.txt" } }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Valid: optional object param omitted entirely. + let args = json!({ "name": "test" }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Invalid: unknown field inside the object. + let args = json!({ "name": "test", "config": { "output": "o", "bogus": 1 } }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Err(ToolError::Arguments { + missing: vec![], + unknown: vec!["bogus".to_owned()], + }) + ); + + // Invalid: missing required field inside the object. + let args = json!({ "name": "test", "config": { "verbose": true } }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Err(ToolError::Arguments { + missing: vec!["output".to_owned()], + unknown: vec![], + }) + ); +} + +/// Build a parameter with a default value. +fn param_with_default(kind: &str, required: bool, default: Value) -> ToolParameterConfig { + ToolParameterConfig { + default: Some(default), + ..param(kind, required) + } +} + +#[test] +fn test_apply_defaults_fills_missing_required_with_default() { + let parameters = IndexMap::from_iter([ + ("path".to_owned(), param("string", true)), + ( + "use_regex".to_owned(), + param_with_default("boolean", true, json!(false)), + ), + ]); + + let mut args: Map = Map::from_iter([("path".to_owned(), json!("src/lib.rs"))]); + + apply_parameter_defaults(&mut args, ¶meters); + + assert_eq!(args.get("path"), Some(&json!("src/lib.rs"))); + assert_eq!(args.get("use_regex"), Some(&json!(false))); +} + +#[test] +fn test_apply_defaults_does_not_overwrite_provided_values() { + let parameters = IndexMap::from_iter([( + "use_regex".to_owned(), + param_with_default("boolean", true, json!(false)), + )]); + + let mut args: Map = Map::from_iter([("use_regex".to_owned(), json!(true))]); + + apply_parameter_defaults(&mut args, ¶meters); + + assert_eq!(args.get("use_regex"), Some(&json!(true))); +} + +#[test] +fn test_apply_defaults_fills_optional_param_with_default() { + let parameters = IndexMap::from_iter([( + "verbose".to_owned(), + param_with_default("boolean", false, json!(false)), + )]); + + let mut args: Map = Map::new(); + apply_parameter_defaults(&mut args, ¶meters); + + assert_eq!(args.get("verbose"), Some(&json!(false))); +} + +#[test] +fn test_apply_defaults_skips_params_without_default() { + let parameters = IndexMap::from_iter([("path".to_owned(), param("string", true))]); + + let mut args: Map = Map::new(); + apply_parameter_defaults(&mut args, ¶meters); + + assert!(!args.contains_key("path")); +} + +#[test] +fn test_apply_defaults_recurses_into_objects() { + let parameters = IndexMap::from_iter([("config".to_owned(), ToolParameterConfig { + kind: "object".to_owned().into(), + required: false, + properties: IndexMap::from_iter([( + "verbose".to_owned(), + param_with_default("boolean", false, json!(true)), + )]), + ..param("object", false) + })]); + + let mut args: Map = Map::from_iter([("config".to_owned(), json!({}))]); + + apply_parameter_defaults(&mut args, ¶meters); + + assert_eq!(args["config"]["verbose"], json!(true)); +} + +#[test] +fn test_apply_defaults_recurses_into_array_items() { + let parameters = IndexMap::from_iter([("items".to_owned(), ToolParameterConfig { + kind: "array".to_owned().into(), + required: true, + items: Some(Box::new(ToolParameterConfig { + kind: "object".to_owned().into(), + required: false, + properties: IndexMap::from_iter([( + "enabled".to_owned(), + param_with_default("boolean", false, json!(true)), + )]), + ..param("object", false) + })), + ..param("array", true) + })]); + + let mut args: Map = Map::from_iter([( + "items".to_owned(), + json!([{"name": "a"}, {"name": "b", "enabled": false}]), + )]); + + apply_parameter_defaults(&mut args, ¶meters); + + let items = args["items"].as_array().unwrap(); + assert_eq!(items[0]["enabled"], json!(true)); + // Explicitly provided false is preserved. + assert_eq!(items[1]["enabled"], json!(false)); +} + +#[test] +fn test_apply_defaults_then_validate_passes() { + // Mirrors the fs_modify_file scenario: replace_using_regex is required + // with a default, and the LLM omits it. + let parameters = IndexMap::from_iter([ + ("path".to_owned(), param("string", true)), + ( + "replace_using_regex".to_owned(), + param_with_default("boolean", true, json!(false)), + ), + ]); + + let mut args: Map = Map::from_iter([("path".to_owned(), json!("README.md"))]); + + // Without defaults, validation would fail. + assert!(validate_tool_arguments(&args, ¶meters).is_err()); + + // After applying defaults, validation passes. + apply_parameter_defaults(&mut args, ¶meters); + assert!(validate_tool_arguments(&args, ¶meters).is_ok()); + assert_eq!(args["replace_using_regex"], json!(false)); +} + +#[test] +fn test_split_short_single_line() { + let (s, d) = split_description("Run cargo check."); + assert_eq!(s, "Run cargo check."); + assert_eq!(d, None); +} + +#[test] +fn test_split_short_no_period() { + let (s, d) = split_description("Run cargo check"); + assert_eq!(s, "Run cargo check"); + assert_eq!(d, None); +} + +#[test] +fn test_split_two_sentences() { + let (s, d) = split_description( + "Run cargo check on a package. Supports workspace packages and feature flags.", + ); + assert_eq!(s, "Run cargo check on a package."); + assert_eq!( + d, + Some("Supports workspace packages and feature flags.".to_owned()) + ); +} + +#[test] +fn test_split_multiline() { + let input = "Search for code in a repository.\n\nSupports regex and qualifiers."; + let (s, d) = split_description(input); + assert_eq!(s, "Search for code in a repository."); + assert_eq!(d, Some("Supports regex and qualifiers.".to_owned())); +} + +#[test] +fn test_split_multiline_no_period() { + let input = "First line without period\nSecond line here."; + let (s, d) = split_description(input); + assert_eq!(s, "First line without period"); + assert_eq!(d, Some("Second line here.".to_owned())); +} + +#[test] +fn test_split_preserves_abbreviations() { + // "e.g." should not be treated as a sentence boundary. + let (s, d) = split_description("Use e.g. foo or bar."); + assert_eq!(s, "Use e.g. foo or bar."); + assert_eq!(d, None); +} + +#[test] +fn test_split_long_single_line_with_period() { + let input = "This is a very long description that exceeds the threshold. It contains \ + additional details about the tool's behavior."; + let (s, d) = split_description(input); + assert_eq!( + s, + "This is a very long description that exceeds the threshold." + ); + assert!(d.is_some()); +} + +#[test] +fn test_split_empty() { + let (s, d) = split_description(""); + assert_eq!(s, ""); + assert_eq!(d, None); +} + +#[test] +fn test_split_trims_whitespace() { + let (s, d) = split_description(" hello "); + assert_eq!(s, "hello"); + assert_eq!(d, None); +}