From 633ea4436548180fc8ce6cfc32743376017b1be3 Mon Sep 17 00:00:00 2001 From: Jean Mertz Date: Wed, 25 Feb 2026 17:06:35 +0100 Subject: [PATCH] feat(llm): implement resilient streaming and native structured output Introduce a provider-agnostic resilience layer for LLM requests and enable native structured output support across all providers. This overhaul replaces the previous tool-based schema enforcement with native provider capabilities (OpenAI Strict Mode, Anthropic Structured Outputs, Gemini JSON Schema). It includes a robust error classification system (`StreamError`) that automatically handles retries with exponential backoff for transient network errors and rate limits. The tool execution layer is refactored to support asynchronous cancellation, improved parameter validation with automatic default injection, and a new builtin tool system starting with `describe_tools`. Additional enhancements include: - Support for adaptive reasoning in Claude Opus 4.6+. - Complete rewrite of the `llamacpp` provider to use SSE and support deepseek-style reasoning content. - A new `MockProvider` and `ExecutorSource` to facilitate side-effect free testing of conversation flows. - Automated description splitting for generating concise tool summaries. Signed-off-by: Jean Mertz --- crates/jp_llm/Cargo.toml | 12 +- crates/jp_llm/src/error.rs | 542 +++++- crates/jp_llm/src/error_tests.rs | 228 +++ crates/jp_llm/src/event.rs | 4 +- crates/jp_llm/src/lib.rs | 11 +- crates/jp_llm/src/model.rs | 81 +- crates/jp_llm/src/provider.rs | 369 +--- crates/jp_llm/src/provider/anthropic.rs | 834 +++++---- crates/jp_llm/src/provider/anthropic_tests.rs | 746 ++++++++ crates/jp_llm/src/provider/google.rs | 627 +++++-- crates/jp_llm/src/provider/google_tests.rs | 501 ++++++ crates/jp_llm/src/provider/llamacpp.rs | 705 +++++--- crates/jp_llm/src/provider/llamacpp_tests.rs | 181 ++ crates/jp_llm/src/provider/mock.rs | 203 +++ crates/jp_llm/src/provider/mock_tests.rs | 124 ++ crates/jp_llm/src/provider/ollama.rs | 377 ++-- crates/jp_llm/src/provider/openai.rs | 615 +++---- crates/jp_llm/src/provider/openai_tests.rs | 378 ++++ crates/jp_llm/src/provider/openrouter.rs | 258 +-- .../jp_llm/src/provider/openrouter_tests.rs | 51 + crates/jp_llm/src/provider_tests.rs | 198 +++ crates/jp_llm/src/query.rs | 61 +- crates/jp_llm/src/query/chat.rs | 33 - crates/jp_llm/src/query/structured.rs | 181 -- crates/jp_llm/src/retry.rs | 116 ++ crates/jp_llm/src/retry_tests.rs | 61 + crates/jp_llm/src/stream.rs | 8 +- crates/jp_llm/src/stream/aggregator.rs | 1 - crates/jp_llm/src/stream/aggregator/chunk.rs | 195 --- .../jp_llm/src/stream/aggregator/reasoning.rs | 3 + crates/jp_llm/src/stream/chain.rs | 22 +- crates/jp_llm/src/structured.rs | 31 - crates/jp_llm/src/structured/titles.rs | 83 - crates/jp_llm/src/test.rs | 225 +-- crates/jp_llm/src/title.rs | 97 ++ crates/jp_llm/src/title_tests.rs | 62 + crates/jp_llm/src/tool.rs | 1521 ++++++++++------- crates/jp_llm/src/tool/builtin.rs | 45 + .../jp_llm/src/tool/builtin/describe_tools.rs | 126 ++ .../src/tool/builtin/describe_tools_tests.rs | 343 ++++ crates/jp_llm/src/tool/executor.rs | 346 ++++ crates/jp_llm/src/tool_tests.rs | 567 ++++++ 42 files changed, 7986 insertions(+), 3186 deletions(-) create mode 100644 crates/jp_llm/src/error_tests.rs create mode 100644 crates/jp_llm/src/provider/anthropic_tests.rs create mode 100644 crates/jp_llm/src/provider/google_tests.rs create mode 100644 crates/jp_llm/src/provider/llamacpp_tests.rs create mode 100644 crates/jp_llm/src/provider/mock.rs create mode 100644 crates/jp_llm/src/provider/mock_tests.rs create mode 100644 crates/jp_llm/src/provider/openai_tests.rs create mode 100644 crates/jp_llm/src/provider/openrouter_tests.rs create mode 100644 crates/jp_llm/src/provider_tests.rs delete mode 100644 crates/jp_llm/src/query/chat.rs delete mode 100644 crates/jp_llm/src/query/structured.rs create mode 100644 crates/jp_llm/src/retry.rs create mode 100644 crates/jp_llm/src/retry_tests.rs delete mode 100644 crates/jp_llm/src/stream/aggregator/chunk.rs delete mode 100644 crates/jp_llm/src/structured.rs delete mode 100644 crates/jp_llm/src/structured/titles.rs create mode 100644 crates/jp_llm/src/title.rs create mode 100644 crates/jp_llm/src/title_tests.rs create mode 100644 crates/jp_llm/src/tool/builtin.rs create mode 100644 crates/jp_llm/src/tool/builtin/describe_tools.rs create mode 100644 crates/jp_llm/src/tool/builtin/describe_tools_tests.rs create mode 100644 crates/jp_llm/src/tool/executor.rs create mode 100644 crates/jp_llm/src/tool_tests.rs diff --git a/crates/jp_llm/Cargo.toml b/crates/jp_llm/Cargo.toml index f5a828e5..e50561b8 100644 --- a/crates/jp_llm/Cargo.toml +++ b/crates/jp_llm/Cargo.toml @@ -15,23 +15,18 @@ version.workspace = true [dependencies] jp_config = { workspace = true } jp_conversation = { workspace = true } -jp_inquire = { workspace = true } jp_mcp = { workspace = true } jp_openrouter = { workspace = true } -jp_printer = { workspace = true } jp_tool = { workspace = true } async-anthropic = { workspace = true } async-stream = { workspace = true } async-trait = { workspace = true } camino = { workspace = true } -crossterm = { workspace = true } -duct = { workspace = true } -duct_sh = { workspace = true } +chrono = { workspace = true } futures = { workspace = true } gemini_client_rs = { workspace = true } indexmap = { workspace = true } -jsonschema = { workspace = true } minijinja = { workspace = true, features = [ "builtins", "json", @@ -41,18 +36,15 @@ minijinja = { workspace = true, features = [ ] } ollama-rs = { workspace = true, features = ["rustls", "stream"] } open-editor = { workspace = true } -openai = { workspace = true, features = ["rustls", "reqwest"] } openai_responses = { workspace = true, features = ["stream"] } quick-xml = { workspace = true, features = ["serialize"] } reqwest = { workspace = true } reqwest-eventsource = { workspace = true } -schemars = { workspace = true } serde = { workspace = true } serde_json = { workspace = true, features = ["preserve_order"] } thiserror = { workspace = true } -chrono = { workspace = true } tokio = { workspace = true } -tokio-stream = { workspace = true } +tokio-util = { workspace = true } tracing = { workspace = true } url = { workspace = true } diff --git a/crates/jp_llm/src/error.rs b/crates/jp_llm/src/error.rs index 4a9d552d..e1d0b653 100644 --- a/crates/jp_llm/src/error.rs +++ b/crates/jp_llm/src/error.rs @@ -1,9 +1,262 @@ +use std::{ + fmt, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use async_anthropic::errors::AnthropicError; +use reqwest::header::{HeaderMap, RETRY_AFTER}; +use serde_json::Value; + use crate::stream::aggregator::tool_call_request::AggregationError; pub(crate) type Result = std::result::Result; +/// A provider-agnostic streaming error. +#[derive(Debug)] +pub struct StreamError { + /// The kind of streaming error. + pub kind: StreamErrorKind, + + /// Whether and when the request can be retried. + /// + /// If `Some`, the request can be retried after the specified duration. + /// If `None`, the caller should use exponential backoff or not retry. + pub retry_after: Option, + + /// Human-readable error message. + message: String, + + /// The underlying source of the error. + /// + /// This is kept for logging and display purposes, but callers should + /// make decisions based on `kind` and `retry_after`, not the source. + source: Option>, +} + +impl StreamError { + /// Create a new stream error. + #[must_use] + pub fn new(kind: StreamErrorKind, message: impl Into) -> Self { + Self { + kind, + message: message.into(), + retry_after: None, + source: None, + } + } + + /// Create a timeout error. + #[must_use] + pub fn timeout(message: impl Into) -> Self { + Self::new(StreamErrorKind::Timeout, message) + } + + /// Create a connection error. + #[must_use] + pub fn connect(message: impl Into) -> Self { + Self::new(StreamErrorKind::Connect, message) + } + + /// Create a rate limit error. + #[must_use] + pub fn rate_limit(retry_after: Option) -> Self { + Self { + kind: StreamErrorKind::RateLimit, + retry_after, + message: "Rate limited".into(), + source: None, + } + } + + /// Create a transient error. + #[must_use] + pub fn transient(message: impl Into) -> Self { + Self::new(StreamErrorKind::Transient, message) + } + + /// Create an error from another error type. + #[must_use] + pub fn other(message: impl Into) -> Self { + Self::new(StreamErrorKind::Other, message) + } + + /// Set the `retry_after` duration. + #[must_use] + pub fn with_retry_after(mut self, duration: Duration) -> Self { + self.retry_after = Some(duration); + self + } + + /// Set the source error. + #[must_use] + pub fn with_source( + mut self, + source: impl Into>, + ) -> Self { + self.source = Some(source.into()); + self + } + + /// Returns whether this error is likely retryable. + #[must_use] + pub fn is_retryable(&self) -> bool { + matches!( + self.kind, + StreamErrorKind::Timeout + | StreamErrorKind::Connect + | StreamErrorKind::RateLimit + | StreamErrorKind::Transient + ) || self.retry_after.is_some() + } +} + +impl fmt::Display for StreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message)?; + if let Some(ref source) = self.source { + write!(f, ": {source}")?; + } + + Ok(()) + } +} + +impl std::error::Error for StreamError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source + .as_ref() + .map(|e| e.as_ref() as &(dyn std::error::Error + 'static)) + } +} + +/// Canonical classifier for [`reqwest::Error`]. +/// +/// Provider-specific `From` impls should delegate to this when they encounter +/// an inner [`reqwest::Error`], rather than re-implementing the classification +/// logic. +impl From for StreamError { + fn from(err: reqwest::Error) -> Self { + if err.is_timeout() { + StreamError::timeout(err.to_string()).with_source(err) + } else if err.is_connect() { + StreamError::connect(err.to_string()).with_source(err) + } else if err.status().is_some_and(|s| s == 429) { + StreamError::rate_limit(None).with_source(err) + } else if err + .status() + .is_some_and(|s| matches!(s.as_u16(), 408 | 409 | _ if s.as_u16() >= 500)) + || err.is_body() + || err.is_decode() + { + StreamError::transient(err.to_string()).with_source(err) + } else { + StreamError::other(err.to_string()).with_source(err) + } + } +} + +/// Canonical classifier for [`reqwest_eventsource::Error`]. +/// +/// Delegates the [`Transport`] variant to the [`reqwest::Error`] classifier. +/// Classifies [`InvalidStatusCode`] by its HTTP status, extracts `Retry-After` +/// headers, and honours the non-standard `x-should-retry` header used by +/// OpenAI. +/// +/// Provider-specific `From` impls should delegate to this when they encounter +/// an inner [`reqwest_eventsource::Error`], rather than re-implementing the +/// classification logic. +/// +/// [`Transport`]: reqwest_eventsource::Error::Transport +/// [`InvalidStatusCode`]: reqwest_eventsource::Error::InvalidStatusCode +impl From for StreamError { + fn from(err: reqwest_eventsource::Error) -> Self { + use reqwest_eventsource::Error; + + match err { + Error::Transport(error) => Self::from(error), + Error::InvalidStatusCode(status, response) => { + let headers = response.headers(); + let retry_after = extract_retry_after(headers); + let code = status.as_u16(); + + // Non-standard `x-should-retry` header overrides + // status-code heuristics. + let retryable = match header_str(headers, "x-should-retry") { + Some("true") => true, + Some("false") => false, + _ => matches!(code, 408 | 409 | 429 | _ if code >= 500), + }; + + if !retryable { + StreamError::other(format!("HTTP {status}")) + } else if code == 429 { + StreamError::rate_limit(retry_after) + } else { + let err = StreamError::transient(format!("HTTP {status}")); + match retry_after { + Some(d) => err.with_retry_after(d), + None => err, + } + } + } + error @ (Error::Utf8(_) + | Error::Parser(_) + | Error::InvalidContentType(_, _) + | Error::InvalidLastEventId(_) + | Error::StreamEnded) => StreamError::other(error.to_string()).with_source(error), + } + } +} + +/// The kind of streaming error. +/// +/// This abstraction allows the resilience layer to make retry decisions without +/// knowing the specific provider implementation details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamErrorKind { + /// Request timed out. + Timeout, + + /// Failed to establish connection. + Connect, + + /// Rate limited by the provider. + RateLimit, + + /// Transient error (server error, temporary failure). + /// These are typically safe to retry. + Transient, + + /// The API key's quota has been exhausted. + /// This is not retryable — the user needs to top up or change plans. + InsufficientQuota, + + /// Other errors that are not categorized. + /// These may or may not be retryable depending on the specific error. + Other, +} + +impl StreamErrorKind { + /// Returns the error kind as a string. + #[must_use] + pub fn as_str(&self) -> &'static str { + match self { + Self::Timeout => "Timeout", + Self::Connect => "Connection error", + Self::RateLimit => "Rate limited", + Self::Transient => "Server error", + Self::InsufficientQuota => "Insufficient API quota", + Self::Other => "Stream Error", + } + } +} + #[derive(Debug, thiserror::Error)] pub enum Error { + /// Streaming error with provider-agnostic classification. + #[error(transparent)] + Stream(#[from] StreamError), + #[error("OpenRouter error: {0}")] OpenRouter(#[from] jp_openrouter::Error), @@ -43,9 +296,6 @@ pub enum Error { #[error("Ollama error: {0}")] Ollama(#[from] ollama_rs::error::OllamaError), - #[error("Missing structured data in response")] - MissingStructuredData, - #[error("Unknown model: {0}")] UnknownModel(String), @@ -56,7 +306,7 @@ pub enum Error { Request(#[from] reqwest::Error), #[error("Anthropic error: {0}")] - Anthropic(#[from] async_anthropic::errors::AnthropicError), + Anthropic(#[from] AnthropicError), #[error("Anthropic request builder error: {0}")] AnthropicRequestBuilder(#[from] async_anthropic::types::CreateMessagesRequestBuilderError), @@ -79,37 +329,6 @@ pub enum Error { ToolCallRequestAggregator(#[from] AggregationError), } -impl From for Error { - fn from(error: gemini_client_rs::GeminiError) -> Self { - use gemini_client_rs::GeminiError; - - match &error { - GeminiError::Api(api) if api.get("status").is_some_and(|v| v.as_u64() == Some(404)) => { - if let Some(model) = api.pointer("/message/error/message").and_then(|v| { - v.as_str().and_then(|s| { - s.contains("Call ListModels").then(|| { - s.split('/') - .nth(1) - .and_then(|v| v.split(' ').next()) - .unwrap_or("unknown") - }) - }) - }) { - return Self::UnknownModel(model.to_owned()); - } - Self::Gemini(error) - } - _ => Self::Gemini(error), - } - } -} - -impl From for Error { - fn from(error: openai_responses::types::response::Error) -> Self { - Self::OpenaiResponse(error) - } -} - #[cfg(test)] impl PartialEq for Error { fn eq(&self, other: &Self) -> bool { @@ -147,7 +366,7 @@ pub enum ToolError { #[error("Failed to serialize tool arguments")] SerializeArgumentsError { - arguments: serde_json::Value, + arguments: Value, #[source] error: serde_json::Error, }, @@ -155,16 +374,23 @@ pub enum ToolError { #[error("Tool call failed: {0}")] ToolCallFailed(String), + #[error("Failed to spawn command: {command}")] + SpawnError { + command: String, + #[source] + error: std::io::Error, + }, + #[error("Failed to open editor to edit tool call")] OpenEditorError { - arguments: serde_json::Value, + arguments: Value, #[source] error: open_editor::errors::OpenEditorError, }, #[error("Failed to edit tool call")] EditArgumentsError { - arguments: serde_json::Value, + arguments: Value, #[source] error: serde_json::Error, }, @@ -179,7 +405,7 @@ pub enum ToolError { #[error("Invalid `type` property for {key}, got {value:?}, expected one of {need:?}")] InvalidType { key: String, - value: serde_json::Value, + value: Value, need: Vec<&'static str>, }, @@ -219,3 +445,241 @@ impl PartialEq for ToolError { format!("{self:?}") == format!("{other:?}") } } + +/// Heuristic check for quota/billing exhaustion based on error text. +/// +/// This catches the common patterns across providers: +/// - OpenAI: `"insufficient_quota"` +/// - Anthropic: `"billing_error"`, `"Your credit balance is too low"` +/// - Google: `"RESOURCE_EXHAUSTED"`, `"Quota exceeded"` +/// - OpenRouter: `"insufficient credits"`, `"out of credits"` +pub(crate) fn looks_like_quota_error(text: &str) -> bool { + let lower = text.to_ascii_lowercase(); + lower.contains("insufficient_quota") + || lower.contains("insufficient quota") + || lower.contains("insufficient credits") + || lower.contains("out of credits") + || lower.contains("billing_error") + || lower.contains("credit balance is too low") + || lower.contains("quota exceeded") + || lower.contains("resource_exhausted") +} + +/// Extracts a retry-after duration from an error message body. +/// +/// Last-resort fallback when response headers don't carry retry timing. +/// Matches common natural-language patterns found in API error responses: +/// +/// - `"retry after 30 seconds"` +/// - `"retry-after: 30"` +/// - `"wait 30 seconds"` +/// - `"try again in 5s"` / `"try again in 5.5s"` +/// - `"retryDelay": "30s"` (Google Gemini JSON body) +pub(crate) fn extract_retry_from_text(text: &str) -> Option { + let lower = text.to_ascii_lowercase(); + + // Scan for patterns and extract the first match. + for window in lower.split_whitespace().collect::>().windows(4) { + // "retry after N second(s)" + if window[0] == "retry" + && window[1] == "after" + && let Some(secs) = parse_secs_token(window[2]) + { + return Some(Duration::from_secs(secs)); + } + // "wait N second(s)" + if window[0] == "wait" + && let Some(secs) = parse_secs_token(window[1]) + { + return Some(Duration::from_secs(secs)); + } + // "try again in Ns" / "try again in N.Ns" + if window[0] == "try" + && window[1] == "again" + && window[2] == "in" + && let Some(secs) = parse_secs_token(window[3]) + { + return Some(Duration::from_secs(secs)); + } + } + + // "retry-after: N" / "retry-after:N" + if let Some(pos) = lower.find("retry-after:") { + let after = lower[pos + "retry-after:".len()..].trim_start(); + if let Some(secs) = after + .split(|c: char| !c.is_ascii_digit()) + .next() + .and_then(|s| s.parse::().ok()) + .filter(|s| *s > 0) + { + return Some(Duration::from_secs(secs)); + } + } + + // "retryDelay": "30s" (Gemini JSON body) + if let Some(pos) = lower.find("retrydelay") { + let after = &lower[pos..]; + if let Some(d) = after + .split('"') + .find(|s| s.ends_with('s') && s[..s.len() - 1].chars().all(|c| c.is_ascii_digit())) + .and_then(parse_human_duration) + { + return Some(Duration::from_secs(d)); + } + } + + None +} + +/// Extracts a retry-after duration from common rate-limit response +/// headers. +/// +/// Headers are checked in decreasing order of authority: +/// +/// 1. `retry-after-ms` — Non-standard (OpenAI). Millisecond precision. +/// 2. `Retry-After` — RFC 7231. Integer or float seconds (the spec +/// mandates integers, but floats are accepted). HTTP-date values are +/// not supported. +/// 3. `RateLimit` — IETF draft `t=` parameter (delta-seconds). +/// See: +/// 4. `x-ratelimit-reset-requests` / `x-ratelimit-reset-tokens` — +/// OpenAI-style Human-duration values (e.g. `6m0s`). Takes the longer +/// of the two if both are present. +/// 5. `x-ratelimit-reset` — Unix timestamp, converted relative to now. +fn extract_retry_after(headers: &HeaderMap) -> Option { + if let Some(d) = header_positive_f64(headers, "retry-after-ms") + .map(|ms| Duration::from_secs_f64(ms / 1000.0)) + { + return Some(d); + } + + if let Some(d) = header_positive_f64(headers, RETRY_AFTER).map(Duration::from_secs_f64) { + return Some(d); + } + + // IETF draft: `RateLimit: remaining=0; t=` + if let Some(secs) = headers + .get("ratelimit") + .and_then(|v| v.to_str().ok()) + .and_then(|v| { + v.split(';') + .map(str::trim) + .find_map(|p| p.strip_prefix("t=")) + }) + .and_then(|v| v.trim().parse::().ok()) + .filter(|v| *v > 0) + { + return Some(Duration::from_secs(secs)); + } + + // 4. OpenAI: `x-ratelimit-reset-requests` / `x-ratelimit-reset-tokens` + // Both use human-duration format (e.g. "1s", "6m0s"). Take the max. + let requests = header_str(headers, "x-ratelimit-reset-requests").and_then(parse_human_duration); + let tokens = header_str(headers, "x-ratelimit-reset-tokens").and_then(parse_human_duration); + + if let Some(secs) = requests.into_iter().chain(tokens).max() { + return Some(Duration::from_secs(secs)); + } + + if let Some(reset_ts) = header_u64(headers, "x-ratelimit-reset") { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if reset_ts > now { + return Some(Duration::from_secs(reset_ts - now)); + } + } + + None +} + +/// Read a header value as `&str`. +fn header_str(headers: &HeaderMap, name: impl reqwest::header::AsHeaderName) -> Option<&str> { + headers.get(name).and_then(|v| v.to_str().ok()) +} + +/// Read a header value as `u64`. +fn header_u64(headers: &HeaderMap, name: impl reqwest::header::AsHeaderName) -> Option { + header_str(headers, name).and_then(|s| s.parse().ok()) +} + +/// Read a header value as a positive, finite `f64`. +fn header_positive_f64( + headers: &HeaderMap, + name: impl reqwest::header::AsHeaderName, +) -> Option { + header_str(headers, name) + .and_then(|s| s.parse::().ok()) + .filter(|v| *v > 0.0 && v.is_finite()) +} + +/// Parses a human-style duration string into whole seconds. +/// +/// Supported units: `h` (hours), `m` (minutes), `s` (seconds), `ms` +/// (milliseconds — rounded up to 1s if non-zero and total is 0). +/// +/// Examples: `"1s"` → 1, `"6m0s"` → 360, `"1h30m"` → 5400, +/// `"200ms"` → 1. +/// +/// Returns `None` for empty, zero, or unparseable values. +fn parse_human_duration(s: &str) -> Option { + let mut total: u64 = 0; + let mut has_sub_second = false; + let mut remaining = s.trim(); + + while !remaining.is_empty() { + let num_end = remaining + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(remaining.len()); + + if num_end == 0 { + return None; + } + + let n: u64 = remaining[..num_end].parse().ok()?; + remaining = &remaining[num_end..]; + + if remaining.starts_with("ms") { + has_sub_second = n > 0; + remaining = &remaining[2..]; + } else if remaining.starts_with('h') { + total += n * 3600; + remaining = &remaining[1..]; + } else if remaining.starts_with('m') { + total += n * 60; + remaining = &remaining[1..]; + } else if remaining.starts_with('s') { + total += n; + remaining = &remaining[1..]; + } else { + return None; + } + } + + // Round sub-second durations up to 1s so we don't return 0 + // when the server asked us to wait. + if total == 0 && has_sub_second { + total = 1; + } + + if total > 0 { Some(total) } else { None } +} + +/// Parse a token like `"30"`, `"30s"`, `"5.5s"` into whole seconds. +#[expect(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +fn parse_secs_token(s: &str) -> Option { + let s = s + .trim_end_matches('s') + .trim_end_matches("second") + .trim_end_matches(','); + s.parse::() + .ok() + .filter(|v| *v > 0.0 && v.is_finite()) + .map(|v| v.ceil() as u64) +} + +#[cfg(test)] +#[path = "error_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/error_tests.rs b/crates/jp_llm/src/error_tests.rs new file mode 100644 index 00000000..4b94d2e1 --- /dev/null +++ b/crates/jp_llm/src/error_tests.rs @@ -0,0 +1,228 @@ +use super::*; + +#[test] +fn extract_retry_after_from_retry_after_ms() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("retry-after-ms", "1500".parse().unwrap()); + + assert_eq!( + extract_retry_after(&headers), + Some(Duration::from_millis(1500)) + ); +} + +#[test] +fn extract_retry_after_ms_takes_priority_over_retry_after() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("retry-after-ms", "500".parse().unwrap()); + headers.insert(reqwest::header::RETRY_AFTER, "30".parse().unwrap()); + + // retry-after-ms is more precise, should be preferred. + assert_eq!( + extract_retry_after(&headers), + Some(Duration::from_millis(500)) + ); +} + +#[test] +fn extract_retry_after_from_standard_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(reqwest::header::RETRY_AFTER, "30".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_secs(30))); +} + +#[test] +fn extract_retry_after_accepts_float_seconds() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(reqwest::header::RETRY_AFTER, "1.5".parse().unwrap()); + + assert_eq!( + extract_retry_after(&headers), + Some(Duration::from_millis(1500)) + ); +} + +#[test] +fn extract_retry_after_ignores_http_date() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::RETRY_AFTER, + "Wed, 21 Oct 2025 07:28:00 GMT".parse().unwrap(), + ); + + // HTTP-date is not supported, should return None. + assert_eq!(extract_retry_after(&headers), None); +} + +#[test] +fn extract_retry_after_from_ietf_ratelimit_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("ratelimit", "remaining=0; t=45".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_secs(45))); +} + +#[test] +fn extract_retry_after_from_openai_reset_requests() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset-requests", "6m0s".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_mins(6))); +} + +#[test] +fn extract_retry_after_from_openai_reset_tokens() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset-tokens", "1s".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_secs(1))); +} + +#[test] +fn extract_retry_after_openai_takes_max_of_both() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset-requests", "2s".parse().unwrap()); + headers.insert("x-ratelimit-reset-tokens", "6m0s".parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_mins(6))); +} + +#[test] +fn extract_retry_after_from_ratelimit_reset() { + let future_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + + 45; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset", future_ts.to_string().parse().unwrap()); + + let result = extract_retry_after(&headers).unwrap(); + // Allow 1s tolerance for test execution time. + assert!(result.as_secs() >= 44 && result.as_secs() <= 46); +} + +#[test] +fn extract_retry_after_prefers_standard_header() { + let future_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + + 120; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(reqwest::header::RETRY_AFTER, "10".parse().unwrap()); + headers.insert("x-ratelimit-reset", future_ts.to_string().parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), Some(Duration::from_secs(10))); +} + +#[test] +fn extract_retry_after_past_reset_returns_none() { + let past_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + - 60; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-ratelimit-reset", past_ts.to_string().parse().unwrap()); + + assert_eq!(extract_retry_after(&headers), None); +} + +#[test] +fn extract_retry_after_empty_headers() { + let headers = reqwest::header::HeaderMap::new(); + assert_eq!(extract_retry_after(&headers), None); +} + +#[test] +fn human_duration_seconds() { + assert_eq!(parse_human_duration("1s"), Some(1)); + assert_eq!(parse_human_duration("30s"), Some(30)); +} + +#[test] +fn human_duration_minutes_and_seconds() { + assert_eq!(parse_human_duration("6m0s"), Some(360)); + assert_eq!(parse_human_duration("1m30s"), Some(90)); +} + +#[test] +fn human_duration_hours() { + assert_eq!(parse_human_duration("1h30m0s"), Some(5400)); + assert_eq!(parse_human_duration("2h"), Some(7200)); +} + +#[test] +fn human_duration_milliseconds_rounds_up() { + assert_eq!(parse_human_duration("200ms"), Some(1)); + assert_eq!(parse_human_duration("0ms"), None); +} + +#[test] +fn human_duration_mixed_with_ms() { + // 1 second + 500ms = 1s (ms doesn't add to whole seconds). + assert_eq!(parse_human_duration("1s500ms"), Some(1)); +} + +#[test] +fn human_duration_zero_returns_none() { + assert_eq!(parse_human_duration("0s"), None); + assert_eq!(parse_human_duration("0m0s"), None); +} + +#[test] +fn human_duration_invalid() { + assert_eq!(parse_human_duration(""), None); + assert_eq!(parse_human_duration("abc"), None); + assert_eq!(parse_human_duration("5x"), None); +} + +#[test] +fn text_retry_after_n_seconds() { + let text = "Rate limit exceeded. Please retry after 30 seconds."; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_secs(30))); +} + +#[test] +fn text_wait_n_seconds() { + let text = "Too many requests. Please wait 60 seconds before trying again."; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_mins(1))); +} + +#[test] +fn text_try_again_in_ns() { + let text = "Service busy, try again in 5s"; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_secs(5))); +} + +#[test] +fn text_try_again_in_float() { + let text = "Overloaded, try again in 5.5s please"; + assert_eq!( + extract_retry_from_text(text), + Some(Duration::from_secs(6)) // ceil(5.5) + ); +} + +#[test] +fn text_retry_after_colon() { + let text = "Error: retry-after: 15"; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_secs(15))); +} + +#[test] +fn text_gemini_retry_delay() { + let text = r#"{"error":{"details":[{"retryDelay":"30s"}]}}"#; + assert_eq!(extract_retry_from_text(text), Some(Duration::from_secs(30))); +} + +#[test] +fn text_no_pattern_returns_none() { + assert_eq!(extract_retry_from_text("Something went wrong"), None); + assert_eq!(extract_retry_from_text(""), None); +} diff --git a/crates/jp_llm/src/event.rs b/crates/jp_llm/src/event.rs index e19fc475..8cd2656e 100644 --- a/crates/jp_llm/src/event.rs +++ b/crates/jp_llm/src/event.rs @@ -14,8 +14,8 @@ use serde_json::Value; /// will have different `index` values corresponding to their relative order. /// /// Only after [`Event::Flush`] is produced with the given `index` value, should -/// all previous parts be merged into a single [`ConversationEvent`], using -/// `EventAggregator`. +/// all previous parts be merged into a single [`ConversationEvent`] (e.g. via +/// [`EventBuilder`](jp_conversation::event_builder::EventBuilder)). #[derive(Debug, Clone, PartialEq)] pub enum Event { /// A part of a completed event. diff --git a/crates/jp_llm/src/lib.rs b/crates/jp_llm/src/lib.rs index 35d9ca18..3cdc6b47 100644 --- a/crates/jp_llm/src/lib.rs +++ b/crates/jp_llm/src/lib.rs @@ -1,15 +1,18 @@ -mod error; +pub mod error; pub mod event; pub mod model; pub mod provider; pub mod query; +pub mod retry; mod stream; -pub mod structured; +pub mod title; pub mod tool; #[cfg(test)] pub(crate) mod test; -pub use error::{Error, ToolError}; +pub use error::{Error, StreamError, StreamErrorKind, ToolError}; pub use provider::Provider; -pub use stream::{aggregator::tool_call_request::AggregationError, chain::EventChain}; +pub use retry::exponential_backoff; +pub use stream::{EventStream, aggregator::tool_call_request::AggregationError, chain::EventChain}; +pub use tool::{CommandResult, ExecutionOutcome, run_tool_command}; diff --git a/crates/jp_llm/src/model.rs b/crates/jp_llm/src/model.rs index 137e3249..310b0402 100644 --- a/crates/jp_llm/src/model.rs +++ b/crates/jp_llm/src/model.rs @@ -49,6 +49,11 @@ impl ModelDetails { } } + #[must_use] + pub fn name(&self) -> &str { + self.display_name.as_deref().unwrap_or(&self.id.name) + } + #[must_use] pub fn custom_reasoning_config( &self, @@ -134,6 +139,21 @@ impl ModelDetails { // Custom configuration, so use it. Some(ReasoningConfig::Custom(custom)) => Some(custom), }, + + // Adaptive + Some(ReasoningDetails::Adaptive { .. }) => match config { + // Off, so disabled. + Some(ReasoningConfig::Off) => None, + + // Unconfigured or auto, so use high effort (the API default). + None | Some(ReasoningConfig::Auto) => Some(CustomReasoningConfig { + effort: ReasoningEffort::High, + exclude: false, + }), + + // Custom configuration, so use it. + Some(ReasoningConfig::Custom(custom)) => Some(custom), + }, } } } @@ -207,6 +227,18 @@ pub enum ReasoningDetails { /// Whether the model supports extremely high effort reasoning. xhigh: bool, }, + + /// Adaptive reasoning support. + /// + /// The model dynamically decides when and how much to think based on + /// task complexity. Uses effort levels (low/medium/high/max) instead of + /// token budgets. + /// + /// Currently only supported by Claude Opus 4.6+. + Adaptive { + /// Whether the model supports `max` effort level. + max: bool, + }, } impl ReasoningDetails { @@ -230,6 +262,11 @@ impl ReasoningDetails { } } + #[must_use] + pub fn adaptive(max: bool) -> Self { + Self::Adaptive { max } + } + #[must_use] pub fn unsupported() -> Self { Self::Unsupported @@ -239,7 +276,7 @@ impl ReasoningDetails { pub fn min_tokens(&self) -> u32 { match self { Self::Budgetted { min_tokens, .. } => *min_tokens, - Self::Leveled { .. } | Self::Unsupported => 0, + Self::Leveled { .. } | Self::Adaptive { .. } | Self::Unsupported => 0, } } @@ -247,7 +284,42 @@ impl ReasoningDetails { pub fn max_tokens(&self) -> Option { match self { Self::Budgetted { max_tokens, .. } => *max_tokens, - Self::Leveled { .. } | Self::Unsupported => None, + Self::Leveled { .. } | Self::Adaptive { .. } | Self::Unsupported => None, + } + } + + /// Returns the lowest reasoning effort level supported by this model, if + /// known. + /// + /// `Leveled` models return their lowest supported level. Other variants + /// return `Option::None` — callers should decide how to handle "disable + /// reasoning" for their provider (e.g. token budget 0, effort `minimal`, + /// `thinking: disabled`, etc.). + #[must_use] + pub fn lowest_effort(&self) -> Option { + match self { + Self::Leveled { + xlow, + low, + medium, + high, + xhigh, + } => { + if *xlow { + Some(ReasoningEffort::Xlow) + } else if *low { + Some(ReasoningEffort::Low) + } else if *medium { + Some(ReasoningEffort::Medium) + } else if *high { + Some(ReasoningEffort::High) + } else if *xhigh { + Some(ReasoningEffort::XHigh) + } else { + None + } + } + _ => None, } } @@ -265,4 +337,9 @@ impl ReasoningDetails { pub fn is_leveled(&self) -> bool { matches!(self, Self::Leveled { .. }) } + + #[must_use] + pub fn is_adaptive(&self) -> bool { + matches!(self, Self::Adaptive { .. }) + } } diff --git a/crates/jp_llm/src/provider.rs b/crates/jp_llm/src/provider.rs index 21e4b6bd..afa4d6ad 100644 --- a/crates/jp_llm/src/provider.rs +++ b/crates/jp_llm/src/provider.rs @@ -3,35 +3,26 @@ pub mod google; // pub mod xai; pub mod anthropic; pub mod llamacpp; +pub mod mock; pub mod ollama; pub mod openai; pub mod openrouter; use anthropic::Anthropic; use async_trait::async_trait; -use futures::{TryStreamExt as _, stream}; use google::Google; use jp_config::{ - assistant::instructions::InstructionsConfig, model::id::{Name, ProviderId}, providers::llm::LlmProviderConfig, }; -use jp_conversation::event::ConversationEvent; use llamacpp::Llamacpp; use ollama::Ollama; use openai::Openai; use openrouter::Openrouter; -use serde_json::Value; -use tracing::warn; use crate::{ - Error, - error::Result, - event::Event, - model::ModelDetails, - query::{ChatQuery, StructuredQuery}, - stream::{EventStream, aggregator::chunk::EventAggregator}, - structured::SCHEMA_TOOL_NAME, + error::Result, model::ModelDetails, provider::mock::MockProvider, query::ChatQuery, + stream::EventStream, }; #[async_trait] @@ -48,369 +39,27 @@ pub trait Provider: std::fmt::Debug + Send + Sync { model: &ModelDetails, query: ChatQuery, ) -> Result; - - /// Perform a non-streaming chat completion. - /// - /// Default implementation collects results from the streaming version. - async fn chat_completion(&self, model: &ModelDetails, query: ChatQuery) -> Result> { - let mut aggregator = EventAggregator::new(); - self.chat_completion_stream(model, query) - .await? - .map_ok(|event| stream::iter(aggregator.ingest(event).into_iter().map(Ok))) - .try_flatten() - .try_collect() - .await - } - - /// Perform a structured completion. - /// - /// Default implementation uses a specialized tool-call to get structured - /// results. - /// - /// Providers that have a dedicated structured response endpoint should - /// override this method. - async fn structured_completion( - &self, - model: &ModelDetails, - query: StructuredQuery, - ) -> Result { - let mut chat_query = ChatQuery { - thread: query.thread.clone(), - tools: vec![query.tool_definition()?], - tool_choice: query.tool_choice()?, - tool_call_strict_mode: true, - }; - - let max_retries = 3; - for i in 1..=max_retries { - let result = self.chat_completion(model, chat_query.clone()).await; - let events = match result { - Ok(events) => events, - Err(error) if i >= max_retries => return Err(error), - Err(error) => { - warn!(%error, "Error while getting structured data. Retrying in non-strict mode."); - chat_query.tool_call_strict_mode = false; - continue; - } - }; - - let data = events - .into_iter() - .filter_map(Event::into_conversation_event) - .filter_map(ConversationEvent::into_tool_call_request) - .find(|call| call.name == SCHEMA_TOOL_NAME) - .map(|call| Value::Object(call.arguments)); - - let result = data - .ok_or("Did not receive any structured data".to_owned()) - .and_then(|data| query.validate(&data).map(|()| data)); - - match result { - Ok(data) => return Ok(query.map(data)), - Err(error) => { - warn!(error, "Failed to fetch structured data. Retrying."); - - chat_query.thread.instructions.push( - InstructionsConfig::default() - .with_title("Structured Data Validation Error") - .with_description( - "The following error occurred while validating the structured \ - data. Please try again.", - ) - .with_item(error), - ); - } - } - } - - Err(Error::MissingStructuredData) - } } /// Get a provider by ID. -/// -/// # Panics -/// -/// Panics if the provider is `ProviderId::TEST`, which is reserved for testing -/// only. pub fn get_provider(id: ProviderId, config: &LlmProviderConfig) -> Result> { let provider: Box = match id { ProviderId::Anthropic => Box::new(Anthropic::try_from(&config.anthropic)?), - ProviderId::Deepseek => todo!(), ProviderId::Google => Box::new(Google::try_from(&config.google)?), ProviderId::Llamacpp => Box::new(Llamacpp::try_from(&config.llamacpp)?), ProviderId::Ollama => Box::new(Ollama::try_from(&config.ollama)?), ProviderId::Openai => Box::new(Openai::try_from(&config.openai)?), ProviderId::Openrouter => Box::new(Openrouter::try_from(&config.openrouter)?), + + ProviderId::Deepseek => todo!(), ProviderId::Xai => todo!(), + + ProviderId::Test => Box::new(MockProvider::new(vec![])), }; Ok(provider) } #[cfg(test)] -mod tests { - use std::sync::Arc; - - use jp_config::{ - assistant::tool_choice::ToolChoice, - conversation::tool::{OneOrManyTypes, ToolParameterConfig, item::ToolParameterItemConfig}, - }; - use jp_conversation::event::ChatRequest; - use jp_test::{Result, function_name}; - - use super::*; - use crate::{ - structured, - test::{TestRequest, run_test, test_model_details}, - }; - - macro_rules! test_all_providers { - ($($fn:ident),* $(,)?) => { - mod anthropic { use super::*; $(test_all_providers!(func; $fn, ProviderId::Anthropic);)* } - mod google { use super::*; $(test_all_providers!(func; $fn, ProviderId::Google);)* } - mod openai { use super::*; $(test_all_providers!(func; $fn, ProviderId::Openai);)* } - mod openrouter{ use super::*; $(test_all_providers!(func; $fn, ProviderId::Openrouter);)* } - mod ollama { use super::*; $(test_all_providers!(func; $fn, ProviderId::Ollama);)* } - mod llamacpp { use super::*; $(test_all_providers!(func; $fn, ProviderId::Llamacpp);)* } - }; - (func; $fn:ident, $provider:ty) => { - paste::paste! { - #[test_log::test(tokio::test)] - async fn [< test_ $fn >]() -> Result { - $fn($provider, function_name!()).await - } - } - }; - } - - async fn chat_completion_nostream(provider: ProviderId, test_name: &str) -> Result { - let request = TestRequest::chat(provider) - .stream(false) - .enable_reasoning() - .event(ChatRequest::from("Test message")); - - run_test(provider, test_name, Some(request)).await - } - - async fn chat_completion_stream(provider: ProviderId, test_name: &str) -> Result { - let request = TestRequest::chat(provider) - .stream(true) - .enable_reasoning() - .event(ChatRequest::from("Test message")); - - run_test(provider, test_name, Some(request)).await - } - - fn tool_call_base(provider: ProviderId) -> TestRequest { - TestRequest::chat(provider) - .event(ChatRequest::from( - "Please run the tool, providing whatever arguments you want.", - )) - .tool("run_me", vec![ - ("foo", ToolParameterConfig { - kind: OneOrManyTypes::One("string".into()), - default: Some("foo".into()), - description: None, - required: false, - enumeration: vec![], - items: None, - }), - ("bar", ToolParameterConfig { - kind: OneOrManyTypes::Many(vec!["string".into(), "array".into()]), - default: None, - description: None, - required: true, - enumeration: vec!["foo".into(), vec!["foo", "bar"].into()], - items: Some(ToolParameterItemConfig { - kind: OneOrManyTypes::One("string".into()), - default: None, - description: None, - enumeration: vec![], - }), - }), - ]) - } - - async fn tool_call_nostream(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_stream(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).stream(true), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_strict(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).tool_call_strict_mode(true), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - /// Without reasoning, "forced" tool calls should work as expected. - async fn tool_call_required_no_reasoning(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).tool_choice(ToolChoice::Required), - TestRequest::tool_call_response(Ok("working!"), true), - ]; - - run_test(provider, test_name, requests).await - } - - /// With reasoning, some models do not support "forced" tool calls, so - /// provider implementations should fall back to trying to instruct the - /// model to use the tool through regular textual instructions. - async fn tool_call_required_reasoning(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider) - .tool_choice(ToolChoice::Required) - .enable_reasoning(), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_auto(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).tool_choice(ToolChoice::Auto), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_function(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).tool_choice_fn("run_me"), - TestRequest::tool_call_response(Ok("working!"), true), - ]; - - run_test(provider, test_name, requests).await - } - - async fn tool_call_reasoning(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - tool_call_base(provider).enable_reasoning(), - TestRequest::tool_call_response(Ok("working!"), false), - ]; - - run_test(provider, test_name, requests).await - } - - async fn structured_completion_success(provider: ProviderId, test_name: &str) -> Result { - let request = - TestRequest::chat(provider).chat_request("I am testing the structured completion API."); - let history = request.as_thread().unwrap().events.clone(); - let request = TestRequest::Structured { - query: structured::titles::titles(3, history, &[]).unwrap(), - model: match request { - TestRequest::Chat { model, .. } => model, - _ => unreachable!(), - }, - assert: Arc::new(|_| {}), - }; - - run_test(provider, test_name, Some(request)).await - } - - async fn structured_completion_error(provider: ProviderId, test_name: &str) -> Result { - let request = - TestRequest::chat(provider).chat_request("I am testing the structured completion API."); - let thread = request.as_thread().cloned().unwrap(); - let query = StructuredQuery::new( - schemars::json_schema!({ - "type": "object", - "description": "1 + 1 = ?", - "required": ["answer"], - "additionalProperties": false, - "properties": { "answer": { "type": "integer" } }, - }), - thread, - ) - .with_validator(move |value| { - value - .get("answer") - .ok_or("Missing `answer` field.".to_owned())? - .as_u64() - .ok_or("Answer must be an integer".to_owned()) - .and_then(|v| Err(format!("You thought 1 + 1 = {v}? Think again!"))) - }); - - let request = TestRequest::Structured { - query, - model: match request { - TestRequest::Chat { model, .. } => model, - _ => unreachable!(), - }, - assert: Arc::new(|results| { - results.iter().all(std::result::Result::is_err); - }), - }; - - run_test(provider, test_name, Some(request)).await - } - - async fn model_details(provider: ProviderId, test_name: &str) -> Result { - let request = TestRequest::ModelDetails { - name: test_model_details(provider).id.name.to_string(), - assert: Arc::new(|_| {}), - }; - - run_test(provider, test_name, Some(request)).await - } - - async fn models(provider: ProviderId, test_name: &str) -> Result { - let request = TestRequest::Models { - assert: Arc::new(|_| {}), - }; - - run_test(provider, test_name, Some(request)).await - } - - async fn multi_turn_conversation(provider: ProviderId, test_name: &str) -> Result { - let requests = vec![ - TestRequest::chat(provider).chat_request("Test message"), - TestRequest::chat(provider) - .enable_reasoning() - .chat_request("Repeat my previous message"), - tool_call_base(provider).tool_choice_fn("run_me"), - TestRequest::tool_call_response(Ok("The secret code is: 42"), true), - TestRequest::chat(provider) - .enable_reasoning() - .chat_request("What was the result of the previous tool call?"), - ]; - - run_test(provider, test_name, requests).await - } - - test_all_providers![ - chat_completion_nostream, - chat_completion_stream, - tool_call_auto, - tool_call_function, - tool_call_reasoning, - tool_call_nostream, - tool_call_required_no_reasoning, - tool_call_required_reasoning, - tool_call_stream, - tool_call_strict, - structured_completion_success, - structured_completion_error, - model_details, - models, - multi_turn_conversation, - ]; -} +#[path = "provider_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/anthropic.rs b/crates/jp_llm/src/provider/anthropic.rs index a5fabe6a..7c917ffc 100644 --- a/crates/jp_llm/src/provider/anthropic.rs +++ b/crates/jp_llm/src/provider/anthropic.rs @@ -1,27 +1,30 @@ -use std::{env, time::Duration}; +use std::{env, future, time::Duration}; use async_anthropic::{ Client, errors::AnthropicError, messages::DEFAULT_MAX_TOKENS, types::{ - self, ListModelsResponse, System, Thinking, ToolBash, ToolCodeExecution, ToolComputerUse, - ToolTextEditor, ToolWebSearch, + self, Effort, JsonOutputFormat, ListModelsResponse, OutputConfig, System, Thinking, + ToolBash, ToolCodeExecution, ToolComputerUse, ToolTextEditor, ToolWebSearch, }, }; use async_stream::try_stream; use async_trait::async_trait; use chrono::NaiveDate; -use futures::{StreamExt as _, TryStreamExt as _, pin_mut}; +use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _, pin_mut, stream}; use indexmap::IndexMap; use jp_config::{ assistant::tool_choice::ToolChoice, - model::id::{Name, ProviderId}, + model::{ + id::{Name, ProviderId}, + parameters::ReasoningEffort, + }, providers::llm::anthropic::AnthropicConfig, }; use jp_conversation::{ ConversationStream, - event::{ChatResponse, ConversationEvent, EventKind}, + event::{ChatResponse, ConversationEvent, EventKind, ToolCallRequest}, thread::{Document, Documents, Thread}, }; use serde_json::{Map, Value, json}; @@ -29,7 +32,10 @@ use tracing::{debug, info, trace, warn}; use super::Provider; use crate::{ - error::{Error, Result}, + error::{ + Error, Result, StreamError, StreamErrorKind, extract_retry_from_text, + looks_like_quota_error, + }, event::{Event, FinishReason}, model::{ModelDeprecation, ModelDetails, ReasoningDetails}, query::ChatQuery, @@ -51,6 +57,27 @@ const MAX_CACHE_CONTROL_COUNT: usize = 4; const THINKING_SIGNATURE_KEY: &str = "anthropic_thinking_signature"; const REDACTED_THINKING_KEY: &str = "anthropic_redacted_thinking"; +/// Known Anthropic error types that are safe to retry. +/// +/// See: +/// See: +const RETRYABLE_ANTHROPIC_ERROR_TYPES: &[&str] = + &["rate_limit_error", "overloaded_error", "api_error"]; + +/// Supported string `format` values for Anthropic structured output. +const SUPPORTED_STRING_FORMATS: &[&str] = &[ + "date-time", + "time", + "date", + "duration", + "email", + "hostname", + "uri", + "ipv4", + "ipv6", + "uuid", +]; + #[derive(Debug, Clone)] pub struct Anthropic { client: Client, @@ -105,18 +132,25 @@ impl Provider for Anthropic { query: ChatQuery, ) -> Result { let client = self.client.clone(); - let chain_on_max_tokens = query + let max_tokens_config = query .thread .events .config()? .assistant .model .parameters - .max_tokens - .is_none() - && self.chain_on_max_tokens; + .max_tokens; - let request = create_request(model, query, true, &self.beta)?; + let (request, is_structured) = create_request(model, query, true, &self.beta)?; + + // Chaining is disabled for structured output — the provider guarantees + // schema compliance so the response won't hit max_tokens for a + // well-constrained schema. + // + // It is also disabled when the user has explicitly configured a max + // tokens value, or when chaining is disabled in the provider config. + let chain_on_max_tokens = + !is_structured && max_tokens_config.is_none() && self.chain_on_max_tokens; debug!(stream = true, "Anthropic chat completion stream request."); trace!( @@ -124,7 +158,7 @@ impl Provider for Anthropic { "Request payload." ); - Ok(call(client, request, chain_on_max_tokens)) + Ok(call(client, request, chain_on_max_tokens, is_structured)) } } @@ -139,6 +173,7 @@ fn call( client: Client, request: types::CreateMessagesRequest, chain_on_max_tokens: bool, + is_structured: bool, ) -> EventStream { Box::pin(try_stream!({ let mut tool_call_aggregator = ToolCallRequestAggregator::new(); @@ -153,87 +188,57 @@ fn call( .messages() .create_stream(request.clone()) .await - .map_err(|e| match e { - AnthropicError::RateLimit { retry_after } => Error::RateLimit { - retry_after: retry_after.map(Duration::from_secs), - }, - - // Anthropic's API is notoriously unreliable, so we - // special-case a few common errors that most of the times - // resolve themselves when retried. - // - // See: - // See: - AnthropicError::StreamError(e) - if ["rate_limit_error", "overloaded_error", "api_error"] - .contains(&e.error_type.as_str()) - || e.error.as_ref().is_some_and(|v| { - v.get("message").and_then(Value::as_str) == Some("Overloaded") - }) => - { - Error::RateLimit { - retry_after: Some(Duration::from_secs(3)), - } - } - - _ => Error::from(e), - }) + .map_err(StreamError::from) + .map_ok(|v| stream::iter(map_event(v, &mut tool_call_aggregator, is_structured))) + .try_flatten() + .chain(future::ready(Ok(Event::Finished(FinishReason::Completed))).into_stream()) .peekable(); pin_mut!(stream); while let Some(event) = stream.next().await.transpose()? { - if let Some(event) = map_event(event, &mut tool_call_aggregator)? { - match event { - // If the assistant has reached the maximum number of - // tokens, and we are in a state in which we can request - // more tokens, we do so by sending a new request and - // chaining those events onto the previous ones, keeping the - // existing stream of events alive. - // - // TODO: generalize this for any provider. - event if should_chain(&event, tool_calls_requested, chain_on_max_tokens) => { - debug!("Max tokens reached, auto-requesting more tokens."); - - for await event in chain(client.clone(), request.clone(), events) { - yield event?; - } - return; - } - done @ Event::Finished(_) => { - yield done; - return; + match event { + // If the assistant has reached the maximum number of + // tokens, and we are in a state in which we can request + // more tokens, we do so by sending a new request and + // chaining those events onto the previous ones, keeping the + // existing stream of events alive. + // + // TODO: generalize this for any provider. + event if should_chain(&event, tool_calls_requested, chain_on_max_tokens) => { + debug!("Max tokens reached, auto-requesting more tokens."); + + for await event in chain(client.clone(), request.clone(), events, is_structured) + { + yield event?; } - Event::Part { event, index } => { - if event.is_tool_call_request() { - tool_calls_requested = true; - } else if chain_on_max_tokens { - events.push(event.clone()); - } - - yield Event::Part { event, index }; + return; + } + done @ Event::Finished(_) => { + yield done; + return; + } + Event::Part { event, index } => { + if event.is_tool_call_request() { + tool_calls_requested = true; + } else if chain_on_max_tokens { + events.push(event.clone()); } - flush @ Event::Flush { .. } => { - let next_delta = stream.as_mut().peek().await.and_then(|e| { - e.as_ref().ok().and_then(|e| match e { - types::MessagesStreamEvent::MessageDelta { delta, .. } => { - Some(delta) - } - _ => None, - }) - }); - - // If we try to flush, but we're about to continue with - // a chained request, we ignore the flush event, to - // allow more event parts to be generated by the next - // response. - if let Some(delta) = next_delta.cloned().and_then(map_message_delta) - && should_chain(&delta, tool_calls_requested, chain_on_max_tokens) - { - continue; - } - - yield flush; + + yield Event::Part { event, index }; + } + flush @ Event::Flush { .. } => { + let next_event = stream.as_mut().peek().await.and_then(|e| e.as_ref().ok()); + + // If we try to flush, but we're about to continue with a + // chained request, we ignore the flush event, to allow more + // event parts to be generated by the next response. + if let Some(event) = next_event + && should_chain(event, tool_calls_requested, chain_on_max_tokens) + { + continue; } + + yield flush; } } } @@ -255,13 +260,18 @@ fn chain( client: Client, mut request: types::CreateMessagesRequest, events: Vec, + is_structured: bool, ) -> EventStream { debug_assert!(!events.iter().any(ConversationEvent::is_tool_call_request)); let mut should_merge = true; let previous_content = events .last() - .and_then(|e| e.as_chat_response().map(ChatResponse::content)) + .and_then(|e| match e.as_chat_response() { + Some(ChatResponse::Message { message }) => Some(message.as_str()), + Some(ChatResponse::Reasoning { reasoning }) => Some(reasoning.as_str()), + _ => None, + }) .unwrap_or_default() .to_owned(); @@ -292,7 +302,7 @@ fn chain( }); Box::pin(try_stream!({ - for await event in call(client, request, true) { + for await event in call(client, request, true, is_structured) { let mut event = event?; // When chaining new events, the reasoning content is irrelevant, as @@ -316,10 +326,12 @@ fn chain( // overlap. Sometimes the assistant will start a chaining response // with a small amount of content that was already seen in the // previous response, and we want to avoid duplicating that. - if let Some(content) = event + if let Some( + ChatResponse::Message { message: content } + | ChatResponse::Reasoning { reasoning: content }, + ) = event .as_conversation_event_mut() .and_then(ConversationEvent::as_chat_response_mut) - .map(ChatResponse::content_mut) { if should_merge { let merge_point = find_merge_point(&previous_content, content, 500); @@ -401,12 +413,11 @@ fn create_request( query: ChatQuery, stream: bool, beta: &BetaFeatures, -) -> Result { +) -> Result<(types::CreateMessagesRequest, bool)> { let ChatQuery { thread, tools, mut tool_choice, - tool_call_strict_mode, } = query; let mut builder = types::CreateMessagesRequestBuilder::default(); @@ -415,11 +426,22 @@ fn create_request( let Thread { system_prompt, - instructions, + sections, attachments, events, } = thread; + // Request a structured response if the very last event is a ChatRequest + // with a schema attached. The schema is transformed to strip unsupported + // properties (moving them into `description` fields as hints). + let format = events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()) + .map(|schema| JsonOutputFormat::JsonSchema { + schema: transform_schema(schema), + }); + let mut cache_control_count = MAX_CACHE_CONTROL_COUNT; let config = events.config()?; @@ -427,13 +449,8 @@ fn create_request( .model(model.id.name.clone()) .messages(AnthropicMessages::build(events, &mut cache_control_count).0); - let tools = convert_tools( - tools, - tool_call_strict_mode - && model.features.contains(&"structured-outputs") - && beta.structured_outputs(), - &mut cache_control_count, - ); + let strict_tools = model.features.contains(&"structured-outputs") && beta.structured_outputs(); + let tools = convert_tools(tools, strict_tools, &mut cache_control_count); let mut system_content = vec![]; @@ -447,20 +464,29 @@ fn create_request( })); } - if !instructions.is_empty() { - let text = instructions - .into_iter() - .map(|instruction| instruction.try_to_xml().map_err(Into::into)) - .collect::>>()? - .join("\n\n"); - - system_content.push(types::SystemContent::Text(types::Text { - text, - cache_control: (cache_control_count > 0).then_some({ - cache_control_count = cache_control_count.saturating_sub(1); - types::CacheControl::default() - }), - })); + // FIXME: Somehow the system_prompt is being duplicated. It has to do with + // `impl PartialConfigDelta`? + // dbg!(&system_content); + + if !sections.is_empty() { + // Each section gets its own system content block. Cache control + // is placed on the last section, as section content is unlikely + // to change between requests. + let mut sections = sections.iter().peekable(); + while let Some(section) = sections.next() { + system_content.push(types::SystemContent::Text(types::Text { + text: section.render(), + cache_control: sections.peek().map_or_else( + || { + (cache_control_count > 0).then(|| { + cache_control_count = cache_control_count.saturating_sub(1); + types::CacheControl::default() + }) + }, + |_| None, + ), + })); + } } if !attachments.is_empty() { @@ -546,36 +572,76 @@ fn create_request( builder.system(System::Content(system_content)); } + // Track the effort from reasoning config so we can merge it with the + // structured output format into a single OutputConfig at the end. + let mut effort = None; + let supports_thinking = model.reasoning.is_some_and(|r| !r.is_unsupported()); + if let Some(config) = reasoning_config { - let (min_budget, mut max_budget) = match model.reasoning { + match model.reasoning { + // Adaptive thinking for Opus 4.6+ + Some(ReasoningDetails::Adaptive { max: supports_max }) => { + builder.thinking(types::ExtendedThinking::Adaptive); + + effort = match config + .effort + .abs_to_rel(model.max_output_tokens) + .unwrap_or(ReasoningEffort::Auto) + { + ReasoningEffort::Max if supports_max => Some(Effort::Max), + ReasoningEffort::Max + | ReasoningEffort::XHigh + | ReasoningEffort::High + | ReasoningEffort::Absolute(_) => Some(Effort::High), + ReasoningEffort::Medium => Some(Effort::Medium), + ReasoningEffort::Low | ReasoningEffort::Xlow | ReasoningEffort::None => { + Some(Effort::Low) + } + ReasoningEffort::Auto => None, + }; + } + + // Budget-based thinking for older models Some(ReasoningDetails::Budgetted { min_tokens, - max_tokens, - }) => (min_tokens, max_tokens.unwrap_or(u32::MAX)), - _ => (0, u32::MAX), - }; + max_tokens: reasoning_max_tokens, + }) => { + let mut max_budget = reasoning_max_tokens.unwrap_or(u32::MAX); - // With interleaved thinking, the `budget_tokens` can exceed the - // `max_tokens` parameter, as it represents the total budget across all - // thinking blocks within one assistant turn. - // - // See: - // - // This is only enabled if the model supports it, otherwise an error is - // returned if the `max_tokens` parameter is larger than the model's - // supported range. - if beta.interleaved_thinking() && model.features.contains(&"interleaved-thinking") { - max_budget = model.context_window.unwrap_or(max_budget); + // With interleaved thinking, the `budget_tokens` can exceed the + // `max_tokens` parameter, as it represents the total budget across all + // thinking blocks within one assistant turn. + // + // See: + // + // This is only enabled if the model supports it, otherwise an error is + // returned if the `max_tokens` parameter is larger than the model's + // supported range. + if beta.interleaved_thinking() && model.features.contains(&"interleaved-thinking") { + max_budget = model.context_window.unwrap_or(max_budget); + } + + builder.thinking(types::ExtendedThinking::Enabled { + budget_tokens: config + .effort + .to_tokens(max_tokens) + .max(min_tokens) + .min(max_budget), + }); + } + + // Other reasoning details (Leveled, Unsupported) - no thinking config + _ => {} } + } else if supports_thinking { + // Reasoning is off but the model supports it — explicitly disable + // to prevent the model from thinking by default. + builder.thinking(types::ExtendedThinking::Disabled); + } - builder.thinking(types::ExtendedThinking { - kind: "enabled".to_string(), - budget_tokens: config - .effort - .to_tokens(max_tokens) - .max(min_budget) - .min(max_budget), - }); + let is_structured = format.is_some(); + if effort.is_some() || is_structured { + builder.output_config(OutputConfig { effort, format }); } if let Some(temperature) = parameters.temperature { @@ -608,12 +674,53 @@ fn create_request( builder.context_management(strategy); } - builder.build().map_err(Into::into) + builder + .build() + .map(|req| (req, is_structured)) + .map_err(Into::into) } #[expect(clippy::too_many_lines)] fn map_model(model: types::Model, beta: &BetaFeatures) -> Result { let details = match model.id.as_str() { + "claude-sonnet-4-6" => ModelDetails { + id: (PROVIDER, model.id).try_into()?, + display_name: Some(model.display_name), + context_window: if beta.context_1m() { + Some(1_000_000) + } else { + Some(200_000) + }, + max_output_tokens: Some(64_000), + reasoning: Some(ReasoningDetails::adaptive(true)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 8, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![ + "interleaved-thinking", + "context-editing", + "structured-outputs", + "adaptive-thinking", + ], + }, + "claude-opus-4-6" | "claude-opus-4-6-20260205" => ModelDetails { + id: (PROVIDER, model.id).try_into()?, + display_name: Some(model.display_name), + context_window: if beta.context_1m() { + Some(1_000_000) + } else { + Some(200_000) + }, + max_output_tokens: Some(128_000), + reasoning: Some(ReasoningDetails::adaptive(true)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 5, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![ + "interleaved-thinking", + "context-editing", + "structured-outputs", + "adaptive-thinking", + ], + }, "claude-opus-4-5" | "claude-opus-4-5-20251101" => ModelDetails { id: (PROVIDER, model.id).try_into()?, display_name: Some(model.display_name), @@ -694,39 +801,6 @@ fn map_model(model: types::Model, beta: &BetaFeatures) -> Result { deprecated: Some(ModelDeprecation::Active), features: vec!["interleaved-thinking", "context-editing"], }, - "claude-3-7-sonnet-latest" | "claude-3-7-sonnet-20250219" => ModelDetails { - id: (PROVIDER, model.id).try_into()?, - display_name: Some(model.display_name), - context_window: Some(200_000), - max_output_tokens: Some(64_000), - reasoning: Some(ReasoningDetails::budgetted(1024, None)), - knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 11, 1).unwrap()), - deprecated: Some(ModelDeprecation::Active), - features: vec![], - }, - "claude-3-5-haiku-latest" | "claude-3-5-haiku-20241022" => ModelDetails { - id: (PROVIDER, model.id).try_into()?, - display_name: Some(model.display_name), - context_window: Some(200_000), - max_output_tokens: Some(8_192), - reasoning: Some(ReasoningDetails::unsupported()), - knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 7, 1).unwrap()), - deprecated: Some(ModelDeprecation::Active), - features: vec![], - }, - "claude-3-opus-latest" | "claude-3-opus-20240229" => ModelDetails { - id: (PROVIDER, model.id).try_into()?, - display_name: Some(model.display_name), - context_window: Some(200_000), - max_output_tokens: Some(4_096), - reasoning: Some(ReasoningDetails::unsupported()), - knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2023, 8, 1).unwrap()), - deprecated: Some(ModelDeprecation::deprecated( - &"recommended replacement: claude-opus-4-1-20250805", - Some(NaiveDate::from_ymd_opt(2026, 1, 5).unwrap()), - )), - features: vec![], - }, "claude-3-haiku-20240307" => ModelDetails { id: (PROVIDER, model.id).try_into()?, display_name: Some(model.display_name), @@ -734,7 +808,10 @@ fn map_model(model: types::Model, beta: &BetaFeatures) -> Result { max_output_tokens: Some(4_096), reasoning: Some(ReasoningDetails::unsupported()), knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 8, 1).unwrap()), - deprecated: Some(ModelDeprecation::Active), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: claude-haiku-4-5-20251001", + Some(NaiveDate::from_ymd_opt(2026, 4, 20).unwrap()), + )), features: vec![], }, id => { @@ -751,7 +828,8 @@ fn map_model(model: types::Model, beta: &BetaFeatures) -> Result { fn map_event( event: types::MessagesStreamEvent, agg: &mut ToolCallRequestAggregator, -) -> Result> { + is_structured: bool, +) -> Vec> { use types::MessagesStreamEvent::*; trace!( @@ -763,11 +841,17 @@ fn map_event( ContentBlockStart { content_block, index, - } => Ok(map_content_start(content_block, index, agg)), - ContentBlockDelta { delta, index } => Ok(map_content_delta(delta, index, agg)), + } => map_content_start(content_block, index, agg, is_structured) + .into_iter() + .map(Ok) + .collect(), + ContentBlockDelta { delta, index } => map_content_delta(delta, index, agg, is_structured) + .into_iter() + .map(Ok) + .collect(), ContentBlockStop { index } => map_content_stop(index, agg), - MessageDelta { delta, .. } => Ok(map_message_delta(delta)), - _ => Ok(None), + MessageDelta { delta, .. } => map_message_delta(&delta).into_iter().map(Ok).collect(), + _ => vec![], } } @@ -798,6 +882,139 @@ impl TryFrom<&AnthropicConfig> for Anthropic { } } +/// Transform a JSON schema to conform to Anthropic's structured output +/// constraints. +/// +/// Anthropic's structured output supports a subset of JSON Schema. Unsupported +/// properties are stripped and appended to the `description` field so the model +/// can still see them as soft hints. +/// +/// Mirrors the logic from Anthropic's Python SDK `transform_schema`. +/// +/// See: +fn transform_schema(mut src: Map) -> Map { + if let Some(r) = src.remove("$ref") { + return Map::from_iter([("$ref".into(), r)]); + } + + let mut out = Map::new(); + + // Helper macro to move a field from src to out. + macro_rules! move_field { + ($key:literal) => { + if let Some(v) = src.remove($key) { + out.insert($key.into(), v); + } + }; + } + + // Extract common fields + move_field!("title"); + move_field!("description"); + + // Recursive Transformation Helpers + let transform_val = |v: Value| match v { + Value::Object(o) => Value::Object(transform_schema(o)), + other => other, + }; + + let transform_map = |m: Map| -> Map { + m.into_iter().map(|(k, v)| (k, transform_val(v))).collect() + }; + + let transform_vec = + |v: Vec| -> Vec { v.into_iter().map(transform_val).collect() }; + + // Handle Recursive Dictionaries + for key in ["$defs", "definitions"] { + if let Some(Value::Object(defs)) = src.remove(key) { + out.insert(key.into(), Value::Object(transform_map(defs))); + } + } + + // Handle Combinators + if let Some(Value::Array(variants)) = src.remove("anyOf") { + out.insert("anyOf".into(), Value::Array(transform_vec(variants))); + } else if let Some(Value::Array(variants)) = src.remove("oneOf") { + // Remap oneOf -> anyOf + out.insert("anyOf".into(), Value::Array(transform_vec(variants))); + } else if let Some(Value::Array(variants)) = src.remove("allOf") { + out.insert("allOf".into(), Value::Array(transform_vec(variants))); + } + + // Handle Type-Specific Logic + // + // We remove "type" now so it doesn't get caught in the "leftovers" logic + // later. + let type_val = src.remove("type"); + match type_val.as_ref().and_then(Value::as_str) { + Some("object") => { + if let Some(Value::Object(props)) = src.remove("properties") { + out.insert("properties".into(), Value::Object(transform_map(props))); + } + + move_field!("required"); + + // Force strictness + src.remove("additionalProperties"); + out.insert("additionalProperties".into(), Value::Bool(false)); + } + Some("array") => { + if let Some(items) = src.remove("items") { + out.insert("items".into(), transform_val(items)); + } + + // Enforce minItems logic + if let Some(min) = src.remove("minItems") { + if min.as_u64().is_some_and(|n| n <= 1) { + out.insert("minItems".into(), min); + } else { + // Put it back into src so it falls into description + src.insert("minItems".into(), min); + } + } + } + Some("string") => { + if let Some(format) = src.remove("format") { + let is_supported = format + .as_str() + .is_some_and(|f| SUPPORTED_STRING_FORMATS.contains(&f)); + + if is_supported { + out.insert("format".into(), format); + } else { + src.insert("format".into(), format); + } + } + } + _ => {} + } + + // Re-insert type. + if let Some(t) = type_val { + out.insert("type".into(), t); + } + + // 7. Handle "Leftovers" (Unsupported fields -> Description) + if !src.is_empty() { + let extra_info = src + .iter() + .map(|(k, v)| format!("{k}: {v}")) + .collect::>() + .join(", "); + + out.entry("description") + .and_modify(|v| { + if let Some(s) = v.as_str() { + *v = Value::from(format!("{s}\n\n{{{extra_info}}}")); + } + }) + .or_insert_with(|| Value::from(format!("{{{extra_info}}}"))); + } + + out +} + fn convert_tool_choice(choice: ToolChoice) -> types::ToolChoice { match choice { ToolChoice::None => types::ToolChoice::none(), @@ -979,7 +1196,11 @@ fn convert_event( && signature.is_some() { types::MessageContent::Thinking(Thinking { - thinking: response.into_content(), + thinking: match response { + ChatResponse::Reasoning { reasoning } => reasoning, + ChatResponse::Message { message } => message, + ChatResponse::Structured { data } => data.to_string(), + }, signature, }) } else if is_anthropic @@ -990,13 +1211,19 @@ fn convert_event( .map(str::to_owned) { types::MessageContent::RedactedThinking { data } - } else if response.is_reasoning() { - // Reasoning from other providers - wrap in XML tags - types::MessageContent::Text( - format!("\n{}\n\n\n", response.content()).into(), - ) } else { - types::MessageContent::Text(response.into_content().into()) + match response { + // Reasoning from other providers - wrap in tags. + ChatResponse::Reasoning { reasoning } => types::MessageContent::Text( + format!("\n{reasoning}\n\n\n").into(), + ), + ChatResponse::Message { message } => { + types::MessageContent::Text(message.into()) + } + ChatResponse::Structured { data } => { + types::MessageContent::Text(data.to_string().into()) + } + } }; Some((types::MessageRole::Assistant, content)) @@ -1028,7 +1255,8 @@ fn convert_event( } EventKind::ChatRequest(_) | EventKind::InquiryRequest(_) - | EventKind::InquiryResponse(_) => None, + | EventKind::InquiryResponse(_) + | EventKind::TurnStart(_) => None, } } @@ -1036,16 +1264,29 @@ fn map_content_start( item: types::MessageContent, index: usize, agg: &mut ToolCallRequestAggregator, + is_structured: bool, ) -> Option { use types::MessageContent::*; let mut metadata = IndexMap::new(); let kind: EventKind = match item { + // Initial part indicating a tool call request has started. The eventual + // fully-aggregated arguments will be sent in a separate Part. ToolUse(types::ToolUse { id, name, .. }) => { - *agg = ToolCallRequestAggregator::new(); - agg.add_chunk(index, Some(id), Some(name), None); - return None; + *agg = ToolCallRequestAggregator::default(); + agg.add_chunk(index, Some(id.clone()), Some(name.clone()), None); + let request = ToolCallRequest { + id, + name, + arguments: Map::new(), + }; + + return Some(Event::Part { + index, + event: request.into(), + }); } + Text(text) if is_structured => ChatResponse::structured(Value::String(text.text)).into(), Text(text) if !text.text.is_empty() => ChatResponse::message(text.text).into(), Text(_) => return None, Thinking(types::Thinking { @@ -1077,9 +1318,13 @@ fn map_content_delta( delta: types::ContentBlockDelta, index: usize, agg: &mut ToolCallRequestAggregator, + is_structured: bool, ) -> Option { let mut metadata = IndexMap::new(); let kind: EventKind = match delta { + types::ContentBlockDelta::TextDelta { text } if is_structured => { + ChatResponse::structured(Value::String(text)).into() + } types::ContentBlockDelta::TextDelta { text } => ChatResponse::message(text).into(), types::ContentBlockDelta::ThinkingDelta { thinking } => { ChatResponse::reasoning(thinking).into() @@ -1105,163 +1350,90 @@ fn map_content_delta( }) } -fn map_content_stop(index: usize, agg: &mut ToolCallRequestAggregator) -> Result> { +fn map_content_stop( + index: usize, + agg: &mut ToolCallRequestAggregator, +) -> Vec> { + let mut events = vec![]; + // Check if we're buffering a tool call request match agg.finalize(index) { - Ok(tool_call) => { - return Ok(Some(Event::Part { - event: ConversationEvent::now(tool_call), - index, - })); - } + Ok(tool_call) => events.push(Ok(Event::Part { + event: ConversationEvent::now(tool_call), + index, + })), Err(AggregationError::UnknownIndex) => {} - Err(error) => return Err(error.into()), + Err(error) => { + events.push(Err(StreamError::other(error.to_string()))); + return events; + } } - Ok(Some(Event::flush(index))) + events.push(Ok(Event::flush(index))); + events } -fn map_message_delta(delta: types::MessageDelta) -> Option { - match delta.stop_reason?.as_str() { +fn map_message_delta(delta: &types::MessageDelta) -> Option { + match delta.stop_reason.as_deref()? { "max_tokens" => Some(Event::Finished(FinishReason::MaxTokens)), _ => None, } } -#[cfg(test)] -mod tests { - use indexmap::IndexMap; - use jp_config::model::parameters::{ - PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort, - }; - use jp_test::{Result, function_name}; - use test_log::test; - - use super::*; - use crate::test::{TestRequest, run_test}; - - const MAGIC_STRING: &str = "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB"; - - #[test(tokio::test)] - async fn test_redacted_thinking() -> Result { - let requests = vec![ - TestRequest::chat(PROVIDER) - .enable_reasoning() - .chat_request(MAGIC_STRING), - TestRequest::chat(PROVIDER) - .chat_request("Do you have access to your redacted thinking content?"), - ]; - - run_test(PROVIDER, function_name!(), requests).await - } +impl From for StreamError { + fn from(error: AnthropicError) -> Self { + use AnthropicError as E; - #[test(tokio::test)] - async fn test_request_chaining() -> Result { - let mut request = TestRequest::chat(PROVIDER) - .stream(true) - .reasoning(Some(PartialReasoningConfig::Custom( - PartialCustomReasoningConfig { - effort: Some(ReasoningEffort::Absolute(1024.into())), - exclude: Some(false), - }, - ))) - .chat_request("Give me a 2000 word explainer about Kirigami-inspired parachutes"); - - if let Some(details) = request.as_model_details_mut() { - details.max_output_tokens = Some(1152); - } + match error { + E::Network(error) => Self::from(error), + E::StreamTransport(_) => StreamError::transient(error.to_string()).with_source(error), + E::RateLimit { retry_after } => { + Self::rate_limit(retry_after.map(Duration::from_secs)).with_source(error) + } - run_test(PROVIDER, function_name!(), Some(request)).await - } + // Anthropic's API is notoriously unreliable, so we special-case a + // few common errors that most of the times resolve themselves when + // retried. + // + // See: + // See: + E::Api(ref api_error) + if RETRYABLE_ANTHROPIC_ERROR_TYPES.contains(&api_error.error_type.as_str()) => + { + let retry_after = api_error + .message + .as_deref() + .and_then(extract_retry_from_text) + .unwrap_or(Duration::from_secs(3)); + + StreamError::transient(error.to_string()) + .with_retry_after(retry_after) + .with_source(error) + } - #[test] - fn test_find_merge_point_edge_cases() { - struct TestCase { - left: &'static str, - right: &'static str, - expected: &'static str, - max_search: usize, - } + // Detect billing/quota errors before falling through to generic. + E::Api(ref api_error) + if looks_like_quota_error(&api_error.error_type) + || api_error + .message + .as_deref() + .is_some_and(looks_like_quota_error) => + { + StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your plan and billing details \ + at https://console.anthropic.com/settings/billing. ({error})" + ), + ) + .with_source(error) + } - let cases = IndexMap::from([ - ("no overlap", TestCase { - left: "Hello", - right: " world", - expected: "Hello world", - max_search: 500, - }), - ("single word overlap", TestCase { - left: "The quick brown", - right: "brown fox", - expected: "The quick brown fox", - max_search: 500, - }), - ("minimal overlap (5 chars)", TestCase { - expected: "abcdefghij", - left: "abcdefgh", - right: "defghij", - max_search: 500, - }), - ( - "below minimum overlap (4 chars) - should not merge", - TestCase { - left: "abcd", - right: "abcd", - expected: "abcdabcd", - max_search: 500, - }, - ), - ("complete overlap", TestCase { - left: "Hello world", - right: "world", - expected: "Hello world", - max_search: 500, - }), - ("overlap with punctuation", TestCase { - left: "Hello, how are", - right: "how are you?", - expected: "Hello, how are you?", - max_search: 500, - }), - ("overlap with whitespace", TestCase { - left: "Hello ", - right: " world", - expected: "Hello world", - max_search: 500, - }), - ("unicode overlap", TestCase { - left: "Hello 世界", - right: "世界 friend", - expected: "Hello 世界 friend", - max_search: 500, - }), - ("long overlap", TestCase { - left: "The quick brown fox jumps", - right: "fox jumps over the lazy dog", - expected: "The quick brown fox jumpsfox jumps over the lazy dog", - max_search: 8, - }), - ("empty right", TestCase { - left: "Hello", - right: "", - expected: "Hello", - max_search: 500, - }), - ]); - - for ( - name, - TestCase { - left, - right, - expected, - max_search, - }, - ) in cases - { - let pos = find_merge_point(left, right, max_search); - let result = format!("{left}{}", &right[pos..]); - assert_eq!(result, expected, "Failed test case: {name}"); + error => StreamError::other(error.to_string()).with_source(error), } } } + +#[cfg(test)] +#[path = "anthropic_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/anthropic_tests.rs b/crates/jp_llm/src/provider/anthropic_tests.rs new file mode 100644 index 00000000..b1da6420 --- /dev/null +++ b/crates/jp_llm/src/provider/anthropic_tests.rs @@ -0,0 +1,746 @@ +use indexmap::IndexMap; +use jp_config::model::parameters::{ + PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort, +}; +use jp_test::{Result, function_name}; +use test_log::test; + +use super::*; +use crate::test::{TestRequest, run_test}; + +const MAGIC_STRING: &str = "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB"; + +#[test(tokio::test)] +async fn test_redacted_thinking() -> Result { + let requests = vec![ + TestRequest::chat(PROVIDER) + .enable_reasoning() + .chat_request(MAGIC_STRING), + TestRequest::chat(PROVIDER) + .chat_request("Do you have access to your redacted thinking content?"), + ]; + + run_test(PROVIDER, function_name!(), requests).await +} + +#[test(tokio::test)] +async fn test_request_chaining() -> Result { + let mut request = TestRequest::chat(PROVIDER) + .reasoning(Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::Absolute(1024.into())), + exclude: Some(false), + }, + ))) + .chat_request("Give me a 2000 word explainer about Kirigami-inspired parachutes"); + + if let Some(details) = request.as_model_details_mut() { + details.max_output_tokens = Some(1152); + } + + run_test(PROVIDER, function_name!(), Some(request)).await +} + +/// Test that Opus 4.6 uses adaptive thinking mode with the effort parameter. +#[test(tokio::test)] +async fn test_opus_4_6_adaptive_thinking() -> Result { + let mut request = TestRequest::chat(PROVIDER) + .reasoning(Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::High), + exclude: Some(false), + }, + ))) + .model("anthropic/claude-opus-4-6".parse().unwrap()) + .chat_request("What is 2 + 2?"); + + // Configure model to use adaptive thinking (Opus 4.6 feature) + if let Some(details) = request.as_model_details_mut() { + details.reasoning = Some(ReasoningDetails::adaptive(true)); + details.features = vec!["adaptive-thinking"]; + } + + run_test(PROVIDER, function_name!(), Some(request)).await +} + +/// Test Opus 4.6 with `Max` effort level (only supported on Opus 4.6). +#[test(tokio::test)] +async fn test_opus_4_6_max_effort() -> Result { + let mut request = TestRequest::chat(PROVIDER) + .reasoning(Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::Max), + exclude: Some(false), + }, + ))) + .model("anthropic/claude-opus-4-6".parse().unwrap()) + .chat_request("What is 2 + 2?"); + + // Configure model to use adaptive thinking with max effort support (Opus 4.6 feature) + if let Some(details) = request.as_model_details_mut() { + details.reasoning = Some(ReasoningDetails::adaptive(true)); + details.features = vec!["adaptive-thinking"]; + } + + run_test(PROVIDER, function_name!(), Some(request)).await +} + +/// Unit test: Verify Opus 4.6 generates adaptive thinking request. +#[test] +fn test_opus_4_6_request_uses_adaptive_thinking() { + use jp_conversation::{ConversationStream, thread::Thread}; + + let model = ModelDetails { + id: (PROVIDER, "claude-opus-4-6").try_into().unwrap(), + display_name: Some("Claude Opus 4.6".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(128_000), + reasoning: Some(ReasoningDetails::adaptive(true)), + knowledge_cutoff: None, + deprecated: None, + features: vec!["adaptive-thinking"], + }; + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events: ConversationStream::new_test().with_chat_request("test"), + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, is_structured) = create_request(&model, query, true, &beta).unwrap(); + assert!(!is_structured); + + // Verify adaptive thinking is used + assert_eq!(request.thinking, Some(types::ExtendedThinking::Adaptive)); + + // Verify output_config has effort set (defaults to High) + assert!(request.output_config.is_some()); + let output_config = request.output_config.unwrap(); + assert_eq!(output_config.effort, Some(Effort::High)); + assert_eq!(output_config.format, None); +} + +/// Unit test: Verify Max effort maps to `Effort::Max` for Opus 4.6. +#[test] +fn test_opus_4_6_max_effort_mapping() { + use jp_conversation::{ConversationStream, thread::Thread}; + + let model = ModelDetails { + id: (PROVIDER, "claude-opus-4-6").try_into().unwrap(), + display_name: Some("Claude Opus 4.6".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(128_000), + reasoning: Some(ReasoningDetails::adaptive(true)), // supports max + knowledge_cutoff: None, + deprecated: None, + features: vec!["adaptive-thinking"], + }; + + let mut events = ConversationStream::new_test().with_chat_request("test"); + let mut delta = jp_config::PartialAppConfig::empty(); + delta.assistant.model.parameters.reasoning = Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::Max), + exclude: Some(false), + }, + )); + events.add_config_delta(delta); + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events, + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, _) = create_request(&model, query, true, &beta).unwrap(); + + // Verify Max effort is used + assert!(request.output_config.is_some()); + let output_config = request.output_config.unwrap(); + assert_eq!(output_config.effort, Some(Effort::Max)); +} + +/// Unit test: Verify budget-based model (Opus 4.5) still uses Enabled thinking. +#[test] +fn test_opus_4_5_uses_budgetted_thinking() { + use jp_conversation::{ConversationStream, thread::Thread}; + + let model = ModelDetails { + id: (PROVIDER, "claude-opus-4-5").try_into().unwrap(), + display_name: Some("Claude Opus 4.5".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(64_000), + reasoning: Some(ReasoningDetails::budgetted(1024, None)), + knowledge_cutoff: None, + deprecated: None, + features: vec!["interleaved-thinking"], + }; + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events: ConversationStream::new_test().with_chat_request("test"), + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, _) = create_request(&model, query, true, &beta).unwrap(); + + // Verify budget-based thinking is used (not adaptive) + assert!(matches!( + request.thinking, + Some(types::ExtendedThinking::Enabled { .. }) + )); + + // Verify output_config is NOT set for budget-based models + assert!(request.output_config.is_none()); +} + +/// Verify structured output sets `output_config.format` when the last event +/// is a `ChatRequest` with a schema. +#[test] +fn test_structured_output_sets_format() { + use jp_conversation::{ConversationStream, event::ChatRequest, thread::Thread}; + use serde_json::json; + + let model = ModelDetails { + id: (PROVIDER, "claude-sonnet-4-5").try_into().unwrap(), + display_name: Some("Claude Sonnet 4.5".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(64_000), + reasoning: Some(ReasoningDetails::budgetted(1024, None)), + knowledge_cutoff: None, + deprecated: None, + features: vec!["structured-outputs"], + }; + + let schema = serde_json::Map::from_iter([ + ("type".into(), json!("object")), + ("properties".into(), json!({"name": {"type": "string"}})), + ]); + + let events = ConversationStream::new_test().with_chat_request(ChatRequest { + content: "Extract contacts".into(), + schema: Some(schema), + }); + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events, + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, is_structured) = create_request(&model, query, true, &beta).unwrap(); + + assert!(is_structured); + assert!(request.output_config.is_some()); + let output_config = request.output_config.unwrap(); + // No adaptive thinking, so effort should be None. + assert_eq!(output_config.effort, None); + // transform_schema adds additionalProperties: false for objects. + let expected_schema = serde_json::Map::from_iter([ + ("type".into(), json!("object")), + ("properties".into(), json!({"name": {"type": "string"}})), + ("additionalProperties".into(), json!(false)), + ]); + assert_eq!( + output_config.format, + Some(JsonOutputFormat::JsonSchema { + schema: expected_schema + }) + ); +} + +/// When the last event is NOT a `ChatRequest` (e.g. a `ChatResponse`), no +/// structured output should be configured even if a prior `ChatRequest` had a +/// schema. +#[test] +fn test_schema_ignored_when_last_event_is_not_chat_request() { + use jp_conversation::{ + ConversationStream, + event::{ChatRequest, ChatResponse}, + thread::Thread, + }; + use serde_json::json; + + let model = ModelDetails { + id: (PROVIDER, "claude-sonnet-4-5").try_into().unwrap(), + display_name: None, + context_window: Some(200_000), + max_output_tokens: Some(64_000), + reasoning: None, + knowledge_cutoff: None, + deprecated: None, + features: vec![], + }; + + let mut events = ConversationStream::new_test(); + // First turn: structured request + events.add_chat_request(ChatRequest { + content: "Extract contacts".into(), + schema: Some(serde_json::Map::from_iter([( + "type".into(), + json!("object"), + )])), + }); + // Then a response (now the last event is not a ChatRequest) + events.add_chat_response(ChatResponse::structured(json!({"name": "Alice"}))); + // Follow-up without schema + events.add_chat_request(ChatRequest { + content: "Explain what you found".into(), + schema: None, + }); + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events, + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (_, is_structured) = create_request(&model, query, true, &beta).unwrap(); + assert!(!is_structured); +} + +/// Adaptive thinking + structured output should coexist on `OutputConfig`. +#[test] +fn test_adaptive_thinking_with_structured_output() { + use jp_conversation::{ConversationStream, event::ChatRequest, thread::Thread}; + use serde_json::json; + + let model = ModelDetails { + id: (PROVIDER, "claude-opus-4-6").try_into().unwrap(), + display_name: Some("Claude Opus 4.6".to_string()), + context_window: Some(200_000), + max_output_tokens: Some(128_000), + reasoning: Some(ReasoningDetails::adaptive(true)), + knowledge_cutoff: None, + deprecated: None, + features: vec!["adaptive-thinking", "structured-outputs"], + }; + + let schema = serde_json::Map::from_iter([("type".into(), json!("object"))]); + + let events = ConversationStream::new_test().with_chat_request(ChatRequest { + content: "Extract data".into(), + schema: Some(schema), + }); + + let query = ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events, + }, + tools: vec![], + tool_choice: ToolChoice::Auto, + }; + + let beta = BetaFeatures(vec![]); + let (request, is_structured) = create_request(&model, query, true, &beta).unwrap(); + + assert!(is_structured); + assert_eq!(request.thinking, Some(types::ExtendedThinking::Adaptive)); + + let output_config = request.output_config.unwrap(); + // Both effort and format should be present. + assert_eq!(output_config.effort, Some(Effort::High)); + let expected_schema = serde_json::Map::from_iter([ + ("type".into(), json!("object")), + ("additionalProperties".into(), json!(false)), + ]); + assert_eq!( + output_config.format, + Some(JsonOutputFormat::JsonSchema { + schema: expected_schema + }) + ); +} + +#[test] +fn test_find_merge_point_edge_cases() { + struct TestCase { + left: &'static str, + right: &'static str, + expected: &'static str, + max_search: usize, + } + + let cases = IndexMap::from([ + ("no overlap", TestCase { + left: "Hello", + right: " world", + expected: "Hello world", + max_search: 500, + }), + ("single word overlap", TestCase { + left: "The quick brown", + right: "brown fox", + expected: "The quick brown fox", + max_search: 500, + }), + ("minimal overlap (5 chars)", TestCase { + expected: "abcdefghij", + left: "abcdefgh", + right: "defghij", + max_search: 500, + }), + ( + "below minimum overlap (4 chars) - should not merge", + TestCase { + left: "abcd", + right: "abcd", + expected: "abcdabcd", + max_search: 500, + }, + ), + ("complete overlap", TestCase { + left: "Hello world", + right: "world", + expected: "Hello world", + max_search: 500, + }), + ("overlap with punctuation", TestCase { + left: "Hello, how are", + right: "how are you?", + expected: "Hello, how are you?", + max_search: 500, + }), + ("overlap with whitespace", TestCase { + left: "Hello ", + right: " world", + expected: "Hello world", + max_search: 500, + }), + ("unicode overlap", TestCase { + left: "Hello 世界", + right: "世界 friend", + expected: "Hello 世界 friend", + max_search: 500, + }), + ("long overlap", TestCase { + left: "The quick brown fox jumps", + right: "fox jumps over the lazy dog", + expected: "The quick brown fox jumpsfox jumps over the lazy dog", + max_search: 8, + }), + ("empty right", TestCase { + left: "Hello", + right: "", + expected: "Hello", + max_search: 500, + }), + ]); + + for ( + name, + TestCase { + left, + right, + expected, + max_search, + }, + ) in cases + { + let pos = find_merge_point(left, right, max_search); + let result = format!("{left}{}", &right[pos..]); + assert_eq!(result, expected, "Failed test case: {name}"); + } +} + +mod transform_schema { + use serde_json::{Map, Value, json}; + + use super::transform_schema; + + #[expect(clippy::needless_pass_by_value)] + fn schema(v: Value) -> Map { + v.as_object().unwrap().clone() + } + + #[test] + fn object_forces_additional_properties_false() { + let input = schema(json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + })); + + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["name"])); + assert_eq!(out["properties"]["name"]["type"], "string"); + } + + #[test] + fn object_drops_existing_additional_properties() { + let input = schema(json!({ + "type": "object", + "properties": {}, + "additionalProperties": true + })); + + let out = transform_schema(input); + assert_eq!(out["additionalProperties"], json!(false)); + } + + #[test] + fn array_keeps_min_items_0_and_1() { + for n in [0, 1] { + let input = schema(json!({ + "type": "array", + "items": { "type": "string" }, + "minItems": n + })); + + let out = transform_schema(input); + assert_eq!(out["minItems"], json!(n), "minItems {n} should be kept"); + } + } + + #[test] + fn array_moves_large_min_items_to_description() { + let input = schema(json!({ + "type": "array", + "items": { "type": "string" }, + "minItems": 3 + })); + + let out = transform_schema(input); + assert!(out.get("minItems").is_none()); + let desc = out["description"].as_str().unwrap(); + assert!( + desc.contains("minItems"), + "description should mention minItems: {desc}" + ); + } + + #[test] + fn array_moves_max_items_to_description() { + let input = schema(json!({ + "type": "array", + "items": { "type": "string" }, + "maxItems": 5 + })); + + let out = transform_schema(input); + assert!(out.get("maxItems").is_none()); + let desc = out["description"].as_str().unwrap(); + assert!( + desc.contains("maxItems"), + "description should mention maxItems: {desc}" + ); + } + + #[test] + fn string_keeps_supported_format() { + let input = schema(json!({ + "type": "string", + "format": "date-time" + })); + + let out = transform_schema(input); + assert_eq!(out["format"], "date-time"); + assert!(out.get("description").is_none()); + } + + #[test] + fn string_moves_unsupported_format_to_description() { + let input = schema(json!({ + "type": "string", + "format": "phone-number" + })); + + let out = transform_schema(input); + assert!(out.get("format").is_none()); + let desc = out["description"].as_str().unwrap(); + assert!( + desc.contains("phone-number"), + "description should contain the format: {desc}" + ); + } + + #[test] + fn numeric_constraints_moved_to_description() { + let input = schema(json!({ + "type": "integer", + "minimum": 1, + "maximum": 10, + "description": "A number" + })); + + let out = transform_schema(input); + assert!(out.get("minimum").is_none()); + assert!(out.get("maximum").is_none()); + let desc = out["description"].as_str().unwrap(); + assert!( + desc.starts_with("A number"), + "should preserve original description" + ); + assert!(desc.contains("minimum"), "should contain minimum: {desc}"); + assert!(desc.contains("maximum"), "should contain maximum: {desc}"); + } + + #[test] + fn ref_passes_through() { + let input = schema(json!({ + "$ref": "#/$defs/Address" + })); + + let out = transform_schema(input); + assert_eq!(out["$ref"], "#/$defs/Address"); + assert_eq!(out.len(), 1); + } + + #[test] + fn defs_recursively_transformed() { + let input = schema(json!({ + "type": "object", + "properties": { + "addr": { "$ref": "#/$defs/Address" } + }, + "$defs": { + "Address": { + "type": "object", + "properties": { "city": { "type": "string" } }, + "additionalProperties": true + } + } + })); + + let out = transform_schema(input); + let addr_def = out["$defs"]["Address"].as_object().unwrap(); + assert_eq!(addr_def["additionalProperties"], json!(false)); + } + + #[test] + fn one_of_converted_to_any_of() { + let input = schema(json!({ + "oneOf": [ + { "type": "string" }, + { "type": "integer" } + ] + })); + + let out = transform_schema(input); + assert!(out.get("oneOf").is_none()); + let any_of = out["anyOf"].as_array().unwrap(); + assert_eq!(any_of.len(), 2); + assert_eq!(any_of[0]["type"], "string"); + assert_eq!(any_of[1]["type"], "integer"); + } + + #[test] + fn nested_properties_recursively_transformed() { + let input = schema(json!({ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { "id": { "type": "integer", "minimum": 0 } }, + "additionalProperties": true + }, + "maxItems": 10 + } + } + })); + + let out = transform_schema(input); + + // Top-level object + assert_eq!(out["additionalProperties"], json!(false)); + + // The array property + let items_prop = out["properties"]["items"].as_object().unwrap(); + assert!(items_prop.get("maxItems").is_none()); + + // The nested object inside the array + let nested = items_prop["items"].as_object().unwrap(); + assert_eq!(nested["additionalProperties"], json!(false)); + + // The integer property's minimum should be in description + let id_prop = nested["properties"]["id"].as_object().unwrap(); + assert!(id_prop.get("minimum").is_none()); + let desc = id_prop["description"].as_str().unwrap(); + assert!( + desc.contains("minimum"), + "nested constraint in description: {desc}" + ); + } + + /// Mirrors the example from Anthropic's Python SDK docstring. + #[test] + fn sdk_docstring_example() { + let input = schema(json!({ + "type": "integer", + "minimum": 1, + "maximum": 10, + "description": "A number" + })); + + let out = transform_schema(input); + assert_eq!(out["type"], "integer"); + let desc = out["description"].as_str().unwrap(); + assert!(desc.starts_with("A number")); + assert!(desc.contains("minimum: 1")); + assert!(desc.contains("maximum: 10")); + } + + /// The `title_schema` used by the title generator should survive + /// transformation. + #[test] + fn title_schema_transforms_cleanly() { + let input = crate::title::title_schema(3); + let out = transform_schema(input); + + assert_eq!(out["type"], "object"); + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["titles"])); + + let titles = out["properties"]["titles"].as_object().unwrap(); + assert_eq!(titles["type"], "array"); + // minItems 1 is kept but > 1 is moved to description + assert!(titles.get("minItems").is_none()); + assert!(titles.get("maxItems").is_none()); + let desc = titles["description"].as_str().unwrap(); + assert!( + desc.contains("minItems"), + "should contain minItems hint: {desc}" + ); + assert!( + desc.contains("maxItems"), + "should contain maxItems hint: {desc}" + ); + } +} diff --git a/crates/jp_llm/src/provider/google.rs b/crates/jp_llm/src/provider/google.rs index 6d4a83ec..5ae80d5e 100644 --- a/crates/jp_llm/src/provider/google.rs +++ b/crates/jp_llm/src/provider/google.rs @@ -2,8 +2,9 @@ use std::{collections::HashMap, env}; use async_stream::stream; use async_trait::async_trait; +use chrono::NaiveDate; use futures::{StreamExt as _, TryStreamExt as _}; -use gemini_client_rs::{GeminiClient, types}; +use gemini_client_rs::{GeminiClient, GeminiError, types}; use indexmap::IndexMap; use jp_config::{ assistant::tool_choice::ToolChoice, @@ -18,14 +19,15 @@ use jp_conversation::{ event::{ChatResponse, ConversationEvent, EventKind, ToolCallRequest}, thread::{Document, Documents, Thread}, }; -use serde_json::Value; +use serde_json::{Map, Value}; use tracing::{debug, trace}; use super::{EventStream, Provider}; use crate::{ - error::{Error, Result}, + StreamErrorKind, + error::{Error, Result, StreamError, looks_like_quota_error}, event::{Event, FinishReason}, - model::{ModelDetails, ReasoningDetails}, + model::{ModelDeprecation, ModelDetails, ReasoningDetails}, query::ChatQuery, tool::ToolDefinition, }; @@ -69,7 +71,7 @@ impl Provider for Google { query: ChatQuery, ) -> Result { let client = self.client.clone(); - let request = create_request(model, query)?; + let (request, structured) = create_request(model, query)?; let slug = model.id.name.clone(); debug!(stream = true, "Google chat completion stream request."); @@ -78,7 +80,7 @@ impl Provider for Google { "Request payload." ); - Ok(call(client, request, slug, 0)) + Ok(call(client, request, slug, 0, structured)) } } @@ -87,17 +89,19 @@ fn call( request: types::GenerateContentRequest, model: Name, tries: usize, + is_structured: bool, ) -> EventStream { Box::pin(stream! { let mut state = IndexMap::new(); let stream = client .stream_content(&model, &request) - .await? - .map_err(Error::from); + .await + .map_err(|e| StreamError::other(e.to_string()))? + .map_err(StreamError::from); tokio::pin!(stream); while let Some(event) = stream.next().await { - for event in map_response(event?, &mut state)? { + for event in map_response(event?, &mut state, is_structured).map_err(|e| StreamError::other(e.to_string()))? { // Sometimes the API returns an "unexpected tool call" error, if // a previous turn had tools available but those were made // unavailable in follow-up turns. This is a known issue: @@ -117,7 +121,7 @@ fn call( let should_retry = matches!(&event, Event::Finished(FinishReason::Other(Value::String(s))) if s == "UNEXPECTED_TOOL_CALL"); if should_retry && tries < 3 { - let mut next_stream = call(client.clone(), request.clone(), model.clone(), tries + 1); + let mut next_stream = call(client.clone(), request.clone(), model.clone(), tries + 1, is_structured); while let Some(item) = next_stream.next().await { yield item; } @@ -131,21 +135,30 @@ fn call( } #[expect(clippy::too_many_lines)] -fn create_request(model: &ModelDetails, query: ChatQuery) -> Result { +fn create_request( + model: &ModelDetails, + query: ChatQuery, +) -> Result<(types::GenerateContentRequest, bool)> { let ChatQuery { thread, tools, tool_choice, - tool_call_strict_mode, } = query; let Thread { system_prompt, - instructions, + sections, attachments, events, } = thread; + // Only use the schema if the very last event is a ChatRequest with one. + let structured_schema = events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()); + let is_structured = structured_schema.is_some(); + let config = events.config()?; let parameters = &config.assistant.model.parameters; @@ -170,51 +183,81 @@ fn create_request(model: &ModelDetails, query: ChatQuery) -> Result 0) || reasoning.is_some()) - .map(|details| types::ThinkingConfig { - include_thoughts: reasoning.is_some_and(|v| !v.exclude), - thinking_budget: match details { - ReasoningDetails::Leveled { .. } => None, - _ => reasoning.map(|v| { + let supports_thinking = model.reasoning.is_some_and(|r| !r.is_unsupported()); + + // Add thinking config if the model supports it. + let thinking_config = if let Some(details) = model.reasoning.filter(|_| supports_thinking) { + if let Some(config) = reasoning { + // Reasoning is enabled — configure thinking accordingly. + Some(types::ThinkingConfig { + include_thoughts: !config.exclude, + thinking_budget: if details.is_leveled() { + None + } else { // TODO: Once the `gemini` crate supports `-1` for "auto" // thinking, use that here if `effort` is `Auto`. // // See: #[expect(clippy::cast_sign_loss)] - v.effort + let tokens = config + .effort .to_tokens(max_output_tokens.unwrap_or(32_000) as u32) .min(details.max_tokens().unwrap_or(u32::MAX)) - .max(details.min_tokens()) - }), - }, - thinking_level: match details { - ReasoningDetails::Leveled { low, high, .. } => { - let level = reasoning.map(|v| { - v.effort + .max(details.min_tokens()); + Some(tokens) + }, + thinking_level: match details { + ReasoningDetails::Leveled { + xlow, + low, + medium, + high, + xhigh: _, + } => { + let level = config + .effort .abs_to_rel(max_output_tokens.map(i32::cast_unsigned)) - }); - - match level { - Some(ReasoningEffort::Low) if low => Some(types::ThinkingLevel::Low), - Some(ReasoningEffort::High) if high => Some(types::ThinkingLevel::High), - // Any other level is unsupported and treated as - // high (since the documentation specifies this is - // the default). - _ => Some(types::ThinkingLevel::High), + .unwrap_or(ReasoningEffort::Auto); + + match level { + ReasoningEffort::None | ReasoningEffort::Xlow if xlow => { + Some(types::ThinkingLevel::Minimal) + } + ReasoningEffort::Low if low => Some(types::ThinkingLevel::Low), + ReasoningEffort::Medium if medium => Some(types::ThinkingLevel::Medium), + ReasoningEffort::High if high => Some(types::ThinkingLevel::High), + + // Any other level is unsupported and treated as + // high (since the documentation specifies this is + // the default). + _ => Some(types::ThinkingLevel::High), + } } - } - _ => None, - }, - }); + _ => None, + }, + }) + } else if details.min_tokens() > 0 { + // Model requires a minimum thinking budget — can't fully disable. + Some(types::ThinkingConfig { + include_thoughts: false, + thinking_budget: Some(details.min_tokens()), + thinking_level: None, + }) + } else { + // Reasoning is off — explicitly disable thinking. + Some(types::ThinkingConfig { + include_thoughts: false, + thinking_budget: Some(0), + thinking_level: None, + }) + } + } else { + None + }; let parts = { let mut parts = vec![]; @@ -222,14 +265,8 @@ fn create_request(model: &ModelDetails, query: ChatQuery) -> Result>>()? - .join("\n\n"); - - parts.push(types::ContentData::Text(text)); + for section in §ions { + parts.push(types::ContentData::Text(section.render())); } if !attachments.is_empty() { @@ -255,58 +292,186 @@ fn create_request(model: &ModelDetails, query: ChatQuery) -> Result>() }; - Ok(types::GenerateContentRequest { - system_instruction: if parts.is_empty() { - None - } else { - Some(types::Content { parts, role: None }) + // Set structured output config on GenerationConfig when a schema is present. + // We use `_responseJsonSchema` (the JSON Schema field) rather than + // `responseSchema` (the OpenAPI Schema field) so that standard JSON + // Schema properties like `additionalProperties` are accepted. + // + // The schema is transformed to rewrite unsupported properties (e.g. + // `const` → `enum`) so constraints aren't silently dropped. + let (response_mime_type, response_json_schema) = match structured_schema { + Some(schema) => ( + Some("application/json".to_owned()), + Some(Value::Object(transform_schema(schema))), + ), + None => (None, None), + }; + + Ok(( + types::GenerateContentRequest { + system_instruction: if parts.is_empty() { + None + } else { + Some(types::Content { parts, role: None }) + }, + contents: convert_events(events), + tools, + tool_config: Some(tool_config), + generation_config: Some(types::GenerationConfig { + max_output_tokens, + #[expect(clippy::cast_lossless)] + temperature: parameters.temperature.map(|v| v as f64), + #[expect(clippy::cast_lossless)] + top_p: parameters.top_p.map(|v| v as f64), + #[expect(clippy::cast_possible_wrap)] + top_k: parameters.top_k.map(|v| v as i32), + thinking_config, + response_mime_type, + response_json_schema, + ..Default::default() + }), }, - contents: convert_events(events), - tools, - tool_config: Some(tool_config), - generation_config: Some(types::GenerationConfig { - max_output_tokens, - #[expect(clippy::cast_lossless)] - temperature: parameters.temperature.map(|v| v as f64), - #[expect(clippy::cast_lossless)] - top_p: parameters.top_p.map(|v| v as f64), - #[expect(clippy::cast_possible_wrap)] - top_k: parameters.top_k.map(|v| v as i32), - thinking_config, - ..Default::default() - }), - }) + is_structured, + )) } +/// Map a Gemini model to a `ModelDetails`. +/// +/// See: +/// See: +#[expect(clippy::too_many_lines)] fn map_model(model: types::Model) -> ModelDetails { - ModelDetails { - id: (PROVIDER, model.base_model_id.as_str()).try_into().unwrap(), - display_name: Some(model.display_name), - context_window: Some(model.input_token_limit), - max_output_tokens: Some(model.output_token_limit), - reasoning: model - .base_model_id - .starts_with("gemini-2.5-pro") - .then_some(ReasoningDetails::budgetted(128, Some(32768))) - .or_else(|| { - model - .base_model_id - .starts_with("gemini-2.5-flash") - .then_some(ReasoningDetails::budgetted(0, Some(24576))) - }) - .or_else(|| { - (model.base_model_id.starts_with("gemini-3-flash") - || model.base_model_id == "gemini-flash-latest") - .then_some(ReasoningDetails::leveled(true, true, true, true, false)) - }) - .or_else(|| { - (model.base_model_id.starts_with("gemini-3-pro") - || model.base_model_id == "gemini-pro-latest") - .then_some(ReasoningDetails::leveled(false, true, false, true, false)) - }), - knowledge_cutoff: None, - deprecated: None, - features: vec![], + let name = model.base_model_id.as_str(); + let display_name = Some(model.display_name); + let context_window = Some(model.input_token_limit); + let max_output_tokens = Some(model.output_token_limit); + let Ok(id) = (PROVIDER, model.base_model_id.as_str()).try_into() else { + return ModelDetails::empty((PROVIDER, "unknown").try_into().unwrap()); + }; + + match name { + "gemini-pro-latest" | "gemini-3.1-pro-preview" | "gemini-3.1-pro-preview-customtools" => { + ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::leveled(false, true, true, true, false)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![], + } + } + "gemini-3-pro-preview" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::leveled(false, true, false, true, false)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![], + }, + "gemini-flash-latest" | "gemini-3-flash-preview" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::leveled(true, true, true, true, false)), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::Active), + features: vec![], + }, + "gemini-2.5-flash" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::budgetted(0, Some(24576))), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: gemini-3-flash-preview", + Some(NaiveDate::from_ymd_opt(2026, 6, 17).unwrap()), + )), + features: vec![], + }, + "gemini-flash-lite-latest" + | "gemini-2.5-flash-lite" + | "gemini-2.5-flash-lite-preview-09-2025" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::budgetted(512, Some(24576))), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: unknown", + Some(NaiveDate::from_ymd_opt(2026, 7, 22).unwrap()), + )), + features: vec![], + }, + "gemini-2.5-pro" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::budgetted(512, Some(24576))), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2025, 1, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: gemini-3-pro-preview", + Some(NaiveDate::from_ymd_opt(2026, 6, 17).unwrap()), + )), + features: vec![], + }, + "gemini-2.0-flash" | "gemini-2.0-flash-001" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::budgetted(0, Some(24576))), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 8, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: gemini-2.5-flash", + Some(NaiveDate::from_ymd_opt(2026, 6, 1).unwrap()), + )), + features: vec![], + }, + "gemini-2.0-flash-lite" | "gemini-2.0-flash-lite-001" => ModelDetails { + id, + display_name, + context_window, + max_output_tokens, + reasoning: Some(ReasoningDetails::unsupported()), + knowledge_cutoff: Some(NaiveDate::from_ymd_opt(2024, 8, 1).unwrap()), + deprecated: Some(ModelDeprecation::deprecated( + &"recommended replacement: gemini-2.5-flash-lite", + Some(NaiveDate::from_ymd_opt(2026, 6, 1).unwrap()), + )), + features: vec![], + }, + id => { + trace!( + name, + display_name = display_name + .clone() + .unwrap_or_else(|| "".to_owned()), + id, + "Missing model details. Falling back to generic model details." + ); + + ModelDetails { + id: (PROVIDER, model.base_model_id.as_str()) + .try_into() + .unwrap_or((PROVIDER, "unknown").try_into().unwrap()), + display_name, + context_window, + max_output_tokens, + reasoning: None, + knowledge_cutoff: None, + deprecated: None, + features: vec![], + } + } } } @@ -340,6 +505,7 @@ impl CandidateState { fn map_response( response: types::GenerateContentResponse, state: &mut IndexMap, + is_structured: bool, ) -> Result> { debug!("Received response from Google API."); trace!( @@ -350,7 +516,7 @@ fn map_response( response .candidates .into_iter() - .flat_map(|v| map_candidate(v, state)) + .flat_map(|v| map_candidate(v, state, is_structured)) .try_fold(vec![], |mut acc, events| { acc.extend(events); Ok(acc) @@ -360,6 +526,7 @@ fn map_response( fn map_candidate( candidate: types::Candidate, states: &mut IndexMap, + is_structured: bool, ) -> Result> { let types::Candidate { content, @@ -382,6 +549,13 @@ fn map_candidate( .. } = part; + // Google sometimes sends empty text parts (e.g. the final chunk of a + // thinking+tool_use response with finishReason: STOP). Skip them before + // mode transition logic to avoid spurious flushes. + if matches!(&data, types::ContentData::Text(text) if text.is_empty()) { + continue; + } + // Determine what "mode" the content is in. let mode = if matches!(data, types::ContentData::FunctionCall(_)) { ContentMode::FunctionCall @@ -414,6 +588,10 @@ fn map_candidate( event: ConversationEvent::now(ChatResponse::reasoning(text)), index, }, + types::ContentData::Text(text) if is_structured => Event::Part { + event: ConversationEvent::now(ChatResponse::structured(Value::String(text))), + index, + }, types::ContentData::Text(text) => Event::Part { event: ConversationEvent::now(ChatResponse::message(text)), index, @@ -488,11 +666,140 @@ impl TryFrom<&GoogleConfig> for Google { } } -fn convert_tool_choice(choice: ToolChoice, strict: bool) -> types::ToolConfig { +/// Transform a JSON schema to conform to Google's structured output constraints. +/// +/// Google's Gemini API supports a subset of JSON Schema. This transformation: +/// - Inlines `$ref` references by replacing them with the referenced `$defs` +/// - Removes `$defs`/`definitions` from the output after inlining +/// - Rewrites `const` to `enum` with a single value (Google ignores `const`) +/// - Adds `propertyOrdering` to objects with multiple properties +/// - Recursively processes `anyOf`, object properties, `additionalProperties`, +/// array `items`, and `prefixItems` +/// +/// Mirrors the logic from Google's Python SDK `process_schema`. +/// +/// See: +fn transform_schema(mut src: Map) -> Map { + // Extract $defs from the root. They are inlined wherever $ref appears + // and discarded from the output. + let defs = src + .remove("$defs") + .or_else(|| src.remove("definitions")) + .and_then(|v| match v { + Value::Object(m) => Some(m), + _ => None, + }) + .unwrap_or_default(); + + process_schema(src, &defs) +} + +/// Core recursive processor for a single schema node. +fn process_schema(mut src: Map, defs: &Map) -> Map { + // Resolve $ref by inlining the referenced definition. + if let Some(Value::String(ref_path)) = src.remove("$ref") + && let Some(resolved) = resolve_ref(&ref_path, defs) + { + let mut merged = resolved; + // Preserve sibling fields (e.g. description, nullable) from the + // referring schema — the definition's own fields take precedence. + for (k, v) in src { + merged.entry(k).or_insert(v); + } + return process_schema(merged, defs); + } + + // Rewrite `const` to `enum` with a single-element array. Google supports + // `enum` for strings and numbers but not `const`. + if let Some(val) = src.remove("const") { + src.insert("enum".into(), Value::Array(vec![val])); + } + + // Handle anyOf: recurse into each variant, then return early (matching the + // Python SDK's behavior). + if let Some(Value::Array(variants)) = src.remove("anyOf") { + src.insert( + "anyOf".into(), + Value::Array( + variants + .into_iter() + .map(|v| resolve_and_process(v, defs)) + .collect(), + ), + ); + return src; + } + + // Type-specific processing. + match src.get("type").and_then(Value::as_str) { + Some("object") => { + if let Some(Value::Object(props)) = src.remove("properties") { + let keys: Vec = props.keys().cloned().collect(); + let processed: Map = props + .into_iter() + .map(|(k, v)| (k, resolve_and_process(v, defs))) + .collect(); + + // Deterministic output ordering for objects with >1 property. + if keys.len() > 1 && !src.contains_key("propertyOrdering") { + src.insert( + "propertyOrdering".into(), + Value::Array(keys.into_iter().map(Value::String).collect()), + ); + } + + src.insert("properties".into(), Value::Object(processed)); + } + + // Process additionalProperties when it's a schema (not a boolean). + if let Some(additional) = src.remove("additionalProperties") { + src.insert("additionalProperties".into(), match additional { + Value::Object(schema) => Value::Object(process_schema(schema, defs)), + other => other, + }); + } + } + Some("array") => { + if let Some(items) = src.remove("items") { + src.insert("items".into(), resolve_and_process(items, defs)); + } + + if let Some(Value::Array(prefixes)) = src.remove("prefixItems") { + src.insert( + "prefixItems".into(), + Value::Array( + prefixes + .into_iter() + .map(|v| resolve_and_process(v, defs)) + .collect(), + ), + ); + } + } + _ => {} + } + + src +} + +/// Resolve any `$ref` inside a value, then recursively process it. +fn resolve_and_process(value: Value, defs: &Map) -> Value { + match value { + Value::Object(map) => Value::Object(process_schema(map, defs)), + other => other, + } +} + +/// Look up a `$ref` path (e.g. `#/$defs/MyType`) in the definitions map. +fn resolve_ref(ref_path: &str, defs: &Map) -> Option> { + let name = ref_path.rsplit("defs/").next().unwrap_or(ref_path); + defs.get(name).and_then(Value::as_object).cloned() +} + +fn convert_tool_choice(choice: ToolChoice) -> types::ToolConfig { let (mode, allowed_function_names) = match choice { ToolChoice::None => (types::FunctionCallingMode::None, vec![]), - ToolChoice::Auto if strict => (types::FunctionCallingMode::Validated, vec![]), - ToolChoice::Auto => (types::FunctionCallingMode::Auto, vec![]), + ToolChoice::Auto => (types::FunctionCallingMode::Validated, vec![]), ToolChoice::Required => (types::FunctionCallingMode::Any, vec![]), ToolChoice::Function(name) => (types::FunctionCallingMode::Any, vec![name]), }; @@ -546,12 +853,20 @@ fn convert_events(events: ConversationStream) -> Vec { types::Role::User, types::ContentData::Text(request.content).into(), ), - EventKind::ChatResponse(response) => (types::Role::Model, types::ContentPart { - thought: response.is_reasoning(), - data: types::ContentData::Text(response.into_content()), - metadata: None, - thought_signature: None, - }), + EventKind::ChatResponse(response) => { + let thought = response.is_reasoning(); + let text = match response { + ChatResponse::Message { message } => message, + ChatResponse::Reasoning { reasoning } => reasoning, + ChatResponse::Structured { data } => data.to_string(), + }; + (types::Role::Model, types::ContentPart { + thought, + data: types::ContentData::Text(text), + metadata: None, + thought_signature: None, + }) + } EventKind::ToolCallRequest(request) => (types::Role::Model, types::ContentPart { data: types::ContentData::FunctionCall(types::FunctionCall { name: { @@ -611,37 +926,67 @@ fn convert_events(events: ConversationStream) -> Vec { }) } -#[cfg(test)] -mod tests { - use jp_config::model::parameters::{ - PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort, - }; - use jp_conversation::event::ChatRequest; - use jp_test::function_name; - use test_log::test; +impl From for StreamError { + fn from(err: GeminiError) -> Self { + match err { + GeminiError::Http(error) => Self::from(error), + GeminiError::EventSource(error) => Self::from(error), + GeminiError::Api(ref value) => { + let msg = err.to_string(); + + // Check for quota/billing exhaustion first. + if looks_like_quota_error(&msg) { + return StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your plan and billing details \ + at https://console.cloud.google.com/billing. ({msg})" + ), + ) + .with_source(err); + } - use super::*; - use crate::test::{TestRequest, run_test}; + // Classify by HTTP status code if present in the API error. + let status = value + .get("status") + .or_else(|| value.pointer("/error/code")) + .and_then(serde_json::Value::as_u64); - // TODO: Test specific conditions as detailed in - // : - // - // - parallel function calls - // - dummy thought signatures - // - multi-turn conversations - #[test(tokio::test)] - async fn test_gemini_3_reasoning() -> std::result::Result<(), Box> { - let request = TestRequest::chat(PROVIDER) - .stream(true) - .reasoning(Some(PartialReasoningConfig::Custom( - PartialCustomReasoningConfig { - effort: Some(ReasoningEffort::Low), - exclude: Some(false), - }, - ))) - .model("google/gemini-3-pro-preview".parse().unwrap()) - .event(ChatRequest::from("Test message")); + match status { + Some(429) => StreamError::rate_limit(None).with_source(err), + Some(500 | 502 | 503 | 504) => StreamError::transient(msg).with_source(err), + _ => StreamError::other(msg).with_source(err), + } + } + GeminiError::Json { data, error } => StreamError::other(data).with_source(error), + GeminiError::FunctionExecution(msg) => StreamError::other(msg), + } + } +} - run_test(PROVIDER, function_name!(), Some(request)).await +impl From for Error { + fn from(error: GeminiError) -> Self { + match &error { + GeminiError::Api(api) if api.get("status").is_some_and(|v| v.as_u64() == Some(404)) => { + if let Some(model) = api.pointer("/message/error/message").and_then(|v| { + v.as_str().and_then(|s| { + s.contains("Call ListModels").then(|| { + s.split('/') + .nth(1) + .and_then(|v| v.split(' ').next()) + .unwrap_or("unknown") + }) + }) + }) { + return Self::UnknownModel(model.to_owned()); + } + Self::Gemini(error) + } + _ => Self::Gemini(error), + } } } + +#[cfg(test)] +#[path = "google_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/google_tests.rs b/crates/jp_llm/src/provider/google_tests.rs new file mode 100644 index 00000000..c6fd1087 --- /dev/null +++ b/crates/jp_llm/src/provider/google_tests.rs @@ -0,0 +1,501 @@ +use jp_config::model::parameters::{ + PartialCustomReasoningConfig, PartialReasoningConfig, ReasoningEffort, +}; +use jp_conversation::event::ChatRequest; +use jp_test::function_name; +use test_log::test; + +use super::*; +use crate::test::{TestRequest, run_test}; + +// TODO: Test specific conditions as detailed in +// : +// +// - parallel function calls +// - dummy thought signatures +// - multi-turn conversations +#[test(tokio::test)] +async fn test_gemini_3_reasoning() -> std::result::Result<(), Box> { + let request = TestRequest::chat(PROVIDER) + .reasoning(Some(PartialReasoningConfig::Custom( + PartialCustomReasoningConfig { + effort: Some(ReasoningEffort::Low), + exclude: Some(false), + }, + ))) + .model("google/gemini-3-pro-preview".parse().unwrap()) + .event(ChatRequest::from("Test message")); + + run_test(PROVIDER, function_name!(), Some(request)).await +} + +mod transform_schema { + use serde_json::{Map, Value, json}; + + use super::transform_schema; + + #[expect(clippy::needless_pass_by_value)] + fn schema(v: Value) -> Map { + v.as_object().unwrap().clone() + } + + #[test] + fn const_rewritten_to_enum() { + let input = schema(json!({ + "type": "string", + "const": "tool_call.my_tool.call_123" + })); + + let out = transform_schema(input); + + assert_eq!(out.get("const"), None); + assert_eq!(out["enum"], json!(["tool_call.my_tool.call_123"])); + assert_eq!(out["type"], "string"); + } + + #[test] + fn const_rewritten_for_non_string_values() { + let input = schema(json!({ + "type": "integer", + "const": 42 + })); + + let out = transform_schema(input); + + assert_eq!(out.get("const"), None); + assert_eq!(out["enum"], json!([42])); + } + + #[test] + fn nested_const_in_properties_rewritten() { + let input = schema(json!({ + "type": "object", + "properties": { + "inquiry_id": { + "type": "string", + "const": "tool_call.fs_modify_file.call_abc" + }, + "answer": { + "type": "boolean" + } + }, + "required": ["inquiry_id", "answer"] + })); + + let out = transform_schema(input); + + let inquiry_id = out["properties"]["inquiry_id"].as_object().unwrap(); + assert_eq!(inquiry_id.get("const"), None); + assert_eq!( + inquiry_id["enum"], + json!(["tool_call.fs_modify_file.call_abc"]) + ); + assert_eq!(inquiry_id["type"], "string"); + assert_eq!(out["properties"]["answer"]["type"], "boolean"); + } + + #[test] + fn deeply_nested_const_rewritten() { + let input = schema(json!({ + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": { + "inner": { + "type": "string", + "const": "fixed" + } + } + } + } + })); + + let out = transform_schema(input); + + let inner = &out["properties"]["outer"]["properties"]["inner"]; + assert_eq!(inner.get("const"), None); + assert_eq!(inner["enum"], json!(["fixed"])); + } + + #[test] + fn const_in_array_items_rewritten() { + let input = schema(json!({ + "type": "array", + "items": { + "type": "string", + "const": "only_value" + } + })); + + let out = transform_schema(input); + + let items = out["items"].as_object().unwrap(); + assert_eq!(items.get("const"), None); + assert_eq!(items["enum"], json!(["only_value"])); + } + + #[test] + fn ref_inlined_from_defs() { + let input = schema(json!({ + "type": "array", + "items": { "$ref": "#/$defs/CountryInfo" }, + "$defs": { + "CountryInfo": { + "type": "object", + "properties": { + "continent": { "type": "string" }, + "gdp": { "type": "integer" } + }, + "required": ["continent", "gdp"] + } + } + })); + + let out = transform_schema(input); + + // $defs should be removed from the output. + assert!(out.get("$defs").is_none()); + + // items should be the inlined definition. + let items = out["items"].as_object().unwrap(); + assert_eq!(items["type"], "object"); + assert_eq!(items["properties"]["continent"]["type"], "string"); + assert_eq!(items["properties"]["gdp"]["type"], "integer"); + assert_eq!(items["required"], json!(["continent", "gdp"])); + } + + #[test] + fn ref_with_sibling_fields_preserved() { + let input = schema(json!({ + "type": "object", + "properties": { + "person": { + "$ref": "#/$defs/Person", + "description": "The main person" + } + }, + "$defs": { + "Person": { + "type": "object", + "properties": { + "name": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + // Sibling "description" should be preserved alongside inlined def. + let person = out["properties"]["person"].as_object().unwrap(); + assert_eq!(person["type"], "object"); + assert_eq!(person["description"], "The main person"); + assert_eq!(person["properties"]["name"]["type"], "string"); + } + + #[test] + fn ref_in_nested_property() { + let input = schema(json!({ + "type": "object", + "properties": { + "addr": { "$ref": "#/$defs/Address" } + }, + "$defs": { + "Address": { + "type": "object", + "properties": { + "city": { "type": "string" }, + "zip": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + let addr = out["properties"]["addr"].as_object().unwrap(); + assert_eq!(addr["type"], "object"); + assert_eq!(addr["properties"]["city"]["type"], "string"); + assert_eq!(addr["properties"]["zip"]["type"], "string"); + // Inlined object gets propertyOrdering. + assert_eq!(addr["propertyOrdering"], json!(["city", "zip"])); + } + + #[test] + fn definitions_also_removed() { + let input = schema(json!({ + "type": "object", + "properties": { + "x": { "$ref": "#/$defs/X" } + }, + "definitions": { + "X": { "type": "string" } + } + })); + + let out = transform_schema(input); + + assert!(out.get("definitions").is_none()); + assert_eq!(out["properties"]["x"]["type"], "string"); + } + + #[test] + fn property_ordering_added_for_multiple_properties() { + let input = schema(json!({ + "type": "object", + "properties": { + "first": { "type": "string" }, + "second": { "type": "integer" }, + "third": { "type": "boolean" } + } + })); + + let out = transform_schema(input); + + assert_eq!(out["propertyOrdering"], json!(["first", "second", "third"])); + } + + #[test] + fn property_ordering_not_added_for_single_property() { + let input = schema(json!({ + "type": "object", + "properties": { + "only": { "type": "string" } + } + })); + + let out = transform_schema(input); + + assert!(out.get("propertyOrdering").is_none()); + } + + #[test] + fn property_ordering_preserved_if_already_set() { + let input = schema(json!({ + "type": "object", + "properties": { + "a": { "type": "string" }, + "b": { "type": "string" } + }, + "propertyOrdering": ["b", "a"] + })); + + let out = transform_schema(input); + + // Existing ordering should not be overwritten. + assert_eq!(out["propertyOrdering"], json!(["b", "a"])); + } + + #[test] + fn anyof_variants_processed() { + let input = schema(json!({ + "anyOf": [ + { "type": "string", "const": "fixed" }, + { "type": "integer" } + ] + })); + + let out = transform_schema(input); + + let variants = out["anyOf"].as_array().unwrap(); + assert_eq!(variants.len(), 2); + // const should be rewritten inside the variant. + assert_eq!(variants[0]["enum"], json!(["fixed"])); + assert!(variants[0].get("const").is_none()); + assert_eq!(variants[1]["type"], "integer"); + } + + #[test] + fn anyof_with_ref_resolved() { + let input = schema(json!({ + "anyOf": [ + { "$ref": "#/$defs/Str" }, + { "type": "integer" } + ], + "$defs": { + "Str": { "type": "string" } + } + })); + + let out = transform_schema(input); + + let variants = out["anyOf"].as_array().unwrap(); + assert_eq!(variants[0]["type"], "string"); + assert_eq!(variants[1]["type"], "integer"); + } + + #[test] + fn additional_properties_bool_preserved() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "additionalProperties": false + })); + + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + } + + #[test] + fn additional_properties_schema_processed() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "additionalProperties": { + "type": "string", + "const": "extra" + } + })); + + let out = transform_schema(input); + + let additional = out["additionalProperties"].as_object().unwrap(); + assert_eq!(additional.get("const"), None); + assert_eq!(additional["enum"], json!(["extra"])); + } + + #[test] + fn prefix_items_processed() { + let input = schema(json!({ + "type": "array", + "prefixItems": [ + { "type": "string", "const": "header" }, + { "type": "integer" } + ] + })); + + let out = transform_schema(input); + + let prefixes = out["prefixItems"].as_array().unwrap(); + assert_eq!(prefixes[0]["enum"], json!(["header"])); + assert!(prefixes[0].get("const").is_none()); + assert_eq!(prefixes[1]["type"], "integer"); + } + + #[test] + fn enum_preserved_unchanged() { + let input = schema(json!({ + "type": "string", + "enum": ["A", "B", "C"] + })); + + let out = transform_schema(input); + + assert_eq!(out["enum"], json!(["A", "B", "C"])); + } + + #[test] + fn supported_properties_preserved() { + let input = schema(json!({ + "type": "integer", + "minimum": 1, + "maximum": 10, + "description": "A number" + })); + + let out = transform_schema(input); + + assert_eq!(out["type"], "integer"); + assert_eq!(out["minimum"], 1); + assert_eq!(out["maximum"], 10); + assert_eq!(out["description"], "A number"); + } + + /// The actual inquiry schema should transform correctly for Google. + #[test] + fn inquiry_schema_transforms_correctly() { + let input = schema(json!({ + "type": "object", + "properties": { + "inquiry_id": { + "type": "string", + "const": "tool_call.fs_modify_file.call_a3b7c9d1" + }, + "answer": { + "type": "boolean" + } + }, + "required": ["inquiry_id", "answer"], + "additionalProperties": false + })); + + let out = transform_schema(input); + + assert_eq!( + Value::Object(out), + json!({ + "type": "object", + "required": ["inquiry_id", "answer"], + "additionalProperties": false, + "propertyOrdering": ["inquiry_id", "answer"], + "properties": { + "inquiry_id": { + "type": "string", + "enum": ["tool_call.fs_modify_file.call_a3b7c9d1"] + }, + "answer": { + "type": "boolean" + } + } + }) + ); + } + + /// The `title_schema` should pass through mostly unchanged. + /// It has a single property so no `propertyOrdering` is added. + #[test] + fn title_schema_passes_through() { + let input = crate::title::title_schema(3); + let out = transform_schema(input.clone()); + + assert_eq!(out, input); + } + + /// Matches the example from the Python SDK docstring. + #[test] + fn sdk_docstring_example() { + let input = schema(json!({ + "items": { "$ref": "#/$defs/CountryInfo" }, + "title": "Placeholder", + "type": "array", + "$defs": { + "CountryInfo": { + "properties": { + "continent": { "title": "Continent", "type": "string" }, + "gdp": { "title": "Gdp", "type": "integer" } + }, + "required": ["continent", "gdp"], + "title": "CountryInfo", + "type": "object" + } + } + })); + + let out = transform_schema(input); + + // $defs removed, $ref inlined, propertyOrdering added. + assert_eq!( + Value::Object(out), + json!({ + "title": "Placeholder", + "type": "array", + "items": { + "properties": { + "continent": { "title": "Continent", "type": "string" }, + "gdp": { "title": "Gdp", "type": "integer" } + }, + "required": ["continent", "gdp"], + "title": "CountryInfo", + "type": "object", + "propertyOrdering": ["continent", "gdp"] + } + }) + ); + } +} diff --git a/crates/jp_llm/src/provider/llamacpp.rs b/crates/jp_llm/src/provider/llamacpp.rs index e4d1f68b..fdde4a4e 100644 --- a/crates/jp_llm/src/provider/llamacpp.rs +++ b/crates/jp_llm/src/provider/llamacpp.rs @@ -1,7 +1,7 @@ use std::mem; use async_trait::async_trait; -use futures::{FutureExt as _, StreamExt as _, future, stream}; +use futures::{StreamExt as _, future, stream}; use jp_config::{ assistant::tool_choice::ToolChoice, model::id::{ModelIdConfig, Name, ProviderId}, @@ -11,26 +11,19 @@ use jp_conversation::{ ConversationEvent, ConversationStream, event::{ChatResponse, EventKind, ToolCallResponse}, }; -use openai::{ - Credentials, - chat::{ - self, ChatCompletionBuilder, ChatCompletionChoiceDelta, ChatCompletionDelta, - ChatCompletionMessage, ChatCompletionMessageDelta, ChatCompletionMessageRole, - ToolCallFunction, structured_output::ToolCallFunctionDefinition, - }, -}; -use serde_json::Value; -use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, trace}; +use reqwest_eventsource::{Event as SseEvent, EventSource}; +use serde::Deserialize; +use serde_json::{Value, json}; +use tracing::{debug, trace, warn}; use super::{ EventStream, ModelDetails, - openai::{ModelListResponse, ModelResponse}, + openai::{ModelListResponse, ModelResponse, parameters_with_strict_mode}, }; use crate::{ - error::{Error, Result}, + error::{Error, StreamError}, event::{Event, FinishReason}, - provider::{Provider, openai::parameters_with_strict_mode}, + provider::Provider, query::ChatQuery, stream::aggregator::{ reasoning::ReasoningExtractor, tool_call_request::ToolCallRequestAggregator, @@ -43,46 +36,12 @@ static PROVIDER: ProviderId = ProviderId::Llamacpp; #[derive(Debug, Clone)] pub struct Llamacpp { reqwest_client: reqwest::Client, - credentials: Credentials, base_url: String, } -impl Llamacpp { - /// Build request for Llama.cpp API. - fn build_request( - &self, - model: &ModelDetails, - query: ChatQuery, - ) -> Result { - let slug = model.id.name.to_string(); - let ChatQuery { - thread, - tools, - tool_choice, - tool_call_strict_mode, - } = query; - - let messages = thread.into_messages(to_system_messages, convert_events)?; - let tools = convert_tools(tools, tool_call_strict_mode, &tool_choice); - let tool_choice = convert_tool_choice(&tool_choice); - - trace!( - slug, - messages_size = messages.len(), - tools_size = tools.len(), - "Built Llamacpp request." - ); - - Ok(ChatCompletionDelta::builder(&slug, messages) - .credentials(self.credentials.clone()) - .tools(tools) - .tool_choice(tool_choice)) - } -} - #[async_trait] impl Provider for Llamacpp { - async fn model_details(&self, name: &Name) -> Result { + async fn model_details(&self, name: &Name) -> Result { let id: ModelIdConfig = (PROVIDER, name.as_ref()).try_into()?; Ok(self @@ -93,7 +52,7 @@ impl Provider for Llamacpp { .unwrap_or(ModelDetails::empty(id))) } - async fn models(&self) -> Result> { + async fn models(&self) -> Result, Error> { self.reqwest_client .get(format!("{}/v1/models", self.base_url)) .send() @@ -104,113 +63,417 @@ impl Provider for Llamacpp { .data .iter() .map(map_model) - .collect::>() + .collect::>() } async fn chat_completion_stream( &self, model: &ModelDetails, query: ChatQuery, - ) -> Result { + ) -> Result { debug!( model = %model.id.name, "Starting Llamacpp chat completion stream." ); - let mut extractor = ReasoningExtractor::default(); - let mut agg = ToolCallRequestAggregator::default(); - let request = self.build_request(model, query)?; + let (body, is_structured) = build_request(model, query)?; - trace!(?request, "Sending request to Llamacpp."); + trace!( + body = serde_json::to_string(&body).unwrap_or_default(), + "Sending request to Llamacpp." + ); + + let request = self + .reqwest_client + .post(format!("{}/v1/chat/completions", self.base_url)) + .header("content-type", "application/json") + .json(&body); + + let es = EventSource::new(request).map_err(|e| Error::InvalidResponse(e.to_string()))?; + + let mut state = StreamState { + extractor: ReasoningExtractor::default(), + tool_calls: ToolCallRequestAggregator::default(), + reasoning_flushed: false, + finish_reason: None, + is_structured, + }; - Ok(request - .create_stream() - .await - .map(ReceiverStream::new) - .expect("Should not fail to clone") - .flat_map(|v| stream::iter(v.choices)) - .flat_map(move |v| stream::iter(map_event(v, &mut extractor, &mut agg))) - .chain(future::ok(Event::Finished(FinishReason::Completed)).into_stream()) + Ok(es + // EventSource yields Err on close; stop the stream. + .take_while(|event| future::ready(event.is_ok())) + .flat_map(move |event| stream::iter(handle_sse_event(event, &mut state))) .boxed()) } } -fn map_event( - event: ChatCompletionChoiceDelta, - extractor: &mut ReasoningExtractor, - agg: &mut ToolCallRequestAggregator, -) -> Vec> { - let ChatCompletionChoiceDelta { - index, - finish_reason, - delta: - ChatCompletionMessageDelta { - content, - tool_calls, - .. - }, - } = event; - - #[allow(clippy::cast_possible_truncation)] - let index = index as usize; - let mut events = vec![]; +/// Mutable state carried across SSE events in a single stream. +struct StreamState { + extractor: ReasoningExtractor, + tool_calls: ToolCallRequestAggregator, + reasoning_flushed: bool, + /// Captured from `finish_reason` in the last choice delta. Emitted as + /// `Event::Finished` when the `[DONE]` sentinel arrives. + finish_reason: Option, + is_structured: bool, +} - for chat::ToolCallDelta { id, function, .. } in tool_calls.into_iter().flatten() { - let (name, arguments) = match function { - Some(chat::ToolCallFunction { name, arguments }) => (Some(name), Some(arguments)), - None => (None, None), - }; +/// Process a single SSE event into zero or more provider-agnostic events. +#[expect(clippy::too_many_lines)] +fn handle_sse_event( + event: Result, + state: &mut StreamState, +) -> Vec> { + match event { + Ok(SseEvent::Open) => vec![], + Ok(SseEvent::Message(msg)) => { + if msg.data == "[DONE]" { + // Finalize the reasoning extractor on stream end. + state.extractor.finalize(); + let mut events: Vec> = + drain_extractor(&mut state.extractor, state.is_structured) + .into_iter() + .map(Ok) + .collect(); + + // Flush reasoning if we never did. + if !state.reasoning_flushed { + events.push(Ok(Event::flush(0))); + state.reasoning_flushed = true; + } + + // Flush message content. + events.push(Ok(Event::flush(1))); + + events.push(Ok(Event::Finished( + state + .finish_reason + .take() + .unwrap_or(FinishReason::Completed), + ))); + return events; + } - agg.add_chunk(index, id, name, arguments.as_deref()); - } + let chunk: StreamChunk = match serde_json::from_str(&msg.data) { + Ok(c) => c, + Err(error) => { + warn!( + error = error.to_string(), + data = &msg.data, + "Failed to parse Llamacpp chunk." + ); + + return vec![]; + } + }; + + let mut events = Vec::new(); + + for choice in &chunk.choices { + let delta = &choice.delta; + + // Reasoning via `reasoning_content` (deepseek / deepseek-legacy formats) + if let Some(reasoning) = &delta.reasoning_content + && !reasoning.is_empty() + { + events.push(Ok(Event::Part { + index: 0, + event: ConversationEvent::now(ChatResponse::reasoning(reasoning.clone())), + })); + } + + // Content + // + // If reasoning_content was present, the server already + // separated reasoning from content (deepseek / + // deepseek-legacy). Otherwise, content may contain tags + // (none format) and needs the extractor. + if let Some(content) = &delta.content + && !content.is_empty() + { + // Server separated reasoning; content is pure text. + if delta.reasoning_content.is_some() { + flush_reasoning_if_needed(&mut events, &mut state.reasoning_flushed); + + let response = if state.is_structured { + ChatResponse::structured(Value::String(content.clone())) + } else { + ChatResponse::message(content.clone()) + }; + events.push(Ok(Event::Part { + index: 1, + event: ConversationEvent::now(response), + })); + } else { + // Might contain tags — feed through extractor. + state.extractor.handle(content); + events.extend( + drain_extractor(&mut state.extractor, state.is_structured) + .into_iter() + .map(Ok), + ); + } + } + + // Tool calls + if let Some(tool_calls) = &delta.tool_calls { + flush_reasoning_if_needed(&mut events, &mut state.reasoning_flushed); + + for tc in tool_calls { + let index = tc.index as usize + 2; + let name = tc.function.as_ref().and_then(|f| f.name.clone()); + let arguments = tc.function.as_ref().and_then(|f| f.arguments.as_deref()); + state + .tool_calls + .add_chunk(index, tc.id.clone(), name, arguments); + } + } + + // Finish reason + if let Some(reason) = &choice.finish_reason { + state.extractor.finalize(); + events.extend( + drain_extractor(&mut state.extractor, state.is_structured) + .into_iter() + .map(Ok), + ); + + if matches!(reason.as_str(), "tool_calls" | "stop") { + events.extend(state.tool_calls.finalize_all().into_iter().flat_map( + |(index, result)| { + vec![ + result + .map(|call| Event::Part { + index, + event: ConversationEvent::now(call), + }) + .map_err(|e| StreamError::other(e.to_string())), + Ok(Event::flush(index)), + ] + }, + )); + } + + if !state.reasoning_flushed { + events.push(Ok(Event::flush(0))); + state.reasoning_flushed = true; + } + events.push(Ok(Event::flush(1))); + + // Per the OpenAI spec. + match reason.as_str() { + "length" => state.finish_reason = Some(FinishReason::MaxTokens), + "stop" => state.finish_reason = Some(FinishReason::Completed), + _ => {} + } + } + } - if ["function_call", "tool_calls"].contains(&finish_reason.as_deref().unwrap_or_default()) { - match agg.finalize(index) { - Ok(request) => events.extend(vec![ - Ok(Event::Part { - index, - event: ConversationEvent::now(request), - }), - Ok(Event::flush(index)), - ]), - Err(error) => events.push(Err(error.into())), + events } + Err(e) => vec![Err(StreamError::from(e))], } +} - if let Some(content) = content { - extractor.handle(&content); - } - - if finish_reason.is_some() { - extractor.finalize(); +/// Push a reasoning flush event if we haven't already. +fn flush_reasoning_if_needed(events: &mut Vec>, flushed: &mut bool) { + if !*flushed { + events.push(Ok(Event::flush(0))); + *flushed = true; } - - events.extend(fetch_content(extractor, index).into_iter().map(Ok)); - events } -fn fetch_content(extractor: &mut ReasoningExtractor, index: usize) -> Vec { +/// Drain accumulated content from the `ReasoningExtractor` into events. +/// +/// Index convention matches Ollama: 0 = reasoning, 1 = message content. +fn drain_extractor(extractor: &mut ReasoningExtractor, is_structured: bool) -> Vec { let mut events = Vec::new(); + if !extractor.reasoning.is_empty() { let reasoning = mem::take(&mut extractor.reasoning); events.push(Event::Part { - index, + index: 0, event: ConversationEvent::now(ChatResponse::reasoning(reasoning)), }); } if !extractor.other.is_empty() { let content = mem::take(&mut extractor.other); + let response = if is_structured { + ChatResponse::structured(Value::String(content)) + } else { + ChatResponse::message(content) + }; events.push(Event::Part { - index, - event: ConversationEvent::now(ChatResponse::message(content)), + index: 1, + event: ConversationEvent::now(response), }); } events } -fn map_model(model: &ModelResponse) -> Result { +/// Build the JSON request body for the llama.cpp `/v1/chat/completions` +/// endpoint. +/// +/// Returns `(body, is_structured)`. +fn build_request(model: &ModelDetails, query: ChatQuery) -> Result<(Value, bool), Error> { + let ChatQuery { + thread, + tools, + tool_choice, + } = query; + + let structured_schema = thread + .events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()); + + let is_structured = structured_schema.is_some(); + let slug = model.id.name.to_string(); + + let messages = thread.into_messages(to_system_messages, convert_events)?; + let converted_tools = convert_tools(tools, &tool_choice); + let tool_choice_val = convert_tool_choice(&tool_choice); + + trace!( + slug, + messages_size = messages.len(), + tools_size = converted_tools.len(), + "Built Llamacpp request." + ); + + let mut body = json!({ + "model": slug, + "messages": messages, + "stream": true, + }); + + if !converted_tools.is_empty() { + body["tools"] = json!(converted_tools); + body["tool_choice"] = json!(tool_choice_val); + } + + if let Some(schema) = structured_schema { + body["response_format"] = json!({ + "type": "json_schema", + "json_schema": { + "name": "structured_output", + "schema": schema, + "strict": true, + }, + }); + } + + Ok((body, is_structured)) +} + +/// Convert system prompt parts into a list of JSON message values. +fn to_system_messages(parts: Vec) -> impl Iterator { + parts + .into_iter() + .map(|content| json!({ "role": "system", "content": content })) +} + +/// Convert a conversation event stream into a list of JSON message values. +fn convert_events(events: ConversationStream) -> Vec { + events + .into_iter() + .filter_map(|event| match event.into_kind() { + EventKind::ChatRequest(request) => { + Some(json!({ "role": "user", "content": request.content })) + } + EventKind::ChatResponse(response) => match response { + ChatResponse::Message { message } => { + Some(json!({ "role": "assistant", "content": message })) + } + ChatResponse::Reasoning { reasoning } => { + // Wrap reasoning in tags so the model can pick up + // its own chain-of-thought on the next turn. + Some(json!({ + "role": "assistant", + "content": format!("\n{reasoning}\n"), + })) + } + ChatResponse::Structured { data } => { + Some(json!({ "role": "assistant", "content": data.to_string() })) + } + }, + EventKind::ToolCallRequest(request) => Some(json!({ + "role": "assistant", + "tool_calls": [{ + "id": request.id, + "type": "function", + "function": { + "name": request.name, + "arguments": Value::Object(request.arguments).to_string(), + }, + }], + })), + EventKind::ToolCallResponse(ToolCallResponse { id, result }) => Some(json!({ + "role": "tool", + "tool_call_id": id, + "content": match result { + Ok(content) | Err(content) => content, + }, + })), + _ => None, + }) + .fold(vec![], |mut messages: Vec, message| { + // Merge consecutive assistant messages that carry tool_calls + // (same folding logic as the old openai-crate implementation). + if message.get("tool_calls").is_some() + && let Some(last) = messages.last_mut() + && last.get("tool_calls").is_some() + && let (Some(existing), Some(new)) = ( + last["tool_calls"].as_array_mut(), + message["tool_calls"].as_array(), + ) + { + existing.extend(new.iter().cloned()); + return messages; + } + messages.push(message); + messages + }) +} + +/// Convert tool definitions to the OpenAI-compatible JSON format. +/// +/// If [`ToolChoice::Function`] is set, only include the named tool. llama.cpp +/// doesn't support calling a specific tool by name, but it supports `required` +/// mode, so we limit the tool list instead. +fn convert_tools(tools: Vec, tool_choice: &ToolChoice) -> Vec { + tools + .into_iter() + .map(|tool| { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description.unwrap_or_default(), + "parameters": parameters_with_strict_mode(tool.parameters, true), + "strict": true, + }, + }) + }) + .filter(|tool| match tool_choice { + ToolChoice::Function(req) => tool["function"]["name"].as_str() == Some(req.as_str()), + _ => true, + }) + .collect() +} + +fn convert_tool_choice(choice: &ToolChoice) -> &str { + match choice { + ToolChoice::Auto => "auto", + ToolChoice::None => "none", + ToolChoice::Required | ToolChoice::Function(_) => "required", + } +} + +fn map_model(model: &ModelResponse) -> Result { Ok(ModelDetails { id: ( PROVIDER, @@ -233,189 +496,65 @@ fn map_model(model: &ModelResponse) -> Result { impl TryFrom<&LlamacppConfig> for Llamacpp { type Error = Error; - fn try_from(config: &LlamacppConfig) -> Result { + fn try_from(config: &LlamacppConfig) -> Result { let reqwest_client = reqwest::Client::builder().build()?; let base_url = config.base_url.clone(); - let credentials = Credentials::new("", &base_url); Ok(Llamacpp { reqwest_client, - credentials, base_url, }) } } -/// Convert a list of [`jp_mcp::Tool`] to a list of [`chat::ChatCompletionTool`]. -/// -/// Additionally, if [`ToolChoice::Function`] is provided, only return the -/// tool(s) that matches the expected name. This is done because Llama.cpp does -/// not support calling a specific tool by name, but it *does* support -/// "required"/forced tool calling, which we can turn into a request to call a -/// specific tool, by limiting the list of tools to the ones that we want to be -/// called. -fn convert_tools( - tools: Vec, - strict: bool, - tool_choice: &ToolChoice, -) -> Vec { - tools - .into_iter() - .map(|tool| chat::ChatCompletionTool::Function { - function: ToolCallFunctionDefinition { - parameters: Some(Value::Object(parameters_with_strict_mode( - tool.parameters, - strict, - ))), - name: tool.name, - description: tool.description.or(Some(String::new())), - strict: Some(strict), - }, - }) - .filter(|tool| match tool_choice { - ToolChoice::Function(req) => matches!( - tool, - chat::ChatCompletionTool::Function { - function: ToolCallFunctionDefinition { name, .. } - } if name == req - ), - _ => true, - }) - .collect::>() +// These mirror the llama.cpp server's `common_chat_msg_diff_to_json_oaicompat` +// output. The critical addition over the `openai` crate is the +// `reasoning_content` field, which carries extracted reasoning for the +// `--reasoning-format deepseek` (default) and `deepseek-legacy` modes. + +#[derive(Debug, Deserialize)] +struct StreamChunk { + #[serde(default)] + choices: Vec, } -fn convert_tool_choice(choice: &ToolChoice) -> chat::ToolChoice { - match choice { - ToolChoice::Auto => chat::ToolChoice::mode(chat::ToolChoiceMode::Auto), - ToolChoice::None => chat::ToolChoice::mode(chat::ToolChoiceMode::None), - ToolChoice::Required | ToolChoice::Function(_) => { - chat::ToolChoice::mode(chat::ToolChoiceMode::Required) - } - } +#[derive(Debug, Deserialize)] +struct StreamChoice { + delta: StreamDelta, + #[serde(default)] + finish_reason: Option, } -/// Convert a list of content into system messages. -fn to_system_messages(parts: Vec) -> impl Iterator { - parts.into_iter().map(|content| ChatCompletionMessage { - role: ChatCompletionMessageRole::System, - content: Some(content), - ..Default::default() - }) +#[derive(Debug, Deserialize, Default)] +struct StreamDelta { + #[serde(default)] + content: Option, + /// Reasoning content extracted by the server (deepseek / deepseek-legacy). + /// This is a non-standard `DeepSeek` extension that llama.cpp also uses. + #[serde(default)] + reasoning_content: Option, + #[serde(default)] + tool_calls: Option>, } -fn convert_events(events: ConversationStream) -> Vec { - events - .into_iter() - .filter_map(|event| match event.into_kind() { - EventKind::ChatRequest(request) => Some(ChatCompletionMessage { - role: ChatCompletionMessageRole::User, - content: Some(request.content), - ..Default::default() - }), - EventKind::ChatResponse(response) => Some(ChatCompletionMessage { - role: ChatCompletionMessageRole::Assistant, - content: Some(response.into_content()), - ..Default::default() - }), - EventKind::ToolCallRequest(request) => Some(ChatCompletionMessage { - role: ChatCompletionMessageRole::Assistant, - tool_calls: Some(vec![chat::ToolCall { - id: request.id.clone(), - r#type: chat::FunctionType::Function, - function: ToolCallFunction { - name: request.name.clone(), - arguments: Value::Object(request.arguments.clone()).to_string(), - }, - }]), - ..Default::default() - }), - EventKind::ToolCallResponse(ToolCallResponse { id, result }) => { - Some(ChatCompletionMessage { - role: ChatCompletionMessageRole::Tool, - tool_call_id: Some(id), - content: Some(match result { - Ok(content) | Err(content) => content, - }), - ..Default::default() - }) - } - _ => None, - }) - .fold(vec![], |mut messages, message| match messages.last_mut() { - Some(last) if message.tool_calls.is_some() && last.tool_calls.is_some() => { - last.tool_calls - .get_or_insert_default() - .extend(message.tool_calls.unwrap_or_default()); - messages - } - _ => { - messages.push(message); - messages - } - }) +#[derive(Debug, Deserialize)] +struct ToolCallDelta { + #[serde(default)] + index: u32, + #[serde(default)] + id: Option, + #[serde(default)] + function: Option, +} + +#[derive(Debug, Deserialize)] +struct FunctionDelta { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, } -// #[cfg(test)] -// mod tests { -// use jp_config::providers::llm::LlmProviderConfig; -// use jp_test::{Result, fn_name, mock::Vcr}; -// use test_log::test; -// -// use super::*; -// -// fn vcr() -> Vcr { -// Vcr::new("http://127.0.0.1:8080", env!("CARGO_MANIFEST_DIR")) -// } -// -// #[test(tokio::test)] -// async fn test_llamacpp_models() -> Result { -// let mut config = LlmProviderConfig::default().llamacpp; -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// Llamacpp::try_from(&config).unwrap().models().await -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_llamacpp_chat_completion() -> Result { -// let mut config = LlmProviderConfig::default().llamacpp; -// let model_id = "llamacpp/llama3:latest".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let query = ChatQuery { -// thread: Thread { -// events: ConversationStream::default().with_chat_request("Test message"), -// ..Default::default() -// }, -// ..Default::default() -// }; -// -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// -// Llamacpp::try_from(&config) -// .unwrap() -// .chat_completion(&model, query) -// .await -// }, -// ) -// .await -// } -// } +#[cfg(test)] +#[path = "llamacpp_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/llamacpp_tests.rs b/crates/jp_llm/src/provider/llamacpp_tests.rs new file mode 100644 index 00000000..97518ac6 --- /dev/null +++ b/crates/jp_llm/src/provider/llamacpp_tests.rs @@ -0,0 +1,181 @@ +use super::*; + +#[test] +fn parse_deepseek_format_reasoning_in_dedicated_field() { + // The default `--reasoning-format deepseek`: reasoning arrives in + // `reasoning_content`, regular content in `content`. + let json = r#"{ + "choices": [{ + "delta": { + "reasoning_content": "Let me think step by step...", + "content": null + }, + "index": 0, + "finish_reason": null + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert_eq!(chunk.choices.len(), 1); + + let delta = &chunk.choices[0].delta; + assert_eq!( + delta.reasoning_content.as_deref(), + Some("Let me think step by step...") + ); + assert!(delta.content.is_none()); +} + +#[test] +fn parse_deepseek_format_content_after_reasoning() { + let json = r#"{ + "choices": [{ + "delta": { + "reasoning_content": null, + "content": "The answer is 42." + }, + "index": 0, + "finish_reason": null + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + let delta = &chunk.choices[0].delta; + assert!(delta.reasoning_content.is_none()); + assert_eq!(delta.content.as_deref(), Some("The answer is 42.")); +} + +#[test] +fn parse_none_format_think_tags_in_content() { + // `--reasoning-format none`: everything in `content`, with tags. + // No `reasoning_content` field at all. + let json = r#"{ + "choices": [{ + "delta": { + "content": "\nLet me reason...\n\nThe answer." + }, + "index": 0, + "finish_reason": null + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + let delta = &chunk.choices[0].delta; + assert!(delta.reasoning_content.is_none()); + assert!(delta.content.as_ref().unwrap().contains("")); +} + +#[test] +fn parse_finish_reason() { + let json = r#"{ + "choices": [{ + "delta": {}, + "index": 0, + "finish_reason": "stop" + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("stop")); +} + +#[test] +fn parse_tool_call_delta() { + let json = r#"{ + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "id": "call_abc123", + "function": { + "name": "get_weather", + "arguments": "{\"city\":" + } + }] + }, + "index": 0, + "finish_reason": null + }] + }"#; + + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id.as_deref(), Some("call_abc123")); + let func = tool_calls[0].function.as_ref().unwrap(); + assert_eq!(func.name.as_deref(), Some("get_weather")); + assert_eq!(func.arguments.as_deref(), Some("{\"city\":")); +} + +#[test] +fn parse_empty_choices() { + // Some servers send empty choices arrays (e.g. usage-only chunks). + let json = r#"{"choices": []}"#; + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert!(chunk.choices.is_empty()); +} + +#[test] +fn parse_missing_optional_fields() { + // Minimal delta with only content. + let json = r#"{"choices": [{"delta": {"content": "hi"}}]}"#; + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + let delta = &chunk.choices[0].delta; + assert_eq!(delta.content.as_deref(), Some("hi")); + assert!(delta.reasoning_content.is_none()); + assert!(delta.tool_calls.is_none()); + assert!(chunk.choices[0].finish_reason.is_none()); +} + +#[test] +fn convert_events_merges_consecutive_tool_calls() { + use jp_conversation::event::ToolCallRequest; + + let mut events = ConversationStream::new_test(); + events.push(ConversationEvent::now(ToolCallRequest { + id: "call_1".into(), + name: "tool_a".into(), + arguments: serde_json::Map::new(), + })); + events.push(ConversationEvent::now(ToolCallRequest { + id: "call_2".into(), + name: "tool_b".into(), + arguments: serde_json::Map::new(), + })); + + let messages = convert_events(events); + + // Should be merged into a single assistant message with 2 tool_calls. + assert_eq!(messages.len(), 1); + let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0]["function"]["name"], "tool_a"); + assert_eq!(tool_calls[1]["function"]["name"], "tool_b"); +} + +#[test] +fn convert_events_wraps_reasoning_in_think_tags() { + let mut events = ConversationStream::new_test(); + events.push(ConversationEvent::now(ChatResponse::reasoning( + "step 1: think hard", + ))); + + let messages = convert_events(events); + + assert_eq!(messages.len(), 1); + let content = messages[0]["content"].as_str().unwrap(); + assert!(content.starts_with("")); + assert!(content.contains("step 1: think hard")); + assert!(content.ends_with("")); +} + +#[test] +fn convert_tool_choice_values() { + assert_eq!(convert_tool_choice(&ToolChoice::Auto), "auto"); + assert_eq!(convert_tool_choice(&ToolChoice::None), "none"); + assert_eq!(convert_tool_choice(&ToolChoice::Required), "required"); + assert_eq!( + convert_tool_choice(&ToolChoice::Function("my_fn".into())), + "required" + ); +} diff --git a/crates/jp_llm/src/provider/mock.rs b/crates/jp_llm/src/provider/mock.rs new file mode 100644 index 00000000..bfa36b9b --- /dev/null +++ b/crates/jp_llm/src/provider/mock.rs @@ -0,0 +1,203 @@ +//! Mock provider for testing LLM interactions without real API calls. +//! +//! This module provides a configurable mock implementation of the [`Provider`] +//! trait, useful for: +//! +//! - Integration tests that need to simulate LLM responses +//! - Testing interrupt/signal handling during streaming +//! - Verifying persistence logic without network calls +//! +//! # Example +//! +//! ```ignore +//! use jp_llm::provider::mock::MockProvider; +//! use jp_llm::event::{Event, FinishReason}; +//! use jp_conversation::event::{ConversationEvent, ChatResponse}; +//! +//! // Create a provider that returns a simple message +//! let provider = MockProvider::with_message("Hello, world!"); +//! +//! // Or create one with custom events for more complex scenarios +//! let provider = MockProvider::new(vec![ +//! Event::Part { +//! index: 0, +//! event: ConversationEvent::now(ChatResponse::message("Hello")), +//! }, +//! Event::flush(0), +//! Event::Finished(FinishReason::Completed), +//! ]); +//! ``` + +use async_trait::async_trait; +use futures::stream; +use jp_config::model::id::{ModelIdConfig, Name, ProviderId}; +use jp_conversation::event::{ChatResponse, ToolCallRequest}; +use serde_json::{Map, Value}; + +use super::Provider; +use crate::{ + error::Result, + event::{Event, FinishReason}, + model::ModelDetails, + query::ChatQuery, + stream::EventStream, +}; + +/// A mock LLM provider for testing. +/// +/// Returns predetermined events from [`chat_completion_stream`], allowing tests +/// to simulate various LLM behaviors without making real API calls. +/// +/// [`chat_completion_stream`]: Provider::chat_completion_stream +#[derive(Debug, Clone)] +pub struct MockProvider { + /// Events to return from the stream. + events: Vec, + + /// Model details to return. + model: ModelDetails, +} + +impl MockProvider { + /// Create a new mock provider with the given events. + /// + /// The events will be returned in order from [`chat_completion_stream`]. + /// + /// [`chat_completion_stream`]: Provider::chat_completion_stream + #[must_use] + pub fn new(events: Vec) -> Self { + Self { + events, + model: Self::default_model(), + } + } + + /// Create a mock provider that streams a simple message response. + /// + /// Useful for basic tests that just need some content to be streamed. + #[must_use] + pub fn with_message(content: &str) -> Self { + Self::new(vec![ + Event::Part { + index: 0, + event: ChatResponse::message(content).into(), + }, + Event::flush(0), + Event::Finished(FinishReason::Completed), + ]) + } + + /// Create a mock provider that streams reasoning followed by a message. + #[must_use] + pub fn with_reasoning_and_message(reasoning: &str, message: &str) -> Self { + Self::new(vec![ + Event::Part { + index: 0, + event: ChatResponse::reasoning(reasoning).into(), + }, + Event::flush(0), + Event::Part { + index: 1, + event: ChatResponse::message(message).into(), + }, + Event::flush(1), + Event::Finished(FinishReason::Completed), + ]) + } + + /// Create a mock provider that streams content in multiple chunks. + /// + /// Useful for testing streaming behavior and partial content handling. + #[must_use] + pub fn with_chunked_message(chunks: &[&str]) -> Self { + let mut events = Vec::with_capacity(chunks.len() + 2); + + for &chunk in chunks { + events.push(Event::Part { + index: 0, + event: ChatResponse::message(chunk).into(), + }); + } + + events.push(Event::flush(0)); + events.push(Event::Finished(FinishReason::Completed)); + + Self::new(events) + } + + /// Create a mock provider that requests a tool call. + #[must_use] + pub fn with_tool_call( + tool_id: impl Into, + tool_name: impl Into, + arguments: Map, + ) -> Self { + Self::new(vec![ + Event::Part { + index: 0, + event: ToolCallRequest { + id: tool_id.into(), + name: tool_name.into(), + arguments, + } + .into(), + }, + Event::flush(0), + Event::Finished(FinishReason::Completed), + ]) + } + + /// Set custom model details for this provider. + #[must_use] + pub fn with_model(mut self, model: ModelDetails) -> Self { + self.model = model; + self + } + + /// Set the model name. + #[must_use] + pub fn with_model_name(mut self, name: impl Into) -> Self { + self.model.id = Self::make_model_id(name); + self + } + + fn default_model() -> ModelDetails { + ModelDetails::empty(Self::make_model_id("mock-model")) + } + + fn make_model_id(name: impl Into) -> ModelIdConfig { + ModelIdConfig { + provider: ProviderId::Test, + name: name.into().parse().expect("valid model name"), + } + } +} + +#[async_trait] +impl Provider for MockProvider { + async fn model_details(&self, name: &Name) -> Result { + let mut model = self.model.clone(); + model.id = ModelIdConfig { + provider: ProviderId::Test, + name: name.clone(), + }; + Ok(model) + } + + async fn models(&self) -> Result> { + Ok(vec![self.model.clone()]) + } + + async fn chat_completion_stream( + &self, + _model: &ModelDetails, + _query: ChatQuery, + ) -> Result { + let events = self.events.clone(); + Ok(Box::pin(stream::iter(events.into_iter().map(Ok)))) + } +} + +#[cfg(test)] +#[path = "mock_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/mock_tests.rs b/crates/jp_llm/src/provider/mock_tests.rs new file mode 100644 index 00000000..176b70b9 --- /dev/null +++ b/crates/jp_llm/src/provider/mock_tests.rs @@ -0,0 +1,124 @@ +use futures::StreamExt; +use jp_conversation::{ConversationStream, thread::Thread}; + +use super::*; + +fn empty_query() -> ChatQuery { + ChatQuery { + thread: Thread { + system_prompt: None, + sections: vec![], + attachments: vec![], + events: ConversationStream::new_test(), + }, + tools: vec![], + tool_choice: jp_config::assistant::tool_choice::ToolChoice::Auto, + } +} + +fn test_name(s: &str) -> Name { + s.parse().expect("valid test name") +} + +#[tokio::test] +async fn test_with_message() { + let provider = MockProvider::with_message("Hello, world!"); + let model = provider.model_details(&test_name("test")).await.unwrap(); + + let mut stream = provider + .chat_completion_stream(&model, empty_query()) + .await + .unwrap(); + + // First event: Part with message + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Part { index: 0, .. })); + + // Second event: Flush + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Flush { index: 0, .. })); + + // Third event: Finished + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Finished(FinishReason::Completed))); + + // Stream should be exhausted + assert!(stream.next().await.is_none()); +} + +#[tokio::test] +async fn test_with_chunked_message() { + let provider = MockProvider::with_chunked_message(&["Hello, ", "world", "!"]); + let model = provider.model_details(&test_name("test")).await.unwrap(); + + let mut stream = provider + .chat_completion_stream(&model, empty_query()) + .await + .unwrap(); + + // Three Part events + for _ in 0..3 { + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Part { index: 0, .. })); + } + + // Flush + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Flush { index: 0, .. })); + + // Finished + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Finished(FinishReason::Completed))); +} + +#[tokio::test] +async fn test_with_reasoning_and_message() { + let provider = MockProvider::with_reasoning_and_message("thinking...", "done"); + let model = provider.model_details(&test_name("test")).await.unwrap(); + + let mut stream = provider + .chat_completion_stream(&model, empty_query()) + .await + .unwrap(); + + // Reasoning part at index 0 + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Part { index: 0, .. })); + + // Flush index 0 + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Flush { index: 0, .. })); + + // Message part at index 1 + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Part { index: 1, .. })); + + // Flush index 1 + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Flush { index: 1, .. })); + + // Finished + let event = stream.next().await.unwrap().unwrap(); + assert!(matches!(event, Event::Finished(FinishReason::Completed))); +} + +#[tokio::test] +async fn test_model_details() { + let provider = MockProvider::with_message("test"); + let model = provider + .model_details(&test_name("custom-name")) + .await + .unwrap(); + + assert_eq!(model.id.name.as_ref(), "custom-name"); + assert_eq!(model.id.provider, ProviderId::Test); +} + +#[tokio::test] +async fn test_models_list() { + let provider = MockProvider::with_message("test").with_model_name("my-model"); + let models = provider.models().await.unwrap(); + + assert_eq!(models.len(), 1); + assert_eq!(models[0].id.name.as_ref(), "my-model"); +} diff --git a/crates/jp_llm/src/provider/ollama.rs b/crates/jp_llm/src/provider/ollama.rs index dfc7583c..d458e70b 100644 --- a/crates/jp_llm/src/provider/ollama.rs +++ b/crates/jp_llm/src/provider/ollama.rs @@ -1,7 +1,7 @@ -use std::{mem, str::FromStr as _}; +use std::str::FromStr as _; use async_trait::async_trait; -use futures::{FutureExt as _, StreamExt as _, future, stream}; +use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _, future, stream}; use jp_config::{ assistant::tool_choice::ToolChoice, model::{ @@ -16,9 +16,10 @@ use jp_conversation::{ }; use ollama_rs::{ Ollama as Client, + error::OllamaError, generation::{ chat::{ChatMessage, ChatMessageResponse, MessageRole, request::ChatMessageRequest}, - parameters::{KeepAlive, TimeUnit}, + parameters::{FormatType, JsonStructure, KeepAlive, TimeUnit}, tools::{ToolCall, ToolCallFunction, ToolFunctionInfo, ToolInfo, ToolType}, }, models::{LocalModel, ModelOptions}, @@ -29,10 +30,9 @@ use url::Url; use super::{EventStream, ModelDetails, Provider}; use crate::{ - error::{Error, Result}, + error::{Error, Result, StreamError}, event::{Event, FinishReason}, query::ChatQuery, - stream::aggregator::reasoning::ReasoningExtractor, tool::ToolDefinition, }; @@ -72,8 +72,7 @@ impl Provider for Ollama { "Starting Ollama chat completion stream." ); - let mut extractor = ReasoningExtractor::default(); - let request = create_request(model, query)?; + let (request, is_structured) = create_request(model, query)?; trace!( request = serde_json::to_string(&request).unwrap_or_default(), @@ -84,11 +83,19 @@ impl Provider for Ollama { .client .send_chat_messages_stream(request) .await? - .filter_map(|v| future::ready(v.ok())) - .map(move |v| stream::iter(map_event(v, &mut extractor))) - .flatten() - .chain(future::ready(Event::Finished(FinishReason::Completed)).into_stream()) - .map(Ok) + .map(|v| v.map_err(|()| StreamError::other("Ollama stream error"))) + .map_ok({ + let mut reasoning_flushed = false; + move |v| { + stream::iter( + map_event(v, is_structured, &mut reasoning_flushed) + .into_iter() + .map(Ok), + ) + } + }) + .try_flatten() + .chain(future::ready(Ok(Event::Finished(FinishReason::Completed))).into_stream()) .boxed()) } } @@ -106,74 +113,111 @@ fn map_model(model: LocalModel) -> Result { }) } -fn map_event(event: ChatMessageResponse, extractor: &mut ReasoningExtractor) -> Vec { +/// Map an Ollama streaming chunk into provider-agnostic events. +/// +/// Index convention: 0 = reasoning, 1 = message content, 2+ = tool calls. +/// +/// Ollama guarantees that thinking tokens arrive before content/tool call +/// tokens. We exploit this by flushing the reasoning stream (index 0) as soon +/// as the first content or tool call chunk appears, ensuring the reasoning +/// event precedes content and tool call events in the history. +fn map_event( + event: ChatMessageResponse, + is_structured: bool, + reasoning_flushed: &mut bool, +) -> Vec { let ChatMessageResponse { message, done, .. } = event; - let mut events = fetch_content(extractor, done); + trace!( + content = message.content, + thinking = message.thinking, + tool_calls = message.tool_calls.len(), + done, + "Ollama stream chunk." + ); - for ( - index, - ToolCall { - function: ToolCallFunction { name, arguments }, - }, - ) in message.tool_calls.into_iter().enumerate() + let mut events = Vec::new(); + + if let Some(thinking) = message.thinking + && !thinking.is_empty() { - events.extend(vec![ - Event::Part { - // These events don't have any index assigned, but we use `0` - // and `1` for regular chat messages and reasoning, and `2` and - // up for tool calls. - index: index + 2, - event: ConversationEvent::now(ToolCallRequest { - id: String::new(), - name, - arguments: match arguments { - Value::Object(map) => map, - v => Map::from_iter([("input".into(), v)]), - }, - }), - }, - Event::flush(0), - ]); + events.push(Event::Part { + index: 0, + event: ConversationEvent::now(ChatResponse::reasoning(thinking)), + }); } - events -} + let has_content = !message.content.is_empty(); + let has_tool_calls = !message.tool_calls.is_empty(); -fn fetch_content(extractor: &mut ReasoningExtractor, done: bool) -> Vec { - let mut events = Vec::new(); + // Flush reasoning before emitting content or tool calls so the reasoning + // event always precedes them in the conversation history. + if !*reasoning_flushed && (has_content || has_tool_calls) { + events.push(Event::flush(0)); + *reasoning_flushed = true; + } - if !extractor.reasoning.is_empty() { - let reasoning = mem::take(&mut extractor.reasoning); + if has_content { + let response = if is_structured { + ChatResponse::structured(Value::String(message.content)) + } else { + ChatResponse::message(message.content) + }; events.push(Event::Part { - index: 0, - event: ConversationEvent::now(ChatResponse::reasoning(reasoning)), + index: 1, + event: ConversationEvent::now(response), }); } - if !extractor.other.is_empty() { - let content = mem::take(&mut extractor.other); + for ( + index, + ToolCall { + function: ToolCallFunction { name, arguments }, + }, + ) in message.tool_calls.into_iter().enumerate() + { + let index = index + 2; events.push(Event::Part { - index: 1, - event: ConversationEvent::now(ChatResponse::message(content)), + index, + event: ConversationEvent::now(ToolCallRequest { + id: String::new(), + name, + arguments: match arguments { + Value::Object(map) => map, + v => Map::from_iter([("input".into(), v)]), + }, + }), }); + events.push(Event::flush(index)); } if done { - events.extend(vec![Event::flush(0), Event::flush(1)]); + if !*reasoning_flushed { + events.push(Event::flush(0)); + } + events.push(Event::flush(1)); } events } -fn create_request(model: &ModelDetails, query: ChatQuery) -> Result { +fn create_request(model: &ModelDetails, query: ChatQuery) -> Result<(ChatMessageRequest, bool)> { let ChatQuery { thread, tools, tool_choice, - tool_call_strict_mode, } = query; + let structured_schema = thread + .events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()) + .map(|schema| JsonStructure::new_for_schema(schema.into())) + .map(|schema| FormatType::StructuredJson(Box::new(schema))); + + let is_structured = structured_schema.is_some(); + let config = thread.events.config()?; let parameters = &config.assistant.model.parameters; @@ -185,7 +229,7 @@ fn create_request(model: &ModelDetails, query: ChatQuery) -> Result Result for Ollama { @@ -265,7 +313,7 @@ impl TryFrom<&OllamaConfig> for Ollama { } } -fn convert_tools(tools: Vec, _strict: bool) -> Result> { +fn convert_tools(tools: Vec) -> Result> { tools .into_iter() .map(|tool| { @@ -334,6 +382,7 @@ fn convert_events(events: ConversationStream) -> Vec { images: None, thinking: Some(reasoning), }), + ChatResponse::Structured { data } => Some(ChatMessage::assistant(data.to_string())), }, EventKind::ToolCallRequest(request) => Some(ChatMessage { role: MessageRole::Assistant, @@ -371,204 +420,20 @@ fn convert_events(events: ConversationStream) -> Vec { }) } -// #[cfg(test)] -// mod tests { -// use jp_config::providers::llm::LlmProviderConfig; -// use jp_conversation::event::ChatResponse; -// use jp_test::{Result, fn_name, mock::Vcr}; -// use test_log::test; -// -// use super::*; -// use crate::structured; -// -// fn vcr(url: &str) -> Vcr { -// Vcr::new(url, env!("CARGO_MANIFEST_DIR")) -// } -// -// #[test(tokio::test)] -// async fn test_ollama_models() -> Result { -// let mut config = LlmProviderConfig::default().ollama; -// let vcr = vcr(&config.base_url); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// -// Ollama::try_from(&config) -// .unwrap() -// .models() -// .await -// .map(|mut v| { -// v.truncate(2); -// v -// }) -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_ollama_chat_completion() -> Result { -// let mut config = LlmProviderConfig::default().ollama; -// let model_id = "ollama/llama3:latest".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let query = ChatQuery { -// thread: Thread { -// events: ConversationStream::default().with_chat_request("Test message"), -// ..Default::default() -// }, -// ..Default::default() -// }; -// -// let vcr = vcr(&config.base_url); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// -// Ollama::try_from(&config) -// .unwrap() -// .chat_completion(&model, query) -// .await -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_ollama_chat_completion_stream() -> Result { -// let mut config = LlmProviderConfig::default().ollama; -// let model_id = "ollama/llama3:latest".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let query = ChatQuery { -// thread: Thread { -// events: ConversationStream::default().with_chat_request("Test message"), -// ..Default::default() -// }, -// ..Default::default() -// }; -// -// let vcr = vcr(&config.base_url); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// -// Ollama::try_from(&config) -// .unwrap() -// .chat_completion_stream(&model, query) -// .await -// .unwrap() -// .collect::>() -// .await -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_ollama_structured_completion() -> Result { -// let mut config = LlmProviderConfig::default().ollama; -// let model_id = "ollama/llama3.1:8b".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let history = ConversationStream::default() -// .with_chat_request("Test message") -// .with_chat_response(ChatResponse::reasoning("Test response")); -// -// let vcr = vcr(&config.base_url); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |_, url| async move { -// config.base_url = url; -// let query = structured::titles::titles(3, history, &[]).unwrap(); -// -// Ollama::try_from(&config) -// .unwrap() -// .structured_completion(&model, query) -// .await -// }, -// ) -// .await -// } -// -// mod chunk_parser { -// use test_log::test; -// -// use super::*; -// -// #[test] -// fn test_no_think_tag_at_all() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("some other text"); -// parser.finalize(); -// assert_eq!(parser.other, "some other text"); -// assert_eq!(parser.reasoning, ""); -// } -// -// #[test] -// fn test_standard_case_with_newline() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("prefix\n\nthoughts\n\nsuffix"); -// parser.finalize(); -// assert_eq!(parser.reasoning, "thoughts\n"); -// assert_eq!(parser.other, "prefix\nsuffix"); -// } -// -// #[test] -// fn test_suffix_only() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("\nthoughts\n\n\nsuffix text here"); -// parser.finalize(); -// assert_eq!(parser.reasoning, "thoughts\n"); -// assert_eq!(parser.other, "\nsuffix text here"); -// } -// -// #[test] -// fn test_ends_with_closing_tag_no_newline() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("\nfinal thoughts\n"); -// parser.handle(""); -// parser.finalize(); -// assert_eq!(parser.reasoning, "final thoughts\n"); -// assert_eq!(parser.other, ""); -// } -// -// #[test] -// fn test_less_than_symbol_in_reasoning_content_is_not_stripped() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("\na < b is a true statement\n"); -// parser.finalize(); -// // The last '<' is part of "", so "a < b is a true statement" is kept. -// assert_eq!(parser.reasoning, "a < b is a true statement\n"); -// } -// -// #[test] -// fn test_less_than_symbol_not_part_of_tag_is_kept() { -// let mut parser = ReasoningExtractor::default(); -// parser.handle("\nhere is a random < symbol"); -// parser.finalize(); -// // The final '<' is not a prefix of '', so it's kept. -// assert_eq!(parser.reasoning, "here is a random < symbol"); -// } -// } -// } +impl From for StreamError { + fn from(err: OllamaError) -> Self { + use ollama_rs::error::InternalOllamaError; + + match err { + OllamaError::ReqwestError(error) => Self::from(error), + OllamaError::ToolCallError(error) => { + StreamError::other("tool-call error").with_source(error) + } + OllamaError::JsonError(error) => StreamError::other("json error").with_source(error), + OllamaError::InternalError(InternalOllamaError { message }) => { + StreamError::other(message) + } + OllamaError::Other(message) => StreamError::other(message), + } + } +} diff --git a/crates/jp_llm/src/provider/openai.rs b/crates/jp_llm/src/provider/openai.rs index 5b242af8..3ac2eb06 100644 --- a/crates/jp_llm/src/provider/openai.rs +++ b/crates/jp_llm/src/provider/openai.rs @@ -6,7 +6,7 @@ use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _, future, stream} use indexmap::{IndexMap, IndexSet}; use jp_config::{ assistant::tool_choice::ToolChoice, - conversation::tool::{OneOrManyTypes, ToolParameterConfig, item::ToolParameterItemConfig}, + conversation::tool::{OneOrManyTypes, ToolParameterConfig}, model::{ id::{Name, ProviderId}, parameters::{CustomReasoningConfig, ReasoningEffort}, @@ -18,7 +18,7 @@ use jp_conversation::{ event::{ChatResponse, ConversationEvent, EventKind, ToolCallRequest, ToolCallResponse}, }; use openai_responses::{ - Client, CreateError, StreamError, + Client, CreateError, StreamError as OpenaiStreamError, types::{self, Include, Request, SummaryConfig}, }; use reqwest::header::{self, HeaderMap, HeaderValue}; @@ -28,7 +28,7 @@ use tracing::{trace, warn}; use super::{EventStream, ModelDetails, Provider}; use crate::{ - error::{Error, Result}, + error::{Error, Result, StreamError, StreamErrorKind}, event::{Event, FinishReason}, model::{ModelDeprecation, ReasoningDetails}, query::ChatQuery, @@ -80,15 +80,15 @@ impl Provider for Openai { model: &ModelDetails, query: ChatQuery, ) -> Result { - let request = create_request(model, query)?; + let (request, is_structured, reasoning_enabled) = create_request(model, query)?; Ok(self .client .stream(request) .or_else(map_error) - .map_ok(|v| stream::iter(map_event(v))) + .map_ok(move |v| stream::iter(map_event(v, is_structured, reasoning_enabled))) .try_flatten() - .chain(future::ok(Event::Finished(FinishReason::Completed)).into_stream()) + .chain(future::ready(Ok(Event::Finished(FinishReason::Completed))).into_stream()) .boxed()) } } @@ -112,41 +112,84 @@ pub(crate) struct ModelResponse { } /// Create a request for the given model and query details. -fn create_request(model: &ModelDetails, query: ChatQuery) -> Result { +/// +/// Returns `(request, is_structured, reasoning_enabled)`. +fn create_request(model: &ModelDetails, query: ChatQuery) -> Result<(Request, bool, bool)> { let ChatQuery { thread, tools, tool_choice, - tool_call_strict_mode, } = query; + // Only use the schema if the very last event is a ChatRequest with one. + // Transform the schema for OpenAI's strict structured output mode + // before passing it to the request. + let text = thread + .events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()) + .map(|schema| types::TextConfig { + format: types::TextFormat::JsonSchema { + schema: Value::Object(transform_schema(schema)), + description: "Structured output".to_owned(), + name: "structured_output".to_owned(), + strict: Some(true), + }, + }); + + let is_structured = text.is_some(); let parameters = thread.events.config()?.assistant.model.parameters; - let reasoning = model - .custom_reasoning_config(parameters.reasoning) - .map(|r| convert_reasoning(r, model.max_output_tokens)); let supports_reasoning = model .reasoning .is_some_and(|v| !matches!(v, ReasoningDetails::Unsupported)); + let reasoning = match model.custom_reasoning_config(parameters.reasoning) { + Some(r) => Some(convert_reasoning(r, model.max_output_tokens)), + // Explicitly disable reasoning for models that support it when the + // user has turned it off. Sending `null` lets the model use its + // default (which may include reasoning). + // + // For leveled models, use their lowest supported effort. For all + // others (budgetted), fall back to `minimal` which is universally + // supported across OpenAI reasoning models. + None if supports_reasoning => { + let effort = model + .reasoning + .and_then(|r| r.lowest_effort()) + .unwrap_or(ReasoningEffort::Xlow); + Some(convert_reasoning( + CustomReasoningConfig { + effort, + exclude: true, + }, + model.max_output_tokens, + )) + } + None => None, + }; + let reasoning_enabled = model + .custom_reasoning_config(parameters.reasoning) + .is_some(); let messages = thread.into_messages(to_system_messages, convert_events(supports_reasoning))?; - let request = Request { model: types::Model::Other(model.id.name.to_string()), input: types::Input::List(messages), - include: supports_reasoning.then_some(vec![Include::ReasoningEncryptedContent]), + include: reasoning_enabled.then_some(vec![Include::ReasoningEncryptedContent]), store: Some(false), tool_choice: Some(convert_tool_choice(tool_choice)), - tools: Some(convert_tools(tools, tool_call_strict_mode)), + tools: Some(convert_tools(tools)), temperature: parameters.temperature, reasoning, max_output_tokens: parameters.max_tokens.map(Into::into), truncation: Some(types::Truncation::Auto), top_p: parameters.top_p, + text, ..Default::default() }; trace!(?request, "Sending request to OpenAI."); - Ok(request) + Ok((request, is_structured, reasoning_enabled)) } #[expect(clippy::too_many_lines)] @@ -447,27 +490,26 @@ fn map_model(model: ModelResponse) -> Result { Ok(details) } -/// Convert a [`StreamError`] into an [`Error`]. +/// Convert an OpenAI [`OpenaiStreamError`] into a [`StreamError`]. /// -/// This needs an async function because we want to get the response text from -/// the body as contextual information. -async fn map_error(error: StreamError) -> Result { +/// This needs an async function because we read the response body for context. +/// Headers are extracted *before* consuming the body. +async fn map_error(error: OpenaiStreamError) -> std::result::Result { Err(match error { - StreamError::Parsing(error) => error.into(), - StreamError::Stream(error) => match error { - reqwest_eventsource::Error::InvalidStatusCode(status_code, response) => { - Error::OpenaiStatusCode { - status_code, - response: response.text().await.unwrap_or_default(), - } - } - _ => Error::OpenaiEvent(Box::new(error)), - }, + OpenaiStreamError::Stream(error) => StreamError::from(error), + OpenaiStreamError::Parsing(error) => { + StreamError::other(error.to_string()).with_source(error) + } }) } /// Map an Openai [`types::Event`] into one or more [`Event`]s. -fn map_event(event: types::Event) -> Vec> { +#[expect(clippy::too_many_lines)] +fn map_event( + event: types::Event, + is_structured: bool, + reasoning_enabled: bool, +) -> Vec> { use types::Event::*; #[expect(clippy::cast_possible_truncation)] @@ -481,11 +523,26 @@ fn map_event(event: types::Event) -> Vec> { output_index, item: types::OutputItem::Message(_), } => vec![Ok(Event::Part { - event: ConversationEvent::now(ChatResponse::message(String::new())), + event: ConversationEvent::now(if is_structured { + ChatResponse::structured(Value::String(String::new())) + } else { + ChatResponse::message(String::new()) + }), index: output_index as usize, })], - // See the previous `OutputItemAdded` case for details. + // Skip all reasoning events when reasoning is disabled. The model + // may still return minimal reasoning output at `effort: "minimal"`. + OutputItemAdded { + item: types::OutputItem::Reasoning(_), + .. + } + | ReasoningSummaryTextDelta { .. } + | OutputItemDone { + item: types::OutputItem::Reasoning(_), + .. + } if !reasoning_enabled => vec![], + OutputItemAdded { output_index, item: types::OutputItem::Reasoning(_), @@ -498,10 +555,17 @@ fn map_event(event: types::Event) -> Vec> { delta, output_index, .. - } => vec![Ok(Event::Part { - event: ConversationEvent::now(ChatResponse::message(delta)), - index: output_index as usize, - })], + } => { + let response = if is_structured { + ChatResponse::structured(Value::String(delta)) + } else { + ChatResponse::message(delta) + }; + vec![Ok(Event::Part { + event: ConversationEvent::now(response), + index: output_index as usize, + })] + } ReasoningSummaryTextDelta { delta, @@ -513,6 +577,8 @@ fn map_event(event: types::Event) -> Vec> { })], OutputItemDone { item, output_index } => { + let index = output_index as usize; + let mut events = vec![]; let metadata = match &item { types::OutputItem::FunctionCall(_) => IndexMap::new(), types::OutputItem::Message(v) => { @@ -536,14 +602,15 @@ fn map_event(event: types::Event) -> Vec> { | types::OutputItem::ComputerToolCall(_) => return vec![], }; - match item { - types::OutputItem::FunctionCall(types::FunctionCall { - name, - arguments, - call_id, - .. - }) => vec![Ok(Event::Part { - index: output_index as usize, + if let types::OutputItem::FunctionCall(types::FunctionCall { + name, + arguments, + call_id, + .. + }) = item + { + events.push(Ok(Event::Part { + index, event: ConversationEvent::now(ToolCallRequest { id: call_id, name, @@ -555,28 +622,42 @@ fn map_event(event: types::Event) -> Vec> { map }, }), - })], - _ => vec![Ok(Event::flush_with_metadata( - output_index as usize, - metadata, - ))], + })); } + + events.push(Ok(Event::flush_with_metadata(index, metadata))); + events } - Error { - code, - message, - param, - } => vec![Err(types::Error { - r#type: "stream_error".to_owned(), - code, - message, - param, - } - .into())], + Error { error } => vec![Err(classify_stream_error(error))], _ => vec![], } } +/// Classify an OpenAI streaming error event into a [`StreamError`]. +/// +/// Maps well-known error types (quota, rate-limit, auth, server errors) +/// to the appropriate [`StreamErrorKind`] so the retry and display layers +/// can handle them correctly. +fn classify_stream_error(error: types::response::Error) -> StreamError { + match error.r#type.as_str() { + "insufficient_quota" => StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your plan and billing details \ + at https://platform.openai.com/settings/organization/billing. \ + ({})", + error.message + ), + ), + "rate_limit_exceeded" => StreamError::rate_limit(None), + "server_error" | "api_error" => StreamError::transient(error.message), + _ => StreamError::other(format!( + "OpenAI error: type={}, code={:?}, message={}, param={:?}", + error.r#type, error.code, error.message, error.param + )), + } +} + impl TryFrom<&OpenaiConfig> for Openai { type Error = Error; @@ -605,6 +686,145 @@ impl TryFrom<&OpenaiConfig> for Openai { } } +/// Transform a JSON schema for OpenAI's strict structured output mode. +/// +/// OpenAI's structured outputs require: +/// - `additionalProperties: false` on all objects +/// - All properties listed in `required` +/// - `allOf` is not supported and must be flattened +/// +/// Additionally handles: +/// - Unraveling `$ref` that has sibling properties (OpenAI doesn't support +/// `$ref` alongside other keys) +/// - Recursively processing `$defs`/`definitions`, `properties`, `items`, and +/// `anyOf` +/// - Stripping `null` defaults +/// +/// Unlike Google, OpenAI supports `$ref`/`$defs` and `const` natively, so those +/// are left in place when standalone. +/// +/// See: +fn transform_schema(src: Map) -> Map { + let root = Value::Object(src.clone()); + process_schema(src, &root) +} + +/// Core recursive processor for a single schema node. +fn process_schema(mut src: Map, root: &Value) -> Map { + // Recursively process $defs/definitions in place. + for key in ["$defs", "definitions"] { + if let Some(Value::Object(defs)) = src.remove(key) { + let processed: Map = defs + .into_iter() + .map(|(k, v)| (k, resolve_and_process(v, root))) + .collect(); + src.insert(key.into(), Value::Object(processed)); + } + } + + // Force `additionalProperties: false` on all objects. + // The docs require this for strict mode. + if src.get("type").and_then(Value::as_str) == Some("object") { + src.insert("additionalProperties".into(), Value::Bool(false)); + } + + // Force all properties into `required` (strict mode requirement). + if let Some(Value::Object(props)) = src.get("properties") { + let keys: Vec = props.keys().map(|k| Value::String(k.clone())).collect(); + src.insert("required".into(), Value::Array(keys)); + } + + // Recursively process object properties. + if let Some(Value::Object(props)) = src.remove("properties") { + let processed: Map = props + .into_iter() + .map(|(k, v)| (k, resolve_and_process(v, root))) + .collect(); + src.insert("properties".into(), Value::Object(processed)); + } + + // Recursively process array items. + if let Some(items) = src.remove("items") { + src.insert("items".into(), resolve_and_process(items, root)); + } + + // Recursively process anyOf variants. + if let Some(Value::Array(variants)) = src.remove("anyOf") { + src.insert( + "anyOf".into(), + Value::Array( + variants + .into_iter() + .map(|v| resolve_and_process(v, root)) + .collect(), + ), + ); + } + + // Flatten `allOf` — not supported by OpenAI. + // Merge all entries into the parent schema; later entries yield to + // earlier ones (and to keys already present on the parent). + if let Some(Value::Array(entries)) = src.remove("allOf") { + for entry in entries { + if let Value::Object(entry_map) = resolve_and_process(entry, root) { + for (k, v) in entry_map { + src.entry(k).or_insert(v); + } + } + } + } + + // Strip `null` defaults (no meaningful distinction for strict mode). + if src.get("default") == Some(&Value::Null) { + src.remove("default"); + } + + // Unravel `$ref` when it has sibling properties. + // OpenAI supports standalone `$ref` but not alongside other keys. + if src.contains_key("$ref") + && src.len() > 1 + && let Some(Value::String(ref_path)) = src.remove("$ref") + { + if let Some(resolved) = resolve_ref(&ref_path, root) { + // Current schema properties take priority over the + // resolved definition's. + let mut merged = resolved; + for (k, v) in src { + merged.insert(k, v); + } + return process_schema(merged, root); + } + // Failed to resolve — put it back. + src.insert("$ref".into(), Value::String(ref_path)); + } + + src +} + +/// Recursively process a value that may be a schema object. +fn resolve_and_process(value: Value, root: &Value) -> Value { + match value { + Value::Object(map) => Value::Object(process_schema(map, root)), + other => other, + } +} + +/// Resolve a JSON pointer against the root schema. +/// +/// Handles paths like `#/$defs/MyType` and `#` (root self-reference). +fn resolve_ref(ref_path: &str, root: &Value) -> Option> { + if ref_path == "#" { + return root.as_object().cloned(); + } + + let path = ref_path.strip_prefix("#/")?; + let mut current = root; + for segment in path.split('/') { + current = current.get(segment)?; + } + current.as_object().cloned() +} + fn convert_tool_choice(choice: ToolChoice) -> types::ToolChoice { match choice { ToolChoice::Auto => types::ToolChoice::Auto, @@ -706,9 +926,7 @@ fn make_config_nullable(cfg: &mut ToolParameterConfig) { /// moving array-based enums into the 'items' configuration. fn sanitize_parameter(config: &mut ToolParameterConfig) { if let Some(items) = &mut config.items { - let mut item_config = items.clone().into(); - sanitize_parameter(&mut item_config); - *items = item_config.into(); + sanitize_parameter(items); } let allows_array = match &config.kind { @@ -758,26 +976,31 @@ fn sanitize_parameter(config: &mut ToolParameterConfig) { OneOrManyTypes::Many(inferred_types.into_iter().collect()) }; - ToolParameterItemConfig { + Box::new(ToolParameterConfig { kind, default: None, + required: false, + summary: None, description: None, + examples: None, enumeration: vec![], - } + items: None, + properties: IndexMap::default(), + }) }); // Append the flattened values to the items enum items_config.enumeration.extend(items); } -fn convert_tools(tools: Vec, strict: bool) -> Vec { +fn convert_tools(tools: Vec) -> Vec { tools .into_iter() .map(|tool| types::Tool::Function { name: tool.name, - strict, + strict: true, description: tool.description, - parameters: parameters_with_strict_mode(tool.parameters, strict).into(), + parameters: parameters_with_strict_mode(tool.parameters, true).into(), }) .collect() } @@ -792,11 +1015,17 @@ fn convert_reasoning( } else { Some(SummaryConfig::Auto) }, - effort: match reasoning.effort.abs_to_rel(max_tokens) { - ReasoningEffort::XHigh => Some(types::ReasoningEffort::XHigh), + effort: match reasoning + .effort + .abs_to_rel(max_tokens) + .unwrap_or(ReasoningEffort::Auto) + { + ReasoningEffort::None => Some(types::ReasoningEffort::None), + ReasoningEffort::Max | ReasoningEffort::XHigh => Some(types::ReasoningEffort::XHigh), ReasoningEffort::High => Some(types::ReasoningEffort::High), ReasoningEffort::Auto | ReasoningEffort::Medium => Some(types::ReasoningEffort::Medium), - ReasoningEffort::Low | ReasoningEffort::Xlow => Some(types::ReasoningEffort::Low), + ReasoningEffort::Low => Some(types::ReasoningEffort::Low), + ReasoningEffort::Xlow => Some(types::ReasoningEffort::Minimal), ReasoningEffort::Absolute(_) => { debug_assert!(false, "Reasoning effort must be relative."); None @@ -898,6 +1127,12 @@ fn convert_events( })] } } + ChatResponse::Structured { data } => { + vec![types::InputListItem::Message(types::InputMessage { + role: types::Role::Assistant, + content: types::ContentInput::Text(data.to_string()), + })] + } } } EventKind::ToolCallRequest(request) => vec![types::InputListItem::Item( @@ -928,232 +1163,12 @@ fn convert_events( } } -// /// Converts a single event into `OpenAI` input items. -// /// -// /// Note: `OpenAI` requires separate items for different content types. -// fn convert_events( -// events: ConversationStream, -// supports_reasoning: bool, -// ) -> Vec { -// events -// .into_iter() -// .flat_map(|event| { -// let ConversationEvent { -// kind, mut metadata, .. -// } = event.event; -// -// match kind { -// EventKind::ChatRequest(request) => { -// vec![types::InputListItem::Message(types::InputMessage { -// role: types::Role::User, -// content: types::ContentInput::Text(request.content), -// })] -// } -// EventKind::ChatResponse(response) => { -// if let Some(item) = metadata.remove(ENCODED_PAYLOAD_KEY).and_then(|s| { -// Some(if response.is_reasoning() { -// if !supports_reasoning { -// return None; -// } -// -// types::InputItem::Reasoning( -// serde_json::from_value::(s).ok()?, -// ) -// } else { -// types::InputItem::OutputMessage( -// serde_json::from_value::(s).ok()?, -// ) -// }) -// }) { -// vec![types::InputListItem::Item(item)] -// } else if response.is_reasoning() { -// // Unsupported reasoning content - wrap in XML tags -// vec![types::InputListItem::Message(types::InputMessage { -// role: types::Role::Assistant, -// content: types::ContentInput::Text(format!( -// "\n{}\n\n\n", -// response.content() -// )), -// })] -// } else { -// vec![types::InputListItem::Message(types::InputMessage { -// role: types::Role::Assistant, -// content: types::ContentInput::Text(response.into_content()), -// })] -// } -// } -// EventKind::ToolCallRequest(request) => { -// let call = metadata -// .remove(ENCODED_PAYLOAD_KEY) -// .and_then(|s| serde_json::from_value::(s).ok()) -// .unwrap_or_else(|| types::FunctionCall { -// call_id: String::new(), -// name: request.name, -// arguments: Value::Object(request.arguments).to_string(), -// status: None, -// id: (!request.id.is_empty()).then_some(request.id), -// }); -// -// vec![types::InputListItem::Item(types::InputItem::FunctionCall( -// call, -// ))] -// } -// EventKind::ToolCallResponse(ToolCallResponse { id, result }) => { -// vec![types::InputListItem::Item( -// types::InputItem::FunctionCallOutput(types::FunctionCallOutput { -// call_id: id, -// output: match result { -// Ok(content) | Err(content) => content, -// }, -// id: None, -// status: None, -// }), -// )] -// } -// _ => vec![], -// } -// }) -// .collect() -// } - -// impl From for Delta { -// fn from(item: types::OutputItem) -> Self { -// match item { -// types::OutputItem::Message(message) => Delta::content( -// message -// .content -// .into_iter() -// .filter_map(|item| match item { -// types::OutputContent::Text { text, .. } => Some(text), -// types::OutputContent::Refusal { .. } => None, -// }) -// .collect::>() -// .join("\n\n"), -// ), -// types::OutputItem::Reasoning(reasoning) => Delta::reasoning( -// reasoning -// .summary -// .into_iter() -// .map(|item| match item { -// types::ReasoningSummary::Text { text, .. } => text, -// }) -// .collect::>() -// .join("\n\n"), -// ), -// types::OutputItem::FunctionCall(call) => { -// Delta::tool_call(call.call_id, call.name, call.arguments).finished() -// } -// _ => Delta::default(), -// } -// } -// } - -// #[cfg(test)] -// mod tests { -// use jp_config::providers::llm::LlmProviderConfig; -// use jp_test::{Result, fn_name, mock::Vcr}; -// use test_log::test; -// -// use super::*; -// -// fn vcr() -> Vcr { -// Vcr::new("https://api.openai.com", env!("CARGO_MANIFEST_DIR")) -// } -// -// #[test(tokio::test)] -// async fn test_openai_model_details() -> Result { -// let mut config = LlmProviderConfig::default().openai; -// let name: Name = "o4-mini".parse().unwrap(); -// -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |recording, url| async move { -// config.base_url = url; -// if !recording { -// // dummy api key value when replaying a cassette -// config.api_key_env = "USER".to_owned(); -// } -// -// Openai::try_from(&config) -// .unwrap() -// .model_details(&name) -// .await -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_openai_models() -> Result { -// let mut config = LlmProviderConfig::default().openai; -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |recording, url| async move { -// config.base_url = url; -// if !recording { -// // dummy api key value when replaying a cassette -// config.api_key_env = "USER".to_owned(); -// } -// -// Openai::try_from(&config) -// .unwrap() -// .models() -// .await -// .map(|mut v| { -// v.truncate(10); -// v -// }) -// }, -// ) -// .await -// } -// -// #[test(tokio::test)] -// async fn test_openai_chat_completion() -> Result { -// let mut config = LlmProviderConfig::default().openai; -// let model_id = "openai/o4-mini".parse().unwrap(); -// let model = ModelDetails::empty(model_id); -// let query = ChatQuery { -// thread: Thread { -// events: ConversationStream::default().with_chat_request("Test message"), -// ..Default::default() -// }, -// ..Default::default() -// }; -// -// let vcr = vcr(); -// vcr.cassette( -// fn_name!(), -// |rule| { -// rule.filter(|when| { -// when.any_request(); -// }); -// }, -// |recording, url| async move { -// config.base_url = url; -// if !recording { -// // dummy api key value when replaying a cassette -// config.api_key_env = "USER".to_owned(); -// } -// -// Openai::try_from(&config) -// .unwrap() -// .chat_completion(&model, query) -// .await -// }, -// ) -// .await -// } -// } +impl From for Error { + fn from(error: types::response::Error) -> Self { + Self::OpenaiResponse(error) + } +} + +#[cfg(test)] +#[path = "openai_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/openai_tests.rs b/crates/jp_llm/src/provider/openai_tests.rs new file mode 100644 index 00000000..1d648ba1 --- /dev/null +++ b/crates/jp_llm/src/provider/openai_tests.rs @@ -0,0 +1,378 @@ +mod transform_schema { + use serde_json::{Map, Value, json}; + + use super::super::transform_schema; + + #[expect(clippy::needless_pass_by_value)] + fn schema(v: Value) -> Map { + v.as_object().unwrap().clone() + } + + #[test] + fn additional_properties_false_forced_on_objects() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + } + })); + + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + } + + #[test] + fn additional_properties_true_overridden_to_false() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "additionalProperties": true + })); + + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + } + + #[test] + fn all_properties_forced_into_required() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" } + } + })); + + let out = transform_schema(input); + + assert_eq!(out["required"], json!(["name", "age"])); + } + + #[test] + fn existing_required_overwritten_to_all_properties() { + let input = schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" } + }, + "required": ["name"] + })); + + let out = transform_schema(input); + + // Both properties must be required in strict mode. + assert_eq!(out["required"], json!(["name", "age"])); + } + + #[test] + fn nested_objects_get_strict_treatment() { + let input = schema(json!({ + "type": "object", + "properties": { + "inner": { + "type": "object", + "properties": { + "x": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + let inner = out["properties"]["inner"].as_object().unwrap(); + assert_eq!(inner["additionalProperties"], json!(false)); + assert_eq!(inner["required"], json!(["x"])); + } + + #[test] + fn defs_recursively_processed() { + let input = schema(json!({ + "type": "object", + "properties": { + "step": { "$ref": "#/$defs/Step" } + }, + "$defs": { + "Step": { + "type": "object", + "properties": { + "explanation": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + // $defs is kept (OpenAI supports it natively). + let step_def = out["$defs"]["Step"].as_object().unwrap(); + assert_eq!(step_def["additionalProperties"], json!(false)); + assert_eq!(step_def["required"], json!(["explanation"])); + } + + #[test] + fn standalone_ref_kept_as_is() { + let input = schema(json!({ + "type": "object", + "properties": { + "step": { "$ref": "#/$defs/Step" } + }, + "$defs": { + "Step": { + "type": "object", + "properties": { + "name": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + // Standalone $ref should stay. + assert_eq!(out["properties"]["step"]["$ref"], "#/$defs/Step"); + } + + #[test] + fn ref_with_siblings_unraveled() { + let input = schema(json!({ + "type": "object", + "properties": { + "person": { + "$ref": "#/$defs/Person", + "description": "The main person" + } + }, + "$defs": { + "Person": { + "type": "object", + "properties": { + "name": { "type": "string" } + } + } + } + })); + + let out = transform_schema(input); + + let person = out["properties"]["person"].as_object().unwrap(); + // $ref should be removed, definition inlined. + assert!(person.get("$ref").is_none()); + assert_eq!(person["type"], "object"); + assert_eq!(person["description"], "The main person"); + assert_eq!(person["properties"]["name"]["type"], "string"); + // Inlined object should also get strict treatment. + assert_eq!(person["additionalProperties"], json!(false)); + } + + #[test] + fn anyof_variants_processed() { + let input = schema(json!({ + "type": "object", + "properties": { + "item": { + "anyOf": [ + { + "type": "object", + "properties": { + "name": { "type": "string" } + } + }, + { "type": "string" } + ] + } + } + })); + + let out = transform_schema(input); + + let variants = out["properties"]["item"]["anyOf"].as_array().unwrap(); + let obj_variant = variants[0].as_object().unwrap(); + assert_eq!(obj_variant["additionalProperties"], json!(false)); + assert_eq!(obj_variant["required"], json!(["name"])); + } + + #[test] + fn allof_single_element_merged() { + let input = schema(json!({ + "allOf": [{ + "type": "object", + "properties": { + "name": { "type": "string" } + } + }] + })); + + let out = transform_schema(input); + + assert!(out.get("allOf").is_none()); + assert_eq!(out["type"], "object"); + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["name"])); + } + + #[test] + fn allof_multiple_elements_merged() { + let input = schema(json!({ + "allOf": [ + { + "type": "object", + "properties": { + "name": { "type": "string" } + } + }, + { + "description": "Extra info" + } + ] + })); + + let out = transform_schema(input); + + assert!(out.get("allOf").is_none()); + assert_eq!(out["type"], "object"); + assert_eq!(out["description"], "Extra info"); + } + + #[test] + fn null_default_stripped() { + let input = schema(json!({ + "type": "string", + "default": null + })); + + let out = transform_schema(input); + + assert!(out.get("default").is_none()); + } + + #[test] + fn non_null_default_preserved() { + let input = schema(json!({ + "type": "string", + "default": "hello" + })); + + let out = transform_schema(input); + + assert_eq!(out["default"], "hello"); + } + + #[test] + fn const_preserved_unchanged() { + let input = schema(json!({ + "type": "string", + "const": "tool_call.my_tool.call_123" + })); + + let out = transform_schema(input); + + assert_eq!(out["const"], "tool_call.my_tool.call_123"); + } + + #[test] + fn array_items_recursively_processed() { + let input = schema(json!({ + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { "type": "integer" } + } + } + })); + + let out = transform_schema(input); + + let items = out["items"].as_object().unwrap(); + assert_eq!(items["additionalProperties"], json!(false)); + assert_eq!(items["required"], json!(["id"])); + } + + /// The inquiry schema should get strict treatment. + #[test] + fn inquiry_schema_transforms_correctly() { + let input = schema(json!({ + "type": "object", + "properties": { + "inquiry_id": { + "type": "string", + "const": "tool_call.fs_modify_file.call_a3b7c9d1" + }, + "answer": { + "type": "boolean" + } + }, + "required": ["inquiry_id", "answer"], + "additionalProperties": false + })); + + let out = transform_schema(input); + + // const should be preserved (OpenAI supports it). + assert_eq!( + out["properties"]["inquiry_id"]["const"], + "tool_call.fs_modify_file.call_a3b7c9d1" + ); + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["inquiry_id", "answer"])); + } + + /// The `title_schema` should get strict treatment applied. + #[test] + fn title_schema_gets_strict_treatment() { + let input = crate::title::title_schema(3); + let out = transform_schema(input); + + assert_eq!(out["additionalProperties"], json!(false)); + assert_eq!(out["required"], json!(["titles"])); + + let titles = out["properties"]["titles"].as_object().unwrap(); + let items = titles["items"].as_object().unwrap(); + assert_eq!(items["type"], "string"); + } + + /// Docs example: definitions with $ref. + #[test] + fn definitions_example_from_docs() { + let input = schema(json!({ + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { "$ref": "#/$defs/step" } + }, + "final_answer": { "type": "string" } + }, + "$defs": { + "step": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + })); + + let out = transform_schema(input); + + // Root object: all properties required. + assert_eq!(out["required"], json!(["steps", "final_answer"])); + + // $defs preserved, step def gets required added. + let step_def = out["$defs"]["step"].as_object().unwrap(); + assert_eq!(step_def["required"], json!(["explanation", "output"])); + assert_eq!(step_def["additionalProperties"], json!(false)); + + // $ref stays as-is (standalone, no siblings). + assert_eq!(out["properties"]["steps"]["items"]["$ref"], "#/$defs/step"); + } +} diff --git a/crates/jp_llm/src/provider/openrouter.rs b/crates/jp_llm/src/provider/openrouter.rs index 2cd8b03c..a1572d53 100644 --- a/crates/jp_llm/src/provider/openrouter.rs +++ b/crates/jp_llm/src/provider/openrouter.rs @@ -21,7 +21,7 @@ use jp_openrouter::{ types::{ self, chat::{CacheControl, Content, Message}, - request::{self, RequestMessage}, + request::{self, JsonSchemaFormat, RequestMessage, ResponseFormat}, response::{ self, ChatCompletion as OpenRouterChunk, FinishReason, ReasoningDetails, ReasoningDetailsFormat, ReasoningDetailsKind, @@ -35,8 +35,8 @@ use tracing::{debug, trace, warn}; use super::{EventStream, ModelDetails}; use crate::{ - Error, - error::Result, + Error, StreamError, + error::{Result, StreamErrorKind, looks_like_quota_error}, event::{self, Event}, provider::{Provider, openai::parameters_with_strict_mode}, query::ChatQuery, @@ -108,18 +108,18 @@ impl Provider for Openrouter { "Starting OpenRouter chat completion stream." ); + let (request, is_structured) = build_request(query, model)?; let mut state = AggregationState { tool_calls: ToolCallRequestAggregator::default(), aggregating_reasoning: false, aggregating_message: false, + is_structured, }; - let request = build_request(query, model)?; - Ok(self .client .chat_completion_stream(request) - .map_err(Error::from) + .map_err(StreamError::from) .map_ok(move |v| stream::iter(map_completion(v, &mut state))) .try_flatten() .boxed()) @@ -136,6 +136,9 @@ struct AggregationState { /// Did the stream of events have any message content? aggregating_message: bool, + + /// Whether the current request uses structured (JSON schema) output. + is_structured: bool, } /// Metadata stored in the conversation stream, based on Openrouter @@ -267,7 +270,10 @@ impl From for IndexMap { } } -fn map_completion(v: OpenRouterChunk, state: &mut AggregationState) -> Vec> { +fn map_completion( + v: OpenRouterChunk, + state: &mut AggregationState, +) -> Vec> { v.choices .into_iter() .flat_map(|v| map_event(v, state)) @@ -275,7 +281,10 @@ fn map_completion(v: OpenRouterChunk, state: &mut AggregationState) -> Vec Vec> { +fn map_event( + choice: types::response::Choice, + state: &mut AggregationState, +) -> Vec> { let types::response::Choice::Streaming(types::response::StreamingChoice { finish_reason, delta: @@ -306,7 +315,17 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V let reasoning_details = MultiProviderMetadata::from_details(reasoning_details); if let Some(error) = error { - return vec![Err(Error::from(error))]; + if looks_like_quota_error(&error.message) { + return vec![Err(StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your credits \ + at https://openrouter.ai/settings/credits. ({})", + error.message + ), + ))]; + } + return vec![Err(StreamError::other(error.message))]; } let mut events = vec![]; @@ -329,9 +348,14 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V { state.aggregating_message = true; + let response = if state.is_structured { + ChatResponse::structured(Value::String(content)) + } else { + ChatResponse::message(content) + }; events.push(Ok(Event::Part { index: 1, - event: ConversationEvent::now(ChatResponse::message(content)), + event: ConversationEvent::now(response), })); } @@ -375,7 +399,7 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V index, event: ConversationEvent::now(call), }) - .map_err(Error::from), + .map_err(|e| StreamError::other(e.to_string())), Ok(Event::flush(index)), ] }), @@ -389,10 +413,9 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V Some(FinishReason::Stop) => { events.push(Ok(Event::Finished(event::FinishReason::Completed))); } - Some(FinishReason::Error) => events.push(Err(jp_openrouter::Error::Stream( - "unknown stream error".into(), - ) - .into())), + Some(FinishReason::Error) => { + events.push(Err(StreamError::other("unknown stream error"))); + } Some(reason) => events.push(Ok(Event::Finished(event::FinishReason::Other( reason.as_str().into(), )))), @@ -403,17 +426,37 @@ fn map_event(choice: types::response::Choice, state: &mut AggregationState) -> V } /// Build request for Openrouter API. -fn build_request(query: ChatQuery, model: &ModelDetails) -> Result { +/// +/// Returns the request and whether structured output is active. +fn build_request( + query: ChatQuery, + model: &ModelDetails, +) -> Result<(request::ChatCompletion, bool)> { let ChatQuery { thread, tools, tool_choice, - tool_call_strict_mode, } = query; let config = thread.events.config()?; let parameters = &config.assistant.model.parameters; + // Only use the schema if the very last event is a ChatRequest with one. + let response_format = thread + .events + .last() + .and_then(|e| e.event.as_chat_request()) + .and_then(|req| req.schema.clone()) + .map(|schema| ResponseFormat::JsonSchema { + json_schema: JsonSchemaFormat { + name: "structured_output".to_owned(), + schema: Value::Object(schema), + strict: Some(true), + }, + }); + + let is_structured = response_format.is_some(); + let slug = model.id.name.to_string(); let reasoning = model.custom_reasoning_config(parameters.reasoning); @@ -422,10 +465,10 @@ fn build_request(query: ChatQuery, model: &ModelDetails) -> Result>(); @@ -447,26 +490,40 @@ fn build_request(query: ChatQuery, model: &ModelDetails) -> Result request::ReasoningEffort::XHigh, - ReasoningEffort::High => request::ReasoningEffort::High, - ReasoningEffort::Auto | ReasoningEffort::Medium => request::ReasoningEffort::Medium, - ReasoningEffort::Low | ReasoningEffort::Xlow => request::ReasoningEffort::Low, - ReasoningEffort::Absolute(_) => { - debug_assert!(false, "Reasoning effort must be relative."); - request::ReasoningEffort::Medium - } - }, - }), - tools, - tool_choice, - ..Default::default() - }) + Ok(( + request::ChatCompletion { + model: slug, + messages: messages.0, + reasoning: reasoning.map(|r| request::Reasoning { + exclude: r.exclude, + effort: match r + .effort + .abs_to_rel(model.max_output_tokens) + .unwrap_or(ReasoningEffort::Auto) + { + ReasoningEffort::Max | ReasoningEffort::XHigh => { + request::ReasoningEffort::XHigh + } + ReasoningEffort::High => request::ReasoningEffort::High, + ReasoningEffort::Auto | ReasoningEffort::Medium => { + request::ReasoningEffort::Medium + } + ReasoningEffort::None => request::ReasoningEffort::None, + ReasoningEffort::Xlow => request::ReasoningEffort::Minimal, + ReasoningEffort::Low => request::ReasoningEffort::Low, + ReasoningEffort::Absolute(_) => { + debug_assert!(false, "Reasoning effort must be relative."); + request::ReasoningEffort::Medium + } + }, + }), + tools, + tool_choice, + response_format, + ..Default::default() + }, + is_structured, + )) } // TODO: Manually add a bunch of often-used models. @@ -483,21 +540,6 @@ fn map_model(model: response::Model) -> Result { }) } -// impl From for Delta { -// fn from(delta: StreamingDelta) -> Self { -// let tool_call = delta.tool_calls.into_iter().next(); -// -// Self { -// content: delta.content, -// reasoning: delta.reasoning, -// tool_call_id: tool_call.as_ref().and_then(ToolCall::id), -// tool_call_name: tool_call.as_ref().and_then(ToolCall::name), -// tool_call_arguments: tool_call.as_ref().and_then(ToolCall::arguments), -// tool_call_finished: false, -// } -// } -// } - impl From for Error { fn from(error: types::response::ErrorResponse) -> Self { Self::OpenRouter(jp_openrouter::Error::Api { @@ -534,7 +576,7 @@ impl TryFrom<(&ModelIdConfig, Thread)> for RequestMessages { fn try_from((model_id, thread): (&ModelIdConfig, Thread)) -> Result { let Thread { system_prompt, - instructions, + sections, attachments, events, } = thread; @@ -554,7 +596,7 @@ impl TryFrom<(&ModelIdConfig, Thread)> for RequestMessages { }); } - if !instructions.is_empty() { + if !sections.is_empty() { content.push(Content::Text { text: "Before we continue, here are some contextual details that will help you \ generate a better response." @@ -562,15 +604,15 @@ impl TryFrom<(&ModelIdConfig, Thread)> for RequestMessages { cache_control: None, }); - // Then instructions in XML tags. + // Then sections as rendered text. // - // Cached (3/4), (for the last instruction), as it's not expected to + // Cached (3/4), (for the last section), as it's not expected to // change. - let mut instructions = instructions.iter().peekable(); - while let Some(instruction) = instructions.next() { + let mut sections = sections.iter().peekable(); + while let Some(section) = sections.next() { content.push(Content::Text { - text: instruction.try_to_xml()?, - cache_control: instructions + text: section.render(), + cache_control: sections .peek() .map_or(Some(CacheControl::Ephemeral), |_| None), }); @@ -631,6 +673,9 @@ fn convert_events(events: ConversationStream) -> Vec { ChatResponse::Reasoning { reasoning, .. } => { vec![Message::default().with_reasoning(reasoning).assistant()] } + ChatResponse::Structured { data } => { + vec![Message::default().with_text(data.to_string()).assistant()] + } }, EventKind::ToolCallRequest(request) => { let message = Message { @@ -663,62 +708,47 @@ fn convert_events(events: ConversationStream) -> Vec { name: None, })] } - EventKind::InquiryRequest(_) | EventKind::InquiryResponse(_) => vec![], + EventKind::InquiryRequest(_) + | EventKind::InquiryResponse(_) + | EventKind::TurnStart(_) => vec![], }) .collect() } -#[cfg(test)] -mod tests { - use jp_config::providers::llm::LlmProviderConfig; - use jp_test::{Result, function_name}; - - use super::*; - use crate::test::TestRequest; - - macro_rules! test_all_models { - ($($fn:ident),* $(,)?) => { - mod anthropic { use super::*; $(test_all_models!(func; $fn, "openrouter/anthropic/claude-haiku-4.5");)* } - mod google { use super::*; $(test_all_models!(func; $fn, "openrouter/google/gemini-2.5-flash");)* } - mod xai { use super::*; $(test_all_models!(func; $fn, "openrouter/x-ai/grok-code-fast-1");)* } - mod minimax { use super::*; $(test_all_models!(func; $fn, "openrouter/minimax/minimax-m2");)* } - }; - (func; $fn:ident, $model:literal) => { - paste::paste! { - #[test_log::test(tokio::test)] - async fn [< test_ $fn >]() -> Result { - $fn($model, &format!("{}_{}", $model.split('/').nth(1).unwrap(), function_name!())).await - } +impl From for StreamError { + fn from(err: jp_openrouter::Error) -> Self { + use jp_openrouter::Error as E; + + match err { + E::Request(error) => Self::from(error), + E::Api { code: 429, .. } => StreamError::rate_limit(None).with_source(err), + // 402 Payment Required — OpenRouter returns this for insufficient credits. + E::Api { code: 402, .. } => StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your credits \ + at https://openrouter.ai/settings/credits. ({err})" + ), + ) + .with_source(err), + E::Api { ref message, .. } if looks_like_quota_error(message) => StreamError::new( + StreamErrorKind::InsufficientQuota, + format!( + "Insufficient API quota. Check your credits \ + at https://openrouter.ai/settings/credits. ({err})" + ), + ) + .with_source(err), + E::Api { + code: 408 | 500 | 502 | 503 | 504, + .. } - }; - } - - test_all_models![sub_provider_event_metadata]; - - async fn sub_provider_event_metadata(model: &str, test_name: &str) -> Result { - let requests = vec![ - TestRequest::chat(ProviderId::Openrouter) - .model(model.parse().unwrap()) - .enable_reasoning() - .chat_request("Test message"), - ]; - - run_test(test_name, requests).await?; - - Ok(()) - } - - async fn run_test( - test_name: impl AsRef, - requests: impl IntoIterator, - ) -> Result { - crate::test::run_chat_completion( - test_name, - env!("CARGO_MANIFEST_DIR"), - ProviderId::Openrouter, - LlmProviderConfig::default(), - requests.into_iter().collect(), - ) - .await + | E::Stream(_) => StreamError::transient(err.to_string()).with_source(err), + _ => StreamError::other(err.to_string()).with_source(err), + } } } + +#[cfg(test)] +#[path = "openrouter_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/provider/openrouter_tests.rs b/crates/jp_llm/src/provider/openrouter_tests.rs new file mode 100644 index 00000000..c7c63ee6 --- /dev/null +++ b/crates/jp_llm/src/provider/openrouter_tests.rs @@ -0,0 +1,51 @@ +use jp_config::providers::llm::LlmProviderConfig; +use jp_test::{Result, function_name}; + +use super::*; +use crate::test::TestRequest; + +macro_rules! test_all_models { + ($($fn:ident),* $(,)?) => { + mod anthropic { use super::*; $(test_all_models!(func; $fn, "openrouter/anthropic/claude-haiku-4.5");)* } + mod google { use super::*; $(test_all_models!(func; $fn, "openrouter/google/gemini-2.5-flash");)* } + mod xai { use super::*; $(test_all_models!(func; $fn, "openrouter/x-ai/grok-code-fast-1");)* } + mod minimax { use super::*; $(test_all_models!(func; $fn, "openrouter/minimax/minimax-m2");)* } + }; + (func; $fn:ident, $model:literal) => { + paste::paste! { + #[test_log::test(tokio::test)] + async fn [< test_ $fn >]() -> Result { + $fn($model, &format!("{}_{}", $model.split('/').nth(1).unwrap(), function_name!())).await + } + } + }; + } + +test_all_models![sub_provider_event_metadata]; + +async fn sub_provider_event_metadata(model: &str, test_name: &str) -> Result { + let requests = vec![ + TestRequest::chat(ProviderId::Openrouter) + .model(model.parse().unwrap()) + .enable_reasoning() + .chat_request("Test message"), + ]; + + run_test(test_name, requests).await?; + + Ok(()) +} + +async fn run_test( + test_name: impl AsRef, + requests: impl IntoIterator, +) -> Result { + crate::test::run_chat_completion( + test_name, + env!("CARGO_MANIFEST_DIR"), + ProviderId::Openrouter, + LlmProviderConfig::default(), + requests.into_iter().collect(), + ) + .await +} diff --git a/crates/jp_llm/src/provider_tests.rs b/crates/jp_llm/src/provider_tests.rs new file mode 100644 index 00000000..7f9c1880 --- /dev/null +++ b/crates/jp_llm/src/provider_tests.rs @@ -0,0 +1,198 @@ +use std::sync::Arc; + +use indexmap::IndexMap; +use jp_config::{ + assistant::tool_choice::ToolChoice, + conversation::tool::{OneOrManyTypes, ToolParameterConfig}, +}; +use jp_conversation::event::ChatRequest; +use jp_test::{Result, function_name}; + +use super::*; +use crate::test::{TestRequest, run_test, test_model_details}; + +macro_rules! test_all_providers { + ($($fn:ident),* $(,)?) => { + mod anthropic { use super::*; $(test_all_providers!(func; $fn, ProviderId::Anthropic);)* } + mod google { use super::*; $(test_all_providers!(func; $fn, ProviderId::Google);)* } + mod openai { use super::*; $(test_all_providers!(func; $fn, ProviderId::Openai);)* } + mod openrouter{ use super::*; $(test_all_providers!(func; $fn, ProviderId::Openrouter);)* } + mod ollama { use super::*; $(test_all_providers!(func; $fn, ProviderId::Ollama);)* } + mod llamacpp { use super::*; $(test_all_providers!(func; $fn, ProviderId::Llamacpp);)* } + }; + (func; $fn:ident, $provider:ty) => { + paste::paste! { + #[test_log::test(tokio::test)] + async fn [< test_ $fn >]() -> Result { + $fn($provider, function_name!()).await + } + } + }; + } + +async fn chat_completion_stream(provider: ProviderId, test_name: &str) -> Result { + let request = TestRequest::chat(provider) + .enable_reasoning() + .event(ChatRequest::from("Test message")); + + run_test(provider, test_name, Some(request)).await +} + +fn tool_call_base(provider: ProviderId) -> TestRequest { + TestRequest::chat(provider) + .event(ChatRequest::from( + "Please run the tool, providing whatever arguments you want.", + )) + .tool("run_me", vec![ + ("foo", ToolParameterConfig { + kind: OneOrManyTypes::One("string".into()), + default: Some("foo".into()), + required: false, + summary: None, + description: None, + examples: None, + enumeration: vec![], + items: None, + properties: IndexMap::default(), + }), + ("bar", ToolParameterConfig { + kind: OneOrManyTypes::Many(vec!["string".into(), "array".into()]), + default: None, + required: true, + summary: None, + description: None, + examples: None, + enumeration: vec!["foo".into(), vec!["foo", "bar"].into()], + items: Some(Box::new(ToolParameterConfig { + kind: OneOrManyTypes::One("string".into()), + default: None, + required: false, + summary: None, + description: None, + examples: None, + enumeration: vec![], + items: None, + properties: IndexMap::default(), + })), + properties: IndexMap::default(), + }), + ]) +} + +async fn tool_call_stream(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider), + TestRequest::tool_call_response(Ok("working!"), false), + ]; + + run_test(provider, test_name, requests).await +} + +/// Without reasoning, "forced" tool calls should work as expected. +async fn tool_call_required_no_reasoning(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider).tool_choice(ToolChoice::Required), + TestRequest::tool_call_response(Ok("working!"), true), + ]; + + run_test(provider, test_name, requests).await +} + +/// With reasoning, some models do not support "forced" tool calls, so +/// provider implementations should fall back to trying to instruct the +/// model to use the tool through regular textual instructions. +async fn tool_call_required_reasoning(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider) + .tool_choice(ToolChoice::Required) + .enable_reasoning(), + TestRequest::tool_call_response(Ok("working!"), false), + ]; + + run_test(provider, test_name, requests).await +} + +async fn tool_call_auto(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider).tool_choice(ToolChoice::Auto), + TestRequest::tool_call_response(Ok("working!"), false), + ]; + + run_test(provider, test_name, requests).await +} + +async fn tool_call_function(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider).tool_choice_fn("run_me"), + TestRequest::tool_call_response(Ok("working!"), true), + ]; + + run_test(provider, test_name, requests).await +} + +async fn tool_call_reasoning(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + tool_call_base(provider).enable_reasoning(), + TestRequest::tool_call_response(Ok("working!"), false), + ]; + + run_test(provider, test_name, requests).await +} + +async fn model_details(provider: ProviderId, test_name: &str) -> Result { + let request = TestRequest::ModelDetails { + name: test_model_details(provider).id.name.to_string(), + assert: Arc::new(|_| {}), + }; + + run_test(provider, test_name, Some(request)).await +} + +async fn models(provider: ProviderId, test_name: &str) -> Result { + let request = TestRequest::Models { + assert: Arc::new(|_| {}), + }; + + run_test(provider, test_name, Some(request)).await +} + +async fn structured_output(provider: ProviderId, test_name: &str) -> Result { + let schema = crate::title::title_schema(1); + + let request = TestRequest::chat(provider).chat_request(ChatRequest { + content: "Generate a title for this conversation.".into(), + schema: Some(schema), + }); + + run_test(provider, test_name, Some(request)).await +} + +async fn multi_turn_conversation(provider: ProviderId, test_name: &str) -> Result { + let requests = vec![ + TestRequest::chat(provider).chat_request("Test message"), + TestRequest::chat(provider) + .enable_reasoning() + .chat_request("Repeat my previous message"), + tool_call_base(provider).tool_choice_fn("run_me"), + TestRequest::tool_call_response(Ok("The secret code is: 42"), true), + TestRequest::chat(provider) + .enable_reasoning() + .chat_request("What was the result of the previous tool call?"), + ]; + + run_test(provider, test_name, requests).await +} + +test_all_providers![ + chat_completion_stream, + tool_call_auto, + tool_call_function, + tool_call_reasoning, + tool_call_required_no_reasoning, + tool_call_required_reasoning, + tool_call_stream, + model_details, + models, + multi_turn_conversation, + structured_output, +]; diff --git a/crates/jp_llm/src/query.rs b/crates/jp_llm/src/query.rs index 3da03670..b0645770 100644 --- a/crates/jp_llm/src/query.rs +++ b/crates/jp_llm/src/query.rs @@ -1,44 +1,31 @@ -pub mod chat; -pub mod structured; - -pub use chat::ChatQuery; +use jp_config::assistant::tool_choice::ToolChoice; use jp_conversation::thread::Thread; -pub use structured::StructuredQuery; -#[derive(Debug)] -pub enum Query { - Chat(ChatQuery), - Structured(StructuredQuery), -} +use crate::tool::ToolDefinition; -impl Query { - /// Get the [`Thread`] for the query. - #[must_use] - pub fn thread(&self) -> &Thread { - match self { - Self::Chat(v) => &v.thread, - Self::Structured(v) => &v.thread, - } - } +#[derive(Debug, Clone)] +pub struct ChatQuery { + pub thread: Thread, + // TODO: Should this be taken from `thread.events`, if not, document why? + // + // I think it should, because the tools that are available to the LLM are + // always represented by the configuration in the conversation stream. If a + // user adds a new tool to a config file, that tool is not automatically + // available in existing conversations (it will be in new ones), but will + // only become available when `--tool` or `--cfg` is used. + pub tools: Vec, + // TODO: Should this instead be a delta config on `thread.events`? + // + // Same logic applies here, I think? + pub tool_choice: ToolChoice, +} - /// Get a mutable reference to the [`Thread`] for the query. - #[must_use] - pub fn thread_mut(&mut self) -> &mut Thread { - match self { - Self::Chat(v) => &mut v.thread, - Self::Structured(v) => &mut v.thread, +impl From for ChatQuery { + fn from(thread: Thread) -> Self { + Self { + thread, + tools: vec![], + tool_choice: ToolChoice::default(), } } - - /// Returns `true` if the query is a chat query. - #[must_use] - pub fn is_chat(&self) -> bool { - matches!(self, Self::Chat(_)) - } - - /// Returns `true` if the query is a structured query. - #[must_use] - pub fn is_structured(&self) -> bool { - matches!(self, Self::Structured(_)) - } } diff --git a/crates/jp_llm/src/query/chat.rs b/crates/jp_llm/src/query/chat.rs deleted file mode 100644 index 9e950af6..00000000 --- a/crates/jp_llm/src/query/chat.rs +++ /dev/null @@ -1,33 +0,0 @@ -use jp_config::assistant::tool_choice::ToolChoice; -use jp_conversation::thread::Thread; - -use crate::tool::ToolDefinition; - -#[derive(Debug, Clone)] -pub struct ChatQuery { - pub thread: Thread, - // TODO: Should this be taken from `thread.events`, if not, document why? - // - // I think it should, because the tools that are available to the LLM are - // always represented by the configuration in the conversation stream. If a - // user adds a new tool to a config file, that tool is not automatically - // available in existing conversations (it will be in new ones), but will - // only become available when `--tool` or `--cfg` is used. - pub tools: Vec, - // TODO: Should this instead be a delta config on `thread.events`? - // - // Same logic applies here, I think? - pub tool_choice: ToolChoice, - pub tool_call_strict_mode: bool, -} - -impl From for ChatQuery { - fn from(thread: Thread) -> Self { - Self { - thread, - tools: vec![], - tool_choice: ToolChoice::default(), - tool_call_strict_mode: false, - } - } -} diff --git a/crates/jp_llm/src/query/structured.rs b/crates/jp_llm/src/query/structured.rs deleted file mode 100644 index 0383b0e6..00000000 --- a/crates/jp_llm/src/query/structured.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::fmt; - -use jp_config::{ - assistant::tool_choice::ToolChoice, - conversation::tool::{OneOrManyTypes, ToolParameterConfig}, -}; -use jp_conversation::thread::Thread; -use schemars::Schema; -use serde_json::Value; - -use crate::{Error, structured::SCHEMA_TOOL_NAME, tool::ToolDefinition}; - -type Mapping = Box Option + Send>; -type Validate = Box Result<(), String> + Send>; - -/// A structured query for LLMs. -pub struct StructuredQuery { - /// The thread to use for the query. - pub thread: Thread, - - /// The JSON schema to enforce the shape of the response. - schema: Schema, - - /// An optional mapping function to mutate the response object into a - /// different shape. - mapping: Option, - - /// Validators to run on the response. If a validator fails, its error is - /// sent back to the assistant, so that it can be fixed/retried. - /// - /// TODO: Add support for JSON Schema validation. - validators: Vec, -} - -impl fmt::Debug for StructuredQuery { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("StructuredQuery") - .field("thread", &self.thread) - .field("schema", &self.schema) - .field("mapping", &"") - .field("validators", &"") - .finish() - } -} - -impl StructuredQuery { - /// Create a new structured query. - #[must_use] - pub fn new(schema: Schema, thread: Thread) -> Self { - Self { - thread, - schema, - mapping: None, - validators: vec![], - } - } - - #[must_use] - pub fn with_mapping( - mut self, - mapping: impl Fn(&mut Value) -> Option + Send + Sync + 'static, - ) -> Self { - self.mapping = Some(Box::new(mapping)); - self - } - - #[must_use] - pub fn with_validator( - mut self, - validator: impl Fn(&Value) -> Result<(), String> + Send + Sync + 'static, - ) -> Self { - self.validators.push(Box::new(validator)); - self - } - - #[must_use] - pub fn with_schema_validator(self, schema: Schema) -> Self { - let validate = move |value: &Value| { - jsonschema::validate(schema.as_value(), value).map_err(|e| e.to_string()) - }; - - self.with_validator(validate) - } - - #[must_use] - pub fn map(&self, mut value: Value) -> Value { - self.mapping - .as_ref() - .and_then(|f| f(&mut value)) - .unwrap_or(value) - } - - pub fn validate(&self, value: &Value) -> Result<(), String> { - for validator in &self.validators { - validator(value)?; - } - - Ok(()) - } - - pub fn tool_definition(&self) -> Result { - let mut description = - "This tool can be used to deliver structured data to the caller. It is NOT intended \ - to GENERATE the requested data, but instead as a structured delivery mechanism. The \ - tool is a no-op implementation in that it allows the assistant to deliver structured \ - data to the user, but the tool will never report back with a result, instead the \ - user can take the structured data from the tool arguments provided." - .to_owned(); - if let Some(desc) = self.schema.get("description").and_then(|v| v.as_str()) { - description.push_str(&format!( - " Here is the description for the requested structured data:\n\n{desc}" - )); - } - - let required = self - .schema - .get("required") - .and_then(|v| v.as_array()) - .into_iter() - .flatten() - .filter_map(|v| v.as_str()) - .collect::>(); - - let parameters = self - .schema - .get("properties") - .and_then(|v| v.as_object()) - .into_iter() - .flatten() - .map(|(k, v)| { - let kind = v - .get("type") - .and_then(|v| match v.clone() { - Value::String(v) => Some(v.into()), - Value::Array(v) => Some( - v.into_iter() - .filter_map(|v| match v { - Value::String(v) => Some(v), - _ => None, - }) - .collect::>() - .into(), - ), - _ => None, - }) - .unwrap_or_else(|| OneOrManyTypes::One("object".to_owned())); - - let parameter = ToolParameterConfig { - kind, - required: required.contains(&k.as_str()), - description: v - .get("description") - .and_then(|v| v.as_str()) - .map(str::to_owned), - default: v.get("default").cloned(), - enumeration: v - .get("enum") - .and_then(|v| v.as_array().cloned()) - .unwrap_or_default(), - items: v - .get("items") - .map(|v| serde_json::from_value(v.clone())) - .transpose()?, - }; - - Ok((k.to_owned(), parameter)) - }) - .collect::>()?; - - Ok(ToolDefinition { - name: SCHEMA_TOOL_NAME.to_owned(), - description: Some(description), - parameters, - include_tool_answers_parameter: false, - }) - } - - pub fn tool_choice(&self) -> Result { - Ok(ToolChoice::Function(self.tool_definition()?.name)) - } -} diff --git a/crates/jp_llm/src/retry.rs b/crates/jp_llm/src/retry.rs new file mode 100644 index 00000000..180c8c64 --- /dev/null +++ b/crates/jp_llm/src/retry.rs @@ -0,0 +1,116 @@ +//! Retry utilities for resilient LLM request handling. + +use std::time::Duration; + +use futures::TryStreamExt as _; +use tracing::{debug, warn}; + +use crate::{Provider, error::Result, event::Event, model::ModelDetails, query::ChatQuery}; + +/// Configuration for resilient stream retries. +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// Maximum number of retry attempts. + pub max_retries: u32, + + /// Base backoff delay in milliseconds. + pub base_backoff_ms: u64, + + /// Maximum backoff delay in seconds. + pub max_backoff_secs: u64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + base_backoff_ms: 1000, + max_backoff_secs: 30, + } + } +} + +/// Execute `chat_completion_stream` with automatic retries on transient errors. +/// +/// Collects the full event stream into a `Vec`. On retryable stream +/// errors, backs off and retries the entire request up to `config.max_retries` +/// times. +/// +/// Non-retryable errors and errors from `chat_completion_stream` itself (before +/// streaming starts) are propagated immediately. +pub async fn collect_with_retry( + provider: &dyn Provider, + model: &ModelDetails, + query: ChatQuery, + config: &RetryConfig, +) -> Result> { + let mut attempt = 0u32; + + loop { + let stream = provider + .chat_completion_stream(model, query.clone()) + .await?; + + match stream.try_collect::>().await { + Ok(events) => return Ok(events), + Err(error) => { + attempt += 1; + + if !error.is_retryable() || attempt > config.max_retries { + warn!( + attempt, + max = config.max_retries, + error = error.to_string(), + "Stream error (exhausted retries)." + ); + return Err(error.into()); + } + + let delay = match error.retry_after { + Some(d) => d.min(Duration::from_secs(config.max_backoff_secs)), + None => exponential_backoff( + attempt, + config.base_backoff_ms, + config.max_backoff_secs, + ), + }; + + debug!( + attempt, + max = config.max_retries, + delay_ms = delay.as_millis(), + error = error.to_string(), + "Retryable stream error, backing off." + ); + + tokio::time::sleep(delay).await; + } + } + } +} + +/// Calculate exponential backoff delay. +/// +/// Formula: `min(base * 2^attempt, max_backoff)` +/// +/// # Arguments +/// +/// * `attempt` - Current attempt number (1-based). The delay doubles with +/// each attempt. +/// * `base_backoff_ms` - Base delay in milliseconds for the first attempt. +/// * `max_backoff_secs` - Maximum delay cap in seconds. +#[must_use] +pub fn exponential_backoff(attempt: u32, base_backoff_ms: u64, max_backoff_secs: u64) -> Duration { + let max_ms = max_backoff_secs * 1000; + + // Cap the exponent to avoid overflow. + let capped_attempt = attempt.saturating_sub(1).min(20); + let base_delay = base_backoff_ms.saturating_mul(1u64 << capped_attempt); + let total_ms = base_delay.min(max_ms); + + Duration::from_millis(total_ms) +} + +#[cfg(test)] +#[path = "retry_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/retry_tests.rs b/crates/jp_llm/src/retry_tests.rs new file mode 100644 index 00000000..eb413905 --- /dev/null +++ b/crates/jp_llm/src/retry_tests.rs @@ -0,0 +1,61 @@ +use super::*; +use crate::error::StreamError; + +/// Default base backoff for tests. +const TEST_BASE_BACKOFF_MS: u64 = 1000; + +/// Default max backoff for tests. +const TEST_MAX_BACKOFF_SECS: u64 = 60; + +#[test] +fn backoff_increases() { + let d1 = exponential_backoff(1, TEST_BASE_BACKOFF_MS, TEST_MAX_BACKOFF_SECS); + let d2 = exponential_backoff(2, TEST_BASE_BACKOFF_MS, TEST_MAX_BACKOFF_SECS); + let d3 = exponential_backoff(3, TEST_BASE_BACKOFF_MS, TEST_MAX_BACKOFF_SECS); + + // Base delays should roughly double + // attempt 1: ~1000ms, attempt 2: ~2000ms, attempt 3: ~4000ms + assert!(d1 < d2); + assert!(d2 < d3); +} + +#[test] +fn backoff_capped() { + let d_high = exponential_backoff(100, TEST_BASE_BACKOFF_MS, TEST_MAX_BACKOFF_SECS); + + // Should be capped at max_backoff_secs + assert!(d_high <= Duration::from_secs(TEST_MAX_BACKOFF_SECS + 1)); +} + +#[test] +fn backoff_respects_config() { + // Custom base and max + let d1 = exponential_backoff(1, 500, 10); + let d2 = exponential_backoff(1, 2000, 10); + + // Higher base should give higher delay + assert!(d1 < d2); + + // Should respect max cap + let d_capped = exponential_backoff(100, 1000, 5); + assert!(d_capped <= Duration::from_secs(5)); +} + +#[test] +fn stream_error_is_retryable() { + // Retryable error kinds + assert!(StreamError::timeout("test").is_retryable()); + assert!(StreamError::connect("test").is_retryable()); + assert!(StreamError::rate_limit(None).is_retryable()); + assert!(StreamError::transient("test").is_retryable()); + + // Non-retryable + assert!(!StreamError::other("test").is_retryable()); +} + +#[test] +fn stream_error_with_retry_after() { + let err = StreamError::rate_limit(Some(Duration::from_secs(30))); + assert_eq!(err.retry_after, Some(Duration::from_secs(30))); + assert!(err.is_retryable()); +} diff --git a/crates/jp_llm/src/stream.rs b/crates/jp_llm/src/stream.rs index dbbb4289..b90c0035 100644 --- a/crates/jp_llm/src/stream.rs +++ b/crates/jp_llm/src/stream.rs @@ -2,9 +2,13 @@ use std::pin::Pin; use futures::Stream; -use crate::{Error, event::Event}; +use crate::{error::StreamError, event::Event}; pub(super) mod aggregator; pub(super) mod chain; -pub type EventStream = Pin> + Send>>; +/// A stream of events from an LLM provider. +/// +/// Errors are represented as `StreamError` to provide provider-agnostic error +/// classification for retry logic. +pub type EventStream = Pin> + Send>>; diff --git a/crates/jp_llm/src/stream/aggregator.rs b/crates/jp_llm/src/stream/aggregator.rs index a1e8adc9..f3c917b5 100644 --- a/crates/jp_llm/src/stream/aggregator.rs +++ b/crates/jp_llm/src/stream/aggregator.rs @@ -1,3 +1,2 @@ -pub(crate) mod chunk; pub(crate) mod reasoning; pub(crate) mod tool_call_request; diff --git a/crates/jp_llm/src/stream/aggregator/chunk.rs b/crates/jp_llm/src/stream/aggregator/chunk.rs deleted file mode 100644 index a0d730b4..00000000 --- a/crates/jp_llm/src/stream/aggregator/chunk.rs +++ /dev/null @@ -1,195 +0,0 @@ -use indexmap::{IndexMap, map::Entry}; -use jp_conversation::{ - ConversationEvent, EventKind, - event::{ChatResponse, ToolCallRequest}, -}; -use serde_json::Value; -use tracing::warn; - -use crate::event::Event; - -/// A buffering state machine that consumes multiplexed streaming events and -/// produces coalesced events keyed by index. -pub struct EventAggregator { - /// The currently accumulating events, keyed by stream index. - pending: IndexMap, -} - -impl EventAggregator { - /// Create a new, empty aggregator. - pub fn new() -> Self { - Self { - pending: IndexMap::new(), - } - } - - /// Consumes a single streaming [`Event`] and returns a vector of zero or - /// more "completed" `Event`s. - /// - /// For consistency, we always send a `Flush` event after flushing a merged - /// `Part` event. While this is not strictly necessary, it makes the API - /// more consistent to use, regardless of whether the event aggregator is - /// used or not. - pub fn ingest(&mut self, event: Event) -> Vec { - match event { - Event::Part { index, event } => match self.pending.entry(index) { - // Nothing buffered for this index, start buffering. - Entry::Vacant(e) => { - e.insert(event); - vec![] - } - Entry::Occupied(mut e) => match try_merge_events(e.get_mut(), event) { - // Merge succeeded. Continue buffering. - Ok(()) => vec![], - // Merge failed (types were different). Force flush the OLD - // event, replace it with the NEW event. - Err(unmerged) => vec![ - Event::Part { - index, - event: e.insert(unmerged), - }, - Event::flush(index), - ], - }, - }, - - Event::Flush { index, metadata } => { - if let Some(event) = self.pending.shift_remove(&index) { - vec![ - Event::Part { - index, - event: event.with_metadata(metadata), - }, - Event::flush(index), - ] - } else { - if !metadata.is_empty() { - warn!( - index, - metadata = ?metadata, - "Received Flush with metadata for empty index." - ); - } - - vec![] - } - } - - Event::Finished(reason) => self - .pending - .drain(..) - .flat_map(|(index, event)| vec![Event::Part { index, event }, Event::flush(index)]) - .chain(std::iter::once(Event::Finished(reason))) - .collect(), - } - } -} - -/// Attempts to merge `incoming` into `target`. Returns `Ok(())` if successful, -/// or `Err(incoming)` if the events were incompatible (e.g., different types), -/// passing ownership of the incoming event back to the caller. -fn try_merge_events( - target: &mut ConversationEvent, - incoming: ConversationEvent, -) -> Result<(), ConversationEvent> { - let ConversationEvent { - kind, - metadata, - timestamp, - } = incoming; - - match (&mut target.kind, kind) { - (EventKind::ChatResponse(t_resp), EventKind::ChatResponse(i_resp)) => { - match merge_chat_responses(t_resp, i_resp) { - Ok(()) => { - // Merge successful. - // - // Now merge the remaining fields from the destructured - // event. - target.metadata.extend(metadata); - target.timestamp = timestamp; - - Ok(()) - } - Err(returned_resp) => { - // Merge failed (variant mismatch). - // - // Reconstruct the event using the returned response and - // original fields. - Err(ConversationEvent { - kind: EventKind::ChatResponse(returned_resp), - metadata, - timestamp, - }) - } - } - } - (EventKind::ToolCallRequest(t_tool), EventKind::ToolCallRequest(i_tool)) => { - merge_tool_calls(t_tool, i_tool); - target.metadata.extend(metadata); - target.timestamp = timestamp; - - Ok(()) - } - // Mismatch in high-level types (e.g. ChatResponse vs ToolCallRequest). - // Reconstruct the event and return it as an error. - (_, other_kind) => Err(ConversationEvent { - kind: other_kind, - metadata, - timestamp, - }), - } -} - -/// Merges two `ChatResponse` items. -/// -/// Returns `Err(incoming)` if they are different variants (Message vs -/// Reasoning). -fn merge_chat_responses( - target: &mut ChatResponse, - incoming: ChatResponse, -) -> Result<(), ChatResponse> { - match (target, incoming) { - (ChatResponse::Message { message: t_msg }, ChatResponse::Message { message: i_msg }) => { - t_msg.push_str(&i_msg); - Ok(()) - } - ( - ChatResponse::Reasoning { reasoning: t_reas }, - ChatResponse::Reasoning { reasoning: i_reas }, - ) => { - t_reas.push_str(&i_reas); - Ok(()) - } - - // Variants didn't match. - (_, incoming) => Err(incoming), - } -} - -/// Merges two `ToolCallRequest` items. -fn merge_tool_calls(target: &mut ToolCallRequest, incoming: ToolCallRequest) { - if target.id.is_empty() && !incoming.id.is_empty() { - target.id = incoming.id; - } - - if target.name.is_empty() && !incoming.name.is_empty() { - target.name = incoming.name; - } - - for (key, val) in incoming.arguments { - match target.arguments.get_mut(&key) { - Some(existing_val) => { - if let (Value::String(s1), Value::String(s2)) = (existing_val, &val) { - s1.push_str(s2); - } else { - // Overwrite non-string values - *target.arguments.entry(key).or_insert(Value::Null) = val; - } - } - None => { - target.arguments.insert(key, val); - } - } - } -} diff --git a/crates/jp_llm/src/stream/aggregator/reasoning.rs b/crates/jp_llm/src/stream/aggregator/reasoning.rs index 3e8aa68c..34b9beb5 100644 --- a/crates/jp_llm/src/stream/aggregator/reasoning.rs +++ b/crates/jp_llm/src/stream/aggregator/reasoning.rs @@ -1,3 +1,6 @@ +//! An extractor that segments a stream of text into 'reasoning' and 'other' +//! buckets. + #[derive(Default, Debug)] /// A parser that segments a stream of text into 'reasoning' and 'other' /// buckets. It handles streams with or without a `` block. diff --git a/crates/jp_llm/src/stream/chain.rs b/crates/jp_llm/src/stream/chain.rs index 5a9fe19d..8bce4cb1 100644 --- a/crates/jp_llm/src/stream/chain.rs +++ b/crates/jp_llm/src/stream/chain.rs @@ -293,7 +293,11 @@ fn event_text_len(event: &Event) -> usize { event .as_conversation_event() .and_then(ConversationEvent::as_chat_response) - .map_or(0, |v| v.content().len()) + .map_or(0, |v| match v { + ChatResponse::Message { message } => message.len(), + ChatResponse::Reasoning { reasoning } => reasoning.len(), + ChatResponse::Structured { .. } => 0, + }) } /// Reconstruct text from a deque of events. @@ -305,10 +309,16 @@ fn reconstruct_text(events: &VecDeque) -> (String, Vec<(usize, usize)>) { let mut map = Vec::new(); for (i, event) in events.iter().enumerate() { - if let Some(content) = event + let content = event .as_conversation_event() .and_then(ConversationEvent::as_chat_response) - .map(ChatResponse::content) + .and_then(|v| match v { + ChatResponse::Message { message } => Some(message.as_str()), + ChatResponse::Reasoning { reasoning } => Some(reasoning.as_str()), + ChatResponse::Structured { .. } => None, + }); + + if let Some(content) = content && !content.is_empty() { s.push_str(content); @@ -324,7 +334,11 @@ fn trim_event_start(event: &mut Event, count: usize) { let Some(content) = event .as_conversation_event_mut() .and_then(ConversationEvent::as_chat_response_mut) - .map(ChatResponse::content_mut) + .and_then(|v| match v { + ChatResponse::Message { message } => Some(message), + ChatResponse::Reasoning { reasoning } => Some(reasoning), + ChatResponse::Structured { .. } => None, + }) else { return; }; diff --git a/crates/jp_llm/src/structured.rs b/crates/jp_llm/src/structured.rs deleted file mode 100644 index 1645c1f0..00000000 --- a/crates/jp_llm/src/structured.rs +++ /dev/null @@ -1,31 +0,0 @@ -//! Tools for requesting structured data from LLMs using tool calls. - -pub mod titles; - -use jp_config::model::id::ModelIdConfig; -use serde::de::DeserializeOwned; - -use crate::{error::Result, provider::Provider, query::StructuredQuery}; - -// Name of the schema enforcement tool -pub(crate) const SCHEMA_TOOL_NAME: &str = "generate_structured_data"; - -/// Request structured data from the LLM for any type `T` that implements -/// [`DeserializeOwned`]. -/// -/// It assumes a [`StructuredQuery`] that has a schema to enforce the correct -/// sturcute for `T`. -/// -/// If a LLM model enforces a JSON object as the response, but you want (e.g.) a -/// list of items, you can use [`StructuredQuery::with_mapping`] to map the -/// response object into the final shape of `T`. -pub async fn completion( - provider: &dyn Provider, - model_id: &ModelIdConfig, - query: StructuredQuery, -) -> Result { - let model = provider.model_details(&model_id.name).await?; - let value = provider.structured_completion(&model, query).await?; - - serde_json::from_value(value).map_err(Into::into) -} diff --git a/crates/jp_llm/src/structured/titles.rs b/crates/jp_llm/src/structured/titles.rs deleted file mode 100644 index fd2bad28..00000000 --- a/crates/jp_llm/src/structured/titles.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::error::Error; - -use jp_config::assistant::instructions::InstructionsConfig; -use jp_conversation::{ConversationStream, thread::ThreadBuilder}; -use serde_json::Value; - -use crate::query::StructuredQuery; - -pub fn titles( - count: usize, - events: ConversationStream, - rejected: &[String], -) -> Result> { - let schema = schemars::json_schema!({ - "type": "object", - "description": format!("Provide {count} concise, descriptive factual titles for this conversation."), - "required": ["titles"], - "additionalProperties": false, - "properties": { - "titles": { - "type": "array", - "items": { - "type": "string", - "description": "A concise, descriptive title for the conversation" - }, - }, - }, - }); - - // The validator schema is more strict than the schema we use to generate - // the titles, because not all providers support the full JSON schema - // feature-set. - let validator = schemars::json_schema!({ - "type": "object", - "required": ["titles"], - "additionalProperties": false, - "properties": { - "titles": { - "type": "array", - "items": { - "type": "string", - "minLength": 1, - "maxLength": 50, - }, - "minItems": count, - "maxItems": count, - }, - }, - }); - - let mut instructions = vec![ - InstructionsConfig::default() - .with_title("Title Generation") - .with_description("Generate titles to summarize the active conversation") - .with_item(format!("Generate exactly {count} titles")) - .with_item("Concise, descriptive, factual") - .with_item("Short and to the point, no more than 50 characters") - .with_item("Deliver as a JSON array of strings") - .with_item("DO NOT mention this request to generate titles"), - ]; - - if !rejected.is_empty() { - let mut instruction = InstructionsConfig::default() - .with_title("Rejected Titles") - .with_description("These listed titles were rejected by the user and must be avoided"); - - for title in rejected { - instruction = instruction.with_item(title); - } - - instructions.push(instruction); - } - - let mapping = |value: &mut Value| value.get_mut("titles").map(Value::take); - let thread = ThreadBuilder::default() - .with_events(events) - .with_instructions(instructions) - .build()?; - - Ok(StructuredQuery::new(schema, thread) - .with_mapping(mapping) - .with_schema_validator(validator)) -} diff --git a/crates/jp_llm/src/test.rs b/crates/jp_llm/src/test.rs index 3d812646..04f26b8a 100644 --- a/crates/jp_llm/src/test.rs +++ b/crates/jp_llm/src/test.rs @@ -17,39 +17,30 @@ use jp_config::{ use jp_conversation::{ ConversationEvent, ConversationStream, event::{ChatRequest, ToolCallResponse}, + event_builder::EventBuilder, stream::ConversationEventWithConfig, thread::{Thread, ThreadBuilder}, }; use jp_test::mock::{Snap, Vcr}; -use schemars::Schema; use crate::{ event::Event, model::{ModelDetails, ReasoningDetails}, provider::get_provider, - query::{ChatQuery, StructuredQuery}, - stream::aggregator::chunk::EventAggregator, + query::ChatQuery, tool::ToolDefinition, }; +#[allow(clippy::large_enum_variant)] pub enum TestRequest { /// A chat completion request. Chat { - stream: bool, model: ModelDetails, query: ChatQuery, #[expect(clippy::type_complexity)] assert: Arc])>, }, - /// A structured completion request. - Structured { - model: ModelDetails, - query: StructuredQuery, - #[expect(clippy::type_complexity)] - assert: Arc])>, - }, - /// List all models. Models { #[expect(clippy::type_complexity)] @@ -83,7 +74,6 @@ impl TestRequest { pub fn chat(provider: ProviderId) -> Self { Self::Chat { - stream: false, model: test_model_details(provider), query: ChatQuery { thread: ThreadBuilder::new() @@ -101,35 +91,11 @@ impl TestRequest { .unwrap(), tools: vec![], tool_choice: ToolChoice::default(), - tool_call_strict_mode: false, }, assert: Arc::new(|_| {}), } } - #[expect(dead_code)] - pub fn structured(provider: ProviderId) -> Self { - Self::Structured { - model: test_model_details(provider), - query: StructuredQuery::new( - true.into(), - ThreadBuilder::new() - .with_events({ - let mut config = AppConfig::new_test(); - config.assistant.model.id = ModelIdOrAliasConfig::Id(ModelIdConfig { - provider, - name: "test".parse().unwrap(), - }); - ConversationStream::new(config.into()) - .with_created_at(Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap()) - }) - .build() - .unwrap(), - ), - assert: Arc::new(|_| {}), - } - } - pub fn tool_call_response(result: Result<&str, &str>, panic_on_missing_request: bool) -> Self { Self::ToolCallResponse { result: result.map(Into::into).map_err(Into::into), @@ -144,14 +110,6 @@ impl TestRequest { } } - pub fn stream(mut self, stream: bool) -> Self { - if let Self::Chat { stream: s, .. } = &mut self { - *s = stream; - } - - self - } - pub fn model(mut self, model: ModelIdConfig) -> Self { let Some(thread) = self.as_thread_mut() else { return self; @@ -161,9 +119,8 @@ impl TestRequest { delta.assistant.model.id = PartialModelIdOrAliasConfig::Id(model.to_partial()); thread.events.add_config_delta(delta); - match &mut self { - Self::Chat { model: m, .. } | Self::Structured { model: m, .. } => m.id = model, - _ => {} + if let Self::Chat { model: m, .. } = &mut self { + m.id = model; } self @@ -211,22 +168,6 @@ impl TestRequest { self } - #[expect(dead_code)] - pub fn schema(self, schema: impl Into) -> Self { - match self { - Self::Structured { - model, - query, - assert, - } => Self::Structured { - model, - query: StructuredQuery::new(schema.into(), query.thread), - assert, - }, - _ => self, - } - } - pub fn tool_choice_fn(self, name: impl Into) -> Self { self.tool_choice(ToolChoice::Function(name.into())) } @@ -244,21 +185,12 @@ impl TestRequest { .into_iter() .map(|(k, v)| (k.to_owned(), v)) .collect(), - include_tool_answers_parameter: false, }); } self } - pub fn tool_call_strict_mode(mut self, strict: bool) -> Self { - if let Self::Chat { query, .. } = &mut self { - query.tool_call_strict_mode = strict; - } - - self - } - #[expect(dead_code)] pub fn assert_chat(mut self, assert: impl Fn(&[Vec]) + 'static) -> Self { if let Self::Chat { assert: a, .. } = &mut self { @@ -268,18 +200,6 @@ impl TestRequest { self } - #[expect(dead_code)] - pub fn assert_structured( - mut self, - assert: impl Fn(&[Result]) + 'static, - ) -> Self { - if let Self::Structured { assert: a, .. } = &mut self { - *a = Arc::new(assert); - } - - self - } - #[expect(dead_code)] pub fn assert_models(mut self, assert: impl Fn(&[ModelDetails]) + 'static) -> Self { if let Self::Models { assert: a, .. } = &mut self { @@ -292,7 +212,6 @@ impl TestRequest { pub fn as_thread(&self) -> Option<&Thread> { match self { Self::Chat { query, .. } => Some(&query.thread), - Self::Structured { query, .. } => Some(&query.thread), _ => None, } } @@ -300,14 +219,13 @@ impl TestRequest { pub fn as_thread_mut(&mut self) -> Option<&mut Thread> { match self { Self::Chat { query, .. } => Some(&mut query.thread), - Self::Structured { query, .. } => Some(&mut query.thread), _ => None, } } pub fn as_model_details_mut(&mut self) -> Option<&mut ModelDetails> { match self { - Self::Chat { model, .. } | Self::Structured { model, .. } => Some(model), + Self::Chat { model, .. } => Some(model), _ => None, } } @@ -316,19 +234,8 @@ impl TestRequest { impl std::fmt::Debug for TestRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Chat { - stream, - model, - query, - .. - } => f + Self::Chat { model, query, .. } => f .debug_struct("Chat") - .field("stream", stream) - .field("model", model) - .field("query", query) - .finish(), - Self::Structured { model, query, .. } => f - .debug_struct("Structured") .field("model", model) .field("query", query) .finish(), @@ -414,9 +321,6 @@ pub async fn run_chat_completion( } let provider = get_provider(provider_id, &config).unwrap(); - let has_structured_request = requests - .iter() - .any(|v| matches!(v, TestRequest::Structured { .. })); let has_chat_request = requests .iter() .any(|v| matches!(v, TestRequest::Chat { .. })); @@ -433,7 +337,6 @@ pub async fn run_chat_completion( let mut all_events = vec![]; let mut history = vec![]; - let mut structured_history = vec![]; let mut model_details = vec![]; let mut models = vec![]; @@ -491,9 +394,6 @@ pub async fn run_chat_completion( TestRequest::Chat { query, .. } => { query.thread.events.config().unwrap().to_partial() } - TestRequest::Structured { query, .. } => { - query.thread.events.config().unwrap().to_partial() - } TestRequest::Models { .. } | TestRequest::ModelDetails { .. } => { PartialAppConfig::empty() } @@ -529,57 +429,75 @@ pub async fn run_chat_completion( // 3. Then we run the query, and collect the new events. match request { TestRequest::Chat { - stream, model, query, assert, } => { - let mut agg = EventAggregator::new(); - let events = if stream { - provider - .chat_completion_stream(&model, query) - .await - .unwrap() - .try_collect() - .await - .unwrap() - } else { - provider.chat_completion(&model, query).await.unwrap() - }; - - for mut event in events { - if let Event::Part { event, .. } = &mut event { - event.timestamp = - Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap(); - } - - all_events[index].push(event.clone()); - - for event in agg.ingest(event) { - if let Event::Part { event, .. } = event { - if let Some(stream) = conversation_stream.as_mut() { - stream.push(event.clone()); + let events: Vec = provider + .chat_completion_stream(&model, query) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let stream = conversation_stream + .as_mut() + .expect("Chat request always sets conversation_stream"); + let mut builder = EventBuilder::new(); + + for llm_event in events { + match llm_event { + Event::Part { index: idx, event } => { + builder.handle_part(idx, event); + } + Event::Flush { + index: idx, + metadata, + } => { + if let Some(mut event) = builder.handle_flush(idx, metadata) { + // Normalize timestamp for deterministic snapshots. + event.timestamp = + Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap(); + + all_events[index].push(Event::Part { + index: idx, + event: event.clone(), + }); + all_events[index].push(Event::flush(idx)); + + history.push(ConversationEventWithConfig { + event: event.clone(), + config: config.clone(), + }); + stream.push(event); + } + } + Event::Finished(reason) => { + for mut event in builder.drain() { + event.timestamp = + Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap(); + + all_events[index].push(Event::Part { + index: 0, + event: event.clone(), + }); + all_events[index].push(Event::flush(0)); + + history.push(ConversationEventWithConfig { + event: event.clone(), + config: config.clone(), + }); + stream.push(event); } - history.push(ConversationEventWithConfig { - event, - config: config.clone(), - }); + all_events[index].push(Event::Finished(reason)); } } } assert(&all_events); } - TestRequest::Structured { - model, - query, - assert, - } => { - let value = provider.structured_completion(&model, query).await; - structured_history.push(value); - assert(&structured_history); - } TestRequest::Models { assert } => { let value = provider.models().await.unwrap(); models.extend(value); @@ -605,18 +523,6 @@ pub async fn run_chat_completion( ]); } - if has_structured_request { - let out = structured_history - .into_iter() - .map(|v| match v { - Ok(value) => value, - Err(error) => format!("Error::{error:?}").into(), - }) - .collect::>(); - - outputs.push(("structured_outputs", Snap::json(out))); - } - if has_model_details_request { outputs.push(("model_details", Snap::debug(model_details))); } @@ -650,7 +556,7 @@ pub(crate) fn test_model_details(id: ProviderId) -> ModelDetails { display_name: None, context_window: Some(200_000), max_output_tokens: Some(64_000), - reasoning: Some(ReasoningDetails::budgetted(128, Some(24576))), + reasoning: Some(ReasoningDetails::budgetted(512, Some(24576))), knowledge_cutoff: None, deprecated: None, features: vec![], @@ -695,6 +601,7 @@ pub(crate) fn test_model_details(id: ProviderId) -> ModelDetails { deprecated: None, features: vec![], }, + ProviderId::Test => ModelDetails::empty("test/mock-model".parse().unwrap()), ProviderId::Xai => unimplemented!(), ProviderId::Deepseek => unimplemented!(), } diff --git a/crates/jp_llm/src/title.rs b/crates/jp_llm/src/title.rs new file mode 100644 index 00000000..6eced3f8 --- /dev/null +++ b/crates/jp_llm/src/title.rs @@ -0,0 +1,97 @@ +//! Shared title generation helpers. +//! +//! Provides the JSON schema and instructions used by background callers (e.g. +//! `TitleGeneratorTask`, `conversation edit --title`) to request conversation +//! titles from an LLM via structured output. + +use jp_config::assistant::{instructions::InstructionsConfig, sections::SectionConfig}; +use serde_json::{Map, Value, json}; + +/// JSON schema for the title generation structured output. +/// +/// Returns a schema requiring an object with a `titles` array of exactly +/// `count` string elements. +#[must_use] +#[allow(clippy::missing_panics_doc)] +pub fn title_schema(count: usize) -> Map { + let schema = json!({ + "type": "object", + "required": ["titles"], + "additionalProperties": false, + "properties": { + "titles": { + "type": "array", + "items": { + "type": "string", + "description": "A concise, descriptive title for the conversation" + }, + "minItems": count, + "maxItems": count, + }, + }, + }); + + schema + .as_object() + .expect("schema is always an object") + .clone() +} + +/// Build instruction sections for title generation. +/// +/// Returns one or two sections: the main generation instructions, and +/// optionally a "rejected titles" section if `rejected` is non-empty. +#[must_use] +pub fn title_instructions(count: usize, rejected: &[String]) -> Vec { + let mut sections = vec![ + InstructionsConfig::default() + .with_title("Title Generation") + .with_description("Generate titles to summarize the active conversation") + .with_item(format!("Generate exactly {count} titles")) + .with_item("Concise, descriptive, factual") + .with_item("Short and to the point, no more than 50 characters") + .with_item("Deliver as a JSON object with a \"titles\" array of strings") + .with_item("DO NOT mention this request to generate titles") + .to_section(), + ]; + + if !rejected.is_empty() { + let mut rejected_instruction = InstructionsConfig::default() + .with_title("Rejected Titles") + .with_description("These listed titles were rejected by the user and must be avoided"); + + for title in rejected { + rejected_instruction = rejected_instruction.with_item(title); + } + + sections.push(rejected_instruction.to_section()); + } + + sections +} + +/// Extract title strings from a structured JSON response. +/// +/// Expects a JSON object with a `titles` array of strings, e.g.: +/// +/// ```json +/// {"titles": ["My Title", "Another Title"]} +/// ``` +/// +/// Returns an empty vec if the structure doesn't match. +#[must_use] +pub fn extract_titles(data: &Value) -> Vec { + data.get("titles") + .and_then(Value::as_array) + .map(|arr| { + arr.iter() + .filter_map(Value::as_str) + .map(str::to_owned) + .collect() + }) + .unwrap_or_default() +} + +#[cfg(test)] +#[path = "title_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/title_tests.rs b/crates/jp_llm/src/title_tests.rs new file mode 100644 index 00000000..034709d1 --- /dev/null +++ b/crates/jp_llm/src/title_tests.rs @@ -0,0 +1,62 @@ +use super::*; + +#[test] +fn title_schema_has_correct_structure() { + let schema = title_schema(3); + + assert_eq!(schema["type"], "object"); + assert_eq!(schema["properties"]["titles"]["type"], "array"); + assert_eq!(schema["properties"]["titles"]["minItems"], 3); + assert_eq!(schema["properties"]["titles"]["maxItems"], 3); + assert!( + schema["required"] + .as_array() + .unwrap() + .contains(&json!("titles")) + ); + assert_eq!(schema["additionalProperties"], false); +} + +#[test] +fn title_schema_single_title() { + let schema = title_schema(1); + assert_eq!(schema["properties"]["titles"]["minItems"], 1); + assert_eq!(schema["properties"]["titles"]["maxItems"], 1); +} + +#[test] +fn title_instructions_without_rejected() { + let sections = title_instructions(3, &[]); + assert_eq!(sections.len(), 1); +} + +#[test] +fn title_instructions_with_rejected() { + let rejected = vec!["Bad Title".to_owned(), "Worse Title".to_owned()]; + let sections = title_instructions(3, &rejected); + assert_eq!(sections.len(), 2); +} + +#[test] +fn extract_titles_valid() { + let data = json!({"titles": ["Title A", "Title B"]}); + assert_eq!(extract_titles(&data), vec!["Title A", "Title B"]); +} + +#[test] +fn extract_titles_missing_key() { + let data = json!({"other": "value"}); + assert!(extract_titles(&data).is_empty()); +} + +#[test] +fn extract_titles_wrong_type() { + let data = json!({"titles": "not an array"}); + assert!(extract_titles(&data).is_empty()); +} + +#[test] +fn extract_titles_mixed_types_filters_non_strings() { + let data = json!({"titles": ["Valid", 42, null, "Also Valid"]}); + assert_eq!(extract_titles(&data), vec!["Valid", "Also Valid"]); +} diff --git a/crates/jp_llm/src/tool.rs b/crates/jp_llm/src/tool.rs index 8b98bb16..9e541542 100644 --- a/crates/jp_llm/src/tool.rs +++ b/crates/jp_llm/src/tool.rs @@ -1,26 +1,419 @@ -use std::{fmt::Write, sync::Arc}; +//! Tool call utilities. +pub mod builtin; +pub mod executor; + +use std::{ffi::OsStr, process::Stdio, sync::Arc}; + +pub use builtin::BuiltinTool; use camino::Utf8Path; -use crossterm::style::Stylize as _; -use indexmap::{IndexMap, IndexSet}; +use indexmap::IndexMap; use jp_config::conversation::tool::{ - OneOrManyTypes, QuestionConfig, QuestionTarget, ResultMode, RunMode, ToolCommandConfig, - ToolConfigWithDefaults, ToolParameterConfig, ToolSource, item::ToolParameterItemConfig, + OneOrManyTypes, ToolCommandConfig, ToolConfigWithDefaults, ToolParameterConfig, ToolSource, }; use jp_conversation::event::ToolCallResponse; -use jp_inquire::{InlineOption, InlineSelect}; use jp_mcp::{ RawContent, ResourceContents, id::{McpServerId, McpToolId}, }; -use jp_printer::PrinterWriter; -use jp_tool::{Action, Outcome}; +use jp_tool::{Action, Outcome, Question}; use minijinja::Environment; use serde_json::{Map, Value, json}; -use tracing::{error, info, trace}; +use tokio::process::Command; +use tokio_util::sync::CancellationToken; +use tracing::{info, trace}; use crate::error::ToolError; +/// Documentation for a single tool parameter. +#[derive(Debug, Clone)] +pub struct ParameterDocs { + pub summary: Option, + pub description: Option, + pub examples: Option, +} + +impl ParameterDocs { + #[must_use] + pub fn is_empty(&self) -> bool { + self.description.is_none() && self.examples.is_none() + } +} + +/// Documentation for a single tool. +#[derive(Debug, Clone)] +pub struct ToolDocs { + pub summary: Option, + pub description: Option, + pub examples: Option, + pub parameters: IndexMap, +} + +impl ToolDocs { + #[must_use] + pub fn is_empty(&self) -> bool { + self.description.is_none() + && self.examples.is_none() + && self.parameters.values().all(ParameterDocs::is_empty) + } +} + +/// The outcome of a tool execution. +/// +/// This type represents the possible results of executing a tool's underlying +/// command or MCP call, without any interactive prompts. The caller is +/// responsible for: +/// +/// 1. Handling permission prompts **before** calling +/// [`ToolDefinition::execute()`]. +/// 2. Handling [`ExecutionOutcome::NeedsInput`] by prompting the user or +/// assistant. +/// 3. Handling result editing **after** receiving the outcome. +/// +/// # Example Flow +/// +/// ```text +/// ToolExecutor (jp_cli) ToolDefinition (jp_llm) +/// ───────────────────── ────────────────────── +/// │ +/// ├── [AwaitingPermission] +/// │ prompt_permission() +/// │ +/// ├── [Running] +/// │ ────────────────────────────► execute() +/// │ │ +/// │ ◄──────────────────────────── ExecutionOutcome +/// ├── [AwaitingInput] (if NeedsInput) +/// │ prompt_question() +/// │ ────────────────────────────► execute() (with answer) +/// │ │ +/// │ ◄──────────────────────────── ExecutionOutcome +/// ├── [AwaitingResultEdit] +/// │ prompt_result_edit() +/// │ +/// └── [Completed] +/// ``` +#[derive(Debug)] +pub enum ExecutionOutcome { + /// Tool executed and produced a result. + Completed { + /// The tool call ID (for correlation with the request). + id: String, + + /// The execution result. + /// + /// If an error occurred, it means the tool ran, but reported an error. + result: Result, + }, + + /// Tool needs additional input before it can complete. + /// + /// The caller should: + /// 1. Present the question to the user (or delegate to the assistant) + /// 2. Collect the answer + /// 3. Call [`ToolDefinition::execute()`] again with the answer in `answers` + NeedsInput { + /// The tool call ID. + id: String, + + /// The question to ask. + question: Question, + }, + + /// Tool execution was cancelled via the cancellation token. + /// + /// This occurs when the user interrupts tool execution (e.g., Ctrl+C during + /// a long-running command). + Cancelled { + /// The tool call ID. + id: String, + }, +} + +impl ExecutionOutcome { + /// Convert the outcome to a [`ToolCallResponse`]. + /// + /// This is useful for building the final response to send to the LLM after + /// any post-processing (e.g., result editing) is complete. + /// + /// # Note + /// + /// For [`ExecutionOutcome::NeedsInput`], this returns a placeholder + /// response. The caller should typically handle `NeedsInput` specially + /// rather than converting it directly to a response. + #[must_use] + pub fn into_response(self) -> ToolCallResponse { + match self { + Self::Completed { id, result } => ToolCallResponse { id, result }, + Self::NeedsInput { id, question } => ToolCallResponse { + id, + result: Ok(format!("Tool requires additional input: {}", question.text)), + }, + Self::Cancelled { id } => ToolCallResponse { + id, + result: Ok("Tool execution cancelled by user.".to_string()), + }, + } + } + + /// Returns the tool call ID. + #[must_use] + pub fn id(&self) -> &str { + match self { + Self::Completed { id, .. } | Self::NeedsInput { id, .. } | Self::Cancelled { id } => id, + } + } + + /// Returns `true` if this is a `NeedsInput` outcome. + #[must_use] + pub fn needs_input(&self) -> bool { + matches!(self, Self::NeedsInput { .. }) + } + + /// Returns `true` if this is a `Cancelled` outcome. + #[must_use] + pub fn is_cancelled(&self) -> bool { + matches!(self, Self::Cancelled { .. }) + } + + /// Returns `true` if this is a `Completed` outcome with a successful result. + #[must_use] + pub fn is_success(&self) -> bool { + matches!(self, Self::Completed { result: Ok(_), .. }) + } +} + +/// Result of running a tool command. +/// +/// This is the single parsing point for all tool command output. Both tool +/// execution and argument formatting go through this type, ensuring consistent +/// handling of `Outcome` variants (including error traces). +#[derive(Debug)] +pub enum CommandResult { + /// Tool produced content. + Success(String), + + /// Tool reported a transient error (can be retried). + TransientError { + /// The error message. + message: String, + + /// The error trace (source chain from the tool process). + trace: Vec, + }, + + /// Tool reported a fatal error. + FatalError(String), + + /// Tool needs additional input before it can continue. + NeedsInput(Question), + + /// Tool was cancelled via the cancellation token. + Cancelled, + + /// stdout wasn't valid `Outcome` JSON. + /// + /// Falls back to treating stdout as plain text. The `success` flag + /// indicates the process exit status. + RawOutput { + /// Raw stdout content. + stdout: String, + + /// Raw stderr content. + stderr: String, + + /// Whether the process exited successfully. + success: bool, + }, +} + +impl CommandResult { + /// Format a transient error message including trace details. + /// + /// If the trace is empty, returns just the message. Otherwise appends + /// the trace entries so the LLM (or user) can see the root cause. + #[must_use] + pub fn format_error(message: &str, trace: &[String]) -> String { + if trace.is_empty() { + message.to_owned() + } else { + format!("{message}\n\nTrace:\n{}", trace.join("\n")) + } + } + + /// Convert to a `Result` suitable for tool call responses. + /// + /// - `Success` → `Ok(content)` + /// - `TransientError` → `Err(json with message + trace)` + /// - `FatalError` → `Err(raw json)` + /// - `NeedsInput` → handled separately by callers (this panics) + /// - `Cancelled` → `Ok(cancellation message)` + /// - `RawOutput` → `Ok(stdout)` if success, `Err(json)` if failure + pub fn into_tool_result(self, name: &str) -> Result { + match self { + Self::Success(content) => Ok(content), + Self::TransientError { message, trace } => Err(json!({ + "message": message, + "trace": trace, + }) + .to_string()), + Self::FatalError(raw) => Err(raw), + Self::Cancelled => Ok("Tool execution cancelled by user.".to_string()), + Self::RawOutput { + stdout, + stderr, + success, + } => { + if success { + Ok(stdout) + } else { + Err(json!({ + "message": format!("Tool '{name}' execution failed."), + "stderr": stderr, + "stdout": stdout, + }) + .to_string()) + } + } + Self::NeedsInput(_) => { + unreachable!("NeedsInput should be handled by the caller") + } + } + } +} + +/// Run a tool command asynchronously with cancellation support. +/// +/// This is the **single entry point** for running tool commands (both execution +/// and argument formatting). It handles: +/// +/// 1. Template rendering via [`minijinja`] +/// 2. Process spawning via Tokio's [`Command`] +/// 3. Cancellation via [`CancellationToken`] +/// 4. Parsing stdout as [`jp_tool::Outcome`] +pub async fn run_tool_command( + command: ToolCommandConfig, + ctx: Value, + root: &Utf8Path, + cancellation_token: CancellationToken, +) -> Result { + let ToolCommandConfig { + program, + args, + shell, + } = command; + + let tmpl = Arc::new(Environment::new()); + + let program = tmpl + .render_str(&program, &ctx) + .map_err(|error| ToolError::TemplateError { + data: program.clone(), + error, + })?; + + let args = args + .iter() + .map(|s| tmpl.render_str(s, &ctx)) + .collect::, _>>() + .map_err(|error| ToolError::TemplateError { + data: args.join(" "), + error, + })?; + + let mut cmd = if shell { + let shell_cmd = std::iter::once(program.clone()) + .chain(args.iter().cloned()) + .collect::>() + .join(" "); + + let mut cmd = Command::new("sh"); + cmd.arg("-c").arg(&shell_cmd); + cmd + } else { + let mut cmd = Command::new(&program); + cmd.args(&args); + cmd + }; + + let child = cmd + .current_dir(root.as_std_path()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(|error| ToolError::SpawnError { + command: format!( + "{} {}", + cmd.as_std().get_program().to_string_lossy(), + cmd.as_std() + .get_args() + .filter_map(OsStr::to_str) + .collect::>() + .join(" ") + ), + error, + })?; + + let wait_handle = tokio::spawn(async move { child.wait_with_output().await }); + let abort_handle = wait_handle.abort_handle(); + + tokio::select! { + biased; + () = cancellation_token.cancelled() => { + abort_handle.abort(); + Ok(CommandResult::Cancelled) + } + result = wait_handle => { + match result { + Ok(Ok(output)) => Ok(parse_command_output( + &output.stdout, + &output.stderr, + output.status.success(), + )), + Ok(Err(error)) => Ok(CommandResult::RawOutput { + stdout: String::new(), + stderr: error.to_string(), + success: false, + }), + Err(join_error) => Ok(CommandResult::RawOutput { + stdout: String::new(), + stderr: format!("Task panicked: {join_error}"), + success: false, + }), + } + } + } +} + +/// Parse raw command output into a [`CommandResult`]. +/// +/// Tries to deserialize stdout as [`jp_tool::Outcome`]. If that fails, +/// falls back to [`CommandResult::RawOutput`]. +fn parse_command_output(stdout: &[u8], stderr: &[u8], success: bool) -> CommandResult { + let stdout_str = String::from_utf8_lossy(stdout); + + match serde_json::from_str::(&stdout_str) { + Ok(Outcome::Success { content }) => CommandResult::Success(content), + Ok(Outcome::Error { + transient, + message, + trace, + }) => { + if transient { + CommandResult::TransientError { message, trace } + } else { + CommandResult::FatalError(stdout_str.into_owned()) + } + } + Ok(Outcome::NeedsInput { question }) => CommandResult::NeedsInput(question), + Err(_) => CommandResult::RawOutput { + stdout: stdout_str.into_owned(), + stderr: String::from_utf8_lossy(stderr).into_owned(), + success, + }, + } +} + /// The definition of a tool. /// /// The definition source is either a [`ToolConfig`] for `local` tools, or a @@ -33,13 +426,6 @@ pub struct ToolDefinition { pub name: String, pub description: Option, pub parameters: IndexMap, - - /// Whether the tool should include the `tool_answers` parameter. - /// - /// This is `true` for any tool that has its questions configured such that - /// at least one question has to be answered by the assistant instead of the - /// user. - pub include_tool_answers_parameter: bool, } impl ToolDefinition { @@ -48,15 +434,13 @@ impl ToolDefinition { source: &ToolSource, description: Option, parameters: IndexMap, - questions: &IndexMap, mcp_client: &jp_mcp::Client, ) -> Result { match &source { - ToolSource::Local { .. } => Ok(local_tool_definition( + ToolSource::Local { .. } | ToolSource::Builtin { .. } => Ok(local_tool_definition( name.to_owned(), description, parameters, - questions, )), ToolSource::Mcp { server, tool } => { mcp_tool_definition( @@ -69,126 +453,139 @@ impl ToolDefinition { ) .await } - ToolSource::Builtin { .. } => todo!(), } } - pub fn format_args( - &self, - name: Option<&str>, - cmd: &ToolCommandConfig, - arguments: &Map, - root: &Utf8Path, - ) -> Result, ToolError> { - let name = name.unwrap_or(&self.name); - if arguments.is_empty() { - return Ok(Ok(String::new())); - } - - let ctx = json!({ - "tool": { - "name": self.name, - "arguments": arguments, - }, - "context": { - "action": Action::FormatArguments, - "root": root.as_str(), - }, - }); - - run_cmd_with_ctx(name, cmd, &ctx, root) - } - - pub async fn call( + /// Execute the tool without any interactive prompts. + /// + /// This is a pure execution method that runs the tool's underlying command + /// or MCP call and returns an [`ExecutionOutcome`]. All interactive + /// decisions (permission prompts, result editing, question handling) are + /// the caller's responsibility. + /// + /// # Arguments + /// + /// * `id` - The tool call ID for correlation with the request + /// * `arguments` - The tool arguments (caller is responsible for any pre-processing) + /// * `answers` - Pre-provided answers to tool questions (from previous `NeedsInput`) + /// * `config` - Tool configuration + /// * `mcp_client` - MCP client for MCP tool execution + /// * `root` - Working directory for local tool execution + /// * `cancellation_token` - Token to cancel long-running execution + /// * `builtin_executors` - Registry of builtin tools + /// + /// # Returns + /// + /// - [`ExecutionOutcome::Completed`] - Tool finished (check inner `Result` for success/error) + /// - [`ExecutionOutcome::NeedsInput`] - Tool needs user input to continue + /// - [`ExecutionOutcome::Cancelled`] - Execution was cancelled via the token + /// + /// # Errors + /// + /// Returns [`ToolError`] for infrastructure errors (spawn failure, missing + /// command, etc.). Tool-level errors (command returned non-zero) are + /// returned as `Ok(ExecutionOutcome::Completed { result: Err(...) })`. + /// + /// # Example + /// + /// ```ignore + /// loop { + /// match definition.execute(id, &args, &answers, ...).await? { + /// ExecutionOutcome::Completed { result, .. } => { + /// // Handle success or tool error + /// break result; + /// } + /// ExecutionOutcome::NeedsInput { question, .. } => { + /// // Prompt user for input + /// let answer = prompt_user(&question)?; + /// answers.insert(question.id, answer); + /// // Loop to retry with answer + /// } + /// ExecutionOutcome::Cancelled { .. } => { + /// break Ok("Cancelled".into()); + /// } + /// } + /// } + /// ``` + pub async fn execute( &self, id: String, - mut arguments: Value, + arguments: Value, answers: &IndexMap, - pending_questions: &IndexSet, + config: &ToolConfigWithDefaults, mcp_client: &jp_mcp::Client, - config: ToolConfigWithDefaults, root: &Utf8Path, - editor: Option<&Utf8Path>, - writer: PrinterWriter<'_>, - ) -> Result { - info!(tool = %self.name, arguments = ?arguments, "Calling tool."); - - // If the tool call has answers to provide to the tool, it means the - // tool already ran once, and we should not ask for confirmation again. - let run_mode = if pending_questions.is_empty() { - config.run() - } else { - RunMode::Unattended - }; + cancellation_token: CancellationToken, + builtin_executors: &builtin::BuiltinExecutors, + ) -> Result { + info!(tool = %self.name, arguments = ?arguments, "Executing tool."); - let mut result_mode = config.result(); - if answers.is_empty() { - self.prepare_run( - run_mode, - &mut result_mode, - &mut arguments, - config.source(), - mcp_client, - editor, - writer, - ) - .await?; - } - - let result = match config.source() { + match config.source() { ToolSource::Local { tool } => { - self.call_local(id, &arguments, answers, &config, tool.as_deref(), root)? + self.execute_local( + id, + arguments, + answers, + config, + tool.as_deref(), + root, + cancellation_token, + ) + .await } ToolSource::Mcp { server, tool } => { - self.call_mcp( + self.execute_mcp( id, - &arguments, + arguments, mcp_client, server.as_deref(), tool.as_deref(), + cancellation_token, ) - .await? + .await } - ToolSource::Builtin { .. } => todo!(), - }; - - trace!(result = ?result, "Tool call completed."); - self.prepare_result(result, result_mode, editor, writer) + ToolSource::Builtin { .. } => { + self.execute_builtin(id, &arguments, answers, builtin_executors) + .await + } + } } - fn call_local( + /// Execute a local tool and return the outcome. + /// + /// This is the pure execution path for local tools. It validates arguments, + /// runs the command, and converts the result to an `ExecutionOutcome`. + async fn execute_local( &self, id: String, - arguments: &Value, + mut arguments: Value, answers: &IndexMap, config: &ToolConfigWithDefaults, tool: Option<&str>, root: &Utf8Path, - ) -> Result { + cancellation_token: CancellationToken, + ) -> Result { let name = tool.unwrap_or(&self.name); - // TODO: Should we enforce at a type-level this for all tool calls, even - // MCP? - if let Some(args) = arguments.as_object() - && let Err(error) = validate_tool_arguments( - args, - &config - .parameters() - .iter() - .map(|(k, v)| (k.to_owned(), v.required)) - .collect(), - ) - { - return Ok(ToolCallResponse { - id, - result: Err(format!("Invalid arguments: {error}")), - }); + // Apply configured defaults for missing parameters, then validate. + if let Some(args) = arguments.as_object_mut() { + apply_parameter_defaults(args, config.parameters()); + + if let Err(error) = validate_tool_arguments(args, config.parameters()) { + return Ok(ExecutionOutcome::Completed { + id, + result: Err(format!( + "Invalid arguments: {error}\n\nYou can call `describe_tools(tools: \ + [\"{name}\"])` to learn more about how to use the tool correctly." + )), + }); + } } let ctx = json!({ "tool": { "name": name, - "arguments": arguments, + "arguments": &arguments, "answers": answers, }, "context": { @@ -201,73 +598,123 @@ impl ToolDefinition { return Err(ToolError::MissingCommand); }; - Ok(ToolCallResponse { - id, - result: run_cmd_with_ctx(name, &command, &ctx, root)?, - }) + match run_tool_command(command, ctx, root, cancellation_token).await? { + CommandResult::Success(content) => Ok(ExecutionOutcome::Completed { + id, + result: Ok(content), + }), + CommandResult::NeedsInput(question) => { + Ok(ExecutionOutcome::NeedsInput { id, question }) + } + CommandResult::Cancelled => Ok(ExecutionOutcome::Cancelled { id }), + other => Ok(ExecutionOutcome::Completed { + id, + result: other.into_tool_result(name), + }), + } } - async fn call_mcp( + /// Execute an MCP tool and return the outcome. + /// + /// This is the pure execution path for MCP tools. It calls the MCP server + /// and converts the result to an `ExecutionOutcome`. + async fn execute_mcp( &self, id: String, - arguments: &Value, + arguments: Value, mcp_client: &jp_mcp::Client, server: Option<&str>, tool: Option<&str>, - ) -> Result { + cancellation_token: CancellationToken, + ) -> Result { let name = tool.unwrap_or(&self.name); - let result = mcp_client - .call_tool(name, server, arguments) - .await - .map_err(ToolError::McpRunToolError)?; + let call_future = mcp_client.call_tool(name, server, &arguments); - let content = result - .content - .into_iter() - .filter_map(|v| match v.raw { - RawContent::Text(v) => Some(v.text), - RawContent::Resource(v) => match v.resource { - ResourceContents::TextResourceContents { text, .. } => Some(text), - ResourceContents::BlobResourceContents { blob, .. } => Some(blob), - }, - RawContent::Image(_) | RawContent::Audio(_) => None, - }) - .collect::>() - .join("\n\n"); + tokio::select! { + biased; + () = cancellation_token.cancelled() => { + info!(tool = %self.name, "MCP tool call cancelled"); + Ok(ExecutionOutcome::Cancelled { id }) + } + result = call_future => { + let result = result.map_err(ToolError::McpRunToolError)?; + + let content = result + .content + .into_iter() + .filter_map(|v| match v.raw { + RawContent::Text(v) => Some(v.text), + RawContent::Resource(v) => match v.resource { + ResourceContents::TextResourceContents { text, .. } => Some(text), + ResourceContents::BlobResourceContents { blob, .. } => Some(blob), + }, + RawContent::Image(_) | RawContent::Audio(_) => None, + }) + .collect::>() + .join("\n\n"); - Ok(ToolCallResponse { - id, - result: if result.is_error.unwrap_or_default() { - Err(content) - } else { - Ok(content) + let result = if result.is_error.unwrap_or_default() { + Err(content) + } else { + Ok(content) + }; + + Ok(ExecutionOutcome::Completed { id, result }) + } + } + } + + /// Execute a builtin tool and return the outcome. + async fn execute_builtin( + &self, + id: String, + arguments: &Value, + answers: &IndexMap, + builtin_executors: &builtin::BuiltinExecutors, + ) -> Result { + let executor = builtin_executors + .get(&self.name) + .ok_or_else(|| ToolError::NotFound { + name: self.name.clone(), + })?; + + let outcome = executor.execute(arguments, answers).await; + + Ok(match outcome { + jp_tool::Outcome::Success { content } => ExecutionOutcome::Completed { + id, + result: Ok(content), }, + jp_tool::Outcome::Error { + message, + trace, + transient: _, + } => { + let error_msg = if trace.is_empty() { + message + } else { + format!("{message}\n\nTrace:\n{}", trace.join("\n")) + }; + ExecutionOutcome::Completed { + id, + result: Err(error_msg), + } + } + jp_tool::Outcome::NeedsInput { question } => { + ExecutionOutcome::NeedsInput { id, question } + } }) } /// Return a map of parameter names to JSON schemas. #[must_use] pub fn to_parameters_map(&self) -> Map { - let mut map = self - .parameters + self.parameters .clone() .into_iter() .map(|(k, v)| (k, v.to_json_schema())) - .collect::>(); - - if self.include_tool_answers_parameter { - map.insert( - "tool_answers".to_owned(), - json!({ - "type": ["object", "null"], - "additionalProperties": true, - "description": "Answers to the tool's questions. This should only be used if explicitly requested by the user.", - }), - ); - } - - map + .collect() } /// Return a JSON schema for the parameters of the tool. @@ -287,421 +734,117 @@ impl ToolDefinition { "required": required, }) } +} - #[expect(clippy::too_many_lines)] - async fn prepare_run( - &self, - run_mode: RunMode, - result_mode: &mut ResultMode, - arguments: &mut Value, - source: &ToolSource, - mcp_client: &jp_mcp::Client, - editor: Option<&Utf8Path>, - mut writer: PrinterWriter<'_>, - ) -> Result<(), ToolError> { - match run_mode { - RunMode::Ask => match InlineSelect::new( - { - let mut question = format!( - "Run {} {} tool", - match source { - ToolSource::Builtin { .. } => "built-in", - ToolSource::Local { .. } => "local", - ToolSource::Mcp { .. } => "mcp", - }, - self.name.as_str().bold().yellow(), - ); - - if let ToolSource::Mcp { server, tool } = source { - let tool = McpToolId::new(tool.as_ref().unwrap_or(&self.name)); - let server = server.as_ref().map(|s| McpServerId::new(s.clone())); - let server_id = mcp_client - .get_tool_server_id(&tool, server.as_ref()) - .await - .map_err(ToolError::McpGetToolError)?; - - question = format!( - "{} from {} server?", - question, - server_id.as_str().bold().blue() - ); - } +/// Split a description string into a short summary and remaining detail. +/// +/// If the text is short (single line, ≤120 chars), it is returned as the +/// summary with no remaining description. +/// +/// Otherwise, the first sentence is extracted as the summary. A sentence +/// ends at `. ` or `.\n`. The remainder becomes the description. +pub(crate) fn split_description(text: &str) -> (String, Option) { + let text = text.trim(); + + // Find the first sentence boundary. + // Look for ". " or ".\n" — a period followed by whitespace. + for (i, _) in text.match_indices('.') { + let after = i + 1; + if after >= text.len() { + // Period at end of string — the whole text is one sentence. + break; + } - question - }, - vec![ - InlineOption::new('y', "Run tool"), - InlineOption::new('n', "Skip running tool"), - InlineOption::new( - 'r', - format!( - "Change run mode (current: {})", - run_mode.to_string().italic().yellow() - ), - ), - InlineOption::new( - 'x', - format!( - "Change result mode (current: {})", - result_mode.to_string().italic().yellow() - ), - ), - InlineOption::new('p', "Print raw tool arguments"), - ], - ) - .prompt(&mut writer) - .unwrap_or('n') + let next_byte = text.as_bytes()[after]; + if next_byte == b'\n' { + // Period followed by newline is always a sentence boundary. + } else if next_byte == b' ' { + // Period followed by space: only split if the next non-space + // character is uppercase (heuristic to skip abbreviations + // like "e.g. foo"). + let rest_after_space = text[after..].trim_start(); + if rest_after_space.is_empty() + || !rest_after_space + .chars() + .next() + .is_some_and(char::is_uppercase) { - 'y' => return Ok(()), - 'n' => return Err(ToolError::Skipped { reason: None }), - 'r' => { - let new_run_mode = match InlineSelect::new("Run Mode", { - let mut options = vec![ - InlineOption::new('a', "Ask"), - InlineOption::new('u', "Unattended (Run Tool Without Changes)"), - InlineOption::new('e', "Edit Arguments"), - InlineOption::new('s', "Skip Call"), - ]; - - if editor.is_some() { - options.push(InlineOption::new('S', "Skip Call, with reasoning")); - } - - options.push(InlineOption::new('c', "Keep Current Run Mode")); - options - }) - .prompt(&mut writer) - .unwrap_or('c') - { - 'a' => RunMode::Ask, - 'u' => RunMode::Unattended, - 'e' => RunMode::Edit, - 's' => RunMode::Skip, - 'S' => match editor { - None => RunMode::Skip, - Some(editor) => { - return Err(ToolError::Skipped { - reason: Some( - open_editor::EditorCallBuilder::new() - .with_editor(open_editor::Editor::from_bin_path( - editor.into(), - )) - .edit_string( - "_Provide reasoning for skipping tool execution_", - ) - .map_err(|error| ToolError::OpenEditorError { - arguments: arguments.clone(), - error, - })?, - ), - }); - } - }, - 'c' => run_mode, - _ => unimplemented!(), - }; - - return Box::pin(self.prepare_run( - new_run_mode, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - 'x' => { - match InlineSelect::new("Result Mode", vec![ - InlineOption::new('a', "Ask"), - InlineOption::new('u', "Unattended (Delver Payload As Is)"), - InlineOption::new('e', "Edit Result Payload"), - InlineOption::new('s', "Skip (Don't Deliver Payload)"), - InlineOption::new('c', "Keep Current Result Mode"), - ]) - .prompt(&mut writer) - .unwrap_or('c') - { - 'a' => *result_mode = ResultMode::Ask, - 'u' => *result_mode = ResultMode::Unattended, - 'e' => *result_mode = ResultMode::Edit, - 's' => *result_mode = ResultMode::Skip, - 'c' => {} - _ => unimplemented!(), - } - - return Box::pin(self.prepare_run( - run_mode, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - 'p' => { - if let Err(error) = - writeln!(writer, "{}\n", serde_json::to_string_pretty(&arguments)?) - { - error!(%error, "Failed to write arguments"); - } - - return Box::pin(self.prepare_run( - RunMode::Ask, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - _ => unreachable!(), - }, - RunMode::Unattended => return Ok(()), - RunMode::Skip => return Err(ToolError::Skipped { reason: None }), - RunMode::Edit => { - let mut args = serde_json::to_string_pretty(&arguments).map_err(|error| { - ToolError::SerializeArgumentsError { - arguments: arguments.clone(), - error, - } - })?; - - *arguments = { - if let Some(editor) = editor { - open_editor::EditorCallBuilder::new() - .with_editor(open_editor::Editor::from_bin_path(editor.into())) - .edit_string_mut(&mut args) - .map_err(|error| ToolError::OpenEditorError { - arguments: arguments.clone(), - error, - })?; - } - - // If the user removed all data from the arguments, we consider the - // edit a no-op, and ask the user if they want to run the tool. - if args.trim().is_empty() { - return Box::pin(self.prepare_run( - RunMode::Ask, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - - match serde_json::from_str::(&args) { - Ok(value) => value, - - // If we can't parse the arguments as valid JSON, we consider - // the input invalid, and ask the user if they want to re-open - // the editor. - Err(error) => { - if let Err(error) = writeln!(writer, "JSON parsing error: {error}") { - error!(%error, "Failed to write error"); - } - - let retry = InlineSelect::new("Re-open editor?", vec![ - InlineOption::new('y', "Open editor to edit arguments"), - InlineOption::new('n', "Skip editing, failing with error"), - ]) - .with_default('y') - .prompt(&mut writer) - .unwrap_or('n'); - - if retry == 'n' { - return Err(ToolError::EditArgumentsError { - arguments: arguments.clone(), - error, - }); - } - - return Box::pin(self.prepare_run( - RunMode::Edit, - result_mode, - arguments, - source, - mcp_client, - editor, - writer, - )) - .await; - } - } - }; + continue; } + } else { + continue; } - Ok(()) - } + { + let summary = text[..=i].trim().to_owned(); + let rest = text[after..].trim(); - fn prepare_result( - &self, - mut result: ToolCallResponse, - result_mode: ResultMode, - editor: Option<&Utf8Path>, - mut writer: PrinterWriter<'_>, - ) -> Result { - match result_mode { - ResultMode::Ask => match InlineSelect::new( - format!( - "Deliver the results of the {} tool call?", - self.name.as_str().bold().yellow(), - ), - vec![ - InlineOption::new('y', "Deliver results"), - InlineOption::new('n', "Do not deliver results"), - InlineOption::new('e', "Edit results manually"), - ], - ) - .with_default('y') - .prompt(&mut writer) - .unwrap_or('n') - { - 'y' => return Ok(result), - 'n' => { - return Ok(ToolCallResponse { - id: result.id, - result: Ok("Tool call result omitted by user.".into()), - }); - } - 'e' => {} - _ => unreachable!(), - }, - ResultMode::Unattended => return Ok(result), - ResultMode::Skip => { - return Ok(ToolCallResponse { - id: result.id, - result: Ok("Tool ran successfully.".into()), - }); + if rest.is_empty() { + return (summary, None); } - ResultMode::Edit => {} + + return (summary, Some(rest.to_owned())); } + } - if let Some(editor) = editor { - let content = open_editor::EditorCallBuilder::new() - .with_editor(open_editor::Editor::from_bin_path(editor.into())) - .edit_string(result.content()) - .map_err(|error| ToolError::OpenEditorError { - arguments: Value::Null, - error, - })?; - - // If the user removed all data from the result, we consider the edit a - // no-op, and ask the user if they want to deliver the tool results. - if content.trim().is_empty() { - return self.prepare_result(result, ResultMode::Ask, Some(editor), writer); - } + // No sentence boundary found — take the first line. + if let Some(nl) = text.find('\n') { + let summary = text[..nl].trim().to_owned(); + let rest = text[nl..].trim(); - result.result = Ok(content); + if rest.is_empty() { + return (summary, None); } - Ok(result) + return (summary, Some(rest.to_owned())); } -} - -fn run_cmd_with_ctx( - name: &str, - command: &ToolCommandConfig, - ctx: &Value, - root: &Utf8Path, -) -> Result, ToolError> { - let command = { - let tmpl = Arc::new(Environment::new()); - - let program = - tmpl.render_str(&command.program, ctx) - .map_err(|error| ToolError::TemplateError { - data: command.program.clone(), - error, - })?; - - let args = command - .args - .iter() - .map(|s| tmpl.render_str(s, ctx)) - .collect::, _>>() - .map_err(|error| ToolError::TemplateError { - data: command.args.join(" ").clone(), - error, - })?; - - let expression = if command.shell { - let cmd = std::iter::once(program.clone()) - .chain(args.iter().cloned()) - .collect::>() - .join(" "); - duct_sh::sh_dangerous(cmd) - } else { - duct::cmd(program.clone(), args) - }; + // Single long line, no period — return as-is. + (text.to_owned(), None) +} - expression - .dir(root) - .unchecked() - .stdout_capture() - .stderr_capture() - }; +/// Fill in configured default values for missing parameters. +/// +/// LLMs commonly omit parameters that have a `default` in the JSON schema, +/// even when those parameters are marked `required`. This function patches +/// the arguments map before validation so that such omissions don't cause +/// spurious "missing argument" errors and unnecessary LLM retries. +fn apply_parameter_defaults( + arguments: &mut Map, + parameters: &IndexMap, +) { + for (name, cfg) in parameters { + if !arguments.contains_key(name) { + if let Some(default) = &cfg.default { + arguments.insert(name.clone(), default.clone()); + } + continue; + } - match command.run() { - Ok(output) => { - let stdout = String::from_utf8_lossy(&output.stdout); - let content = match serde_json::from_str::(&stdout) { - Err(_) => stdout.to_string(), - Ok(Outcome::Error { - transient, - message, - trace, - }) => { - if transient { - return Ok(Err(json!({ - "message": message, - "trace": trace, - }) - .to_string())); - } + // Recurse into object fields. + if let Some(obj) = arguments.get_mut(name).and_then(Value::as_object_mut) + && !cfg.properties.is_empty() + { + apply_parameter_defaults(obj, &cfg.properties); + } - return Err(ToolError::ToolCallFailed(stdout.to_string())); - } - Ok(Outcome::Success { content }) => content, - Ok(Outcome::NeedsInput { question }) => { - return Err(ToolError::NeedsInput { question }); + // Recurse into array elements. + if let Some(items) = &cfg.items + && !items.properties.is_empty() + && let Some(arr) = arguments.get_mut(name).and_then(Value::as_array_mut) + { + for elem in arr.iter_mut() { + if let Some(obj) = elem.as_object_mut() { + apply_parameter_defaults(obj, &items.properties); } - }; - - if output.status.success() { - Ok(Ok(content)) - } else { - let stderr = String::from_utf8_lossy(&output.stderr); - Ok(Err(json!({ - "message": format!("Tool '{name}' execution failed."), - "stderr": stderr, - "stdout": content, - }) - .to_string())) } } - Err(error) => Ok(Err(json!({ - "message": format!( - "Failed to execute command '{command:?}': {error}", - ), - }) - .to_string())), } } fn validate_tool_arguments( arguments: &Map, - parameters: &IndexMap, + parameters: &IndexMap, ) -> Result<(), ToolError> { let unknown = arguments .keys() @@ -710,8 +853,8 @@ fn validate_tool_arguments( .collect::>(); let mut missing = vec![]; - for (name, required) in parameters { - if *required && !arguments.contains_key(name) { + for (name, cfg) in parameters { + if cfg.required && !arguments.contains_key(name) { missing.push(name.to_owned()); } } @@ -720,51 +863,241 @@ fn validate_tool_arguments( return Err(ToolError::Arguments { missing, unknown }); } + // Recurse into nested structures. + for (name, cfg) in parameters { + let Some(value) = arguments.get(name) else { + continue; + }; + + // Object parameters with properties: validate the object fields. + if let Some(obj) = value.as_object() + && !cfg.properties.is_empty() + { + validate_tool_arguments(obj, &cfg.properties)?; + } + + // Array parameters with items that have properties: validate each + // element. + if let Some(items) = &cfg.items + && !items.properties.is_empty() + && let Some(arr) = value.as_array() + { + for element in arr { + if let Some(obj) = element.as_object() { + validate_tool_arguments(obj, &items.properties)?; + } + } + } + } + Ok(()) } +/// Resolved tool definitions and their on-demand documentation. +pub struct ResolvedTools { + /// Tool definitions sent to the LLM provider. + pub definitions: Vec, + + /// Per-tool documentation for `describe_tools`, keyed by tool name. + pub docs: IndexMap, +} + pub async fn tool_definitions( configs: impl Iterator, mcp_client: &jp_mcp::Client, -) -> Result, ToolError> { +) -> Result { let mut definitions = vec![]; + let mut docs = IndexMap::new(); + for (name, config) in configs { // Skip disabled tools. if !config.enable() { continue; } - definitions.push( - ToolDefinition::new( + let (definition, tool_docs) = resolve_tool(name, &config, mcp_client).await?; + + if !tool_docs.is_empty() { + docs.insert(name.to_owned(), tool_docs); + } + + definitions.push(definition); + } + + Ok(ResolvedTools { definitions, docs }) +} + +/// Resolve a single tool definition and its documentation. +async fn resolve_tool( + name: &str, + config: &ToolConfigWithDefaults, + mcp_client: &jp_mcp::Client, +) -> Result<(ToolDefinition, ToolDocs), ToolError> { + match config.source() { + ToolSource::Local { .. } | ToolSource::Builtin { .. } => { + // For local/builtin tools, docs come from config fields. + let definition = ToolDefinition::new( name, config.source(), - config.description().map(str::to_owned), + config.summary().map(str::to_owned), config.parameters().clone(), - config.questions(), mcp_client, ) - .await?, - ); + .await?; + + let tool_docs = docs_from_config(config); + Ok((definition, tool_docs)) + } + ToolSource::Mcp { .. } => { + // For MCP tools, resolve against the server, then build docs + // from config overrides + auto-split of MCP descriptions. + resolve_mcp_tool(name, config, mcp_client).await + } + } +} + +/// Build `ToolDocs` from config fields (local/builtin tools). +fn docs_from_config(config: &ToolConfigWithDefaults) -> ToolDocs { + let summary = config.summary().map(str::to_owned); + let description = config.description().map(str::to_owned); + let examples = config.examples().map(str::to_owned); + + let parameters = config + .parameters() + .iter() + .filter_map(|(param_name, param_cfg)| { + let summary = param_cfg + .summary + .as_deref() + .or(param_cfg.description.as_deref()) + .map(str::to_owned); + let desc = param_cfg.description.as_deref().map(str::to_owned); + let ex = param_cfg.examples.as_deref().map(str::to_owned); + + if summary.is_none() && desc.is_none() && ex.is_none() { + return None; + } + + Some((param_name.to_owned(), ParameterDocs { + summary, + description: desc, + examples: ex, + })) + }) + .collect(); + + ToolDocs { + summary, + description, + examples, + parameters, + } +} + +/// Resolve an MCP tool: build definition + docs with auto-split heuristic. +async fn resolve_mcp_tool( + name: &str, + config: &ToolConfigWithDefaults, + mcp_client: &jp_mcp::Client, +) -> Result<(ToolDefinition, ToolDocs), ToolError> { + let has_user_summary = config.summary().is_some(); + + let definition = ToolDefinition::new( + name, + config.source(), + config.summary().map(str::to_owned), + config.parameters().clone(), + mcp_client, + ) + .await?; + + // Build docs. If the user provided summary/description/examples in + // config, use those. Otherwise, auto-split the resolved MCP description. + let (summary, description) = if has_user_summary { + // User provided explicit summary — use config fields as-is. + ( + config.summary().map(str::to_owned), + config.description().map(str::to_owned), + ) + } else if let Some(resolved) = &definition.description { + // No user summary — auto-split the MCP description. + let (s, d) = split_description(resolved); + (Some(s), d) + } else { + (None, None) + }; + + let examples = config.examples().map(str::to_owned); + + // Build parameter docs. For each parameter, check if the user + // provided an override. If not, auto-split the MCP description. + let parameters = definition + .parameters + .iter() + .filter_map(|(param_name, param_cfg)| { + let user_override = config.parameters().get(param_name); + let has_user_param_summary = user_override.and_then(|o| o.summary.as_ref()).is_some(); + + let (summary, desc) = if has_user_param_summary { + // User provided explicit summary for this parameter. + let summary = user_override + .and_then(|o| o.summary.as_deref()) + .or(user_override.and_then(|o| o.description.as_deref())) + .map(str::to_owned); + let desc = user_override + .and_then(|o| o.description.as_deref()) + .map(str::to_owned); + (summary, desc) + } else if let Some(resolved) = ¶m_cfg.description { + // Auto-split the resolved (possibly MCP) description. + let (s, d) = split_description(resolved); + (Some(s), d) + } else { + (None, None) + }; + + let ex = user_override + .and_then(|o| o.examples.as_deref()) + .map(str::to_owned); + + if summary.is_none() && desc.is_none() && ex.is_none() { + return None; + } + + Some((param_name.to_owned(), ParameterDocs { + summary, + description: desc, + examples: ex, + })) + }) + .collect(); + + // The definition's description should be the summary (short) for the + // provider API. Replace it if we auto-split. + let mut definition = definition; + if !has_user_summary && let Some(ref s) = summary { + definition.description = Some(s.clone()); } - Ok(definitions) + let tool_docs = ToolDocs { + summary, + description, + examples, + parameters, + }; + + Ok((definition, tool_docs)) } fn local_tool_definition( name: String, description: Option, parameters: IndexMap, - questions: &IndexMap, ) -> ToolDefinition { - let include_tool_answers_parameter = questions - .iter() - .any(|(_, v)| v.target == QuestionTarget::Assistant); - ToolDefinition { name, description, parameters, - include_tool_answers_parameter, } } @@ -901,11 +1234,13 @@ async fn mcp_tool_definition( params.insert(name.to_owned(), ToolParameterConfig { kind, default, - description, required, + summary: None, + description, + examples: None, enumeration, items: opts.get("items").and_then(|v| v.as_object()).and_then(|v| { - Some(ToolParameterItemConfig { + Some(Box::new(ToolParameterConfig { kind: match v.get("type")? { Value::String(v) => OneOrManyTypes::One(v.to_owned()), Value::Array(v) => OneOrManyTypes::Many( @@ -917,10 +1252,16 @@ async fn mcp_tool_definition( _ => return None, }, default: None, + required: false, + summary: None, description: None, + examples: None, enumeration: vec![], - }) + items: None, + properties: IndexMap::default(), + })) }), + properties: IndexMap::default(), }); } @@ -928,65 +1269,9 @@ async fn mcp_tool_definition( name: name.to_owned(), description, parameters: params, - include_tool_answers_parameter: false, }) } #[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_validate_tool_arguments() { - struct TestCase { - arguments: Map, - parameters: IndexMap, - want: Result<(), ToolError>, - } - - let cases = vec![ - ("empty", TestCase { - arguments: Map::new(), - parameters: IndexMap::new(), - want: Ok(()), - }), - ("correct", TestCase { - arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), - parameters: IndexMap::from_iter([ - ("foo".to_owned(), true), - ("bar".to_owned(), false), - ]), - want: Ok(()), - }), - ("missing", TestCase { - arguments: Map::new(), - parameters: IndexMap::from_iter([("foo".to_owned(), true)]), - want: Err(ToolError::Arguments { - missing: vec!["foo".to_owned()], - unknown: vec![], - }), - }), - ("unknown", TestCase { - arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), - parameters: IndexMap::from_iter([("bar".to_owned(), false)]), - want: Err(ToolError::Arguments { - missing: vec![], - unknown: vec!["foo".to_owned()], - }), - }), - ("both", TestCase { - arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), - parameters: IndexMap::from_iter([("bar".to_owned(), true)]), - want: Err(ToolError::Arguments { - missing: vec!["bar".to_owned()], - unknown: vec!["foo".to_owned()], - }), - }), - ]; - - for (name, test_case) in cases { - let result = validate_tool_arguments(&test_case.arguments, &test_case.parameters); - assert_eq!(result, test_case.want, "failed case: {name}"); - } - } -} +#[path = "tool_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/tool/builtin.rs b/crates/jp_llm/src/tool/builtin.rs new file mode 100644 index 00000000..b247c43c --- /dev/null +++ b/crates/jp_llm/src/tool/builtin.rs @@ -0,0 +1,45 @@ +//! Builtin tool trait and executor registry. +//! +//! Maps tool names to their Rust implementations. + +pub mod describe_tools; + +use std::{collections::HashMap, sync::Arc}; + +use async_trait::async_trait; +use indexmap::IndexMap; +use jp_tool::Outcome; +use serde_json::Value; + +/// A built-in tool that executes Rust code instead of shelling out. +#[async_trait] +pub trait BuiltinTool: Send + Sync { + /// Execute the tool with the given arguments and accumulated answers. + async fn execute(&self, arguments: &Value, answers: &IndexMap) -> Outcome; +} + +/// Registry mapping builtin tool names to their executors. +#[derive(Clone, Default)] +pub struct BuiltinExecutors { + executors: HashMap>, +} + +impl BuiltinExecutors { + #[must_use] + pub fn new() -> Self { + Self { + executors: HashMap::new(), + } + } + + #[must_use] + pub fn register(mut self, name: impl Into, tool: impl BuiltinTool + 'static) -> Self { + self.executors.insert(name.into(), Arc::new(tool)); + self + } + + #[must_use] + pub fn get(&self, name: &str) -> Option> { + self.executors.get(name).cloned() + } +} diff --git a/crates/jp_llm/src/tool/builtin/describe_tools.rs b/crates/jp_llm/src/tool/builtin/describe_tools.rs new file mode 100644 index 00000000..2d4ab907 --- /dev/null +++ b/crates/jp_llm/src/tool/builtin/describe_tools.rs @@ -0,0 +1,126 @@ +//! The `describe_tools` builtin implementation. + +use async_trait::async_trait; +use indexmap::IndexMap; +use jp_tool::Outcome; +use serde_json::Value; + +use crate::tool::{BuiltinTool, ToolDocs}; + +pub struct DescribeTools { + docs: IndexMap, +} + +impl DescribeTools { + #[must_use] + pub fn new(docs: IndexMap) -> Self { + Self { docs } + } + + fn format_tool_docs(name: &str, docs: &ToolDocs) -> String { + let mut out = format!("## {name}\n"); + + if let Some(summary) = &docs.summary { + out.push('\n'); + out.push_str(summary); + out.push('\n'); + } + + if let Some(desc) = &docs.description { + out.push('\n'); + out.push_str(desc); + out.push('\n'); + } + + if let Some(examples) = &docs.examples { + out.push_str("\n### Examples\n\n"); + out.push_str(examples); + out.push('\n'); + } + + let has_param_docs = docs.parameters.values().any(|p| !p.is_empty()); + if has_param_docs { + out.push_str("\n### Parameters\n"); + + for (param_name, param_docs) in &docs.parameters { + if param_docs.is_empty() { + continue; + } + + out.push_str(&format!("\n#### `{param_name}`\n")); + + if let Some(summary) = ¶m_docs.summary { + out.push('\n'); + out.push_str(summary); + out.push('\n'); + } + + if let Some(desc) = ¶m_docs.description { + out.push('\n'); + out.push_str(desc); + out.push('\n'); + } + + if let Some(examples) = ¶m_docs.examples { + out.push('\n'); + out.push_str(examples); + out.push('\n'); + } + } + } + + out + } +} + +#[async_trait] +impl BuiltinTool for DescribeTools { + async fn execute(&self, arguments: &Value, _answers: &IndexMap) -> Outcome { + let tool_names = match arguments.get("tools").and_then(Value::as_array) { + Some(arr) => arr.iter().filter_map(Value::as_str).collect::>(), + None => { + return Outcome::Error { + message: "Missing or invalid `tools` parameter.".to_owned(), + trace: vec![], + transient: false, + }; + } + }; + + if tool_names.is_empty() { + return Outcome::Error { + message: "The `tools` array must not be empty.".to_owned(), + trace: vec![], + transient: false, + }; + } + + let mut sections = Vec::new(); + let mut not_found = Vec::new(); + + for name in &tool_names { + match self.docs.get(*name) { + Some(docs) => sections.push(Self::format_tool_docs(name, docs)), + None => not_found.push(*name), + } + } + + let mut output = sections.join("\n---\n\n"); + + if !not_found.is_empty() { + if !output.is_empty() { + output.push_str("\n---\n\n"); + } + output.push_str(&format!( + "No additional documentation available for: {}", + not_found.join(", ") + )); + } + + Outcome::Success { content: output } + } +} + +#[cfg(test)] +#[path = "describe_tools_tests.rs"] +mod tests; diff --git a/crates/jp_llm/src/tool/builtin/describe_tools_tests.rs b/crates/jp_llm/src/tool/builtin/describe_tools_tests.rs new file mode 100644 index 00000000..e250ef25 --- /dev/null +++ b/crates/jp_llm/src/tool/builtin/describe_tools_tests.rs @@ -0,0 +1,343 @@ +use indexmap::IndexMap; +use jp_tool::Outcome; +use serde_json::{Value, json}; + +use super::*; +use crate::tool::{ParameterDocs, ToolDocs}; + +fn empty_tool_docs() -> ToolDocs { + ToolDocs { + summary: None, + description: None, + examples: None, + parameters: IndexMap::new(), + } +} + +fn param_docs( + summary: Option<&str>, + description: Option<&str>, + examples: Option<&str>, +) -> ParameterDocs { + ParameterDocs { + summary: summary.map(str::to_owned), + description: description.map(str::to_owned), + examples: examples.map(str::to_owned), + } +} + +fn no_answers() -> IndexMap { + IndexMap::new() +} + +#[test] +fn test_format_empty_docs() { + let out = DescribeTools::format_tool_docs("my_tool", &empty_tool_docs()); + assert_eq!(out, "## my_tool\n"); +} + +#[test] +fn test_format_summary_only() { + let docs = ToolDocs { + summary: Some("A brief summary.".to_owned()), + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!(out, "## my_tool\n\nA brief summary.\n"); +} + +#[test] +fn test_format_description_only() { + let docs = ToolDocs { + description: Some("Detailed description.".to_owned()), + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!(out, "## my_tool\n\nDetailed description.\n"); +} + +#[test] +fn test_format_summary_and_description() { + let docs = ToolDocs { + summary: Some("Summary.".to_owned()), + description: Some("Description.".to_owned()), + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!(out, "## my_tool\n\nSummary.\n\nDescription.\n"); +} + +#[test] +fn test_format_all_tool_fields() { + let docs = ToolDocs { + summary: Some("Summary.".to_owned()), + description: Some("Description.".to_owned()), + examples: Some("my_tool()".to_owned()), + parameters: IndexMap::new(), + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!( + out, + "## my_tool\n\nSummary.\n\nDescription.\n\n### Examples\n\nmy_tool()\n" + ); +} + +#[test] +fn test_format_parameter_with_description() { + let mut parameters = IndexMap::new(); + parameters.insert( + "input".to_owned(), + param_docs(Some("Short summary."), Some("Long description."), None), + ); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!( + out, + "## my_tool\n\n### Parameters\n\n#### `input`\n\nShort summary.\n\nLong description.\n" + ); +} + +#[test] +fn test_format_parameter_with_examples() { + let mut parameters = IndexMap::new(); + parameters.insert( + "query".to_owned(), + param_docs(None, Some("The search query."), Some("\"hello world\"")), + ); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!( + out, + "## my_tool\n\n### Parameters\n\n#### `query`\n\nThe search query.\n\n\"hello world\"\n" + ); +} + +#[test] +fn test_format_parameter_all_fields() { + let mut parameters = IndexMap::new(); + parameters.insert( + "path".to_owned(), + param_docs( + Some("File path."), + Some("Absolute or relative path to the target file."), + Some("\"src/main.rs\""), + ), + ); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert_eq!( + out, + "## my_tool\n\n### Parameters\n\n#### `path`\n\nFile path.\n\nAbsolute or relative path \ + to the target file.\n\n\"src/main.rs\"\n" + ); +} + +#[test] +fn test_format_skips_empty_parameters() { + // is_empty() checks description and examples only, so a param with only + // summary is still "empty" and is excluded from the output. + let mut parameters = IndexMap::new(); + parameters.insert( + "documented".to_owned(), + param_docs(None, Some("Has description."), None), + ); + parameters.insert("undocumented".to_owned(), param_docs(None, None, None)); + parameters.insert( + "summary_only".to_owned(), + param_docs(Some("Only a summary."), None, None), + ); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert!( + out.contains("#### `documented`"), + "documented param should appear" + ); + assert!( + !out.contains("#### `undocumented`"), + "undocumented param should be skipped" + ); + assert!( + !out.contains("#### `summary_only`"), + "summary-only param should be skipped" + ); +} + +#[test] +fn test_format_no_parameters_section_when_all_params_empty() { + // If every parameter is empty, the "### Parameters" section is omitted. + let mut parameters = IndexMap::new(); + parameters.insert("a".to_owned(), param_docs(None, None, None)); + parameters.insert("b".to_owned(), param_docs(Some("summary only"), None, None)); + + let docs = ToolDocs { + parameters, + ..empty_tool_docs() + }; + let out = DescribeTools::format_tool_docs("my_tool", &docs); + assert!(!out.contains("### Parameters")); + assert_eq!(out, "## my_tool\n"); +} + +#[tokio::test] +async fn test_execute_missing_tools_argument() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool.execute(&json!({}), &no_answers()).await; + let Outcome::Error { + message, transient, .. + } = result + else { + panic!("expected Outcome::Error"); + }; + assert!(message.contains("`tools`")); + assert!(!transient); +} + +#[tokio::test] +async fn test_execute_tools_not_an_array() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool + .execute(&json!({"tools": "my_tool"}), &no_answers()) + .await; + assert!( + matches!(result, Outcome::Error { .. }), + "non-array `tools` should be an error" + ); +} + +#[tokio::test] +async fn test_execute_empty_tools_array() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool.execute(&json!({"tools": []}), &no_answers()).await; + let Outcome::Error { message, .. } = result else { + panic!("expected Outcome::Error"); + }; + assert!(message.contains("must not be empty")); +} + +#[tokio::test] +async fn test_execute_single_known_tool() { + let mut docs = IndexMap::new(); + docs.insert("my_tool".to_owned(), ToolDocs { + summary: Some("Tool summary.".to_owned()), + ..empty_tool_docs() + }); + + let tool = DescribeTools::new(docs); + let result = tool + .execute(&json!({"tools": ["my_tool"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!(content, "## my_tool\n\nTool summary.\n"); +} + +#[tokio::test] +async fn test_execute_known_tool_with_empty_docs() { + let mut docs = IndexMap::new(); + docs.insert("bare_tool".to_owned(), empty_tool_docs()); + + let tool = DescribeTools::new(docs); + let result = tool + .execute(&json!({"tools": ["bare_tool"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!(content, "## bare_tool\n"); +} + +#[tokio::test] +async fn test_execute_multiple_known_tools_separated_by_divider() { + let mut docs = IndexMap::new(); + docs.insert("tool_a".to_owned(), ToolDocs { + summary: Some("A.".to_owned()), + ..empty_tool_docs() + }); + docs.insert("tool_b".to_owned(), ToolDocs { + summary: Some("B.".to_owned()), + ..empty_tool_docs() + }); + + let tool = DescribeTools::new(docs); + let result = tool + .execute(&json!({"tools": ["tool_a", "tool_b"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!(content, "## tool_a\n\nA.\n\n---\n\n## tool_b\n\nB.\n"); +} + +#[tokio::test] +async fn test_execute_single_unknown_tool() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool + .execute(&json!({"tools": ["unknown_tool"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!( + content, + "No additional documentation available for: unknown_tool" + ); +} + +#[tokio::test] +async fn test_execute_multiple_unknown_tools() { + let tool = DescribeTools::new(IndexMap::new()); + let result = tool + .execute(&json!({"tools": ["foo", "bar"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!( + content, + "No additional documentation available for: foo, bar" + ); +} + +#[tokio::test] +async fn test_execute_mixed_known_and_unknown_tools() { + let mut docs = IndexMap::new(); + docs.insert("known".to_owned(), ToolDocs { + summary: Some("Summary.".to_owned()), + ..empty_tool_docs() + }); + + let tool = DescribeTools::new(docs); + let result = tool + .execute(&json!({"tools": ["known", "unknown"]}), &no_answers()) + .await; + + let Outcome::Success { content } = result else { + panic!("expected Outcome::Success"); + }; + assert_eq!( + content, + "## known\n\nSummary.\n\n---\n\nNo additional documentation available for: unknown" + ); +} diff --git a/crates/jp_llm/src/tool/executor.rs b/crates/jp_llm/src/tool/executor.rs new file mode 100644 index 00000000..dbde0514 --- /dev/null +++ b/crates/jp_llm/src/tool/executor.rs @@ -0,0 +1,346 @@ +use std::sync::Mutex; + +use async_trait::async_trait; +use camino::Utf8Path; +use indexmap::IndexMap; +use jp_config::conversation::tool::{RunMode, ToolSource, ToolsConfig}; +use jp_conversation::event::{ToolCallRequest, ToolCallResponse}; +use jp_mcp::Client; +use jp_tool::Question; +use serde_json::{Map, Value}; +use tokio_util::sync::CancellationToken; + +use crate::ToolError; + +/// Trait for tool execution, enabling mock implementations for testing. +/// +/// This trait abstracts the execution of a single tool call, allowing the +/// `ToolCoordinator` to work with both real and mock executors. +/// +/// # Design +/// +/// The executor is intentionally simple - it just executes tools with given +/// answers. All decision-making about question targets, static answers, and how +/// to handle `NeedsInput` is done by the coordinator, which has access to the +/// tool configuration. +#[async_trait] +pub trait Executor: Send + Sync { + /// Returns the tool call ID. + fn tool_id(&self) -> &str; + + /// Returns the tool name. + fn tool_name(&self) -> &str; + + /// Returns the tool call arguments. + /// + /// This is separate from [`permission_info()`](Self::permission_info) + /// because arguments are always available, while permission info is only + /// present for tools that require a permission prompt. + fn arguments(&self) -> &Map; + + /// Returns information needed for permission prompting. + /// + /// Returns `None` if the tool doesn't need a permission prompt (e.g., + /// `RunMode::Unattended` or `RunMode::Skip`). + fn permission_info(&self) -> Option; + + /// Updates the arguments to use for execution. + /// + /// This is called after permission prompting if the user edited the + /// arguments (via `RunMode::Edit`). The new arguments replace the original + /// arguments from the tool call request. + fn set_arguments(&mut self, args: Value); + + /// Executes the tool once with the given answers. + /// + /// This method performs a single execution pass. If the tool needs + /// additional input, it returns `ExecutorResult::NeedsInput` and the + /// coordinator handles prompting and retrying. + /// + /// The executor doesn't know how questions should be answered - it just + /// reports that input is needed. The coordinator looks up the tool + /// configuration to determine whether to prompt the user or ask the LLM. + /// + /// # Arguments + /// + /// * `answers` - Accumulated answers from previous `NeedsInput` responses + /// * `mcp_client` - MCP client for remote tool execution + /// * `root` - Project root directory + /// * `cancellation_token` - Token to cancel execution + async fn execute( + &self, + answers: &IndexMap, + mcp_client: &Client, + root: &Utf8Path, + cancellation_token: CancellationToken, + ) -> ExecutorResult; +} + +/// Abstraction over how executors are created for tool calls. +/// +/// This trait enables dependency injection of executor creation, allowing tests +/// to use mock executors without executing real shell commands. +#[async_trait] +pub trait ExecutorSource: Send + Sync { + /// Creates an executor for the given tool call request. + /// + /// # Arguments + /// + /// * `request` - The tool call request from the LLM + /// * `tools_config` - Configuration for all tools + /// * `mcp_client` - MCP client for remote tool execution + /// + /// # Errors + /// + /// Returns an error if the tool is not found or cannot be initialized. + async fn create( + &self, + request: ToolCallRequest, + tools_config: &ToolsConfig, + mcp_client: &Client, + ) -> Result, ToolError>; +} + +/// Result of a tool execution attempt. +/// +/// Tools may need multiple rounds of execution if they require additional +/// input. This enum allows the executor to return control to the coordinator, +/// which decides how to handle the `NeedsInput` case by looking up the question +/// configuration. +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] // NeedsInput variant is larger but rarely used +pub enum ExecutorResult { + /// Tool completed (success or error). + Completed(ToolCallResponse), + + /// Tool needs additional input before it can continue. + /// + /// The executor doesn't know who should answer - it just reports that input + /// is needed. The coordinator looks up the question configuration to + /// determine the target: + /// + /// - `User`: Prompt the user interactively, then restart the tool + /// - `Assistant`: Format a response asking the LLM to re-run with answers + NeedsInput { + /// Tool call ID. + tool_id: String, + + /// Tool name (for persisting answers). + tool_name: String, + + /// The question that needs to be answered. + question: Question, + + /// Accumulated answers so far (for retry). + accumulated_answers: IndexMap, + }, +} + +/// A mock executor for testing that returns pre-configured results. +/// +/// This executor doesn't execute any real commands - it simply returns whatever +/// result is configured, making it ideal for testing tool coordination flows +/// without side effects. +/// +/// # Example +/// +/// ```ignore +/// let executor = MockExecutor::completed("call_1", "my_tool", "success output"); +/// let result = executor.execute(&answers, &client, &root, token).await; +/// assert!(result.is_completed()); +/// ``` +pub struct MockExecutor { + tool_id: String, + tool_name: String, + arguments: Map, + permission_info: Option, + result: Mutex>, +} + +impl MockExecutor { + /// Creates a mock executor that returns a successful completion. + #[must_use] + pub fn completed(tool_id: &str, tool_name: &str, output: &str) -> Self { + Self { + tool_id: tool_id.to_string(), + tool_name: tool_name.to_string(), + arguments: Map::new(), + permission_info: None, + result: Mutex::new(Some(ExecutorResult::Completed(ToolCallResponse { + id: tool_id.to_string(), + result: Ok(output.to_string()), + }))), + } + } + + /// Creates a mock executor that returns an error. + #[must_use] + pub fn error(tool_id: &str, tool_name: &str, error: &str) -> Self { + Self { + tool_id: tool_id.to_string(), + tool_name: tool_name.to_string(), + arguments: Map::new(), + permission_info: None, + result: Mutex::new(Some(ExecutorResult::Completed(ToolCallResponse { + id: tool_id.to_string(), + result: Err(error.to_string()), + }))), + } + } + + /// Sets the arguments for this executor. + #[must_use] + pub fn with_arguments(mut self, args: Map) -> Self { + self.arguments = args; + self + } + + /// Sets the permission info for this executor. + /// + /// If set, the executor will require permission prompting based on the + /// configured `RunMode`. + #[must_use] + pub fn with_permission_info(mut self, info: PermissionInfo) -> Self { + self.permission_info = Some(info); + self + } + + /// Sets a custom result for this executor. + #[must_use] + pub fn with_result(mut self, result: ExecutorResult) -> Self { + self.result = Mutex::new(Some(result)); + self + } +} + +#[async_trait] +impl Executor for MockExecutor { + fn tool_id(&self) -> &str { + &self.tool_id + } + + fn tool_name(&self) -> &str { + &self.tool_name + } + + fn arguments(&self) -> &Map { + &self.arguments + } + + fn permission_info(&self) -> Option { + self.permission_info.clone() + } + + fn set_arguments(&mut self, _args: Value) { + // No-op for mock executor - arguments don't affect the pre-configured + // result + } + + async fn execute( + &self, + _answers: &IndexMap, + _mcp_client: &Client, + _root: &Utf8Path, + _cancellation_token: CancellationToken, + ) -> ExecutorResult { + self.result.lock().unwrap().take().unwrap_or_else(|| { + ExecutorResult::Completed(ToolCallResponse { + id: self.tool_id.clone(), + result: Err("MockExecutor: result already consumed".to_string()), + }) + }) + } +} + +/// An executor source for testing that returns pre-registered mock executors. +/// +/// This allows tests to inject mock executors for specific tool names without +/// executing any real shell commands. +/// +/// # Example +/// +/// ```ignore +/// let source = TestExecutorSource::new() +/// .with_executor("my_tool", |req| { +/// Box::new(MockExecutor::completed(&req.id, &req.name, "mock output")) +/// }); +/// +/// let coordinator = ToolCoordinator::new(tools_config, Arc::new(source)); +/// ``` +pub struct TestExecutorSource { + #[allow(clippy::type_complexity)] + factories: std::collections::HashMap< + String, + Box Box + Send + Sync>, + >, +} + +impl TestExecutorSource { + /// Creates a new empty test executor source. + #[must_use] + pub fn new() -> Self { + Self { + factories: std::collections::HashMap::new(), + } + } + + /// Registers a factory function for a tool name. + /// + /// When `create()` is called for this tool name, the factory will be + /// invoked to create the executor. + #[must_use] + pub fn with_executor(mut self, tool_name: &str, factory: F) -> Self + where + F: Fn(ToolCallRequest) -> Box + Send + Sync + 'static, + { + self.factories + .insert(tool_name.to_string(), Box::new(factory)); + self + } +} + +impl Default for TestExecutorSource { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ExecutorSource for TestExecutorSource { + async fn create( + &self, + request: ToolCallRequest, + _tools_config: &ToolsConfig, + _mcp_client: &Client, + ) -> Result, ToolError> { + if let Some(factory) = self.factories.get(&request.name) { + Ok(factory(request)) + } else { + Err(ToolError::NotFound { + name: request.name.clone(), + }) + } + } +} + +/// Information needed to prompt for tool execution permission. +/// +/// This struct contains all the data the `ToolPrompter` needs to show a +/// permission prompt to the user. +#[derive(Debug, Clone)] +pub struct PermissionInfo { + /// The tool call ID. + pub tool_id: String, + + /// The tool name. + pub tool_name: String, + + /// The tool source (builtin, local, MCP). + pub tool_source: ToolSource, + + /// The configured run mode. + pub run_mode: RunMode, + + /// The arguments to pass to the tool. + pub arguments: Value, +} diff --git a/crates/jp_llm/src/tool_tests.rs b/crates/jp_llm/src/tool_tests.rs new file mode 100644 index 00000000..d7e19e5b --- /dev/null +++ b/crates/jp_llm/src/tool_tests.rs @@ -0,0 +1,567 @@ +use jp_tool::AnswerType; + +use super::*; + +#[test] +fn test_execution_outcome_completed_success_into_response() { + let outcome = ExecutionOutcome::Completed { + id: "call_123".to_string(), + result: Ok("Tool output".to_string()), + }; + + let response = outcome.into_response(); + assert_eq!(response.id, "call_123"); + assert_eq!(response.result, Ok("Tool output".to_string())); +} + +#[test] +fn test_execution_outcome_completed_error_into_response() { + let outcome = ExecutionOutcome::Completed { + id: "call_456".to_string(), + result: Err("Tool failed".to_string()), + }; + + let response = outcome.into_response(); + assert_eq!(response.id, "call_456"); + assert_eq!(response.result, Err("Tool failed".to_string())); +} + +#[test] +fn test_execution_outcome_needs_input_into_response() { + let question = Question { + id: "q1".to_string(), + text: "What is your name?".to_string(), + answer_type: AnswerType::Text, + default: None, + }; + + let outcome = ExecutionOutcome::NeedsInput { + id: "call_789".to_string(), + question, + }; + + let response = outcome.into_response(); + assert_eq!(response.id, "call_789"); + assert!(response.result.is_ok()); + assert!( + response + .result + .unwrap() + .contains("requires additional input") + ); +} + +#[test] +fn test_execution_outcome_cancelled_into_response() { + let outcome = ExecutionOutcome::Cancelled { + id: "call_abc".to_string(), + }; + + let response = outcome.into_response(); + assert_eq!(response.id, "call_abc"); + assert!(response.result.is_ok()); + assert!(response.result.unwrap().contains("cancelled")); +} + +#[test] +fn test_execution_outcome_id() { + let completed = ExecutionOutcome::Completed { + id: "id1".to_string(), + result: Ok(String::new()), + }; + assert_eq!(completed.id(), "id1"); + + let needs_input = ExecutionOutcome::NeedsInput { + id: "id2".to_string(), + question: Question { + id: "q".to_string(), + text: "?".to_string(), + answer_type: AnswerType::Text, + default: None, + }, + }; + assert_eq!(needs_input.id(), "id2"); + + let cancelled = ExecutionOutcome::Cancelled { + id: "id3".to_string(), + }; + assert_eq!(cancelled.id(), "id3"); +} + +#[test] +fn test_execution_outcome_helper_methods() { + let success = ExecutionOutcome::Completed { + id: "1".to_string(), + result: Ok("output".to_string()), + }; + assert!(success.is_success()); + assert!(!success.needs_input()); + assert!(!success.is_cancelled()); + + let failure = ExecutionOutcome::Completed { + id: "2".to_string(), + result: Err("error".to_string()), + }; + assert!(!failure.is_success()); + assert!(!failure.needs_input()); + assert!(!failure.is_cancelled()); + + let needs_input = ExecutionOutcome::NeedsInput { + id: "3".to_string(), + question: Question { + id: "q".to_string(), + text: "?".to_string(), + answer_type: AnswerType::Boolean, + default: None, + }, + }; + assert!(!needs_input.is_success()); + assert!(needs_input.needs_input()); + assert!(!needs_input.is_cancelled()); + + let cancelled = ExecutionOutcome::Cancelled { + id: "4".to_string(), + }; + assert!(!cancelled.is_success()); + assert!(!cancelled.needs_input()); + assert!(cancelled.is_cancelled()); +} + +/// Build a minimal `ToolParameterConfig` for use in validation tests. +fn param(kind: &str, required: bool) -> ToolParameterConfig { + ToolParameterConfig { + kind: kind.to_owned().into(), + required, + default: None, + summary: None, + description: None, + examples: None, + enumeration: vec![], + items: None, + properties: IndexMap::default(), + } +} + +#[test] +fn test_validate_tool_arguments() { + struct TestCase { + arguments: Map, + parameters: IndexMap, + want: Result<(), ToolError>, + } + + let cases = vec![ + ("empty", TestCase { + arguments: Map::new(), + parameters: IndexMap::new(), + want: Ok(()), + }), + ("correct", TestCase { + arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), + parameters: IndexMap::from_iter([ + ("foo".to_owned(), param("string", true)), + ("bar".to_owned(), param("string", false)), + ]), + want: Ok(()), + }), + ("missing", TestCase { + arguments: Map::new(), + parameters: IndexMap::from_iter([("foo".to_owned(), param("string", true))]), + want: Err(ToolError::Arguments { + missing: vec!["foo".to_owned()], + unknown: vec![], + }), + }), + ("unknown", TestCase { + arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), + parameters: IndexMap::from_iter([("bar".to_owned(), param("string", false))]), + want: Err(ToolError::Arguments { + missing: vec![], + unknown: vec!["foo".to_owned()], + }), + }), + ("both", TestCase { + arguments: Map::from_iter([("foo".to_owned(), json!("bar"))]), + parameters: IndexMap::from_iter([("bar".to_owned(), param("string", true))]), + want: Err(ToolError::Arguments { + missing: vec!["bar".to_owned()], + unknown: vec!["foo".to_owned()], + }), + }), + ]; + + for (name, test_case) in cases { + let result = validate_tool_arguments(&test_case.arguments, &test_case.parameters); + assert_eq!(result, test_case.want, "failed case: {name}"); + } +} + +#[test] +fn test_validate_nested_array_item_properties() { + // Mirrors the fs_modify_file schema: + // patterns: array of { old: string (required), new: string (required) } + let parameters = IndexMap::from_iter([ + ("path".to_owned(), param("string", true)), + ("patterns".to_owned(), ToolParameterConfig { + kind: "array".to_owned().into(), + required: true, + items: Some(Box::new(ToolParameterConfig { + kind: "object".to_owned().into(), + required: false, + properties: IndexMap::from_iter([ + ("old".to_owned(), param("string", true)), + ("new".to_owned(), param("string", true)), + ]), + ..param("object", false) + })), + ..param("array", true) + }), + ]); + + // Valid: correct inner fields. + let args = json!({ + "path": "src/lib.rs", + "patterns": [{"old": "foo", "new": "bar"}] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Valid: multiple items. + let args = json!({ + "path": "src/lib.rs", + "patterns": [ + {"old": "a", "new": "b"}, + {"old": "c", "new": "d"} + ] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Invalid: unknown inner field. + let args = json!({ + "path": "src/lib.rs", + "patterns": [{"old": "foo", "new": "bar", "extra": true}] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Err(ToolError::Arguments { + missing: vec![], + unknown: vec!["extra".to_owned()], + }) + ); + + // Invalid: missing required inner field. + let args = json!({ + "path": "src/lib.rs", + "patterns": [{"old": "foo"}] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Err(ToolError::Arguments { + missing: vec!["new".to_owned()], + unknown: vec![], + }) + ); + + // Invalid: wrong inner field names (the LLM hallucinated names). + let args = json!({ + "path": "src/lib.rs", + "patterns": [{"string_to_replace": "foo", "new_string": "bar"}] + }); + let err = validate_tool_arguments(args.as_object().unwrap(), ¶meters); + assert!(err.is_err()); + let ToolError::Arguments { missing, unknown } = err.unwrap_err() else { + panic!("expected Arguments error"); + }; + assert_eq!(missing, vec!["old".to_owned(), "new".to_owned()]); + // preserve_order: keys iterate in insertion order from json! macro + assert_eq!(unknown, vec![ + "string_to_replace".to_owned(), + "new_string".to_owned() + ]); + + // Valid: non-object array items are skipped (no crash). + let args = json!({ + "path": "src/lib.rs", + "patterns": ["not an object"] + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Valid: parameter is not an array (type mismatch, but not our job to check types). + let args = json!({ + "path": "src/lib.rs", + "patterns": "not an array" + }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); +} + +#[test] +fn test_validate_nested_object_properties() { + let parameters = IndexMap::from_iter([ + ("name".to_owned(), param("string", true)), + ("config".to_owned(), ToolParameterConfig { + kind: "object".to_owned().into(), + required: false, + properties: IndexMap::from_iter([ + ("verbose".to_owned(), param("boolean", false)), + ("output".to_owned(), param("string", true)), + ]), + ..param("object", false) + }), + ]); + + // Valid. + let args = json!({ "name": "test", "config": { "verbose": true, "output": "out.txt" } }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Valid: optional object param omitted entirely. + let args = json!({ "name": "test" }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Ok(()) + ); + + // Invalid: unknown field inside the object. + let args = json!({ "name": "test", "config": { "output": "o", "bogus": 1 } }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Err(ToolError::Arguments { + missing: vec![], + unknown: vec!["bogus".to_owned()], + }) + ); + + // Invalid: missing required field inside the object. + let args = json!({ "name": "test", "config": { "verbose": true } }); + assert_eq!( + validate_tool_arguments(args.as_object().unwrap(), ¶meters), + Err(ToolError::Arguments { + missing: vec!["output".to_owned()], + unknown: vec![], + }) + ); +} + +/// Build a parameter with a default value. +fn param_with_default(kind: &str, required: bool, default: Value) -> ToolParameterConfig { + ToolParameterConfig { + default: Some(default), + ..param(kind, required) + } +} + +#[test] +fn test_apply_defaults_fills_missing_required_with_default() { + let parameters = IndexMap::from_iter([ + ("path".to_owned(), param("string", true)), + ( + "use_regex".to_owned(), + param_with_default("boolean", true, json!(false)), + ), + ]); + + let mut args: Map = Map::from_iter([("path".to_owned(), json!("src/lib.rs"))]); + + apply_parameter_defaults(&mut args, ¶meters); + + assert_eq!(args.get("path"), Some(&json!("src/lib.rs"))); + assert_eq!(args.get("use_regex"), Some(&json!(false))); +} + +#[test] +fn test_apply_defaults_does_not_overwrite_provided_values() { + let parameters = IndexMap::from_iter([( + "use_regex".to_owned(), + param_with_default("boolean", true, json!(false)), + )]); + + let mut args: Map = Map::from_iter([("use_regex".to_owned(), json!(true))]); + + apply_parameter_defaults(&mut args, ¶meters); + + assert_eq!(args.get("use_regex"), Some(&json!(true))); +} + +#[test] +fn test_apply_defaults_fills_optional_param_with_default() { + let parameters = IndexMap::from_iter([( + "verbose".to_owned(), + param_with_default("boolean", false, json!(false)), + )]); + + let mut args: Map = Map::new(); + apply_parameter_defaults(&mut args, ¶meters); + + assert_eq!(args.get("verbose"), Some(&json!(false))); +} + +#[test] +fn test_apply_defaults_skips_params_without_default() { + let parameters = IndexMap::from_iter([("path".to_owned(), param("string", true))]); + + let mut args: Map = Map::new(); + apply_parameter_defaults(&mut args, ¶meters); + + assert!(!args.contains_key("path")); +} + +#[test] +fn test_apply_defaults_recurses_into_objects() { + let parameters = IndexMap::from_iter([("config".to_owned(), ToolParameterConfig { + kind: "object".to_owned().into(), + required: false, + properties: IndexMap::from_iter([( + "verbose".to_owned(), + param_with_default("boolean", false, json!(true)), + )]), + ..param("object", false) + })]); + + let mut args: Map = Map::from_iter([("config".to_owned(), json!({}))]); + + apply_parameter_defaults(&mut args, ¶meters); + + assert_eq!(args["config"]["verbose"], json!(true)); +} + +#[test] +fn test_apply_defaults_recurses_into_array_items() { + let parameters = IndexMap::from_iter([("items".to_owned(), ToolParameterConfig { + kind: "array".to_owned().into(), + required: true, + items: Some(Box::new(ToolParameterConfig { + kind: "object".to_owned().into(), + required: false, + properties: IndexMap::from_iter([( + "enabled".to_owned(), + param_with_default("boolean", false, json!(true)), + )]), + ..param("object", false) + })), + ..param("array", true) + })]); + + let mut args: Map = Map::from_iter([( + "items".to_owned(), + json!([{"name": "a"}, {"name": "b", "enabled": false}]), + )]); + + apply_parameter_defaults(&mut args, ¶meters); + + let items = args["items"].as_array().unwrap(); + assert_eq!(items[0]["enabled"], json!(true)); + // Explicitly provided false is preserved. + assert_eq!(items[1]["enabled"], json!(false)); +} + +#[test] +fn test_apply_defaults_then_validate_passes() { + // Mirrors the fs_modify_file scenario: replace_using_regex is required + // with a default, and the LLM omits it. + let parameters = IndexMap::from_iter([ + ("path".to_owned(), param("string", true)), + ( + "replace_using_regex".to_owned(), + param_with_default("boolean", true, json!(false)), + ), + ]); + + let mut args: Map = Map::from_iter([("path".to_owned(), json!("README.md"))]); + + // Without defaults, validation would fail. + assert!(validate_tool_arguments(&args, ¶meters).is_err()); + + // After applying defaults, validation passes. + apply_parameter_defaults(&mut args, ¶meters); + assert!(validate_tool_arguments(&args, ¶meters).is_ok()); + assert_eq!(args["replace_using_regex"], json!(false)); +} + +#[test] +fn test_split_short_single_line() { + let (s, d) = split_description("Run cargo check."); + assert_eq!(s, "Run cargo check."); + assert_eq!(d, None); +} + +#[test] +fn test_split_short_no_period() { + let (s, d) = split_description("Run cargo check"); + assert_eq!(s, "Run cargo check"); + assert_eq!(d, None); +} + +#[test] +fn test_split_two_sentences() { + let (s, d) = split_description( + "Run cargo check on a package. Supports workspace packages and feature flags.", + ); + assert_eq!(s, "Run cargo check on a package."); + assert_eq!( + d, + Some("Supports workspace packages and feature flags.".to_owned()) + ); +} + +#[test] +fn test_split_multiline() { + let input = "Search for code in a repository.\n\nSupports regex and qualifiers."; + let (s, d) = split_description(input); + assert_eq!(s, "Search for code in a repository."); + assert_eq!(d, Some("Supports regex and qualifiers.".to_owned())); +} + +#[test] +fn test_split_multiline_no_period() { + let input = "First line without period\nSecond line here."; + let (s, d) = split_description(input); + assert_eq!(s, "First line without period"); + assert_eq!(d, Some("Second line here.".to_owned())); +} + +#[test] +fn test_split_preserves_abbreviations() { + // "e.g." should not be treated as a sentence boundary. + let (s, d) = split_description("Use e.g. foo or bar."); + assert_eq!(s, "Use e.g. foo or bar."); + assert_eq!(d, None); +} + +#[test] +fn test_split_long_single_line_with_period() { + let input = "This is a very long description that exceeds the threshold. It contains \ + additional details about the tool's behavior."; + let (s, d) = split_description(input); + assert_eq!( + s, + "This is a very long description that exceeds the threshold." + ); + assert!(d.is_some()); +} + +#[test] +fn test_split_empty() { + let (s, d) = split_description(""); + assert_eq!(s, ""); + assert_eq!(d, None); +} + +#[test] +fn test_split_trims_whitespace() { + let (s, d) = split_description(" hello "); + assert_eq!(s, "hello"); + assert_eq!(d, None); +}