From 62278e67bd29a783c53efae766be95f3855d6890 Mon Sep 17 00:00:00 2001 From: darin Date: Thu, 19 Feb 2026 18:46:49 -0800 Subject: [PATCH 01/44] feat: multi-turn Predict + grouped Message model - Refactor Message from flat enum to Role + Vec - Reasoning continuity preserved through rig round-trips - From trivial (no data loss), RigChatMessage removed - Predict API split: forward(input) + forward_continue(chat) - ToolLoopMode::CallerManaged for caller-controlled tool loops - Full conversation history in LMResponse.chat - temp_env replaces unsafe set_var in all tests - 14 new tests: round-trip, CallerManaged conversation, reasoning preservation --- Cargo.lock | 11 + crates/dspy-rs/Cargo.toml | 3 + crates/dspy-rs/src/adapter/chat.rs | 2 +- crates/dspy-rs/src/core/lm/chat.rs | 519 +++++++++++++++--- crates/dspy-rs/src/core/lm/mod.rs | 230 +++++++- crates/dspy-rs/src/predictors/predict.rs | 118 +++- .../tests/test_caller_managed_conversation.rs | 215 ++++++++ .../tests/test_chain_of_thought_swap.rs | 23 +- crates/dspy-rs/tests/test_chat.rs | 262 +++++++-- crates/dspy-rs/tests/test_lm.rs | 114 ++-- .../dspy-rs/tests/test_message_roundtrip.rs | 345 ++++++++++++ .../tests/test_predict_conversation.rs | 159 ++++++ .../tests/test_predict_conversation_live.rs | 65 +++ crates/dspy-rs/tests/test_react_builder.rs | 23 +- crates/dspy-rs/tests/test_settings.rs | 32 +- crates/dspy-rs/tests/test_tool_call.rs | 24 +- crates/dspy-rs/tests/typed_integration.rs | 23 +- 17 files changed, 1850 insertions(+), 318 deletions(-) create mode 100644 crates/dspy-rs/tests/test_caller_managed_conversation.rs create mode 100644 crates/dspy-rs/tests/test_message_roundtrip.rs create mode 100644 crates/dspy-rs/tests/test_predict_conversation.rs create mode 100644 crates/dspy-rs/tests/test_predict_conversation_live.rs diff --git a/Cargo.lock b/Cargo.lock index d0e24aeb..11e8b33b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1236,6 +1236,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "temp-env", "tempfile", "thiserror 2.0.17", "tokio", @@ -4239,6 +4240,16 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" +[[package]] +name = "temp-env" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96374855068f47402c3121c6eed88d29cb1de8f3ab27090e273e420bdabcf050" +dependencies = [ + "futures", + "parking_lot", +] + [[package]] name = "tempfile" version = "3.23.0" diff --git a/crates/dspy-rs/Cargo.toml b/crates/dspy-rs/Cargo.toml index 8a8d5a5f..94f32610 100644 --- a/crates/dspy-rs/Cargo.toml +++ b/crates/dspy-rs/Cargo.toml @@ -51,3 +51,6 @@ ignored = ["rig-core"] [features] default = [] + +[dev-dependencies] +temp-env = { version = "0.3.6", features = ["async_closure"] } diff --git a/crates/dspy-rs/src/adapter/chat.rs b/crates/dspy-rs/src/adapter/chat.rs index 2bf58a1b..8ff78cb8 100644 --- a/crates/dspy-rs/src/adapter/chat.rs +++ b/crates/dspy-rs/src/adapter/chat.rs @@ -625,7 +625,7 @@ impl ChatAdapter { where O: BamlType + for<'a> facet::Facet<'a>, { - let content = response.content(); + let content = response.text_content(); let output_format = schema.output_format(); let sections = parse_sections(&content); diff --git a/crates/dspy-rs/src/core/lm/chat.rs b/crates/dspy-rs/src/core/lm/chat.rs index d3459fd0..87aa260a 100644 --- a/crates/dspy-rs/src/core/lm/chat.rs +++ b/crates/dspy-rs/src/core/lm/chat.rs @@ -3,106 +3,469 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; -use rig::completion::{AssistantContent, Message as RigMessage, message::UserContent}; +use rig::OneOrMany; +use rig::message::{ + AssistantContent, Message as RigMessage, Reasoning, ToolCall, ToolResult, ToolResultContent, + UserContent, +}; + +// --------------------------------------------------------------------------- +// ContentBlock — one piece of content within a message +// --------------------------------------------------------------------------- + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlock { + Text { text: String }, + ToolCall { tool_call: ToolCall }, + ToolResult { tool_result: ToolResult }, + Reasoning { reasoning: Reasoning }, +} + +impl ContentBlock { + pub fn text(t: impl Into) -> Self { + ContentBlock::Text { text: t.into() } + } + + pub fn tool_call(tc: ToolCall) -> Self { + ContentBlock::ToolCall { tool_call: tc } + } + + pub fn tool_result(tr: ToolResult) -> Self { + ContentBlock::ToolResult { tool_result: tr } + } + + pub fn reasoning(r: Reasoning) -> Self { + ContentBlock::Reasoning { reasoning: r } + } +} + +// --------------------------------------------------------------------------- +// Role +// --------------------------------------------------------------------------- + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, +} + +impl Role { + pub fn as_str(&self) -> &'static str { + match self { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + } + } +} + +// --------------------------------------------------------------------------- +// Message — a single turn in a conversation +// --------------------------------------------------------------------------- #[derive(Clone, Debug, Serialize, Deserialize)] -pub enum Message { - System { content: String }, - User { content: String }, - Assistant { content: String }, +pub struct Message { + pub role: Role, + pub content: Vec, + /// Provider-assigned message ID (e.g. Anthropic thinking turn IDs). + #[serde(skip_serializing_if = "Option::is_none", default)] + pub id: Option, } impl Message { + /// Creates a text-only message from a role string. + /// + /// # Panics + /// + /// Panics if `role` is not one of `"system"`, `"user"`, or `"assistant"`. pub fn new(role: &str, content: &str) -> Self { - match role { - "system" => Message::system(content), - "user" => Message::user(content), - "assistant" => Message::assistant(content), + let role = match role { + "system" => Role::System, + "user" => Role::User, + "assistant" => Role::Assistant, _ => panic!("Invalid role: {role}"), + }; + Self { + role, + content: vec![ContentBlock::text(content)], + id: None, } } pub fn user(content: impl Into) -> Self { - Message::User { - content: content.into(), + Self { + role: Role::User, + content: vec![ContentBlock::text(content)], + id: None, } } pub fn assistant(content: impl Into) -> Self { - Message::Assistant { - content: content.into(), + Self { + role: Role::Assistant, + content: vec![ContentBlock::text(content)], + id: None, } } pub fn system(content: impl Into) -> Self { - Message::System { - content: content.into(), + Self { + role: Role::System, + content: vec![ContentBlock::text(content)], + id: None, } } - pub fn content(&self) -> String { - match self { - Message::System { content } => content.clone(), - Message::User { content } => content.clone(), - Message::Assistant { content } => content.clone(), + /// Creates an assistant message containing a single tool call. + pub fn tool_call(tool_call: ToolCall) -> Self { + Self { + role: Role::Assistant, + content: vec![ContentBlock::tool_call(tool_call)], + id: None, } } - pub fn get_message_turn(&self) -> RigMessage { - match self { - Message::User { content } => RigMessage::user(content.clone()), - Message::Assistant { content } => RigMessage::assistant(content.clone()), - _ => panic!("Invalid role: {:?}", self), + /// Creates a user message containing a single tool result. + pub fn tool_result(tool_result: ToolResult) -> Self { + Self { + role: Role::User, + content: vec![ContentBlock::tool_result(tool_result)], + id: None, + } + } + + /// Creates an assistant message containing a single reasoning block. + pub fn reasoning(reasoning: Reasoning) -> Self { + Self { + role: Role::Assistant, + content: vec![ContentBlock::reasoning(reasoning)], + id: None, + } + } + + /// Creates a message with arbitrary content blocks. + pub fn with_content(role: Role, content: Vec) -> Self { + Self { + role, + content, + id: None, + } + } + + // -- Accessors ----------------------------------------------------------- + + /// Returns a string representation of the message's content. + /// + /// For text-only messages, returns the text. For multi-content messages, + /// returns all blocks formatted and joined with newlines. + pub fn content(&self) -> String { + let parts: Vec = self + .content + .iter() + .map(|block| match block { + ContentBlock::Text { text } => text.clone(), + ContentBlock::ToolCall { tool_call } => { + format!( + "{}({})", + tool_call.function.name, tool_call.function.arguments + ) + } + ContentBlock::ToolResult { tool_result } => tool_result + .content + .iter() + .filter_map(|item| match item { + ToolResultContent::Text(text) => Some(text.text.as_str()), + ToolResultContent::Image(_) => None, + }) + .collect::>() + .join("\n"), + ContentBlock::Reasoning { reasoning } => reasoning.display_text(), + }) + .collect(); + parts.join("\n") + } + + /// Returns only the text content, ignoring tool calls, tool results, + /// and reasoning blocks. Used by the parser to extract structured output. + pub fn text_content(&self) -> String { + self.content + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("\n") + } + + // -- Content query helpers ----------------------------------------------- + + /// Returns `true` if this message contains at least one tool call. + pub fn has_tool_calls(&self) -> bool { + self.content + .iter() + .any(|b| matches!(b, ContentBlock::ToolCall { .. })) + } + + /// Returns `true` if this message contains at least one tool result. + pub fn has_tool_results(&self) -> bool { + self.content + .iter() + .any(|b| matches!(b, ContentBlock::ToolResult { .. })) + } + + /// Returns `true` if this message contains at least one reasoning block. + pub fn has_reasoning(&self) -> bool { + self.content + .iter() + .any(|b| matches!(b, ContentBlock::Reasoning { .. })) + } + + /// Extracts all tool calls from this message. + pub fn tool_calls(&self) -> Vec<&ToolCall> { + self.content + .iter() + .filter_map(|b| match b { + ContentBlock::ToolCall { tool_call } => Some(tool_call), + _ => None, + }) + .collect() + } + + // -- Rig conversion ------------------------------------------------------ + + /// Converts this message to a rig message for provider API calls. + /// + /// Returns `None` for system messages (rig handles them as preamble). + pub fn to_rig_message(&self) -> Option { + match self.role { + Role::System => None, + Role::User => { + let user_content: Vec = self + .content + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(UserContent::text(text.clone())), + ContentBlock::ToolResult { tool_result } => { + Some(UserContent::ToolResult(tool_result.clone())) + } + // ToolCall/Reasoning don't belong in user messages; skip gracefully + _ => None, + }) + .collect(); + if user_content.is_empty() { + return Some(RigMessage::user(String::new())); + } + Some(RigMessage::User { + content: OneOrMany::many(user_content) + .unwrap_or_else(|_| OneOrMany::one(UserContent::text(String::new()))), + }) + } + Role::Assistant => { + let asst_content: Vec = self + .content + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(AssistantContent::text(text.clone())), + ContentBlock::ToolCall { tool_call } => { + Some(AssistantContent::ToolCall(tool_call.clone())) + } + ContentBlock::Reasoning { reasoning } => { + Some(AssistantContent::Reasoning(reasoning.clone())) + } + // ToolResult doesn't belong in assistant messages; skip gracefully + _ => None, + }) + .collect(); + if asst_content.is_empty() { + return Some(RigMessage::assistant(String::new())); + } + Some(RigMessage::Assistant { + id: self.id.clone(), + content: OneOrMany::many(asst_content) + .unwrap_or_else(|_| OneOrMany::one(AssistantContent::text(String::new()))), + }) + } } } + // -- JSON serialization -------------------------------------------------- + pub fn to_json(&self) -> Value { - match self { - Message::System { content } => json!({ "role": "system", "content": content }), - Message::User { content } => json!({ "role": "user", "content": content }), - Message::Assistant { content } => json!({ "role": "assistant", "content": content }), + let content_json: Vec = self + .content + .iter() + .map(|block| match block { + ContentBlock::Text { text } => json!({ "type": "text", "text": text }), + ContentBlock::ToolCall { tool_call } => { + json!({ "type": "tool_call", "tool_call": tool_call }) + } + ContentBlock::ToolResult { tool_result } => { + json!({ "type": "tool_result", "tool_result": tool_result }) + } + ContentBlock::Reasoning { reasoning } => { + json!({ "type": "reasoning", "reasoning": reasoning }) + } + }) + .collect(); + + let mut msg = json!({ + "role": self.role.as_str(), + "content": content_json, + }); + + if let Some(id) = &self.id { + msg.as_object_mut() + .unwrap() + .insert("id".to_string(), json!(id)); } + + msg + } + + fn from_json_value(message: &Value) -> Result { + let role_str = message + .get("role") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("chat message missing string role"))?; + + let role = match role_str { + "system" => Role::System, + "user" => Role::User, + "assistant" => Role::Assistant, + other => return Err(anyhow::anyhow!("unsupported chat message role: {other}")), + }; + + let id = message.get("id").and_then(Value::as_str).map(String::from); + + let content_val = message.get("content"); + + // Support both formats: + // New: "content": [{ "type": "text", "text": "..." }, ...] + // Legacy: "content": "plain string" + let content = match content_val { + Some(Value::Array(arr)) => arr + .iter() + .map(parse_content_block) + .collect::>>()?, + Some(Value::String(s)) => vec![ContentBlock::text(s.clone())], + _ => { + // Legacy type-tagged format: { "type": "tool_call", "tool_call": {...} } + match message.get("type").and_then(Value::as_str) { + Some("tool_call") => { + let tc: ToolCall = serde_json::from_value(message["tool_call"].clone())?; + vec![ContentBlock::tool_call(tc)] + } + Some("tool_result") => { + let tr: ToolResult = + serde_json::from_value(message["tool_result"].clone())?; + vec![ContentBlock::tool_result(tr)] + } + Some("reasoning") => { + let r: Reasoning = serde_json::from_value(message["reasoning"].clone())?; + vec![ContentBlock::reasoning(r)] + } + Some(other) => { + return Err(anyhow::anyhow!("unsupported chat message type: {other}")); + } + None => return Err(anyhow::anyhow!("chat message missing content field")), + } + } + }; + + Ok(Self { role, content, id }) + } +} + +fn parse_content_block(value: &Value) -> Result { + let block_type = value + .get("type") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("content block missing type"))?; + + match block_type { + "text" => { + let text = value + .get("text") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("text block missing text field"))?; + Ok(ContentBlock::text(text)) + } + "tool_call" => { + let tc: ToolCall = serde_json::from_value(value["tool_call"].clone())?; + Ok(ContentBlock::tool_call(tc)) + } + "tool_result" => { + let tr: ToolResult = serde_json::from_value(value["tool_result"].clone())?; + Ok(ContentBlock::tool_result(tr)) + } + "reasoning" => { + let r: Reasoning = serde_json::from_value(value["reasoning"].clone())?; + Ok(ContentBlock::reasoning(r)) + } + other => Err(anyhow::anyhow!("unsupported content block type: {other}")), } } +// --------------------------------------------------------------------------- +// From — lossless conversion, one rig message → one DSRs message +// --------------------------------------------------------------------------- + impl From for Message { fn from(message: RigMessage) -> Self { match message { RigMessage::User { content } => { - let text = content + let blocks: Vec = content .into_iter() - .find_map(|c| { - if let UserContent::Text(t) = c { - Some(t.text) - } else { - None - } + .filter_map(|item| match item { + UserContent::Text(text) => Some(ContentBlock::text(text.text)), + UserContent::ToolResult(result) => Some(ContentBlock::tool_result(result)), + UserContent::Image(_) + | UserContent::Audio(_) + | UserContent::Video(_) + | UserContent::Document(_) => None, }) - .unwrap_or_default(); - Message::user(text) + .collect(); + Message { + role: Role::User, + content: if blocks.is_empty() { + vec![ContentBlock::text(String::new())] + } else { + blocks + }, + id: None, + } } - RigMessage::Assistant { content, .. } => { - let text = content + RigMessage::Assistant { id, content } => { + let blocks: Vec = content .into_iter() - .find_map(|c| { - if let AssistantContent::Text(t) = c { - Some(t.text) - } else { - None - } + .filter_map(|item| match item { + AssistantContent::Text(text) => Some(ContentBlock::text(text.text)), + AssistantContent::ToolCall(tc) => Some(ContentBlock::tool_call(tc)), + AssistantContent::Reasoning(r) => Some(ContentBlock::reasoning(r)), + AssistantContent::Image(_) => None, }) - .unwrap_or_default(); - Message::assistant(text) + .collect(); + Message { + role: Role::Assistant, + content: if blocks.is_empty() { + vec![ContentBlock::text(String::new())] + } else { + blocks + }, + id, + } } } } } -pub struct RigChatMessage { - pub system: String, - pub conversation: Vec, - pub prompt: RigMessage, -} +// --------------------------------------------------------------------------- +// Chat — ordered sequence of messages +// --------------------------------------------------------------------------- #[derive(Clone, Debug)] pub struct Chat { @@ -139,16 +502,13 @@ impl Chat { } pub fn from_json(&self, json_dump: Value) -> Result { - let messages = json_dump.as_array().unwrap(); + let messages = json_dump + .as_array() + .ok_or_else(|| anyhow::anyhow!("chat dump must be an array"))?; let messages = messages .iter() - .map(|message| { - Message::new( - message["role"].as_str().unwrap(), - message["content"].as_str().unwrap(), - ) - }) - .collect(); + .map(Message::from_json_value) + .collect::>>()?; Ok(Self { messages }) } @@ -161,22 +521,27 @@ impl Chat { json!(messages) } - pub fn get_rig_messages(&self) -> RigChatMessage { - let system: String = self.messages[0].content(); - let conversation: Vec = if self.messages.len() > 2 { - self.messages[1..self.messages.len() - 1] - .iter() - .map(|message| message.get_message_turn()) - .collect::>() - } else { - vec![] - }; - let prompt = self.messages.last().unwrap().get_message_turn(); + // -- Rig interop --------------------------------------------------------- - RigChatMessage { - system, - conversation, - prompt, - } + /// Extracts the system prompt text from the first system message. + pub fn system_prompt(&self) -> String { + self.messages + .iter() + .find_map(|message| { + if message.role == Role::System { + Some(message.text_content()) + } else { + None + } + }) + .unwrap_or_default() + } + + /// Converts all non-system messages to rig messages for provider API calls. + pub fn to_rig_chat_history(&self) -> Vec { + self.messages + .iter() + .filter_map(Message::to_rig_message) + .collect() } } diff --git a/crates/dspy-rs/src/core/lm/mod.rs b/crates/dspy-rs/src/core/lm/mod.rs index 4d18c50b..52273430 100644 --- a/crates/dspy-rs/src/core/lm/mod.rs +++ b/crates/dspy-rs/src/core/lm/mod.rs @@ -31,6 +31,12 @@ pub struct LMResponse { pub tool_executions: Vec, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ToolLoopMode { + Auto, + CallerManaged, +} + #[derive(Builder)] #[builder(finish_fn(vis = "", name = __internal_build))] pub struct LM { @@ -179,7 +185,6 @@ impl LMBuilder { struct ToolLoopResult { message: Message, - #[allow(unused)] chat_history: Vec, tool_calls: Vec, tool_executions: Vec, @@ -197,6 +202,7 @@ enum ChoiceAction { ToolCalls { calls: Vec, full_content: Box>, + assistant_text: Option, }, } @@ -224,6 +230,7 @@ fn classify_choice(choice: rig::OneOrMany) -> ChoiceAction { return ChoiceAction::ToolCalls { calls: tool_calls, full_content: Box::new(choice), + assistant_text: text, }; } @@ -261,6 +268,17 @@ async fn find_and_execute_tool( } impl LM { + fn chat_from_rig_history(system_prompt: &str, history: &[rig::message::Message]) -> Chat { + let mut chat = Chat::new(Vec::new()); + if !system_prompt.is_empty() { + chat.push_message(Message::system(system_prompt.to_string())); + } + for message in history { + chat.push_message(Message::from(message.clone())); + } + chat + } + /// Execute all tool calls in a batch, returning results paired with their calls. async fn execute_tool_batch( tools: &mut [Arc], @@ -422,6 +440,7 @@ impl LM { ChoiceAction::ToolCalls { calls, full_content, + .. } => { let context = format!("iteration {}", iteration); debug!(iteration, count = calls.len(), "executing tool calls"); @@ -445,39 +464,47 @@ impl LM { Err(anyhow::anyhow!("Max tool iterations reached")) } + pub async fn call(&self, messages: Chat, tools: Vec>) -> Result { + self.call_with_tool_loop_mode(messages, tools, ToolLoopMode::Auto) + .await + } + #[tracing::instrument( - name = "dsrs.lm.call", + name = "dsrs.lm.call_with_tool_loop_mode", level = "debug", skip(self, messages, tools), fields( model = %self.model, message_count = messages.len(), tool_count = tools.len(), - cache_enabled = self.cache + cache_enabled = self.cache, + tool_loop_mode = ?tool_loop_mode ) )] - pub async fn call(&self, messages: Chat, tools: Vec>) -> Result { + pub async fn call_with_tool_loop_mode( + &self, + messages: Chat, + tools: Vec>, + tool_loop_mode: ToolLoopMode, + ) -> Result { use rig::OneOrMany; use rig::completion::CompletionRequest; - let request_messages = messages.get_rig_messages(); + let system_prompt = messages.system_prompt(); + let chat_history = messages.to_rig_chat_history(); let mut tool_definitions = Vec::new(); for tool in &tools { tool_definitions.push(tool.definition("".to_string()).await); } trace!( - conversation_messages = request_messages.conversation.len(), + conversation_messages = chat_history.len(), tool_definitions = tool_definitions.len(), "prepared completion request inputs" ); - // Build the completion request manually - let mut chat_history = request_messages.conversation; - chat_history.push(request_messages.prompt); - let request = CompletionRequest { model: None, - preamble: Some(request_messages.system.clone()), + preamble: Some(system_prompt.clone()), chat_history: if chat_history.len() == 1 { OneOrMany::one(chat_history.clone().into_iter().next().unwrap()) } else { @@ -517,12 +544,17 @@ impl LM { // Scan ALL content blocks in the response — don't just look at .first(). // Responses can be [Reasoning, ToolCall] or [Reasoning, Text]. let mut tool_loop_result = None; - let first_choice = match classify_choice(response.choice) { + let mut returned_tool_calls = Vec::new(); + let mut assistant_content_for_history: Option> = None; + let mut append_output_after_history = false; + let classified = classify_choice(response.choice.clone()); + let first_choice = match classified { ChoiceAction::Text(text) => Message::assistant(&text), ChoiceAction::ToolCalls { calls, full_content, - } if !tools.is_empty() => { + assistant_text, + } if tool_loop_mode == ToolLoopMode::Auto && !tools.is_empty() => { debug!(count = calls.len(), "entering tool loop"); let result = self .execute_tool_loop( @@ -531,24 +563,58 @@ impl LM { tools, tool_definitions, chat_history, - request_messages.system, + system_prompt.clone(), &mut accumulated_usage, ) .await?; let message = result.message.clone(); tool_loop_result = Some(result); + append_output_after_history = true; message } - ChoiceAction::ToolCalls { calls, .. } => { + ChoiceAction::ToolCalls { calls, .. } + if tool_loop_mode == ToolLoopMode::Auto && tools.is_empty() => + { let names: Vec<_> = calls.iter().map(|tc| tc.function.name.as_str()).collect(); warn!(?names, "tools requested but no tools available"); let msg = format!("Tool calls requested: {:?}, but no tools available", names); + assistant_content_for_history = Some(rig::OneOrMany::many( + calls + .into_iter() + .map(AssistantContent::ToolCall) + .collect::>(), + )?); + append_output_after_history = true; Message::assistant(&msg) } + ChoiceAction::ToolCalls { + calls, + assistant_text, + full_content, + } => { + returned_tool_calls = calls; + assistant_content_for_history = Some(*full_content); + Message::assistant(assistant_text.unwrap_or_default()) + } }; - let mut full_chat = messages.clone(); - full_chat.push_message(first_choice.clone()); + let mut full_chat = if let Some(result) = tool_loop_result.as_ref() { + Self::chat_from_rig_history(&system_prompt, &result.chat_history) + } else { + let mut chat = messages.clone(); + if let Some(content) = assistant_content_for_history { + // Convert grouped rig content into a single grouped Message. + let rig_msg = rig::message::Message::Assistant { id: None, content }; + chat.push_message(Message::from(rig_msg)); + } else { + // Text-only path: preserve a single assistant response turn. + chat.push_message(first_choice.clone()); + } + chat + }; + if append_output_after_history { + full_chat.push_message(first_choice.clone()); + } debug!( tool_calls = tool_loop_result .as_ref() @@ -569,7 +635,7 @@ impl LM { tool_calls: tool_loop_result .as_ref() .map(|result| result.tool_calls.clone()) - .unwrap_or_default(), + .unwrap_or(returned_tool_calls), tool_executions: tool_loop_result .map(|result| result.tool_executions) .unwrap_or_default(), @@ -648,9 +714,7 @@ impl DummyLM { prediction: String, ) -> Result { let mut full_chat = messages.clone(); - full_chat.push_message(Message::Assistant { - content: prediction.clone(), - }); + full_chat.push_message(Message::assistant(prediction.clone())); if self.cache && let Some(cache) = self.cache_handler.as_ref() @@ -682,9 +746,7 @@ impl DummyLM { } Ok(LMResponse { - output: Message::Assistant { - content: prediction.clone(), - }, + output: Message::assistant(prediction.clone()), usage: LmUsage::default(), chat: full_chat, tool_calls: Vec::new(), @@ -716,6 +778,10 @@ mod tests { use super::*; use rig::OneOrMany; use rig::completion::AssistantContent; + use rig::completion::ToolDefinition; + use rig::tool::Tool; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; fn make_tool_call(name: &str) -> AssistantContent { AssistantContent::tool_call( @@ -749,10 +815,12 @@ mod tests { ChoiceAction::ToolCalls { calls, full_content, + assistant_text, } => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "search"); assert_eq!(full_content.iter().count(), 1); + assert!(assistant_text.is_none()); } ChoiceAction::Text(_) => panic!("expected ToolCalls, got Text"), } @@ -770,11 +838,13 @@ mod tests { ChoiceAction::ToolCalls { calls, full_content, + assistant_text, } => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "search"); // full_content preserves both blocks assert_eq!(full_content.iter().count(), 2); + assert!(assistant_text.is_none()); } ChoiceAction::Text(_) => panic!("expected ToolCalls, got Text"), } @@ -809,9 +879,14 @@ mod tests { OneOrMany::many(vec![make_text("some text"), make_tool_call("search")]).unwrap(); match classify_choice(choice) { - ChoiceAction::ToolCalls { calls, .. } => { + ChoiceAction::ToolCalls { + calls, + assistant_text, + .. + } => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "search"); + assert_eq!(assistant_text.as_deref(), Some("some text")); } ChoiceAction::Text(_) => panic!("expected ToolCalls, got Text"), } @@ -830,11 +905,13 @@ mod tests { ChoiceAction::ToolCalls { calls, full_content, + assistant_text, } => { assert_eq!(calls.len(), 2); assert_eq!(calls[0].function.name, "search"); assert_eq!(calls[1].function.name, "calculate"); assert_eq!(full_content.iter().count(), 3); + assert!(assistant_text.is_none()); } ChoiceAction::Text(_) => panic!("expected ToolCalls, got Text"), } @@ -850,4 +927,107 @@ mod tests { ChoiceAction::ToolCalls { .. } => panic!("expected Text, got ToolCalls"), } } + + #[derive(Clone)] + struct CountingTool { + calls: Arc, + } + + #[derive(Debug)] + struct CountingToolError; + + impl std::fmt::Display for CountingToolError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "counting tool error") + } + } + + impl std::error::Error for CountingToolError {} + + impl Tool for CountingTool { + const NAME: &'static str = "counter"; + type Error = CountingToolError; + type Args = serde_json::Value; + type Output = String; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: Self::NAME.to_string(), + description: "counter tool".to_string(), + parameters: serde_json::json!({ + "type": "object", + "additionalProperties": true + }), + } + } + + async fn call(&self, _args: Self::Args) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok("counted".to_string()) + } + } + + fn test_lm_with_model(model: TestCompletionModel) -> LM { + LM { + base_url: None, + api_key: None, + model: "openai:gpt-4o-mini".to_string(), + temperature: 0.0, + max_tokens: 128, + max_tool_iterations: 4, + cache: false, + cache_handler: None, + client: Some(Arc::new(LMClient::Test(model))), + } + } + + #[tokio::test] + async fn call_with_caller_managed_mode_returns_tool_calls_without_executing() { + let model = TestCompletionModel::new([make_tool_call("counter")]); + let lm = test_lm_with_model(model); + + let call_count = Arc::new(AtomicUsize::new(0)); + let tools: Vec> = vec![Arc::new(CountingTool { + calls: Arc::clone(&call_count), + })]; + + let chat = Chat::new(vec![Message::user("Use the counter tool")]); + let response = lm + .call_with_tool_loop_mode(chat, tools, ToolLoopMode::CallerManaged) + .await + .expect("caller-managed call should succeed"); + + assert_eq!(response.tool_calls.len(), 1); + assert!(response.tool_executions.is_empty()); + assert_eq!(call_count.load(Ordering::SeqCst), 0); + assert_eq!(response.output.content(), ""); + assert_eq!(response.chat.len(), 2); + assert!(response.chat.messages[1].has_tool_calls()); + } + + #[tokio::test] + async fn call_default_auto_mode_executes_tool_loop() { + let model = TestCompletionModel::new([make_tool_call("counter"), make_text("done")]); + let lm = test_lm_with_model(model); + + let call_count = Arc::new(AtomicUsize::new(0)); + let tools: Vec> = vec![Arc::new(CountingTool { + calls: Arc::clone(&call_count), + })]; + + let chat = Chat::new(vec![Message::user("Use the counter tool")]); + let response = lm + .call(chat, tools) + .await + .expect("auto call should succeed"); + + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_executions.len(), 1); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + assert_eq!(response.output.content(), "done"); + assert_eq!(response.chat.len(), 4); + assert!(response.chat.messages[1].has_tool_calls()); + assert!(response.chat.messages[2].has_tool_results()); + assert_eq!(response.chat.messages[3].role, Role::Assistant); + } } diff --git a/crates/dspy-rs/src/predictors/predict.rs b/crates/dspy-rs/src/predictors/predict.rs index b45a0e69..d3196bcd 100644 --- a/crates/dspy-rs/src/predictors/predict.rs +++ b/crates/dspy-rs/src/predictors/predict.rs @@ -189,12 +189,43 @@ impl Predict { S::Input: BamlType, S::Output: BamlType, { - let lm = { - let guard = GLOBAL_SETTINGS.read().unwrap(); - let settings = guard.as_ref().unwrap(); - Arc::clone(&settings.lm) - }; + let chat = self.build_chat(&input)?; + let (predicted, _) = self.call_and_parse(chat).await?; + Ok(predicted) + } + /// Continues a prior conversation and parses the LM's response. + /// + /// The caller owns the `Chat` between calls: + /// 1. Call [`forward`] to get the first turn's `(Predicted, Chat)`. + /// 2. Append a follow-up user message to the returned `Chat`. + /// 3. Call `forward_continue` with the updated `Chat`. + /// + /// The LM response is parsed using the same `[[ ## field ## ]]` protocol. + /// The caller is responsible for including format instructions in follow-up + /// messages if the model needs reminding of the output format. + pub async fn forward_continue( + &self, + chat: Chat, + ) -> Result<(Predicted, Chat), PredictError> + where + S::Input: BamlType, + S::Output: BamlType, + { + trace!(message_count = chat.len(), "continuing prior chat"); + self.call_and_parse(chat).await + } + + /// Builds the first-turn chat from the signature, demos, and input. + /// + /// Returns a [`Chat`] ready to pass to [`call_and_parse`](Predict::call_and_parse) + /// or [`forward_continue`](Predict::forward_continue). Useful when you need to + /// inspect or modify the prompt before sending it to the LM. + #[allow(clippy::result_large_err)] + pub fn build_chat(&self, input: &S::Input) -> Result + where + S::Input: BamlType, + { let chat_adapter = ChatAdapter; let system = match chat_adapter .format_system_message_typed_with_instruction::(self.instruction_override.as_deref()) @@ -211,7 +242,7 @@ impl Predict { } }; - let user = chat_adapter.format_user_message_typed::(&input); + let user = chat_adapter.format_user_message_typed::(input); trace!( system_len = system.len(), user_len = user.len(), @@ -228,6 +259,27 @@ impl Predict { } chat.push("user", &user); trace!(message_count = chat.len(), "chat constructed"); + Ok(chat) + } + + /// Calls the LM with the given chat and parses the response. + /// + /// This is the shared implementation behind [`forward`](Predict::forward) and + /// [`forward_continue`](Predict::forward_continue). Use it directly when you need + /// both the prediction and the updated conversation history. + pub async fn call_and_parse( + &self, + chat: Chat, + ) -> Result<(Predicted, Chat), PredictError> + where + S::Input: BamlType, + S::Output: BamlType, + { + let lm = { + let guard = GLOBAL_SETTINGS.read().unwrap(); + let settings = guard.as_ref().unwrap(); + Arc::clone(&settings.lm) + }; let response = match lm.call(chat, self.tools.clone()).await { Ok(response) => response, @@ -249,6 +301,14 @@ impl Predict { "lm response received" ); + let crate::core::lm::LMResponse { + output, + usage, + chat, + tool_calls, + tool_executions, + } = response; + let node_id = if crate::trace::is_tracing() { crate::trace::record_node( crate::trace::NodeType::Predict { @@ -261,27 +321,27 @@ impl Predict { None }; - let raw_response = response.output.content().to_string(); - let lm_usage = response.usage.clone(); + let chat_adapter = ChatAdapter; + let raw_response = output.content().to_string(); + let lm_usage = usage.clone(); - let (typed_output, field_metas) = - match chat_adapter.parse_response_typed::(&response.output) { - Ok(parsed) => parsed, - Err(err) => { - let failed_fields = err.fields(); - debug!( - failed_fields = failed_fields.len(), - fields = ?failed_fields, - raw_response_len = raw_response.len(), - "typed parse failed" - ); - return Err(PredictError::Parse { - source: err, - raw_response, - lm_usage, - }); - } - }; + let (typed_output, field_metas) = match chat_adapter.parse_response_typed::(&output) { + Ok(parsed) => parsed, + Err(err) => { + let failed_fields = err.fields(); + debug!( + failed_fields = failed_fields.len(), + fields = ?failed_fields, + raw_response_len = raw_response.len(), + "typed parse failed" + ); + return Err(PredictError::Parse { + source: err, + raw_response, + lm_usage, + }); + } + }; let checks_total = field_metas .values() @@ -316,13 +376,13 @@ impl Predict { let metadata = CallMetadata::new( raw_response, lm_usage, - response.tool_calls, - response.tool_executions, + tool_calls, + tool_executions, node_id, field_metas, ); - Ok(Predicted::new(typed_output, metadata)) + Ok((Predicted::new(typed_output, metadata), chat)) } } diff --git a/crates/dspy-rs/tests/test_caller_managed_conversation.rs b/crates/dspy-rs/tests/test_caller_managed_conversation.rs new file mode 100644 index 00000000..dc2f9b2f --- /dev/null +++ b/crates/dspy-rs/tests/test_caller_managed_conversation.rs @@ -0,0 +1,215 @@ +//! CallerManaged + tools + conversation flow test. +//! +//! This is the RLM critical path: the caller controls tool execution and +//! manages the conversation loop, not the LM layer's auto tool loop. + +use dspy_rs::{ + ChatAdapter, LM, LMClient, Message, Predict, Role, Signature, TestCompletionModel, + ToolLoopMode, configure, +}; +use rig::completion::AssistantContent; +use rig::message::{Text, ToolCall, ToolFunction}; +use std::sync::LazyLock; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn response_with_fields(fields: &[(&str, &str)]) -> String { + let mut response = String::new(); + for (name, value) in fields { + response.push_str(&format!("[[ ## {name} ## ]]\n{value}\n\n")); + } + response.push_str("[[ ## completed ## ]]\n"); + response +} + +fn text_response(text: impl Into) -> AssistantContent { + AssistantContent::Text(Text { text: text.into() }) +} + +async fn build_test_lm(responses: Vec) -> (LM, TestCompletionModel) { + let client = TestCompletionModel::new(responses); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client.clone())) + .await + .unwrap(); + (lm, client) +} + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Code execution signature for RLM-style interaction. +struct CodeExec { + #[input] + prompt: String, + + #[output] + result: String, +} + +/// The full RLM-style loop: +/// 1. Predict builds initial chat → calls LM → model requests a tool call +/// 2. CallerManaged mode: LM returns the tool call without executing it +/// 3. Caller manually executes the tool, appends result to chat +/// 4. Caller calls forward_continue → LM returns the final text answer +/// +/// This is the exact pattern RLM will use for Python REPL interaction. +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn caller_managed_tool_loop_with_conversation() { + let _lock = SETTINGS_LOCK.lock().await; + + // Response 1: model wants to call a tool (returned as text since TestCompletionModel + // only supports single-content responses via AssistantContent) + let tool_call_response = + text_response("[[ ## result ## ]]\nNeed to execute code first\n\n[[ ## completed ## ]]\n"); + // Response 2: after seeing tool result, model gives final answer + let final_response = text_response(response_with_fields(&[("result", "42")])); + + let (lm, _client) = build_test_lm(vec![tool_call_response, final_response]).await; + configure(lm, ChatAdapter {}); + + let predict = Predict::::new(); + let input = CodeExecInput { + prompt: "Calculate 6 * 7".to_string(), + }; + + // Turn 1: Build chat and call LM + let chat = predict + .build_chat(&input) + .expect("build_chat should succeed"); + let (first_result, mut chat) = predict + .call_and_parse(chat) + .await + .expect("first turn should succeed"); + assert_eq!( + first_result.into_inner().result, + "Need to execute code first" + ); + + // Caller simulates tool execution: append user message with result + chat.push_message(Message::user("Tool output: 42")); + + // Turn 2: Continue the conversation + let (second_result, final_chat) = predict + .forward_continue(chat) + .await + .expect("second turn should succeed"); + assert_eq!(second_result.into_inner().result, "42"); + + // Verify chat grew across turns + assert!( + final_chat.len() >= 5, + "chat should have system + user + asst + user + asst, got {}", + final_chat.len() + ); + + // Verify turn ordering + assert_eq!(final_chat.messages[0].role, Role::System); + assert_eq!(final_chat.messages[1].role, Role::User); + assert_eq!(final_chat.messages[2].role, Role::Assistant); + assert_eq!(final_chat.messages[3].role, Role::User); // caller's tool result + assert_eq!(final_chat.messages[4].role, Role::Assistant); // final answer +} + +/// Tests the LM-level CallerManaged mode directly: when a tool call is requested +/// with CallerManaged mode, the LM returns the tool calls without executing them +/// and the caller controls what happens next. +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn lm_caller_managed_returns_tool_calls_in_chat_history() { + let _lock = SETTINGS_LOCK.lock().await; + + // Model responds with a tool call + let tool_call_content = AssistantContent::ToolCall(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "python_repl".to_string(), + arguments: serde_json::json!({"code": "print(6 * 7)"}), + }, + )); + + let (lm, _client) = build_test_lm(vec![tool_call_content]).await; + + let chat = dspy_rs::Chat::new(vec![Message::user("Run some code")]); + let response = lm + .call_with_tool_loop_mode(chat, vec![], ToolLoopMode::CallerManaged) + .await + .expect("caller-managed call should succeed"); + + // Tool calls returned but NOT executed + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_calls[0].function.name, "python_repl"); + assert!( + response.tool_executions.is_empty(), + "CallerManaged should not execute tools" + ); + + // Chat history should contain the tool call message + assert!( + response.chat.messages.iter().any(|m| m.has_tool_calls()), + "chat history should include the tool call message" + ); +} + +/// Multi-turn with parse failure on second turn verifies that errors +/// include the correct raw_response from the continuation, not the first turn. +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn parse_failure_on_second_turn_includes_correct_raw_response() { + let _lock = SETTINGS_LOCK.lock().await; + + let good_response = text_response(response_with_fields(&[("result", "first answer")])); + // Second response is malformed — no field markers + let bad_response = text_response("This response has no field markers at all."); + + let (lm, _client) = build_test_lm(vec![good_response, bad_response]).await; + configure(lm, ChatAdapter {}); + + let predict = Predict::::new(); + let input = CodeExecInput { + prompt: "test".to_string(), + }; + + // Turn 1: succeeds + let chat = predict.build_chat(&input).expect("build_chat"); + let (first_result, mut chat) = predict.call_and_parse(chat).await.expect("turn 1"); + assert_eq!(first_result.into_inner().result, "first answer"); + + // Turn 2: should fail with parse error containing the bad response + chat.push_message(Message::user("follow up")); + let err = predict + .forward_continue(chat) + .await + .expect_err("second turn should fail"); + + match err { + dspy_rs::PredictError::Parse { + raw_response, + source, + .. + } => { + assert!( + raw_response.contains("no field markers"), + "raw_response should be from the second turn, got: {}", + raw_response + ); + // The error should mention the missing field + let fields = source.fields(); + assert!( + !fields.is_empty() || source.field().is_some(), + "parse error should identify which field(s) failed" + ); + } + other => panic!( + "expected PredictError::Parse, got: {:?}", + std::mem::discriminant(&other) + ), + } +} diff --git a/crates/dspy-rs/tests/test_chain_of_thought_swap.rs b/crates/dspy-rs/tests/test_chain_of_thought_swap.rs index 6e7614a0..80819747 100644 --- a/crates/dspy-rs/tests/test_chain_of_thought_swap.rs +++ b/crates/dspy-rs/tests/test_chain_of_thought_swap.rs @@ -23,19 +23,18 @@ fn text_response(text: impl Into) -> AssistantContent { } async fn configure_test_lm(responses: Vec) { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - let client = TestCompletionModel::new(responses.into_iter().map(text_response)); - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .build() - .await - .unwrap() - .with_client(LMClient::Test(client)) - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client)) + .await + .unwrap(); configure(lm, ChatAdapter {}); } diff --git a/crates/dspy-rs/tests/test_chat.rs b/crates/dspy-rs/tests/test_chat.rs index 0e231301..fd8e0fe9 100644 --- a/crates/dspy-rs/tests/test_chat.rs +++ b/crates/dspy-rs/tests/test_chat.rs @@ -1,4 +1,9 @@ -use dspy_rs::core::{Chat, Message}; +use dspy_rs::core::lm::chat::{Chat, ContentBlock, Message, Role}; +use rig::OneOrMany; +use rig::message::{ + AssistantContent, Message as RigMessage, Reasoning, ToolCall, ToolFunction, ToolResult, + ToolResultContent, UserContent, +}; use rstest::*; use serde_json::json; @@ -10,20 +15,14 @@ fn test_chat_init() { Message::assistant("Hello, world to you!"), ]); - let json_value = chat.to_json(); - let json = json_value.as_array().unwrap(); - assert_eq!(chat.len(), 3); - assert_eq!(json[0]["role"], "system"); assert!(!chat.is_empty()); - assert_eq!( - json[0]["content"], - "You are a helpful assistant.".to_string() - ); - assert_eq!(json[1]["role"], "user"); - assert_eq!(json[1]["content"], "Hello, world!".to_string()); - assert_eq!(json[2]["role"], "assistant"); - assert_eq!(json[2]["content"], "Hello, world to you!".to_string()); + assert_eq!(chat.messages[0].role, Role::System); + assert_eq!(chat.messages[0].content(), "You are a helpful assistant."); + assert_eq!(chat.messages[1].role, Role::User); + assert_eq!(chat.messages[1].content(), "Hello, world!"); + assert_eq!(chat.messages[2].role, Role::Assistant); + assert_eq!(chat.messages[2].content(), "Hello, world to you!"); } #[rstest] @@ -31,11 +30,9 @@ fn test_chat_push() { let mut chat = Chat::new(vec![]); chat.push("user", "Hello, world!"); - let json_value = chat.to_json(); - let json = json_value.as_array().unwrap(); - assert_eq!(json.len(), 1); - assert_eq!(json[0]["role"], "user"); - assert_eq!(json[0]["content"], "Hello, world!".to_string()); + assert_eq!(chat.len(), 1); + assert_eq!(chat.messages[0].role, Role::User); + assert_eq!(chat.messages[0].content(), "Hello, world!"); } #[rstest] @@ -44,47 +41,48 @@ fn test_chat_pop() { chat.push("user", "Hello, world!"); chat.pop(); - let json_value = chat.to_json(); - let json = json_value.as_array().unwrap(); - assert_eq!(json.len(), 0); + assert_eq!(chat.len(), 0); } #[rstest] -fn test_chat_to_json() { +fn test_chat_to_json_and_back() { let chat = Chat::new(vec![ Message::system("You are a helpful assistant."), Message::user("Hello, world!"), Message::assistant("Hello, world to you!"), ]); - let json = chat.to_json(); + let json_dump = chat.to_json(); + let reparsed = Chat::new(vec![]).from_json(json_dump).unwrap(); + + assert_eq!(reparsed.len(), 3); + assert_eq!(reparsed.messages[0].role, Role::System); assert_eq!( - json.to_string(), - "[{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},{\"role\":\"user\",\"content\":\"Hello, world!\"},{\"role\":\"assistant\",\"content\":\"Hello, world to you!\"}]" + reparsed.messages[0].content(), + "You are a helpful assistant." ); + assert_eq!(reparsed.messages[1].role, Role::User); + assert_eq!(reparsed.messages[1].content(), "Hello, world!"); + assert_eq!(reparsed.messages[2].role, Role::Assistant); + assert_eq!(reparsed.messages[2].content(), "Hello, world to you!"); } #[rstest] -fn test_chat_from_json() { +fn test_chat_from_legacy_json() { + // Legacy format: "content" is a plain string let json = json!([ {"role":"system","content":"You are a helpful assistant."}, {"role":"user","content":"Hello, world!"}, {"role":"assistant","content":"Hello, world to you!"} ]); - let empty_chat = Chat::new(vec![]); - let chat = empty_chat.from_json(json).unwrap(); - - let json_value = chat.to_json(); - let json = json_value.as_array().unwrap(); + let chat = Chat::new(vec![]).from_json(json).unwrap(); assert_eq!(chat.len(), 3); - assert_eq!(json[0]["role"], "system"); - assert_eq!( - json[0]["content"], - "You are a helpful assistant.".to_string() - ); - assert_eq!(json[1]["role"], "user"); - assert_eq!(json[1]["content"], "Hello, world!".to_string()); - assert_eq!(json[2]["content"], "Hello, world to you!".to_string()); + assert_eq!(chat.messages[0].role, Role::System); + assert_eq!(chat.messages[0].content(), "You are a helpful assistant."); + assert_eq!(chat.messages[1].role, Role::User); + assert_eq!(chat.messages[1].content(), "Hello, world!"); + assert_eq!(chat.messages[2].role, Role::Assistant); + assert_eq!(chat.messages[2].content(), "Hello, world to you!"); } #[rstest] @@ -103,20 +101,16 @@ fn test_chat_push_all() { chat1.push_all(&chat2); assert_eq!(chat1.len(), 5); - - let json_value = chat1.to_json(); - let json = json_value.as_array().unwrap(); - - assert_eq!(json[0]["role"], "system"); - assert_eq!(json[0]["content"], "You are a helpful assistant."); - assert_eq!(json[1]["role"], "user"); - assert_eq!(json[1]["content"], "Hello!"); - assert_eq!(json[2]["role"], "assistant"); - assert_eq!(json[2]["content"], "Hi there!"); - assert_eq!(json[3]["role"], "user"); - assert_eq!(json[3]["content"], "How are you?"); - assert_eq!(json[4]["role"], "assistant"); - assert_eq!(json[4]["content"], "I'm doing well, thank you!"); + assert_eq!(chat1.messages[0].role, Role::System); + assert_eq!(chat1.messages[0].content(), "You are a helpful assistant."); + assert_eq!(chat1.messages[1].role, Role::User); + assert_eq!(chat1.messages[1].content(), "Hello!"); + assert_eq!(chat1.messages[2].role, Role::Assistant); + assert_eq!(chat1.messages[2].content(), "Hi there!"); + assert_eq!(chat1.messages[3].role, Role::User); + assert_eq!(chat1.messages[3].content(), "How are you?"); + assert_eq!(chat1.messages[4].role, Role::Assistant); + assert_eq!(chat1.messages[4].content(), "I'm doing well, thank you!"); } #[rstest] @@ -127,10 +121,164 @@ fn test_chat_push_all_empty() { chat1.push_all(&empty_chat); assert_eq!(chat1.len(), 1); + assert_eq!(chat1.messages[0].role, Role::System); + assert_eq!(chat1.messages[0].content(), "System message"); +} + +#[rstest] +fn test_new_variants_round_trip_json() { + let call = ToolCall::new( + "call-1".to_string(), + ToolFunction { + name: "lookup".to_string(), + arguments: json!({ "query": "rust" }), + }, + ); + let result = ToolResult { + id: "call-1".to_string(), + call_id: Some("provider-call-1".to_string()), + content: OneOrMany::one(ToolResultContent::text("result payload")), + }; + let reasoning = Reasoning::new("thinking..."); + + let chat = Chat::new(vec![ + Message::system("You are a tool-using assistant."), + Message::tool_call(call.clone()), + Message::tool_result(result.clone()), + Message::reasoning(reasoning.clone()), + ]); + + let json_dump = chat.to_json(); + let reparsed = Chat::new(vec![]).from_json(json_dump).unwrap(); + assert_eq!(reparsed.len(), 4); + + assert_eq!(reparsed.messages[0].role, Role::System); + + assert_eq!(reparsed.messages[1].role, Role::Assistant); + assert!(reparsed.messages[1].has_tool_calls()); + let reparsed_calls = reparsed.messages[1].tool_calls(); + assert_eq!(reparsed_calls[0].function.name, call.function.name); + + assert_eq!(reparsed.messages[2].role, Role::User); + assert!(reparsed.messages[2].has_tool_results()); + + assert_eq!(reparsed.messages[3].role, Role::Assistant); + assert!(reparsed.messages[3].has_reasoning()); +} - let json_value = chat1.to_json(); - let json = json_value.as_array().unwrap(); +#[rstest] +fn test_system_prompt_and_rig_chat_history() { + let chat = Chat::new(vec![ + Message::system("Be helpful"), + Message::user("Hello"), + Message::assistant("Hi!"), + ]); + + assert_eq!(chat.system_prompt(), "Be helpful"); + let history = chat.to_rig_chat_history(); + assert_eq!(history.len(), 2); // system excluded +} + +#[rstest] +fn test_empty_chat_system_prompt_and_rig_history() { + let chat = Chat::new(vec![]); + + assert_eq!(chat.system_prompt(), ""); + let history = chat.to_rig_chat_history(); + assert!(history.is_empty()); +} + +#[rstest] +fn test_from_rig_message_preserves_all_content() { + // User with text + tool result — both preserved + let user_msg = RigMessage::User { + content: OneOrMany::many(vec![ + UserContent::text("some context"), + UserContent::ToolResult(ToolResult { + id: "id-1".to_string(), + call_id: None, + content: OneOrMany::one(ToolResultContent::text("ok")), + }), + ]) + .unwrap(), + }; + let converted = Message::from(user_msg); + assert_eq!(converted.role, Role::User); + assert_eq!(converted.content.len(), 2); + assert!(matches!(converted.content[0], ContentBlock::Text { .. })); + assert!(matches!( + converted.content[1], + ContentBlock::ToolResult { .. } + )); + + // Assistant with reasoning + tool call — both preserved (was lossy before) + let assistant_msg = RigMessage::Assistant { + id: Some("asst-123".to_string()), + content: OneOrMany::many(vec![ + AssistantContent::Reasoning(Reasoning::new("step by step")), + AssistantContent::ToolCall(ToolCall::new( + "tool-2".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({ "q": "x" }), + }, + )), + ]) + .unwrap(), + }; + let converted = Message::from(assistant_msg); + assert_eq!(converted.role, Role::Assistant); + assert_eq!(converted.id, Some("asst-123".to_string())); + assert_eq!(converted.content.len(), 2); + assert!(converted.has_reasoning()); + assert!(converted.has_tool_calls()); +} + +#[rstest] +fn test_rig_round_trip_preserves_grouped_content() { + // Create a grouped assistant message with reasoning + tool call + let original_rig = RigMessage::Assistant { + id: None, + content: OneOrMany::many(vec![ + AssistantContent::Reasoning(Reasoning::new("thinking")), + AssistantContent::ToolCall(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({"q": "rust"}), + }, + )), + ]) + .unwrap(), + }; + + // Convert to DSRs Message + let dsrs_msg = Message::from(original_rig); + assert_eq!(dsrs_msg.content.len(), 2); + + // Convert back to rig message + let round_tripped = dsrs_msg.to_rig_message().unwrap(); + match round_tripped { + RigMessage::Assistant { content, .. } => { + assert_eq!(content.iter().count(), 2); // Both blocks preserved! + } + _ => panic!("expected assistant message"), + } +} + +#[rstest] +fn test_text_content_filters_non_text_blocks() { + let msg = Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("thinking")), + ContentBlock::text("the answer is 42"), + ], + ); - assert_eq!(json[0]["role"], "system"); - assert_eq!(json[0]["content"], "System message"); + // text_content() returns only Text blocks + assert_eq!(msg.text_content(), "the answer is 42"); + // content() returns everything + assert!(msg.content().contains("thinking")); + assert!(msg.content().contains("the answer is 42")); } diff --git a/crates/dspy-rs/tests/test_lm.rs b/crates/dspy-rs/tests/test_lm.rs index 41106d9f..9fed309c 100644 --- a/crates/dspy-rs/tests/test_lm.rs +++ b/crates/dspy-rs/tests/test_lm.rs @@ -84,16 +84,15 @@ fn test_lm_usage_add() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_with_cache_enabled() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - // Create LM with cache enabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(true) + .build(), + ) + .await + .unwrap(); // Verify cache handler is initialized assert!(lm.cache_handler.is_some()); @@ -102,16 +101,15 @@ async fn test_lm_with_cache_enabled() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_with_cache_disabled() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - // Create LM with cache explicitly disabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(false) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(false) + .build(), + ) + .await + .unwrap(); // Verify cache handler is NOT initialized when cache is disabled assert!(lm.cache_handler.is_none()); @@ -120,16 +118,15 @@ async fn test_lm_with_cache_disabled() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_cache_initialization_on_first_call() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - // Create LM with cache enabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(true) + .build(), + ) + .await + .unwrap(); // After build, cache_handler should be initialized assert!(lm.cache_handler.is_some()); @@ -138,20 +135,19 @@ async fn test_lm_cache_initialization_on_first_call() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_cache_direct_operations() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } use dspy_rs::Prediction; use dspy_rs::data::RawExample; use std::collections::HashMap; - // Create LM with cache enabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(true) + .build(), + ) + .await + .unwrap(); // Get cache handler let cache = lm @@ -207,20 +203,19 @@ async fn test_lm_cache_direct_operations() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_cache_with_different_models() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - std::env::set_var("ANTHROPIC_API_KEY", "test"); - } // Test that cache works with different model configurations let models = vec!["openai:gpt-3.5-turbo", "anthropic:claude-3-haiku-20240307"]; for model in models { - let lm = LM::builder() - .model(model.to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [ + ("OPENAI_API_KEY", Some("test")), + ("ANTHROPIC_API_KEY", Some("test")), + ], + LM::builder().model(model.to_string()).cache(true).build(), + ) + .await + .unwrap(); // Cache should be initialized regardless of model assert!( @@ -234,20 +229,19 @@ async fn test_lm_cache_with_different_models() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_cache_with_complex_inputs() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } use dspy_rs::Prediction; use dspy_rs::data::RawExample; use std::collections::HashMap; - // Create LM with cache enabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(true) + .build(), + ) + .await + .unwrap(); let cache = lm .cache_handler diff --git a/crates/dspy-rs/tests/test_message_roundtrip.rs b/crates/dspy-rs/tests/test_message_roundtrip.rs new file mode 100644 index 00000000..96461483 --- /dev/null +++ b/crates/dspy-rs/tests/test_message_roundtrip.rs @@ -0,0 +1,345 @@ +//! Round-trip tests for the new Message model. +//! +//! Verifies that the grouped Role + ContentBlock representation preserves +//! all content through: DSRs Message → rig Message → DSRs Message, and +//! through JSON serialization/deserialization. + +use dspy_rs::core::lm::chat::{Chat, ContentBlock, Message, Role}; +use rig::OneOrMany; +use rig::message::{ + Message as RigMessage, Reasoning, ToolCall, ToolFunction, ToolResult, ToolResultContent, +}; +use serde_json::json; + +// --------------------------------------------------------------------------- +// Reasoning continuity round-trip +// --------------------------------------------------------------------------- + +/// Anthropic's thinking turns produce [Reasoning, Reasoning, ToolCall] in a +/// single assistant turn. The entire chain of thought must survive: +/// DSRs Message → rig → DSRs Message +#[test] +fn reasoning_chain_survives_rig_roundtrip() { + let original = Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("step 1: analyze the query")), + ContentBlock::reasoning(Reasoning::new("step 2: plan the search")), + ContentBlock::tool_call(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({"q": "rust ownership"}), + }, + )), + ], + ); + + // Forward: DSRs → rig + let rig_msg = original + .to_rig_message() + .expect("assistant message should convert to rig"); + + // Backward: rig → DSRs + let roundtripped = Message::from(rig_msg); + + assert_eq!(roundtripped.role, Role::Assistant); + assert_eq!( + roundtripped.content.len(), + 3, + "all three content blocks must survive: got {:?}", + roundtripped.content + ); + + assert!( + matches!(&roundtripped.content[0], ContentBlock::Reasoning { reasoning } if reasoning.display_text().contains("step 1")), + "first reasoning block lost" + ); + assert!( + matches!(&roundtripped.content[1], ContentBlock::Reasoning { reasoning } if reasoning.display_text().contains("step 2")), + "second reasoning block lost" + ); + assert!( + matches!(&roundtripped.content[2], ContentBlock::ToolCall { tool_call } if tool_call.function.name == "search"), + "tool call lost" + ); +} + +/// A reasoning-only assistant turn (no text, no tool call) must round-trip. +#[test] +fn reasoning_only_turn_roundtrips() { + let original = Message::with_content( + Role::Assistant, + vec![ContentBlock::reasoning(Reasoning::new( + "just thinking out loud", + ))], + ); + + let rig_msg = original.to_rig_message().unwrap(); + let roundtripped = Message::from(rig_msg); + + assert_eq!(roundtripped.role, Role::Assistant); + assert_eq!(roundtripped.content.len(), 1); + assert!(roundtripped.has_reasoning()); + assert!(!roundtripped.has_tool_calls()); +} + +// --------------------------------------------------------------------------- +// Multi-content user messages +// --------------------------------------------------------------------------- + +/// A user message with both text and a tool result must preserve both. +#[test] +fn user_text_plus_tool_result_roundtrips() { + let original = Message::with_content( + Role::User, + vec![ + ContentBlock::text("Here is context"), + ContentBlock::tool_result(ToolResult { + id: "tr-1".to_string(), + call_id: Some("tc-1".to_string()), + content: OneOrMany::one(ToolResultContent::text("search result")), + }), + ], + ); + + let rig_msg = original.to_rig_message().unwrap(); + let roundtripped = Message::from(rig_msg); + + assert_eq!(roundtripped.role, Role::User); + assert_eq!( + roundtripped.content.len(), + 2, + "both text and tool result must survive" + ); + assert!(matches!( + &roundtripped.content[0], + ContentBlock::Text { text } if text == "Here is context" + )); + assert!(roundtripped.has_tool_results()); +} + +// --------------------------------------------------------------------------- +// Multi-turn conversation with reasoning in history +// --------------------------------------------------------------------------- + +/// Build a multi-turn conversation where an earlier assistant turn has +/// reasoning blocks. Convert the full chat to rig format and back. +/// The reasoning from earlier turns must be preserved. +#[test] +fn multi_turn_conversation_preserves_earlier_reasoning() { + let chat = Chat::new(vec![ + Message::system("You are a helpful assistant."), + Message::user("What is the capital of France?"), + // Turn 1 reply: reasoning + text answer + Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("The user is asking about geography.")), + ContentBlock::text("The capital of France is Paris."), + ], + ), + // User follow-up + Message::user("And Germany?"), + // Turn 2 reply: just text + Message::assistant("The capital of Germany is Berlin."), + ]); + + // Convert to rig and back + let rig_history = chat.to_rig_chat_history(); + // rig_history should have 4 messages (system excluded) + assert_eq!(rig_history.len(), 4); + + // Reconstruct from rig history + let mut reconstructed = Chat::new(vec![Message::system(chat.system_prompt())]); + for rig_msg in rig_history { + reconstructed.push_message(Message::from(rig_msg)); + } + + assert_eq!(reconstructed.len(), 5); + + // Verify turn 1's reasoning survived + let turn1_reply = &reconstructed.messages[2]; + assert_eq!(turn1_reply.role, Role::Assistant); + assert!( + turn1_reply.has_reasoning(), + "turn 1 reasoning must survive rig round-trip" + ); + assert_eq!( + turn1_reply.content.len(), + 2, + "turn 1 must have both reasoning and text" + ); +} + +// --------------------------------------------------------------------------- +// JSON serialization round-trip +// --------------------------------------------------------------------------- + +/// Full multi-content message survives JSON serialization. +#[test] +fn grouped_message_json_roundtrip() { + let original = Chat::new(vec![ + Message::system("Be helpful"), + Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("let me think")), + ContentBlock::text("the answer is 42"), + ContentBlock::tool_call(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "verify".to_string(), + arguments: json!({"answer": 42}), + }, + )), + ], + ), + Message::with_content( + Role::User, + vec![ + ContentBlock::tool_result(ToolResult { + id: "tc-1".to_string(), + call_id: None, + content: OneOrMany::one(ToolResultContent::text("confirmed")), + }), + ContentBlock::text("Thanks! Can you also check 43?"), + ], + ), + ]); + + let json = original.to_json(); + let reparsed = Chat::new(vec![]).from_json(json).unwrap(); + + assert_eq!(reparsed.len(), 3); + + // Verify the assistant message preserved all 3 content blocks + let asst = &reparsed.messages[1]; + assert_eq!(asst.role, Role::Assistant); + assert_eq!(asst.content.len(), 3); + assert!(asst.has_reasoning()); + assert!(asst.has_tool_calls()); + + // Verify the user message preserved both blocks + let user = &reparsed.messages[2]; + assert_eq!(user.role, Role::User); + assert_eq!(user.content.len(), 2); + assert!(user.has_tool_results()); +} + +/// Legacy JSON format (content as plain string) still parses correctly. +#[test] +fn legacy_plain_string_json_parses_into_new_model() { + let legacy_json = json!([ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ]); + + let chat = Chat::new(vec![]).from_json(legacy_json).unwrap(); + assert_eq!(chat.len(), 3); + assert_eq!(chat.messages[0].role, Role::System); + assert_eq!(chat.messages[0].content(), "Be helpful"); + assert_eq!(chat.messages[2].text_content(), "Hi there!"); +} + +// --------------------------------------------------------------------------- +// Accessor correctness +// --------------------------------------------------------------------------- + +#[test] +fn text_content_excludes_non_text_blocks() { + let msg = Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("internal monologue")), + ContentBlock::text("visible answer"), + ContentBlock::tool_call(ToolCall::new( + "tc".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({}), + }, + )), + ], + ); + + assert_eq!(msg.text_content(), "visible answer"); + // content() includes everything + let full = msg.content(); + assert!(full.contains("internal monologue")); + assert!(full.contains("visible answer")); + assert!(full.contains("search")); +} + +#[test] +fn tool_calls_accessor_returns_all_tool_calls() { + let msg = Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("planning")), + ContentBlock::tool_call(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({"q": "a"}), + }, + )), + ContentBlock::tool_call(ToolCall::new( + "tc-2".to_string(), + ToolFunction { + name: "calculate".to_string(), + arguments: json!({"expr": "1+1"}), + }, + )), + ], + ); + + let calls = msg.tool_calls(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].function.name, "search"); + assert_eq!(calls[1].function.name, "calculate"); +} + +// --------------------------------------------------------------------------- +// Edge cases +// --------------------------------------------------------------------------- + +/// Empty content vec (pathological) should not panic. +#[test] +fn empty_content_message_does_not_panic() { + let msg = Message::with_content(Role::Assistant, vec![]); + assert_eq!(msg.content(), ""); + assert_eq!(msg.text_content(), ""); + assert!(!msg.has_tool_calls()); + assert!(!msg.has_reasoning()); + + // Rig conversion should produce an assistant message with empty text + let rig_msg = msg.to_rig_message().unwrap(); + match rig_msg { + RigMessage::Assistant { content, .. } => { + assert_eq!(content.iter().count(), 1); // empty text fallback + } + _ => panic!("expected assistant message"), + } +} + +/// System messages return None from to_rig_message (handled as preamble). +#[test] +fn system_message_excluded_from_rig_conversion() { + let msg = Message::system("You are helpful"); + assert!(msg.to_rig_message().is_none()); +} + +/// Message ID (e.g. Anthropic thinking turn IDs) survives round-trip. +#[test] +fn message_id_survives_rig_roundtrip() { + let mut msg = Message::assistant("some text"); + msg.id = Some("msg_abc123".to_string()); + + let rig_msg = msg.to_rig_message().unwrap(); + let roundtripped = Message::from(rig_msg); + + // Note: rig's User messages don't carry IDs, but Assistant messages do + assert_eq!(roundtripped.id, Some("msg_abc123".to_string())); +} diff --git a/crates/dspy-rs/tests/test_predict_conversation.rs b/crates/dspy-rs/tests/test_predict_conversation.rs new file mode 100644 index 00000000..5ec381a4 --- /dev/null +++ b/crates/dspy-rs/tests/test_predict_conversation.rs @@ -0,0 +1,159 @@ +use dspy_rs::{ + ChatAdapter, LM, LMClient, Message, Predict, Role, Signature, TestCompletionModel, configure, +}; +use rig::completion::{AssistantContent, CompletionRequest}; +use rig::message::{Message as RigMessage, Text, UserContent}; +use std::sync::LazyLock; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn response_with_fields(fields: &[(&str, &str)]) -> String { + let mut response = String::new(); + for (name, value) in fields { + response.push_str(&format!("[[ ## {name} ## ]]\n{value}\n\n")); + } + response.push_str("[[ ## completed ## ]]\n"); + response +} + +fn text_response(text: impl Into) -> AssistantContent { + AssistantContent::Text(Text { text: text.into() }) +} + +async fn configure_test_lm(responses: Vec) -> TestCompletionModel { + let client = TestCompletionModel::new(responses.into_iter().map(text_response)); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client.clone())) + .await + .unwrap(); + + configure(lm, ChatAdapter {}); + + client +} + +fn request_contains_text(request: &CompletionRequest, needle: &str) -> bool { + if request + .preamble + .as_ref() + .is_some_and(|preamble| preamble.contains(needle)) + { + return true; + } + + for message in request.chat_history.iter() { + match message { + RigMessage::User { content } => { + for item in content.iter() { + if let UserContent::Text(text) = item + && text.text.contains(needle) + { + return true; + } + } + } + RigMessage::Assistant { content, .. } => { + for item in content.iter() { + match item { + AssistantContent::Text(text) if text.text.contains(needle) => return true, + AssistantContent::Reasoning(reasoning) + if reasoning.display_text().contains(needle) => + { + return true; + } + _ => {} + } + } + } + } + } + + false +} + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Conversational QA test signature. +struct ConversationQA { + #[input] + question: String, + + #[output] + answer: String, +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn forward_returns_chat_and_prediction() { + let _lock = SETTINGS_LOCK.lock().await; + let response = response_with_fields(&[("answer", "Paris")]); + let _client = configure_test_lm(vec![response]).await; + + let predict = Predict::::new(); + let input = ConversationQAInput { + question: "What is the capital of France?".to_string(), + }; + + let chat = predict + .build_chat(&input) + .expect("build_chat should succeed"); + let (predicted, chat) = predict + .call_and_parse(chat) + .await + .expect("first turn should succeed"); + + assert_eq!(predicted.into_inner().answer, "Paris"); + assert_eq!(chat.len(), 3); + assert_eq!(chat.messages[0].role, Role::System); + assert_eq!(chat.messages[1].role, Role::User); + assert_eq!(chat.messages[2].role, Role::Assistant); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn forward_continue_supports_two_turn_roundtrip() { + let _lock = SETTINGS_LOCK.lock().await; + let first_response = response_with_fields(&[("answer", "First turn answer")]); + let second_response = response_with_fields(&[("answer", "Second turn answer")]); + let client = configure_test_lm(vec![first_response, second_response]).await; + + let predict = Predict::::new(); + let first_input = ConversationQAInput { + question: "Turn 1 question".to_string(), + }; + + // First turn: build fresh chat + let chat = predict + .build_chat(&first_input) + .expect("build_chat should succeed"); + let (first_predicted, mut chat) = predict + .call_and_parse(chat) + .await + .expect("first turn should succeed"); + assert_eq!(first_predicted.into_inner().answer, "First turn answer"); + + // Second turn: append follow-up, continue conversation + let caller_follow_up = "Caller follow-up message"; + chat.push_message(Message::user(caller_follow_up)); + + let (second_predicted, second_chat) = predict + .forward_continue(chat) + .await + .expect("second turn should succeed"); + + assert_eq!(second_predicted.into_inner().answer, "Second turn answer"); + assert!(second_chat.len() >= 5); + + // Verify the follow-up text was sent to the LM + let last_request = client + .last_request() + .expect("test model should capture last request"); + assert!(request_contains_text(&last_request, caller_follow_up)); +} diff --git a/crates/dspy-rs/tests/test_predict_conversation_live.rs b/crates/dspy-rs/tests/test_predict_conversation_live.rs new file mode 100644 index 00000000..984fe7a7 --- /dev/null +++ b/crates/dspy-rs/tests/test_predict_conversation_live.rs @@ -0,0 +1,65 @@ +use dspy_rs::{ChatAdapter, LM, Message, Predict, Signature, configure}; +use std::sync::LazyLock; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Live multi-turn conversation signature. +struct LiveConversation { + #[input] + prompt: String, + + #[output] + answer: String, +} + +#[tokio::test] +#[ignore] // Requires real network access and provider API key(s) +async fn live_forward_continue_two_turn_roundtrip() { + let _lock = SETTINGS_LOCK.lock().await; + + let lm = LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .temperature(0.0) + .max_tokens(256) + .build() + .await + .expect("failed to build LM for live smoke test"); + configure(lm, ChatAdapter {}); + + let predict = Predict::::new(); + + // First turn: build and call + let first_input = LiveConversationInput { + prompt: "Reply with the word ONE.".to_string(), + }; + let chat = predict + .build_chat(&first_input) + .expect("build_chat should succeed"); + let (first, mut chat) = predict + .call_and_parse(chat) + .await + .expect("first turn failed"); + assert!( + !first.answer.trim().is_empty(), + "first turn answer should not be empty" + ); + + // Second turn: append follow-up, continue + chat.push_message(Message::user( + "Now reply with the word TWO. Use the same answer field format.", + )); + + let (second, chat2) = predict + .forward_continue(chat) + .await + .expect("second turn failed"); + + assert!( + second.answer.to_ascii_lowercase().contains("two"), + "second turn answer should include 'two', got: {}", + second.answer + ); + assert!(chat2.len() >= 5, "chat should grow across turns"); +} diff --git a/crates/dspy-rs/tests/test_react_builder.rs b/crates/dspy-rs/tests/test_react_builder.rs index 12ef96f7..40f69e33 100644 --- a/crates/dspy-rs/tests/test_react_builder.rs +++ b/crates/dspy-rs/tests/test_react_builder.rs @@ -31,19 +31,18 @@ fn parse_calculator_args(args: &str) -> (i64, i64) { } async fn configure_test_lm(responses: Vec) { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - let client = TestCompletionModel::new(responses.into_iter().map(text_response)); - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .build() - .await - .unwrap() - .with_client(LMClient::Test(client)) - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client)) + .await + .unwrap(); configure(lm, ChatAdapter {}); } diff --git a/crates/dspy-rs/tests/test_settings.rs b/crates/dspy-rs/tests/test_settings.rs index 2b2bea2d..3bc328fd 100644 --- a/crates/dspy-rs/tests/test_settings.rs +++ b/crates/dspy-rs/tests/test_settings.rs @@ -3,31 +3,27 @@ use dspy_rs::{ChatAdapter, LM, configure, get_lm}; #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_settings() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - configure( + let lm1 = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], LM::builder() .model("openai:gpt-4o-mini".to_string()) - .build() - .await - .unwrap(), - ChatAdapter {}, - ); + .build(), + ) + .await + .unwrap(); + configure(lm1, ChatAdapter {}); let lm = get_lm(); assert_eq!(lm.model, "openai:gpt-4o-mini"); - configure( - LM::builder() - .model("openai:gpt-4o".to_string()) - .build() - .await - .unwrap(), - ChatAdapter {}, - ); + let lm2 = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder().model("openai:gpt-4o".to_string()).build(), + ) + .await + .unwrap(); + configure(lm2, ChatAdapter {}); let lm = get_lm(); - assert_eq!(lm.model, "openai:gpt-4o"); } diff --git a/crates/dspy-rs/tests/test_tool_call.rs b/crates/dspy-rs/tests/test_tool_call.rs index 566f36f5..80a8b741 100644 --- a/crates/dspy-rs/tests/test_tool_call.rs +++ b/crates/dspy-rs/tests/test_tool_call.rs @@ -108,13 +108,10 @@ async fn test_tool_call_with_no_tools() { } let response = response.unwrap(); - match response.output { - Message::Assistant { content } => { - // The response should contain some mention of 4 - println!("Assistant response: {}", content); - } - _ => panic!("Expected assistant message"), - } + assert_eq!(response.output.role, dspy_rs::Role::Assistant); + let content = response.output.content(); + // The response should contain some mention of 4 + println!("Assistant response: {}", content); } #[tokio::test] @@ -140,12 +137,9 @@ async fn test_tool_call_with_calculator() { // Call with the calculator tool let response = lm.call(chat, tools).await.unwrap(); - match response.output { - Message::Assistant { content } => { - println!("Assistant response after tool use: {}", content); - // The response should mention the result (100) or that the tool was called - assert!(content.contains("100") || content.contains("Tool call")); - } - _ => panic!("Expected assistant message"), - } + assert_eq!(response.output.role, dspy_rs::Role::Assistant); + let content = response.output.content(); + println!("Assistant response after tool use: {}", content); + // The response should mention the result (100) or that the tool was called + assert!(content.contains("100") || content.contains("Tool call")); } diff --git a/crates/dspy-rs/tests/typed_integration.rs b/crates/dspy-rs/tests/typed_integration.rs index 6b6f7968..973dc0a4 100644 --- a/crates/dspy-rs/tests/typed_integration.rs +++ b/crates/dspy-rs/tests/typed_integration.rs @@ -23,19 +23,18 @@ fn text_response(text: impl Into) -> AssistantContent { } async fn configure_test_lm(responses: Vec) -> TestCompletionModel { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - let client = TestCompletionModel::new(responses.into_iter().map(text_response)); - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .build() - .await - .unwrap() - .with_client(LMClient::Test(client.clone())) - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client.clone())) + .await + .unwrap(); configure(lm, ChatAdapter {}); From eb92f4248bfe7b1af68797f981ff102f2bc6f89f Mon Sep 17 00:00:00 2001 From: darin Date: Thu, 19 Feb 2026 20:56:07 -0800 Subject: [PATCH 02/44] fix: use new_spanned for stable/nightly-consistent trybuild snapshots Two trybuild tests (render_invalid_jinja, render_non_literal) failed on CI because syn::Error::new(span, msg) with .span() produces different underline widths on stable (CI) vs nightly (local). Switch to syn::Error::new_spanned(tokens, msg) which reliably spans from first to last token regardless of compiler version. --- crates/dsrs-macros/src/lib.rs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/crates/dsrs-macros/src/lib.rs b/crates/dsrs-macros/src/lib.rs index 320d521c..c00f3fba 100644 --- a/crates/dsrs-macros/src/lib.rs +++ b/crates/dsrs-macros/src/lib.rs @@ -6,7 +6,6 @@ use syn::{ Token, Visibility, parse::{Parse, ParseStream}, parse_macro_input, - spanned::Spanned, visit::Visit, }; @@ -268,7 +267,7 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { )); } let template = parse_render_jinja_attr(attr)?; - validate_jinja_template(&template, attr.span())?; + validate_jinja_template(&template, attr)?; render_jinja = Some(template); } else if attr.path().is_ident("flatten") { if saw_flatten { @@ -367,7 +366,7 @@ fn parse_desc_from_attr(attr: &Attribute, attr_name: &str) -> syn::Result