diff --git a/Cargo.lock b/Cargo.lock index d0e24aeb..87e8e697 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1227,15 +1227,18 @@ dependencies = [ "kdam", "minijinja", "parquet", + "pyo3", "rand 0.8.5", "rayon", "regex", "reqwest 0.13.2", "rig-core", + "rlm-derive", "rstest 0.25.0", "schemars", "serde", "serde_json", + "temp-env", "tempfile", "thiserror 2.0.17", "tokio", @@ -3251,6 +3254,67 @@ dependencies = [ "unarray", ] +[[package]] +name = "pyo3" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d" +dependencies = [ + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b455933107de8642b4487ed26d912c2d899dec6114884214a0b3bb3be9261ea6" +dependencies = [ + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c85c9cbfaddf651b1221594209aed57e9e5cff63c4d11d1feead529b872a089" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a5b10c9bf9888125d917fb4d2ca2d25c8df94c7ab5a52e13313a07e050a3b02" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03b51720d314836e53327f5871d4c0cfb4fb37cc2c4a11cc71907a86342c40f9" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.106", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -3633,6 +3697,19 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rlm-derive" +version = "0.1.0" +dependencies = [ + "dspy-rs", + "facet", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.106", + "trybuild", +] + [[package]] name = "rstest" version = "0.22.0" @@ -4233,12 +4310,28 @@ dependencies = [ "libc", ] +[[package]] +name = "target-lexicon" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" + [[package]] name = "target-triple" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" +[[package]] +name = "temp-env" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96374855068f47402c3121c6eed88d29cb1de8f3ab27090e273e420bdabcf050" +dependencies = [ + "futures", + "parking_lot", +] + [[package]] name = "tempfile" version = "3.23.0" @@ -4778,6 +4871,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + [[package]] name = "unsafe-libyaml" version = "0.2.11" diff --git a/INVESTIGATION_facet_baml_bridge.md b/INVESTIGATION_facet_baml_bridge.md new file mode 100644 index 00000000..97773740 --- /dev/null +++ b/INVESTIGATION_facet_baml_bridge.md @@ -0,0 +1,130 @@ +# Investigation: Facet ↔ BAML Bridge Redundancy + +## Summary + +The codebase has **two independent paths** that produce BAML `TypeIR` from facet type metadata, causing divergence. BAML's native rendering is already used — the problem isn't that we're re-implementing rendering, it's that `SignatureSchema` builds its own TypeIR from raw `facet::Shape` while `bamltype`'s `SchemaBuilder` builds a richer, field-attr-aware TypeIR. These can disagree silently. + +## The Two Paths (Root Problem) + +### Path 1: `bamltype` SchemaBuilder (full-fidelity) +``` +#[BamlType] struct → facet::Facet derive → facet attrs (bamltype::*) + → SchemaBuilder::build_field_type_ir(field, owner, variant) + → checks field attrs: with adapters, int_repr, map_key_repr + → registers Classes/Enums into SchemaRegistry + → builds OutputFormatContent (with recursive class detection) + → cached in BamlSchema::baml_schema() as SchemaBundle +``` +**Location:** `crates/bamltype/src/schema_builder.rs:453-490` (`build_field_type_ir`) + +This path sees `#[baml(with="Codec")]`, `#[baml(int_repr="string")]`, `#[baml(map_key_repr="pairs")]` and transforms the TypeIR accordingly. An adapter can completely replace a field's type. A map can become a list of generated entry classes. + +### Path 2: `SignatureSchema` (shape-only, loses field attrs) +``` +#[derive(Signature)] struct → facet::Shape for Input/Output + → collect_fields() iterates struct fields + → emit_field() calls build_type_ir_from_shape(field.shape()) + → TypeIR built from shape alone, NO field attr awareness + → stored in FieldSchema.type_ir +``` +**Location:** `crates/dspy-rs/src/core/schema.rs:337` — the critical line: +```rust +let mut type_ir = build_type_ir_from_shape(field.shape()); +``` + +This calls `schema_builder::build_type_ir_from_shape()` which creates a **fresh SchemaBuilder** and calls `build_type_ir(shape)` — NOT `build_field_type_ir(field, ...)`. It never sees field-level attributes. + +### Where They're Used Together (Mismatch Surface) + +| Consumer | Uses FieldSchema.type_ir (Path 2) | Uses OutputFormatContent (Path 1) | +|----------|----------------------------------|----------------------------------| +| `ChatAdapter::parse_structured_output_with_meta` | ✅ `jsonish::from_str(..., &field.type_ir, ...)` | ✅ `schema.output_format()` | +| RLM Output Contract (prompt.rs:99) | ✅ `field.type_ir.diagnostic_repr()` | ❌ | +| RLM py_bridge kwargs coercion | ✅ `field.type_ir` for dispatch | ✅ `output_format` for class/enum lookups | +| ChatAdapter field schema rendering | ❌ | ✅ `OutputFormatContent::render()` | + +When Path 1 and Path 2 disagree (e.g., a field has `int_repr="string"` or `with="Codec"`), `jsonish` gets a TypeIR that says "int" while OutputFormatContent says "string" (or a completely different adapter type). This is a silent correctness bug. + +## What BAML "Native" Actually Means Here + +BAML's native rendering (`internal-baml-jinja`) is **already used**: +- `OutputFormatContent::render(options)` — schema prompt text ✅ +- `jsonish::from_str(...)` — LLM output parsing ✅ +- `format_baml_value(...)` — value formatting ✅ + +The custom bridge (`crates/bamltype`) builds the **inputs** to native rendering: +- `facet::Shape` → `TypeIR` (the type graph) +- `facet::Shape` → `OutputFormatContent` (class/enum registry) + +You can't remove this bridge without a replacement source of truth (e.g., a BAML compiler, or user-authored BAML schemas). + +## RlmType: Not a Schema Divergence + +`#[rlm_type]` is a composition macro, not a competing schema system: +```rust +// rlm_attr.rs:43-45 — it literally just adds these: +input.attrs.push(syn::parse_quote!(#[pyclass(...)])); +input.attrs.push(syn::parse_quote!(#[BamlType])); +merge_derive(&mut input.attrs, &[syn::parse_quote!(RlmType)]); +``` + +`RlmType` derive adds Python interop methods (`__baml__`, `__repr__`, `__iter__`, etc.) that delegate to `BamlType` for conversion. There's no schema divergence here — it's a pure consumer of `bamltype`. + +## Internal Name Drift + +There's a subtle naming inconsistency between two functions that compute BAML internal names: + +**`schema_builder::internal_name_for_shape(shape)`** (schema_builder.rs:44-55): +```rust +// Uses module_path::type_identifier +format!("{module}::{}", shape.type_identifier) +``` + +**`runtime::baml_internal_name::()`** (runtime.rs:80-94): +```rust +// Falls back to std::any::type_name::() +std::any::type_name::() +``` + +`std::any::type_name` returns e.g. `my_crate::my_module::MyType` while `internal_name_for_shape` returns `my_module::MyType`. These could drift in edge cases, causing class lookup failures in value conversion or formatting. + +## Complexity Hotspots + +### 1. Adapter Function Pointers in Facet Attrs +`bamltype-derive` encodes function pointers (`WithAdapterFns`) into facet attribute metadata. These are `fn()` pointers stored as `&'static dyn Any` in compile-time reflection data. This works but is deeply non-obvious and makes the bridge hard to replace. + +### 2. Map Key Repr "pairs" Generates Phantom Classes +`map_key_repr="pairs"` lowers `Map` → `List` and registers a generated class. Any code that assumes maps stay maps will break. + +### 3. Two Value Conversion Engines +- `bamltype/src/convert.rs`: Rust value ↔ BamlValue (facet Peek-based) +- `rlm/py_bridge.rs`: Python value → BamlValue (TypeIR + OutputFormatContent-aware) + +Both walk value trees against schemas, both have relaxed parsing heuristics, both could diverge. + +## Recommendations + +### Fix 1: Make SignatureSchema source TypeIR from bamltype's SchemaBundle (HIGH PRIORITY) + +Instead of: +```rust +let mut type_ir = build_type_ir_from_shape(field.shape()); +``` + +Do one of: +- **Option A**: Look up the field's TypeIR from `::baml_schema().output_format` class definitions +- **Option B**: Expose `SchemaBuilder::build_field_type_ir` as a public API that `SignatureSchema` can call + +This eliminates the "two sources of truth" problem entirely. + +### Fix 2: Unify internal name computation + +Change `runtime::baml_internal_name::()` fallback from `type_name::()` to `internal_name_for_shape(T::SHAPE)`. + +### Fix 3: Use OutputFormatContent::render for RLM Output Contract + +Instead of `field.type_ir.diagnostic_repr()` (which uses the divergent Path 2 TypeIR), render the contract using the same native rendering used for structured output prompts. + +### Fix 4 (Optional): Consolidate py_bridge coercion through jsonish + +Normalize Python values to JSON, then use `jsonish::from_str(output_format, type_ir, ...)` instead of a parallel walker. Keeps one coercion engine. diff --git a/crates/dspy-rs/Cargo.toml b/crates/dspy-rs/Cargo.toml index 8a8d5a5f..1e2e795c 100644 --- a/crates/dspy-rs/Cargo.toml +++ b/crates/dspy-rs/Cargo.toml @@ -45,9 +45,15 @@ enum_dispatch = "0.3.13" tracing = "0.1.44" tracing-subscriber = { version = "0.3.22", features = ["env-filter", "fmt"] } minijinja = { git = "https://github.com/boundaryml/minijinja.git", branch = "main", default-features = false, features = ["builtins", "serde"] } +pyo3 = { version = "0.27", features = ["auto-initialize"], optional = true } +rlm-derive = { path = "../rlm-derive", optional = true } [package.metadata.cargo-machete] ignored = ["rig-core"] [features] default = [] +rlm = ["dep:pyo3", "dep:rlm-derive", "dsrs_macros/rlm"] + +[dev-dependencies] +temp-env = { version = "0.3.6", features = ["async_closure"] } diff --git a/crates/dspy-rs/examples/01-simple.rs b/crates/dspy-rs/examples/01-simple.rs index d662d92e..027d486b 100644 --- a/crates/dspy-rs/examples/01-simple.rs +++ b/crates/dspy-rs/examples/01-simple.rs @@ -17,8 +17,8 @@ use anyhow::Result; use bon::Builder; use dspy_rs::data::RawExample; use dspy_rs::{ - CallMetadata, ChatAdapter, Example, LM, LmError, Module, Predict, PredictError, Predicted, - Prediction, configure, init_tracing, + CallMetadata, Chat, ChatAdapter, Example, LM, LmError, Module, Predict, PredictError, + Predicted, Prediction, configure, init_tracing, }; const QA_INSTRUCTION: &str = "Answer the question step by step."; @@ -115,7 +115,11 @@ impl Module for QARater { .data .insert("rating".into(), rate_result.rating.into()); - Ok(Predicted::new(combined, CallMetadata::default())) + Ok(Predicted::new( + combined, + CallMetadata::default(), + Chat::new(vec![]), + )) } } @@ -128,7 +132,7 @@ async fn main() -> Result<()> { .model("openai:gpt-4o-mini".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); // ========================================================================= @@ -147,7 +151,7 @@ async fn main() -> Result<()> { println!("Reasoning: {}", output.reasoning); println!("Answer: {}", output.answer); - // Predicted carries both typed output and metadata. + // Predicted carries typed output, metadata, and chat history. let result = predict.call(input).await?; println!("\nWith metadata:"); println!( diff --git a/crates/dspy-rs/examples/02-module-iteration-and-updation.rs b/crates/dspy-rs/examples/02-module-iteration-and-updation.rs index d0d09893..e35c10cd 100644 --- a/crates/dspy-rs/examples/02-module-iteration-and-updation.rs +++ b/crates/dspy-rs/examples/02-module-iteration-and-updation.rs @@ -83,7 +83,7 @@ async fn main() -> Result<()> { .model("openai:gpt-4o-mini".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let metric = ExactMatch; diff --git a/crates/dspy-rs/examples/03-evaluate-hotpotqa.rs b/crates/dspy-rs/examples/03-evaluate-hotpotqa.rs index f9cf6a69..94a6b6e1 100644 --- a/crates/dspy-rs/examples/03-evaluate-hotpotqa.rs +++ b/crates/dspy-rs/examples/03-evaluate-hotpotqa.rs @@ -48,7 +48,7 @@ async fn main() -> Result<()> { .model("openai:gpt-4o-mini".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let examples = DataLoader::load_hf::( diff --git a/crates/dspy-rs/examples/04-optimize-hotpotqa.rs b/crates/dspy-rs/examples/04-optimize-hotpotqa.rs index 0907db86..ba13143f 100644 --- a/crates/dspy-rs/examples/04-optimize-hotpotqa.rs +++ b/crates/dspy-rs/examples/04-optimize-hotpotqa.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { .model("openai:gpt-4o-mini".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let examples = DataLoader::load_hf::( diff --git a/crates/dspy-rs/examples/05-heterogenous-examples.rs b/crates/dspy-rs/examples/05-heterogenous-examples.rs index d32d01ea..1795d7d3 100644 --- a/crates/dspy-rs/examples/05-heterogenous-examples.rs +++ b/crates/dspy-rs/examples/05-heterogenous-examples.rs @@ -34,7 +34,7 @@ async fn main() -> Result<()> { .model("openai:gpt-4o-mini".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let heterogeneous = RawExample::new( diff --git a/crates/dspy-rs/examples/06-other-providers-batch.rs b/crates/dspy-rs/examples/06-other-providers-batch.rs index 57cf792b..3d360523 100644 --- a/crates/dspy-rs/examples/06-other-providers-batch.rs +++ b/crates/dspy-rs/examples/06-other-providers-batch.rs @@ -49,7 +49,7 @@ async fn main() -> Result<()> { .model("anthropic:claude-sonnet-4-5-20250929".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let mut anthropic = Vec::new(); @@ -63,7 +63,7 @@ async fn main() -> Result<()> { .model("gemini:gemini-2.0-flash".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let mut gemini = Vec::new(); diff --git a/crates/dspy-rs/examples/07-inspect-history.rs b/crates/dspy-rs/examples/07-inspect-history.rs index b15b5cec..eb4e9d60 100644 --- a/crates/dspy-rs/examples/07-inspect-history.rs +++ b/crates/dspy-rs/examples/07-inspect-history.rs @@ -27,7 +27,7 @@ async fn main() -> Result<()> { .model("openai:gpt-4o-mini".to_string()) .build() .await?; - configure(lm, ChatAdapter); + configure(lm, ChatAdapter::new()); let predictor = Predict::::new(); let output = predictor diff --git a/crates/dspy-rs/examples/08-optimize-mipro.rs b/crates/dspy-rs/examples/08-optimize-mipro.rs index 6fab8439..24c41d4b 100644 --- a/crates/dspy-rs/examples/08-optimize-mipro.rs +++ b/crates/dspy-rs/examples/08-optimize-mipro.rs @@ -74,7 +74,7 @@ async fn main() -> Result<()> { println!("=== MIPROv2 Optimizer Example ===\n"); - configure(LM::default(), ChatAdapter); + configure(LM::default(), ChatAdapter::new()); println!("Loading training data from HuggingFace..."); let train_examples = DataLoader::load_hf::( diff --git a/crates/dspy-rs/examples/09-gepa-sentiment.rs b/crates/dspy-rs/examples/09-gepa-sentiment.rs index 515fe70b..7158bcfe 100644 --- a/crates/dspy-rs/examples/09-gepa-sentiment.rs +++ b/crates/dspy-rs/examples/09-gepa-sentiment.rs @@ -88,7 +88,10 @@ fn sentiment_example(text: &str, expected: &str) -> Example async fn main() -> Result<()> { init_tracing()?; - configure(LM::builder().temperature(0.7).build().await?, ChatAdapter); + configure( + LM::builder().temperature(0.7).build().await?, + ChatAdapter::new(), + ); let trainset = vec![ sentiment_example( diff --git a/crates/dspy-rs/examples/10-gepa-llm-judge.rs b/crates/dspy-rs/examples/10-gepa-llm-judge.rs index 95255284..60337444 100644 --- a/crates/dspy-rs/examples/10-gepa-llm-judge.rs +++ b/crates/dspy-rs/examples/10-gepa-llm-judge.rs @@ -150,7 +150,10 @@ fn training_example(problem: &str, expected_answer: &str) -> Example Result<()> { init_tracing()?; - configure(LM::builder().temperature(0.7).build().await?, ChatAdapter); + configure( + LM::builder().temperature(0.7).build().await?, + ChatAdapter::new(), + ); let trainset = vec![ training_example( diff --git a/crates/dspy-rs/examples/11-custom-client.rs b/crates/dspy-rs/examples/11-custom-client.rs index 8bdcb6b0..52b18489 100644 --- a/crates/dspy-rs/examples/11-custom-client.rs +++ b/crates/dspy-rs/examples/11-custom-client.rs @@ -42,7 +42,7 @@ async fn main() -> Result<()> { .with_client(custom_lm_client) .await?; - configure(lm, ChatAdapter); + configure(lm, ChatAdapter::new()); let predictor = Predict::::new(); let prediction = predictor diff --git a/crates/dspy-rs/examples/12-tracing.rs b/crates/dspy-rs/examples/12-tracing.rs index f1e3d412..86ba69a4 100644 --- a/crates/dspy-rs/examples/12-tracing.rs +++ b/crates/dspy-rs/examples/12-tracing.rs @@ -11,8 +11,8 @@ use anyhow::Result; use bon::Builder; use dspy_rs::data::RawExample; use dspy_rs::{ - CallMetadata, ChatAdapter, LM, LmUsage, Module, Predict, PredictError, Predicted, Prediction, - Signature, configure, init_tracing, + CallMetadata, Chat, ChatAdapter, LM, LmUsage, Module, Predict, PredictError, Predicted, + Prediction, Signature, configure, init_tracing, trace::{self, Executor}, }; use serde_json::json; @@ -83,7 +83,11 @@ impl Module for QARater { }, ); - Ok(Predicted::new(prediction, CallMetadata::default())) + Ok(Predicted::new( + prediction, + CallMetadata::default(), + Chat::new(vec![]), + )) } } @@ -96,7 +100,7 @@ async fn main() -> Result<()> { .model("openai:gpt-4o-mini".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let module = QARater::builder().build(); diff --git a/crates/dspy-rs/examples/15-tools.rs b/crates/dspy-rs/examples/15-tools.rs index c2170238..f57ae742 100644 --- a/crates/dspy-rs/examples/15-tools.rs +++ b/crates/dspy-rs/examples/15-tools.rs @@ -100,7 +100,7 @@ async fn main() -> Result<()> { .model("groq:openai/gpt-oss-120b".to_string()) .build() .await?; - configure(lm, ChatAdapter); + configure(lm, ChatAdapter::new()); let predictor = Predict::::builder() .instruction("You must call the calculator tool for arithmetic.") diff --git a/crates/dspy-rs/examples/16-insurance-claim-prompt.rs b/crates/dspy-rs/examples/16-insurance-claim-prompt.rs index da712dc9..c61c8a8c 100644 --- a/crates/dspy-rs/examples/16-insurance-claim-prompt.rs +++ b/crates/dspy-rs/examples/16-insurance-claim-prompt.rs @@ -196,7 +196,7 @@ pub struct InsuranceClaimInfo { fn main() { init_tracing().expect("failed to initialize tracing"); - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let system = adapter .format_system_message_typed::() .expect("system prompt"); diff --git a/crates/dspy-rs/examples/90-smoke-slice1-typed-predict.rs b/crates/dspy-rs/examples/90-smoke-slice1-typed-predict.rs index 7b485756..bf451a74 100644 --- a/crates/dspy-rs/examples/90-smoke-slice1-typed-predict.rs +++ b/crates/dspy-rs/examples/90-smoke-slice1-typed-predict.rs @@ -18,7 +18,7 @@ async fn main() -> Result<()> { .model("openai:gpt-5.2".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let module = Predict::::new(); diff --git a/crates/dspy-rs/examples/91-smoke-slice2-chain-of-thought.rs b/crates/dspy-rs/examples/91-smoke-slice2-chain-of-thought.rs index 12b90e56..d8ab2053 100644 --- a/crates/dspy-rs/examples/91-smoke-slice2-chain-of-thought.rs +++ b/crates/dspy-rs/examples/91-smoke-slice2-chain-of-thought.rs @@ -18,7 +18,7 @@ async fn main() -> Result<()> { .model("openai:gpt-5.2".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let module = ChainOfThought::::new(); diff --git a/crates/dspy-rs/examples/92-smoke-slice3-module-authoring.rs b/crates/dspy-rs/examples/92-smoke-slice3-module-authoring.rs index 50da034c..08989940 100644 --- a/crates/dspy-rs/examples/92-smoke-slice3-module-authoring.rs +++ b/crates/dspy-rs/examples/92-smoke-slice3-module-authoring.rs @@ -39,7 +39,7 @@ async fn main() -> Result<()> { .model("openai:gpt-5.2".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let module = SmokeModule::new(); diff --git a/crates/dspy-rs/examples/93-smoke-slice4-react-operational.rs b/crates/dspy-rs/examples/93-smoke-slice4-react-operational.rs index 90c358f7..9546aaa8 100644 --- a/crates/dspy-rs/examples/93-smoke-slice4-react-operational.rs +++ b/crates/dspy-rs/examples/93-smoke-slice4-react-operational.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { .model("openai:gpt-5.2".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let module = ReAct::::builder() @@ -83,7 +83,7 @@ async fn main() -> Result<()> { } anyhow::anyhow!("slice4 smoke failed") })?; - let (output, metadata) = predicted.into_parts(); + let (output, metadata, _chat) = predicted.into_parts(); println!("tool_calls: {}", metadata.tool_calls.len()); println!("tool_executions: {}", metadata.tool_executions.len()); diff --git a/crates/dspy-rs/examples/94-smoke-slice5-optimizer-interface.rs b/crates/dspy-rs/examples/94-smoke-slice5-optimizer-interface.rs index 97d7fec1..284bfbbc 100644 --- a/crates/dspy-rs/examples/94-smoke-slice5-optimizer-interface.rs +++ b/crates/dspy-rs/examples/94-smoke-slice5-optimizer-interface.rs @@ -37,7 +37,7 @@ async fn main() -> Result<()> { .model("openai:gpt-5.2".to_string()) .build() .await?, - ChatAdapter, + ChatAdapter::new(), ); let mut module = ChainOfThought::::new(); diff --git a/crates/dspy-rs/src/adapter/chat.rs b/crates/dspy-rs/src/adapter/chat.rs index 2bf58a1b..73c31660 100644 --- a/crates/dspy-rs/src/adapter/chat.rs +++ b/crates/dspy-rs/src/adapter/chat.rs @@ -15,12 +15,20 @@ use tracing::{debug, trace}; use super::Adapter; use crate::CallMetadata; use crate::{ - BamlType, BamlValue, ConstraintLevel, ConstraintResult, FieldMeta, Flag, InputRenderSpec, + BamlType, BamlValue, Chat, ConstraintLevel, ConstraintResult, FieldMeta, Flag, InputRenderSpec, JsonishError, Message, OutputFormatContent, ParseError, PredictError, Predicted, RenderOptions, Signature, TypeIR, }; -/// Builds prompts and parses responses using the `[[ ## field ## ]]` delimiter protocol. +/// Output formatting/parsing dialect for [`ChatAdapter`]. +#[derive(Debug, Clone, Copy, Default, Eq, PartialEq)] +pub enum Dialect { + #[default] + Chat, + Passthrough, +} + +/// Builds prompts and parses responses using signature-aware adapter dialects. /// /// The adapter is stateless — all state comes from the [`SignatureSchema`](crate::SignatureSchema) /// passed to each method. Two usage patterns: @@ -32,8 +40,10 @@ use crate::{ /// /// The building blocks exist so module authors can compose custom prompt flows (e.g. /// ReAct's action/extract loop) without reimplementing the delimiter protocol. -#[derive(Default, Clone)] -pub struct ChatAdapter; +#[derive(Debug, Clone, Default)] +pub struct ChatAdapter { + dialect: Dialect, +} static FIELD_HEADER_PATTERN: LazyLock = LazyLock::new(|| Regex::new(r"^\[\[ ## ([^#]+?) ## \]\]").unwrap()); @@ -167,42 +177,16 @@ fn resolve_rendered_type_token(token: &str, output_format: Option<&OutputFormatC } } - token.rsplit("::").next().unwrap_or(token).to_string() -} - -fn simplify_type_name(raw: &str, output_format: Option<&OutputFormatContent>) -> String { - let mut result = String::with_capacity(raw.len()); - let mut chars = raw.chars(); - while let Some(ch) = chars.next() { - if ch == '`' { - let mut token = String::new(); - for next in chars.by_ref() { - if next == '`' { - break; - } - token.push(next); - } - let rendered = resolve_rendered_type_token(&token, output_format); - result.push_str(&rendered); - } else { - result.push(ch); - } - } - result + crate::core::simplify_type_token(token) } fn render_type_name_for_prompt( type_ir: &TypeIR, output_format: Option<&OutputFormatContent>, ) -> String { - let raw = type_ir.diagnostic_repr().to_string(); - let simplified = simplify_type_name(&raw, output_format); - simplified - .replace("class ", "") - .replace("enum ", "") - .replace(" | ", " or ") - .trim() - .to_string() + crate::core::render_type_name_for_prompt_with(type_ir, |token| { + resolve_rendered_type_token(token, output_format) + }) } fn split_schema_definitions(schema: &str) -> Option<(String, String)> { @@ -300,6 +284,26 @@ fn format_schema_for_prompt(schema: &str) -> String { } impl ChatAdapter { + pub fn new() -> Self { + Self { + dialect: Dialect::Chat, + } + } + + pub fn passthrough() -> Self { + Self { + dialect: Dialect::Passthrough, + } + } + + pub fn dialect(&self) -> Dialect { + self.dialect + } + + fn is_structured_output(&self) -> bool { + !matches!(self.dialect, Dialect::Passthrough) + } + fn format_task_description_schema( &self, schema: &crate::SignatureSchema, @@ -331,7 +335,11 @@ impl ChatAdapter { indented.push_str(line); } - format!("In adhering to this structure, your objective is: {indented}") + if self.is_structured_output() { + format!("In adhering to this structure, your objective is: {indented}") + } else { + format!("Your objective is: {indented}") + } } fn format_response_instructions_schema(&self, schema: &crate::SignatureSchema) -> String { @@ -396,12 +404,29 @@ impl ChatAdapter { schema: &crate::SignatureSchema, instruction_override: Option<&str>, ) -> Result { - let parts = [ - self.format_field_descriptions_schema(schema), - self.format_field_structure_schema(schema)?, - self.format_response_instructions_schema(schema), - self.format_task_description_schema(schema, instruction_override), - ]; + if !self.is_structured_output() + && let Some(instruction_override) = instruction_override + { + trace!( + system_len = instruction_override.len(), + "formatted schema system prompt" + ); + return Ok(instruction_override.to_string()); + } + + let parts = if self.is_structured_output() { + vec![ + self.format_field_descriptions_schema(schema), + self.format_field_structure_schema(schema)?, + self.format_response_instructions_schema(schema), + self.format_task_description_schema(schema, instruction_override), + ] + } else { + vec![ + self.format_field_descriptions_schema(schema), + self.format_task_description_schema(schema, instruction_override), + ] + }; let system = parts.join("\n\n"); trace!(system_len = system.len(), "formatted schema system prompt"); @@ -488,8 +513,10 @@ impl ChatAdapter { /// Navigates the `BamlValue` using each field's [`FieldPath`](crate::FieldPath) to /// handle flattened structs correctly. A field with path `["inner", "question"]` is /// extracted from the nested structure but rendered as a flat `[[ ## question ## ]]` - /// section in the prompt. Appends response instructions so the LM sees - /// output-field ordering guidance in the latest user turn. + /// section in the prompt. + /// + /// Structured dialects append explicit output-field ordering guidance in the + /// user turn. Passthrough dialect omits output protocol instructions entirely. pub fn format_input(&self, schema: &crate::SignatureSchema, input: &I) -> String where I: BamlType + for<'a> facet::Facet<'a>, @@ -500,21 +527,45 @@ impl ChatAdapter { let vars = Value::Object(serde_json::Map::new()); let mut result = String::new(); + let raw_perception_mode = !self.is_structured_output() + && schema.input_fields().len() == 1 + && schema.input_fields()[0].lm_name == "perception"; for field_spec in schema.input_fields() { if let Some(value) = value_for_path_relaxed(&baml_value, field_spec.path()) { - result.push_str(&format!("[[ ## {} ## ]]\n", field_spec.lm_name)); - result.push_str(&render_input_field( - field_spec, - value, - &input_json, - input_output_format, - &vars, - )); - result.push_str("\n\n"); + if self.is_structured_output() { + result.push_str(&format!("[[ ## {} ## ]]\n", field_spec.lm_name)); + result.push_str(&render_input_field( + field_spec, + value, + &input_json, + input_output_format, + &vars, + )); + result.push_str("\n\n"); + } else { + let rendered = render_input_field( + field_spec, + value, + &input_json, + input_output_format, + &vars, + ); + if raw_perception_mode { + result.push_str(&rendered); + result.push_str("\n"); + } else { + result.push_str(field_spec.lm_name); + result.push_str(":\n"); + result.push_str(&rendered); + result.push_str("\n\n"); + } + } } } - result.push_str(&self.format_response_instructions_schema(schema)); + if self.is_structured_output() { + result.push_str(&self.format_response_instructions_schema(schema)); + } result } @@ -625,7 +676,22 @@ impl ChatAdapter { where O: BamlType + for<'a> facet::Facet<'a>, { - let content = response.content(); + match self.dialect { + Dialect::Chat => self.parse_structured_output_with_meta::(schema, response), + Dialect::Passthrough => self.parse_passthrough_output_with_meta::(schema, response), + } + } + + #[allow(clippy::result_large_err)] + fn parse_structured_output_with_meta( + &self, + schema: &crate::SignatureSchema, + response: &Message, + ) -> std::result::Result<(O, IndexMap), ParseError> + where + O: BamlType + for<'a> facet::Facet<'a>, + { + let content = response.text_content(); let output_format = schema.output_format(); let sections = parse_sections(&content); @@ -786,6 +852,64 @@ impl ChatAdapter { Ok((typed_output, metas)) } + #[allow(clippy::result_large_err)] + fn parse_passthrough_output_with_meta( + &self, + schema: &crate::SignatureSchema, + response: &Message, + ) -> std::result::Result<(O, IndexMap), ParseError> + where + O: BamlType + for<'a> facet::Facet<'a>, + { + let output_fields = schema.output_fields(); + if output_fields.len() != 1 { + return Err(ParseError::ExtractionFailed { + field: "".to_string(), + raw_response: response.content(), + reason: format!( + "passthrough adapter requires exactly one output field, got {}", + output_fields.len() + ), + }); + } + + let raw_response = response.content(); + let code = + extract_passthrough_body(response).ok_or_else(|| ParseError::ExtractionFailed { + field: output_fields[0].rust_name.clone(), + raw_response: raw_response.clone(), + reason: "empty passthrough response".to_string(), + })?; + + let mut output_map = bamltype::baml_types::BamlMap::new(); + output_map.insert( + output_fields[0].rust_name.clone(), + BamlValue::String(code.clone()), + ); + + let typed_output = ::try_from_baml_value(BamlValue::Class( + ::baml_internal_name().to_string(), + output_map, + )) + .map_err(|err| ParseError::ExtractionFailed { + field: output_fields[0].rust_name.clone(), + raw_response: raw_response.clone(), + reason: err.to_string(), + })?; + + let mut metas = IndexMap::new(); + metas.insert( + output_fields[0].rust_name.clone(), + FieldMeta { + raw_text: code, + flags: Vec::new(), + checks: Vec::new(), + }, + ); + + Ok((typed_output, metas)) + } + #[allow(clippy::result_large_err)] /// Parses an LM response into a typed output, discarding field metadata. /// @@ -829,12 +953,14 @@ impl ChatAdapter { response: Message, ) -> std::result::Result, PredictError> { let raw_response = response.content(); + let parse_chat = Chat::new(vec![response.clone()]); let (output, field_meta) = self .parse_response_typed::(&response) .map_err(|source| PredictError::Parse { source, raw_response: raw_response.clone(), lm_usage: crate::LmUsage::default(), + chat: parse_chat, })?; let metadata = CallMetadata::new( raw_response, @@ -844,7 +970,7 @@ impl ChatAdapter { None, field_meta, ); - Ok(Predicted::new(output, metadata)) + Ok(Predicted::new(output, metadata, Chat::new(vec![response]))) } } @@ -884,6 +1010,92 @@ fn parse_sections(content: &str) -> IndexMap { parsed } +fn extract_passthrough_body(response: &Message) -> Option { + extract_passthrough_body_from_text(&response.text_content()) +} + +fn extract_passthrough_body_from_text(text: &str) -> Option { + let trimmed = text.trim(); + if trimmed.is_empty() { + return None; + } + + let fenced_blocks = extract_fenced_code_blocks(trimmed); + if !fenced_blocks.is_empty() { + return Some(fenced_blocks.join("\n\n")); + } + + Some(trimmed.to_string()) +} + +fn extract_fenced_code_blocks(text: &str) -> Vec { + let mut blocks = Vec::new(); + let mut in_fence = false; + let mut current = Vec::new(); + + for line in text.lines() { + if line.trim_start().starts_with("```") { + if in_fence { + let block = current.join("\n").trim().to_string(); + if !block.is_empty() { + blocks.push(block); + } + current.clear(); + in_fence = false; + } else { + in_fence = true; + current.clear(); + } + continue; + } + + if in_fence { + current.push(line); + } + } + + if in_fence { + let block = current.join("\n").trim().to_string(); + if !block.is_empty() { + blocks.push(block); + } + } + + blocks +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn passthrough_extracts_all_fenced_blocks() { + let text = "```python\nx = 1\n```\n\nSome prose.\n\n```py\ny = 2\n```\n``` \nz = 3\n```"; + assert_eq!( + extract_passthrough_body_from_text(text), + Some("x = 1\n\ny = 2\n\nz = 3".to_string()) + ); + } + + #[test] + fn passthrough_extracts_unclosed_fenced_block() { + let text = "```python\nx = 1\ny = 2"; + assert_eq!( + extract_passthrough_body_from_text(text), + Some("x = 1\ny = 2".to_string()) + ); + } + + #[test] + fn passthrough_falls_back_to_full_text_without_fences() { + let text = " x = 1\ny = x + 1 "; + assert_eq!( + extract_passthrough_body_from_text(text), + Some("x = 1\ny = x + 1".to_string()) + ); + } +} + fn value_for_path_relaxed<'a>( value: &'a BamlValue, path: &crate::FieldPath, diff --git a/crates/dspy-rs/src/core/errors.rs b/crates/dspy-rs/src/core/errors.rs index bc206dbf..3903de24 100644 --- a/crates/dspy-rs/src/core/errors.rs +++ b/crates/dspy-rs/src/core/errors.rs @@ -1,6 +1,6 @@ use std::{error::Error as StdError, time::Duration}; -use crate::{BamlConvertError, BamlValue, LmUsage}; +use crate::{BamlConvertError, BamlValue, Chat, LmUsage}; /// Error from the jsonish coercion layer when LM output can't be parsed as a typed value. #[derive(Debug)] @@ -56,6 +56,8 @@ pub enum ErrorClass { /// 3. **[`Conversion`](PredictError::Conversion)** — we parsed a valid `BamlValue` /// from the response, but it doesn't fit the Rust output type. Code bug or schema /// mismatch. **Not retryable** — the same parsed value will fail the same way. +/// 4. **[`Module`](PredictError::Module)** — module-internal execution error outside +/// of direct LM parsing/provider failure. /// /// Use [`is_retryable`](PredictError::is_retryable) for retry logic. /// Use [`class`](PredictError::class) for coarse [`ErrorClass`] bucketing. @@ -78,6 +80,11 @@ pub enum PredictError { source: ParseError, raw_response: String, lm_usage: LmUsage, + /// Conversation history including the failed assistant turn. + /// + /// This enables callers (for example, multi-turn modules like RLM) to continue + /// the conversation after a recoverable parse failure. + chat: Chat, }, /// The response parsed into a `BamlValue` but doesn't match the typed output struct. @@ -91,6 +98,14 @@ pub enum PredictError { /// The successfully parsed `BamlValue` that failed type conversion. parsed: BamlValue, }, + + /// Module-level execution error not represented by LM/provider/parse conversion. + #[error("{module} module failed")] + Module { + module: &'static str, + #[source] + source: Box, + }, } impl PredictError { @@ -99,6 +114,7 @@ impl PredictError { Self::Lm { source } => source.class(), Self::Parse { .. } => ErrorClass::BadResponse, Self::Conversion { .. } => ErrorClass::Internal, + Self::Module { .. } => ErrorClass::Internal, } } @@ -107,6 +123,7 @@ impl PredictError { Self::Lm { source } => source.is_retryable(), Self::Parse { .. } => true, Self::Conversion { .. } => false, + Self::Module { .. } => false, } } } diff --git a/crates/dspy-rs/src/core/lm/chat.rs b/crates/dspy-rs/src/core/lm/chat.rs index d3459fd0..690db9b2 100644 --- a/crates/dspy-rs/src/core/lm/chat.rs +++ b/crates/dspy-rs/src/core/lm/chat.rs @@ -3,106 +3,432 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; -use rig::completion::{AssistantContent, Message as RigMessage, message::UserContent}; +use rig::OneOrMany; +use rig::message::{ + AssistantContent, Message as RigMessage, Reasoning, ToolCall, ToolResult, ToolResultContent, + UserContent, +}; + +// --------------------------------------------------------------------------- +// ContentBlock — one piece of content within a message +// --------------------------------------------------------------------------- #[derive(Clone, Debug, Serialize, Deserialize)] -pub enum Message { - System { content: String }, - User { content: String }, - Assistant { content: String }, +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlock { + Text { text: String }, + ToolCall { tool_call: ToolCall }, + ToolResult { tool_result: ToolResult }, + Reasoning { reasoning: Reasoning }, +} + +impl ContentBlock { + pub fn text(t: impl Into) -> Self { + ContentBlock::Text { text: t.into() } + } + + pub fn tool_call(tc: ToolCall) -> Self { + ContentBlock::ToolCall { tool_call: tc } + } + + pub fn tool_result(tr: ToolResult) -> Self { + ContentBlock::ToolResult { tool_result: tr } + } + + pub fn reasoning(r: Reasoning) -> Self { + ContentBlock::Reasoning { reasoning: r } + } +} + +// --------------------------------------------------------------------------- +// Role +// --------------------------------------------------------------------------- + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, +} + +impl Role { + pub fn as_str(&self) -> &'static str { + match self { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + } + } +} + +// --------------------------------------------------------------------------- +// Message — a single turn in a conversation +// --------------------------------------------------------------------------- + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: Vec, + /// Provider-assigned message ID (e.g. Anthropic thinking turn IDs). + #[serde(skip_serializing_if = "Option::is_none", default)] + pub id: Option, } impl Message { - pub fn new(role: &str, content: &str) -> Self { - match role { - "system" => Message::system(content), - "user" => Message::user(content), - "assistant" => Message::assistant(content), - _ => panic!("Invalid role: {role}"), + /// Creates a text-only message for a typed role. + pub fn new(role: Role, content: impl Into) -> Self { + Self { + role, + content: vec![ContentBlock::text(content)], + id: None, } } pub fn user(content: impl Into) -> Self { - Message::User { - content: content.into(), + Self { + role: Role::User, + content: vec![ContentBlock::text(content)], + id: None, } } pub fn assistant(content: impl Into) -> Self { - Message::Assistant { - content: content.into(), + Self { + role: Role::Assistant, + content: vec![ContentBlock::text(content)], + id: None, } } pub fn system(content: impl Into) -> Self { - Message::System { - content: content.into(), + Self { + role: Role::System, + content: vec![ContentBlock::text(content)], + id: None, } } - pub fn content(&self) -> String { - match self { - Message::System { content } => content.clone(), - Message::User { content } => content.clone(), - Message::Assistant { content } => content.clone(), + /// Creates an assistant message containing a single tool call. + pub fn tool_call(tool_call: ToolCall) -> Self { + Self { + role: Role::Assistant, + content: vec![ContentBlock::tool_call(tool_call)], + id: None, } } - pub fn get_message_turn(&self) -> RigMessage { - match self { - Message::User { content } => RigMessage::user(content.clone()), - Message::Assistant { content } => RigMessage::assistant(content.clone()), - _ => panic!("Invalid role: {:?}", self), + /// Creates a user message containing a single tool result. + pub fn tool_result(tool_result: ToolResult) -> Self { + Self { + role: Role::User, + content: vec![ContentBlock::tool_result(tool_result)], + id: None, + } + } + + /// Creates an assistant message containing a single reasoning block. + pub fn reasoning(reasoning: Reasoning) -> Self { + Self { + role: Role::Assistant, + content: vec![ContentBlock::reasoning(reasoning)], + id: None, + } + } + + /// Creates a message with arbitrary content blocks. + pub fn with_content(role: Role, content: Vec) -> Self { + Self { + role, + content, + id: None, + } + } + + // -- Accessors ----------------------------------------------------------- + + /// Returns a string representation of the message's content. + /// + /// For text-only messages, returns the text. For multi-content messages, + /// returns all blocks formatted and joined with newlines. + pub fn content(&self) -> String { + let parts: Vec = self + .content + .iter() + .map(|block| match block { + ContentBlock::Text { text } => text.clone(), + ContentBlock::ToolCall { tool_call } => { + format!( + "{}({})", + tool_call.function.name, tool_call.function.arguments + ) + } + ContentBlock::ToolResult { tool_result } => tool_result + .content + .iter() + .filter_map(|item| match item { + ToolResultContent::Text(text) => Some(text.text.as_str()), + ToolResultContent::Image(_) => None, + }) + .collect::>() + .join("\n"), + ContentBlock::Reasoning { reasoning } => reasoning.display_text(), + }) + .collect(); + parts.join("\n") + } + + /// Returns only the text content, ignoring tool calls, tool results, + /// and reasoning blocks. Used by the parser to extract structured output. + pub fn text_content(&self) -> String { + self.content + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("\n") + } + + // -- Content query helpers ----------------------------------------------- + + /// Returns `true` if this message contains at least one tool call. + pub fn has_tool_calls(&self) -> bool { + self.content + .iter() + .any(|b| matches!(b, ContentBlock::ToolCall { .. })) + } + + /// Returns `true` if this message contains at least one tool result. + pub fn has_tool_results(&self) -> bool { + self.content + .iter() + .any(|b| matches!(b, ContentBlock::ToolResult { .. })) + } + + /// Returns `true` if this message contains at least one reasoning block. + pub fn has_reasoning(&self) -> bool { + self.content + .iter() + .any(|b| matches!(b, ContentBlock::Reasoning { .. })) + } + + /// Extracts all tool calls from this message. + pub fn tool_calls(&self) -> Vec<&ToolCall> { + self.content + .iter() + .filter_map(|b| match b { + ContentBlock::ToolCall { tool_call } => Some(tool_call), + _ => None, + }) + .collect() + } + + // -- Rig conversion ------------------------------------------------------ + + /// Converts this message to a rig message for provider API calls. + /// + /// Returns `None` for system messages (rig handles them as preamble). + pub(crate) fn to_rig_message(&self) -> Option { + match self.role { + Role::System => None, + Role::User => { + let user_content: Vec = self + .content + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(UserContent::text(text.clone())), + ContentBlock::ToolResult { tool_result } => { + Some(UserContent::ToolResult(tool_result.clone())) + } + // ToolCall/Reasoning don't belong in user messages; skip gracefully + _ => None, + }) + .collect(); + if user_content.is_empty() { + return Some(RigMessage::user(String::new())); + } + Some(RigMessage::User { + content: OneOrMany::many(user_content) + .unwrap_or_else(|_| OneOrMany::one(UserContent::text(String::new()))), + }) + } + Role::Assistant => { + let asst_content: Vec = self + .content + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(AssistantContent::text(text.clone())), + ContentBlock::ToolCall { tool_call } => { + Some(AssistantContent::ToolCall(tool_call.clone())) + } + ContentBlock::Reasoning { reasoning } => { + Some(AssistantContent::Reasoning(reasoning.clone())) + } + // ToolResult doesn't belong in assistant messages; skip gracefully + _ => None, + }) + .collect(); + if asst_content.is_empty() { + return Some(RigMessage::assistant(String::new())); + } + Some(RigMessage::Assistant { + id: self.id.clone(), + content: OneOrMany::many(asst_content) + .unwrap_or_else(|_| OneOrMany::one(AssistantContent::text(String::new()))), + }) + } } } + // -- JSON serialization -------------------------------------------------- + pub fn to_json(&self) -> Value { - match self { - Message::System { content } => json!({ "role": "system", "content": content }), - Message::User { content } => json!({ "role": "user", "content": content }), - Message::Assistant { content } => json!({ "role": "assistant", "content": content }), + let content_json: Vec = self + .content + .iter() + .map(|block| match block { + ContentBlock::Text { text } => json!({ "type": "text", "text": text }), + ContentBlock::ToolCall { tool_call } => { + json!({ "type": "tool_call", "tool_call": tool_call }) + } + ContentBlock::ToolResult { tool_result } => { + json!({ "type": "tool_result", "tool_result": tool_result }) + } + ContentBlock::Reasoning { reasoning } => { + json!({ "type": "reasoning", "reasoning": reasoning }) + } + }) + .collect(); + + let mut msg = json!({ + "role": self.role.as_str(), + "content": content_json, + }); + + if let Some(id) = &self.id { + msg.as_object_mut() + .unwrap() + .insert("id".to_string(), json!(id)); } + + msg + } + + fn from_json_value(message: &Value) -> Result { + let role_str = message + .get("role") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("chat message missing string role"))?; + + let role = match role_str { + "system" => Role::System, + "user" => Role::User, + "assistant" => Role::Assistant, + other => return Err(anyhow::anyhow!("unsupported chat message role: {other}")), + }; + + let id = message.get("id").and_then(Value::as_str).map(String::from); + + let content = message + .get("content") + .and_then(Value::as_array) + .ok_or_else(|| anyhow::anyhow!("chat message content must be an array"))? + .iter() + .map(parse_content_block) + .collect::>>()?; + + Ok(Self { role, content, id }) } } +fn parse_content_block(value: &Value) -> Result { + let block_type = value + .get("type") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("content block missing type"))?; + + match block_type { + "text" => { + let text = value + .get("text") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("text block missing text field"))?; + Ok(ContentBlock::text(text)) + } + "tool_call" => { + let tc: ToolCall = serde_json::from_value(value["tool_call"].clone())?; + Ok(ContentBlock::tool_call(tc)) + } + "tool_result" => { + let tr: ToolResult = serde_json::from_value(value["tool_result"].clone())?; + Ok(ContentBlock::tool_result(tr)) + } + "reasoning" => { + let r: Reasoning = serde_json::from_value(value["reasoning"].clone())?; + Ok(ContentBlock::reasoning(r)) + } + other => Err(anyhow::anyhow!("unsupported content block type: {other}")), + } +} + +// --------------------------------------------------------------------------- +// From — grouped conversion, one rig message → one DSRs message +// --------------------------------------------------------------------------- + impl From for Message { fn from(message: RigMessage) -> Self { match message { RigMessage::User { content } => { - let text = content + let blocks: Vec = content .into_iter() - .find_map(|c| { - if let UserContent::Text(t) = c { - Some(t.text) - } else { - None - } + .filter_map(|item| match item { + UserContent::Text(text) => Some(ContentBlock::text(text.text)), + UserContent::ToolResult(result) => Some(ContentBlock::tool_result(result)), + UserContent::Image(_) + | UserContent::Audio(_) + | UserContent::Video(_) + | UserContent::Document(_) => None, }) - .unwrap_or_default(); - Message::user(text) + .collect(); + Message { + role: Role::User, + content: if blocks.is_empty() { + vec![ContentBlock::text(String::new())] + } else { + blocks + }, + id: None, + } } - RigMessage::Assistant { content, .. } => { - let text = content + RigMessage::Assistant { id, content } => { + let blocks: Vec = content .into_iter() - .find_map(|c| { - if let AssistantContent::Text(t) = c { - Some(t.text) - } else { - None - } + .filter_map(|item| match item { + AssistantContent::Text(text) => Some(ContentBlock::text(text.text)), + AssistantContent::ToolCall(tc) => Some(ContentBlock::tool_call(tc)), + AssistantContent::Reasoning(r) => Some(ContentBlock::reasoning(r)), + AssistantContent::Image(_) => None, }) - .unwrap_or_default(); - Message::assistant(text) + .collect(); + Message { + role: Role::Assistant, + content: if blocks.is_empty() { + vec![ContentBlock::text(String::new())] + } else { + blocks + }, + id, + } } } } } -pub struct RigChatMessage { - pub system: String, - pub conversation: Vec, - pub prompt: RigMessage, -} +// --------------------------------------------------------------------------- +// Chat — ordered sequence of messages +// --------------------------------------------------------------------------- #[derive(Clone, Debug)] pub struct Chat { @@ -122,7 +448,7 @@ impl Chat { self.messages.is_empty() } - pub fn push(&mut self, role: &str, content: &str) { + pub fn push(&mut self, role: Role, content: impl Into) { self.messages.push(Message::new(role, content)); } @@ -139,16 +465,13 @@ impl Chat { } pub fn from_json(&self, json_dump: Value) -> Result { - let messages = json_dump.as_array().unwrap(); + let messages = json_dump + .as_array() + .ok_or_else(|| anyhow::anyhow!("chat dump must be an array"))?; let messages = messages .iter() - .map(|message| { - Message::new( - message["role"].as_str().unwrap(), - message["content"].as_str().unwrap(), - ) - }) - .collect(); + .map(Message::from_json_value) + .collect::>>()?; Ok(Self { messages }) } @@ -161,22 +484,130 @@ impl Chat { json!(messages) } - pub fn get_rig_messages(&self) -> RigChatMessage { - let system: String = self.messages[0].content(); - let conversation: Vec = if self.messages.len() > 2 { - self.messages[1..self.messages.len() - 1] - .iter() - .map(|message| message.get_message_turn()) - .collect::>() - } else { - vec![] - }; - let prompt = self.messages.last().unwrap().get_message_turn(); + // -- Rig interop --------------------------------------------------------- - RigChatMessage { - system, - conversation, - prompt, - } + /// Extracts the system prompt text from the first system message. + pub(crate) fn system_prompt(&self) -> String { + self.messages + .iter() + .find_map(|message| { + if message.role == Role::System { + Some(message.text_content()) + } else { + None + } + }) + .unwrap_or_default() + } + + /// Converts all non-system messages to rig messages for provider API calls. + pub(crate) fn to_rig_chat_history(&self) -> Vec { + self.messages + .iter() + .filter_map(Message::to_rig_message) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rig::OneOrMany; + use rig::message::{ToolFunction, ToolResultContent}; + use serde_json::json; + + #[test] + fn rig_conversion_preserves_assistant_reasoning_and_tool_calls() { + let original = Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("step 1")), + ContentBlock::reasoning(Reasoning::new("step 2")), + ContentBlock::tool_call(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({"q": "rust ownership"}), + }, + )), + ], + ); + + let rig_msg = original + .to_rig_message() + .expect("assistant message should convert to rig"); + let roundtripped = Message::from(rig_msg); + + assert_eq!(roundtripped.role, Role::Assistant); + assert_eq!(roundtripped.content.len(), 3); + assert!(matches!( + &roundtripped.content[0], + ContentBlock::Reasoning { .. } + )); + assert!(matches!( + &roundtripped.content[1], + ContentBlock::Reasoning { .. } + )); + assert!( + matches!(&roundtripped.content[2], ContentBlock::ToolCall { tool_call } if tool_call.function.name == "search") + ); + } + + #[test] + fn rig_conversion_preserves_user_text_and_tool_result() { + let original = Message::with_content( + Role::User, + vec![ + ContentBlock::text("Here is context"), + ContentBlock::tool_result(ToolResult { + id: "tr-1".to_string(), + call_id: Some("tc-1".to_string()), + content: OneOrMany::one(ToolResultContent::text("search result")), + }), + ], + ); + + let rig_msg = original + .to_rig_message() + .expect("user should convert to rig"); + let roundtripped = Message::from(rig_msg); + + assert_eq!(roundtripped.role, Role::User); + assert_eq!(roundtripped.content.len(), 2); + assert!( + matches!(&roundtripped.content[0], ContentBlock::Text { text } if text == "Here is context") + ); + assert!(roundtripped.has_tool_results()); + } + + #[test] + fn rig_chat_history_excludes_system_message() { + let chat = Chat::new(vec![ + Message::system("You are a helpful assistant."), + Message::user("What is the capital of France?"), + Message::assistant("Paris."), + ]); + + let rig_history = chat.to_rig_chat_history(); + assert_eq!(rig_history.len(), 2); + } + + #[test] + fn system_messages_are_not_converted_to_rig_messages() { + let msg = Message::system("You are helpful"); + assert!(msg.to_rig_message().is_none()); + } + + #[test] + fn assistant_message_id_survives_rig_roundtrip() { + let mut msg = Message::assistant("some text"); + msg.id = Some("msg_abc123".to_string()); + + let rig_msg = msg + .to_rig_message() + .expect("assistant should convert to rig"); + let roundtripped = Message::from(rig_msg); + + assert_eq!(roundtripped.id, Some("msg_abc123".to_string())); } } diff --git a/crates/dspy-rs/src/core/lm/client_registry.rs b/crates/dspy-rs/src/core/lm/client_registry.rs index 6df3c7ca..d41d2cf5 100644 --- a/crates/dspy-rs/src/core/lm/client_registry.rs +++ b/crates/dspy-rs/src/core/lm/client_registry.rs @@ -14,6 +14,11 @@ use std::{ }; use tracing::{debug, trace, warn}; +#[derive(Clone, Debug, Default)] +pub struct ProviderOptions { + pub anthropic_prompt_caching: bool, +} + #[enum_dispatch] #[allow(async_fn_in_trait)] pub trait CompletionProvider { @@ -77,6 +82,7 @@ impl CompletionProvider for TestCompletionModel { #[derive(Clone)] pub enum LMClient { OpenAI(openai::completion::CompletionModel), + OpenAIResponses(openai::responses_api::ResponsesCompletionModel), Gemini(gemini::completion::CompletionModel), Anthropic(anthropic::completion::CompletionModel), Groq(groq::CompletionModel), @@ -103,6 +109,16 @@ impl CompletionProvider for openai::completion::CompletionModel { } } +impl CompletionProvider for openai::responses_api::ResponsesCompletionModel { + async fn completion( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { + let response = rig::completion::CompletionModel::completion(self, request).await?; + Ok(to_unit_completion_response(response)) + } +} + impl CompletionProvider for anthropic::completion::CompletionModel { async fn completion( &self, @@ -213,6 +229,20 @@ impl CompletionProvider for deepseek::CompletionModel { } impl LMClient { + fn parse_openai_model(model: &str) -> (&str, bool) { + if let Some(rest) = model + .strip_prefix("openai-responses:") + .or_else(|| model.strip_prefix("openai_responses:")) + .or_else(|| model.strip_prefix("openai.responses:")) + { + return (rest, true); + } + if let Some(rest) = model.strip_prefix("openai:") { + return (rest, false); + } + (model, false) + } + #[tracing::instrument( name = "dsrs.lm.client.get_api_key", level = "trace", @@ -236,15 +266,27 @@ impl LMClient { fields(model, base_url_present = true, api_key_present = true) )] pub fn from_openai_compatible(base_url: &str, api_key: &str, model: &str) -> Result { - trace!(base_url, "creating openai-compatible client"); - let client = openai::CompletionsClient::builder() - .api_key(api_key) - .base_url(base_url) - .build()?; - debug!("openai-compatible client ready"); - Ok(LMClient::OpenAI(openai::completion::CompletionModel::new( - client, model, - ))) + let (model, use_responses) = Self::parse_openai_model(model); + trace!(base_url, use_responses, "creating openai-compatible client"); + if use_responses { + let client = openai::Client::builder() + .api_key(api_key) + .base_url(base_url) + .build()?; + debug!("openai-compatible responses client ready"); + Ok(LMClient::OpenAIResponses( + openai::responses_api::ResponsesCompletionModel::new(client, model), + )) + } else { + let client = openai::CompletionsClient::builder() + .api_key(api_key) + .base_url(base_url) + .build()?; + debug!("openai-compatible client ready"); + Ok(LMClient::OpenAI(openai::completion::CompletionModel::new( + client, model, + ))) + } } /// Build case 2: Local OpenAI-compatible model from base_url (vLLM, etc.) @@ -256,15 +298,30 @@ impl LMClient { fields(model, base_url_present = true) )] pub fn from_local(base_url: &str, model: &str) -> Result { - trace!(base_url, "creating local openai-compatible client"); - let client = openai::CompletionsClient::builder() - .api_key("dummy-key-for-local-server") - .base_url(base_url) - .build()?; - debug!("local openai-compatible client ready"); - Ok(LMClient::OpenAI(openai::completion::CompletionModel::new( - client, model, - ))) + let (model, use_responses) = Self::parse_openai_model(model); + trace!( + base_url, + use_responses, "creating local openai-compatible client" + ); + if use_responses { + let client = openai::Client::builder() + .api_key("dummy-key-for-local-server") + .base_url(base_url) + .build()?; + debug!("local openai-compatible responses client ready"); + Ok(LMClient::OpenAIResponses( + openai::responses_api::ResponsesCompletionModel::new(client, model), + )) + } else { + let client = openai::CompletionsClient::builder() + .api_key("dummy-key-for-local-server") + .base_url(base_url) + .build()?; + debug!("local openai-compatible client ready"); + Ok(LMClient::OpenAI(openai::completion::CompletionModel::new( + client, model, + ))) + } } /// Build case 3: From provider via model name (provider:model format) @@ -279,6 +336,25 @@ impl LMClient { ) )] pub fn from_model_string(model_str: &str, api_key: Option<&str>) -> Result { + Self::from_model_string_with_options(model_str, api_key, &ProviderOptions::default()) + } + + #[tracing::instrument( + name = "dsrs.lm.client.from_model_string_with_options", + level = "debug", + skip(model_str, api_key, options), + fields( + provider = tracing::field::Empty, + model_id = tracing::field::Empty, + api_key_present = api_key.is_some(), + anthropic_prompt_caching = options.anthropic_prompt_caching + ) + )] + pub fn from_model_string_with_options( + model_str: &str, + api_key: Option<&str>, + options: &ProviderOptions, + ) -> Result { let (provider, model_id) = model_str.split_once(':').ok_or(anyhow::anyhow!( "Model string must be in format 'provider:model_name'" ))?; @@ -286,6 +362,14 @@ impl LMClient { tracing::Span::current().record("model_id", tracing::field::display(model_id)); match provider { + "openai-responses" | "openai_responses" | "openai.responses" => { + debug!("selecting openai responses provider"); + let key = Self::get_api_key(api_key, "OPENAI_API_KEY")?; + let client = openai::Client::builder().api_key(key.as_ref()).build()?; + Ok(LMClient::OpenAIResponses( + openai::responses_api::ResponsesCompletionModel::new(client, model_id), + )) + } "openai" => { debug!("selecting openai provider"); let key = Self::get_api_key(api_key, "OPENAI_API_KEY")?; @@ -300,9 +384,11 @@ impl LMClient { debug!("selecting anthropic provider"); let key = Self::get_api_key(api_key, "ANTHROPIC_API_KEY")?; let client = anthropic::Client::builder().api_key(key.as_ref()).build()?; - Ok(LMClient::Anthropic( - anthropic::completion::CompletionModel::new(client, model_id), - )) + let mut model = anthropic::completion::CompletionModel::new(client, model_id); + if options.anthropic_prompt_caching { + model = model.with_prompt_caching(); + } + Ok(LMClient::Anthropic(model)) } "gemini" => { debug!("selecting gemini provider"); @@ -340,7 +426,7 @@ impl LMClient { _ => { warn!(provider, "unsupported provider"); anyhow::bail!( - "Unsupported provider: {}. Supported providers are: openai, anthropic, gemini, groq, openrouter, ollama", + "Unsupported provider: {}. Supported providers are: openai, openai-responses, anthropic, gemini, groq, openrouter, ollama", provider ); } @@ -356,3 +442,40 @@ impl LMClient { client.into() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn anthropic_prompt_caching_option_enables_model_flag() { + let client = LMClient::from_model_string_with_options( + "anthropic:claude-opus-4-0", + Some("test-key"), + &ProviderOptions { + anthropic_prompt_caching: true, + }, + ) + .expect("anthropic client should build"); + + match client { + LMClient::Anthropic(model) => assert!(model.prompt_caching), + _ => panic!("expected anthropic client"), + } + } + + #[test] + fn anthropic_prompt_caching_defaults_to_disabled() { + let client = LMClient::from_model_string_with_options( + "anthropic:claude-opus-4-0", + Some("test-key"), + &ProviderOptions::default(), + ) + .expect("anthropic client should build"); + + match client { + LMClient::Anthropic(model) => assert!(!model.prompt_caching), + _ => panic!("expected anthropic client"), + } + } +} diff --git a/crates/dspy-rs/src/core/lm/mod.rs b/crates/dspy-rs/src/core/lm/mod.rs index 4d18c50b..70ad0670 100644 --- a/crates/dspy-rs/src/core/lm/mod.rs +++ b/crates/dspy-rs/src/core/lm/mod.rs @@ -8,6 +8,7 @@ pub use usage::*; use anyhow::Result; use rig::{completion::AssistantContent, message::ToolCall, message::ToolChoice, tool::ToolDyn}; +use serde_json::Value; use bon::Builder; use std::{collections::HashMap, sync::Arc}; @@ -31,6 +32,12 @@ pub struct LMResponse { pub tool_executions: Vec, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ToolLoopMode { + Auto, + CallerManaged, +} + #[derive(Builder)] #[builder(finish_fn(vis = "", name = __internal_build))] pub struct LM { @@ -44,6 +51,11 @@ pub struct LM { pub max_tokens: u32, #[builder(default = 10)] pub max_tool_iterations: u32, + /// Provider-specific request parameters forwarded to rig `CompletionRequest.additional_params`. + pub additional_params: Option, + /// Enables Anthropic prompt caching when using `anthropic:model` model strings. + #[builder(default = false)] + pub anthropic_prompt_caching: bool, #[builder(default = false)] pub cache: bool, pub cache_handler: Option>>, @@ -66,6 +78,8 @@ impl Clone for LM { temperature: self.temperature, max_tokens: self.max_tokens, max_tool_iterations: self.max_tool_iterations, + additional_params: self.additional_params.clone(), + anthropic_prompt_caching: self.anthropic_prompt_caching, cache: self.cache, cache_handler: self.cache_handler.clone(), client: self.client.clone(), @@ -121,7 +135,13 @@ impl LM { // Uses provider-specific clients (None, api_key, model) if model.contains(':') => { debug!(build_case = 3, "using provider:model client"); - Arc::new(LMClient::from_model_string(model, api_key.as_deref())?) + Arc::new(LMClient::from_model_string_with_options( + model, + api_key.as_deref(), + &ProviderOptions { + anthropic_prompt_caching: self.anthropic_prompt_caching, + }, + )?) } // Default case: assume OpenAI provider if no colon in model name (None, api_key, model) => { @@ -131,7 +151,13 @@ impl LM { } else { format!("openai:{}", model) }; - Arc::new(LMClient::from_model_string(&model_str, api_key.as_deref())?) + Arc::new(LMClient::from_model_string_with_options( + &model_str, + api_key.as_deref(), + &ProviderOptions { + anthropic_prompt_caching: self.anthropic_prompt_caching, + }, + )?) } }; @@ -179,7 +205,6 @@ impl LMBuilder { struct ToolLoopResult { message: Message, - #[allow(unused)] chat_history: Vec, tool_calls: Vec, tool_executions: Vec, @@ -189,7 +214,10 @@ struct ToolLoopResult { /// Reasoning blocks are preserved in `full_content` for faithful history replay. enum ChoiceAction { /// Terminal text response (possibly preceded by reasoning). - Text(String), + Text { + text: String, + full_content: Box>, + }, /// One or more tool calls to execute. Carries the full `OneOrMany` so /// reasoning blocks are preserved when we push the assistant turn into /// chat history. Supports parallel tool calling (Anthropic multi-tool-use, @@ -197,6 +225,7 @@ enum ChoiceAction { ToolCalls { calls: Vec, full_content: Box>, + assistant_text: Option, }, } @@ -224,11 +253,15 @@ fn classify_choice(choice: rig::OneOrMany) -> ChoiceAction { return ChoiceAction::ToolCalls { calls: tool_calls, full_content: Box::new(choice), + assistant_text: text, }; } if let Some(t) = text { - return ChoiceAction::Text(t); + return ChoiceAction::Text { + text: t, + full_content: Box::new(choice), + }; } // Fallback: only reasoning blocks — extract display text @@ -240,7 +273,10 @@ fn classify_choice(choice: rig::OneOrMany) -> ChoiceAction { }) .collect::>() .join("\n"); - ChoiceAction::Text(display) + ChoiceAction::Text { + text: display, + full_content: Box::new(choice), + } } /// Look up a tool by name in the tool list and execute it. @@ -261,6 +297,35 @@ async fn find_and_execute_tool( } impl LM { + fn chat_from_rig_history(system_prompt: &str, history: &[rig::message::Message]) -> Chat { + let mut chat = Chat::new(Vec::new()); + if !system_prompt.is_empty() { + chat.push_message(Message::system(system_prompt.to_string())); + } + for message in history { + chat.push_message(Message::from(message.clone())); + } + chat + } + + fn to_request_chat_history( + chat_history: Vec, + ) -> Result> { + match chat_history.len() { + 0 => Err(anyhow::anyhow!( + "chat must contain at least one non-system message" + )), + 1 => Ok(rig::OneOrMany::one( + chat_history + .into_iter() + .next() + .ok_or_else(|| anyhow::anyhow!("chat history unexpectedly empty"))?, + )), + _ => rig::OneOrMany::many(chat_history) + .map_err(|_| anyhow::anyhow!("chat must contain at least one message")), + } + } + /// Execute all tool calls in a batch, returning results paired with their calls. async fn execute_tool_batch( tools: &mut [Arc], @@ -346,7 +411,6 @@ impl LM { system_prompt: String, accumulated_usage: &mut LmUsage, ) -> Result { - use rig::OneOrMany; use rig::completion::CompletionRequest; let max_iterations = self.max_tool_iterations as usize; @@ -375,17 +439,13 @@ impl LM { let request = CompletionRequest { model: None, preamble: Some(system_prompt.clone()), - chat_history: if chat_history.len() == 1 { - OneOrMany::one(chat_history.clone().into_iter().next().unwrap()) - } else { - OneOrMany::many(chat_history.clone()).expect("chat_history should not be empty") - }, + chat_history: Self::to_request_chat_history(chat_history.clone())?, documents: Vec::new(), tools: tool_definitions.clone(), temperature: Some(self.temperature as f64), max_tokens: Some(self.max_tokens as u64), tool_choice: Some(ToolChoice::Auto), - additional_params: None, + additional_params: self.additional_params.clone(), output_schema: None, }; @@ -410,10 +470,13 @@ impl LM { // Scan ALL content blocks — don't just look at .first(), since // responses can be [Reasoning, ToolCall] or [Reasoning, Text]. match classify_choice(response.choice) { - ChoiceAction::Text(text) => { + ChoiceAction::Text { full_content, .. } => { debug!(iteration, "tool loop completed with text"); + let content = *full_content; + let message = + Message::from(rig::message::Message::Assistant { id: None, content }); return Ok(ToolLoopResult { - message: Message::assistant(&text), + message, chat_history, tool_calls: all_tool_calls, tool_executions: all_tool_executions, @@ -422,6 +485,7 @@ impl LM { ChoiceAction::ToolCalls { calls, full_content, + .. } => { let context = format!("iteration {}", iteration); debug!(iteration, count = calls.len(), "executing tool calls"); @@ -453,36 +517,34 @@ impl LM { model = %self.model, message_count = messages.len(), tool_count = tools.len(), - cache_enabled = self.cache + cache_enabled = self.cache, + tool_loop_mode = ?tool_loop_mode ) )] - pub async fn call(&self, messages: Chat, tools: Vec>) -> Result { - use rig::OneOrMany; + pub async fn call( + &self, + messages: Chat, + tools: Vec>, + tool_loop_mode: ToolLoopMode, + ) -> Result { use rig::completion::CompletionRequest; - let request_messages = messages.get_rig_messages(); + let system_prompt = messages.system_prompt(); + let chat_history = messages.to_rig_chat_history(); let mut tool_definitions = Vec::new(); for tool in &tools { tool_definitions.push(tool.definition("".to_string()).await); } trace!( - conversation_messages = request_messages.conversation.len(), + conversation_messages = chat_history.len(), tool_definitions = tool_definitions.len(), "prepared completion request inputs" ); - // Build the completion request manually - let mut chat_history = request_messages.conversation; - chat_history.push(request_messages.prompt); - let request = CompletionRequest { model: None, - preamble: Some(request_messages.system.clone()), - chat_history: if chat_history.len() == 1 { - OneOrMany::one(chat_history.clone().into_iter().next().unwrap()) - } else { - OneOrMany::many(chat_history.clone()).expect("chat_history should not be empty") - }, + preamble: Some(system_prompt.clone()), + chat_history: Self::to_request_chat_history(chat_history.clone())?, documents: Vec::new(), tools: tool_definitions.clone(), temperature: Some(self.temperature as f64), @@ -492,7 +554,7 @@ impl LM { } else { None }, - additional_params: None, + additional_params: self.additional_params.clone(), output_schema: None, }; @@ -517,12 +579,26 @@ impl LM { // Scan ALL content blocks in the response — don't just look at .first(). // Responses can be [Reasoning, ToolCall] or [Reasoning, Text]. let mut tool_loop_result = None; - let first_choice = match classify_choice(response.choice) { - ChoiceAction::Text(text) => Message::assistant(&text), + let mut returned_tool_calls = Vec::new(); + let mut assistant_content_for_history: Option> = None; + let mut output_override: Option = None; + let mut append_output_after_history = false; + let classified = classify_choice(response.choice.clone()); + let first_choice = match classified { + ChoiceAction::Text { text, full_content } => { + let content = *full_content; + assistant_content_for_history = Some(content.clone()); + output_override = Some(Message::from(rig::message::Message::Assistant { + id: None, + content, + })); + Message::assistant(&text) + } ChoiceAction::ToolCalls { calls, full_content, - } if !tools.is_empty() => { + assistant_text: _, + } if tool_loop_mode == ToolLoopMode::Auto && !tools.is_empty() => { debug!(count = calls.len(), "entering tool loop"); let result = self .execute_tool_loop( @@ -531,24 +607,65 @@ impl LM { tools, tool_definitions, chat_history, - request_messages.system, + system_prompt.clone(), &mut accumulated_usage, ) .await?; let message = result.message.clone(); tool_loop_result = Some(result); + append_output_after_history = true; message } - ChoiceAction::ToolCalls { calls, .. } => { + ChoiceAction::ToolCalls { + calls, + assistant_text, + full_content, + } if tool_loop_mode == ToolLoopMode::Auto && tools.is_empty() => { let names: Vec<_> = calls.iter().map(|tc| tc.function.name.as_str()).collect(); warn!(?names, "tools requested but no tools available"); - let msg = format!("Tool calls requested: {:?}, but no tools available", names); - Message::assistant(&msg) + returned_tool_calls = calls.clone(); + let content = *full_content; + assistant_content_for_history = Some(content.clone()); + output_override = Some(Message::from(rig::message::Message::Assistant { + id: None, + content, + })); + Message::assistant(assistant_text.unwrap_or_default()) + } + ChoiceAction::ToolCalls { + calls, + assistant_text, + full_content, + } => { + returned_tool_calls = calls; + let content = *full_content; + assistant_content_for_history = Some(content.clone()); + output_override = Some(Message::from(rig::message::Message::Assistant { + id: None, + content, + })); + Message::assistant(assistant_text.unwrap_or_default()) } }; - - let mut full_chat = messages.clone(); - full_chat.push_message(first_choice.clone()); + let output = output_override.unwrap_or_else(|| first_choice.clone()); + + let mut full_chat = if let Some(result) = tool_loop_result.as_ref() { + Self::chat_from_rig_history(&system_prompt, &result.chat_history) + } else { + let mut chat = messages.clone(); + if let Some(content) = assistant_content_for_history { + // Convert grouped rig content into a single grouped Message. + let rig_msg = rig::message::Message::Assistant { id: None, content }; + chat.push_message(Message::from(rig_msg)); + } else { + // Text-only path: preserve a single assistant response turn. + chat.push_message(first_choice.clone()); + } + chat + }; + if append_output_after_history { + full_chat.push_message(first_choice.clone()); + } debug!( tool_calls = tool_loop_result .as_ref() @@ -563,13 +680,13 @@ impl LM { ); Ok(LMResponse { - output: first_choice, + output, usage: accumulated_usage, chat: full_chat, tool_calls: tool_loop_result .as_ref() .map(|result| result.tool_calls.clone()) - .unwrap_or_default(), + .unwrap_or(returned_tool_calls), tool_executions: tool_loop_result .map(|result| result.tool_executions) .unwrap_or_default(), @@ -648,9 +765,7 @@ impl DummyLM { prediction: String, ) -> Result { let mut full_chat = messages.clone(); - full_chat.push_message(Message::Assistant { - content: prediction.clone(), - }); + full_chat.push_message(Message::assistant(prediction.clone())); if self.cache && let Some(cache) = self.cache_handler.as_ref() @@ -682,9 +797,7 @@ impl DummyLM { } Ok(LMResponse { - output: Message::Assistant { - content: prediction.clone(), - }, + output: Message::assistant(prediction.clone()), usage: LmUsage::default(), chat: full_chat, tool_calls: Vec::new(), @@ -716,6 +829,10 @@ mod tests { use super::*; use rig::OneOrMany; use rig::completion::AssistantContent; + use rig::completion::ToolDefinition; + use rig::tool::Tool; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; fn make_tool_call(name: &str) -> AssistantContent { AssistantContent::tool_call( @@ -737,7 +854,10 @@ mod tests { fn classify_text_only() { let choice = OneOrMany::one(make_text("hello")); match classify_choice(choice) { - ChoiceAction::Text(t) => assert_eq!(t, "hello"), + ChoiceAction::Text { text, full_content } => { + assert_eq!(text, "hello"); + assert_eq!(full_content.iter().count(), 1); + } ChoiceAction::ToolCalls { .. } => panic!("expected Text, got ToolCalls"), } } @@ -749,12 +869,14 @@ mod tests { ChoiceAction::ToolCalls { calls, full_content, + assistant_text, } => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "search"); assert_eq!(full_content.iter().count(), 1); + assert!(assistant_text.is_none()); } - ChoiceAction::Text(_) => panic!("expected ToolCalls, got Text"), + ChoiceAction::Text { .. } => panic!("expected ToolCalls, got Text"), } } @@ -770,13 +892,15 @@ mod tests { ChoiceAction::ToolCalls { calls, full_content, + assistant_text, } => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "search"); // full_content preserves both blocks assert_eq!(full_content.iter().count(), 2); + assert!(assistant_text.is_none()); } - ChoiceAction::Text(_) => panic!("expected ToolCalls, got Text"), + ChoiceAction::Text { .. } => panic!("expected ToolCalls, got Text"), } } @@ -789,7 +913,10 @@ mod tests { .unwrap(); match classify_choice(choice) { - ChoiceAction::Text(t) => assert_eq!(t, "the answer is 42"), + ChoiceAction::Text { text, full_content } => { + assert_eq!(text, "the answer is 42"); + assert_eq!(full_content.iter().count(), 2); + } ChoiceAction::ToolCalls { .. } => panic!("expected Text, got ToolCalls"), } } @@ -798,7 +925,10 @@ mod tests { fn classify_reasoning_only_fallback() { let choice = OneOrMany::one(make_reasoning("just thinking")); match classify_choice(choice) { - ChoiceAction::Text(t) => assert_eq!(t, "just thinking"), + ChoiceAction::Text { text, full_content } => { + assert_eq!(text, "just thinking"); + assert_eq!(full_content.iter().count(), 1); + } ChoiceAction::ToolCalls { .. } => panic!("expected Text, got ToolCalls"), } } @@ -809,11 +939,16 @@ mod tests { OneOrMany::many(vec![make_text("some text"), make_tool_call("search")]).unwrap(); match classify_choice(choice) { - ChoiceAction::ToolCalls { calls, .. } => { + ChoiceAction::ToolCalls { + calls, + assistant_text, + .. + } => { assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "search"); + assert_eq!(assistant_text.as_deref(), Some("some text")); } - ChoiceAction::Text(_) => panic!("expected ToolCalls, got Text"), + ChoiceAction::Text { .. } => panic!("expected ToolCalls, got Text"), } } @@ -830,13 +965,15 @@ mod tests { ChoiceAction::ToolCalls { calls, full_content, + assistant_text, } => { assert_eq!(calls.len(), 2); assert_eq!(calls[0].function.name, "search"); assert_eq!(calls[1].function.name, "calculate"); assert_eq!(full_content.iter().count(), 3); + assert!(assistant_text.is_none()); } - ChoiceAction::Text(_) => panic!("expected ToolCalls, got Text"), + ChoiceAction::Text { .. } => panic!("expected ToolCalls, got Text"), } } @@ -846,8 +983,199 @@ mod tests { rig::completion::message::Image::default(), )); match classify_choice(choice) { - ChoiceAction::Text(t) => assert!(t.is_empty()), + ChoiceAction::Text { text, full_content } => { + assert!(text.is_empty()); + assert_eq!(full_content.iter().count(), 1); + } ChoiceAction::ToolCalls { .. } => panic!("expected Text, got ToolCalls"), } } + + #[derive(Clone)] + struct CountingTool { + calls: Arc, + } + + #[derive(Debug)] + struct CountingToolError; + + impl std::fmt::Display for CountingToolError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "counting tool error") + } + } + + impl std::error::Error for CountingToolError {} + + impl Tool for CountingTool { + const NAME: &'static str = "counter"; + type Error = CountingToolError; + type Args = serde_json::Value; + type Output = String; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: Self::NAME.to_string(), + description: "counter tool".to_string(), + parameters: serde_json::json!({ + "type": "object", + "additionalProperties": true + }), + } + } + + async fn call(&self, _args: Self::Args) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok("counted".to_string()) + } + } + + fn test_lm_with_model(model: TestCompletionModel) -> LM { + LM { + base_url: None, + api_key: None, + model: "openai:gpt-4o-mini".to_string(), + temperature: 0.0, + max_tokens: 128, + max_tool_iterations: 4, + additional_params: None, + anthropic_prompt_caching: false, + cache: false, + cache_handler: None, + client: Some(Arc::new(LMClient::Test(model))), + } + } + + fn test_lm_with_model_and_params(model: TestCompletionModel, additional_params: Value) -> LM { + LM { + additional_params: Some(additional_params), + ..test_lm_with_model(model) + } + } + + #[tokio::test] + async fn call_with_caller_managed_mode_returns_tool_calls_without_executing() { + let model = TestCompletionModel::new([make_tool_call("counter")]); + let lm = test_lm_with_model(model); + + let call_count = Arc::new(AtomicUsize::new(0)); + let tools: Vec> = vec![Arc::new(CountingTool { + calls: Arc::clone(&call_count), + })]; + + let chat = Chat::new(vec![Message::user("Use the counter tool")]); + let response = lm + .call(chat, tools, ToolLoopMode::CallerManaged) + .await + .expect("caller-managed call should succeed"); + + assert_eq!(response.tool_calls.len(), 1); + assert!(response.tool_executions.is_empty()); + assert_eq!(call_count.load(Ordering::SeqCst), 0); + assert!(response.output.has_tool_calls()); + assert!(response.output.content().contains("counter")); + assert_eq!(response.chat.len(), 2); + assert!(response.chat.messages[1].has_tool_calls()); + } + + #[tokio::test] + async fn call_forwards_additional_params_to_completion_request() { + let model = TestCompletionModel::new([make_text("ok")]); + let lm = test_lm_with_model_and_params( + model.clone(), + serde_json::json!({ + "reasoning": { "effort": "high" } + }), + ); + + let chat = Chat::new(vec![Message::user("hello")]); + let _ = lm + .call(chat, Vec::new(), ToolLoopMode::CallerManaged) + .await + .expect("call should succeed"); + + let request = model.last_request().expect("request should be captured"); + assert_eq!( + request.additional_params, + Some(serde_json::json!({ "reasoning": { "effort": "high" } })) + ); + } + + #[tokio::test] + async fn call_default_auto_mode_executes_tool_loop() { + let model = TestCompletionModel::new([make_tool_call("counter"), make_text("done")]); + let lm = test_lm_with_model(model); + + let call_count = Arc::new(AtomicUsize::new(0)); + let tools: Vec> = vec![Arc::new(CountingTool { + calls: Arc::clone(&call_count), + })]; + + let chat = Chat::new(vec![Message::user("Use the counter tool")]); + let response = lm + .call(chat, tools, ToolLoopMode::Auto) + .await + .expect("auto call should succeed"); + + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_executions.len(), 1); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + assert_eq!(response.output.content(), "done"); + assert_eq!(response.chat.len(), 4); + assert!(response.chat.messages[1].has_tool_calls()); + assert!(response.chat.messages[2].has_tool_results()); + assert_eq!(response.chat.messages[3].role, Role::Assistant); + } + + #[tokio::test] + async fn call_auto_mode_with_no_tools_returns_requested_tool_calls() { + let model = TestCompletionModel::new([make_tool_call("counter")]); + let lm = test_lm_with_model(model); + + let chat = Chat::new(vec![Message::user("Use the counter tool")]); + let response = lm + .call(chat, vec![], ToolLoopMode::Auto) + .await + .expect("auto call should succeed"); + + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_calls[0].function.name, "counter"); + assert!(response.tool_executions.is_empty()); + assert!(response.chat.messages.iter().any(Message::has_tool_calls)); + } + + #[test] + fn text_choice_with_reasoning_preserves_grouped_content_for_output_conversion() { + let choice = OneOrMany::many(vec![make_reasoning("thinking"), make_text("done")]).unwrap(); + + let (text, full_content) = match classify_choice(choice) { + ChoiceAction::Text { text, full_content } => (text, full_content), + ChoiceAction::ToolCalls { .. } => panic!("expected Text, got ToolCalls"), + }; + + assert_eq!(text, "done"); + + let output = Message::from(rig::message::Message::Assistant { + id: None, + content: *full_content, + }); + assert!(output.has_reasoning()); + assert_eq!(output.text_content(), "done"); + } + + #[tokio::test] + async fn call_with_only_system_message_returns_error() { + let model = TestCompletionModel::new([make_text("unused")]); + let lm = test_lm_with_model(model); + let chat = Chat::new(vec![Message::system("system only")]); + + let err = lm + .call(chat, vec![], ToolLoopMode::Auto) + .await + .expect_err("system-only chat should fail"); + assert!( + err.to_string() + .contains("chat must contain at least one non-system message") + ); + } } diff --git a/crates/dspy-rs/src/core/mod.rs b/crates/dspy-rs/src/core/mod.rs index 64bf2cb8..a28e3b3f 100644 --- a/crates/dspy-rs/src/core/mod.rs +++ b/crates/dspy-rs/src/core/mod.rs @@ -7,7 +7,8 @@ //! input, returns a predicted output) so that strategies are interchangeable. //! //! [`Predicted`] wraps a typed output with [`CallMetadata`] (raw response text, token -//! usage, per-field parse results). The error hierarchy — [`PredictError`], [`ParseError`], +//! usage, per-field parse results) and [`Chat`] (the conversation history from the LM +//! call). The error hierarchy — [`PredictError`], [`ParseError`], //! [`LmError`] — distinguishes LM failures from parse failures so callers can handle //! retries differently. [`LM`] is the language model client itself. //! @@ -30,6 +31,7 @@ mod schema; pub mod settings; pub mod signature; pub mod specials; +mod type_name; pub(crate) use dyn_predictor::*; pub use errors::{ConversionError, ErrorClass, JsonishError, LmError, ParseError, PredictError}; @@ -41,3 +43,4 @@ pub use schema::{FieldMetadataSpec, FieldPath, FieldSchema, InputRenderSpec, Sig pub use settings::*; pub use signature::*; pub use specials::*; +pub(crate) use type_name::{render_type_name_for_prompt_with, simplify_type_token}; diff --git a/crates/dspy-rs/src/core/module.rs b/crates/dspy-rs/src/core/module.rs index 234998e3..feeae5ed 100644 --- a/crates/dspy-rs/src/core/module.rs +++ b/crates/dspy-rs/src/core/module.rs @@ -17,13 +17,15 @@ type IndexedForwardResult = (usize, Result, PredictError>); /// implementors. `call` currently just delegates to `forward` — the split exists so we /// can add hooks or tracing around `call` without breaking module implementations. /// -/// # Two kinds of output data +/// # Three kinds of output data /// /// Every call returns [`Predicted`](crate::Predicted), which carries: /// - **`Output`** — what the LM was asked to produce. Shaped by your signature and any /// augmentations. Accessible directly via `Deref`: `result.answer`, `result.reasoning`. /// - **[`CallMetadata`](crate::CallMetadata)** — what the runtime observed. Token counts, /// raw response, constraint results. Never enters a prompt. Via `result.metadata()`. +/// - **[`Chat`](crate::Chat)** — conversation history for the call, including the assistant +/// response turn, so callers can continue multi-turn interactions via `result.chat()`. /// /// This drives the type system: [`ChainOfThought`](crate::ChainOfThought) changes `Output` /// because it modifies the prompt (adds a `reasoning` field). A wrapper like `BestOfN` keeps diff --git a/crates/dspy-rs/src/core/module_ext.rs b/crates/dspy-rs/src/core/module_ext.rs index 7586b203..ed448488 100644 --- a/crates/dspy-rs/src/core/module_ext.rs +++ b/crates/dspy-rs/src/core/module_ext.rs @@ -77,8 +77,8 @@ where async fn forward(&self, input: Self::Input) -> Result, PredictError> { let predicted = self.inner.call(input).await?; - let (output, metadata) = predicted.into_parts(); - Ok(Predicted::new((self.map)(output), metadata)) + let (output, metadata, chat) = predicted.into_parts(); + Ok(Predicted::new((self.map)(output), metadata, chat)) } } @@ -107,8 +107,8 @@ where async fn forward(&self, input: Self::Input) -> Result, PredictError> { let predicted = self.inner.call(input).await?; - let (output, metadata) = predicted.into_parts(); + let (output, metadata, chat) = predicted.into_parts(); let transformed = (self.and_then)(output)?; - Ok(Predicted::new(transformed, metadata)) + Ok(Predicted::new(transformed, metadata, chat)) } } diff --git a/crates/dspy-rs/src/core/predicted.rs b/crates/dspy-rs/src/core/predicted.rs index c40940c2..8056f61e 100644 --- a/crates/dspy-rs/src/core/predicted.rs +++ b/crates/dspy-rs/src/core/predicted.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use indexmap::IndexMap; use rig::message::ToolCall; -use crate::{Flag, LmUsage}; +use crate::{Chat, Flag, LmUsage}; /// Per-field details from parsing an LM response. /// @@ -159,22 +159,28 @@ impl CallMetadata { /// let result = Predicted::new( /// QAOutput { answer: "42".into() }, /// CallMetadata::default(), +/// dspy_rs::Chat::new(vec![]), /// ); /// assert_eq!(result.answer, "42"); // output field via Deref /// let _usage = &result.metadata().lm_usage; // runtime info, never in prompts -/// let (output, meta) = result.into_parts(); // decompose for ownership +/// let (output, meta, _chat) = result.into_parts(); // decompose for ownership /// assert_eq!(output.answer, "42"); /// ``` #[derive(Debug, Clone)] pub struct Predicted { output: O, metadata: CallMetadata, + chat: Chat, } impl Predicted { /// Creates a new `Predicted` from an output value and call metadata. - pub fn new(output: O, metadata: CallMetadata) -> Self { - Self { output, metadata } + pub fn new(output: O, metadata: CallMetadata, chat: Chat) -> Self { + Self { + output, + metadata, + chat, + } } /// Returns the call metadata (raw response, token usage, tool calls, field-level details). @@ -182,14 +188,19 @@ impl Predicted { &self.metadata } + /// Returns conversation history associated with this prediction. + pub fn chat(&self) -> &Chat { + &self.chat + } + /// Unwraps the typed output, discarding metadata. pub fn into_inner(self) -> O { self.output } - /// Splits into the typed output and call metadata. - pub fn into_parts(self) -> (O, CallMetadata) { - (self.output, self.metadata) + /// Splits into typed output, call metadata, and conversation history. + pub fn into_parts(self) -> (O, CallMetadata, Chat) { + (self.output, self.metadata, self.chat) } } diff --git a/crates/dspy-rs/src/core/type_name.rs b/crates/dspy-rs/src/core/type_name.rs new file mode 100644 index 00000000..bf52918f --- /dev/null +++ b/crates/dspy-rs/src/core/type_name.rs @@ -0,0 +1,52 @@ +use crate::TypeIR; + +pub(crate) fn simplify_type_token(token: &str) -> String { + token.rsplit("::").next().unwrap_or(token).to_string() +} + +pub(crate) fn simplify_type_name_with( + raw: &str, + mut render_token: impl FnMut(&str) -> String, +) -> String { + let mut result = String::with_capacity(raw.len()); + let mut chars = raw.chars(); + while let Some(ch) = chars.next() { + if ch == '`' { + let mut token = String::new(); + for next in chars.by_ref() { + if next == '`' { + break; + } + token.push(next); + } + result.push_str(&render_token(&token)); + } else { + result.push(ch); + } + } + result +} + +pub(crate) fn render_type_name_for_prompt_with( + type_ir: &TypeIR, + render_token: impl FnMut(&str) -> String, +) -> String { + simplify_type_name_with(&type_ir.diagnostic_repr().to_string(), render_token) + .replace("class ", "") + .replace("enum ", "") + .replace(" | ", " or ") + .trim() + .to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simplify_type_name_with_rewrites_backtick_tokens() { + let raw = "class `my::pkg::Thing` | `other::Foo`"; + let rendered = simplify_type_name_with(raw, simplify_type_token); + assert_eq!(rendered, "class Thing | Foo"); + } +} diff --git a/crates/dspy-rs/src/lib.rs b/crates/dspy-rs/src/lib.rs index c8e5e2af..8399ccc4 100644 --- a/crates/dspy-rs/src/lib.rs +++ b/crates/dspy-rs/src/lib.rs @@ -41,7 +41,7 @@ //! .build() //! .await //! .unwrap(); -//! dspy_rs::configure(lm, ChatAdapter); +//! dspy_rs::configure(lm, ChatAdapter::new()); //! //! // 2. Pick a strategy //! let cot = ChainOfThought::::new(); @@ -131,6 +131,10 @@ pub use bamltype::internal_baml_jinja::types::{OutputFormatContent, RenderOption pub use bamltype::jsonish::deserializer::deserialize_flags::Flag; pub use dsrs_macros::*; pub use facet::Facet; +#[cfg(feature = "rlm")] +pub use modules::rlm::RlmInputFields; +#[cfg(feature = "rlm")] +pub use rlm_derive::{RlmType, rlm_type}; /// Pre-built signature for use in doc examples. Not part of the public API. #[doc(hidden)] @@ -150,6 +154,8 @@ pub mod __macro_support { pub use anyhow; pub use bamltype; pub use indexmap; + #[cfg(feature = "rlm")] + pub use pyo3; pub use schemars; pub use serde; pub use serde_json; diff --git a/crates/dspy-rs/src/modules/mod.rs b/crates/dspy-rs/src/modules/mod.rs index bb78415a..9739773f 100644 --- a/crates/dspy-rs/src/modules/mod.rs +++ b/crates/dspy-rs/src/modules/mod.rs @@ -1,5 +1,9 @@ pub mod chain_of_thought; pub mod react; +#[cfg(feature = "rlm")] +pub mod rlm; pub use chain_of_thought::{ChainOfThought, ChainOfThoughtOutput, Reasoning, WithReasoning}; pub use react::ReAct; +#[cfg(feature = "rlm")] +pub use rlm::Rlm; diff --git a/crates/dspy-rs/src/modules/react.rs b/crates/dspy-rs/src/modules/react.rs index 471d1889..752de6b2 100644 --- a/crates/dspy-rs/src/modules/react.rs +++ b/crates/dspy-rs/src/modules/react.rs @@ -158,7 +158,7 @@ where ReActActionStepInput::new(serialized_input.clone(), trajectory_text.clone()); let action_predicted = self.action.call(action_input).await?; - let (action_output, mut action_metadata) = action_predicted.into_parts(); + let (action_output, mut action_metadata, _action_chat) = action_predicted.into_parts(); tool_calls.append(&mut action_metadata.tool_calls); tool_executions.append(&mut action_metadata.tool_executions); @@ -220,12 +220,16 @@ where let extract_input = ReActExtractStepInput::new(serialized_input, trajectory_text); let extract_predicted = self.extract.call(extract_input).await?; - let (extract_output, mut extract_metadata) = extract_predicted.into_parts(); + let (extract_output, mut extract_metadata, extract_chat) = extract_predicted.into_parts(); extract_metadata.tool_calls.extend(tool_calls); extract_metadata.tool_executions.extend(tool_executions); let output: ReActExtractStepOutput = extract_output; - Ok(Predicted::new(output.output, extract_metadata)) + Ok(Predicted::new( + output.output, + extract_metadata, + extract_chat, + )) } } diff --git a/crates/dspy-rs/src/modules/rlm/exec.rs b/crates/dspy-rs/src/modules/rlm/exec.rs new file mode 100644 index 00000000..07d098b8 --- /dev/null +++ b/crates/dspy-rs/src/modules/rlm/exec.rs @@ -0,0 +1,583 @@ +use pyo3::exceptions::PyRuntimeError; +use pyo3::ffi::c_str; +use pyo3::types::{PyAnyMethods, PyDict, PyModule}; +use pyo3::{Py, PyResult, Python}; + +use super::submit::{SUBMIT_STDOUT_ATTR, is_submit_terminated}; + +const NO_OUTPUT_MESSAGE: &str = "(no output - did you forget to print?)"; +const TRACEBACK_ATTR: &str = "__dsrs_traceback__"; +const LLM_BUDGET_EXHAUSTED_PREFIX: &str = "LLM call budget exhausted: requested "; + +static EXEC_HELPER_CODE: &std::ffi::CStr = c_str!( + r#" +import ast +import contextlib +import io +import traceback + + +def dsrs_exec(code, globals_dict, suppress_output): + buffer = io.StringIO() + result = None + with contextlib.redirect_stdout(buffer): + try: + parsed = ast.parse(code, mode="exec") + if suppress_output or not parsed.body: + exec(compile(parsed, "", "exec"), globals_dict, globals_dict) + else: + last = parsed.body[-1] + if isinstance(last, ast.Expr): + body = parsed.body[:-1] + if body: + exec( + compile(ast.Module(body=body, type_ignores=[]), "", "exec"), + globals_dict, + globals_dict, + ) + result = eval( + compile(ast.Expression(last.value), "", "eval"), + globals_dict, + globals_dict, + ) + else: + exec(compile(parsed, "", "exec"), globals_dict, globals_dict) + except BaseException as exc: + try: + setattr(exc, "__dsrs_stdout__", buffer.getvalue()) + except Exception: + pass + try: + setattr(exc, "__dsrs_traceback__", traceback.format_exc()) + except Exception: + pass + raise + return buffer.getvalue(), (None if result is None else repr(result)) +"# +); + +pub fn execute_repl_code( + py: Python<'_>, + globals: &Py, + code: &str, + max_output_chars: usize, +) -> Result { + let prepared_code = preprocess_repl_code(code); + let suppress_output = prepared_code.trim_end().ends_with(';'); + + match run_exec( + py, + globals, + &prepared_code, + suppress_output, + max_output_chars, + ) { + Ok(output) => Ok(output), + Err(err) => { + if let Some(repaired_code) = maybe_repair_submit_code(py, &prepared_code, &err) { + match run_exec( + py, + globals, + &repaired_code, + suppress_output, + max_output_chars, + ) { + Ok(output) => return Ok(output), + Err(_repaired_err) => {} + } + } + let stdout = extract_submit_stdout(py, &err).unwrap_or_default(); + let traceback = extract_traceback(py, &err) + .or_else(|| format_python_traceback(py, &err).ok()) + .unwrap_or_else(|| err.to_string()); + if let Some(resource_message) = format_resource_budget_message(&traceback) { + let combined = combine_stdout_and_message(stdout, resource_message); + return Err(truncate_capture_output(&combined, max_output_chars)); + } + let combined = combine_stdout_and_traceback(stdout, traceback); + Err(truncate_capture_output(&combined, max_output_chars)) + } + } +} + +fn maybe_repair_submit_code(py: Python<'_>, code: &str, err: &pyo3::PyErr) -> Option { + if !code.contains("SUBMIT(") { + return None; + } + + let traceback = extract_traceback(py, err).or_else(|| format_python_traceback(py, err).ok())?; + if !traceback.contains("SyntaxError") + || (!traceback.contains("unterminated triple-quoted string literal") + && !traceback.contains("unterminated string literal")) + { + return None; + } + + repair_submit_code(code) +} + +fn repair_submit_code(code: &str) -> Option { + if !code.contains("SUBMIT(") { + return None; + } + + let mut repaired = code.trim_end().to_string(); + let mut changed = false; + + for quote in ["\"\"\"", "'''"] { + if repaired.matches(quote).count() % 2 != 0 { + repaired.push_str(quote); + changed = true; + } + } + + if let Some(submit_start) = repaired.rfind("SUBMIT(") { + let tail = &repaired[submit_start..]; + let open_parens = tail.chars().filter(|&c| c == '(').count(); + let close_parens = tail.chars().filter(|&c| c == ')').count(); + if open_parens > close_parens { + repaired.push_str(&")".repeat(open_parens - close_parens)); + changed = true; + } + } + + if changed { Some(repaired) } else { None } +} + +fn preprocess_repl_code(code: &str) -> String { + let without_fences = strip_markdown_fence_lines(code); + strip_leading_non_python_lines(&without_fences) +} + +fn strip_markdown_fence_lines(text: &str) -> String { + text.lines() + .filter(|line| !line.trim_start().starts_with("```")) + .collect::>() + .join("\n") +} + +fn strip_leading_non_python_lines(text: &str) -> String { + let lines = text.lines().collect::>(); + let first_code_index = lines.iter().position(|line| { + let trimmed = line.trim(); + !trimmed.is_empty() && looks_like_python_line(trimmed) + }); + + let selected = match first_code_index { + Some(index) => lines[index..].join("\n"), + None => text.to_string(), + }; + selected.trim_end().to_string() +} + +fn looks_like_python_line(line: &str) -> bool { + let trimmed = line.trim_start(); + if trimmed.is_empty() { + return false; + } + if trimmed.starts_with('#') + || trimmed.starts_with('[') + || trimmed.starts_with('{') + || trimmed.contains('=') + || trimmed.contains('(') + { + return true; + } + + let lower = trimmed.to_ascii_lowercase(); + for prefix in [ + "import ", "from ", "print", "def ", "for ", "if ", "while ", "try:", "with ", "class ", + ] { + if lower.starts_with(prefix) { + return true; + } + } + + false +} + +fn run_exec( + py: Python<'_>, + globals: &Py, + code: &str, + suppress_output: bool, + max_output_chars: usize, +) -> PyResult { + let helper_globals = PyDict::new(py); + py.run( + EXEC_HELPER_CODE, + Some(&helper_globals), + Some(&helper_globals), + )?; + let exec_fn = helper_globals + .get_item("dsrs_exec") + .map_err(|_| PyRuntimeError::new_err("dsrs_exec helper function missing"))?; + let globals = globals.bind(py); + match exec_fn.call1((code, globals, suppress_output)) { + Ok(result) => { + let (stdout, repr) = result.extract::<(String, Option)>()?; + Ok(format_output(stdout, repr, max_output_chars)) + } + Err(err) if is_submit_terminated(&err, py) => { + let stdout = extract_submit_stdout(py, &err).unwrap_or_default(); + Ok(format_output(stdout, None, max_output_chars)) + } + Err(err) => Err(err), + } +} + +fn extract_submit_stdout(py: Python<'_>, err: &pyo3::PyErr) -> Option { + err.value(py) + .getattr(SUBMIT_STDOUT_ATTR) + .ok() + .and_then(|value| value.extract::().ok()) +} + +fn extract_traceback(py: Python<'_>, err: &pyo3::PyErr) -> Option { + err.value(py) + .getattr(TRACEBACK_ATTR) + .ok() + .and_then(|value| value.extract::().ok()) +} + +fn format_python_traceback(py: Python<'_>, err: &pyo3::PyErr) -> PyResult { + let traceback = PyModule::import(py, "traceback")?; + let formatted = traceback.getattr("format_exception")?.call1(( + err.get_type(py), + err.value(py), + err.traceback(py), + ))?; + let parts: Vec = formatted.extract()?; + Ok(parts.join("")) +} + +fn combine_stdout_and_traceback(stdout: String, traceback: String) -> String { + if stdout.is_empty() { + return traceback; + } + if stdout.ends_with('\n') { + format!("{stdout}{traceback}") + } else { + format!("{stdout}\n{traceback}") + } +} + +fn combine_stdout_and_message(stdout: String, message: String) -> String { + if stdout.is_empty() { + return message; + } + if stdout.ends_with('\n') { + format!("{stdout}{message}") + } else { + format!("{stdout}\n{message}") + } +} + +fn format_resource_budget_message(traceback: &str) -> Option { + let after_prefix = traceback.split_once(LLM_BUDGET_EXHAUSTED_PREFIX)?.1; + let (requested_raw, after_requested) = after_prefix.split_once(", remaining ")?; + let (remaining_raw, after_remaining) = after_requested.split_once(", max ")?; + let max_raw: String = after_remaining + .chars() + .take_while(|ch| ch.is_ascii_digit()) + .collect(); + if max_raw.is_empty() { + return None; + } + + let requested = requested_raw.trim().parse::().ok()?; + let remaining = remaining_raw.trim().parse::().ok()?; + let max = max_raw.parse::().ok()?; + + Some(format!( + "⛔ RESOURCE: llm_query({requested}) refused — {remaining} of {max} calls remain.\ncode was valid. namespace unchanged." + )) +} + +fn format_output(stdout: String, repr: Option, max_chars: usize) -> String { + let mut output = stdout; + if let Some(repr) = repr { + if !output.is_empty() && !output.ends_with('\n') { + output.push('\n'); + } + output.push_str(&repr); + } + + if output.is_empty() { + output = NO_OUTPUT_MESSAGE.to_string(); + } + + truncate_capture_output(&output, max_chars) +} + +fn truncate_capture_output(text: &str, max_chars: usize) -> String { + if max_chars == 0 { + return String::new(); + } + let total = text.chars().count(); + if total <= max_chars { + return text.to_string(); + } + + let head_len = max_chars / 2; + let tail_len = max_chars.saturating_sub(head_len); + + let head: String = text.chars().take(head_len).collect(); + let tail: String = text.chars().skip(total.saturating_sub(tail_len)).collect(); + let truncation_notice = format!( + "[STDOUT TRUNCATED at {} chars ({} total)]", + format_count(max_chars), + format_count(total) + ); + + format!("{head}\n{tail}\n{truncation_notice}") +} + +fn format_count(value: usize) -> String { + let raw = value.to_string(); + let mut out = String::with_capacity(raw.len() + raw.len() / 3); + for (index, ch) in raw.chars().rev().enumerate() { + if index > 0 && index % 3 == 0 { + out.push(','); + } + out.push(ch); + } + out.chars().rev().collect() +} + +#[cfg(test)] +mod tests { + use pyo3::types::{PyDict, PyDictMethods}; + + use super::*; + use crate::modules::rlm::submit::SubmitTerminated; + + #[test] + fn executes_expression_and_returns_repr() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let output = execute_repl_code(py, &globals, "1 + 2", 100).expect("exec"); + assert_eq!(output, "3"); + }); + } + + #[test] + fn combines_stdout_and_repr() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let output = execute_repl_code(py, &globals, "print('hi')\n2 + 3", 100).expect("exec"); + assert_eq!(output, "hi\n5"); + }); + } + + #[test] + fn suppresses_output_on_trailing_semicolon() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let output = execute_repl_code(py, &globals, "2 + 3;", 100).expect("exec"); + assert_eq!(output, NO_OUTPUT_MESSAGE); + }); + } + + #[test] + fn returns_no_output_message_when_no_stdout_or_repr() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let output = execute_repl_code(py, &globals, "x = 10", 100).expect("exec"); + assert_eq!(output, NO_OUTPUT_MESSAGE); + }); + } + + #[test] + fn truncates_with_head_and_tail() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let output = execute_repl_code(py, &globals, "print('abcdefghijklmnopqrstuvwxyz')", 10) + .expect("exec"); + assert!(output.contains("[STDOUT TRUNCATED at 10 chars (27 total)]")); + assert!(output.starts_with("abcde")); + assert!(output.contains("wxyz\n")); + assert!(output.ends_with("[STDOUT TRUNCATED at 10 chars (27 total)]")); + }); + } + + #[test] + fn submit_terminated_is_treated_as_success_path() { + Python::attach(|py| { + let globals = PyDict::new(py); + globals + .set_item("SubmitTerminated", py.get_type::()) + .expect("set type"); + let globals = globals.unbind(); + + let output = execute_repl_code( + py, + &globals, + "print('before submit')\nraise SubmitTerminated('done')", + 200, + ) + .expect("exec"); + + assert_eq!(output, "before submit\n"); + }); + } + + #[test] + fn syntax_errors_return_err_string() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let err = execute_repl_code(py, &globals, "if True print('x')", 100) + .expect_err("should fail"); + assert!(err.contains("SyntaxError")); + assert!(err.contains("Traceback")); + }); + } + + #[test] + fn includes_stdout_and_traceback_on_python_errors() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let err = execute_repl_code( + py, + &globals, + "print('before failure')\nraise ValueError('boom')", + 500, + ) + .expect_err("should fail"); + + assert!(err.contains("before failure")); + assert!(err.contains("Traceback")); + assert!(err.contains("ValueError: boom")); + }); + } + + #[test] + fn import_errors_include_traceback_and_exception_type() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let err = execute_repl_code(py, &globals, "import definitely_missing_module_xyz", 500) + .expect_err("should fail"); + + assert!(err.contains("Traceback")); + assert!( + err.contains("ModuleNotFoundError") + || err.contains("ImportError") + || err.contains("AttributeError"), + "expected import-related failure class in traceback: {err}" + ); + assert!( + err.contains("definitely_missing_module_xyz") + || err.contains("partially initialized module"), + "expected import target or fallback import error context: {err}" + ); + }); + } + + #[test] + fn truncates_error_output_with_budget() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let err = execute_repl_code( + py, + &globals, + "print('abcdefghijklmnopqrstuvwxyz')\nraise RuntimeError('abcdefghijklmnopqrstuvwxyz')", + 20, + ) + .expect_err("should fail"); + + assert!(err.contains("[STDOUT TRUNCATED at 20 chars (")); + assert!(err.chars().count() > 20); + }); + } + + #[test] + fn truncation_is_unicode_safe_for_multibyte_characters() { + let text = "😀".repeat(40); + let truncated = truncate_capture_output(&text, 9); + + assert!(truncated.contains("[STDOUT TRUNCATED at 9 chars (40 total)]")); + assert!(truncated.is_char_boundary(truncated.len())); + assert!(truncated.contains('😀')); + } + + #[test] + fn format_count_uses_thousands_separators() { + assert_eq!(format_count(0), "0"); + assert_eq!(format_count(10_000), "10,000"); + assert_eq!(format_count(2_345_678), "2,345,678"); + } + + #[test] + fn preprocess_strips_markdown_fences() { + let raw = "```python\nprint('a')\n```\n```py\nprint('b')\n```\n```\nprint('c')\n```"; + let prepared = preprocess_repl_code(raw); + assert_eq!(prepared, "print('a')\nprint('b')\nprint('c')"); + } + + #[test] + fn preprocess_strips_leading_prose_until_python() { + let raw = "Let me start by exploring the data first.\n\n```python\n# First, inspect\na = 1\nprint(a)\n```"; + let prepared = preprocess_repl_code(raw); + assert_eq!(prepared, "# First, inspect\na = 1\nprint(a)"); + } + + #[test] + fn execute_repl_code_handles_failed_turn_one_pattern() { + Python::attach(|py| { + let globals = PyDict::new(py).unbind(); + let raw = "Let me start by exploring the data to understand the structure and then systematically find recurring corrections.\n\n```python\n# First, explore the data structure\nprint('ok')\n```"; + let output = execute_repl_code(py, &globals, raw, 500).expect("exec"); + assert_eq!(output, "ok\n"); + }); + } + + #[test] + fn repair_submit_code_closes_unterminated_triple_quote_and_paren() { + let repaired = repair_submit_code("SUBMIT(direct_answer=\"\"\"hello") + .expect("repair should produce code"); + assert_eq!(repaired, "SUBMIT(direct_answer=\"\"\"hello\"\"\")"); + } + + #[test] + fn execute_repl_code_repairs_unterminated_submit_payload() { + Python::attach(|py| { + let globals = PyDict::new(py); + globals + .set_item("SubmitTerminated", py.get_type::()) + .expect("set type"); + py.run( + c_str!("def SUBMIT(**kwargs):\n raise SubmitTerminated('done')\n"), + Some(&globals), + Some(&globals), + ) + .expect("submit helper"); + let globals = globals.unbind(); + + let output = execute_repl_code(py, &globals, "SUBMIT(direct_answer=\"\"\"hello", 500) + .expect("submit should recover"); + assert_eq!(output, NO_OUTPUT_MESSAGE); + }); + } + + #[test] + fn budget_exhaustion_errors_are_rendered_as_resource_messages_without_traceback() { + Python::attach(|py| { + let globals = PyDict::new(py); + py.run( + c_str!( + "def llm_query(prompt):\n raise RuntimeError(\"[Error] RuntimeError: LLM call budget exhausted: requested 1, remaining 0, max 20. This is retryable after reducing llm_query usage.\")\n" + ), + Some(&globals), + Some(&globals), + ) + .expect("define llm_query"); + let globals = globals.unbind(); + + let err = execute_repl_code(py, &globals, "llm_query('hello')", 500) + .expect_err("budget exhaustion should fail"); + assert!(err.contains("⛔ RESOURCE: llm_query(1) refused — 0 of 20 calls remain.")); + assert!(err.contains("code was valid. namespace unchanged.")); + assert!(!err.contains("Traceback")); + }); + } +} diff --git a/crates/dspy-rs/src/modules/rlm/mod.rs b/crates/dspy-rs/src/modules/rlm/mod.rs new file mode 100644 index 00000000..8266f906 --- /dev/null +++ b/crates/dspy-rs/src/modules/rlm/mod.rs @@ -0,0 +1,2085 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::marker::PhantomData; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use indexmap::IndexMap; +use pyo3::types::{ + PyAnyMethods, PyBool, PyDict, PyDictMethods, PyFloat, PyInt, PyList, PyListMethods, PyModule, + PySet, PyString, PyStringMethods, PyTuple, PyTypeMethods, +}; +use pyo3::{Bound, Py, Python}; +use rig::message::ToolCall; +use tracing::{debug, info, info_span}; + +use crate::{ + BamlType, BamlValue, CallMetadata, Chat, ChatAdapter, Facet, FieldMeta, LmUsage, Module, + Predict, PredictError, Predicted, Role, Signature, +}; + +mod exec; +mod previews; +mod prompt; +mod py_bridge; +pub mod runtime; +mod submit; +mod tools; +use previews::render_previews; +use prompt::{render_action_instruction, render_extract_instruction}; +pub use runtime::{ + DynRuntime, LlmTools, PyO3Runtime, RlmInputFields, RlmRuntime, StubRuntime, SubmitError, + SubmitHandler, SubmitResultDyn, SubmitSlot, clear_submit_slot, take_submit_result, +}; +pub use tools::LlmQuery; + +const DEFAULT_MAX_ITERATIONS: usize = 20; +const DEFAULT_MAX_LLM_CALLS: usize = 50; +const DEFAULT_MAX_OUTPUT_CHARS: usize = 10_000; +const DEFAULT_ENABLE_EXTRACTION_FALLBACK: bool = true; +const MAX_RECOVERABLE_PARSE_SNIPPET_CHARS: usize = 80; +const STDOUT_TRUNCATION_NOTICE_PREFIX: &str = "[STDOUT TRUNCATED at "; +const SYNTHETIC_TURN_ZERO_ASSISTANT_CODE: &str = r#"# turn-0 API orientation +if "sessions" in globals() and hasattr(sessions, "items") and len(sessions.items) > 0: + s = sessions.items[0] + print(s.render()[:500]) + msgs = s.thread("darin") + print(f"darin messages: {len(msgs)}") +else: + print("hello world")"#; + +const REPL_HISTORY_INPUT_RENDER_TEMPLATE: &str = r#"{% if this.entries|length == 0 %}(no executed REPL turns captured){% else %}{% for entry in this.entries %}=== Turn {{ entry.turn }} === +Code: +{{ entry.code }} + +Output: +{% if entry.output %}{{ entry.output }}{% else %}{% endif %}{% if not loop.last %} + +{% endif %}{% endfor %}{% endif %}"#; + +#[derive(Signature, Clone, Debug)] +struct RlmActionSig { + #[input] + perception: String, + + #[output] + code: String, +} + +#[derive(Clone, Debug)] +#[BamlType] +struct REPLHistory { + entries: Vec, +} + +#[derive(Clone, Debug)] +#[BamlType] +struct REPLEntry { + turn: u32, + code: String, + output: String, +} + +#[derive(Clone, Debug)] +#[BamlType] +struct RlmExtractInput { + variables_info: String, + repl_history: REPLHistory, +} + +struct RlmExtractSig(PhantomData); + +impl Signature for RlmExtractSig +where + S: Signature, + S::Input: BamlType + for<'a> Facet<'a> + Clone + Send + Sync + RlmInputFields, + S::Output: BamlType + for<'a> Facet<'a> + Clone + Send + Sync, +{ + type Input = RlmExtractInput; + type Output = S::Output; + + fn instruction() -> &'static str { + S::instruction() + } + + fn input_shape() -> &'static facet::Shape { + facet::shape_of::() + } + + fn output_shape() -> &'static facet::Shape { + facet::shape_of::() + } + + fn input_field_metadata() -> &'static [crate::FieldMetadataSpec] { + const INPUT_META: [crate::FieldMetadataSpec; 2] = [ + crate::FieldMetadataSpec { + rust_name: "variables_info", + alias: None, + constraints: &[], + input_render: crate::InputRenderSpec::Default, + }, + crate::FieldMetadataSpec { + rust_name: "repl_history", + alias: None, + constraints: &[], + input_render: crate::InputRenderSpec::Jinja(REPL_HISTORY_INPUT_RENDER_TEMPLATE), + }, + ]; + &INPUT_META + } + + fn output_field_metadata() -> &'static [crate::FieldMetadataSpec] { + S::output_field_metadata() + } +} + +#[derive(Debug, Clone, facet::Facet)] +#[facet(crate = facet)] +pub struct RlmConfig { + pub max_iterations: usize, + pub max_llm_calls: usize, + pub max_output_chars: usize, + pub enable_extraction_fallback: bool, +} + +impl Default for RlmConfig { + fn default() -> Self { + Self { + max_iterations: DEFAULT_MAX_ITERATIONS, + max_llm_calls: DEFAULT_MAX_LLM_CALLS, + max_output_chars: DEFAULT_MAX_OUTPUT_CHARS, + enable_extraction_fallback: DEFAULT_ENABLE_EXTRACTION_FALLBACK, + } + } +} + +#[derive(Debug, Clone, Default)] +struct MetadataAcc { + lm_usage: LmUsage, + tool_calls: Vec, + tool_executions: Vec, + raw_responses: Vec, + field_meta: IndexMap, +} + +impl MetadataAcc { + fn absorb_call_metadata(&mut self, metadata: CallMetadata) { + self.lm_usage = std::mem::take(&mut self.lm_usage) + metadata.lm_usage; + self.tool_calls.extend(metadata.tool_calls); + self.tool_executions.extend(metadata.tool_executions); + self.raw_responses.push(metadata.raw_response); + self.field_meta.extend(metadata.field_meta); + } + + fn absorb_parse_metadata(&mut self, raw_response: String, lm_usage: LmUsage) { + self.lm_usage = std::mem::take(&mut self.lm_usage) + lm_usage; + self.raw_responses.push(raw_response); + } + + fn into_call_metadata(self) -> CallMetadata { + let raw_response = if self.raw_responses.is_empty() { + String::new() + } else { + self.raw_responses.join("\n\n") + }; + + CallMetadata::new( + raw_response, + self.lm_usage, + self.tool_calls, + self.tool_executions, + None, + self.field_meta, + ) + } +} + +enum ActionTurn { + Parsed(Predicted), + RecoverableParse { + raw_response: String, + lm_usage: LmUsage, + chat: Chat, + reason: String, + }, +} + +enum ExecOutcome { + Continue { + output: String, + }, + SubmitAccepted { + value: BamlValue, + field_meta: IndexMap, + }, + SubmitValidationError { + message: String, + errors: Vec, + raw_output: String, + }, + SubmitAssertionFailed { + label: String, + expression: String, + raw_output: String, + }, + PythonException { + message: String, + }, + RecoverableParse { + message: String, + }, +} + +enum TurnDecision { + Continue, + Finalization, + Fallback, +} + +#[derive(Debug, Clone, Default)] +struct PerceptionFeedback { + stdout: Option, + stderr: Option, + execution_time: Option, +} + +#[derive(Debug, Clone)] +struct PerceptionMessage { + text: String, + namespace_snapshot: BTreeMap, +} + +#[derive(Debug, Clone, Default)] +struct NamespaceSections { + injected: Vec<(String, String)>, + recent: Vec<(String, String)>, + stable: Vec<(String, String)>, + updated_names: Vec, + namespace_snapshot: BTreeMap, + initial_state: bool, +} + +#[derive(Debug, thiserror::Error)] +pub enum RlmError { + #[error("configuration error: {message}")] + Configuration { message: String }, + + #[error("action predict failed")] + ActionPredict { + #[source] + source: PredictError, + }, + + #[error("python execution failed: {message}")] + PythonExec { message: String }, + + #[error("extraction fallback failed")] + ExtractFallback { + #[source] + source: PredictError, + }, + + #[error("max iterations reached ({max})")] + MaxIterationsReached { max: usize }, + + #[error("internal invariant violated: {message}")] + Invariant { message: String }, +} + +impl From for PredictError { + fn from(value: RlmError) -> Self { + match value { + RlmError::ActionPredict { source } => source, + RlmError::ExtractFallback { source } => source, + other => PredictError::Module { + module: "Rlm", + source: Box::new(other), + }, + } + } +} + +#[derive(facet::Facet)] +#[facet(crate = facet)] +pub struct Rlm +where + S: Signature, + S::Input: BamlType + for<'a> Facet<'a> + Clone + Send + Sync + RlmInputFields, + S::Output: BamlType + for<'a> Facet<'a> + Clone + Send + Sync, +{ + extract: Predict>, + + #[facet(skip)] + config: RlmConfig, + #[facet(skip)] + instruction_override: Option, + #[facet(skip, opaque)] + sub_lm: Option>, + #[facet(skip, opaque)] + runtime: Arc>, +} + +impl Default for Rlm +where + S: Signature, + S::Input: BamlType + for<'a> Facet<'a> + Clone + Send + Sync + RlmInputFields, + S::Output: BamlType + for<'a> Facet<'a> + Clone + Send + Sync, +{ + fn default() -> Self { + Self::new() + } +} + +impl Rlm +where + S: Signature, + S::Input: BamlType + for<'a> Facet<'a> + Clone + Send + Sync + RlmInputFields, + S::Output: BamlType + for<'a> Facet<'a> + Clone + Send + Sync, +{ + pub fn new() -> Self { + Self::builder().build() + } + + pub fn builder() -> RlmBuilder { + RlmBuilder::new() + } + + pub async fn call(&self, input: S::Input) -> Result, PredictError> { + self.forward(input).await + } + + pub async fn forward(&self, input: S::Input) -> Result, PredictError> { + self.run_loop(&input).await.map_err(Into::into) + } + + async fn run_loop(&self, input: &S::Input) -> Result, RlmError> { + if self.config.max_iterations == 0 { + return Err(RlmError::Configuration { + message: "max_iterations must be >= 1".to_string(), + }); + } + info!( + max_iterations = self.config.max_iterations, + max_llm_calls = self.config.max_llm_calls, + max_output_chars = self.config.max_output_chars, + extraction_fallback = self.config.enable_extraction_fallback, + "rlm run started" + ); + + let submit_slot: SubmitSlot = Arc::new(Mutex::new(None)); + let submit_handler = SubmitHandler::new::(Arc::clone(&submit_slot)); + let sub_lm = self.sub_lm.clone().or_else(|| { + let guard = crate::GLOBAL_SETTINGS.read().ok()?; + guard.as_ref().map(|settings| Arc::clone(&settings.lm)) + }); + if self.runtime.requires_sub_lm_tools() && sub_lm.is_none() { + return Err(RlmError::Configuration { + message: "Rlm runtime requires a configured sub-LM (global configure() or builder.sub_lm(...))" + .to_string(), + }); + } + let llm_tools = if self.runtime.requires_sub_lm_tools() { + Some(LlmTools::with_budget( + sub_lm.expect("sub_lm present when required by runtime"), + self.config.max_llm_calls, + tokio::runtime::Handle::try_current().map_err(|err| RlmError::Configuration { + message: format!("Rlm requires an active Tokio runtime handle: {err}"), + })?, + )) + } else { + None + }; + let input_fields = input.rlm_field_names().len(); + let setup = { + let _inject_span = info_span!( + "rlm.inject", + input_fields, + sub_lm_tools = llm_tools.is_some() + ) + .entered(); + Python::attach(|py| { + self.runtime.setup_interpreter_globals( + py, + input, + &submit_handler, + llm_tools.as_ref(), + ) + }) + } + .map_err(|err| RlmError::Configuration { + message: err.to_string(), + })?; + debug!( + input_fields, + injected_objects = setup.methods_by_var.len(), + "interpreter globals injected" + ); + let globals = setup.globals; + + let preview_span = info_span!( + "rlm.preview", + input_fields, + preview_len = tracing::field::Empty + ); + let previews = { + let _preview_span = preview_span.enter(); + render_previews::(input, &setup.methods_by_var, &setup.methods_by_type) + }; + let preview_len = previews.chars().count(); + preview_span.record("preview_len", preview_len); + info!(preview_len, "rlm preview generated"); + + let action_instruction = render_action_instruction::( + &self.config, + self.instruction_override.as_deref(), + &previews, + ); + // TODO(dsrs-rlm): This local Predict is a runtime-workaround so instruction + // composition can include runtime-collected method metadata and rendered + // input schemas. Structural fix options: + // 1) public post-build instruction override on Predict, or + // 2) build-time instruction composition using compile-time method metadata. + let generate_action = Predict::::builder() + .instruction(action_instruction.clone()) + .adapter(ChatAdapter::passthrough()) + .build(); + let task_hint = task_hint_from_input::(input).unwrap_or_else(|| { + if let Some(instruction) = self.instruction_override.as_deref() { + instruction.trim().to_string() + } else { + S::instruction().trim().to_string() + } + }); + + let enable_turn_zero_demo = true; + let mut previous_namespace_snapshot: Option> = None; + let mut previous_sub_lm_remaining: Option = None; + let (mut history, mut feedback): (Option, Option) = + if enable_turn_zero_demo { + let initial_sub_lm_remaining = + self.runtime.sub_lm_budget_remaining(llm_tools.as_ref()); + previous_sub_lm_remaining = Some(initial_sub_lm_remaining); + let synthetic_turn_zero_user = build_synthetic_turn_zero_user_message( + self.config.max_iterations, + initial_sub_lm_remaining, + ); + let mut synthetic_history = Chat::new(vec![]); + synthetic_history.push(Role::System, &action_instruction); + synthetic_history.push(Role::User, &synthetic_turn_zero_user); + synthetic_history.push(Role::Assistant, SYNTHETIC_TURN_ZERO_ASSISTANT_CODE); + + clear_submit_slot(&submit_slot); + let synthetic_start = Instant::now(); + let synthetic_feedback = Python::attach(|py| { + self.runtime.execute_repl_code( + py, + &globals, + SYNTHETIC_TURN_ZERO_ASSISTANT_CODE, + self.config.max_output_chars, + ) + }); + let synthetic_execution_time = synthetic_start.elapsed(); + + ( + Some(synthetic_history), + Some(match synthetic_feedback { + Ok(stdout) => PerceptionFeedback { + stdout: Some(stdout), + stderr: None, + execution_time: Some(synthetic_execution_time), + }, + Err(stderr) => PerceptionFeedback { + stdout: None, + stderr: Some(stderr), + execution_time: Some(synthetic_execution_time), + }, + }), + ) + } else { + (None, None) + }; + if enable_turn_zero_demo { + previous_namespace_snapshot = Some( + Python::attach(|py| { + collect_namespace_snapshot(py, &globals, input.rlm_field_names()) + .map(|snapshot| namespace_snapshot_map(&snapshot)) + }) + .map_err(|message| RlmError::Configuration { message })?, + ); + } + let mut turn_index = 1usize; + let mut acc = MetadataAcc::default(); + let mut repl_history = REPLHistory { + entries: Vec::new(), + }; + + loop { + let is_first_turn = turn_index == 1; + let _turn_span = info_span!( + "rlm.turn", + iteration = turn_index, + first_turn = is_first_turn + ) + .entered(); + match self.decide_turn_policy(turn_index, self.config.max_iterations) { + TurnDecision::Fallback => { + if self.config.enable_extraction_fallback { + let action_history = history.take(); + return self + .run_extraction_fallback( + &previews, + repl_history, + action_history, + &mut acc, + ) + .await; + } + return Err(RlmError::MaxIterationsReached { + max: self.config.max_iterations, + }); + } + TurnDecision::Continue | TurnDecision::Finalization => {} + } + + let budget_remaining = self + .config + .max_iterations + .saturating_sub(turn_index) + .saturating_add(1); + let sub_lm_remaining = self.runtime.sub_lm_budget_remaining(llm_tools.as_ref()); + let sub_lm_spent_last_turn = + previous_sub_lm_remaining.map(|prev| prev.saturating_sub(sub_lm_remaining)); + let perception = Python::attach(|py| { + build_perception_message::( + py, + &globals, + input, + &task_hint, + feedback.as_ref(), + budget_remaining, + sub_lm_remaining, + is_first_turn, + turn_index, + sub_lm_spent_last_turn, + previous_namespace_snapshot.as_ref(), + ) + }) + .map_err(|message| RlmError::Configuration { message })?; + previous_sub_lm_remaining = Some(sub_lm_remaining); + previous_namespace_snapshot = Some(perception.namespace_snapshot); + let action_input = RlmActionSigInput::new(perception.text); + + info!( + iteration = turn_index, + first_turn = is_first_turn, + budget_remaining, + "running action predict call" + ); + let turn_history = history.take(); + match self + .run_action_turn(&generate_action, action_input, turn_history) + .await? + { + ActionTurn::RecoverableParse { + raw_response, + lm_usage, + chat, + reason, + } => { + debug!( + iteration = turn_index, + response_kind = "error", + error_kind = "recoverable_parse", + "predict response received" + ); + acc.absorb_parse_metadata(raw_response, lm_usage); + history = Some(chat); + feedback = Some(PerceptionFeedback { + stdout: None, + stderr: Some(reason), + execution_time: None, + }); + turn_index += 1; + } + ActionTurn::Parsed(predicted) => { + let (action_output, action_metadata, action_chat) = predicted.into_parts(); + acc.absorb_call_metadata(action_metadata); + history = Some(action_chat); + + let code = action_output.code; + clear_submit_slot(&submit_slot); + + let execution_started = Instant::now(); + let exec_result = Python::attach(|py| { + self.runtime.execute_repl_code( + py, + &globals, + &code, + self.config.max_output_chars, + ) + }); + let execution_time = execution_started.elapsed(); + let submit_result = take_submit_result(&submit_slot); + let outcome = classify_exec_outcome(exec_result, submit_result); + + match outcome { + ExecOutcome::SubmitAccepted { value, field_meta } => { + info!( + iteration = turn_index, + response_kind = "submit", + "predict response received" + ); + let typed_output = + S::Output::try_from_baml_value(value).map_err(|err| { + RlmError::Invariant { + message: format!( + "SUBMIT produced invalid output value: {err}" + ), + } + })?; + acc.field_meta.extend(field_meta); + + let final_chat = history.unwrap_or_else(|| Chat::new(vec![])); + return Ok(Predicted::new( + typed_output, + acc.into_call_metadata(), + final_chat, + )); + } + other => { + debug!( + iteration = turn_index, + response_kind = predict_response_kind_from_outcome(&other), + outcome = exec_outcome_kind(&other), + "predict response received" + ); + feedback = Some(perception_feedback_from_outcome( + &other, + Some(execution_time), + )); + repl_history.entries.push(REPLEntry { + turn: turn_index.min(u32::MAX as usize) as u32, + code, + output: outcome_to_raw_output(&other), + }); + turn_index += 1; + } + } + } + } + } + } + + async fn run_action_turn( + &self, + generate_action: &Predict, + action_input: RlmActionSigInput, + history: Option, + ) -> Result { + match generate_action.forward(action_input, history).await { + Ok(predicted) => Ok(ActionTurn::Parsed(predicted)), + Err(error) => match error { + PredictError::Parse { + source, + raw_response, + lm_usage, + chat, + } if raw_response.trim().is_empty() => { + let reason = format_empty_response_recovery_reason(&raw_response, &source); + Ok(ActionTurn::RecoverableParse { + raw_response, + lm_usage, + chat, + reason, + }) + } + other => Err(RlmError::ActionPredict { source: other }), + }, + } + } + + fn decide_turn_policy(&self, turn_index: usize, max_iterations: usize) -> TurnDecision { + if turn_index < max_iterations { + TurnDecision::Continue + } else if turn_index == max_iterations { + TurnDecision::Finalization + } else { + TurnDecision::Fallback + } + } + + async fn run_extraction_fallback( + &self, + previews: &str, + repl_history: REPLHistory, + action_history: Option, + acc: &mut MetadataAcc, + ) -> Result, RlmError> { + let extract_input = RlmExtractInput { + variables_info: previews.to_string(), + repl_history, + }; + let predicted = self + .extract + .forward(extract_input, None) + .await + .map_err(|source| RlmError::ExtractFallback { source })?; + let (output, metadata, extract_chat) = predicted.into_parts(); + acc.absorb_call_metadata(metadata); + let metadata = std::mem::take(acc).into_call_metadata(); + // Preserve action-loop chat when fallback extraction runs so downstream + // transcripts still reflect the REPL interaction that produced the evidence. + // If no action history exists, fall back to the extractor chat. + let final_chat = action_history.unwrap_or(extract_chat); + Ok(Predicted::new(output, metadata, final_chat)) + } +} + +impl Module for Rlm +where + S: Signature, + S::Input: BamlType + for<'a> Facet<'a> + Clone + Send + Sync + RlmInputFields, + S::Output: BamlType + for<'a> Facet<'a> + Clone + Send + Sync, +{ + type Input = S::Input; + type Output = S::Output; + + async fn forward(&self, input: S::Input) -> Result, PredictError> { + Rlm::forward(self, input).await + } +} + +pub struct RlmBuilder +where + S: Signature, + S::Input: BamlType + for<'a> Facet<'a> + Clone + Send + Sync + RlmInputFields, + S::Output: BamlType + for<'a> Facet<'a> + Clone + Send + Sync, +{ + config: RlmConfig, + instruction_override: Option, + sub_lm: Option>, + runtime: Option>>, + _marker: PhantomData, +} + +impl RlmBuilder +where + S: Signature, + S::Input: BamlType + for<'a> Facet<'a> + Clone + Send + Sync + RlmInputFields, + S::Output: BamlType + for<'a> Facet<'a> + Clone + Send + Sync, +{ + fn new() -> Self { + Self { + config: RlmConfig::default(), + instruction_override: None, + sub_lm: None, + runtime: None, + _marker: PhantomData, + } + } + + pub fn max_iterations(mut self, max_iterations: usize) -> Self { + self.config.max_iterations = max_iterations; + self + } + + pub fn max_llm_calls(mut self, max_llm_calls: usize) -> Self { + self.config.max_llm_calls = max_llm_calls; + self + } + + pub fn max_output_chars(mut self, max_output_chars: usize) -> Self { + self.config.max_output_chars = max_output_chars; + self + } + + pub fn enable_extraction_fallback(mut self, enable_extraction_fallback: bool) -> Self { + self.config.enable_extraction_fallback = enable_extraction_fallback; + self + } + + pub fn instruction(mut self, instruction: impl Into) -> Self { + self.instruction_override = Some(instruction.into()); + self + } + + pub fn sub_lm(mut self, sub_lm: Arc) -> Self { + self.sub_lm = Some(sub_lm); + self + } + + pub fn runtime(mut self, runtime: Arc>) -> Self { + self.runtime = Some(runtime); + self + } + + pub fn build(self) -> Rlm { + let extract_instruction = + render_extract_instruction::(self.instruction_override.as_deref()); + let extract = Predict::>::builder() + .instruction(extract_instruction) + .adapter(ChatAdapter::new()) + .build(); + + let runtime = self + .runtime + .unwrap_or_else(|| default_runtime::(self.config.max_llm_calls)); + + Rlm { + extract, + config: self.config, + instruction_override: self.instruction_override, + sub_lm: self.sub_lm, + runtime, + } + } +} + +fn default_runtime(max_llm_calls: usize) -> Arc> +where + S::Input: BamlType + for<'a> Facet<'a> + Clone + Send + Sync + RlmInputFields, + S::Output: BamlType + for<'a> Facet<'a> + Clone + Send + Sync, +{ + if let Ok(runtime_override) = std::env::var("DSPY_RS_RLM_RUNTIME") { + match runtime_override.trim().to_ascii_lowercase().as_str() { + "stub" => return Arc::new(StubRuntime::new(max_llm_calls)), + "pyo3" => return Arc::new(PyO3Runtime), + _ => {} + } + } + + #[cfg(test)] + { + Arc::new(StubRuntime::new(max_llm_calls)) + } + #[cfg(not(test))] + { + let _ = max_llm_calls; + Arc::new(PyO3Runtime) + } +} + +fn task_hint_from_input(input: &S::Input) -> Option +where + S: Signature, + S::Input: BamlType, +{ + let value = input.to_baml_value(); + let question = match &value { + BamlValue::Class(_, fields) | BamlValue::Map(fields) => fields.get("question"), + _ => None, + }?; + if let BamlValue::String(text) = question { + let trimmed = text.trim(); + if !trimmed.is_empty() { + return Some(trimmed.to_string()); + } + } + None +} + +fn build_perception_message( + py: Python<'_>, + globals: &Py, + input: &S::Input, + task_hint: &str, + feedback: Option<&PerceptionFeedback>, + budget_remaining: usize, + sub_lm_remaining: usize, + first_turn: bool, + turn_index: usize, + sub_lm_spent_last_turn: Option, + previous_namespace_snapshot: Option<&BTreeMap>, +) -> Result +where + S: Signature, + S::Input: BamlType + RlmInputFields, +{ + let namespace = collect_namespace_snapshot(py, globals, input.rlm_field_names())?; + let namespace_sections = partition_namespace_snapshot( + &namespace, + input.rlm_field_names(), + previous_namespace_snapshot, + ); + + let mut lines = Vec::new(); + lines.push(format!("=== Execution Receipt (Turn {turn_index}) ===")); + lines.push(format!( + "Time: {}", + format_execution_time(feedback.and_then(|item| item.execution_time)) + )); + lines.push(format!( + "Budget: {} remaining | {} sub-LLM calls remaining", + turns_label(budget_remaining), + sub_lm_remaining + )); + let sub_lm_cost_line = match sub_lm_spent_last_turn { + Some(spent) => format!( + "Sub-LLM cost: {spent} call{} spent last turn", + plural_suffix(spent) + ), + None => "Sub-LLM cost: n/a (first turn)".to_string(), + }; + lines.push(sub_lm_cost_line); + lines.push(format!( + "Updated: {}", + render_updated_names(&namespace_sections) + )); + + if let Some(feedback) = feedback { + if let Some(stdout) = feedback.stdout.as_deref() + && !stdout.trim().is_empty() + { + lines.push(String::new()); + lines.push("--- stdout ---".to_string()); + append_stdout_lines_with_truncation_hint( + &mut lines, + stdout, + &namespace_sections.updated_names, + ); + lines.push("--------------".to_string()); + } + if let Some(stderr) = feedback.stderr.as_deref() + && !stderr.trim().is_empty() + { + lines.push(String::new()); + lines.push("[stderr]".to_string()); + lines.push(stderr.to_string()); + } + } + + if first_turn { + if !lines.is_empty() { + lines.push(String::new()); + } + lines.push(format!("[query] {}", truncate_chars(task_hint, 180))); + } + + if budget_remaining == 1 { + lines.push(String::new()); + lines.push("⚠ LAST TURN — you MUST call SUBMIT() now with your best answer.".to_string()); + } + + lines.push(String::new()); + lines.push("=== Namespace ===".to_string()); + render_namespace_section(&mut lines, "Injected", &namespace_sections.injected); + render_namespace_section(&mut lines, "Recent", &namespace_sections.recent); + render_stable_namespace_summary(&mut lines, namespace_sections.stable.len()); + lines.push(String::new()); + lines.push(render_repl_prompt( + turn_index, + budget_remaining, + sub_lm_remaining, + namespace_sections.namespace_snapshot.len(), + )); + + Ok(PerceptionMessage { + text: lines.join("\n"), + namespace_snapshot: namespace_sections.namespace_snapshot, + }) +} + +fn build_synthetic_turn_zero_user_message( + budget_remaining: usize, + sub_lm_remaining: usize, +) -> String { + [ + "=== Execution Receipt (Turn 0) ===".to_string(), + "Time: n/a".to_string(), + format!( + "Budget: {} remaining | {} sub-LLM calls remaining", + turns_label(budget_remaining), + sub_lm_remaining + ), + "Sub-LLM cost: n/a (synthetic setup turn)".to_string(), + "Updated: (initial state — no prior diff)".to_string(), + String::new(), + "=== Namespace ===".to_string(), + "[Injected]".to_string(), + "(none)".to_string(), + String::new(), + "[Recent]".to_string(), + "(none)".to_string(), + String::new(), + "[Stable] 0 variables".to_string(), + String::new(), + render_repl_prompt(0, budget_remaining, sub_lm_remaining, 0), + ] + .join("\n") +} + +fn partition_namespace_snapshot( + namespace: &[(String, String)], + injected_roots: &[&str], + previous_namespace_snapshot: Option<&BTreeMap>, +) -> NamespaceSections { + let roots = injected_roots + .iter() + .map(|name| (*name).to_string()) + .collect::>(); + let mut sections = NamespaceSections { + initial_state: previous_namespace_snapshot.is_none(), + ..NamespaceSections::default() + }; + + for (name, repr_value) in namespace { + sections + .namespace_snapshot + .insert(name.clone(), repr_value.clone()); + + let changed_since_last_turn = previous_namespace_snapshot + .and_then(|snapshot| snapshot.get(name)) + .map(|previous| previous != repr_value) + .unwrap_or(previous_namespace_snapshot.is_some()); + + if changed_since_last_turn { + sections.updated_names.push(name.clone()); + } + + if roots.contains(name) { + sections.injected.push((name.clone(), repr_value.clone())); + } else if previous_namespace_snapshot.is_none() || changed_since_last_turn { + sections.recent.push((name.clone(), repr_value.clone())); + } else { + sections.stable.push((name.clone(), repr_value.clone())); + } + } + + sections +} + +fn namespace_snapshot_map(namespace: &[(String, String)]) -> BTreeMap { + namespace + .iter() + .map(|(name, repr_value)| (name.clone(), repr_value.clone())) + .collect() +} + +fn render_namespace_section(lines: &mut Vec, title: &str, entries: &[(String, String)]) { + lines.push(String::new()); + lines.push(format!("[{title}]")); + if entries.is_empty() { + lines.push("(none)".to_string()); + return; + } + for (name, repr_value) in entries { + lines.push(format!("{name} = {repr_value}")); + } +} + +fn render_stable_namespace_summary(lines: &mut Vec, stable_count: usize) { + lines.push(String::new()); + lines.push(format!( + "[Stable] {stable_count} {}", + if stable_count == 1 { + "variable" + } else { + "variables" + } + )); +} + +fn render_updated_names(sections: &NamespaceSections) -> String { + if sections.initial_state { + return "(initial state — no prior diff)".to_string(); + } + if sections.updated_names.is_empty() { + return "none".to_string(); + } + render_updated_var_names_inline(§ions.updated_names) +} + +fn append_stdout_lines_with_truncation_hint( + lines: &mut Vec, + stdout: &str, + updated_names: &[String], +) { + for line in stdout.lines() { + lines.push(line.to_string()); + if line + .trim_start() + .starts_with(STDOUT_TRUNCATION_NOTICE_PREFIX) + && !updated_names.is_empty() + { + lines.push(format!( + "hint: updated vars this turn: {} — query directly", + render_updated_var_names_inline(updated_names) + )); + } + } +} + +fn render_updated_var_names_inline(updated_names: &[String]) -> String { + updated_names + .iter() + .map(|name| format!("`{name}`")) + .collect::>() + .join(", ") +} + +fn format_execution_time(duration: Option) -> String { + duration + .map(|value| format!("{:.1}s", value.as_secs_f64())) + .unwrap_or_else(|| "n/a".to_string()) +} + +fn turns_label(turns: usize) -> String { + if turns == 1 { + "1 turn".to_string() + } else { + format!("{turns} turns") + } +} + +fn render_repl_prompt( + turn_index: usize, + turns_remaining: usize, + sub_lm_remaining: usize, + namespace_var_count: usize, +) -> String { + format!( + "[T{turn_index} | {} | {sub_lm_remaining} llm | {namespace_var_count} vars] >>>", + turns_label(turns_remaining), + ) +} + +fn plural_suffix(count: usize) -> &'static str { + if count == 1 { "" } else { "s" } +} + +fn collect_namespace_snapshot( + py: Python<'_>, + globals: &Py, + injected_roots: &[&str], +) -> Result, String> { + let dict = globals.bind(py); + let roots = injected_roots + .iter() + .map(|name| (*name).to_string()) + .collect::>(); + + let mut out = Vec::new(); + for root in injected_roots { + if let Some(value) = dict + .get_item(*root) + .map_err(|err| format!("failed to fetch root `{root}` from globals: {err}"))? + { + out.push(((*root).to_string(), safe_namespace_repr(&value, true)?)); + } + } + + let mut extras = Vec::new(); + for (name, value) in dict.iter() { + let Ok(name) = name.extract::() else { + continue; + }; + if roots.contains(name.as_str()) { + continue; + } + if !include_in_namespace(name.as_str(), &value, &roots) { + continue; + } + extras.push((name, safe_namespace_repr(&value, false)?)); + } + extras.sort_by(|a, b| a.0.cmp(&b.0)); + out.extend(extras); + + Ok(out) +} + +fn include_in_namespace( + name: &str, + value: &Bound<'_, pyo3::PyAny>, + roots: &BTreeSet, +) -> bool { + if roots.contains(name) { + return true; + } + if name.starts_with('_') { + return false; + } + if name.chars().count() <= 1 { + return false; + } + if value.is_instance_of::() { + return false; + } + if value.is_callable() { + return false; + } + true +} + +fn safe_namespace_repr(value: &Bound<'_, pyo3::PyAny>, is_root: bool) -> Result { + if is_root { + if value.is_instance_of::() { + let len = value.len().unwrap_or_default(); + if len > 5 { + if let Ok(list) = value.cast::() { + let mut preview = Vec::new(); + for item in list.iter().take(2) { + let rendered = sanitize_python_surface(&repr_value(&item)?); + preview.push(truncate_chars(&rendered, 100)); + } + if !preview.is_empty() { + return Ok(format!("[{}, ... ({} total)]", preview.join(", "), len)); + } + } + return Ok(format!("list({len} items)")); + } + } + return Ok(truncate_chars(&repr_value(value)?, 200)); + } + + if value.is_instance_of::() { + let text = value + .extract::() + .map_err(|err| format!("string extract failed: {err}"))?; + return Ok(format!("{:?}", truncate_chars(&text, 50))); + } + if value.is_instance_of::() + || value.is_instance_of::() + || value.is_instance_of::() + { + return repr_value(value); + } + + if value.is_instance_of::() { + let len = value.len().unwrap_or_default(); + if len <= 5 { + return Ok(truncate_chars( + &sanitize_python_surface(&repr_value(value)?), + 120, + )); + } + return Ok(format!("")); + } + if value.is_instance_of::() { + let len = value.len().unwrap_or_default(); + if len <= 5 { + return Ok(truncate_chars( + &sanitize_python_surface(&repr_value(value)?), + 120, + )); + } + return Ok(format!("")); + } + if value.is_instance_of::() { + let len = value.len().unwrap_or_default(); + if len <= 5 { + return Ok(truncate_chars( + &sanitize_python_surface(&repr_value(value)?), + 120, + )); + } + return Ok(format!("")); + } + if value.is_instance_of::() { + let len = value.len().unwrap_or_default(); + if len <= 5 { + return Ok(truncate_chars( + &sanitize_python_surface(&repr_value(value)?), + 120, + )); + } + return Ok(format!("")); + } + + let class_name = value + .get_type() + .name() + .map(|name| name.to_string_lossy().to_string()) + .unwrap_or_else(|_| "Object".to_string()); + + if let Ok(len) = value.len() { + return Ok(format!("<{class_name}: {len} items>")); + } + + Ok(format!("<{class_name}>")) +} + +fn repr_value(value: &Bound<'_, pyo3::PyAny>) -> Result { + let repr = value + .repr() + .map_err(|err| format!("repr() failed: {err}"))?; + Ok(repr.to_string_lossy().to_string()) +} + +fn sanitize_python_surface(text: &str) -> String { + let mut out = String::with_capacity(text.len()); + let mut token = String::new(); + + let flush = |out: &mut String, token: &mut String| { + if token.is_empty() { + return; + } + if let Some(last) = token.rsplit("::").next() { + out.push_str(last); + } else { + out.push_str(token); + } + token.clear(); + }; + + for ch in text.chars() { + if ch.is_ascii_alphanumeric() || ch == '_' || ch == ':' { + token.push(ch); + } else { + flush(&mut out, &mut token); + out.push(ch); + } + } + flush(&mut out, &mut token); + out +} + +fn perception_feedback_from_outcome( + outcome: &ExecOutcome, + execution_time: Option, +) -> PerceptionFeedback { + match outcome { + ExecOutcome::Continue { output } => PerceptionFeedback { + stdout: (!output.trim().is_empty()).then(|| output.clone()), + stderr: None, + execution_time, + }, + ExecOutcome::SubmitAccepted { .. } => PerceptionFeedback::default(), + ExecOutcome::SubmitValidationError { .. } + | ExecOutcome::SubmitAssertionFailed { .. } + | ExecOutcome::PythonException { .. } + | ExecOutcome::RecoverableParse { .. } => PerceptionFeedback { + stdout: None, + stderr: Some(outcome_to_raw_output(outcome)), + execution_time, + }, + } +} + +fn truncate_chars(text: &str, max_chars: usize) -> String { + if text.chars().count() <= max_chars { + return text.to_string(); + } + let mut out = String::new(); + for ch in text.chars().take(max_chars) { + out.push(ch); + } + out.push_str("..."); + out +} + +#[cfg(test)] +fn recoverable_outcome_from_parse_error(error: &PredictError) -> Option<(String, Chat)> { + match error { + PredictError::Parse { + raw_response, + chat, + source, + .. + } if raw_response.trim().is_empty() => Some(( + format_empty_response_recovery_reason(raw_response, source), + chat.clone(), + )), + _ => None, + } +} + +fn format_empty_response_recovery_reason( + raw_response: &str, + source: &impl std::fmt::Display, +) -> String { + let total_chars = raw_response.chars().count(); + let mut snippet = raw_response + .chars() + .take(MAX_RECOVERABLE_PARSE_SNIPPET_CHARS) + .collect::(); + if total_chars > MAX_RECOVERABLE_PARSE_SNIPPET_CHARS { + snippet.push_str("..."); + } + + format!( + "Empty response from model ({source}). Write executable Python code. Raw response: len={total_chars}, snippet={snippet:?}." + ) +} + +fn classify_exec_outcome( + exec_result: Result, + submit_result: Option, +) -> ExecOutcome { + if let Some(submit_result) = submit_result { + let raw_exec_output = match exec_result { + Ok(output) => output, + Err(message) => message, + }; + return match submit_result { + Ok((value, field_meta)) => ExecOutcome::SubmitAccepted { value, field_meta }, + Err(SubmitError::ValidationError { message, errors }) => { + ExecOutcome::SubmitValidationError { + message, + errors, + raw_output: raw_exec_output, + } + } + Err(SubmitError::AssertionFailed { label, expression }) => { + ExecOutcome::SubmitAssertionFailed { + label, + expression, + raw_output: raw_exec_output, + } + } + }; + } + + match exec_result { + Ok(output) => ExecOutcome::Continue { output }, + Err(message) => ExecOutcome::PythonException { message }, + } +} + +fn predict_response_kind_from_outcome(outcome: &ExecOutcome) -> &'static str { + match outcome { + ExecOutcome::SubmitAccepted { .. } => "submit", + ExecOutcome::Continue { .. } => "code", + ExecOutcome::SubmitValidationError { .. } + | ExecOutcome::SubmitAssertionFailed { .. } + | ExecOutcome::PythonException { .. } + | ExecOutcome::RecoverableParse { .. } => "error", + } +} + +fn exec_outcome_kind(outcome: &ExecOutcome) -> &'static str { + match outcome { + ExecOutcome::Continue { .. } => "continue", + ExecOutcome::SubmitAccepted { .. } => "submit_accepted", + ExecOutcome::SubmitValidationError { .. } => "submit_validation_error", + ExecOutcome::SubmitAssertionFailed { .. } => "submit_assertion_failed", + ExecOutcome::PythonException { .. } => "python_exception", + ExecOutcome::RecoverableParse { .. } => "recoverable_parse", + } +} + +fn outcome_to_raw_output(outcome: &ExecOutcome) -> String { + match outcome { + ExecOutcome::Continue { output, .. } => output.clone(), + ExecOutcome::SubmitAccepted { .. } => String::new(), + ExecOutcome::SubmitValidationError { + message, + errors, + raw_output, + } => { + if !raw_output.is_empty() { + return raw_output.clone(); + } + if errors.is_empty() { + message.clone() + } else { + format!("{message}\n{}", errors.join("\n")) + } + } + ExecOutcome::SubmitAssertionFailed { + label, + expression, + raw_output, + } => { + if !raw_output.is_empty() { + return raw_output.clone(); + } + format!("Submit assertion failed: `{label}` ({expression})") + } + ExecOutcome::PythonException { message } => message.clone(), + ExecOutcome::RecoverableParse { message } => message.clone(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ParseError, Signature}; + use pyo3::Python; + use pyo3::types::{PyDict, PyDictMethods, PyModule}; + use std::sync::Arc; + use temp_env::with_var; + + #[derive(Signature, Clone, Debug)] + struct RuntimePolicySig { + #[input] + prompt: String, + #[output] + answer: String, + } + + #[test] + fn default_runtime_in_tests_uses_stub_policy() { + let runtime = default_runtime::(3); + assert!( + !runtime.requires_sub_lm_tools(), + "test default runtime should be StubRuntime without required sub-LM tools" + ); + } + + #[test] + fn default_runtime_override_to_pyo3_is_explicit() { + with_var("DSPY_RS_RLM_RUNTIME", Some("pyo3"), || { + let runtime = default_runtime::(3); + assert!( + runtime.requires_sub_lm_tools(), + "explicit pyo3 override should require sub-LM tools" + ); + }); + } + + #[test] + fn default_runtime_override_to_stub_is_explicit() { + with_var("DSPY_RS_RLM_RUNTIME", Some("stub"), || { + let runtime = default_runtime::(3); + assert!( + !runtime.requires_sub_lm_tools(), + "explicit stub override should not require sub-LM tools" + ); + }); + } + + #[test] + fn perception_message_uses_execution_receipt_and_namespace_sections() { + Python::attach(|py| { + let globals = PyDict::new(py); + globals + .set_item("prompt", "Where did signal drop?") + .expect("set prompt"); + globals + .set_item("result_count", 7) + .expect("set result_count"); + globals.set_item("_tmp", 99).expect("set tmp"); + + let input = RuntimePolicySigInput { + prompt: "Where did signal drop?".to_string(), + }; + let message = build_perception_message::( + py, + &globals.unbind(), + &input, + "Inspect trajectories", + None, + 3, + 11, + true, + 1, + None, + None, + ) + .expect("message"); + let message = message.text; + + assert!(message.contains("=== Execution Receipt (Turn 1) ===")); + assert!(message.contains("Budget: 3 turns remaining | 11 sub-LLM calls remaining")); + assert!(message.contains("Sub-LLM cost: n/a (first turn)")); + assert!(message.contains("[query] Inspect trajectories")); + assert!(message.contains("=== Namespace ===")); + assert!(message.contains("[Injected]")); + assert!(message.contains("[Recent]")); + assert!(message.contains("[Stable]")); + assert!(message.contains("prompt =")); + assert!(message.contains("result_count = 7")); + assert!(!message.contains("_tmp =")); + assert!(message.ends_with("[T1 | 3 turns | 11 llm | 2 vars] >>>")); + }); + } + + #[test] + fn perception_message_turn_two_includes_stdout_and_last_turn_warning() { + Python::attach(|py| { + let globals = PyDict::new(py); + globals.set_item("prompt", "x").expect("set prompt"); + let input = RuntimePolicySigInput { + prompt: "x".to_string(), + }; + let feedback = PerceptionFeedback { + stdout: Some("computed summary".to_string()), + stderr: None, + execution_time: Some(Duration::from_millis(1250)), + }; + + let message = build_perception_message::( + py, + &globals.unbind(), + &input, + "Inspect trajectories", + Some(&feedback), + 1, + 3, + false, + 2, + Some(2), + None, + ) + .expect("message"); + let message = message.text; + + assert!(message.contains("Time: 1.2s")); + assert!(message.contains("Sub-LLM cost: 2 calls spent last turn")); + assert!(message.contains("--- stdout ---")); + assert!(message.contains("computed summary")); + assert!( + message.contains("⚠ LAST TURN — you MUST call SUBMIT() now with your best answer.") + ); + assert!(!message.contains("[query]")); + }); + } + + #[test] + fn synthetic_turn_zero_user_message_matches_demo_shape() { + let message = build_synthetic_turn_zero_user_message(12, 20); + assert!(message.contains("=== Execution Receipt (Turn 0) ===")); + assert!(message.contains("Budget: 12 turns remaining | 20 sub-LLM calls remaining")); + assert!(message.contains("=== Namespace ===")); + assert!(message.ends_with("[T0 | 12 turns | 20 llm | 0 vars] >>>")); + assert!(!message.contains("[query]")); + } + + #[test] + fn first_turn_with_feedback_places_stdout_before_query() { + Python::attach(|py| { + let globals = PyDict::new(py); + globals.set_item("prompt", "x").expect("set prompt"); + let input = RuntimePolicySigInput { + prompt: "x".to_string(), + }; + let feedback = PerceptionFeedback { + stdout: Some("hello world".to_string()), + stderr: None, + execution_time: Some(Duration::from_millis(100)), + }; + let message = build_perception_message::( + py, + &globals.unbind(), + &input, + "Inspect trajectories", + Some(&feedback), + 12, + 20, + true, + 1, + None, + None, + ) + .expect("message"); + let message = message.text; + let stdout_idx = message.find("--- stdout ---").expect("stdout marker"); + let query_idx = message.find("[query]").expect("query marker"); + assert!(stdout_idx < query_idx, "stdout should appear before query"); + }); + } + + #[test] + fn perception_message_adds_truncation_hint_when_vars_updated() { + Python::attach(|py| { + let baseline = PyDict::new(py); + baseline.set_item("prompt", "x").expect("set prompt"); + let baseline = baseline.unbind(); + let previous_snapshot = collect_namespace_snapshot(py, &baseline, &["prompt"]) + .map(|snapshot| namespace_snapshot_map(&snapshot)) + .expect("previous snapshot"); + + let globals = PyDict::new(py); + globals.set_item("prompt", "x").expect("set prompt"); + globals + .set_item("retro_corrections", vec!["a", "b"]) + .expect("set updated var"); + + let input = RuntimePolicySigInput { + prompt: "x".to_string(), + }; + let feedback = PerceptionFeedback { + stdout: Some( + "partial output\n[STDOUT TRUNCATED at 10,000 chars (24,847 total)]".to_string(), + ), + stderr: None, + execution_time: Some(Duration::from_millis(220)), + }; + + let message = build_perception_message::( + py, + &globals.unbind(), + &input, + "Inspect trajectories", + Some(&feedback), + 7, + 0, + false, + 6, + Some(20), + Some(&previous_snapshot), + ) + .expect("message") + .text; + + assert!(message.contains("[STDOUT TRUNCATED at 10,000 chars (24,847 total)]")); + assert!( + message + .contains("hint: updated vars this turn: `retro_corrections` — query directly"), + "{message}" + ); + }); + } + + #[test] + fn perception_message_skips_truncation_hint_when_no_vars_updated() { + Python::attach(|py| { + let globals = PyDict::new(py); + globals.set_item("prompt", "x").expect("set prompt"); + globals.set_item("stable_value", 1).expect("set stable"); + let globals = globals.unbind(); + let previous_snapshot = collect_namespace_snapshot(py, &globals, &["prompt"]) + .map(|snapshot| namespace_snapshot_map(&snapshot)) + .expect("previous snapshot"); + + let input = RuntimePolicySigInput { + prompt: "x".to_string(), + }; + let feedback = PerceptionFeedback { + stdout: Some("[STDOUT TRUNCATED at 10,000 chars (24,847 total)]".to_string()), + stderr: None, + execution_time: Some(Duration::from_millis(180)), + }; + + let message = build_perception_message::( + py, + &globals, + &input, + "Inspect trajectories", + Some(&feedback), + 7, + 0, + false, + 6, + Some(0), + Some(&previous_snapshot), + ) + .expect("message") + .text; + + assert!(!message.contains("hint: updated vars this turn:")); + }); + } + + #[test] + fn perception_message_partitions_recent_and_stable_with_snapshot_diff() { + Python::attach(|py| { + let globals = PyDict::new(py); + globals.set_item("prompt", "x").expect("set prompt"); + globals + .set_item("stable_value", 1) + .expect("set stable value"); + let globals = globals.unbind(); + let previous_snapshot = collect_namespace_snapshot(py, &globals, &["prompt"]) + .map(|snapshot| namespace_snapshot_map(&snapshot)) + .expect("previous snapshot"); + + let globals = PyDict::new(py); + globals.set_item("prompt", "x").expect("set prompt"); + globals + .set_item("stable_value", 1) + .expect("set stable value"); + globals + .set_item("recent_value", 2) + .expect("set recent value"); + + let input = RuntimePolicySigInput { + prompt: "x".to_string(), + }; + let message = build_perception_message::( + py, + &globals.unbind(), + &input, + "Inspect trajectories", + None, + 8, + 5, + false, + 4, + Some(1), + Some(&previous_snapshot), + ) + .expect("message") + .text; + + assert!(message.contains("Updated: `recent_value`")); + assert!(message.contains("[Recent]\nrecent_value = 2")); + assert!(message.contains("[Stable] 1 variable")); + }); + } + + #[test] + fn namespace_filtering_excludes_noise_and_keeps_roots() { + Python::attach(|py| { + let globals = PyDict::new(py); + globals + .set_item("prompt", "Where did signal drop?") + .expect("set prompt root"); + globals.set_item("i", 1).expect("set single char"); + globals + .set_item("_scratch", "temp") + .expect("set private name"); + + let json_mod = PyModule::import(py, "json").expect("import json"); + globals + .set_item("json", json_mod) + .expect("set module variable"); + + let builtins = PyModule::import(py, "builtins").expect("import builtins"); + let len_fn = builtins.getattr("len").expect("load len"); + globals + .set_item("callable_fn", len_fn) + .expect("set callable variable"); + + globals + .set_item("kept_value", 42) + .expect("set regular value"); + + let input = RuntimePolicySigInput { + prompt: "Where did signal drop?".to_string(), + }; + let message = build_perception_message::( + py, + &globals.unbind(), + &input, + "Inspect trajectories", + None, + 3, + 9, + true, + 1, + None, + None, + ) + .expect("message"); + let message = message.text; + + assert!(message.contains("prompt =")); + assert!(message.contains("kept_value = 42")); + assert!(!message.contains("\ni = ")); + assert!(!message.contains("_scratch = ")); + assert!(!message.contains("json = ")); + assert!(!message.contains("callable_fn = ")); + }); + } + + #[test] + fn sanitize_python_surface_strips_module_paths() { + let rendered = sanitize_python_surface( + "Sessions(items=[tanha::types::Session(id='abc')], kind=tanha::types::Kind::Fast)", + ); + assert!(!rendered.contains("tanha::types::")); + assert!(rendered.contains("Session(id='abc')")); + assert!(rendered.contains("kind=Fast")); + } + + #[test] + fn root_namespace_repr_uses_object_repr_without_custom_heuristics() { + Python::attach(|py| { + let globals = PyDict::new(py); + py.run( + pyo3::ffi::c_str!( + "class Sessions:\n def __repr__(self):\n return 'Sessions(CUSTOM_REPR)'\nsessions = Sessions()\n" + ), + Some(&globals), + Some(&globals), + ) + .expect("python setup"); + let sessions = globals + .get_item("sessions") + .expect("sessions lookup should succeed") + .expect("sessions should exist"); + let rendered = safe_namespace_repr(&sessions, true).expect("repr"); + assert_eq!(rendered, "Sessions(CUSTOM_REPR)"); + }); + } + + #[test] + fn extract_signature_uses_custom_repl_history_render_template() { + let fields = RlmExtractSig::::input_field_metadata(); + assert_eq!(fields.len(), 2); + match fields[1].input_render { + crate::InputRenderSpec::Jinja(template) => { + assert!(template.contains("=== Turn {{ entry.turn }} ===")); + assert!(template.contains("Code:")); + assert!(template.contains("Output:")); + } + other => panic!("expected jinja render template, got: {other:?}"), + } + } + + #[test] + fn turn_policy_reserves_last_turn_for_finalization_then_fallback() { + let module = Rlm::::builder().build(); + + assert!(matches!( + module.decide_turn_policy(1, 3), + TurnDecision::Continue + )); + assert!(matches!( + module.decide_turn_policy(2, 3), + TurnDecision::Continue + )); + assert!(matches!( + module.decide_turn_policy(3, 3), + TurnDecision::Finalization + )); + assert!(matches!( + module.decide_turn_policy(4, 3), + TurnDecision::Fallback + )); + } + + #[test] + fn perception_feedback_maps_stdout_and_stderr_honestly() { + let continue_feedback = perception_feedback_from_outcome( + &ExecOutcome::Continue { + output: "ok".to_string(), + }, + Some(Duration::from_secs(2)), + ); + assert_eq!(continue_feedback.stdout.as_deref(), Some("ok")); + assert!(continue_feedback.stderr.is_none()); + assert_eq!( + continue_feedback + .execution_time + .map(|value| value.as_secs()), + Some(2) + ); + + let error_feedback = perception_feedback_from_outcome( + &ExecOutcome::PythonException { + message: "Traceback...".to_string(), + }, + Some(Duration::from_millis(750)), + ); + assert_eq!(error_feedback.stderr.as_deref(), Some("Traceback...")); + assert!(error_feedback.stdout.is_none()); + assert_eq!( + error_feedback + .execution_time + .map(|value| value.as_millis() as u64), + Some(750) + ); + } + + #[test] + fn classify_exec_outcome_covers_all_variants_and_feedback_projection() { + let continue_outcome = classify_exec_outcome(Ok("x\n".into()), None); + assert!(matches!( + continue_outcome, + ExecOutcome::Continue { ref output } if output == "x\n" + )); + assert_eq!(outcome_to_raw_output(&continue_outcome), "x\n"); + + let submit_ok = classify_exec_outcome( + Ok(String::new()), + Some(Ok((BamlValue::String("ok".to_string()), IndexMap::new()))), + ); + assert!(matches!(submit_ok, ExecOutcome::SubmitAccepted { .. })); + assert!(outcome_to_raw_output(&submit_ok).is_empty()); + + let submit_validation = classify_exec_outcome( + Err("Traceback...\nSubmitError".to_string()), + Some(Err(SubmitError::ValidationError { + message: "validation failed".to_string(), + errors: vec!["field `answer` expected string".to_string()], + })), + ); + assert!(matches!( + submit_validation, + ExecOutcome::SubmitValidationError { .. } + )); + assert_eq!( + outcome_to_raw_output(&submit_validation), + "Traceback...\nSubmitError" + ); + + let submit_assert = classify_exec_outcome( + Err("SubmitError: Assertion failed".to_string()), + Some(Err(SubmitError::AssertionFailed { + label: "non_empty".to_string(), + expression: "this.len() > 0".to_string(), + })), + ); + assert!(matches!( + submit_assert, + ExecOutcome::SubmitAssertionFailed { .. } + )); + assert_eq!( + outcome_to_raw_output(&submit_assert), + "SubmitError: Assertion failed" + ); + + let python_exception = classify_exec_outcome(Err("Traceback...".into()), None); + assert!(matches!( + python_exception, + ExecOutcome::PythonException { ref message } if message == "Traceback..." + )); + assert_eq!(outcome_to_raw_output(&python_exception), "Traceback..."); + + let recoverable = ExecOutcome::RecoverableParse { + message: "Your response was empty.".to_string(), + }; + assert_eq!( + outcome_to_raw_output(&recoverable), + "Your response was empty." + ); + } + + #[test] + fn recoverable_parse_error_detection_only_triggers_on_empty_response() { + let empty_err = PredictError::Parse { + source: ParseError::ExtractionFailed { + field: "code".to_string(), + raw_response: String::new(), + reason: "empty passthrough response".to_string(), + }, + raw_response: " \n\t".to_string(), + lm_usage: LmUsage::default(), + chat: Chat::new(vec![]), + }; + let recovered = recoverable_outcome_from_parse_error(&empty_err) + .expect("empty response should be recoverable"); + assert!(recovered.0.contains("Empty response from model")); + assert!(recovered.0.contains("Raw response: len=")); + assert!(recovered.0.contains("\\n\\t")); + + let non_empty_err = PredictError::Parse { + source: ParseError::ExtractionFailed { + field: "code".to_string(), + raw_response: "no code".to_string(), + reason: "failed extraction".to_string(), + }, + raw_response: "I refuse".to_string(), + lm_usage: LmUsage::default(), + chat: Chat::new(vec![]), + }; + assert!( + recoverable_outcome_from_parse_error(&non_empty_err).is_none(), + "non-empty parse failures should remain terminal" + ); + } + + #[tokio::test] + async fn pyo3_runtime_requires_sub_lm_when_not_configured() { + let module = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .build(); + + let err = module + .call(RuntimePolicySigInput { + prompt: "ping".to_string(), + }) + .await + .expect_err("missing sub-LM should fail before first action turn"); + match err { + PredictError::Module { source, .. } => { + assert!( + source.to_string().contains("configured sub-LM"), + "expected sub-LM config error, got: {source}" + ); + } + other => panic!("expected module error, got: {other}"), + } + } +} diff --git a/crates/dspy-rs/src/modules/rlm/previews.rs b/crates/dspy-rs/src/modules/rlm/previews.rs new file mode 100644 index 00000000..5a933be6 --- /dev/null +++ b/crates/dspy-rs/src/modules/rlm/previews.rs @@ -0,0 +1,1111 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use bamltype::baml_types::ir_type::{TypeGeneric, UnionTypeViewGeneric}; +use bamltype::baml_types::type_meta::base::TypeMeta; +use bamltype::baml_types::{StreamingMode, TypeIR, TypeValue}; +use bamltype::internal_baml_jinja::types::{Class, OutputFormatContent}; +use tracing::{debug, info_span}; + +use super::runtime::MethodSignature; +use crate::{BamlType, Facet, FieldSchema, Signature, SignatureSchema}; + +#[derive(Clone, Copy)] +struct RenderBudget { + max_methods: usize, + max_depth: usize, +} + +impl RenderBudget { + const fn relaxed() -> Self { + Self { + max_methods: usize::MAX, + max_depth: 12, + } + } +} + +pub(super) fn render_previews( + _input: &S::Input, + methods_by_var: &BTreeMap>, + methods_by_type: &BTreeMap>, +) -> String +where + S::Input: BamlType + for<'a> Facet<'a>, +{ + let schema = SignatureSchema::of::(); + let input_format = ::baml_output_format(); + + let render_span = info_span!( + "rlm.preview.render", + input_fields = schema.input_fields().len(), + method_vars = methods_by_var.len(), + output_len = tracing::field::Empty + ); + let _render_guard = render_span.enter(); + + let budget = RenderBudget::relaxed(); + let rendered = render_with_budget( + schema, + input_format, + methods_by_var, + methods_by_type, + budget, + ); + let output_len = rendered.chars().count(); + debug!( + output_len, + max_methods = budget.max_methods, + max_depth = budget.max_depth, + "preview rendered" + ); + render_span.record("output_len", output_len); + rendered +} + +pub(super) fn is_primitive_input_type(type_ir: &TypeIR) -> bool { + let Some(inner) = strip_optional(type_ir) else { + return false; + }; + + matches!( + inner, + TypeGeneric::Primitive(TypeValue::String, _) + | TypeGeneric::Primitive(TypeValue::Int, _) + | TypeGeneric::Primitive(TypeValue::Float, _) + | TypeGeneric::Primitive(TypeValue::Bool, _) + ) +} + +pub(super) fn type_label(type_ir: &TypeIR, output_format: &OutputFormatContent) -> String { + clean_type_expr(type_ir, output_format) +} + +pub(super) fn render_type_shape( + type_ir: &TypeIR, + output_format: &OutputFormatContent, + indent: usize, +) -> Vec { + let mut visited = BTreeSet::new(); + let methods_by_type = BTreeMap::new(); + render_type_node( + type_ir, + output_format, + &methods_by_type, + indent, + 0, + RenderBudget::relaxed().max_depth, + &mut visited, + ) +} + +fn render_with_budget( + schema: &SignatureSchema, + input_format: &OutputFormatContent, + methods_by_var: &BTreeMap>, + methods_by_type: &BTreeMap>, + budget: RenderBudget, +) -> String { + let mut lines = Vec::new(); + let mut rendered_any = false; + + for field in schema.input_fields() { + if is_primitive_input_type(&field.type_ir) { + continue; + } + + rendered_any = true; + lines.extend(render_variable_block( + field, + input_format, + methods_by_var + .get(field.rust_name.as_str()) + .map(Vec::as_slice), + methods_by_type, + budget, + )); + lines.push(String::new()); + } + + if !rendered_any { + lines.push("(No complex input variables.)".to_string()); + } + + while lines.last().is_some_and(String::is_empty) { + lines.pop(); + } + + lines.join("\n") +} + +fn render_variable_block( + field: &FieldSchema, + output_format: &OutputFormatContent, + methods: Option<&[MethodSignature]>, + methods_by_type: &BTreeMap>, + budget: RenderBudget, +) -> Vec { + let mut lines = Vec::new(); + + lines.push(format!( + "Variable: `{}` (access it in your code)", + field.rust_name + )); + lines.push(format!( + "Type: {}", + type_label(&field.type_ir, output_format) + )); + + if !field.docs.trim().is_empty() { + lines.push(format!("Description: {}", normalize_doc_text(&field.docs))); + } + + lines.push("Schema:".to_string()); + + let mut visited = BTreeSet::new(); + lines.extend(render_root_schema( + &field.type_ir, + output_format, + methods, + methods_by_type, + 2, + 0, + budget, + &mut visited, + )); + + lines +} + +fn render_root_schema( + type_ir: &TypeIR, + output_format: &OutputFormatContent, + methods: Option<&[MethodSignature]>, + methods_by_type: &BTreeMap>, + indent: usize, + depth: usize, + budget: RenderBudget, + visited: &mut BTreeSet, +) -> Vec { + if let Some((class_name, mode)) = class_type_ref(type_ir) + && let Some(class) = output_format.classes.get(&(class_name.to_string(), mode)) + { + return render_class_block( + class, + output_format, + methods, + methods_by_type, + indent, + depth, + budget, + visited, + ); + } + + render_type_node( + type_ir, + output_format, + methods_by_type, + indent, + depth, + budget.max_depth, + visited, + ) +} + +fn render_class_block( + class: &Class, + output_format: &OutputFormatContent, + methods: Option<&[MethodSignature]>, + methods_by_type: &BTreeMap>, + indent: usize, + depth: usize, + budget: RenderBudget, + visited: &mut BTreeSet, +) -> Vec { + let class_name = class.name.rendered_name().to_string(); + if depth >= budget.max_depth || !visited.insert(class_name.clone()) { + return vec![format!("{}{}", spaces(indent), class_name)]; + } + + let mut lines = Vec::new(); + lines.push(format!("{}{} {{", spaces(indent), class_name)); + + let methods = methods.or_else(|| methods_by_type.get(&class_name).map(Vec::as_slice)); + if let Some(methods) = methods { + let methods = methods + .iter() + .filter(|method| !method.is_dunder) + .take(budget.max_methods) + .collect::>(); + + if !methods.is_empty() { + lines.push(format!("{}// methods", spaces(indent + 2))); + for method in methods { + lines.push(format!( + "{}{}", + spaces(indent + 2), + render_method_line(method) + )); + } + } + } + + lines.push(format!("{}// shape", spaces(indent + 2))); + for (field_name, field_type, description, _) in &class.fields { + lines.extend(render_field_line( + field_name.real_name(), + field_type, + description.as_deref(), + output_format, + methods_by_type, + indent + 2, + depth + 1, + budget, + visited, + )); + } + + lines.push(format!("{}}}", spaces(indent))); + lines +} + +fn render_field_line( + field_name: &str, + field_type: &TypeIR, + description: Option<&str>, + output_format: &OutputFormatContent, + methods_by_type: &BTreeMap>, + indent: usize, + depth: usize, + budget: RenderBudget, + visited: &mut BTreeSet, +) -> Vec { + let mut lines = Vec::new(); + let rendered = render_type_node( + field_type, + output_format, + methods_by_type, + indent + 2, + depth, + budget.max_depth, + visited, + ); + + if rendered.len() == 1 { + let mut line = format!( + "{}{}: {}", + spaces(indent), + field_name, + rendered[0].trim_start() + ); + if let Some(description) = description + && !description.trim().is_empty() + { + line.push_str(" // "); + line.push_str(&normalize_doc_text(description)); + } + lines.push(line); + return lines; + } + + let mut first_line = format!( + "{}{}: {}", + spaces(indent), + field_name, + rendered[0].trim_start() + ); + if let Some(description) = description + && !description.trim().is_empty() + { + first_line.push_str(" // "); + first_line.push_str(&normalize_doc_text(description)); + } + lines.push(first_line); + lines.extend(rendered.into_iter().skip(1)); + + lines +} + +fn render_type_node( + type_ir: &TypeIR, + output_format: &OutputFormatContent, + methods_by_type: &BTreeMap>, + indent: usize, + depth: usize, + max_depth: usize, + visited: &mut BTreeSet, +) -> Vec { + if depth >= max_depth { + return vec![format!( + "{}{}", + spaces(indent), + type_label(type_ir, output_format) + )]; + } + + if let Some(optional_inner) = optional_inner(type_ir) + && is_simple_type(optional_inner) + { + return vec![format!( + "{}{} | null", + spaces(indent), + type_label(optional_inner, output_format) + )]; + } + + match type_ir { + TypeGeneric::List(inner, _) => { + render_list_node( + inner, + output_format, + methods_by_type, + indent, + depth + 1, + max_depth, + visited, + ) + } + TypeGeneric::Map(key, value, _) => { + let key_name = type_label(key, output_format); + if is_simple_type(value) { + return vec![format!( + "{}map<{}, {}>", + spaces(indent), + key_name, + type_label(value, output_format) + )]; + } + + let mut lines = vec![format!("{}map<{},", spaces(indent), key_name)]; + lines.extend(render_type_node( + value, + output_format, + methods_by_type, + indent + 2, + depth + 1, + max_depth, + visited, + )); + lines.push(format!("{}>", spaces(indent))); + lines + } + TypeGeneric::Class { name, mode, .. } => { + if let Some(class) = output_format.classes.get(&(name.to_string(), *mode)) { + let class_methods = methods_by_type + .get(class.name.rendered_name()) + .map(Vec::as_slice); + render_class_block( + class, + output_format, + class_methods, + methods_by_type, + indent, + depth, + RenderBudget::relaxed(), + visited, + ) + } else { + vec![format!("{}{}", spaces(indent), short_name(name))] + } + } + TypeGeneric::Enum { name, .. } => vec![format!( + "{}{}", + spaces(indent), + enum_name(name, output_format) + )], + TypeGeneric::Union(union, _) => { + render_union_node( + union, + output_format, + methods_by_type, + indent, + depth, + max_depth, + visited, + ) + } + TypeGeneric::RecursiveTypeAlias { name, .. } => { + if let Some(alias) = output_format.structural_recursive_aliases.get(name) { + render_type_node( + alias, + output_format, + methods_by_type, + indent, + depth + 1, + max_depth, + visited, + ) + } else { + vec![format!("{}{}", spaces(indent), short_name(name))] + } + } + TypeGeneric::Primitive(value, _) => { + vec![format!("{}{}", spaces(indent), primitive_name(*value))] + } + TypeGeneric::Literal(literal, _) => { + vec![format!("{}{:?}", spaces(indent), literal)] + } + _ => vec![format!( + "{}{}", + spaces(indent), + clean_diagnostic_repr(type_ir) + )], + } +} + +fn render_list_node( + inner: &TypeIR, + output_format: &OutputFormatContent, + methods_by_type: &BTreeMap>, + indent: usize, + depth: usize, + max_depth: usize, + visited: &mut BTreeSet, +) -> Vec { + if is_simple_type(inner) { + return vec![format!( + "{}list[{}]", + spaces(indent), + type_label(inner, output_format) + )]; + } + + let mut lines = vec![format!("{}list[", spaces(indent))]; + lines.extend(render_type_node( + inner, + output_format, + methods_by_type, + indent + 2, + depth, + max_depth, + visited, + )); + lines.push(format!("{}]", spaces(indent))); + lines +} + +fn render_union_node( + union: &bamltype::baml_types::ir_type::UnionTypeGeneric, + output_format: &OutputFormatContent, + methods_by_type: &BTreeMap>, + indent: usize, + depth: usize, + max_depth: usize, + visited: &mut BTreeSet, +) -> Vec { + if let UnionTypeViewGeneric::Optional(inner) = union.view() { + if is_simple_type(inner) { + return vec![format!( + "{}{} | null", + spaces(indent), + type_label(inner, output_format) + )]; + } + } + + let mut lines = vec![format!("{}one of:", spaces(indent))]; + for option in union.iter_include_null() { + let option = unwrap_single_payload_variant_class(option, output_format).unwrap_or(option); + let rendered = render_type_node( + option, + output_format, + methods_by_type, + indent + 4, + depth + 1, + max_depth, + visited, + ); + if rendered.is_empty() { + continue; + } + + lines.push(format!( + "{}- {}", + spaces(indent + 2), + rendered[0].trim_start() + )); + for extra in rendered.iter().skip(1) { + lines.push(extra.to_string()); + } + } + + lines +} + +fn unwrap_single_payload_variant_class<'a>( + type_ir: &'a TypeIR, + output_format: &'a OutputFormatContent, +) -> Option<&'a TypeIR> { + let TypeGeneric::Class { name, mode, .. } = type_ir else { + return None; + }; + let class = output_format.classes.get(&(name.to_string(), *mode))?; + if class.fields.len() != 2 { + return None; + } + + let mut literal_count = 0usize; + let mut payload: Option<&TypeIR> = None; + for (_, field_type, _, _) in &class.fields { + if matches!(field_type, TypeGeneric::Literal(..)) { + literal_count += 1; + continue; + } + if payload.is_some() { + return None; + } + payload = Some(field_type); + } + + if literal_count == 1 { payload } else { None } +} + +fn render_method_line(method: &MethodSignature) -> String { + let mut line = format!(".{}{}", method.name, method.signature); + let doc = normalize_doc_text(&method.doc); + if !doc.is_empty() { + line.push_str(" // "); + line.push_str(&doc); + } + line +} + +fn is_simple_type(type_ir: &TypeIR) -> bool { + if let Some(inner) = strip_optional(type_ir) { + return matches!( + inner, + TypeGeneric::Primitive(..) + | TypeGeneric::Enum { .. } + | TypeGeneric::Literal(..) + | TypeGeneric::Top(..) + ); + } + + matches!( + type_ir, + TypeGeneric::Primitive(..) + | TypeGeneric::Enum { .. } + | TypeGeneric::Literal(..) + | TypeGeneric::Top(..) + ) +} + +fn strip_optional(type_ir: &TypeIR) -> Option<&TypeIR> { + match type_ir { + TypeGeneric::Union(union, _) => match union.view() { + UnionTypeViewGeneric::Optional(inner) => Some(inner), + _ => None, + }, + _ => Some(type_ir), + } +} + +fn optional_inner(type_ir: &TypeIR) -> Option<&TypeIR> { + match type_ir { + TypeGeneric::Union(union, _) => match union.view() { + UnionTypeViewGeneric::Optional(inner) => Some(inner), + _ => None, + }, + _ => None, + } +} + +fn class_type_ref(type_ir: &TypeIR) -> Option<(&str, StreamingMode)> { + match type_ir { + TypeGeneric::Class { name, mode, .. } => Some((name.as_str(), *mode)), + TypeGeneric::Union(union, _) => match union.view() { + UnionTypeViewGeneric::Optional(inner) => class_type_ref(inner), + _ => None, + }, + _ => None, + } +} + +fn clean_type_expr(type_ir: &TypeIR, output_format: &OutputFormatContent) -> String { + match type_ir { + TypeGeneric::Primitive(value, _) => primitive_name(*value).to_string(), + TypeGeneric::Class { name, mode, .. } => output_format + .classes + .get(&(name.to_string(), *mode)) + .map(|class| class.name.rendered_name().to_string()) + .unwrap_or_else(|| short_name(name)), + TypeGeneric::Enum { name, .. } => enum_name(name, output_format), + TypeGeneric::List(inner, _) => { + format!("list[{}]", clean_type_expr(inner, output_format)) + } + TypeGeneric::Map(key, value, _) => format!( + "map<{}, {}>", + clean_type_expr(key, output_format), + clean_type_expr(value, output_format) + ), + TypeGeneric::Union(union, _) => { + if let UnionTypeViewGeneric::Optional(inner) = union.view() { + return format!("{} | null", clean_type_expr(inner, output_format)); + } + + let variants = union + .iter_include_null() + .into_iter() + .map(|variant| clean_type_expr(variant, output_format)) + .collect::>(); + variants.join(" | ") + } + TypeGeneric::RecursiveTypeAlias { name, .. } => short_name(name), + _ => clean_diagnostic_repr(type_ir), + } +} + +fn clean_diagnostic_repr(type_ir: &TypeIR) -> String { + let mut out = type_ir.diagnostic_repr().to_string(); + out = out.replace("class `", ""); + out = out.replace("enum `", ""); + out = out.replace('`', ""); + for token in ["class ", "enum "] { + out = out.replace(token, ""); + } + short_path_tokens(&out) +} + +fn short_path_tokens(raw: &str) -> String { + let mut out = String::with_capacity(raw.len()); + let mut token = String::new(); + + let flush = |out: &mut String, token: &mut String| { + if token.is_empty() { + return; + } + if token.contains("::") { + if let Some(last) = token.rsplit("::").next() { + out.push_str(last); + } + } else { + out.push_str(token); + } + token.clear(); + }; + + for ch in raw.chars() { + if ch.is_ascii_alphanumeric() || ch == '_' || ch == ':' { + token.push(ch); + } else { + flush(&mut out, &mut token); + out.push(ch); + } + } + flush(&mut out, &mut token); + out +} + +fn enum_name(internal: &str, output_format: &OutputFormatContent) -> String { + output_format + .enums + .get(internal) + .map(|enm| enm.name.rendered_name().to_string()) + .unwrap_or_else(|| short_name(internal)) +} + +fn primitive_name(value: TypeValue) -> &'static str { + match value { + TypeValue::String => "string", + TypeValue::Int => "int", + TypeValue::Float => "float", + TypeValue::Bool => "bool", + TypeValue::Null => "null", + _ => "value", + } +} + +fn short_name(path: &str) -> String { + path.rsplit("::").next().unwrap_or(path).to_string() +} + +fn normalize_doc_text(text: &str) -> String { + text.lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .collect::>() + .join(" ") +} + +fn spaces(count: usize) -> String { + " ".repeat(count) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::BamlType; + use crate::Signature; + + #[derive(Clone, Debug)] + #[BamlType] + struct PreviewAction { + /// Tool name. + name: String, + /// JSON arguments. + arguments: String, + /// Tool output. + result: Option, + /// True if the tool errored. + is_error: bool, + } + + #[derive(Clone, Debug)] + #[BamlType] + struct PreviewTurn { + /// User message that started this turn. + trigger: Option, + /// Tool actions in this turn. + actions: Vec, + } + + #[derive(Clone, Debug)] + #[BamlType] + struct PreviewSession { + /// First user message, truncated. + brief: Option, + /// Turn sequence. + turns: Vec, + } + + #[derive(Clone, Debug)] + #[BamlType] + struct PreviewSessions { + /// Stored sessions. + items: Vec, + } + + #[derive(Signature, Clone, Debug)] + struct PreviewSig { + #[input] + title: String, + + #[input] + count: i64, + + #[output] + answer: String, + } + + #[test] + fn primitive_inputs_are_skipped() { + let input = PreviewSigInput { + title: "x".to_string(), + count: 3, + }; + let rendered = render_previews::(&input, &BTreeMap::new(), &BTreeMap::new()); + assert!(rendered.contains("(No complex input variables.)")); + } + + #[derive(Signature, Clone, Debug)] + struct RichPreviewSig { + #[input] + /// Turn-level trajectories for each development session. + sessions: PreviewSessions, + + #[output] + answer: String, + } + + #[derive(Clone, Debug)] + #[BamlType] + struct DuplicateActions { + /// Primary action list. + primary: Vec, + /// Backup action list. + backup: Vec, + } + + #[derive(Signature, Clone, Debug)] + struct DedupPreviewSig { + #[input] + /// Two fields that reference the same nested class. + actions: DuplicateActions, + + #[output] + answer: String, + } + + #[test] + fn schema_rendering_has_methods_shape_comments_and_nested_lists() { + let input = RichPreviewSigInput { + sessions: PreviewSessions { + items: vec![PreviewSession { + brief: Some("Investigate signal drop".to_string()), + turns: vec![PreviewTurn { + trigger: Some("start".to_string()), + actions: vec![PreviewAction { + name: "search".to_string(), + arguments: "{\"q\":\"start\"}".to_string(), + result: Some("ok".to_string()), + is_error: false, + }], + }], + }], + }, + }; + let methods = BTreeMap::from([( + "sessions".to_string(), + vec![ + MethodSignature { + name: "search".to_string(), + signature: "(query)".to_string(), + doc: "Find matching sessions.".to_string(), + source: super::super::runtime::MethodSource::Custom, + is_dunder: false, + }, + MethodSignature { + name: "hidden".to_string(), + signature: "()".to_string(), + doc: "".to_string(), + source: super::super::runtime::MethodSource::Custom, + is_dunder: false, + }, + ], + )]); + + let rendered = render_previews::(&input, &methods, &BTreeMap::new()); + assert!(rendered.contains("Variable: `sessions` (access it in your code)")); + assert!(rendered.contains("Type: PreviewSessions")); + assert!( + rendered.contains("Description: Turn-level trajectories for each development session.") + ); + assert!(rendered.contains("// methods")); + assert!(rendered.contains(".search(query) // Find matching sessions.")); + assert!(rendered.contains(".hidden()")); + assert!(!rendered.contains(".hidden() //")); + assert!(rendered.contains("// shape")); + assert!(rendered.contains("items: list[ // Stored sessions.")); + assert!(rendered.contains("brief: string | null // First user message, truncated.")); + assert!(rendered.contains("turns: list[ // Turn sequence.")); + assert!(rendered.contains("PreviewTurn {")); + assert!(rendered.contains("actions: list[ // Tool actions in this turn.")); + assert!(rendered.contains("PreviewAction {")); + assert!(rendered.contains("name: string // Tool name.")); + assert!(rendered.contains("arguments: string // JSON arguments.")); + assert!(rendered.contains("result: string | null // Tool output.")); + assert!(rendered.contains("is_error: bool // True if the tool errored.")); + assert!( + rendered.contains("trigger: string | null // User message that started this turn.") + ); + assert!(!rendered.contains("Vec<")); + assert!(!rendered.contains("String")); + assert!(!rendered.contains("i64")); + assert!(!rendered.contains("$self")); + } + + #[test] + fn shared_nested_type_is_rendered_once_then_referenced() { + let input = DedupPreviewSigInput { + actions: DuplicateActions { + primary: vec![PreviewAction { + name: "search".to_string(), + arguments: "{}".to_string(), + result: None, + is_error: false, + }], + backup: vec![PreviewAction { + name: "grep".to_string(), + arguments: "{}".to_string(), + result: None, + is_error: false, + }], + }, + }; + + let rendered = + render_previews::(&input, &BTreeMap::new(), &BTreeMap::new()); + assert_eq!(rendered.matches("PreviewAction {").count(), 1); + assert_eq!(rendered.matches("name: string // Tool name.").count(), 1); + } + + #[derive(Clone, Debug)] + #[BamlType] + enum UnionIndentChoice { + First { data: PreviewSession }, + Second { data: PreviewAction }, + } + + #[derive(Signature, Clone, Debug)] + struct UnionIndentSig { + #[input] + choice: UnionIndentChoice, + + #[output] + answer: String, + } + + #[test] + fn union_option_continuations_preserve_nested_indentation() { + let input = UnionIndentSigInput { + choice: UnionIndentChoice::First { + data: PreviewSession { + brief: Some("Investigate signal drop".to_string()), + turns: vec![PreviewTurn { + trigger: Some("start".to_string()), + actions: vec![PreviewAction { + name: "search".to_string(), + arguments: "{\"q\":\"start\"}".to_string(), + result: Some("ok".to_string()), + is_error: false, + }], + }], + }, + }, + }; + + let rendered = + render_previews::(&input, &BTreeMap::new(), &BTreeMap::new()); + let option_indent = rendered + .lines() + .find(|line| line.contains("- PreviewSession {")) + .map(|line| line.chars().take_while(|ch| *ch == ' ').count()) + .expect("union option line present"); + let brief_indent = rendered + .lines() + .find(|line| line.contains("brief: string | null")) + .map(|line| line.chars().take_while(|ch| *ch == ' ').count()) + .expect("nested brief line present"); + assert!( + brief_indent > option_indent, + "nested field indent should be deeper than its union option line" + ); + } + + #[derive(Clone, Debug)] + #[BamlType] + struct WrapA { + value: String, + } + + #[derive(Clone, Debug)] + #[BamlType] + struct WrapB { + value: String, + } + + #[derive(Clone, Debug)] + #[BamlType] + enum WrapperUnion { + First { data: WrapA }, + Second { data: WrapB }, + } + + #[derive(Signature, Clone, Debug)] + struct WrapperUnionSig { + #[input] + item: WrapperUnion, + + #[output] + answer: String, + } + + #[test] + fn single_payload_data_enum_variants_render_as_payload_union_arms() { + let input = WrapperUnionSigInput { + item: WrapperUnion::First { + data: WrapA { + value: "x".to_string(), + }, + }, + }; + let rendered = + render_previews::(&input, &BTreeMap::new(), &BTreeMap::new()); + + assert!(rendered.contains("one of:")); + assert!(rendered.contains("WrapA {")); + assert!(rendered.contains("WrapB {")); + assert!(!rendered.contains("WrapperUnion_First {")); + assert!(!rendered.contains("WrapperUnion_Second {")); + assert!(!rendered.contains("type: String(\"First\")")); + assert!(!rendered.contains("type: String(\"Second\")")); + assert!(!rendered.contains("data: WrapA {")); + assert!(!rendered.contains("data: WrapB {")); + } + + #[test] + fn render_method_line_collapses_multiline_docs_to_single_line() { + let method = MethodSignature { + name: "after".to_string(), + signature: "(date)".to_string(), + doc: "Returns `Sessions`: sessions on or after an ISO date prefix like `2026-02-25`.\n\nReturns a `Sessions` sub-collection so calls can be chained.".to_string(), + source: super::super::runtime::MethodSource::Custom, + is_dunder: false, + }; + + let rendered = render_method_line(&method); + assert_eq!( + rendered, + ".after(date) // Returns `Sessions`: sessions on or after an ISO date prefix like `2026-02-25`. Returns a `Sessions` sub-collection so calls can be chained." + ); + } + + #[test] + fn nested_class_methods_render_when_provided_by_type_name() { + let input = RichPreviewSigInput { + sessions: PreviewSessions { + items: vec![PreviewSession { + brief: Some("Investigate signal drop".to_string()), + turns: vec![PreviewTurn { + trigger: Some("start".to_string()), + actions: vec![PreviewAction { + name: "search".to_string(), + arguments: "{\"q\":\"start\"}".to_string(), + result: Some("ok".to_string()), + is_error: false, + }], + }], + }], + }, + }; + let methods_by_type = BTreeMap::from([( + "PreviewSession".to_string(), + vec![MethodSignature { + name: "thread".to_string(), + signature: "(participants)".to_string(), + doc: "Conversation view for selected participants.".to_string(), + source: super::super::runtime::MethodSource::Custom, + is_dunder: false, + }], + )]); + + let rendered = + render_previews::(&input, &BTreeMap::new(), &methods_by_type); + assert!(rendered.contains("PreviewSession {")); + assert!(rendered.contains(".thread(participants) // Conversation view for selected participants.")); + } + + #[test] + fn methods_without_docstrings_are_rendered_without_comment_suffix() { + let input = RichPreviewSigInput { + sessions: PreviewSessions { + items: vec![PreviewSession { + brief: Some("Investigate signal drop".to_string()), + turns: vec![PreviewTurn { + trigger: Some("start".to_string()), + actions: vec![PreviewAction { + name: "search".to_string(), + arguments: "{\"q\":\"start\"}".to_string(), + result: Some("ok".to_string()), + is_error: false, + }], + }], + }], + }, + }; + let methods = BTreeMap::from([( + "sessions".to_string(), + vec![MethodSignature { + name: "undocumented".to_string(), + signature: "()".to_string(), + doc: "".to_string(), + source: super::super::runtime::MethodSource::Custom, + is_dunder: false, + }], + )]); + + let rendered = render_previews::(&input, &methods, &BTreeMap::new()); + assert!(rendered.contains(".undocumented()")); + assert!(!rendered.contains(".undocumented() //")); + } +} diff --git a/crates/dspy-rs/src/modules/rlm/prompt.rs b/crates/dspy-rs/src/modules/rlm/prompt.rs new file mode 100644 index 00000000..e43bcb48 --- /dev/null +++ b/crates/dspy-rs/src/modules/rlm/prompt.rs @@ -0,0 +1,385 @@ +use crate::{ConstraintKind, Signature, SignatureSchema}; +use bamltype::baml_types::ir_type::{TypeGeneric, UnionTypeViewGeneric}; +use bamltype::internal_baml_jinja::types::OutputFormatContent; + +use super::RlmConfig; +use super::previews::{is_primitive_input_type, render_type_shape, type_label}; + +const PATTERNS_BLOCK: &str = r#"## Sub-LLM Patterns + +Use these patterns when direct string operations are not enough. + +# Semantic filter +# Budget-aware: llm_query_batched uses 1 call per item. Slice first! +relevant = [x for x, r in zip(items, llm_query_batched( + [f"Is {x.label} about {topic}? yes/no" for x in items[:5]] +)) if 'yes' in r.lower()] + +# Chain +findings = llm_query_batched([f"Key finding: {x}" for x in relevant]) +answer = llm_query(f"Synthesize:\n" + "\n---\n".join(findings)) + +# Map-reduce +chunks = [text[i:i+2000] for i in range(0, len(text), 2000)] +parts = llm_query_batched([f"Summarize: {c}" for c in chunks]) +summary = llm_query(f"Combine:\n" + "\n".join(parts)) + +# Direct +quick_answer = llm_query(f"Answer directly: {question}") + +# SUBMIT safely for long answers +# Build long strings via variables first to avoid unterminated triple-quote/parens errors. +direct_answer = ( + "Line 1...\n" + "Line 2..." +) +key_findings = ( + "1. First finding...\n" + "2. Second finding..." +) + +SUBMIT(direct_answer=direct_answer, key_findings=key_findings)"#; + +pub(super) fn render_action_instruction( + _config: &RlmConfig, + instruction_override: Option<&str>, + variable_schemas: &str, +) -> String { + let schema = SignatureSchema::of::(); + let task = instruction_override + .unwrap_or_else(|| schema.instruction()) + .trim(); + + let mut lines = vec![ + "## Task".to_string(), + task.to_string(), + String::new(), + "## Input Variables".to_string(), + ]; + + if variable_schemas.trim().is_empty() { + lines.push("(No complex input variables.)".to_string()); + } else { + lines.push(variable_schemas.trim().to_string()); + } + + lines.push(String::new()); + lines.push("## Output Schema".to_string()); + lines.push("Call SUBMIT() with the following fields when you have your answer:".to_string()); + lines.push(String::new()); + lines.push("Your output fields are:".to_string()); + + let output_format = schema.output_format(); + for (index, field) in schema.output_fields().iter().enumerate() { + let type_name = type_label(&field.type_ir, output_format); + let mut doc_lines = field.docs.lines().map(str::trim_end).collect::>(); + while doc_lines.first().is_some_and(|line| line.trim().is_empty()) { + doc_lines.remove(0); + } + while doc_lines.last().is_some_and(|line| line.trim().is_empty()) { + doc_lines.pop(); + } + + if let Some(first_doc) = doc_lines.first() { + lines.push(format!( + "{}. `{}` ({}): {}", + index + 1, + field.lm_name, + type_name, + first_doc + )); + for line in doc_lines.iter().skip(1) { + lines.push(format!(" {}", line)); + } + } else { + lines.push(format!( + "{}. `{}` ({})", + index + 1, + field.lm_name, + type_name + )); + } + + if let Some(variants) = enum_variants_line(&field.type_ir, output_format) { + lines.push(format!(" Valid values: {variants}")); + } + + if !is_simple_output_type(&field.type_ir) { + lines.push(" Schema:".to_string()); + for line in render_type_shape(&field.type_ir, output_format, 5) { + lines.push(line); + } + } + + lines.push(String::new()); + } + + let submit_assignments = schema + .output_fields() + .iter() + .map(|field| format!("{}=...", field.lm_name)) + .collect::>() + .join(", "); + lines.push(format!("When final, call SUBMIT({submit_assignments}).")); + + lines.push(String::new()); + lines.push("## Available Tools".to_string()); + lines.push("Available in the REPL:".to_string()); + lines.push("- Input variables accessible directly by name".to_string()); + lines.push("- `llm_query(prompt)` — query a sub-LLM (~500K char capacity)".to_string()); + lines.push("- `llm_query_batched(prompts)` — batch query concurrently".to_string()); + lines.push("- `SUBMIT(field1=value1, ...)` — submit final answer".to_string()); + lines.push("- `print()` — ALWAYS print to see results".to_string()); + lines.push("- Standard libraries available (import as needed)".to_string()); + lines.push("Plus any user-provided tools with their descriptions.".to_string()); + + lines.push(String::new()); + lines.push("## Guidelines".to_string()); + lines.push("Response format contract:".to_string()); + lines.push("- Output code only.".to_string()); + lines.push("- No prose.".to_string()); + lines.push("- No markdown fences.".to_string()); + lines.push("- If needed, put reasoning only in Python comments.".to_string()); + lines.push(String::new()); + lines.push("1. EXPLORE FIRST - Look at your data before processing it.".to_string()); + lines.push("2. ITERATE - Write small code snippets, observe, decide next steps.".to_string()); + lines.push("3. VERIFY BEFORE SUBMITTING - If results seem wrong, reconsider.".to_string()); + lines.push( + "4. USE llm_query FOR SEMANTICS - String matching finds WHERE; llm_query understands WHAT." + .to_string(), + ); + lines.push( + "5. MINIMIZE RETYPING — keep intermediate results in named variables for reuse." + .to_string(), + ); + lines.push( + "6. SUBMIT ONLY AFTER SEEING OUTPUTS — verify your answer looks right before calling SUBMIT.".to_string(), + ); + + lines.push(String::new()); + lines.push("## Constraints".to_string()); + lines.push("- Soft checks use ⚠. Hard assertions use ❌.".to_string()); + let mut any_constraints = false; + for field in schema.output_fields() { + for constraint in field.constraints { + any_constraints = true; + let marker = match constraint.kind { + ConstraintKind::Check => "⚠ soft", + ConstraintKind::Assert => "❌ hard", + }; + lines.push(format!( + "- `{}`: {marker} - {} ({})", + field.lm_name, constraint.label, constraint.expression + )); + } + } + if !any_constraints { + lines.push("- No explicit soft checks or hard assertions for this signature.".to_string()); + } + + lines.push(String::new()); + lines.push(PATTERNS_BLOCK.to_string()); + + while lines.last().is_some_and(String::is_empty) { + lines.pop(); + } + + lines.join("\n") +} + +pub(super) fn render_extract_instruction( + instruction_override: Option<&str>, +) -> String { + let schema = SignatureSchema::of::(); + let task = instruction_override + .unwrap_or_else(|| schema.instruction()) + .trim(); + + [ + "The following REPL session was generated for this task:", + task, + "", + "Based on the execution history, extract the final outputs. Review what was computed and provide the best answer from the trajectory.", + ] + .join("\n") +} + +fn is_simple_output_type(type_ir: &crate::TypeIR) -> bool { + match type_ir { + TypeGeneric::Union(union, _) => match union.view() { + UnionTypeViewGeneric::Optional(inner) => is_simple_output_type(inner), + _ => false, + }, + TypeGeneric::List(inner, _) => is_simple_output_type(inner), + TypeGeneric::Primitive(..) + | TypeGeneric::Enum { .. } + | TypeGeneric::Literal(..) + | TypeGeneric::Top(..) => true, + _ => is_primitive_input_type(type_ir), + } +} + +fn enum_variants_line( + type_ir: &crate::TypeIR, + output_format: &OutputFormatContent, +) -> Option { + let enum_name = match type_ir { + TypeGeneric::Enum { name, .. } => Some(name.as_str()), + TypeGeneric::Union(union, _) => match union.view() { + UnionTypeViewGeneric::Optional(inner) => match inner { + TypeGeneric::Enum { name, .. } => Some(name.as_str()), + _ => None, + }, + _ => None, + }, + _ => None, + }?; + + let enm = output_format.enums.get(enum_name)?; + let variants = enm + .values + .iter() + .map(|(name, _)| name.rendered_name().to_string()) + .collect::>(); + if variants.is_empty() { + None + } else { + Some(variants.join(" | ")) + } +} + +#[cfg(test)] +mod tests { + use crate::BamlType; + use crate::Signature; + + use super::*; + + #[derive(Signature, Clone, Debug)] + /// Solve the query against the corpus. + struct PromptSig { + #[input] + papers: Vec, + + #[input] + question: String, + + #[output] + #[assert("this.len() > 0", label = "non_empty")] + answer: String, + } + + #[derive(Clone, Debug)] + #[BamlType] + enum FailureMode { + Ignorance, + DiscoveryFailure, + } + + #[derive(Signature, Clone, Debug)] + struct OutputFormatSig { + #[input] + question: String, + + #[output] + tags: Vec, + + #[output] + mode: FailureMode, + + #[output] + /// Line one. + /// - top bullet + /// - nested bullet + notes: String, + } + + #[test] + fn includes_new_core_sections() { + let rendered = render_action_instruction::( + &RlmConfig::default(), + None, + "Variable: `papers`", + ); + + assert!(rendered.contains("## Task")); + assert!(rendered.contains("## Input Variables")); + assert!(rendered.contains("## Output Schema")); + assert!(rendered.contains("## Available Tools")); + assert!(rendered.contains("## Guidelines")); + assert!(rendered.contains("## Constraints")); + assert!(rendered.contains("## Sub-LLM Patterns")); + assert!(rendered.contains("No markdown fences")); + assert!(rendered.contains("SUBMIT safely for long answers")); + } + + #[test] + fn system_message_sections_are_in_locked_order() { + let rendered = render_action_instruction::( + &RlmConfig::default(), + None, + "Variable: `papers`", + ); + + let idx_task = rendered.find("## Task").expect("task section"); + let idx_inputs = rendered + .find("## Input Variables") + .expect("input variables section"); + let idx_output = rendered + .find("## Output Schema") + .expect("output schema section"); + let idx_tools = rendered.find("## Available Tools").expect("tools section"); + let idx_guidelines = rendered.find("## Guidelines").expect("guidelines section"); + let idx_constraints = rendered + .find("## Constraints") + .expect("constraints section"); + let idx_patterns = rendered + .find("## Sub-LLM Patterns") + .expect("patterns section"); + + assert!(idx_task < idx_inputs); + assert!(idx_inputs < idx_output); + assert!(idx_output < idx_tools); + assert!(idx_tools < idx_guidelines); + assert!(idx_guidelines < idx_constraints); + assert!(idx_constraints < idx_patterns); + } + + #[test] + fn extract_instruction_includes_task_and_extraction_guidance() { + let rendered = render_extract_instruction::(None); + + assert!(rendered.contains("The following REPL session was generated for this task:")); + assert!(rendered.contains("Solve the query against the corpus.")); + assert!(rendered.contains("Based on the execution history, extract the final outputs.")); + } + + #[test] + fn output_section_skips_schema_for_simple_list_and_enum_and_shows_enum_values() { + let rendered = + render_action_instruction::(&RlmConfig::default(), None, ""); + + let tags_block = rendered + .split("1. `tags` (list[string])") + .nth(1) + .and_then(|tail| tail.split("2. `mode` (FailureMode)").next()) + .expect("tags block"); + assert!(!tags_block.contains("Schema:")); + + let mode_block = rendered + .split("2. `mode` (FailureMode)") + .nth(1) + .and_then(|tail| tail.split("3. `notes` (string)").next()) + .expect("mode block"); + assert!(!mode_block.contains("Schema:")); + assert!(mode_block.contains("Valid values: Ignorance | DiscoveryFailure")); + } + + #[test] + fn output_docstrings_preserve_leading_whitespace() { + let rendered = + render_action_instruction::(&RlmConfig::default(), None, ""); + assert!(rendered.contains(" - top bullet")); + assert!(rendered.contains("- nested bullet")); + } +} diff --git a/crates/dspy-rs/src/modules/rlm/py_bridge.rs b/crates/dspy-rs/src/modules/rlm/py_bridge.rs new file mode 100644 index 00000000..b91bbf77 --- /dev/null +++ b/crates/dspy-rs/src/modules/rlm/py_bridge.rs @@ -0,0 +1,2078 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use anyhow::anyhow; +use bamltype::BamlParseError; +use bamltype::baml_types::ir_type::UnionTypeViewGeneric; +use bamltype::baml_types::{BamlMap, BamlValue, LiteralValue, StreamingMode, TypeIR, TypeValue}; +use bamltype::internal_baml_jinja::types::{Class, OutputFormatContent}; +use bamltype::jsonish; +use bamltype::jsonish::deserializer::coercer::run_user_checks; +use pyo3::IntoPyObjectExt; +use pyo3::types::{ + PyAnyMethods, PyBool, PyDict, PyDictMethods, PyFloat, PyInt, PyList, PyListMethods, PyModule, + PyString, PyTuple, PyTupleMethods, PyTypeMethods, +}; +use pyo3::{Bound, Py, PyAny, PyResult, Python}; +use serde_json::Value as JsonValue; + +use super::runtime::{InterpreterSetup, MethodSignature, MethodSource, RlmInputFields}; +use super::submit::SubmitHandler; +use super::tools::LlmTools; +use crate::{BamlConvertError, BamlType, ConstraintLevel, ResponseCheck, Signature}; + +const RESERVED_GLOBAL_NAMES: [&str; 3] = ["llm_query", "llm_query_batched", "SUBMIT"]; +const MAX_METHOD_COLLECTION_DEPTH: usize = 8; +const MAX_METHOD_COLLECTION_ITEMS: usize = 12; + +pub fn setup_interpreter_globals( + py: Python<'_>, + input: &S::Input, + submit_handler: &SubmitHandler, + llm_tools: Option<&LlmTools>, +) -> PyResult +where + S::Input: RlmInputFields, +{ + let globals = PyDict::new(py); + + if let Some(name) = input + .rlm_field_names() + .iter() + .copied() + .find(|name| RESERVED_GLOBAL_NAMES.contains(name)) + { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "RLM input field '{name}' conflicts with reserved runtime binding. Rename this field (reserved names: {}).", + RESERVED_GLOBAL_NAMES.join(", ") + ))); + } + input.inject_into_python(py, &globals)?; + let input_format = ::baml_output_format(); + let (methods_by_var, methods_by_type) = + collect_methods_by_var(py, &globals, input.rlm_field_names(), input_format)?; + + if let Some(llm_tools) = llm_tools { + let tools_py = Py::new(py, llm_tools.clone())?; + let tools_bound = tools_py.bind(py); + globals.set_item("llm_query", tools_bound.getattr("llm_query")?)?; + globals.set_item( + "llm_query_batched", + tools_bound.getattr("llm_query_batched")?, + )?; + } + globals.set_item("SUBMIT", Py::new(py, submit_handler.clone())?)?; + + Ok(InterpreterSetup { + globals: globals.unbind(), + methods_by_var, + methods_by_type, + }) +} + +fn collect_methods_by_var( + py: Python<'_>, + globals: &Bound<'_, PyDict>, + field_names: &[&str], + output_format: &OutputFormatContent, +) -> PyResult<(BTreeMap>, BTreeMap>)> { + let inspect = PyModule::import(py, "inspect")?; + let mut methods_by_var = BTreeMap::new(); + let mut methods_by_type = BTreeMap::new(); + let mut observed_classes_by_name = BTreeMap::new(); + let mut observed_instances_by_name = BTreeMap::new(); + let mut candidate_modules = BTreeSet::new(); + let mut visited_object_ids = BTreeSet::new(); + let mut visited_type_names = BTreeSet::new(); + + for field_name in field_names { + let Some(value) = globals.get_item(field_name)? else { + continue; + }; + let methods = collect_visible_methods_for_object(&inspect, &value)?; + methods_by_var.insert((*field_name).to_string(), methods); + collect_methods_for_reachable_types( + &inspect, + &value, + &mut methods_by_type, + &mut observed_classes_by_name, + &mut observed_instances_by_name, + &mut candidate_modules, + &mut visited_object_ids, + &mut visited_type_names, + 0, + )?; + } + + collect_methods_for_schema_types( + py, + &inspect, + output_format, + &mut methods_by_type, + &observed_classes_by_name, + &observed_instances_by_name, + &candidate_modules, + )?; + + Ok((methods_by_var, methods_by_type)) +} + +fn collect_visible_methods_for_object( + inspect: &Bound<'_, PyModule>, + value: &Bound<'_, PyAny>, +) -> PyResult> { + if value.is_instance_of::() + || value.is_instance_of::() + || value.is_instance_of::() + || value.is_instance_of::() + || value.is_instance_of::() + || value.is_instance_of::() + || value.is_instance_of::() + { + return Ok(Vec::new()); + } + + let class = value.get_type(); + collect_visible_methods_for_class(inspect, class.as_any()) +} + +fn collect_visible_methods_for_class( + inspect: &Bound<'_, PyModule>, + class: &Bound<'_, PyAny>, +) -> PyResult> { + let members = inspect.call_method1("getmembers", (&class, inspect.getattr("isroutine")?))?; + let members = members.cast::()?; + let mut methods = Vec::new(); + + for member in members.iter() { + let tuple = member.cast::()?; + if tuple.len() != 2 { + continue; + } + let name = tuple.get_item(0)?.extract::()?; + let is_dunder = name.starts_with("__") && name.ends_with("__"); + if name == "__baml__" + || (is_dunder && !matches!(name.as_str(), "__len__" | "__iter__" | "__getitem__")) + { + continue; + } + + let callable = tuple.get_item(1)?; + let doc = extract_trimmed_docstring(&callable)?; + + methods.push(MethodSignature { + signature: sanitize_signature( + &extract_signature(inspect, &callable).unwrap_or_else(|| "()".to_string()), + ), + source: classify_method_source(&name), + name, + doc, + is_dunder, + }); + } + + methods.sort_by(|a, b| { + a.name + .cmp(&b.name) + .then(a.signature.cmp(&b.signature)) + .then(a.doc.cmp(&b.doc)) + }); + methods.dedup_by(|a, b| a.name == b.name && a.signature == b.signature); + Ok(methods) +} + +fn collect_methods_for_reachable_types( + inspect: &Bound<'_, PyModule>, + value: &Bound<'_, PyAny>, + methods_by_type: &mut BTreeMap>, + observed_classes_by_name: &mut BTreeMap>, + observed_instances_by_name: &mut BTreeMap>, + candidate_modules: &mut BTreeSet, + visited_object_ids: &mut BTreeSet, + visited_type_names: &mut BTreeSet, + depth: usize, +) -> PyResult<()> { + if depth > MAX_METHOD_COLLECTION_DEPTH { + return Ok(()); + } + + let object_id = value.as_ptr() as usize; + if !visited_object_ids.insert(object_id) { + return Ok(()); + } + + let class = value.get_type(); + let class_name = class + .name() + .ok() + .and_then(|name| name.extract::().ok()) + .unwrap_or_else(|| "".to_string()); + if visited_type_names.insert(class_name.clone()) { + let methods = collect_visible_methods_for_class(inspect, class.as_any())?; + methods_by_type.insert(class_name.clone(), methods); + } + if let Ok(module_name) = class + .getattr("__module__") + .and_then(|name| name.extract::()) + { + candidate_modules.insert(module_name); + } + if let Ok(py_name) = class.name().and_then(|name| name.extract::()) { + observed_classes_by_name + .entry(py_name) + .or_insert_with(|| class.as_any().clone().unbind()); + } + observed_instances_by_name + .entry(class_name) + .or_insert_with(|| value.clone().unbind()); + + if value.is_instance_of::() + || value.is_instance_of::() + || value.is_instance_of::() + || value.is_instance_of::() + { + return Ok(()); + } + + if let Ok(list) = value.cast::() { + for item in list.iter().take(MAX_METHOD_COLLECTION_ITEMS) { + collect_methods_for_reachable_types( + inspect, + &item, + methods_by_type, + observed_classes_by_name, + observed_instances_by_name, + candidate_modules, + visited_object_ids, + visited_type_names, + depth + 1, + )?; + } + return Ok(()); + } + + if let Ok(tuple) = value.cast::() { + for item in tuple.iter().take(MAX_METHOD_COLLECTION_ITEMS) { + collect_methods_for_reachable_types( + inspect, + &item, + methods_by_type, + observed_classes_by_name, + observed_instances_by_name, + candidate_modules, + visited_object_ids, + visited_type_names, + depth + 1, + )?; + } + return Ok(()); + } + + if let Ok(dict) = value.cast::() { + for (key, item) in dict.iter().take(MAX_METHOD_COLLECTION_ITEMS) { + collect_methods_for_reachable_types( + inspect, + &key, + methods_by_type, + observed_classes_by_name, + observed_instances_by_name, + candidate_modules, + visited_object_ids, + visited_type_names, + depth + 1, + )?; + collect_methods_for_reachable_types( + inspect, + &item, + methods_by_type, + observed_classes_by_name, + observed_instances_by_name, + candidate_modules, + visited_object_ids, + visited_type_names, + depth + 1, + )?; + } + return Ok(()); + } + + if let Ok(object_dict) = value.getattr("__dict__") + && let Ok(object_dict) = object_dict.cast::() + { + for (_name, item) in object_dict.iter() { + collect_methods_for_reachable_types( + inspect, + &item, + methods_by_type, + observed_classes_by_name, + observed_instances_by_name, + candidate_modules, + visited_object_ids, + visited_type_names, + depth + 1, + )?; + } + } + + if let Ok(class_dict_any) = class.getattr("__dict__") + && let Ok(class_dict) = class_dict_any.cast::() + { + for (name, _) in class_dict.iter() { + let Ok(name) = name.extract::() else { + continue; + }; + if name.starts_with("__") { + continue; + } + let Ok(item) = value.getattr(name.as_str()) else { + continue; + }; + if item.is_callable() { + continue; + } + collect_methods_for_reachable_types( + inspect, + &item, + methods_by_type, + observed_classes_by_name, + observed_instances_by_name, + candidate_modules, + visited_object_ids, + visited_type_names, + depth + 1, + )?; + } + } + + if let Ok(annotations_any) = class.getattr("__annotations__") + && let Ok(annotations) = annotations_any.cast::() + { + for (name, _) in annotations.iter() { + let Ok(name) = name.extract::() else { + continue; + }; + if name.starts_with("__") { + continue; + } + let Ok(item) = value.getattr(name.as_str()) else { + continue; + }; + collect_methods_for_reachable_types( + inspect, + &item, + methods_by_type, + observed_classes_by_name, + observed_instances_by_name, + candidate_modules, + visited_object_ids, + visited_type_names, + depth + 1, + )?; + } + } + + Ok(()) +} + +fn collect_methods_for_schema_types( + py: Python<'_>, + inspect: &Bound<'_, PyModule>, + output_format: &OutputFormatContent, + methods_by_type: &mut BTreeMap>, + observed_classes_by_name: &BTreeMap>, + observed_instances_by_name: &BTreeMap>, + candidate_modules: &BTreeSet, +) -> PyResult<()> { + let module_classes = collect_module_class_objects(py, inspect, candidate_modules)?; + let object_subclasses = collect_object_subclass_index(py)?; + let schema_type_names = collect_schema_type_names(output_format); + let runtime_type_names = observed_classes_by_name + .keys() + .cloned() + .collect::>(); + + let mut schema_class_names = BTreeMap::>::new(); + let mut schema_fields = BTreeMap::>::new(); + for ((raw_name, _streaming), class) in output_format.classes.iter() { + let rendered_name = class.name.rendered_name().to_string(); + let aliases = schema_class_names.entry(rendered_name.clone()).or_default(); + aliases.insert(rendered_name); + aliases.insert(raw_name.clone()); + schema_fields.entry(class.name.rendered_name().to_string()).or_insert_with(|| { + class + .fields + .iter() + .map(|(field_name, _, _, _)| field_name.real_name().to_string()) + .collect() + }); + } + + let mut resolved_classes = BTreeMap::>::new(); + let mut resolved_instances = BTreeMap::>::new(); + for (rendered_name, aliases) in &schema_class_names { + if let Some(class_obj) = resolve_schema_class_object( + py, + aliases, + observed_classes_by_name, + &module_classes, + &object_subclasses, + ) { + resolved_classes.insert(rendered_name.clone(), class_obj); + } + if let Some(instance_obj) = resolve_schema_instance_object(py, aliases, observed_instances_by_name) + { + resolved_instances.insert(rendered_name.clone(), instance_obj); + } + if !resolved_classes.contains_key(rendered_name) + && let Some(instance) = resolved_instances.get(rendered_name) + { + resolved_classes.insert( + rendered_name.clone(), + instance.bind(py).get_type().as_any().clone().unbind(), + ); + } + } + + loop { + let unresolved = schema_class_names + .keys() + .filter(|name| !resolved_classes.contains_key(*name)) + .cloned() + .collect::>(); + if unresolved.is_empty() { + break; + } + let progressed = project_unresolved_schema_classes_from_runtime_fields( + py, + &unresolved, + &schema_class_names, + &schema_fields, + &mut resolved_classes, + &mut resolved_instances, + )?; + if !progressed { + break; + } + } + + for (rendered_name, aliases) in schema_class_names { + let synthetic_by_alias = rendered_name.contains('_') && aliases.iter().any(|a| a.contains("__")); + if synthetic_by_alias + || is_synthetic_variant_class_name(&rendered_name, &schema_type_names, &runtime_type_names) + { + methods_by_type.insert(rendered_name, Vec::new()); + continue; + } + if methods_by_type.contains_key(&rendered_name) { + continue; + } + let methods = if let Some(class_obj) = resolved_classes.get(&rendered_name) { + let class_obj = class_obj.bind(py); + let resolved_name = class_obj + .getattr("__name__") + .ok() + .and_then(|name| name.extract::().ok()) + .unwrap_or_default(); + if resolved_name == rendered_name { + collect_visible_methods_for_class(inspect, class_obj)? + } else { + Vec::new() + } + } else { + Vec::new() + }; + let _ = aliases; + methods_by_type.insert(rendered_name, methods); + } + + Ok(()) +} + +fn collect_module_class_objects( + py: Python<'_>, + inspect: &Bound<'_, PyModule>, + module_names: &BTreeSet, +) -> PyResult>> { + let is_class = inspect.getattr("isclass")?; + let mut classes = BTreeMap::new(); + for module_name in module_names { + let Ok(module) = PyModule::import(py, module_name.as_str()) else { + continue; + }; + let Ok(members_any) = inspect.call_method1("getmembers", (&module, &is_class)) else { + continue; + }; + let Ok(members) = members_any.cast::() else { + continue; + }; + for member in members.iter() { + let Ok(tuple) = member.cast::() else { + continue; + }; + if tuple.len() != 2 { + continue; + } + let Ok(name) = tuple.get_item(0)?.extract::() else { + continue; + }; + let Ok(class_obj) = tuple.get_item(1) else { + continue; + }; + classes + .entry(name) + .or_insert_with(|| class_obj.clone().unbind()); + } + } + Ok(classes) +} + +fn collect_object_subclass_index(py: Python<'_>) -> PyResult>> { + let builtins = PyModule::import(py, "builtins")?; + let object_type = builtins.getattr("object")?; + let subclasses_any = object_type.call_method0("__subclasses__")?; + let subclasses = subclasses_any.cast::()?; + let mut classes = BTreeMap::new(); + for subclass in subclasses.iter() { + let Ok(name) = subclass.getattr("__name__").and_then(|name| name.extract::()) else { + continue; + }; + classes.entry(name).or_insert_with(|| subclass.clone().unbind()); + } + Ok(classes) +} + +fn resolve_schema_class_object( + py: Python<'_>, + aliases: &BTreeSet, + observed_classes_by_name: &BTreeMap>, + module_classes: &BTreeMap>, + object_subclasses: &BTreeMap>, +) -> Option> { + for alias in aliases { + if let Some(class_obj) = observed_classes_by_name.get(alias) { + return Some(class_obj.clone_ref(py)); + } + } + for alias in aliases { + if let Some(class_obj) = module_classes.get(alias) { + return Some(class_obj.clone_ref(py)); + } + } + for alias in aliases { + if let Some(class_obj) = object_subclasses.get(alias) { + return Some(class_obj.clone_ref(py)); + } + } + None +} + +fn resolve_schema_instance_object( + py: Python<'_>, + aliases: &BTreeSet, + observed_instances_by_name: &BTreeMap>, +) -> Option> { + for alias in aliases { + if let Some(instance) = observed_instances_by_name.get(alias) { + return Some(instance.clone_ref(py)); + } + } + None +} + +fn collect_schema_type_names(output_format: &OutputFormatContent) -> BTreeSet { + let mut names = BTreeSet::new(); + for class in output_format.classes.values() { + names.insert(class.name.rendered_name().to_string()); + } + for enm in output_format.enums.values() { + names.insert(enm.name.rendered_name().to_string()); + } + names +} + +fn is_synthetic_variant_class_name( + rendered_name: &str, + schema_type_names: &BTreeSet, + runtime_type_names: &BTreeSet, +) -> bool { + let Some((prefix, suffix)) = rendered_name.split_once('_') else { + return false; + }; + if prefix.is_empty() || suffix.is_empty() { + return false; + } + let Some(first) = suffix.chars().next() else { + return false; + }; + first.is_ascii_uppercase() + && (schema_type_names.contains(prefix) || runtime_type_names.contains(prefix)) +} + +fn project_unresolved_schema_classes_from_runtime_fields( + py: Python<'_>, + unresolved: &[String], + schema_aliases: &BTreeMap>, + schema_fields: &BTreeMap>, + resolved_classes: &mut BTreeMap>, + resolved_instances: &mut BTreeMap>, +) -> PyResult { + let mut progressed = false; + let mut discovered = Vec::<(String, Py, Py)>::new(); + let parents = resolved_instances + .keys() + .cloned() + .collect::>(); + + for parent in parents { + let Some(instance) = resolved_instances.get(&parent) else { + continue; + }; + let Some(field_names) = schema_fields.get(&parent) else { + continue; + }; + let instance = instance.bind(py); + for field_name in field_names { + let Ok(field_value) = instance.getattr(field_name.as_str()) else { + continue; + }; + + let candidate = if let Ok(list) = field_value.cast::() { + if list.is_empty() { + None + } else { + list.get_item(0).ok() + } + } else if let Ok(tuple) = field_value.cast::() { + if tuple.is_empty() { + None + } else { + tuple.get_item(0).ok() + } + } else { + Some(field_value) + }; + + let Some(candidate) = candidate else { + continue; + }; + if candidate.is_none() || candidate.is_callable() { + continue; + } + + let candidate_class = candidate.get_type(); + let Ok(candidate_name) = candidate_class + .name() + .and_then(|name| name.extract::()) + else { + continue; + }; + + for target in unresolved { + if resolved_classes.contains_key(target) { + continue; + } + let Some(aliases) = schema_aliases.get(target) else { + continue; + }; + if !aliases.contains(&candidate_name) { + continue; + } + + discovered.push(( + target.clone(), + candidate_class.as_any().clone().unbind(), + candidate.clone().unbind(), + )); + } + } + } + + for (target, class_obj, instance_obj) in discovered { + if resolved_classes.contains_key(&target) { + continue; + } + resolved_classes.insert(target.clone(), class_obj); + resolved_instances.entry(target).or_insert(instance_obj); + progressed = true; + } + + Ok(progressed) +} + +fn extract_trimmed_docstring(callable: &Bound<'_, PyAny>) -> PyResult { + let Some(raw_doc) = callable.getattr("__doc__")?.extract::>()? else { + return Ok(String::new()); + }; + Ok(raw_doc.trim().to_string()) +} + +fn extract_signature(inspect: &Bound<'_, PyModule>, callable: &Bound<'_, PyAny>) -> Option { + if let Ok(text_sig) = callable.getattr("__text_signature__") + && let Ok(Some(text_sig)) = text_sig.extract::>() + { + let trimmed = text_sig.trim(); + if !trimmed.is_empty() { + return Some(trimmed.to_string()); + } + } + + inspect + .call_method1("signature", (callable,)) + .ok() + .and_then(|sig| sig.str().ok()) + .and_then(|sig| sig.extract::().ok()) + .map(|sig| sig.trim().to_string()) + .filter(|sig| !sig.is_empty()) + .or_else(|| { + callable + .call_method0("__signature__") + .ok() + .and_then(|sig| sig.str().ok()) + .and_then(|sig| sig.extract::().ok()) + .map(|sig| sig.trim().to_string()) + .filter(|sig| !sig.is_empty()) + }) + .or_else(|| None) +} + +fn sanitize_signature(raw_signature: &str) -> String { + let mut signature = raw_signature.trim().to_string(); + + if signature == "($self)" || signature == "($self, /)" { + signature = "()".to_string(); + } else if signature.starts_with("($self, /, ") { + signature = signature.replacen("($self, /, ", "(", 1); + } else if signature.starts_with("($self, ") { + signature = signature.replacen("($self, ", "(", 1); + } + + if signature == "(self)" || signature == "(self, /)" { + signature = "()".to_string(); + } else if signature.starts_with("(self, /, ") { + signature = signature.replacen("(self, /, ", "(", 1); + } else if signature.starts_with("(self, ") { + signature = signature.replacen("(self, ", "(", 1); + } + signature = signature.replace("($self, /)", "()"); + signature = signature.replace("($self,)", "()"); + signature = signature.replace(", /)", ")"); + signature = signature.replace(", /, ", ", "); + + if !signature.starts_with('(') { + signature = format!("({signature})"); + } + + simplify_qualified_type_paths(&signature) +} + +fn simplify_qualified_type_paths(raw: &str) -> String { + let mut out = String::with_capacity(raw.len()); + let mut token = String::new(); + + let flush = |out: &mut String, token: &mut String| { + if token.is_empty() { + return; + } + if token.contains('.') { + if let Some(last) = token.rsplit('.').next() { + out.push_str(last); + } + } else { + out.push_str(token); + } + token.clear(); + }; + + for ch in raw.chars() { + if ch.is_ascii_alphanumeric() || ch == '_' || ch == '.' { + token.push(ch); + } else { + flush(&mut out, &mut token); + out.push(ch); + } + } + flush(&mut out, &mut token); + out +} + +fn classify_method_source(name: &str) -> MethodSource { + match name { + "__len__" | "__iter__" | "__getitem__" | "__repr__" | "__baml__" => MethodSource::Generated, + _ => MethodSource::Custom, + } +} + +/// Convert BamlValue tree to Python objects recursively. +pub fn baml_value_to_py(py: Python<'_>, value: &BamlValue) -> PyResult> { + match value { + BamlValue::String(value) => Ok(value.clone().into_py_any(py)?), + BamlValue::Int(value) => Ok(value.into_py_any(py)?), + BamlValue::Float(value) => Ok(value.into_py_any(py)?), + BamlValue::Bool(value) => Ok(value.into_py_any(py)?), + BamlValue::Null => Ok(py.None()), + BamlValue::List(items) => { + let list = PyList::empty(py); + for item in items { + list.append(baml_value_to_py(py, item)?)?; + } + Ok(list.into_any().unbind()) + } + BamlValue::Map(map) => { + let dict = PyDict::new(py); + for (key, value) in map.iter() { + dict.set_item(key, baml_value_to_py(py, value)?)?; + } + Ok(dict.into_any().unbind()) + } + BamlValue::Enum(_, variant) => Ok(variant.clone().into_py_any(py)?), + BamlValue::Class(_, fields) => { + let dict = PyDict::new(py); + for (key, value) in fields.iter() { + dict.set_item(key, baml_value_to_py(py, value)?)?; + } + Ok(dict.into_any().unbind()) + } + BamlValue::Media(_) => Err(pyo3::exceptions::PyTypeError::new_err( + "Media values are not supported in RLM V1", + )), + } +} + +pub fn kwargs_to_baml_value( + py: Python<'_>, + kwargs: &Bound<'_, PyDict>, +) -> Result { + let schema = S::schema(); + let output_format = schema.output_format(); + let mut fields = BamlMap::new(); + + for field in schema.output_fields() { + let value = kwargs + .get_item(field.lm_name) + .map_err(py_err_to_parse)? + .ok_or_else(|| missing_field_error(&[], field.lm_name))?; + let baml_value = py_to_baml_value(py, &value, &field.type_ir, output_format) + .map_err(|err| add_field_context(err, field.lm_name))?; + fields.insert(field.rust_name.to_string(), baml_value); + } + + if let Some(class_name) = output_class_name(output_format) { + Ok(BamlValue::Class(class_name, fields)) + } else { + Ok(BamlValue::Map(fields)) + } +} + +pub fn collect_checks_for_output( + value: &BamlValue, +) -> Result, BamlParseError> { + let schema = S::schema(); + + let fields = match value { + BamlValue::Class(_, fields) | BamlValue::Map(fields) => fields, + other => { + return Err(BamlParseError::Convert(BamlConvertError::new( + Vec::new(), + "object", + format!("{other:?}"), + "expected an object", + ))); + } + }; + + let mut checks = Vec::new(); + let mut failed = Vec::new(); + + for field in schema.output_fields() { + let Some(value) = fields.get(field.rust_name.as_str()) else { + return Err(missing_field_error(&[], field.rust_name.as_str())); + }; + + let results = run_user_checks(value, &field.type_ir).map_err(BamlParseError::from)?; + for (constraint, ok) in results { + if constraint.level == ConstraintLevel::Assert && !ok { + failed.push(ResponseCheck { + name: constraint + .label + .clone() + .unwrap_or_else(|| "assert".to_string()), + expression: constraint.expression.0.clone(), + status: "failed".to_string(), + }); + } + + if let Some(check) = ResponseCheck::from_check_result((constraint, ok)) { + checks.push(check); + } + } + } + + if !failed.is_empty() { + return Err(BamlParseError::ConstraintAssertsFailed { failed }); + } + + Ok(checks) +} + +fn output_class_name(output_format: &OutputFormatContent) -> Option { + let mut current = output_format.target.clone(); + loop { + match current { + TypeIR::Class { name, .. } => return Some(name), + TypeIR::RecursiveTypeAlias { name, .. } => { + if let Some(next) = output_format.structural_recursive_aliases.get(&name) { + current = next.clone(); + continue; + } + return None; + } + _ => return None, + } + } +} + +fn add_field_context(err: BamlParseError, field: &str) -> BamlParseError { + match err { + BamlParseError::Convert(err) => { + let mut path = Vec::with_capacity(err.path.len() + 1); + path.push(field.to_string()); + path.extend(err.path); + BamlParseError::Convert(BamlConvertError::new( + path, + err.expected, + err.got, + err.message, + )) + } + BamlParseError::Jsonish(inner) => BamlParseError::Convert(BamlConvertError::new( + vec![field.to_string()], + "schema", + "python", + inner.to_string(), + )), + other => other, + } +} + +pub fn py_to_baml_value( + py: Python<'_>, + obj: &Bound<'_, PyAny>, + r#type: &TypeIR, + output_format: &OutputFormatContent, +) -> Result { + let obj = if obj.hasattr("__baml__").map_err(py_err_to_parse)? { + obj.call_method0("__baml__").map_err(py_err_to_parse)? + } else { + obj.clone() + }; + let obj = normalize_python_object(py, &obj).map_err(py_err_to_parse)?; + let mut path = Vec::new(); + py_to_baml_value_inner(py, &obj, r#type, output_format, &mut path) +} + +pub fn normalize_python_object<'py>( + py: Python<'py>, + obj: &Bound<'py, PyAny>, +) -> PyResult> { + if obj.is_instance_of::() || obj.is_instance_of::() { + return Ok(obj.clone()); + } + + if let Ok(value) = obj.call_method0("model_dump") { + return Ok(value); + } + + if let Ok(value) = obj.call_method0("dict") { + return Ok(value); + } + + if let Ok(value) = obj.call_method0("_asdict") { + return Ok(value); + } + + if let Ok(dataclasses) = PyModule::import(py, "dataclasses") + && let Ok(is_dataclass) = dataclasses.getattr("is_dataclass") + && is_dataclass.call1((obj,))?.is_truthy()? + && let Ok(asdict) = dataclasses.getattr("asdict") + { + return asdict.call1((obj,)); + } + + if let Ok(attrs) = PyModule::import(py, "attr") + && let Ok(has) = attrs.getattr("has") + && has.call1((obj,))?.is_truthy()? + && let Ok(asdict) = attrs.getattr("asdict") + { + return asdict.call1((obj,)); + } + + if let Ok(dict) = obj.getattr("__dict__") { + return Ok(dict); + } + + Ok(obj.clone()) +} + +fn py_to_baml_value_inner( + py: Python<'_>, + obj: &Bound<'_, PyAny>, + r#type: &TypeIR, + output_format: &OutputFormatContent, + path: &mut Vec, +) -> Result { + let resolved = resolve_recursive_type(r#type, output_format); + + if !is_string_target(&resolved) && obj.is_instance_of::() { + let raw = obj.extract::().map_err(py_err_to_parse)?; + if let Ok(parsed_json) = serde_json::from_str::(&raw) { + let py_obj = json_value_to_py(py, &parsed_json).into_bound(py); + return py_to_baml_value_inner(py, &py_obj, &resolved, output_format, path); + } + } + + match &resolved { + TypeIR::Primitive(TypeValue::String, _) => obj + .extract::() + .map(BamlValue::String) + .map_err(py_err_to_parse), + TypeIR::Primitive(TypeValue::Int, _) => { + if obj.is_instance_of::() { + return Err(conversion_error(path, &resolved, obj)); + } + obj.extract::() + .map(BamlValue::Int) + .map_err(py_err_to_parse) + } + TypeIR::Primitive(TypeValue::Float, _) => { + if obj.is_instance_of::() { + return Err(conversion_error(path, &resolved, obj)); + } + obj.extract::() + .map(BamlValue::Float) + .map_err(py_err_to_parse) + } + TypeIR::Primitive(TypeValue::Bool, _) => obj + .extract::() + .map(BamlValue::Bool) + .map_err(py_err_to_parse), + TypeIR::Primitive(TypeValue::Null, _) => { + if obj.is_none() { + Ok(BamlValue::Null) + } else { + Err(conversion_error(path, &resolved, obj)) + } + } + TypeIR::Primitive(TypeValue::Media(_), _) => Err(conversion_error(path, &resolved, obj)), + TypeIR::Enum { name, .. } => { + let raw = obj.extract::().map_err(py_err_to_parse)?; + let enum_type = output_format.enums.get(name).ok_or_else(|| { + BamlParseError::Jsonish(anyhow!("missing enum definition for {name}")) + })?; + let mut matches_variant = false; + for (value, _) in &enum_type.values { + if value.real_name() == raw || value.rendered_name() == raw { + matches_variant = true; + break; + } + } + if !matches_variant { + return Err(conversion_error(path, &resolved, obj)); + } + Ok(BamlValue::Enum(name.to_string(), raw)) + } + TypeIR::Literal(LiteralValue::String(literal), _) => { + let raw = obj.extract::().map_err(py_err_to_parse)?; + if raw == *literal { + Ok(BamlValue::String(raw)) + } else { + Err(conversion_error(path, &resolved, obj)) + } + } + TypeIR::Literal(LiteralValue::Int(literal), _) => { + if obj.is_instance_of::() { + return Err(conversion_error(path, &resolved, obj)); + } + let raw = obj.extract::().map_err(py_err_to_parse)?; + if raw == *literal { + Ok(BamlValue::Int(raw)) + } else { + Err(conversion_error(path, &resolved, obj)) + } + } + TypeIR::Literal(LiteralValue::Bool(literal), _) => { + let raw = obj.extract::().map_err(py_err_to_parse)?; + if raw == *literal { + Ok(BamlValue::Bool(raw)) + } else { + Err(conversion_error(path, &resolved, obj)) + } + } + TypeIR::Class { name, .. } => { + py_to_class_value(py, obj, name.as_str(), output_format, path) + } + TypeIR::List(item_type, _) => { + py_to_list_value(py, obj, item_type.as_ref(), output_format, path) + } + TypeIR::Map(key_type, value_type, _) => py_to_map_value( + py, + obj, + key_type.as_ref(), + value_type.as_ref(), + output_format, + path, + ), + TypeIR::Tuple(items, _) => py_to_tuple_value(py, obj, items, output_format, path), + TypeIR::RecursiveTypeAlias { name, .. } => Err(BamlParseError::Jsonish(anyhow!( + "missing recursive alias {name}" + ))), + TypeIR::Top(_) => py_any_to_baml_value_untyped(py, obj), + TypeIR::Arrow(_, _) => Err(conversion_error(path, &resolved, obj)), + TypeIR::Union(inner, _) => match inner.view() { + UnionTypeViewGeneric::Null => { + if obj.is_none() { + Ok(BamlValue::Null) + } else { + Err(conversion_error(path, &resolved, obj)) + } + } + UnionTypeViewGeneric::Optional(t) => { + if obj.is_none() { + Ok(BamlValue::Null) + } else { + py_to_baml_value_inner(py, obj, t, output_format, path) + } + } + UnionTypeViewGeneric::OneOf(types) | UnionTypeViewGeneric::OneOfOptional(types) => { + let mut last_err: Option = None; + for t in types { + match py_to_baml_value_inner(py, obj, t, output_format, path) { + Ok(value) => return Ok(value), + Err(err) => last_err = Some(err), + } + } + Err(last_err.unwrap_or_else(|| conversion_error(path, &resolved, obj))) + } + }, + } +} + +fn py_to_class_value( + py: Python<'_>, + obj: &Bound<'_, PyAny>, + class_name: &str, + output_format: &OutputFormatContent, + path: &mut Vec, +) -> Result { + let dict = match obj.cast::() { + Ok(dict) => dict, + Err(_) => { + if let Some(value) = + orjson_fallback_to_baml(py, obj, &TypeIR::class(class_name), output_format) + { + return Ok(value); + } + return Err(conversion_error(path, &TypeIR::class(class_name), obj)); + } + }; + + let class = find_class(output_format, class_name).ok_or_else(|| { + BamlParseError::Jsonish(anyhow!("missing class definition for {class_name}")) + })?; + + let mut fields = BamlMap::new(); + for field in &class.fields { + let (name, field_type, _, _) = field; + let rendered: &str = name.rendered_name(); + let real: &str = name.real_name(); + + let value = dict + .get_item(rendered) + .map_err(py_err_to_parse)? + .or_else(|| dict.get_item(real).ok().flatten()); + + let value = match value { + Some(value) => value, + None => { + if field_type.is_optional() { + fields.insert(real.to_string(), BamlValue::Null); + continue; + } + return Err(missing_field_error(path, real)); + } + }; + + let field_value = with_path_segment(path, real.to_string(), |path| { + py_to_baml_value_inner(py, &value, field_type, output_format, path) + })?; + fields.insert(real.to_string(), field_value); + } + + Ok(BamlValue::Class(class_name.to_string(), fields)) +} + +fn py_to_map_value( + py: Python<'_>, + obj: &Bound<'_, PyAny>, + key_type: &TypeIR, + value_type: &TypeIR, + output_format: &OutputFormatContent, + path: &mut Vec, +) -> Result { + if !matches!( + key_type, + TypeIR::Primitive(TypeValue::String, _) | TypeIR::Literal(LiteralValue::String(_), _) + ) { + return Err(BamlParseError::Convert(BamlConvertError::new( + path.clone(), + "string", + schema_type_name(key_type), + "map keys must be strings", + ))); + } + + let dict = match obj.cast::() { + Ok(dict) => dict, + Err(_) => { + if let Some(value) = orjson_fallback_to_baml( + py, + obj, + &TypeIR::map(key_type.clone(), value_type.clone()), + output_format, + ) { + return Ok(value); + } + return Err(conversion_error( + path, + &TypeIR::map(key_type.clone(), value_type.clone()), + obj, + )); + } + }; + + let mut map = BamlMap::new(); + for (key, value) in dict.iter() { + let key = key + .extract::() + .map_err(|_| conversion_error(path, key_type, &key))?; + let value = with_path_segment(path, key.clone(), |path| { + py_to_baml_value_inner(py, &value, value_type, output_format, path) + })?; + map.insert(key, value); + } + + Ok(BamlValue::Map(map)) +} + +fn py_to_list_value( + py: Python<'_>, + obj: &Bound<'_, PyAny>, + item_type: &TypeIR, + output_format: &OutputFormatContent, + path: &mut Vec, +) -> Result { + let list = if let Ok(list) = obj.cast::() { + list + } else if let Ok(tuple) = obj.cast::() { + let mut items = Vec::with_capacity(tuple.len()); + for (idx, item) in tuple.iter().enumerate() { + let value = with_path_segment(path, idx.to_string(), |path| { + py_to_baml_value_inner(py, &item, item_type, output_format, path) + })?; + items.push(value); + } + return Ok(BamlValue::List(items)); + } else { + if let Some(value) = + orjson_fallback_to_baml(py, obj, &TypeIR::list(item_type.clone()), output_format) + { + return Ok(value); + } + return Err(conversion_error( + path, + &TypeIR::list(item_type.clone()), + obj, + )); + }; + + let mut items = Vec::with_capacity(list.len()); + for (idx, item) in list.iter().enumerate() { + let value = with_path_segment(path, idx.to_string(), |path| { + py_to_baml_value_inner(py, &item, item_type, output_format, path) + })?; + items.push(value); + } + + Ok(BamlValue::List(items)) +} + +fn py_to_tuple_value( + py: Python<'_>, + obj: &Bound<'_, PyAny>, + items: &[TypeIR], + output_format: &OutputFormatContent, + path: &mut Vec, +) -> Result { + if let Ok(tuple) = obj.cast::() { + if tuple.len() != items.len() { + return Err(conversion_error(path, &TypeIR::tuple(items.to_vec()), obj)); + } + let mut values = Vec::with_capacity(items.len()); + for (idx, (item, item_type)) in tuple.iter().zip(items.iter()).enumerate() { + let value = with_path_segment(path, idx.to_string(), |path| { + py_to_baml_value_inner(py, &item, item_type, output_format, path) + })?; + values.push(value); + } + return Ok(BamlValue::List(values)); + } + + if let Ok(list) = obj.cast::() { + if list.len() != items.len() { + return Err(conversion_error(path, &TypeIR::tuple(items.to_vec()), obj)); + } + let mut values = Vec::with_capacity(items.len()); + for (idx, (item, item_type)) in list.iter().zip(items.iter()).enumerate() { + let value = with_path_segment(path, idx.to_string(), |path| { + py_to_baml_value_inner(py, &item, item_type, output_format, path) + })?; + values.push(value); + } + return Ok(BamlValue::List(values)); + } + + Err(conversion_error(path, &TypeIR::tuple(items.to_vec()), obj)) +} + +fn py_any_to_baml_value_untyped( + py: Python<'_>, + obj: &Bound<'_, PyAny>, +) -> Result { + if obj.is_none() { + return Ok(BamlValue::Null); + } + + if obj.is_instance_of::() { + return obj + .extract::() + .map(BamlValue::Bool) + .map_err(py_err_to_parse); + } + + if let Ok(value) = obj.extract::() { + return Ok(BamlValue::Int(value)); + } + + if let Ok(value) = obj.extract::() { + return Ok(BamlValue::Float(value)); + } + + if let Ok(value) = obj.extract::() { + return Ok(BamlValue::String(value)); + } + + if let Ok(dict) = obj.cast::() { + let mut map = BamlMap::new(); + for (key, value) in dict.iter() { + let key = key.extract::().map_err(py_err_to_parse)?; + let value = py_any_to_baml_value_untyped(py, &value)?; + map.insert(key, value); + } + return Ok(BamlValue::Map(map)); + } + + if let Ok(list) = obj.cast::() { + let mut items = Vec::with_capacity(list.len()); + for item in list.iter() { + items.push(py_any_to_baml_value_untyped(py, &item)?); + } + return Ok(BamlValue::List(items)); + } + + if let Ok(tuple) = obj.cast::() { + let mut items = Vec::with_capacity(tuple.len()); + for item in tuple.iter() { + items.push(py_any_to_baml_value_untyped(py, &item)?); + } + return Ok(BamlValue::List(items)); + } + + let raw = python_object_to_json_string(py, obj)?; + let parsed: JsonValue = + serde_json::from_str(&raw).map_err(|err| BamlParseError::Jsonish(anyhow!(err)))?; + Ok(json_value_to_baml_value(&parsed)) +} + +fn python_object_to_json_string( + py: Python<'_>, + obj: &Bound<'_, PyAny>, +) -> Result { + if let Ok(orjson) = PyModule::import(py, "orjson") + && let Ok(dumps) = orjson.getattr("dumps") + && let Ok(raw) = dumps.call1((obj,)) + && let Ok(bytes) = raw.extract::>() + { + return String::from_utf8(bytes).map_err(|err| BamlParseError::Jsonish(anyhow!(err))); + } + + let json = PyModule::import(py, "json").map_err(py_err_to_parse)?; + let dumps = json.getattr("dumps").map_err(py_err_to_parse)?; + dumps + .call1((obj,)) + .map_err(py_err_to_parse)? + .extract::() + .map_err(py_err_to_parse) +} + +fn json_value_to_py(py: Python<'_>, value: &JsonValue) -> Py { + match value { + JsonValue::Null => py.None(), + JsonValue::Bool(value) => value.into_py_any(py).unwrap_or_else(|_| py.None()), + JsonValue::Number(value) => value + .as_i64() + .map(|value| value.into_py_any(py).unwrap_or_else(|_| py.None())) + .or_else(|| { + value + .as_f64() + .map(|value| value.into_py_any(py).unwrap_or_else(|_| py.None())) + }) + .unwrap_or_else(|| py.None()), + JsonValue::String(value) => value.clone().into_py_any(py).unwrap_or_else(|_| py.None()), + JsonValue::Array(values) => { + let list = PyList::empty(py); + for item in values { + let _ = list.append(json_value_to_py(py, item)); + } + list.into_any().unbind() + } + JsonValue::Object(values) => { + let dict = PyDict::new(py); + for (key, value) in values { + let _ = dict.set_item(key, json_value_to_py(py, value)); + } + dict.into_any().unbind() + } + } +} + +fn json_value_to_baml_value(value: &JsonValue) -> BamlValue { + match value { + JsonValue::Null => BamlValue::Null, + JsonValue::Bool(value) => BamlValue::Bool(*value), + JsonValue::Number(value) => { + if let Some(value) = value.as_i64() { + BamlValue::Int(value) + } else if let Some(value) = value.as_f64() { + BamlValue::Float(value) + } else { + BamlValue::Null + } + } + JsonValue::String(value) => BamlValue::String(value.clone()), + JsonValue::Array(values) => { + BamlValue::List(values.iter().map(json_value_to_baml_value).collect()) + } + JsonValue::Object(values) => BamlValue::Map( + values + .iter() + .map(|(key, value)| (key.clone(), json_value_to_baml_value(value))) + .collect(), + ), + } +} + +fn resolve_recursive_type(r#type: &TypeIR, output_format: &OutputFormatContent) -> TypeIR { + let mut current = r#type.clone(); + loop { + let next = match ¤t { + TypeIR::RecursiveTypeAlias { name, .. } => output_format + .structural_recursive_aliases + .get(name) + .cloned(), + _ => None, + }; + + match next { + Some(next) => current = next, + None => return current, + } + } +} + +fn find_class<'a>(output_format: &'a OutputFormatContent, class_name: &str) -> Option<&'a Class> { + let key = (class_name.to_string(), StreamingMode::NonStreaming); + if let Some(class) = output_format.classes.get(&key) { + return Some(class); + } + + output_format + .classes + .iter() + .find(|((name, _), _)| name == class_name) + .map(|(_, class)| class) +} + +fn is_string_target(r#type: &TypeIR) -> bool { + matches!( + r#type, + TypeIR::Primitive(TypeValue::String, _) | TypeIR::Literal(LiteralValue::String(_), _) + ) +} + +fn conversion_error(path: &[String], expected: &TypeIR, got: &Bound<'_, PyAny>) -> BamlParseError { + let got_type = py_type_name(got); + BamlParseError::Convert(BamlConvertError::new( + path.to_vec(), + "schema", + got_type, + format!("expected {}", schema_type_name(expected)), + )) +} + +fn with_path_segment( + path: &mut Vec, + segment: String, + convert: impl FnOnce(&mut Vec) -> Result, +) -> Result { + path.push(segment); + let result = convert(path); + path.pop(); + result +} + +fn schema_type_name(type_ir: &TypeIR) -> String { + crate::core::render_type_name_for_prompt_with(type_ir, crate::core::simplify_type_token) +} + +fn missing_field_error(path: &[String], field: &str) -> BamlParseError { + let mut full_path = path.to_vec(); + full_path.push(field.to_string()); + + BamlParseError::Convert(BamlConvertError::new( + full_path, + "field", + "missing", + format!("missing required field {field}"), + )) +} + +fn py_type_name(obj: &Bound<'_, PyAny>) -> String { + obj.get_type() + .name() + .ok() + .and_then(|name| name.extract::().ok()) + .unwrap_or_else(|| "".to_string()) +} + +fn py_err_to_parse(err: pyo3::PyErr) -> BamlParseError { + BamlParseError::Jsonish(anyhow!(err.to_string())) +} + +fn orjson_fallback_to_baml( + py: Python<'_>, + obj: &Bound<'_, PyAny>, + r#type: &TypeIR, + output_format: &OutputFormatContent, +) -> Option { + let raw = python_object_to_json_string(py, obj).ok()?; + let parsed = jsonish::from_str(output_format, r#type, &raw, true).ok()?; + Some(BamlValue::from(parsed)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use bamltype::baml_types::ir_type::UnionConstructor; + use pyo3::prelude::*; + use pyo3::types::{PyDict, PyDictMethods}; + use tokio::runtime::Handle; + + use super::*; + use crate::BamlType; + use crate::Signature; + use crate::modules::rlm::{LlmQuery, SubmitSlot}; + + #[derive(Signature, Clone, Debug)] + struct BridgeSig { + #[input] + question: String, + + #[input] + count: i64, + + #[output] + answer: String, + + #[output] + #[check("this >= 0.0", label = "non_negative")] + score: f64, + } + + #[derive(Signature, Clone, Debug)] + struct AssertSig { + #[input] + prompt: String, + + #[output] + #[assert("this > 0", label = "positive")] + score: i64, + } + + #[derive(Signature, Clone, Debug)] + struct ReservedNameSig { + #[input] + llm_query: String, + + #[output] + answer: String, + } + + #[pyclass] + #[BamlType] + #[derive(Clone, Debug)] + struct MethodFixture { + label: String, + } + + #[pymethods] + impl MethodFixture { + #[new] + fn new(label: String) -> Self { + Self { label } + } + + #[pyo3(text_signature = "(query)")] + /// Search entries by query text. + fn search(&self, query: String) -> String { + format!("{}:{query}", self.label) + } + + /// Return the character count for this fixture label. + fn __len__(&self) -> usize { + self.label.chars().count() + } + + fn undocumented(&self) -> String { + self.label.clone() + } + } + + #[derive(Signature, Clone, Debug)] + struct MethodFixtureSig { + #[input] + trajectory: MethodFixture, + + #[output] + answer: String, + } + + #[derive(Signature, Clone, Debug)] + struct MethodFixtureListSig { + #[input] + trajectories: Vec, + + #[output] + answer: String, + } + + #[pyclass] + #[BamlType] + #[derive(Clone, Debug)] + struct NoAnnotationsChild { + label: String, + } + + #[pymethods] + impl NoAnnotationsChild { + #[new] + fn new(label: String) -> Self { + Self { label } + } + + /// Thread view for this child fixture. + fn thread(&self, participants: Vec) -> String { + format!("{}:{}", self.label, participants.join(",")) + } + } + + #[pyclass] + #[BamlType] + #[derive(Clone, Debug)] + struct NoAnnotationsContainer { + items: Vec, + } + + #[derive(Signature, Clone, Debug)] + struct NoAnnotationsSig { + #[input] + container: NoAnnotationsContainer, + + #[output] + answer: String, + } + + struct MockLm; + + #[async_trait::async_trait] + impl LlmQuery for MockLm { + async fn query(&self, prompt: &str) -> anyhow::Result { + Ok(format!("mock:{prompt}")) + } + } + + #[test] + fn baml_value_to_py_supports_common_types() { + Python::attach(|py| { + let value = BamlValue::Map(BamlMap::from_iter([ + ("name".to_string(), BamlValue::String("alice".to_string())), + ( + "nums".to_string(), + BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2)]), + ), + ("ok".to_string(), BamlValue::Bool(true)), + ( + "nested".to_string(), + BamlValue::Class( + "Nested".to_string(), + BamlMap::from_iter([("x".to_string(), BamlValue::Float(1.25))]), + ), + ), + ])); + + let py_obj = baml_value_to_py(py, &value).expect("convert to py"); + let dict = py_obj.bind(py).cast::().expect("dict"); + assert_eq!( + dict.get_item("name") + .expect("getitem") + .expect("name") + .extract::() + .expect("name str"), + "alice" + ); + assert!( + dict.get_item("ok") + .expect("getitem") + .expect("ok") + .extract::() + .expect("ok bool") + ); + }); + } + + #[test] + fn kwargs_to_baml_value_validates_typed_fields() { + Python::attach(|py| { + let kwargs = PyDict::new(py); + kwargs.set_item("answer", "done").expect("set answer"); + kwargs.set_item("score", 0.85).expect("set score"); + + let converted = kwargs_to_baml_value::(py, &kwargs).expect("convert kwargs"); + let BamlValue::Class(_, fields) = converted else { + panic!("expected class output"); + }; + assert_eq!( + fields.get("answer"), + Some(&BamlValue::String("done".to_string())) + ); + assert_eq!(fields.get("score"), Some(&BamlValue::Float(0.85))); + }); + } + + #[test] + fn kwargs_to_baml_value_reports_type_error_context() { + Python::attach(|py| { + let kwargs = PyDict::new(py); + kwargs.set_item("answer", "done").expect("set answer"); + kwargs.set_item("score", "oops").expect("set score"); + + let err = kwargs_to_baml_value::(py, &kwargs).expect_err("should fail"); + match err { + BamlParseError::Convert(err) => { + assert_eq!(err.path.first().map(|s| s.as_str()), Some("score")); + } + other => panic!("unexpected error: {other}"), + } + }); + } + + #[test] + fn collect_checks_for_output_reports_assert_failures() { + let value = BamlValue::Map(BamlMap::from_iter([( + "score".to_string(), + BamlValue::Int(-1), + )])); + + let err = collect_checks_for_output::(&value).expect_err("assert should fail"); + match err { + BamlParseError::ConstraintAssertsFailed { failed } => { + assert_eq!(failed.len(), 1); + assert_eq!(failed[0].name, "positive"); + } + other => panic!("unexpected error: {other}"), + } + } + + #[test] + fn setup_interpreter_globals_injects_inputs_and_tools() { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + runtime.block_on(async { + Python::attach(|py| { + let slot: SubmitSlot = Arc::new(std::sync::Mutex::new(None)); + let submit = SubmitHandler::new::(Arc::clone(&slot)); + let tools = LlmTools::with_budget(Arc::new(MockLm), 2, Handle::current()); + + let input = BridgeSigInput { + question: "what?".to_string(), + count: 3, + }; + + let setup = + setup_interpreter_globals::(py, &input, &submit, Some(&tools)) + .expect("setup globals"); + let globals = setup.globals.bind(py).clone(); + + assert!(globals.get_item("question").expect("getitem").is_some()); + assert!(globals.get_item("count").expect("getitem").is_some()); + assert!(globals.get_item("llm_query").expect("getitem").is_some()); + assert!( + globals + .get_item("llm_query_batched") + .expect("getitem") + .is_some() + ); + assert!(globals.get_item("SUBMIT").expect("getitem").is_some()); + assert!(setup.methods_by_var.contains_key("question")); + assert!(setup.methods_by_var.contains_key("count")); + assert!(setup.methods_by_type.contains_key("str")); + assert!(setup.methods_by_type.contains_key("int")); + }); + }); + } + + #[test] + fn setup_interpreter_globals_without_sub_lm_tools_still_injects_submit_and_inputs() { + Python::attach(|py| { + let slot: SubmitSlot = Arc::new(std::sync::Mutex::new(None)); + let submit = SubmitHandler::new::(Arc::clone(&slot)); + let input = BridgeSigInput { + question: "what?".to_string(), + count: 3, + }; + + let setup = setup_interpreter_globals::(py, &input, &submit, None) + .expect("setup globals"); + let globals = setup.globals.bind(py).clone(); + + assert!(globals.get_item("question").expect("getitem").is_some()); + assert!(globals.get_item("count").expect("getitem").is_some()); + assert!(globals.get_item("SUBMIT").expect("getitem").is_some()); + assert!(globals.get_item("llm_query").expect("getitem").is_none()); + assert!( + globals + .get_item("llm_query_batched") + .expect("getitem") + .is_none() + ); + assert!(setup.methods_by_var.contains_key("question")); + assert!(setup.methods_by_var.contains_key("count")); + assert!(setup.methods_by_type.contains_key("str")); + assert!(setup.methods_by_type.contains_key("int")); + }); + } + + #[test] + fn setup_interpreter_globals_rejects_reserved_input_names() { + Python::attach(|py| { + let slot: SubmitSlot = Arc::new(std::sync::Mutex::new(None)); + let submit = SubmitHandler::new::(Arc::clone(&slot)); + let input = ReservedNameSigInput { + llm_query: "collision".to_string(), + }; + + let err = setup_interpreter_globals::(py, &input, &submit, None) + .expect_err("reserved input names should fail setup"); + let message = err.to_string(); + assert!(message.contains("llm_query")); + assert!(message.contains("reserved runtime binding")); + }); + } + + #[test] + fn setup_interpreter_globals_collects_filtered_method_metadata() { + Python::attach(|py| { + let slot: SubmitSlot = Arc::new(std::sync::Mutex::new(None)); + let submit = SubmitHandler::new::(Arc::clone(&slot)); + let input = MethodFixtureSigInput { + trajectory: MethodFixture { + label: "root".to_string(), + }, + }; + + let setup = setup_interpreter_globals::(py, &input, &submit, None) + .expect("setup globals"); + let methods = setup + .methods_by_var + .get("trajectory") + .expect("trajectory methods"); + let type_methods = setup + .methods_by_type + .get("MethodFixture") + .expect("MethodFixture methods"); + + assert_eq!( + setup.methods_by_var.keys().collect::>(), + vec![&"trajectory".to_string()], + "keys must match injected variable names" + ); + assert!( + methods.windows(2).all(|w| w[0].name <= w[1].name), + "method list should be deterministic and sorted by name" + ); + assert!(methods.iter().any(|m| m.name == "search")); + assert!(methods.iter().any(|m| m.name == "__len__")); + assert!(methods.iter().any(|m| m.name == "undocumented")); + assert!(!methods.iter().any(|m| m.name == "__baml__")); + assert!(type_methods.iter().any(|m| m.name == "search")); + + let search = methods + .iter() + .find(|m| m.name == "search") + .expect("search method metadata"); + assert!(search.signature.contains("query")); + assert!(!search.signature.contains("self")); + assert!(search.doc.contains("Search entries")); + assert!(matches!(search.source, MethodSource::Custom)); + assert!(!search.is_dunder); + + let undocumented = methods + .iter() + .find(|m| m.name == "undocumented") + .expect("undocumented method metadata"); + assert!(undocumented.doc.is_empty()); + assert!(matches!(undocumented.source, MethodSource::Custom)); + assert!(!undocumented.is_dunder); + + let dunder_len = methods + .iter() + .find(|m| m.name == "__len__") + .expect("__len__ metadata"); + assert!(dunder_len.is_dunder); + assert!(matches!(dunder_len.source, MethodSource::Generated)); + assert!(!dunder_len.doc.trim().is_empty()); + }); + } + + #[test] + fn setup_interpreter_globals_collects_reachable_nested_type_methods() { + Python::attach(|py| { + let slot: SubmitSlot = Arc::new(std::sync::Mutex::new(None)); + let submit = SubmitHandler::new::(Arc::clone(&slot)); + let input = MethodFixtureListSigInput { + trajectories: vec![MethodFixture { + label: "root".to_string(), + }], + }; + + let setup = setup_interpreter_globals::(py, &input, &submit, None) + .expect("setup globals"); + let nested_type_methods = setup + .methods_by_type + .get("MethodFixture") + .expect("nested MethodFixture methods"); + + assert!( + nested_type_methods.iter().any(|m| m.name == "search"), + "nested type methods should include custom MethodFixture methods" + ); + }); + } + + #[test] + fn setup_interpreter_globals_collects_schema_nested_type_methods_without_runtime_instance() { + Python::attach(|py| { + let _unused = Py::new( + py, + NoAnnotationsChild { + label: "seed".to_string(), + }, + ) + .expect("seed nested class type object"); + + let slot: SubmitSlot = Arc::new(std::sync::Mutex::new(None)); + let submit = SubmitHandler::new::(Arc::clone(&slot)); + let input = NoAnnotationsSigInput { + container: NoAnnotationsContainer { items: Vec::new() }, + }; + + let setup = + setup_interpreter_globals::(py, &input, &submit, None) + .expect("setup globals"); + let nested_methods = setup + .methods_by_type + .get("NoAnnotationsChild") + .expect("nested schema type methods"); + + assert!( + nested_methods.iter().any(|m| m.name == "thread"), + "schema-driven class lookup should collect nested type methods even when the input graph has no nested instances" + ); + }); + } + + #[test] + fn sanitize_signature_removes_python_self_variants() { + assert_eq!( + sanitize_signature("($self, path_fragment)"), + "(path_fragment)" + ); + assert_eq!( + sanitize_signature("($self, /, path_fragment)"), + "(path_fragment)" + ); + assert_eq!( + sanitize_signature("(self, /, path_fragment)"), + "(path_fragment)" + ); + assert_eq!(sanitize_signature("($self, /)"), "()"); + } + + #[test] + fn sanitize_signature_simplifies_qualified_type_paths() { + let raw = "(query: builtins.str, other: tanha.types.Sessions) -> tanha.types.Sessions"; + let sanitized = sanitize_signature(raw); + assert!(!sanitized.contains("builtins.")); + assert!(!sanitized.contains("tanha.types.")); + assert!(sanitized.contains("str")); + assert!(sanitized.contains("Sessions")); + } + + #[test] + fn union_attempts_do_not_leak_path_segments_between_branches() { + Python::attach(|py| { + let list = PyList::empty(py); + list.append(3).expect("append"); + + let union = TypeIR::union(vec![ + TypeIR::list(TypeIR::literal_int(1)), + TypeIR::list(TypeIR::literal_int(2)), + ]); + let output_format = BridgeSig::schema().output_format(); + + let err = py_to_baml_value(py, list.as_any(), &union, output_format) + .expect_err("union should fail to parse mismatched literal"); + match err { + BamlParseError::Convert(err) => { + assert_eq!( + err.path, + vec!["0".to_string()], + "path should represent one nesting level, not accumulate from prior union attempts" + ); + } + other => panic!("unexpected error: {other}"), + } + }); + } +} diff --git a/crates/dspy-rs/src/modules/rlm/runtime.rs b/crates/dspy-rs/src/modules/rlm/runtime.rs new file mode 100644 index 00000000..a29dc97b --- /dev/null +++ b/crates/dspy-rs/src/modules/rlm/runtime.rs @@ -0,0 +1,170 @@ +use std::collections::BTreeMap; +use std::sync::Arc; + +use super::exec; +use super::py_bridge; +use super::submit; +use super::tools; +use crate::Signature; +use pyo3::types::{PyAny, PyDict, PyDictMethods}; +use pyo3::{Bound, Py, PyResult, Python}; + +pub type SubmitResultDyn = submit::SubmitResultDyn; +pub type SubmitSlot = submit::SubmitSlot; +pub type SubmitError = submit::SubmitError; +pub type SubmitHandler = submit::SubmitHandler; +pub type LlmTools = tools::LlmTools; + +pub use submit::{clear_submit_slot, take_submit_result}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum MethodSource { + Generated, + Custom, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct MethodSignature { + pub name: String, + pub signature: String, + pub doc: String, + pub source: MethodSource, + pub is_dunder: bool, +} + +#[derive(Debug)] +pub struct InterpreterSetup { + pub globals: Py, + pub methods_by_var: BTreeMap>, + pub methods_by_type: BTreeMap>, +} + +pub trait RlmInputFields { + fn rlm_field_names(&self) -> &'static [&'static str]; + + fn rlm_py_fields(&self, py: Python<'_>) -> PyResult)>>; + + fn inject_into_python<'py>( + &self, + py: Python<'py>, + globals: &Bound<'py, PyDict>, + ) -> PyResult<()> { + for (name, obj) in self.rlm_py_fields(py)? { + globals.set_item(name, obj)?; + } + Ok(()) + } +} + +/// Runtime abstraction for REPL-backed RLM execution. +/// +/// V1 ships with a stub implementation in this crate. Another module can provide +/// a concrete PyO3-backed implementation by implementing this trait and wiring it +/// through `RlmBuilder::runtime(...)`. +pub trait RlmRuntime: Send + Sync { + /// Whether this runtime needs sub-LM tools (`llm_query*`) to be installed. + /// Stub runtimes can return `false` so tests can run without sub-LM wiring. + fn requires_sub_lm_tools(&self) -> bool { + true + } + + fn setup_interpreter_globals( + &self, + py: Python<'_>, + input: &S::Input, + submit_handler: &SubmitHandler, + llm_tools: Option<&LlmTools>, + ) -> PyResult + where + S::Input: RlmInputFields; + + fn execute_repl_code( + &self, + py: Python<'_>, + globals: &Py, + code: &str, + max_output_chars: usize, + ) -> Result; + + fn sub_lm_budget_remaining(&self, llm_tools: Option<&LlmTools>) -> usize; +} + +#[derive(Default, Debug, Clone)] +pub struct StubRuntime; + +impl StubRuntime { + pub fn new(_max_llm_calls: usize) -> Self { + Self + } +} + +impl RlmRuntime for StubRuntime { + fn requires_sub_lm_tools(&self) -> bool { + false + } + + fn setup_interpreter_globals( + &self, + py: Python<'_>, + _input: &S::Input, + _submit_handler: &SubmitHandler, + _llm_tools: Option<&LlmTools>, + ) -> PyResult + where + S::Input: RlmInputFields, + { + Ok(InterpreterSetup { + globals: PyDict::new(py).unbind(), + methods_by_var: BTreeMap::new(), + methods_by_type: BTreeMap::new(), + }) + } + + fn execute_repl_code( + &self, + _py: Python<'_>, + _globals: &Py, + _code: &str, + _max_output_chars: usize, + ) -> Result { + Ok(String::new()) + } + + fn sub_lm_budget_remaining(&self, _llm_tools: Option<&LlmTools>) -> usize { + 0 + } +} + +#[derive(Default, Debug, Clone)] +pub struct PyO3Runtime; + +impl RlmRuntime for PyO3Runtime { + fn setup_interpreter_globals( + &self, + py: Python<'_>, + input: &S::Input, + submit_handler: &SubmitHandler, + llm_tools: Option<&LlmTools>, + ) -> PyResult + where + S::Input: RlmInputFields, + { + py_bridge::setup_interpreter_globals::(py, input, submit_handler, llm_tools) + } + + fn execute_repl_code( + &self, + py: Python<'_>, + globals: &Py, + code: &str, + max_output_chars: usize, + ) -> Result { + exec::execute_repl_code(py, globals, code, max_output_chars) + } + + fn sub_lm_budget_remaining(&self, llm_tools: Option<&LlmTools>) -> usize { + llm_tools.map(LlmTools::remaining_calls).unwrap_or(0) + } +} + +pub type DynRuntime = Arc>; diff --git a/crates/dspy-rs/src/modules/rlm/submit.rs b/crates/dspy-rs/src/modules/rlm/submit.rs new file mode 100644 index 00000000..ea65616b --- /dev/null +++ b/crates/dspy-rs/src/modules/rlm/submit.rs @@ -0,0 +1,551 @@ +use std::collections::HashSet; +use std::sync::{Arc, Mutex}; + +use bamltype::BamlParseError; +use indexmap::IndexMap; +use pyo3::exceptions::PyException; +use pyo3::prelude::*; +use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods}; + +use crate::{ + BamlValue, ConstraintKind, ConstraintResult, FieldMeta, ResponseCheck, Signature, + SignatureSchema, +}; + +/// Type-erased SUBMIT result used by the outer loop controller. +pub type SubmitResultDyn = Result<(BamlValue, IndexMap), SubmitError>; + +/// Shared storage slot written by SUBMIT and consumed by the RLM loop. +pub type SubmitSlot = Arc>>; + +#[derive(Debug, Clone)] +pub enum SubmitError { + ValidationError { + message: String, + errors: Vec, + }, + AssertionFailed { + label: String, + expression: String, + }, +} + +struct ParsedDyn { + baml_value: BamlValue, + checks: Vec, +} + +type ParseFn = dyn for<'py> Fn(Python<'py>, &Bound<'py, PyDict>) -> Result + + Send + + Sync; + +pyo3::create_exception!( + dspy_rs_rlm, + SubmitTerminated, + PyException, + "Raised to terminate REPL execution after a successful SUBMIT." +); + +pub const SUBMIT_STDOUT_ATTR: &str = "__dsrs_stdout__"; + +pub fn is_submit_terminated(err: &PyErr, py: Python<'_>) -> bool { + err.is_instance_of::(py) +} + +pub fn clear_submit_slot(slot: &SubmitSlot) { + set_submit_result(slot, None); +} + +pub fn take_submit_result(slot: &SubmitSlot) -> Option { + slot.lock().expect("submit slot lock poisoned").take() +} + +fn set_submit_result(slot: &SubmitSlot, value: Option) { + *slot.lock().expect("submit slot lock poisoned") = value; +} + +#[pyclass] +#[derive(Clone)] +pub struct SubmitHandler { + parse_fn: Arc, + schema: Arc, + slot: SubmitSlot, + schema_description: String, + output_fields_lm: Vec, + output_fields_set: HashSet, +} + +impl SubmitHandler { + pub fn new(slot: SubmitSlot) -> Self { + let schema = Arc::new(S::schema().clone()); + let schema_description = generate_schema_description(schema.as_ref()); + let output_fields_lm = schema + .output_fields() + .iter() + .map(|field| field.lm_name.to_string()) + .collect::>(); + let output_fields_set = output_fields_lm.iter().cloned().collect::>(); + + let parse_fn: Arc = Arc::new(|py, kwargs| { + let baml_value = super::py_bridge::kwargs_to_baml_value::(py, kwargs)?; + let checks = super::py_bridge::collect_checks_for_output::(&baml_value)?; + Ok(ParsedDyn { baml_value, checks }) + }); + + Self { + parse_fn, + schema, + slot, + schema_description, + output_fields_lm, + output_fields_set, + } + } + + pub fn with_new_slot() -> (Self, SubmitSlot) { + let slot = Arc::new(Mutex::new(None)); + (Self::new::(Arc::clone(&slot)), slot) + } +} + +#[pymethods] +impl SubmitHandler { + #[pyo3(signature = (**kwargs))] + fn __call__(&self, py: Python<'_>, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult { + let kwargs = kwargs.ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "SUBMIT requires keyword arguments. Usage: SUBMIT(field1=value1, field2=value2)", + ) + })?; + + let mut unexpected = Vec::new(); + for (key, _) in kwargs.iter() { + let key = key.extract::().map_err(py_err_to_value)?; + if !self.output_fields_set.contains(&key) { + unexpected.push(key); + } + } + unexpected.sort(); + + let mut missing = Vec::new(); + for field in &self.output_fields_lm { + let present = kwargs.contains(field.as_str()).map_err(py_err_to_value)?; + if !present { + missing.push(field.clone()); + } + } + + if !missing.is_empty() || !unexpected.is_empty() { + let usage = format_submit_usage(&self.output_fields_lm); + let mut errors = Vec::new(); + if !missing.is_empty() { + errors.push(format!("missing fields: {:?}", missing)); + } + if !unexpected.is_empty() { + errors.push(format!("unexpected fields: {:?}", unexpected)); + } + errors.push(format!("use SUBMIT({usage})")); + + let message = match (missing.is_empty(), unexpected.is_empty()) { + (false, true) => "Missing output fields".to_string(), + (true, false) => "Unexpected output fields".to_string(), + (false, false) => "Invalid output fields".to_string(), + (true, true) => unreachable!(), + }; + + let user_message = format_submit_error("Validation failed", &errors, None); + set_submit_result( + &self.slot, + Some(Err(SubmitError::ValidationError { message, errors })), + ); + return Ok(user_message); + } + + let parsed_result = (self.parse_fn)(py, kwargs); + + match parsed_result { + Ok(parsed) => { + let ParsedDyn { baml_value, checks } = parsed; + let raw_text = serde_json::to_string(&baml_value) + .unwrap_or_else(|_| "".to_string()); + let metas = build_field_metas(&checks, &raw_text); + set_submit_result(&self.slot, Some(Ok((baml_value, metas)))); + + Err(SubmitTerminated::new_err("SUBMIT accepted")) + } + Err(BamlParseError::ConstraintAssertsFailed { failed }) => { + let failure = failed.first().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "SUBMIT assertion failed with no details", + ) + })?; + + set_submit_result( + &self.slot, + Some(Err(SubmitError::AssertionFailed { + label: failure.name.clone(), + expression: failure.expression.clone(), + })), + ); + + Ok(format_submit_error( + "Assertion failed", + &[format!( + "'{}': {} (please fix and try again)", + failure.name, failure.expression + )], + None, + )) + } + Err(err) => { + let errors = format_parse_errors(kwargs, &self.schema, &err); + set_submit_result( + &self.slot, + Some(Err(SubmitError::ValidationError { + message: err.to_string(), + errors: errors.clone(), + })), + ); + + Ok(format_submit_error( + "Validation failed", + &errors, + if self.schema_description.is_empty() { + None + } else { + Some(self.schema_description.as_str()) + }, + )) + } + } + } + + pub fn schema(&self) -> String { + self.schema_description.clone() + } +} + +fn build_field_metas(checks: &[ResponseCheck], raw_json: &str) -> IndexMap { + let mut metas = IndexMap::new(); + let mut meta = FieldMeta { + raw_text: raw_json.to_string(), + flags: Vec::new(), + checks: Vec::new(), + }; + + for check in checks { + meta.checks.push(ConstraintResult { + label: check.name.clone(), + expression: check.expression.clone(), + passed: check.status == "succeeded", + }); + } + + metas.insert("_root".to_string(), meta); + metas +} + +fn format_parse_errors( + kwargs: &Bound<'_, PyDict>, + schema: &SignatureSchema, + err: &BamlParseError, +) -> Vec { + match err { + BamlParseError::Convert(err) => vec![format_convert_error(kwargs, schema, err)], + BamlParseError::Jsonish(err) => vec![err.to_string()], + BamlParseError::ConstraintAssertsFailed { failed } => failed + .iter() + .map(|check| format!("assertion '{}' failed: {}", check.name, check.expression)) + .collect(), + } +} + +fn format_convert_error( + kwargs: &Bound<'_, PyDict>, + schema: &SignatureSchema, + err: &crate::BamlConvertError, +) -> String { + if err.expected == "field" && err.got == "missing" { + return format!("missing required field: {}", err.path_string()); + } + + let expected = err + .message + .strip_prefix("expected ") + .unwrap_or(err.expected) + .trim(); + let expected = to_python_type_name(expected); + let got = to_python_type_name(err.got.as_str()); + + let field_path = err.path_string(); + let value_repr = first_path_value_repr(kwargs, schema, &err.path); + + match value_repr { + Some(value_repr) => format!( + "field '{}' expected {}, got {} {}", + field_path, expected, got, value_repr + ), + None => format!("field '{}' expected {}, got {}", field_path, expected, got), + } +} + +fn to_python_type_name(raw: &str) -> String { + let trimmed = raw.trim(); + let base = trimmed.strip_prefix("BamlValue::").unwrap_or(trimmed); + match base { + "String" => "str".to_string(), + "Int" => "int".to_string(), + "Float" => "float".to_string(), + "Bool" => "bool".to_string(), + "Null" => "None".to_string(), + "List" => "list".to_string(), + "Map" | "Class" => "dict".to_string(), + "Enum" => "enum".to_string(), + "Media" => "media".to_string(), + other => other.to_string(), + } +} + +fn format_submit_error(summary: &str, details: &[String], schema: Option<&str>) -> String { + let mut message = format!("SubmitError: {summary}"); + if !details.is_empty() { + message.push('\n'); + for detail in details { + message.push_str(" - "); + message.push_str(detail); + message.push('\n'); + } + message.pop(); + } + if let Some(schema) = schema { + message.push_str("\n\nExpected schema:\n"); + message.push_str(schema); + } + message +} + +fn first_path_value_repr( + kwargs: &Bound<'_, PyDict>, + schema: &SignatureSchema, + path: &[String], +) -> Option { + let first = path.first()?; + + let lm_name = schema + .output_fields() + .iter() + .find_map(|field| { + if field.rust_name == *first || field.lm_name == first { + Some(field.lm_name) + } else { + None + } + }) + .unwrap_or(first.as_str()); + + let value = kwargs.get_item(lm_name).ok().flatten()?; + value + .repr() + .ok() + .and_then(|repr| repr.extract::().ok()) +} + +fn format_submit_usage(fields: &[String]) -> String { + fields + .iter() + .map(|field| format!("{field}={field}")) + .collect::>() + .join(", ") +} + +fn generate_schema_description(schema: &SignatureSchema) -> String { + let fields = schema.output_fields(); + if fields.is_empty() { + return String::new(); + } + + let mut desc = String::new(); + desc.push_str("SUBMIT("); + desc.push_str( + &fields + .iter() + .map(|field| field.lm_name) + .collect::>() + .join(", "), + ); + desc.push_str(") where:\n"); + + for field in fields { + let type_name = crate::core::render_type_name_for_prompt_with( + &field.type_ir, + crate::core::simplify_type_token, + ); + desc.push_str(&format!(" {}: {}", field.lm_name, type_name)); + + if !field.docs.is_empty() { + desc.push_str(&format!(" # {}", field.docs)); + } + desc.push('\n'); + + for constraint in field.constraints { + let kind = match constraint.kind { + ConstraintKind::Check => "check", + ConstraintKind::Assert => "ASSERT", + }; + if constraint.label.is_empty() { + desc.push_str(&format!(" [{kind}] {}\n", constraint.expression)); + } else { + desc.push_str(&format!( + " [{kind}] {}: {}\n", + constraint.label, constraint.expression + )); + } + } + } + + desc.trim_end().to_string() +} + +fn py_err_to_value(err: pyo3::PyErr) -> pyo3::PyErr { + pyo3::exceptions::PyValueError::new_err(err.to_string()) +} + +#[cfg(test)] +mod tests { + use pyo3::types::PyDict; + + use super::*; + use crate::Signature; + + #[derive(Signature, Clone, Debug)] + struct SubmitSig { + #[input] + question: String, + + #[output] + answer: String, + + #[output] + score: f64, + } + + #[derive(Signature, Clone, Debug)] + struct SubmitAssertSig { + #[input] + question: String, + + #[output] + #[assert("this > 0", label = "positive")] + score: i64, + } + + #[test] + fn submit_success_writes_slot_and_raises_terminated() { + Python::attach(|py| { + let (handler, slot) = SubmitHandler::with_new_slot::(); + let kwargs = PyDict::new(py); + kwargs.set_item("answer", "ok").expect("set answer"); + kwargs.set_item("score", 0.9).expect("set score"); + + let err = handler + .__call__(py, Some(&kwargs)) + .expect_err("successful submit must raise SubmitTerminated"); + assert!(is_submit_terminated(&err, py)); + + let stored = take_submit_result(&slot).expect("slot must be populated"); + assert!(stored.is_ok()); + }); + } + + #[test] + fn missing_field_returns_validation_error() { + Python::attach(|py| { + let (handler, slot) = SubmitHandler::with_new_slot::(); + let kwargs = PyDict::new(py); + kwargs.set_item("answer", "ok").expect("set answer"); + + let message = handler + .__call__(py, Some(&kwargs)) + .expect("missing field should return recoverable message"); + assert!(message.contains("SubmitError: Validation failed")); + + let stored = take_submit_result(&slot).expect("slot must be populated"); + match stored { + Err(SubmitError::ValidationError { errors, .. }) => { + assert!(errors.iter().any(|err| err.contains("missing fields"))); + } + other => panic!("unexpected stored result: {other:?}"), + } + }); + } + + #[test] + fn type_mismatch_returns_detailed_field_error() { + Python::attach(|py| { + let (handler, slot) = SubmitHandler::with_new_slot::(); + let kwargs = PyDict::new(py); + kwargs.set_item("answer", "ok").expect("set answer"); + kwargs.set_item("score", "oops").expect("set score"); + + let message = handler + .__call__(py, Some(&kwargs)) + .expect("type mismatch should be recoverable"); + assert!(message.contains("field 'score'")); + assert!(message.contains("expected")); + assert!(message.contains("got")); + + let stored = take_submit_result(&slot).expect("slot must be populated"); + assert!(matches!(stored, Err(SubmitError::ValidationError { .. }))); + }); + } + + #[test] + fn assertion_failure_is_recorded() { + Python::attach(|py| { + let (handler, slot) = SubmitHandler::with_new_slot::(); + let kwargs = PyDict::new(py); + kwargs.set_item("score", -1).expect("set score"); + + let message = handler + .__call__(py, Some(&kwargs)) + .expect("assertion failure should be recoverable"); + assert!(message.contains("SubmitError: Assertion failed")); + + let stored = take_submit_result(&slot).expect("slot must be populated"); + match stored { + Err(SubmitError::AssertionFailed { label, .. }) => { + assert_eq!(label, "positive"); + } + other => panic!("unexpected stored result: {other:?}"), + } + }); + } + + #[test] + fn clear_submit_slot_removes_previous_value() { + let (handler, slot) = SubmitHandler::with_new_slot::(); + drop(handler); + + *slot.lock().expect("lock") = Some(Err(SubmitError::ValidationError { + message: "x".to_string(), + errors: vec!["y".to_string()], + })); + + clear_submit_slot(&slot); + assert!(slot.lock().expect("lock").is_none()); + } + + #[test] + fn python_type_name_mapping_covers_baml_tokens() { + assert_eq!(to_python_type_name("BamlValue::String"), "str"); + assert_eq!(to_python_type_name("BamlValue::Int"), "int"); + assert_eq!(to_python_type_name("BamlValue::Float"), "float"); + assert_eq!(to_python_type_name("BamlValue::Bool"), "bool"); + assert_eq!(to_python_type_name("BamlValue::Null"), "None"); + assert_eq!(to_python_type_name("BamlValue::List"), "list"); + assert_eq!(to_python_type_name("BamlValue::Map"), "dict"); + assert_eq!(to_python_type_name("BamlValue::Class"), "dict"); + assert_eq!(to_python_type_name("BamlValue::Enum"), "enum"); + assert_eq!(to_python_type_name("BamlValue::Media"), "media"); + } +} diff --git a/crates/dspy-rs/src/modules/rlm/tools.rs b/crates/dspy-rs/src/modules/rlm/tools.rs new file mode 100644 index 00000000..4d6b3471 --- /dev/null +++ b/crates/dspy-rs/src/modules/rlm/tools.rs @@ -0,0 +1,451 @@ +use std::future::Future; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use async_trait::async_trait; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::PyModule; +use tokio::runtime::{Handle, RuntimeFlavor}; + +use crate::LM; +use crate::core::lm::{Chat, Message, ToolLoopMode}; + +#[async_trait] +pub trait LlmQuery: Send + Sync { + async fn query(&self, prompt: &str) -> anyhow::Result; +} + +#[async_trait] +impl LlmQuery for LM { + async fn query(&self, prompt: &str) -> anyhow::Result { + let messages = Chat::new(vec![Message::user(prompt)]); + let response = self + .call(messages, Vec::new(), ToolLoopMode::CallerManaged) + .await?; + Ok(response.output.text_content()) + } +} + +#[pyclass] +#[derive(Clone)] +pub struct LlmTools { + lm: Arc, + pub max_llm_calls: usize, + budget_remaining: Arc, + handle: Handle, +} + +impl LlmTools { + pub fn new( + lm: Arc, + budget_remaining: Arc, + max_llm_calls: usize, + handle: Handle, + ) -> Self { + Self { + lm, + max_llm_calls, + budget_remaining, + handle, + } + } + + pub fn with_budget(lm: Arc, max_llm_calls: usize, handle: Handle) -> Self { + Self::new( + lm, + Arc::new(AtomicUsize::new(max_llm_calls)), + max_llm_calls, + handle, + ) + } + + #[cfg(test)] + fn call_count(&self) -> usize { + self.max_llm_calls + .saturating_sub(self.budget_remaining.load(Ordering::SeqCst)) + } + + pub fn remaining_calls(&self) -> usize { + self.budget_remaining.load(Ordering::SeqCst) + } + + fn reserve_calls(&self, count: usize) -> PyResult<()> { + loop { + let current = self.budget_remaining.load(Ordering::SeqCst); + if current < count { + return Err(PyRuntimeError::new_err(format!( + "[Error] RuntimeError: LLM call budget exhausted: requested {count}, remaining {current}, max {}. This is retryable after reducing llm_query usage.", + self.max_llm_calls + ))); + } + + if self + .budget_remaining + .compare_exchange(current, current - count, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + return Ok(()); + } + } + } + + fn reserve_calls_for_batch(&self, requested: usize) -> usize { + loop { + let current = self.budget_remaining.load(Ordering::SeqCst); + let to_execute = current.min(requested); + if self + .budget_remaining + .compare_exchange( + current, + current.saturating_sub(to_execute), + Ordering::SeqCst, + Ordering::SeqCst, + ) + .is_ok() + { + return to_execute; + } + } + } + + fn emit_budget_warning(&self, executed: usize, requested: usize) { + if executed >= requested { + return; + } + let remaining = self.remaining_calls(); + let warning = format!( + "⚠ Budget: executed first {executed} of {requested} requested queries ({remaining} remaining of {} max). \ +results[i] aligns to prompts[i] for i < {executed}; skipped prompts[{executed}..{requested}].", + self.max_llm_calls, + ); + Python::attach(|py| { + if let Ok(builtins) = PyModule::import(py, "builtins") + && let Ok(print_fn) = builtins.getattr("print") + { + let _ = print_fn.call1((warning,)); + } + }); + } + + fn ensure_prompt(prompt: &str) -> PyResult<()> { + if prompt.trim().is_empty() { + return Err(PyValueError::new_err( + "[Error] ValueError: prompt cannot be empty", + )); + } + Ok(()) + } + + fn block_with_runtime(&self, fut: F) -> PyResult + where + F: Future, + { + let current_handle = Handle::try_current().map_err(|err| { + Self::runtime_error(format!("an active Tokio runtime is required: {err}")) + })?; + if current_handle.runtime_flavor() == RuntimeFlavor::CurrentThread { + return Err(Self::runtime_error( + "llm_query requires a multi-thread Tokio runtime; current-thread runtime is not supported", + )); + } + + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + tokio::task::block_in_place(|| self.handle.block_on(fut)) + })) + .map_err(|_| { + Self::runtime_error( + "failed to block in the current Tokio runtime; use a multi-thread runtime", + ) + }) + } + + fn runtime_error(err: impl std::fmt::Display) -> PyErr { + PyRuntimeError::new_err(format!("[Error] RuntimeError: {err}")) + } +} + +#[pymethods] +impl LlmTools { + fn llm_query(&self, prompt: String) -> PyResult { + Self::ensure_prompt(&prompt)?; + self.reserve_calls(1)?; + + let response = self.block_with_runtime(self.lm.query(&prompt))?; + let response = response.map_err(Self::runtime_error)?; + + Ok(response) + } + + fn llm_query_batched(&self, prompts: Vec) -> PyResult> { + if prompts.is_empty() { + return Ok(Vec::new()); + } + + for prompt in &prompts { + Self::ensure_prompt(prompt)?; + } + + let requested = prompts.len(); + let executable = self.reserve_calls_for_batch(requested); + if executable == 0 { + self.emit_budget_warning(0, requested); + return Ok(Vec::new()); + } + self.emit_budget_warning(executable, requested); + + let responses = self.block_with_runtime(async { + let futures = prompts + .iter() + .take(executable) + .map(|prompt| self.lm.query(prompt)); + futures::future::join_all(futures).await + })?; + + let mut results = Vec::with_capacity(responses.len()); + for response in responses { + match response { + Ok(text) => results.push(text), + Err(err) => return Err(Self::runtime_error(err)), + } + } + Ok(results) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::sync::Mutex; + + use super::*; + + #[derive(Default)] + struct MockLm { + calls: Mutex>, + fail_on: Mutex>, + } + + #[async_trait] + impl LlmQuery for MockLm { + async fn query(&self, prompt: &str) -> anyhow::Result { + self.calls + .lock() + .expect("calls mutex poisoned") + .push(prompt.to_string()); + + if self + .fail_on + .lock() + .expect("fail_on mutex poisoned") + .contains(prompt) + { + anyhow::bail!("mock failure for {prompt}"); + } + + Ok(format!("answer:{prompt}")) + } + } + + #[test] + fn llm_query_consumes_budget_and_returns_text() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let lm = Arc::new(MockLm::default()); + let tools = LlmTools::with_budget(lm.clone(), 2, Handle::current()); + + let first = tools.llm_query("hello".to_string()).expect("first call"); + assert_eq!(first, "answer:hello"); + assert_eq!(tools.call_count(), 1); + + let second = tools.llm_query("world".to_string()).expect("second call"); + assert_eq!(second, "answer:world"); + assert_eq!(tools.call_count(), 2); + + let calls = lm.calls.lock().expect("calls lock").clone(); + assert_eq!(calls, vec!["hello".to_string(), "world".to_string()]); + }); + } + + #[test] + fn budget_exhaustion_returns_retryable_runtime_error() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let tools = LlmTools::with_budget(Arc::new(MockLm::default()), 1, Handle::current()); + let _ = tools.llm_query("one".to_string()).expect("first call"); + + let err = tools + .llm_query("two".to_string()) + .expect_err("budget should be exhausted"); + assert!(err.to_string().contains("budget exhausted")); + assert!(err.to_string().contains("retryable")); + }); + } + + #[test] + fn llm_query_batched_runs_all_prompts() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let lm = Arc::new(MockLm::default()); + let tools = LlmTools::with_budget(lm.clone(), 5, Handle::current()); + + let responses = tools + .llm_query_batched(vec!["a".to_string(), "b".to_string(), "c".to_string()]) + .expect("batched call"); + assert_eq!(responses, vec!["answer:a", "answer:b", "answer:c"]); + assert_eq!(tools.call_count(), 3); + + let mut calls = lm.calls.lock().expect("calls lock").clone(); + calls.sort(); + assert_eq!(calls, vec!["a", "b", "c"]); + }); + } + + #[test] + fn llm_query_batched_propagates_runtime_errors() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let lm = Arc::new(MockLm::default()); + lm.fail_on + .lock() + .expect("fail_on lock") + .insert("bad".to_string()); + + let tools = LlmTools::with_budget(lm, 3, Handle::current()); + let err = tools + .llm_query_batched(vec!["ok".to_string(), "bad".to_string()]) + .expect_err("second prompt should fail"); + + assert!(err.to_string().contains("mock failure for bad")); + }); + } + + #[test] + fn llm_query_batched_executes_partial_batch_when_budget_is_short() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let lm = Arc::new(MockLm::default()); + let tools = LlmTools::with_budget(lm.clone(), 2, Handle::current()); + + let responses = tools + .llm_query_batched(vec![ + "one".to_string(), + "two".to_string(), + "three".to_string(), + ]) + .expect("partial batch should succeed"); + assert_eq!(responses, vec!["answer:one", "answer:two"]); + assert_eq!(tools.remaining_calls(), 0); + + let calls = lm.calls.lock().expect("calls lock").clone(); + assert_eq!(calls, vec!["one".to_string(), "two".to_string()]); + }); + } + + #[test] + fn llm_query_batched_returns_empty_when_budget_is_zero() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let lm = Arc::new(MockLm::default()); + let tools = LlmTools::with_budget(lm.clone(), 1, Handle::current()); + let _ = tools.llm_query("one".to_string()).expect("first call"); + assert_eq!(tools.remaining_calls(), 0); + + let responses = tools + .llm_query_batched(vec!["two".to_string(), "three".to_string()]) + .expect("zero-budget batch should not error"); + assert!(responses.is_empty()); + + let calls = lm.calls.lock().expect("calls lock").clone(); + assert_eq!(calls, vec!["one".to_string()]); + }); + } + + #[test] + fn shared_budget_is_enforced_across_single_and_batched_calls() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let tools = LlmTools::with_budget(Arc::new(MockLm::default()), 3, Handle::current()); + + let first = tools.llm_query("one".to_string()).expect("first call"); + assert_eq!(first, "answer:one"); + assert_eq!(tools.remaining_calls(), 2); + + let responses = tools + .llm_query_batched(vec!["two".to_string(), "three".to_string()]) + .expect("batched call"); + assert_eq!(responses, vec!["answer:two", "answer:three"]); + assert_eq!(tools.remaining_calls(), 0); + + let err = tools + .llm_query("four".to_string()) + .expect_err("budget should be exhausted"); + assert!(err.to_string().contains("budget exhausted")); + }); + } + + #[test] + fn empty_batched_call_returns_immediately_without_consuming_budget() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let tools = LlmTools::with_budget(Arc::new(MockLm::default()), 2, Handle::current()); + + let responses = tools + .llm_query_batched(Vec::new()) + .expect("empty batch should be valid"); + assert!(responses.is_empty()); + assert_eq!(tools.remaining_calls(), 2); + }); + } + + #[test] + fn current_thread_runtime_returns_clear_error_instead_of_panicking() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + + rt.block_on(async { + let tools = LlmTools::with_budget(Arc::new(MockLm::default()), 1, Handle::current()); + let err = tools + .llm_query("hello".to_string()) + .expect_err("current-thread runtime should fail gracefully"); + + let message = err.to_string(); + assert!(message.contains("multi-thread Tokio runtime")); + assert!(message.contains("current-thread")); + }); + } +} diff --git a/crates/dspy-rs/src/optimizer/copro.rs b/crates/dspy-rs/src/optimizer/copro.rs index 736f0f87..0b7925d6 100644 --- a/crates/dspy-rs/src/optimizer/copro.rs +++ b/crates/dspy-rs/src/optimizer/copro.rs @@ -216,7 +216,7 @@ mod tests { use super::*; use crate::evaluate::{MetricOutcome, TypedMetric}; - use crate::{CallMetadata, Predict, PredictError, Predicted, Signature}; + use crate::{CallMetadata, Chat, Predict, PredictError, Predicted, Signature}; #[derive(Signature, Clone, Debug)] struct CoproStateSig { @@ -246,6 +246,7 @@ mod tests { answer: input.prompt, }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/src/optimizer/gepa.rs b/crates/dspy-rs/src/optimizer/gepa.rs index e4c799c6..90ec6ecc 100644 --- a/crates/dspy-rs/src/optimizer/gepa.rs +++ b/crates/dspy-rs/src/optimizer/gepa.rs @@ -506,7 +506,7 @@ mod tests { use super::*; use crate::evaluate::{MetricOutcome, TypedMetric}; - use crate::{CallMetadata, Predict, PredictError, Predicted, Signature}; + use crate::{CallMetadata, Chat, Predict, PredictError, Predicted, Signature}; #[derive(Signature, Clone, Debug)] struct GepaStateSig { @@ -536,6 +536,7 @@ mod tests { answer: input.prompt, }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/src/optimizer/mipro.rs b/crates/dspy-rs/src/optimizer/mipro.rs index 6f2b4136..d0023899 100644 --- a/crates/dspy-rs/src/optimizer/mipro.rs +++ b/crates/dspy-rs/src/optimizer/mipro.rs @@ -184,7 +184,7 @@ impl MIPROv2 { let input = example.input.clone(); let predicted = module.call(input).await.map_err(|err| anyhow!("{err}"))?; let outcome = metric.evaluate(example, &predicted).await?; - let (output, _) = predicted.into_parts(); + let (output, _, _) = predicted.into_parts(); traces.push(Trace::new( example.input.clone(), output.to_baml_value(), @@ -417,7 +417,7 @@ mod tests { use super::*; use crate::evaluate::{MetricOutcome, TypedMetric}; - use crate::{CallMetadata, Predict, PredictError, Predicted, Signature}; + use crate::{CallMetadata, Chat, Predict, PredictError, Predicted, Signature}; #[derive(Signature, Clone, Debug)] struct MiproStateSig { @@ -447,6 +447,7 @@ mod tests { answer: input.prompt, }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/src/predictors/predict.rs b/crates/dspy-rs/src/predictors/predict.rs index b45a0e69..f1ecf10a 100644 --- a/crates/dspy-rs/src/predictors/predict.rs +++ b/crates/dspy-rs/src/predictors/predict.rs @@ -13,7 +13,7 @@ use crate::core::{DynPredictor, Module, PredictAccessorFns, PredictState, Signat use crate::data::example::Example as RawExample; use crate::{ BamlType, BamlValue, CallMetadata, Chat, ChatAdapter, GLOBAL_SETTINGS, LmError, LmUsage, - PredictError, Predicted, Prediction, SignatureSchema, + PredictError, Predicted, Prediction, Role, SignatureSchema, ToolLoopMode, }; /// A typed input/output pair for few-shot prompting. @@ -127,6 +127,8 @@ pub struct Predict { demos: Vec>, instruction_override: Option, #[facet(skip, opaque)] + adapter: Option, + #[facet(skip, opaque)] _marker: PhantomData, } @@ -137,6 +139,7 @@ impl Predict { tools: Vec::new(), demos: Vec::new(), instruction_override: None, + adapter: None, _marker: PhantomData, } } @@ -146,10 +149,15 @@ impl Predict { PredictBuilder::new() } + /// Overrides the adapter used for prompt composition and response parsing. + pub fn adapter(mut self, adapter: ChatAdapter) -> Self { + self.adapter = Some(adapter); + self + } + /// Calls the LM with this predictor's signature, demos, and tools. /// - /// Delegates to [`forward`](Predict::forward). Both exist for symmetry with the - /// [`Module`] trait; `call` is what you use, `forward` is the implementation. + /// Convenience wrapper around [`forward`](Predict::forward) with `history = None`. #[tracing::instrument( name = "dsrs.predict.call", level = "debug", @@ -167,35 +175,49 @@ impl Predict { S::Input: BamlType, S::Output: BamlType, { - self.forward(input).await + self.forward(input, None).await } - /// Builds the prompt, calls the LM, and parses the response. - /// - /// The full pipeline: - /// 1. Format system message from the signature's schema and instruction override - /// 2. Format demo examples as user/assistant exchanges - /// 3. Format the input as the final user message - /// 4. Call the LM (with any tools attached) - /// 5. Parse the response into `S::Output` via the `[[ ## field ## ]]` protocol - /// 6. Record a trace node if inside a [`trace()`](crate::trace::trace) scope + /// Canonical typed predict path. /// - /// # Errors + /// - `history = None` starts a new conversation (system + demos + input). + /// - `history = Some(chat)` continues a prior conversation by appending the + /// typed `input` as the next user turn. /// - /// - [`PredictError::Lm`] if the LM call fails (network, rate limit, timeout) - /// - [`PredictError::Parse`] if the response can't be parsed into the output fields - pub async fn forward(&self, input: S::Input) -> Result, PredictError> + /// Returns the parsed prediction. Updated chat history is available via + /// [`Predicted::chat`](crate::Predicted::chat). + pub async fn forward( + &self, + input: S::Input, + history: Option, + ) -> Result, PredictError> where S::Input: BamlType, S::Output: BamlType, { - let lm = { - let guard = GLOBAL_SETTINGS.read().unwrap(); - let settings = guard.as_ref().unwrap(); - Arc::clone(&settings.lm) - }; + let chat = self.compose_chat(&input, history)?; + self.execute_chat(chat).await + } + + #[allow(clippy::result_large_err)] + fn compose_chat(&self, input: &S::Input, history: Option) -> Result + where + S::Input: BamlType, + { + let chat_adapter = self.adapter.clone().unwrap_or_default(); + let user = chat_adapter.format_user_message_typed::(input); + trace!( + user_len = user.len(), + continuing = history.is_some(), + "typed input formatted" + ); + + if let Some(mut chat) = history { + chat.push(Role::User, &user); + trace!(message_count = chat.len(), "chat continued"); + return Ok(chat); + } - let chat_adapter = ChatAdapter; let system = match chat_adapter .format_system_message_typed_with_instruction::(self.instruction_override.as_deref()) { @@ -211,25 +233,37 @@ impl Predict { } }; - let user = chat_adapter.format_user_message_typed::(&input); trace!( system_len = system.len(), user_len = user.len(), - "typed prompt formatted" + "typed prompt initialized" ); let mut chat = Chat::new(vec![]); - chat.push("system", &system); + chat.push(Role::System, &system); for demo in &self.demos { let demo_user = chat_adapter.format_user_message_typed::(&demo.input); let demo_assistant = chat_adapter.format_assistant_message_typed::(&demo.output); - chat.push("user", &demo_user); - chat.push("assistant", &demo_assistant); + chat.push(Role::User, &demo_user); + chat.push(Role::Assistant, &demo_assistant); } - chat.push("user", &user); + chat.push(Role::User, &user); trace!(message_count = chat.len(), "chat constructed"); + Ok(chat) + } - let response = match lm.call(chat, self.tools.clone()).await { + async fn execute_chat(&self, chat: Chat) -> Result, PredictError> + where + S::Input: BamlType, + S::Output: BamlType, + { + let lm = { + let guard = GLOBAL_SETTINGS.read().unwrap(); + let settings = guard.as_ref().unwrap(); + Arc::clone(&settings.lm) + }; + + let response = match lm.call(chat, self.tools.clone(), ToolLoopMode::Auto).await { Ok(response) => response, Err(err) => { return Err(PredictError::Lm { @@ -249,6 +283,14 @@ impl Predict { "lm response received" ); + let crate::core::lm::LMResponse { + output, + usage, + chat, + tool_calls, + tool_executions, + } = response; + let node_id = if crate::trace::is_tracing() { crate::trace::record_node( crate::trace::NodeType::Predict { @@ -261,27 +303,28 @@ impl Predict { None }; - let raw_response = response.output.content().to_string(); - let lm_usage = response.usage.clone(); + let chat_adapter = self.adapter.clone().unwrap_or_default(); + let raw_response = output.content().to_string(); + let lm_usage = usage.clone(); - let (typed_output, field_metas) = - match chat_adapter.parse_response_typed::(&response.output) { - Ok(parsed) => parsed, - Err(err) => { - let failed_fields = err.fields(); - debug!( - failed_fields = failed_fields.len(), - fields = ?failed_fields, - raw_response_len = raw_response.len(), - "typed parse failed" - ); - return Err(PredictError::Parse { - source: err, - raw_response, - lm_usage, - }); - } - }; + let (typed_output, field_metas) = match chat_adapter.parse_response_typed::(&output) { + Ok(parsed) => parsed, + Err(err) => { + let failed_fields = err.fields(); + debug!( + failed_fields = failed_fields.len(), + fields = ?failed_fields, + raw_response_len = raw_response.len(), + "typed parse failed" + ); + return Err(PredictError::Parse { + source: err, + raw_response, + lm_usage, + chat: chat.clone(), + }); + } + }; let checks_total = field_metas .values() @@ -316,13 +359,13 @@ impl Predict { let metadata = CallMetadata::new( raw_response, lm_usage, - response.tool_calls, - response.tool_executions, + tool_calls, + tool_executions, node_id, field_metas, ); - Ok(Predicted::new(typed_output, metadata)) + Ok(Predicted::new(typed_output, metadata, chat)) } } @@ -346,6 +389,7 @@ pub struct PredictBuilder { tools: Vec>, demos: Vec>, instruction_override: Option, + adapter: Option, _marker: PhantomData, } @@ -355,6 +399,7 @@ impl PredictBuilder { tools: Vec::new(), demos: Vec::new(), instruction_override: None, + adapter: None, _marker: PhantomData, } } @@ -389,12 +434,19 @@ impl PredictBuilder { self } + /// Overrides the adapter used for prompt composition and parsing. + pub fn adapter(mut self, adapter: ChatAdapter) -> Self { + self.adapter = Some(adapter); + self + } + /// Builds the [`Predict`]. pub fn build(self) -> Predict { Predict { tools: self.tools, demos: self.demos, instruction_override: self.instruction_override, + adapter: self.adapter, _marker: PhantomData, } } @@ -537,7 +589,7 @@ where ) )] async fn forward(&self, input: S::Input) -> Result, PredictError> { - Predict::forward(self, input).await + Predict::forward(self, input, None).await } } diff --git a/crates/dspy-rs/tests/test_adapter_dialect_passthrough.rs b/crates/dspy-rs/tests/test_adapter_dialect_passthrough.rs new file mode 100644 index 00000000..3a1c603b --- /dev/null +++ b/crates/dspy-rs/tests/test_adapter_dialect_passthrough.rs @@ -0,0 +1,154 @@ +use dspy_rs::{ + ChatAdapter, LM, LMClient, ParseError, Predict, PredictError, Signature, TestCompletionModel, + configure, +}; +use rig::completion::AssistantContent; +use rig::message::Text; +use std::sync::LazyLock; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn text_response(text: impl Into) -> AssistantContent { + AssistantContent::Text(Text { text: text.into() }) +} + +fn structured_response(fields: &[(&str, &str)]) -> String { + let mut response = String::new(); + for (name, value) in fields { + response.push_str(&format!("[[ ## {name} ## ]]\n{value}\n\n")); + } + response.push_str("[[ ## completed ## ]]\n"); + response +} + +async fn configure_test_lm(responses: Vec) { + let client = TestCompletionModel::new(responses.into_iter().map(text_response)); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client)) + .await + .unwrap(); + configure(lm, ChatAdapter::new()); +} + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Generate executable Python code for the task. +struct RlmActionLike { + #[input] + task: String, + + #[output] + code: String, +} + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Answer the question. +struct QaLike { + #[input] + question: String, + + #[output] + answer: String, +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn passthrough_adapter_maps_entire_response_to_code() { + let _lock = SETTINGS_LOCK.lock().await; + configure_test_lm(vec![r#"print("ok")"#.to_string()]).await; + + let predict = Predict::::new().adapter(ChatAdapter::passthrough()); + let result = predict + .call(RlmActionLikeInput { + task: "print ok".to_string(), + }) + .await + .expect("passthrough parse should succeed"); + + assert_eq!(result.into_inner().code, r#"print("ok")"#); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn passthrough_adapter_extracts_fenced_code() { + let _lock = SETTINGS_LOCK.lock().await; + configure_test_lm(vec!["```python\nprint('hello')\n```\n".to_string()]).await; + + let predict = Predict::::new().adapter(ChatAdapter::passthrough()); + let result = predict + .call(RlmActionLikeInput { + task: "say hello".to_string(), + }) + .await + .expect("fenced passthrough parse should succeed"); + + assert_eq!(result.into_inner().code, "print('hello')"); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn passthrough_whitespace_response_surfaces_parse_error_with_chat() { + let _lock = SETTINGS_LOCK.lock().await; + configure_test_lm(vec![" \n\t".to_string()]).await; + + let predict = Predict::::new().adapter(ChatAdapter::passthrough()); + let err = predict + .call(RlmActionLikeInput { + task: "do something".to_string(), + }) + .await + .expect_err("whitespace passthrough response should fail parse"); + + match err { + PredictError::Parse { + source: ParseError::ExtractionFailed { .. }, + raw_response, + chat, + .. + } => { + assert!(raw_response.trim().is_empty()); + assert!( + !chat.is_empty(), + "parse error should carry conversation chat for recovery" + ); + } + other => panic!("unexpected error variant: {other:?}"), + } +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn per_predict_adapter_selection_allows_mixed_dialects() { + let _lock = SETTINGS_LOCK.lock().await; + configure_test_lm(vec![ + "print(2 + 2)".to_string(), + structured_response(&[("answer", "4")]), + ]) + .await; + + let action_predict = Predict::::new().adapter(ChatAdapter::passthrough()); + let extract_predict = Predict::::new(); + + let action = action_predict + .call(RlmActionLikeInput { + task: "math".to_string(), + }) + .await + .expect("passthrough action parse should succeed"); + assert_eq!(action.code, "print(2 + 2)"); + + let extract = extract_predict + .call(QaLikeInput { + question: "2 + 2".to_string(), + }) + .await + .expect("default chat parse should succeed"); + assert_eq!(extract.answer, "4"); +} diff --git a/crates/dspy-rs/tests/test_adapters.rs b/crates/dspy-rs/tests/test_adapters.rs index 65ee7279..46e46de7 100644 --- a/crates/dspy-rs/tests/test_adapters.rs +++ b/crates/dspy-rs/tests/test_adapters.rs @@ -51,7 +51,7 @@ struct DeepFlattenSig { #[test] fn chat_adapter_formats_typed_system_prompt() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let system = adapter .format_system_message_typed::() .expect("system prompt should format"); @@ -65,7 +65,7 @@ fn chat_adapter_formats_typed_system_prompt() { #[test] fn chat_adapter_formats_user_and_assistant_messages() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let user = adapter.format_user_message_typed::(&BasicSignatureInput { problem: "What is the capital of France?".to_string(), @@ -87,7 +87,7 @@ fn chat_adapter_formats_user_and_assistant_messages() { #[test] fn chat_adapter_parses_typed_response() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let response = Message::assistant("[[ ## answer ## ]]\nParis\n\n[[ ## completed ## ]]"); let (output, field_meta) = adapter @@ -114,7 +114,7 @@ fn parse_sections_accepts_non_word_field_names() { #[test] fn chat_adapter_formats_user_messages_with_multi_level_flatten_paths() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let user = adapter.format_user_message_typed::(&DeepFlattenSigInput { question: "What should we answer?".to_string(), middle: FlattenMiddleSigInput { diff --git a/crates/dspy-rs/tests/test_bamltype_docs_contract.rs b/crates/dspy-rs/tests/test_bamltype_docs_contract.rs index f1216b73..73510f9f 100644 --- a/crates/dspy-rs/tests/test_bamltype_docs_contract.rs +++ b/crates/dspy-rs/tests/test_bamltype_docs_contract.rs @@ -63,7 +63,7 @@ struct DocsTypeEffectsSig { } fn system_message() -> String { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); adapter .format_system_message_typed::() .expect("system message") diff --git a/crates/dspy-rs/tests/test_call_outcome.rs b/crates/dspy-rs/tests/test_call_outcome.rs index ee8183c1..c9d0a7bb 100644 --- a/crates/dspy-rs/tests/test_call_outcome.rs +++ b/crates/dspy-rs/tests/test_call_outcome.rs @@ -1,5 +1,5 @@ use dspy_rs::{ - CallMetadata, ConstraintResult, FieldMeta, LmUsage, ParseError, PredictError, Predicted, + CallMetadata, Chat, ConstraintResult, FieldMeta, LmUsage, ParseError, PredictError, Predicted, }; use indexmap::IndexMap; @@ -17,6 +17,7 @@ fn parse_error_preserves_raw_response_and_usage() { }, raw_response: "raw response".to_string(), lm_usage: usage.clone(), + chat: Chat::new(vec![]), }; match err { @@ -24,6 +25,7 @@ fn parse_error_preserves_raw_response_and_usage() { source: ParseError::MissingField { field, .. }, raw_response, lm_usage, + .. } => { assert_eq!(field, "answer"); assert_eq!(raw_response, "raw response"); @@ -60,7 +62,7 @@ fn predicted_exposes_field_metadata() { field_meta, ); - let predicted = Predicted::new("Paris".to_string(), metadata); + let predicted = Predicted::new("Paris".to_string(), metadata, Chat::new(vec![])); assert_eq!(predicted.metadata().field_raw("answer"), Some("Paris")); assert!(!predicted.metadata().has_failed_checks()); diff --git a/crates/dspy-rs/tests/test_caller_managed_conversation.rs b/crates/dspy-rs/tests/test_caller_managed_conversation.rs new file mode 100644 index 00000000..c059c64a --- /dev/null +++ b/crates/dspy-rs/tests/test_caller_managed_conversation.rs @@ -0,0 +1,216 @@ +//! CallerManaged + tools + conversation flow test. +//! +//! This is the RLM critical path: the caller controls tool execution and +//! manages the conversation loop, not the LM layer's auto tool loop. + +use dspy_rs::{ + ChatAdapter, LM, LMClient, Message, Predict, Role, Signature, TestCompletionModel, + ToolLoopMode, configure, +}; +use rig::completion::AssistantContent; +use rig::message::{Text, ToolCall, ToolFunction}; +use std::sync::LazyLock; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn response_with_fields(fields: &[(&str, &str)]) -> String { + let mut response = String::new(); + for (name, value) in fields { + response.push_str(&format!("[[ ## {name} ## ]]\n{value}\n\n")); + } + response.push_str("[[ ## completed ## ]]\n"); + response +} + +fn text_response(text: impl Into) -> AssistantContent { + AssistantContent::Text(Text { text: text.into() }) +} + +async fn build_test_lm(responses: Vec) -> (LM, TestCompletionModel) { + let client = TestCompletionModel::new(responses); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client.clone())) + .await + .unwrap(); + (lm, client) +} + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Code execution signature for RLM-style interaction. +struct CodeExec { + #[input] + prompt: String, + + #[output] + result: String, +} + +/// The full RLM-style loop: +/// 1. Predict builds initial chat → calls LM → model requests a tool call +/// 2. CallerManaged mode: LM returns the tool call without executing it +/// 3. Caller manually executes the tool, then calls Predict forward with prior chat history +/// 4. LM returns the final text answer +/// +/// This is the exact pattern RLM will use for Python REPL interaction. +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn caller_managed_tool_loop_with_conversation() { + let _lock = SETTINGS_LOCK.lock().await; + + // Response 1: model wants to call a tool (returned as text since TestCompletionModel + // only supports single-content responses via AssistantContent) + let tool_call_response = + text_response("[[ ## result ## ]]\nNeed to execute code first\n\n[[ ## completed ## ]]\n"); + // Response 2: after seeing tool result, model gives final answer + let final_response = text_response(response_with_fields(&[("result", "42")])); + + let (lm, _client) = build_test_lm(vec![tool_call_response, final_response]).await; + configure(lm, ChatAdapter::new()); + + let predict = Predict::::new(); + let input = CodeExecInput { + prompt: "Calculate 6 * 7".to_string(), + }; + + // Turn 1 + let first_result = predict + .forward(input, None) + .await + .expect("first turn forward should succeed"); + let chat = first_result.chat().clone(); + assert_eq!( + first_result.into_inner().result, + "Need to execute code first" + ); + + // Turn 2: continue with prior chat and typed follow-up + let follow_up = CodeExecInput { + prompt: "Tool output: 42".to_string(), + }; + let second_result = predict + .forward(follow_up, Some(chat)) + .await + .expect("second turn forward should succeed"); + let final_chat = second_result.chat().clone(); + assert_eq!(second_result.into_inner().result, "42"); + + // Verify chat grew across turns + assert!( + final_chat.len() >= 5, + "chat should have system + user + asst + user + asst, got {}", + final_chat.len() + ); + + // Verify turn ordering + assert_eq!(final_chat.messages[0].role, Role::System); + assert_eq!(final_chat.messages[1].role, Role::User); + assert_eq!(final_chat.messages[2].role, Role::Assistant); + assert_eq!(final_chat.messages[3].role, Role::User); // caller's tool result + assert_eq!(final_chat.messages[4].role, Role::Assistant); // final answer +} + +/// Tests the LM-level CallerManaged mode directly: when a tool call is requested +/// with CallerManaged mode, the LM returns the tool calls without executing them +/// and the caller controls what happens next. +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn lm_caller_managed_returns_tool_calls_in_chat_history() { + let _lock = SETTINGS_LOCK.lock().await; + + // Model responds with a tool call + let tool_call_content = AssistantContent::ToolCall(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "python_repl".to_string(), + arguments: serde_json::json!({"code": "print(6 * 7)"}), + }, + )); + + let (lm, _client) = build_test_lm(vec![tool_call_content]).await; + + let chat = dspy_rs::Chat::new(vec![Message::user("Run some code")]); + let response = lm + .call(chat, vec![], ToolLoopMode::CallerManaged) + .await + .expect("caller-managed call should succeed"); + + // Tool calls returned but NOT executed + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_calls[0].function.name, "python_repl"); + assert!( + response.tool_executions.is_empty(), + "CallerManaged should not execute tools" + ); + + // Chat history should contain the tool call message + assert!( + response.chat.messages.iter().any(|m| m.has_tool_calls()), + "chat history should include the tool call message" + ); +} + +/// Multi-turn with parse failure on second turn verifies that errors +/// include the correct raw_response from the continuation, not the first turn. +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn parse_failure_on_second_turn_includes_correct_raw_response() { + let _lock = SETTINGS_LOCK.lock().await; + + let good_response = text_response(response_with_fields(&[("result", "first answer")])); + // Second response is malformed — no field markers + let bad_response = text_response("This response has no field markers at all."); + + let (lm, _client) = build_test_lm(vec![good_response, bad_response]).await; + configure(lm, ChatAdapter::new()); + + let predict = Predict::::new(); + let input = CodeExecInput { + prompt: "test".to_string(), + }; + + // Turn 1: succeeds + let first_result = predict.forward(input, None).await.expect("turn 1"); + let chat = first_result.chat().clone(); + assert_eq!(first_result.into_inner().result, "first answer"); + + // Turn 2: should fail with parse error containing the bad response + let follow_up = CodeExecInput { + prompt: "follow up".to_string(), + }; + let err = predict + .forward(follow_up, Some(chat)) + .await + .expect_err("second turn should fail"); + + match err { + dspy_rs::PredictError::Parse { + raw_response, + source, + .. + } => { + assert!( + raw_response.contains("no field markers"), + "raw_response should be from the second turn, got: {}", + raw_response + ); + // The error should mention the missing field + let fields = source.fields(); + assert!( + !fields.is_empty() || source.field().is_some(), + "parse error should identify which field(s) failed" + ); + } + other => panic!( + "expected PredictError::Parse, got: {:?}", + std::mem::discriminant(&other) + ), + } +} diff --git a/crates/dspy-rs/tests/test_chain_of_thought_swap.rs b/crates/dspy-rs/tests/test_chain_of_thought_swap.rs index 6e7614a0..a199e56b 100644 --- a/crates/dspy-rs/tests/test_chain_of_thought_swap.rs +++ b/crates/dspy-rs/tests/test_chain_of_thought_swap.rs @@ -23,21 +23,20 @@ fn text_response(text: impl Into) -> AssistantContent { } async fn configure_test_lm(responses: Vec) { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - let client = TestCompletionModel::new(responses.into_iter().map(text_response)); - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .build() - .await - .unwrap() - .with_client(LMClient::Test(client)) - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client)) + .await + .unwrap(); - configure(lm, ChatAdapter {}); + configure(lm, ChatAdapter::new()); } #[derive(Signature, Clone, Debug, PartialEq, facet::Facet)] diff --git a/crates/dspy-rs/tests/test_chat.rs b/crates/dspy-rs/tests/test_chat.rs index 0e231301..0efd2b10 100644 --- a/crates/dspy-rs/tests/test_chat.rs +++ b/crates/dspy-rs/tests/test_chat.rs @@ -1,4 +1,9 @@ -use dspy_rs::core::{Chat, Message}; +use dspy_rs::core::lm::chat::{Chat, ContentBlock, Message, Role}; +use rig::OneOrMany; +use rig::message::{ + AssistantContent, Message as RigMessage, Reasoning, ToolCall, ToolFunction, ToolResult, + ToolResultContent, UserContent, +}; use rstest::*; use serde_json::json; @@ -10,81 +15,66 @@ fn test_chat_init() { Message::assistant("Hello, world to you!"), ]); - let json_value = chat.to_json(); - let json = json_value.as_array().unwrap(); - assert_eq!(chat.len(), 3); - assert_eq!(json[0]["role"], "system"); assert!(!chat.is_empty()); - assert_eq!( - json[0]["content"], - "You are a helpful assistant.".to_string() - ); - assert_eq!(json[1]["role"], "user"); - assert_eq!(json[1]["content"], "Hello, world!".to_string()); - assert_eq!(json[2]["role"], "assistant"); - assert_eq!(json[2]["content"], "Hello, world to you!".to_string()); + assert_eq!(chat.messages[0].role, Role::System); + assert_eq!(chat.messages[0].content(), "You are a helpful assistant."); + assert_eq!(chat.messages[1].role, Role::User); + assert_eq!(chat.messages[1].content(), "Hello, world!"); + assert_eq!(chat.messages[2].role, Role::Assistant); + assert_eq!(chat.messages[2].content(), "Hello, world to you!"); } #[rstest] fn test_chat_push() { let mut chat = Chat::new(vec![]); - chat.push("user", "Hello, world!"); + chat.push(Role::User, "Hello, world!"); - let json_value = chat.to_json(); - let json = json_value.as_array().unwrap(); - assert_eq!(json.len(), 1); - assert_eq!(json[0]["role"], "user"); - assert_eq!(json[0]["content"], "Hello, world!".to_string()); + assert_eq!(chat.len(), 1); + assert_eq!(chat.messages[0].role, Role::User); + assert_eq!(chat.messages[0].content(), "Hello, world!"); } #[rstest] fn test_chat_pop() { let mut chat = Chat::new(vec![]); - chat.push("user", "Hello, world!"); + chat.push(Role::User, "Hello, world!"); chat.pop(); - let json_value = chat.to_json(); - let json = json_value.as_array().unwrap(); - assert_eq!(json.len(), 0); + assert_eq!(chat.len(), 0); } #[rstest] -fn test_chat_to_json() { +fn test_chat_to_json_and_back() { let chat = Chat::new(vec![ Message::system("You are a helpful assistant."), Message::user("Hello, world!"), Message::assistant("Hello, world to you!"), ]); - let json = chat.to_json(); + let json_dump = chat.to_json(); + let reparsed = Chat::new(vec![]).from_json(json_dump).unwrap(); + + assert_eq!(reparsed.len(), 3); + assert_eq!(reparsed.messages[0].role, Role::System); assert_eq!( - json.to_string(), - "[{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},{\"role\":\"user\",\"content\":\"Hello, world!\"},{\"role\":\"assistant\",\"content\":\"Hello, world to you!\"}]" + reparsed.messages[0].content(), + "You are a helpful assistant." ); + assert_eq!(reparsed.messages[1].role, Role::User); + assert_eq!(reparsed.messages[1].content(), "Hello, world!"); + assert_eq!(reparsed.messages[2].role, Role::Assistant); + assert_eq!(reparsed.messages[2].content(), "Hello, world to you!"); } #[rstest] -fn test_chat_from_json() { +fn test_chat_from_json_requires_grouped_content() { let json = json!([ {"role":"system","content":"You are a helpful assistant."}, {"role":"user","content":"Hello, world!"}, {"role":"assistant","content":"Hello, world to you!"} ]); - let empty_chat = Chat::new(vec![]); - let chat = empty_chat.from_json(json).unwrap(); - - let json_value = chat.to_json(); - let json = json_value.as_array().unwrap(); - - assert_eq!(chat.len(), 3); - assert_eq!(json[0]["role"], "system"); - assert_eq!( - json[0]["content"], - "You are a helpful assistant.".to_string() - ); - assert_eq!(json[1]["role"], "user"); - assert_eq!(json[1]["content"], "Hello, world!".to_string()); - assert_eq!(json[2]["content"], "Hello, world to you!".to_string()); + let err = Chat::new(vec![]).from_json(json).unwrap_err(); + assert!(err.to_string().contains("content must be an array")); } #[rstest] @@ -103,20 +93,16 @@ fn test_chat_push_all() { chat1.push_all(&chat2); assert_eq!(chat1.len(), 5); - - let json_value = chat1.to_json(); - let json = json_value.as_array().unwrap(); - - assert_eq!(json[0]["role"], "system"); - assert_eq!(json[0]["content"], "You are a helpful assistant."); - assert_eq!(json[1]["role"], "user"); - assert_eq!(json[1]["content"], "Hello!"); - assert_eq!(json[2]["role"], "assistant"); - assert_eq!(json[2]["content"], "Hi there!"); - assert_eq!(json[3]["role"], "user"); - assert_eq!(json[3]["content"], "How are you?"); - assert_eq!(json[4]["role"], "assistant"); - assert_eq!(json[4]["content"], "I'm doing well, thank you!"); + assert_eq!(chat1.messages[0].role, Role::System); + assert_eq!(chat1.messages[0].content(), "You are a helpful assistant."); + assert_eq!(chat1.messages[1].role, Role::User); + assert_eq!(chat1.messages[1].content(), "Hello!"); + assert_eq!(chat1.messages[2].role, Role::Assistant); + assert_eq!(chat1.messages[2].content(), "Hi there!"); + assert_eq!(chat1.messages[3].role, Role::User); + assert_eq!(chat1.messages[3].content(), "How are you?"); + assert_eq!(chat1.messages[4].role, Role::Assistant); + assert_eq!(chat1.messages[4].content(), "I'm doing well, thank you!"); } #[rstest] @@ -127,10 +113,110 @@ fn test_chat_push_all_empty() { chat1.push_all(&empty_chat); assert_eq!(chat1.len(), 1); + assert_eq!(chat1.messages[0].role, Role::System); + assert_eq!(chat1.messages[0].content(), "System message"); +} + +#[rstest] +fn test_new_variants_round_trip_json() { + let call = ToolCall::new( + "call-1".to_string(), + ToolFunction { + name: "lookup".to_string(), + arguments: json!({ "query": "rust" }), + }, + ); + let result = ToolResult { + id: "call-1".to_string(), + call_id: Some("provider-call-1".to_string()), + content: OneOrMany::one(ToolResultContent::text("result payload")), + }; + let reasoning = Reasoning::new("thinking..."); - let json_value = chat1.to_json(); - let json = json_value.as_array().unwrap(); + let chat = Chat::new(vec![ + Message::system("You are a tool-using assistant."), + Message::tool_call(call.clone()), + Message::tool_result(result.clone()), + Message::reasoning(reasoning.clone()), + ]); + + let json_dump = chat.to_json(); + let reparsed = Chat::new(vec![]).from_json(json_dump).unwrap(); + assert_eq!(reparsed.len(), 4); + + assert_eq!(reparsed.messages[0].role, Role::System); + + assert_eq!(reparsed.messages[1].role, Role::Assistant); + assert!(reparsed.messages[1].has_tool_calls()); + let reparsed_calls = reparsed.messages[1].tool_calls(); + assert_eq!(reparsed_calls[0].function.name, call.function.name); + + assert_eq!(reparsed.messages[2].role, Role::User); + assert!(reparsed.messages[2].has_tool_results()); + + assert_eq!(reparsed.messages[3].role, Role::Assistant); + assert!(reparsed.messages[3].has_reasoning()); +} + +#[rstest] +fn test_from_rig_message_preserves_all_content() { + // User with text + tool result — both preserved + let user_msg = RigMessage::User { + content: OneOrMany::many(vec![ + UserContent::text("some context"), + UserContent::ToolResult(ToolResult { + id: "id-1".to_string(), + call_id: None, + content: OneOrMany::one(ToolResultContent::text("ok")), + }), + ]) + .unwrap(), + }; + let converted = Message::from(user_msg); + assert_eq!(converted.role, Role::User); + assert_eq!(converted.content.len(), 2); + assert!(matches!(converted.content[0], ContentBlock::Text { .. })); + assert!(matches!( + converted.content[1], + ContentBlock::ToolResult { .. } + )); + + // Assistant with reasoning + tool call — both preserved (was lossy before) + let assistant_msg = RigMessage::Assistant { + id: Some("asst-123".to_string()), + content: OneOrMany::many(vec![ + AssistantContent::Reasoning(Reasoning::new("step by step")), + AssistantContent::ToolCall(ToolCall::new( + "tool-2".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({ "q": "x" }), + }, + )), + ]) + .unwrap(), + }; + let converted = Message::from(assistant_msg); + assert_eq!(converted.role, Role::Assistant); + assert_eq!(converted.id, Some("asst-123".to_string())); + assert_eq!(converted.content.len(), 2); + assert!(converted.has_reasoning()); + assert!(converted.has_tool_calls()); +} + +#[rstest] +fn test_text_content_filters_non_text_blocks() { + let msg = Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("thinking")), + ContentBlock::text("the answer is 42"), + ], + ); - assert_eq!(json[0]["role"], "system"); - assert_eq!(json[0]["content"], "System message"); + // text_content() returns only Text blocks + assert_eq!(msg.text_content(), "the answer is 42"); + // content() returns everything + assert!(msg.content().contains("thinking")); + assert!(msg.content().contains("the answer is 42")); } diff --git a/crates/dspy-rs/tests/test_chat_adapter_schema.rs b/crates/dspy-rs/tests/test_chat_adapter_schema.rs index 388218a7..92b9a633 100644 --- a/crates/dspy-rs/tests/test_chat_adapter_schema.rs +++ b/crates/dspy-rs/tests/test_chat_adapter_schema.rs @@ -1,4 +1,4 @@ -use dspy_rs::{CallMetadata, ChatAdapter, Message, Predicted, Signature}; +use dspy_rs::{CallMetadata, Chat, ChatAdapter, Message, Predicted, Signature}; #[derive(Signature, Clone, Debug)] /// Adapter schema parse fixture. @@ -23,7 +23,7 @@ struct AliasSig { #[test] fn parse_response_typed_uses_schema_field_names() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let response = Message::assistant("[[ ## answer ## ]]\nParis\n\n[[ ## completed ## ]]\n"); let (output, field_meta) = adapter @@ -42,7 +42,7 @@ fn parse_response_typed_uses_schema_field_names() { None, field_meta, ); - let predicted = Predicted::new(output, metadata); + let predicted = Predicted::new(output, metadata, Chat::new(vec![])); assert_eq!(predicted.metadata().field_raw("answer"), Some("Paris")); assert!(!predicted.metadata().has_failed_checks()); @@ -51,7 +51,7 @@ fn parse_response_typed_uses_schema_field_names() { #[test] fn parse_response_typed_accepts_dotted_field_markers() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let response = Message::assistant("[[ ## answer.value ## ]]\nParis\n\n[[ ## completed ## ]]\n"); let (output, field_meta) = adapter diff --git a/crates/dspy-rs/tests/test_chat_prompt_composition.rs b/crates/dspy-rs/tests/test_chat_prompt_composition.rs index e216c15a..7db552e9 100644 --- a/crates/dspy-rs/tests/test_chat_prompt_composition.rs +++ b/crates/dspy-rs/tests/test_chat_prompt_composition.rs @@ -40,7 +40,7 @@ fn response_instruction_line(message: &str) -> &str { #[test] fn system_prompt_includes_all_sections_in_order_with_boundaries() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let system = adapter .format_system_message_typed::() .expect("system prompt should format"); @@ -80,7 +80,7 @@ fn system_prompt_includes_all_sections_in_order_with_boundaries() { #[test] fn system_prompt_field_descriptions_and_structure_are_present() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let system = adapter .format_system_message_typed::() .expect("system prompt should format"); @@ -101,7 +101,7 @@ fn system_prompt_field_descriptions_and_structure_are_present() { #[test] fn response_instruction_line_orders_output_fields() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let system = adapter .format_system_message_typed::() .expect("system prompt should format"); @@ -115,7 +115,7 @@ fn response_instruction_line_orders_output_fields() { #[test] fn instruction_override_is_used_in_objective_section() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let override_instruction = "Follow the rubric.\nCite the context."; let system = adapter .format_system_message_typed_with_instruction::(Some(override_instruction)) @@ -129,7 +129,7 @@ fn instruction_override_is_used_in_objective_section() { #[test] fn empty_instruction_uses_generated_fallback_objective() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let system = adapter .format_system_message_typed::() .expect("system prompt should format"); @@ -140,7 +140,7 @@ fn empty_instruction_uses_generated_fallback_objective() { #[test] fn typed_and_schema_system_builders_match() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let typed = adapter .format_system_message_typed_with_instruction::(Some("Override objective")) .expect("typed system prompt"); @@ -153,7 +153,7 @@ fn typed_and_schema_system_builders_match() { #[test] fn typed_and_schema_user_builders_match_and_append_requirements() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = PromptPartsSigInput { question: "What is the capital of France?".to_string(), context: "Facts: Paris is the capital city of France.".to_string(), @@ -184,9 +184,65 @@ fn typed_and_schema_user_builders_match_and_append_requirements() { ); } +#[test] +fn passthrough_user_message_has_no_marker_or_output_protocol_ceremony() { + let adapter = ChatAdapter::passthrough(); + let input = PromptPartsSigInput { + question: "What is the capital of France?".to_string(), + context: "Facts: Paris is the capital city of France.".to_string(), + }; + + let typed = adapter.format_user_message_typed::(&input); + let schema = adapter.format_input(PromptPartsSig::schema(), &input); + assert_eq!(typed, schema); + + assert!( + !typed.contains("[[ ##"), + "passthrough message should not include marker protocol:\n{typed}" + ); + assert!( + !typed.contains("Respond with the corresponding output fields"), + "passthrough message should not include output format instructions:\n{typed}" + ); + assert!( + !typed.contains("[[ ## completed ## ]]"), + "passthrough message should not include completion marker:\n{typed}" + ); + + assert!(typed.contains("question:")); + assert!(typed.contains("context:")); + assert!(typed.contains("What is the capital of France?")); + assert!(typed.contains("Facts: Paris is the capital city of France.")); +} + +#[test] +fn passthrough_system_with_instruction_override_is_raw_instruction_only() { + let adapter = ChatAdapter::passthrough(); + let override_instruction = "Use Python only.\nCall SUBMIT when done."; + let system = adapter + .format_system_message_typed_with_instruction::(Some(override_instruction)) + .expect("passthrough system prompt should format"); + + assert_eq!(system, override_instruction); + assert!(!system.contains("Your input fields are:")); + assert!(!system.contains("Your objective is:")); +} + +#[test] +fn passthrough_system_without_override_keeps_existing_scaffolding() { + let adapter = ChatAdapter::passthrough(); + let system = adapter + .format_system_message_typed::() + .expect("passthrough system prompt should format"); + + assert!(system.contains("Your input fields are:")); + assert!(system.contains("Your objective is:")); + assert!(system.contains("Answer the prompt using the provided context.")); +} + #[test] fn demo_format_composes_user_and_assistant_parts() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let demo = Example::::new( PromptPartsSigInput { question: "Question?".to_string(), @@ -213,7 +269,7 @@ fn demo_format_composes_user_and_assistant_parts() { #[test] fn typed_and_schema_assistant_builders_match_and_end_with_completed_marker() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let output = PromptPartsSigOutput { answer: "Paris".to_string(), confidence: 0.9, diff --git a/crates/dspy-rs/tests/test_chat_prompt_golden.rs b/crates/dspy-rs/tests/test_chat_prompt_golden.rs index 0cca5ece..b14e46aa 100644 --- a/crates/dspy-rs/tests/test_chat_prompt_golden.rs +++ b/crates/dspy-rs/tests/test_chat_prompt_golden.rs @@ -11,7 +11,7 @@ struct GoldenSig { #[test] fn golden_system_prompt_is_stable() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let system = adapter .format_system_message_typed::() .expect("system prompt should format"); @@ -44,7 +44,7 @@ fn golden_system_prompt_is_stable() { #[test] fn golden_user_prompt_is_stable() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = GoldenSigInput { question: "What is 2+2?".to_string(), }; @@ -62,7 +62,7 @@ fn golden_user_prompt_is_stable() { #[test] fn golden_assistant_prompt_is_stable() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let output = GoldenSigOutput { answer: "4".to_string(), }; @@ -79,7 +79,7 @@ fn golden_assistant_prompt_is_stable() { #[test] fn golden_demo_messages_are_stable() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let demo = Example::::new( GoldenSigInput { question: "What is 2+2?".to_string(), diff --git a/crates/dspy-rs/tests/test_dataloader.rs b/crates/dspy-rs/tests/test_dataloader.rs index e98d8db8..47cd9329 100644 --- a/crates/dspy-rs/tests/test_dataloader.rs +++ b/crates/dspy-rs/tests/test_dataloader.rs @@ -4,7 +4,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use bon::Builder; use dspy_rs::{ - COPRO, CallMetadata, DataLoader, Example, MetricOutcome, Module, Optimizer, Predict, + COPRO, CallMetadata, Chat, DataLoader, Example, MetricOutcome, Module, Optimizer, Predict, PredictError, Predicted, Signature, TypedLoadOptions, TypedMetric, UnknownFieldPolicy, average_score, evaluate_trainset, }; @@ -54,6 +54,7 @@ impl Module for EchoModule { answer: input.question, }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/tests/test_evaluate_trainset_typed.rs b/crates/dspy-rs/tests/test_evaluate_trainset_typed.rs index 95e3f26b..0a5670ea 100644 --- a/crates/dspy-rs/tests/test_evaluate_trainset_typed.rs +++ b/crates/dspy-rs/tests/test_evaluate_trainset_typed.rs @@ -1,7 +1,7 @@ use anyhow::{Result, anyhow}; use dspy_rs::{ - CallMetadata, Example, MetricOutcome, Module, PredictError, Predicted, Signature, TypedMetric, - average_score, evaluate_trainset, + CallMetadata, Chat, Example, MetricOutcome, Module, PredictError, Predicted, Signature, + TypedMetric, average_score, evaluate_trainset, }; use std::sync::{Arc, Mutex}; @@ -26,6 +26,7 @@ impl Module for EchoModule { answer: input.prompt, }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/tests/test_flatten_roundtrip.rs b/crates/dspy-rs/tests/test_flatten_roundtrip.rs index 78874ff9..eb88b160 100644 --- a/crates/dspy-rs/tests/test_flatten_roundtrip.rs +++ b/crates/dspy-rs/tests/test_flatten_roundtrip.rs @@ -11,7 +11,7 @@ struct QA { #[test] fn augmented_demo_roundtrips_through_adapter() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let demo = Example::>::new( QAInput { question: "What is 2+2?".to_string(), diff --git a/crates/dspy-rs/tests/test_gepa_typed_metric_feedback.rs b/crates/dspy-rs/tests/test_gepa_typed_metric_feedback.rs index f6ab2a63..0dca6630 100644 --- a/crates/dspy-rs/tests/test_gepa_typed_metric_feedback.rs +++ b/crates/dspy-rs/tests/test_gepa_typed_metric_feedback.rs @@ -1,6 +1,6 @@ use anyhow::Result; use dspy_rs::{ - CallMetadata, Example, FeedbackMetric, GEPA, MetricOutcome, Module, Optimizer, Predict, + CallMetadata, Chat, Example, FeedbackMetric, GEPA, MetricOutcome, Module, Optimizer, Predict, PredictError, Predicted, Signature, TypedMetric, }; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -35,6 +35,7 @@ impl Module for InstructionEchoModule { answer: input.prompt, }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/tests/test_input_format.rs b/crates/dspy-rs/tests/test_input_format.rs index 8170ae9c..44ac40df 100644 --- a/crates/dspy-rs/tests/test_input_format.rs +++ b/crates/dspy-rs/tests/test_input_format.rs @@ -161,7 +161,7 @@ fn extract_baml_field<'a>(value: &'a BamlValue, field_name: &str) -> &'a BamlVal #[test] fn typed_input_format_yaml_renders_field_names() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = FormatSigInput { question: "What is YAML?".to_string(), context: vec![Document { @@ -179,7 +179,7 @@ fn typed_input_format_yaml_renders_field_names() { #[test] fn typed_input_format_json_is_parsable() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = FormatJsonSigInput { question: "What is JSON?".to_string(), context: vec![Document { @@ -201,7 +201,7 @@ fn typed_input_format_json_is_parsable() { #[test] fn typed_input_format_toon_matches_formatter() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = FormatToonSigInput { question: "What is TOON?".to_string(), context: vec![Document { @@ -224,7 +224,7 @@ fn typed_input_format_toon_matches_formatter() { #[test] fn typed_input_default_string_is_raw() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = DefaultFormatSigInput { question: "Raw string".to_string(), context: vec![Document { @@ -240,7 +240,7 @@ fn typed_input_default_string_is_raw() { #[test] fn typed_input_default_non_string_is_json() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = DefaultFormatSigInput { question: "Default JSON".to_string(), context: vec![Document { @@ -261,7 +261,7 @@ fn typed_input_default_non_string_is_json() { #[test] fn typed_input_appends_response_instruction_reminder() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = DefaultFormatSigInput { question: "Reminder check".to_string(), context: vec![Document { @@ -277,7 +277,7 @@ fn typed_input_appends_response_instruction_reminder() { #[test] fn typed_input_render_jinja_uses_context_values() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = RenderJinjaSigInput { question: "Question".to_string(), context: Document { @@ -296,7 +296,7 @@ fn typed_input_render_jinja_uses_context_values() { #[test] fn typed_input_render_jinja_missing_var_panics() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = RenderJinjaStrictSigInput { question: "Question".to_string(), }; @@ -309,7 +309,7 @@ fn typed_input_render_jinja_missing_var_panics() { #[test] fn typed_input_render_jinja_exposes_field_metadata_and_vars() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = RenderJinjaFieldMetaSigInput { context: Document { text: "Hello".to_string(), @@ -329,7 +329,7 @@ fn typed_input_render_jinja_exposes_field_metadata_and_vars() { #[test] fn typed_input_render_jinja_non_string_primitives() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = RenderPrimitiveSigInput { count: 42, is_ready: true, @@ -345,7 +345,7 @@ fn typed_input_render_jinja_non_string_primitives() { #[test] fn typed_input_render_jinja_supports_contrib_filters() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let input = RenderContribFilterSigInput { context: Document { text: "abcdefg".to_string(), diff --git a/crates/dspy-rs/tests/test_lm.rs b/crates/dspy-rs/tests/test_lm.rs index 41106d9f..d6557b02 100644 --- a/crates/dspy-rs/tests/test_lm.rs +++ b/crates/dspy-rs/tests/test_lm.rs @@ -84,16 +84,15 @@ fn test_lm_usage_add() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_with_cache_enabled() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - // Create LM with cache enabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(true) + .build(), + ) + .await + .unwrap(); // Verify cache handler is initialized assert!(lm.cache_handler.is_some()); @@ -102,16 +101,15 @@ async fn test_lm_with_cache_enabled() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_with_cache_disabled() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - // Create LM with cache explicitly disabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(false) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(false) + .build(), + ) + .await + .unwrap(); // Verify cache handler is NOT initialized when cache is disabled assert!(lm.cache_handler.is_none()); @@ -120,16 +118,15 @@ async fn test_lm_with_cache_disabled() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_cache_initialization_on_first_call() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - // Create LM with cache enabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(true) + .build(), + ) + .await + .unwrap(); // After build, cache_handler should be initialized assert!(lm.cache_handler.is_some()); @@ -138,20 +135,19 @@ async fn test_lm_cache_initialization_on_first_call() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_cache_direct_operations() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } use dspy_rs::Prediction; use dspy_rs::data::RawExample; use std::collections::HashMap; - // Create LM with cache enabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(true) + .build(), + ) + .await + .unwrap(); // Get cache handler let cache = lm @@ -207,20 +203,23 @@ async fn test_lm_cache_direct_operations() { #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_lm_cache_with_different_models() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - std::env::set_var("ANTHROPIC_API_KEY", "test"); - } // Test that cache works with different model configurations - let models = vec!["openai:gpt-3.5-turbo", "anthropic:claude-3-haiku-20240307"]; + let models = vec![ + "openai:gpt-3.5-turbo", + "openai-responses:gpt-4o-mini", + "anthropic:claude-3-haiku-20240307", + ]; for model in models { - let lm = LM::builder() - .model(model.to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [ + ("OPENAI_API_KEY", Some("test")), + ("ANTHROPIC_API_KEY", Some("test")), + ], + LM::builder().model(model.to_string()).cache(true).build(), + ) + .await + .unwrap(); // Cache should be initialized regardless of model assert!( @@ -231,23 +230,35 @@ async fn test_lm_cache_with_different_models() { } } +#[tokio::test] +#[cfg_attr(miri, ignore)] +async fn test_lm_local_openai_responses_provider_builds() { + let lm = LM::builder() + .base_url("http://localhost:11434/v1".to_string()) + .model("openai-responses:gpt-5.2".to_string()) + .build() + .await + .expect("openai-responses local build should succeed"); + + assert_eq!(lm.model, "openai-responses:gpt-5.2"); +} + #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_cache_with_complex_inputs() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } use dspy_rs::Prediction; use dspy_rs::data::RawExample; use std::collections::HashMap; - // Create LM with cache enabled - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .cache(true) - .build() - .await - .unwrap(); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .cache(true) + .build(), + ) + .await + .unwrap(); let cache = lm .cache_handler diff --git a/crates/dspy-rs/tests/test_message_roundtrip.rs b/crates/dspy-rs/tests/test_message_roundtrip.rs new file mode 100644 index 00000000..48f82b2e --- /dev/null +++ b/crates/dspy-rs/tests/test_message_roundtrip.rs @@ -0,0 +1,169 @@ +//! Public-API tests for the grouped Message model. +//! +//! These tests validate message/content behavior through stable public methods +//! (`to_json`/`from_json`, content accessors) without calling crate-internal +//! rig conversion helpers. + +use dspy_rs::{Chat, ContentBlock, Message, Role}; +use rig::OneOrMany; +use rig::message::{Reasoning, ToolCall, ToolFunction, ToolResult, ToolResultContent}; +use serde_json::json; + +#[test] +fn grouped_message_json_roundtrip() { + let original = Chat::new(vec![ + Message::system("Be helpful"), + Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("let me think")), + ContentBlock::text("the answer is 42"), + ContentBlock::tool_call(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "verify".to_string(), + arguments: json!({"answer": 42}), + }, + )), + ], + ), + Message::with_content( + Role::User, + vec![ + ContentBlock::tool_result(ToolResult { + id: "tc-1".to_string(), + call_id: None, + content: OneOrMany::one(ToolResultContent::text("confirmed")), + }), + ContentBlock::text("Thanks! Can you also check 43?"), + ], + ), + ]); + + let json = original.to_json(); + let reparsed = Chat::new(vec![]).from_json(json).unwrap(); + + assert_eq!(reparsed.len(), 3); + + let asst = &reparsed.messages[1]; + assert_eq!(asst.role, Role::Assistant); + assert_eq!(asst.content.len(), 3); + assert!(asst.has_reasoning()); + assert!(asst.has_tool_calls()); + + let user = &reparsed.messages[2]; + assert_eq!(user.role, Role::User); + assert_eq!(user.content.len(), 2); + assert!(user.has_tool_results()); +} + +#[test] +fn multi_turn_json_roundtrip_preserves_earlier_reasoning() { + let chat = Chat::new(vec![ + Message::system("You are a helpful assistant."), + Message::user("What is the capital of France?"), + Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("The user is asking about geography.")), + ContentBlock::text("The capital of France is Paris."), + ], + ), + Message::user("And Germany?"), + Message::assistant("The capital of Germany is Berlin."), + ]); + + let reparsed = Chat::new(vec![]).from_json(chat.to_json()).unwrap(); + assert_eq!(reparsed.len(), 5); + + let turn1_reply = &reparsed.messages[2]; + assert_eq!(turn1_reply.role, Role::Assistant); + assert!(turn1_reply.has_reasoning()); + assert_eq!(turn1_reply.content.len(), 2); +} + +#[test] +fn legacy_plain_string_json_is_rejected() { + let legacy_json = json!([ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ]); + + let err = Chat::new(vec![]).from_json(legacy_json).unwrap_err(); + assert!(err.to_string().contains("content must be an array")); +} + +#[test] +fn text_content_excludes_non_text_blocks() { + let msg = Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("internal monologue")), + ContentBlock::text("visible answer"), + ContentBlock::tool_call(ToolCall::new( + "tc".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({}), + }, + )), + ], + ); + + assert_eq!(msg.text_content(), "visible answer"); + let full = msg.content(); + assert!(full.contains("internal monologue")); + assert!(full.contains("visible answer")); + assert!(full.contains("search")); +} + +#[test] +fn tool_calls_accessor_returns_all_tool_calls() { + let msg = Message::with_content( + Role::Assistant, + vec![ + ContentBlock::reasoning(Reasoning::new("planning")), + ContentBlock::tool_call(ToolCall::new( + "tc-1".to_string(), + ToolFunction { + name: "search".to_string(), + arguments: json!({"q": "a"}), + }, + )), + ContentBlock::tool_call(ToolCall::new( + "tc-2".to_string(), + ToolFunction { + name: "calculate".to_string(), + arguments: json!({"expr": "1+1"}), + }, + )), + ], + ); + + let calls = msg.tool_calls(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].function.name, "search"); + assert_eq!(calls[1].function.name, "calculate"); +} + +#[test] +fn empty_content_message_does_not_panic() { + let msg = Message::with_content(Role::Assistant, vec![]); + assert_eq!(msg.content(), ""); + assert_eq!(msg.text_content(), ""); + assert!(!msg.has_tool_calls()); + assert!(!msg.has_reasoning()); +} + +#[test] +fn message_id_survives_json_roundtrip() { + let mut msg = Message::assistant("some text"); + msg.id = Some("msg_abc123".to_string()); + + let chat = Chat::new(vec![msg]); + let reparsed = Chat::new(vec![]).from_json(chat.to_json()).unwrap(); + + assert_eq!(reparsed.messages.len(), 1); + assert_eq!(reparsed.messages[0].id, Some("msg_abc123".to_string())); +} diff --git a/crates/dspy-rs/tests/test_module_ext.rs b/crates/dspy-rs/tests/test_module_ext.rs index c7bb1c16..d0f2511f 100644 --- a/crates/dspy-rs/tests/test_module_ext.rs +++ b/crates/dspy-rs/tests/test_module_ext.rs @@ -1,4 +1,6 @@ -use dspy_rs::{BamlType, CallMetadata, Module, ModuleExt, ParseError, PredictError, Predicted}; +use dspy_rs::{ + BamlType, CallMetadata, Chat, Module, ModuleExt, ParseError, PredictError, Predicted, +}; struct MaybeFails; @@ -37,6 +39,7 @@ impl Module for MaybeFails { }, raw_response: format!("raw:{input_value}"), lm_usage: dspy_rs::LmUsage::default(), + chat: Chat::new(vec![]), }) } else { Ok(Predicted::new( @@ -44,6 +47,7 @@ impl Module for MaybeFails { value: input_value * 2, }, metadata, + Chat::new(vec![]), )) } } @@ -66,6 +70,7 @@ fn transform_int_payload(value: IntPayload) -> Result }, raw_response: "transform".to_string(), lm_usage: dspy_rs::LmUsage::default(), + chat: Chat::new(vec![]), }) } } diff --git a/crates/dspy-rs/tests/test_module_facet_shapes.rs b/crates/dspy-rs/tests/test_module_facet_shapes.rs index 9aaa8d07..89ce69b5 100644 --- a/crates/dspy-rs/tests/test_module_facet_shapes.rs +++ b/crates/dspy-rs/tests/test_module_facet_shapes.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "rlm")] +use dspy_rs::Rlm; use dspy_rs::{ChainOfThought, Facet, ModuleExt, PredictError, ReAct, Signature}; use facet::{self, Type, UserType}; @@ -112,3 +114,23 @@ fn and_then_shape_exposes_inner_chain_of_thought_shape() { let nested_predictor = find_field(inner.shape(), "predictor"); assert_eq!(nested_predictor.shape().type_identifier, "Predict"); } + +#[cfg(feature = "rlm")] +#[test] +fn rlm_shape_exposes_extract_and_skips_runtime_fields() { + let module = Rlm::::new(); + let shape = shape_of(&module); + + let extract = find_field(shape, "extract"); + assert!(!extract.should_skip_deserializing()); + assert_eq!(extract.shape().type_identifier, "Predict"); + + let config = find_field(shape, "config"); + let instruction_override = find_field(shape, "instruction_override"); + let sub_lm = find_field(shape, "sub_lm"); + let runtime = find_field(shape, "runtime"); + assert!(config.should_skip_deserializing()); + assert!(instruction_override.should_skip_deserializing()); + assert!(sub_lm.should_skip_deserializing()); + assert!(runtime.should_skip_deserializing()); +} diff --git a/crates/dspy-rs/tests/test_module_forward_all.rs b/crates/dspy-rs/tests/test_module_forward_all.rs index a2376455..fc28b746 100644 --- a/crates/dspy-rs/tests/test_module_forward_all.rs +++ b/crates/dspy-rs/tests/test_module_forward_all.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use dspy_rs::{BamlType, CallMetadata, Module, PredictError, Predicted, forward_all}; +use dspy_rs::{BamlType, CallMetadata, Chat, Module, PredictError, Predicted, forward_all}; use tokio::time::sleep; struct DelayEcho; @@ -27,6 +27,7 @@ impl Module for DelayEcho { Ok(Predicted::new( DelayOutput { value: input.value }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/tests/test_optimizer_named_parameters_integration.rs b/crates/dspy-rs/tests/test_optimizer_named_parameters_integration.rs index 8d4abdcd..08ec77f5 100644 --- a/crates/dspy-rs/tests/test_optimizer_named_parameters_integration.rs +++ b/crates/dspy-rs/tests/test_optimizer_named_parameters_integration.rs @@ -1,6 +1,6 @@ use anyhow::Result; use dspy_rs::{ - COPRO, CallMetadata, Example, MetricOutcome, Module, Optimizer, Predict, PredictError, + COPRO, CallMetadata, Chat, Example, MetricOutcome, Module, Optimizer, Predict, PredictError, Predicted, Signature, TypedMetric, }; @@ -33,6 +33,7 @@ impl Module for InstructionEchoModule { answer: input.prompt, }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/tests/test_optimizer_typed_metric.rs b/crates/dspy-rs/tests/test_optimizer_typed_metric.rs index c05a590d..f1983560 100644 --- a/crates/dspy-rs/tests/test_optimizer_typed_metric.rs +++ b/crates/dspy-rs/tests/test_optimizer_typed_metric.rs @@ -1,7 +1,7 @@ use anyhow::{Result, anyhow}; use dspy_rs::{ - COPRO, CallMetadata, Example, MIPROv2, MetricOutcome, Module, Optimizer, Predict, PredictError, - Predicted, Signature, TypedMetric, + COPRO, CallMetadata, Chat, Example, MIPROv2, MetricOutcome, Module, Optimizer, Predict, + PredictError, Predicted, Signature, TypedMetric, }; use std::collections::HashSet; use std::sync::{Arc, Mutex}; @@ -35,6 +35,7 @@ impl Module for InstructionEchoModule { answer: input.prompt, }, CallMetadata::default(), + Chat::new(vec![]), )) } } diff --git a/crates/dspy-rs/tests/test_predict_conversation.rs b/crates/dspy-rs/tests/test_predict_conversation.rs new file mode 100644 index 00000000..7e830340 --- /dev/null +++ b/crates/dspy-rs/tests/test_predict_conversation.rs @@ -0,0 +1,157 @@ +use dspy_rs::{ + ChatAdapter, LM, LMClient, Predict, Role, Signature, TestCompletionModel, configure, +}; +use rig::completion::{AssistantContent, CompletionRequest}; +use rig::message::{Message as RigMessage, Text, UserContent}; +use std::sync::LazyLock; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn response_with_fields(fields: &[(&str, &str)]) -> String { + let mut response = String::new(); + for (name, value) in fields { + response.push_str(&format!("[[ ## {name} ## ]]\n{value}\n\n")); + } + response.push_str("[[ ## completed ## ]]\n"); + response +} + +fn text_response(text: impl Into) -> AssistantContent { + AssistantContent::Text(Text { text: text.into() }) +} + +async fn configure_test_lm(responses: Vec) -> TestCompletionModel { + let client = TestCompletionModel::new(responses.into_iter().map(text_response)); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client.clone())) + .await + .unwrap(); + + configure(lm, ChatAdapter::new()); + + client +} + +fn request_contains_text(request: &CompletionRequest, needle: &str) -> bool { + if request + .preamble + .as_ref() + .is_some_and(|preamble| preamble.contains(needle)) + { + return true; + } + + for message in request.chat_history.iter() { + match message { + RigMessage::User { content } => { + for item in content.iter() { + if let UserContent::Text(text) = item + && text.text.contains(needle) + { + return true; + } + } + } + RigMessage::Assistant { content, .. } => { + for item in content.iter() { + match item { + AssistantContent::Text(text) if text.text.contains(needle) => return true, + AssistantContent::Reasoning(reasoning) + if reasoning.display_text().contains(needle) => + { + return true; + } + _ => {} + } + } + } + } + } + + false +} + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Conversational QA test signature. +struct ConversationQA { + #[input] + question: String, + + #[output] + answer: String, +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn forward_returns_prediction_with_chat_metadata() { + let _lock = SETTINGS_LOCK.lock().await; + let response = response_with_fields(&[("answer", "Paris")]); + let _client = configure_test_lm(vec![response]).await; + + let predict = Predict::::new(); + let input = ConversationQAInput { + question: "What is the capital of France?".to_string(), + }; + + let predicted = predict + .forward(input, None) + .await + .expect("forward should succeed"); + let chat = predicted.chat(); + + assert_eq!(chat.len(), 3); + assert_eq!(chat.messages[0].role, Role::System); + assert_eq!(chat.messages[1].role, Role::User); + assert_eq!(chat.messages[2].role, Role::Assistant); + assert_eq!(predicted.into_inner().answer, "Paris"); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn forward_with_history_supports_two_turn_roundtrip() { + let _lock = SETTINGS_LOCK.lock().await; + let first_response = response_with_fields(&[("answer", "First turn answer")]); + let second_response = response_with_fields(&[("answer", "Second turn answer")]); + let client = configure_test_lm(vec![first_response, second_response]).await; + + let predict = Predict::::new(); + let first_input = ConversationQAInput { + question: "Turn 1 question".to_string(), + }; + + let first_predicted = predict + .forward(first_input, None) + .await + .expect("first turn forward should succeed"); + let chat = first_predicted.chat().clone(); + assert_eq!(first_predicted.into_inner().answer, "First turn answer"); + + // Second turn: typed follow-up with prior history + let caller_follow_up = "Caller follow-up message"; + let second_input = ConversationQAInput { + question: caller_follow_up.to_string(), + }; + + let second_predicted = predict + .forward(second_input, Some(chat)) + .await + .expect("second turn forward should succeed"); + let second_chat = second_predicted.chat().clone(); + + assert_eq!(second_predicted.into_inner().answer, "Second turn answer"); + assert!(second_chat.len() >= 5); + + // Verify the follow-up text was sent to the LM + let last_request = client + .last_request() + .expect("test model should capture last request"); + assert!(request_contains_text(&last_request, caller_follow_up)); +} diff --git a/crates/dspy-rs/tests/test_predict_conversation_live.rs b/crates/dspy-rs/tests/test_predict_conversation_live.rs new file mode 100644 index 00000000..a12ee20d --- /dev/null +++ b/crates/dspy-rs/tests/test_predict_conversation_live.rs @@ -0,0 +1,64 @@ +use dspy_rs::{ChatAdapter, LM, Predict, Signature, configure}; +use std::sync::LazyLock; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Live multi-turn conversation signature. +struct LiveConversation { + #[input] + prompt: String, + + #[output] + answer: String, +} + +#[tokio::test] +#[ignore] // Requires real network access and provider API key(s) +async fn live_forward_with_history_two_turn_roundtrip() { + let _lock = SETTINGS_LOCK.lock().await; + + let lm = LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .temperature(0.0) + .max_tokens(256) + .build() + .await + .expect("failed to build LM for live smoke test"); + configure(lm, ChatAdapter::new()); + + let predict = Predict::::new(); + + // First turn: build and call + let first_input = LiveConversationInput { + prompt: "Reply with the word ONE.".to_string(), + }; + let first = predict + .forward(first_input, None) + .await + .expect("first turn forward failed"); + let chat = first.chat().clone(); + assert!( + !first.answer.trim().is_empty(), + "first turn answer should not be empty" + ); + + // Second turn: continue with typed follow-up and prior history + let second_input = LiveConversationInput { + prompt: "Now reply with the word TWO. Use the same answer field format.".to_string(), + }; + + let second = predict + .forward(second_input, Some(chat)) + .await + .expect("second turn forward failed"); + let chat2 = second.chat(); + + assert!( + second.answer.to_ascii_lowercase().contains("two"), + "second turn answer should include 'two', got: {}", + second.answer + ); + assert!(chat2.len() >= 5, "chat should grow across turns"); +} diff --git a/crates/dspy-rs/tests/test_react_builder.rs b/crates/dspy-rs/tests/test_react_builder.rs index 12ef96f7..56072013 100644 --- a/crates/dspy-rs/tests/test_react_builder.rs +++ b/crates/dspy-rs/tests/test_react_builder.rs @@ -31,21 +31,20 @@ fn parse_calculator_args(args: &str) -> (i64, i64) { } async fn configure_test_lm(responses: Vec) { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - let client = TestCompletionModel::new(responses.into_iter().map(text_response)); - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .build() - .await - .unwrap() - .with_client(LMClient::Test(client)) - .await - .unwrap(); - - configure(lm, ChatAdapter {}); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client)) + .await + .unwrap(); + + configure(lm, ChatAdapter::new()); } #[derive(Signature, Clone, Debug)] @@ -113,7 +112,7 @@ async fn react_builder_executes_multi_tool_calculator_loop_and_extracts_output() .await .expect("react call should succeed"); - let (result, metadata) = predicted.into_parts(); + let (result, metadata, _chat) = predicted.into_parts(); assert_eq!( add_calls.load(Ordering::SeqCst), 1, @@ -209,7 +208,7 @@ async fn react_unknown_tool_name_does_not_execute_first_tool() { }) .await .expect("react call should succeed"); - let (_, metadata) = predicted.into_parts(); + let (_, metadata, _chat) = predicted.into_parts(); assert_eq!( add_calls.load(Ordering::SeqCst), diff --git a/crates/dspy-rs/tests/test_rlm_fallback_integration.rs b/crates/dspy-rs/tests/test_rlm_fallback_integration.rs new file mode 100644 index 00000000..ef5aa30a --- /dev/null +++ b/crates/dspy-rs/tests/test_rlm_fallback_integration.rs @@ -0,0 +1,138 @@ +#![cfg(feature = "rlm")] + +use dspy_rs::modules::rlm::PyO3Runtime; +use dspy_rs::{ + ChatAdapter, LM, LMClient, PredictError, Rlm, Signature, TestCompletionModel, configure, +}; +use rig::completion::AssistantContent; +use rig::message::Text; +use std::sync::{Arc, LazyLock}; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn text_response(text: impl Into) -> AssistantContent { + AssistantContent::Text(Text { text: text.into() }) +} + +fn response_with_fields(fields: &[(&str, &str)]) -> String { + let mut response = String::new(); + for (name, value) in fields { + response.push_str(&format!("[[ ## {name} ## ]]\n{value}\n\n")); + } + response.push_str("[[ ## completed ## ]]\n"); + response +} + +async fn build_test_lm_with_client(responses: Vec) -> (LM, TestCompletionModel) { + let client = TestCompletionModel::new(responses.into_iter().map(text_response)); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .temperature(0.0) + .build(), + ) + .await + .expect("build lm") + .with_client(LMClient::Test(client.clone())) + .await + .expect("install test client"); + (lm, client) +} + +async fn configure_test_lm(responses: Vec) -> LM { + let (lm, _) = build_test_lm_with_client(responses).await; + configure(lm.clone(), ChatAdapter::new()); + lm +} + +async fn configure_test_lm_with_client(responses: Vec) -> (LM, TestCompletionModel) { + let (lm, client) = build_test_lm_with_client(responses).await; + configure(lm.clone(), ChatAdapter::new()); + (lm, client) +} + +#[derive(Signature, Clone, Debug, PartialEq)] +struct RlmFallbackSig { + #[input] + prompt: String, + #[output] + answer: String, +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn rlm_fallback_extractor_runs_after_finalization_failure_and_uses_repl_history() { + let _lock = SETTINGS_LOCK.lock().await; + let (lm, client) = configure_test_lm_with_client(vec![ + "x = 40 + 2\nprint(f'x={x}')".to_string(), + "print('still working')".to_string(), + "print('final turn still no submit')".to_string(), + response_with_fields(&[("answer", "from-fallback")]), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(3) + .enable_extraction_fallback(true) + .build(); + + let predicted = rlm + .call(RlmFallbackSigInput { + prompt: "never submit; let fallback extract".to_string(), + }) + .await + .expect("fallback extraction should produce typed output"); + + assert_eq!(predicted.answer, "from-fallback"); + let raw = &predicted.metadata().raw_response; + assert!(raw.contains("x = 40 + 2")); + assert!(raw.contains("print('final turn still no submit')")); + assert!(raw.contains("[[ ## answer ## ]]")); + + let last_request = client.last_request().expect("expected extraction request"); + let request_debug = format!("{last_request:?}"); + assert!(request_debug.contains("[[ ## repl_history ## ]]")); + assert!(request_debug.contains("=== Turn 1 ===")); + assert!(request_debug.contains("Code:")); + assert!(request_debug.contains("Output:")); + assert!(request_debug.contains("x = 40 + 2")); + assert!(request_debug.contains("[[ ## variables_info ## ]]")); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn rlm_without_extraction_fallback_returns_max_iterations_error() { + let _lock = SETTINGS_LOCK.lock().await; + let lm = configure_test_lm(vec![ + "print('turn1')".to_string(), + "print('turn2')".to_string(), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(2) + .enable_extraction_fallback(false) + .build(); + + let err = rlm + .call(RlmFallbackSigInput { + prompt: "never submit".to_string(), + }) + .await + .expect_err("expected max-iteration failure when fallback is disabled"); + match err { + PredictError::Module { source, .. } => { + assert!( + source.to_string().contains("max iterations reached (2)"), + "unexpected error: {source}" + ); + } + other => panic!("expected module error, got: {other}"), + } +} diff --git a/crates/dspy-rs/tests/test_rlm_integration_demo.rs b/crates/dspy-rs/tests/test_rlm_integration_demo.rs new file mode 100644 index 00000000..cd3778d6 --- /dev/null +++ b/crates/dspy-rs/tests/test_rlm_integration_demo.rs @@ -0,0 +1,336 @@ +#![cfg(feature = "rlm")] +#![allow(legacy_derive_helpers)] + +use dspy_rs::modules::rlm::PyO3Runtime; +use dspy_rs::{ + ChatAdapter, LM, LMClient, Rlm, Signature, TestCompletionModel, configure, rlm_type, +}; +use rig::completion::AssistantContent; +use rig::message::Text; +use std::sync::{Arc, LazyLock}; +use tokio::sync::Mutex; + +use dspy_rs::__macro_support::pyo3; +use pyo3::types::{PyAnyMethods, PyDict}; +use pyo3::{IntoPyObjectExt, Python}; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn text_response(text: impl Into) -> AssistantContent { + AssistantContent::Text(Text { text: text.into() }) +} + +async fn build_test_lm_with_client(responses: Vec) -> (LM, TestCompletionModel) { + let client = TestCompletionModel::new(responses.into_iter().map(text_response)); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .temperature(0.0) + .build(), + ) + .await + .expect("build lm") + .with_client(LMClient::Test(client.clone())) + .await + .expect("install test client"); + (lm, client) +} + +async fn configure_test_lm_with_client(responses: Vec) -> (LM, TestCompletionModel) { + let (lm, client) = build_test_lm_with_client(responses).await; + configure(lm.clone(), ChatAdapter::new()); + (lm, client) +} + +#[rlm_type] +#[rlm(iter = "keywords", index = "keywords")] +#[derive(Clone, Debug)] +struct Paper { + /// Paper title. + title: String, + /// Abstract body. + abstract_text: String, + /// Publication year. + year: i32, + /// Search keywords. + keywords: Vec, + #[rlm(skip_python)] + internal_rank: i32, +} + +#[derive(Signature, Clone, Debug)] +/// Find the most relevant papers for the query. +struct PaperSearch { + #[input] + papers: Vec, + #[input] + query: String, + #[output] + relevant_titles: Vec, + #[output] + reasoning: String, +} + +#[derive(Signature, Clone, Debug)] +/// Render-only signature for previewing a single paper object. +struct PaperPreviewSig { + #[input] + paper: Paper, + #[output] + ok: bool, +} + +fn demo_papers() -> Vec { + vec![ + Paper { + title: "Intro to Rust for LLMs".to_string(), + abstract_text: "Typed pipelines for model programs.".to_string(), + year: 2024, + keywords: vec!["rust".to_string(), "llm".to_string()], + internal_rank: 1, + }, + Paper { + title: "Graph Reasoning at Scale".to_string(), + abstract_text: "Large context retrieval and synthesis.".to_string(), + year: 2023, + keywords: vec!["graph".to_string(), "retrieval".to_string()], + internal_rank: 2, + }, + ] +} + +#[test] +fn demo_signature_generates_correct_pyclass_methods() { + Python::attach(|py| { + let paper = demo_papers().remove(0); + let py_obj = paper + .clone() + .into_py_any(py) + .expect("Paper should convert to native PyO3 object"); + let bound = py_obj.bind(py); + + assert!( + !bound.is_instance_of::(), + "Paper must inject as native object, not dict" + ); + assert_eq!( + bound + .getattr("title") + .expect("title getter") + .extract::() + .expect("title extract"), + paper.title + ); + assert!( + !bound + .hasattr("internal_rank") + .expect("hasattr internal_rank"), + "skip_python fields must not be exposed as Python attributes" + ); + + let repr = bound + .repr() + .expect("repr") + .extract::() + .expect("repr string"); + assert!(repr.contains("Paper")); + + let baml = bound.call_method0("__baml__").expect("__baml__ call"); + assert!(baml.is_instance_of::()); + let baml_dict = baml.cast::().expect("__baml__ returns dict"); + assert_eq!( + baml_dict + .get_item("title") + .expect("title get_item") + .extract::() + .expect("title value"), + "Intro to Rust for LLMs" + ); + assert_eq!( + bound + .call_method0("__len__") + .expect("len call") + .extract::() + .expect("len value"), + 2 + ); + }); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn integration_preview_matches_spec_and_prompt_is_clean() { + let _lock = SETTINGS_LOCK.lock().await; + let (lm, client) = configure_test_lm_with_client(vec![ + "dict_style_ok = True\ntry:\n _ = papers[0]['title']\nexcept Exception:\n dict_style_ok = False\nreason = f\"{type(papers[0]).__name__}:{papers[0].title}:dict={dict_style_ok}\"\nSUBMIT(relevant_titles=[papers[0].title], reasoning=reason)".to_string(), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(1) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(PaperSearchInput { + papers: demo_papers(), + query: "rust typed model pipelines".to_string(), + }) + .await + .expect("RLM demo call should complete"); + + assert_eq!(predicted.relevant_titles, vec!["Intro to Rust for LLMs"]); + assert!( + predicted + .reasoning + .contains("Paper:Intro to Rust for LLMs:dict=False"), + "reasoning should confirm native object access and dict-style failure" + ); + + let request = client + .last_request() + .expect("expected action turn request capture"); + let request_debug = format!("{request:?}"); + + assert!( + request_debug.contains("## Task"), + "system prompt should include Task section" + ); + assert!( + request_debug.contains("Find the most relevant papers for the query."), + "system prompt should include developer instruction" + ); + assert!( + !request_debug.contains("Your input fields are:"), + "adapter wrapping should be absent" + ); + assert!( + !request_debug.contains("Your objective is:"), + "adapter wrapping should be absent" + ); + + assert!( + request_debug.contains("## Input Variables"), + "{request_debug}" + ); + assert!( + request_debug.contains("Variable: `papers` (access it in your code)"), + "{request_debug}" + ); + assert!(request_debug.contains("title: string"), "{request_debug}"); + assert!( + request_debug.contains("=== Execution Receipt (Turn 1) ==="), + "{request_debug}" + ); + assert!( + request_debug.contains("Budget: 1 turn remaining |"), + "{request_debug}" + ); + assert!(request_debug.contains("[query]"), "{request_debug}"); + assert!( + request_debug.contains("=== Namespace ==="), + "{request_debug}" + ); + assert!(request_debug.contains("[Injected]"), "{request_debug}"); + assert!(request_debug.contains("[Recent]"), "{request_debug}"); + assert!(request_debug.contains(">>>"), "{request_debug}"); + assert!( + !request_debug.contains("__baml__"), + "preview should hide __baml__" + ); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn integration_preview_shows_paper_fields_and_methods() { + let _lock = SETTINGS_LOCK.lock().await; + let (lm, client) = configure_test_lm_with_client(vec!["SUBMIT(ok=True)".to_string()]).await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(1) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(PaperPreviewSigInput { + paper: demo_papers().remove(0), + }) + .await + .expect("preview demo should submit"); + assert!(predicted.ok); + + let request = client + .last_request() + .expect("expected preview request capture"); + let request_debug = format!("{request:?}"); + assert!( + request_debug.contains("Variable: `paper` (access it in your code)"), + "{request_debug}" + ); + assert!(request_debug.contains("title: string"), "{request_debug}"); + assert!( + request_debug.contains("## Input Variables"), + "{request_debug}" + ); + assert!( + !request_debug.contains("Methods:"), + "legacy methods block should not appear in new schema format" + ); + assert!( + !request_debug.contains(".__len__("), + "dunder methods should not appear in schema-facing method surface" + ); + assert!( + !request_debug.contains("__baml__"), + "preview should hide __baml__" + ); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn submit_validation_errors_are_pythonic_not_baml_internal() { + let _lock = SETTINGS_LOCK.lock().await; + let (lm, client) = configure_test_lm_with_client(vec![ + "SUBMIT(relevant_titles=123, reasoning=5)".to_string(), + "SUBMIT(relevant_titles=['Intro to Rust for LLMs'], reasoning='fixed')".to_string(), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(2) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(PaperSearchInput { + papers: demo_papers(), + query: "rust".to_string(), + }) + .await + .expect("second submit should succeed after validation feedback"); + + assert_eq!(predicted.relevant_titles, vec!["Intro to Rust for LLMs"]); + assert_eq!(predicted.reasoning, "fixed"); + + let request = client + .last_request() + .expect("expected second-turn request capture"); + let request_debug = format!("{request:?}"); + + assert!( + request_debug.contains("SubmitError: Validation failed"), + "{request_debug}" + ); + assert!(request_debug.contains("got python"), "{request_debug}"); + assert!( + !request_debug.contains("BamlValue::"), + "feedback should not leak Baml internal type names" + ); +} diff --git a/crates/dspy-rs/tests/test_rlm_live_openai_gpt52.rs b/crates/dspy-rs/tests/test_rlm_live_openai_gpt52.rs new file mode 100644 index 00000000..389c3e3b --- /dev/null +++ b/crates/dspy-rs/tests/test_rlm_live_openai_gpt52.rs @@ -0,0 +1,59 @@ +#![cfg(feature = "rlm")] + +use dspy_rs::modules::rlm::PyO3Runtime; +use dspy_rs::{ChatAdapter, LM, Rlm, Signature, configure}; +use std::sync::{Arc, LazyLock}; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +#[derive(Signature, Clone, Debug, PartialEq)] +/// Return executable Python only and call SUBMIT with the final typed answer. +struct LiveMathProblem { + #[input] + problem: String, + + #[output] + answer: i64, +} + +#[tokio::test] +#[ignore] // Requires network access + OPENAI_API_KEY +async fn live_rlm_v1_openai_responses_gpt52_end_to_end() { + let _lock = SETTINGS_LOCK.lock().await; + let _ = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set for live test"); + + let lm = LM::builder() + .model("openai-responses:gpt-5.2".to_string()) + .temperature(0.0) + .max_tokens(512) + .build() + .await + .expect("failed to build live LM"); + + configure(lm.clone(), ChatAdapter::new()); + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(12) + .max_llm_calls(8) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(LiveMathProblemInput { + problem: + "Compute 12 * 13. Respond with executable Python only (no markdown, no prose). \ +Immediately call SUBMIT(answer=). Example shape:\nresult = 12 * 13\nSUBMIT(answer=result)" + .to_string(), + }) + .await + .expect("live RLM call failed"); + + assert_eq!(predicted.answer, 156); + assert!( + predicted.metadata().raw_response.contains("SUBMIT("), + "expected SUBMIT path evidence in raw response" + ); +} diff --git a/crates/dspy-rs/tests/test_rlm_loop_integration.rs b/crates/dspy-rs/tests/test_rlm_loop_integration.rs new file mode 100644 index 00000000..553c4db5 --- /dev/null +++ b/crates/dspy-rs/tests/test_rlm_loop_integration.rs @@ -0,0 +1,246 @@ +#![cfg(feature = "rlm")] + +use dspy_rs::modules::rlm::PyO3Runtime; +use dspy_rs::{ChatAdapter, LM, LMClient, Rlm, Signature, TestCompletionModel, configure}; +use rig::completion::AssistantContent; +use rig::message::Text; +use std::sync::{Arc, LazyLock}; +use tokio::sync::Mutex; + +static SETTINGS_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +fn text_response(text: impl Into) -> AssistantContent { + AssistantContent::Text(Text { text: text.into() }) +} + +async fn build_test_lm_with_client(responses: Vec) -> (LM, TestCompletionModel) { + let client = TestCompletionModel::new(responses.into_iter().map(text_response)); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .temperature(0.0) + .build(), + ) + .await + .expect("build lm") + .with_client(LMClient::Test(client.clone())) + .await + .expect("install test client"); + (lm, client) +} + +async fn configure_test_lm(responses: Vec) -> LM { + let (lm, _) = build_test_lm_with_client(responses).await; + configure(lm.clone(), ChatAdapter::new()); + lm +} + +async fn configure_test_lm_with_client(responses: Vec) -> (LM, TestCompletionModel) { + let (lm, client) = build_test_lm_with_client(responses).await; + configure(lm.clone(), ChatAdapter::new()); + (lm, client) +} + +#[derive(Signature, Clone, Debug, PartialEq)] +struct RlmLoopSig { + #[input] + prompt: String, + #[output] + answer: String, +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn rlm_recovers_from_empty_action_then_submits() { + let _lock = SETTINGS_LOCK.lock().await; + let lm = configure_test_lm(vec![ + String::new(), + "SUBMIT(answer='recovered')".to_string(), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(3) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(RlmLoopSigInput { + prompt: "return recovered".to_string(), + }) + .await + .expect("rlm call should recover and submit"); + + assert_eq!(predicted.answer, "recovered"); + assert!( + predicted + .metadata() + .raw_response + .contains("SUBMIT(answer='recovered')") + ); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn rlm_invalid_submit_retries_then_accepts_valid_submit() { + let _lock = SETTINGS_LOCK.lock().await; + let lm = configure_test_lm(vec![ + "SUBMIT(answer=123)".to_string(), + "SUBMIT(answer='fixed')".to_string(), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(3) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(RlmLoopSigInput { + prompt: "return fixed".to_string(), + }) + .await + .expect("rlm call should retry after invalid submit"); + + assert_eq!(predicted.answer, "fixed"); + assert!( + predicted + .metadata() + .raw_response + .contains("SUBMIT(answer=123)") + ); + assert!( + predicted + .metadata() + .raw_response + .contains("SUBMIT(answer='fixed')") + ); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn rlm_v3_demo_recovers_empty_then_python_error_then_finalization_submit() { + let _lock = SETTINGS_LOCK.lock().await; + let (lm, client) = configure_test_lm_with_client(vec![ + String::new(), + "if True print('x')".to_string(), + "SUBMIT(answer='finalized')".to_string(), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(3) + .max_output_chars(500) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(RlmLoopSigInput { + prompt: "finalize with best answer".to_string(), + }) + .await + .expect("rlm should recover and submit on finalization turn"); + assert_eq!(predicted.answer, "finalized"); + + let last_request = client + .last_request() + .expect("expected final request to be captured"); + let request_debug = format!("{last_request:?}"); + assert!( + request_debug.contains("SyntaxError"), + "finalization turn should include prior python error feedback" + ); + assert!( + request_debug.contains("⚠ LAST TURN — you MUST call SUBMIT() now with your best answer."), + "finalization directive should be present on last repair turn" + ); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test] +async fn rlm_feedback_carries_truncation_marker_with_configured_budget() { + let _lock = SETTINGS_LOCK.lock().await; + let (lm, client) = configure_test_lm_with_client(vec![ + "print('abcdefghijklmnopqrstuvwxyz0123456789')".to_string(), + "SUBMIT(answer='done')".to_string(), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(lm)) + .max_iterations(2) + .max_output_chars(10) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(RlmLoopSigInput { + prompt: "test truncation".to_string(), + }) + .await + .expect("rlm should truncate feedback and still submit"); + assert_eq!(predicted.answer, "done"); + + let last_request = client + .last_request() + .expect("expected request carrying truncated feedback"); + let request_debug = format!("{last_request:?}"); + assert!(request_debug.contains("[STDOUT TRUNCATED at 10 chars (")); +} + +#[cfg_attr(miri, ignore = "MIRI has issues with tokio's I/O driver")] +#[tokio::test(flavor = "multi_thread")] +async fn rlm_sub_lm_tools_persist_state_and_decrement_budget_across_turns() { + let _lock = SETTINGS_LOCK.lock().await; + let (_action_lm, action_client) = configure_test_lm_with_client(vec![ + "single = llm_query('single')\nbatch = llm_query_batched(['left', 'right'])\ncounter = 40 + len(batch)".to_string(), + "try:\n llm_query('should_fail')\n budget_state = 'not_exhausted'\nexcept Exception as err:\n budget_state = 'exhausted' if 'budget exhausted' in str(err) else f'unexpected:{err}'\nSUBMIT(answer=f'{counter}:{budget_state}:{single}')".to_string(), + ]) + .await; + let (sub_lm, _) = build_test_lm_with_client(vec![ + "single-ok".to_string(), + "batch-a".to_string(), + "batch-b".to_string(), + ]) + .await; + + let rlm = Rlm::::builder() + .runtime(Arc::new(PyO3Runtime)) + .sub_lm(Arc::new(sub_lm)) + .max_iterations(2) + .max_llm_calls(3) + .enable_extraction_fallback(false) + .build(); + + let predicted = rlm + .call(RlmLoopSigInput { + prompt: "Use both sub-LM helpers, then submit on turn two.".to_string(), + }) + .await + .expect("rlm should complete with persisted state and enforced budget"); + + assert_eq!(predicted.answer, "42:exhausted:single-ok"); + assert!( + predicted + .metadata() + .raw_response + .contains("llm_query_batched") + ); + + let last_request = action_client + .last_request() + .expect("expected second-turn request with feedback"); + let request_debug = format!("{last_request:?}"); + assert!( + request_debug.contains("Budget: 1 turn remaining | 0 sub-LLM calls remaining"), + "second turn should see depleted sub-LM budget" + ); +} diff --git a/crates/dspy-rs/tests/test_settings.rs b/crates/dspy-rs/tests/test_settings.rs index 2b2bea2d..1c940f1a 100644 --- a/crates/dspy-rs/tests/test_settings.rs +++ b/crates/dspy-rs/tests/test_settings.rs @@ -3,31 +3,27 @@ use dspy_rs::{ChatAdapter, LM, configure, get_lm}; #[tokio::test] #[cfg_attr(miri, ignore)] async fn test_settings() { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - configure( + let lm1 = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], LM::builder() .model("openai:gpt-4o-mini".to_string()) - .build() - .await - .unwrap(), - ChatAdapter {}, - ); + .build(), + ) + .await + .unwrap(); + configure(lm1, ChatAdapter::new()); let lm = get_lm(); assert_eq!(lm.model, "openai:gpt-4o-mini"); - configure( - LM::builder() - .model("openai:gpt-4o".to_string()) - .build() - .await - .unwrap(), - ChatAdapter {}, - ); + let lm2 = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder().model("openai:gpt-4o".to_string()).build(), + ) + .await + .unwrap(); + configure(lm2, ChatAdapter::new()); let lm = get_lm(); - assert_eq!(lm.model, "openai:gpt-4o"); } diff --git a/crates/dspy-rs/tests/test_tool_call.rs b/crates/dspy-rs/tests/test_tool_call.rs index 566f36f5..8be6782e 100644 --- a/crates/dspy-rs/tests/test_tool_call.rs +++ b/crates/dspy-rs/tests/test_tool_call.rs @@ -1,4 +1,4 @@ -use dspy_rs::{Chat, LM, Message}; +use dspy_rs::{Chat, LM, Message, ToolLoopMode}; use rig::completion::ToolDefinition; use rig::tool::ToolDyn; use std::error::Error; @@ -99,7 +99,7 @@ async fn test_tool_call_with_no_tools() { chat.push_message(Message::user("What is 2 + 2?")); // Call without tools - let response = lm.call(chat, vec![]).await; + let response = lm.call(chat, vec![], ToolLoopMode::Auto).await; // Should get a text response (or network error if no real API key) if let Err(e) = &response { @@ -108,13 +108,10 @@ async fn test_tool_call_with_no_tools() { } let response = response.unwrap(); - match response.output { - Message::Assistant { content } => { - // The response should contain some mention of 4 - println!("Assistant response: {}", content); - } - _ => panic!("Expected assistant message"), - } + assert_eq!(response.output.role, dspy_rs::Role::Assistant); + let content = response.output.content(); + // The response should contain some mention of 4 + println!("Assistant response: {}", content); } #[tokio::test] @@ -138,14 +135,11 @@ async fn test_tool_call_with_calculator() { let tools: Vec> = vec![Arc::new(calculator)]; // Call with the calculator tool - let response = lm.call(chat, tools).await.unwrap(); + let response = lm.call(chat, tools, ToolLoopMode::Auto).await.unwrap(); - match response.output { - Message::Assistant { content } => { - println!("Assistant response after tool use: {}", content); - // The response should mention the result (100) or that the tool was called - assert!(content.contains("100") || content.contains("Tool call")); - } - _ => panic!("Expected assistant message"), - } + assert_eq!(response.output.role, dspy_rs::Role::Assistant); + let content = response.output.content(); + println!("Assistant response after tool use: {}", content); + // The response should mention the result (100) or that the tool was called + assert!(content.contains("100") || content.contains("Tool call")); } diff --git a/crates/dspy-rs/tests/test_typed_alias.rs b/crates/dspy-rs/tests/test_typed_alias.rs index 55118527..ef447031 100644 --- a/crates/dspy-rs/tests/test_typed_alias.rs +++ b/crates/dspy-rs/tests/test_typed_alias.rs @@ -14,7 +14,7 @@ struct AliasSignature { #[test] fn typed_alias_is_used_in_prompt_and_user_message() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let system = adapter .format_system_message_typed::() .expect("system message"); @@ -37,7 +37,7 @@ fn typed_alias_is_used_in_prompt_and_user_message() { #[test] fn typed_alias_parses_output_and_maps_to_rust_name() { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); let response = Message::assistant("[[ ## final_answer ## ]]\nHi\n\n[[ ## completed ## ]]"); let (output, metas) = adapter .parse_response_typed::(&response) diff --git a/crates/dspy-rs/tests/test_typed_prompt_format.rs b/crates/dspy-rs/tests/test_typed_prompt_format.rs index c8e9f7dd..8735149a 100644 --- a/crates/dspy-rs/tests/test_typed_prompt_format.rs +++ b/crates/dspy-rs/tests/test_typed_prompt_format.rs @@ -49,7 +49,7 @@ struct ComprehensiveSignature { } fn system_message() -> String { - let adapter = ChatAdapter; + let adapter = ChatAdapter::new(); adapter .format_system_message_typed::() .expect("system message") diff --git a/crates/dspy-rs/tests/typed_integration.rs b/crates/dspy-rs/tests/typed_integration.rs index 6b6f7968..f662d25f 100644 --- a/crates/dspy-rs/tests/typed_integration.rs +++ b/crates/dspy-rs/tests/typed_integration.rs @@ -23,21 +23,20 @@ fn text_response(text: impl Into) -> AssistantContent { } async fn configure_test_lm(responses: Vec) -> TestCompletionModel { - unsafe { - std::env::set_var("OPENAI_API_KEY", "test"); - } - let client = TestCompletionModel::new(responses.into_iter().map(text_response)); - let lm = LM::builder() - .model("openai:gpt-4o-mini".to_string()) - .build() - .await - .unwrap() - .with_client(LMClient::Test(client.clone())) - .await - .unwrap(); - - configure(lm, ChatAdapter {}); + let lm = temp_env::async_with_vars( + [("OPENAI_API_KEY", Some("test"))], + LM::builder() + .model("openai:gpt-4o-mini".to_string()) + .build(), + ) + .await + .unwrap() + .with_client(LMClient::Test(client.clone())) + .await + .unwrap(); + + configure(lm, ChatAdapter::new()); client } diff --git a/crates/dsrs-macros/Cargo.toml b/crates/dsrs-macros/Cargo.toml index 4b666638..06eb82c7 100644 --- a/crates/dsrs-macros/Cargo.toml +++ b/crates/dsrs-macros/Cargo.toml @@ -13,6 +13,10 @@ license = "Apache-2.0" [lib] proc-macro = true +[features] +default = [] +rlm = [] + [dependencies] syn = { version = "2", features = ["full"] } quote = "1" diff --git a/crates/dsrs-macros/src/lib.rs b/crates/dsrs-macros/src/lib.rs index 320d521c..d197fc51 100644 --- a/crates/dsrs-macros/src/lib.rs +++ b/crates/dsrs-macros/src/lib.rs @@ -6,7 +6,6 @@ use syn::{ Token, Visibility, parse::{Parse, ParseStream}, parse_macro_input, - spanned::Spanned, visit::Visit, }; @@ -268,7 +267,7 @@ fn parse_single_field(field: &syn::Field) -> syn::Result { )); } let template = parse_render_jinja_attr(attr)?; - validate_jinja_template(&template, attr.span())?; + validate_jinja_template(&template, attr)?; render_jinja = Some(template); } else if attr.path().is_ident("flatten") { if saw_flatten { @@ -367,7 +366,7 @@ fn parse_desc_from_attr(attr: &Attribute, attr_name: &str) -> syn::Result