From ba28ea03c67e28456f2baf2a83190c4b6da7e082 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Tue, 9 Jun 2026 13:54:40 -0400 Subject: [PATCH] feat(desktop): provider/model selection for personas and runtime-aware env injection Add structured provider and model fields to PersonaRecord, with runtime-aware env injection at spawn and deploy time. Pack import filters derived provider/model env vars to prevent stale env values from shadowing the structured fields. Signed-off-by: Will Pfleger --- crates/sprout-acp/src/acp.rs | 2 +- crates/sprout-acp/src/config.rs | 4 +- crates/sprout-agent/src/config.rs | 46 ++- crates/sprout-cli/src/commands/pack.rs | 4 +- crates/sprout-persona/src/resolve.rs | 89 +++- crates/sprout-persona/tests/e2e_env_flow.rs | 383 ++++++++++++++++++ crates/sprout-persona/tests/integration.rs | 4 +- desktop/scripts/check-file-sizes.mjs | 9 +- desktop/src-tauri/src/commands/agents.rs | 20 +- desktop/src-tauri/src/commands/personas.rs | 8 +- .../src-tauri/src/managed_agents/discovery.rs | 11 +- .../src-tauri/src/managed_agents/env_vars.rs | 41 ++ .../src/managed_agents/env_vars/tests.rs | 90 +++- desktop/src-tauri/src/managed_agents/nest.rs | 1 + .../src/managed_agents/persona_card.rs | 55 ++- .../src-tauri/src/managed_agents/personas.rs | 8 +- .../src/managed_agents/personas/tests.rs | 1 + .../src-tauri/src/managed_agents/runtime.rs | 136 +++++-- desktop/src-tauri/src/managed_agents/teams.rs | 1 + desktop/src-tauri/src/managed_agents/types.rs | 9 + desktop/src-tauri/src/migration.rs | 168 ++++++++ .../src/features/agents/ui/PersonaDialog.tsx | 25 ++ .../agents/ui/personaDialogState.test.mjs | 84 ++++ .../features/agents/ui/personaDialogState.ts | 3 + desktop/src/shared/api/tauriPersonas.ts | 7 + desktop/src/shared/api/types.ts | 4 + 26 files changed, 1136 insertions(+), 77 deletions(-) create mode 100644 crates/sprout-persona/tests/e2e_env_flow.rs diff --git a/crates/sprout-acp/src/acp.rs b/crates/sprout-acp/src/acp.rs index 5cdba8ce8..4a58f2767 100644 --- a/crates/sprout-acp/src/acp.rs +++ b/crates/sprout-acp/src/acp.rs @@ -202,7 +202,7 @@ impl AcpClient { // Callers MUST still call shutdown().await for guaranteed cleanup. .kill_on_drop(true); - // Per-persona env vars (e.g., GOOSE_PROVIDER, GOOSE_MODEL). + // Per-persona env vars (e.g., GOOSE_PROVIDER, SPROUT_AGENT_PROVIDER). // Only injected if not already set in parent env (operator precedence). for (key, value) in extra_env { if std::env::var(key).is_err() { diff --git a/crates/sprout-acp/src/config.rs b/crates/sprout-acp/src/config.rs index d9a6abbf1..7cf2d4ca9 100644 --- a/crates/sprout-acp/src/config.rs +++ b/crates/sprout-acp/src/config.rs @@ -462,7 +462,7 @@ pub struct Config { pub respond_to: RespondTo, /// Validated allowlist of pubkey hex strings (used when respond_to == Allowlist). pub respond_to_allowlist: HashSet, - /// Per-persona env vars to inject at agent spawn time (e.g., GOOSE_PROVIDER, GOOSE_MODEL). + /// Per-persona env vars to inject at agent spawn time (e.g., GOOSE_PROVIDER, GOOSE_MODEL, SPROUT_AGENT_MODEL). /// Populated from persona pack resolution. Empty when no pack is configured. pub persona_env_vars: Vec<(String, String)>, /// Whether to publish encrypted observer frames through the relay. @@ -760,7 +760,7 @@ impl Config { ( Some(persona.system_prompt), persona.model, - persona.goose_env_vars, + persona.runtime_env_vars, ) } (Some(_), None) => { diff --git a/crates/sprout-agent/src/config.rs b/crates/sprout-agent/src/config.rs index c67016da6..f4ede30d0 100644 --- a/crates/sprout-agent/src/config.rs +++ b/crates/sprout-agent/src/config.rs @@ -93,6 +93,12 @@ impl Config { databricks_host.as_deref(), databricks_model.as_deref(), )?; + + // Universal model override — any provider will use this when its own + // model env var is absent. Useful for wrapper scripts that set a single + // var regardless of which provider is active. + let sprout_agent_model = env("SPROUT_AGENT_MODEL"); + // OPENAI_COMPAT_API is only read when provider=openai, so a stray // bad value can't break an Anthropic-only deployment. // @@ -102,19 +108,28 @@ impl Config { let (api_key, model, base_url, openai_api) = match provider { Provider::Anthropic => ( req("ANTHROPIC_API_KEY")?, - req("ANTHROPIC_MODEL")?, + resolve_model( + env("ANTHROPIC_MODEL").as_deref(), + sprout_agent_model.as_deref(), + ) + .ok_or_else(|| "config: ANTHROPIC_MODEL required".to_string())?, env_or("ANTHROPIC_BASE_URL", "https://api.anthropic.com"), OpenAiApi::Auto, // unused for Anthropic ), Provider::OpenAi => ( req("OPENAI_COMPAT_API_KEY")?, - req("OPENAI_COMPAT_MODEL")?, + resolve_model( + env("OPENAI_COMPAT_MODEL").as_deref(), + sprout_agent_model.as_deref(), + ) + .ok_or_else(|| "config: OPENAI_COMPAT_MODEL required".to_string())?, env_or("OPENAI_COMPAT_BASE_URL", "https://api.openai.com/v1"), parse_openai_api(env("OPENAI_COMPAT_API").as_deref())?, ), Provider::Databricks => ( env("DATABRICKS_TOKEN").unwrap_or_default(), - databricks_model.ok_or_else(|| "config: DATABRICKS_MODEL required".to_string())?, + resolve_model(databricks_model.as_deref(), sprout_agent_model.as_deref()) + .ok_or_else(|| "config: DATABRICKS_MODEL required".to_string())?, databricks_host.ok_or_else(|| "config: DATABRICKS_HOST required".to_string())?, OpenAiApi::Chat, // Databricks invocations is chat-shaped ), @@ -232,6 +247,13 @@ fn req(k: &str) -> Result { env(k).ok_or_else(|| format!("config: {k} required")) } +/// Returns the first of `provider_model` or `universal_fallback` that is +/// `Some`, converting to an owned `String`. Returns `None` when both are +/// absent so the caller can supply a provider-specific error message. +fn resolve_model(provider_model: Option<&str>, universal_fallback: Option<&str>) -> Option { + provider_model.or(universal_fallback).map(str::to_owned) +} + fn present_nonempty(v: Option<&str>) -> bool { v.map(str::trim).is_some_and(|s| !s.is_empty()) } @@ -596,4 +618,22 @@ mod tests { assert_eq!(is_openai_host(url), want, "url={url}"); } } + + #[test] + fn resolve_model_prefers_provider_specific() { + let result = resolve_model(Some("anthropic-model"), Some("universal-model")); + assert_eq!(result.as_deref(), Some("anthropic-model")); + } + + #[test] + fn resolve_model_falls_back_to_universal() { + let result = resolve_model(None, Some("universal-model")); + assert_eq!(result.as_deref(), Some("universal-model")); + } + + #[test] + fn resolve_model_returns_none_when_both_absent() { + let result = resolve_model(None, None); + assert!(result.is_none()); + } } diff --git a/crates/sprout-cli/src/commands/pack.rs b/crates/sprout-cli/src/commands/pack.rs index e76cabdf8..de5f82fff 100644 --- a/crates/sprout-cli/src/commands/pack.rs +++ b/crates/sprout-cli/src/commands/pack.rs @@ -136,9 +136,9 @@ pub fn cmd_inspect(path: &str) -> Result<(), CliError> { prompt_preview.replace('\n', " ") ); - if !persona.goose_env_vars.is_empty() { + if !persona.runtime_env_vars.is_empty() { let env_str: Vec = persona - .goose_env_vars + .runtime_env_vars .iter() .map(|(k, v)| format!("{k}={v}")) .collect(); diff --git a/crates/sprout-persona/src/resolve.rs b/crates/sprout-persona/src/resolve.rs index dcc6e31af..47782a59a 100644 --- a/crates/sprout-persona/src/resolve.rs +++ b/crates/sprout-persona/src/resolve.rs @@ -60,8 +60,8 @@ pub struct ResolvedPersona { // Skills (bare names — reserved for future use, not yet wired) pub skills: Vec, - // Env var projection for goose subprocess - pub goose_env_vars: Vec<(String, String)>, + // Env var projection for agent subprocess + pub runtime_env_vars: Vec<(String, String)>, } /// An MCP server with env values as literals (no interpolation in this PR). @@ -106,7 +106,7 @@ pub struct ResolvedPack { /// Returns a `ResolvedPack` with fully typed, ACP-ready output for each /// persona. All merge policy (levels 3-5) is applied. MCP servers are /// merged with literal env passthrough (no `${VAR}` interpolation). -/// Goose env vars are projected from model/temperature/context config. +/// Env vars are projected from model/temperature/context config. pub fn resolve_pack(pack_dir: &Path) -> Result { let loaded = pack::load_pack(pack_dir)?; resolve_loaded_pack(&loaded) @@ -218,7 +218,7 @@ fn resolve_one_persona( let triggers = resolve_triggers(lp.triggers.as_ref()); let mcp_servers = merge_mcp_servers(shared_mcp, &lp.mcp_servers); let hooks = resolve_hooks(lp.hooks.as_ref()); - let goose_env_vars = project_env_vars(lp); + let runtime_env_vars = runtime_env_vars(lp); // Version: LoadedPersona has no per-persona version field — persona files // don't declare a version in frontmatter. The pack version is used as-is. @@ -246,7 +246,7 @@ fn resolve_one_persona( mcp_servers, hooks, skills: lp.skills.clone(), - goose_env_vars, + runtime_env_vars, } } @@ -381,17 +381,30 @@ fn resolve_hooks(hooks: Option<&crate::merge::HooksData>) -> Option Vec<(String, String)> { +fn runtime_env_vars(persona: &LoadedPersona) -> Vec<(String, String)> { let mut vars = Vec::new(); + let runtime = persona.runtime.as_deref(); if let Some(model_str) = &persona.model { let (provider, model_id) = split_model(model_str); - if let Some(p) = provider { - vars.push(("GOOSE_PROVIDER".to_owned(), p.to_owned())); + + match runtime { + Some("sprout-agent") => { + vars.push(("SPROUT_AGENT_MODEL".to_owned(), model_id.to_owned())); + if let Some(p) = provider { + vars.push(("SPROUT_AGENT_PROVIDER".to_owned(), p.to_owned())); + } + } + _ => { + if let Some(p) = provider { + vars.push(("GOOSE_PROVIDER".to_owned(), p.to_owned())); + } + vars.push(("GOOSE_MODEL".to_owned(), model_id.to_owned())); + } } - vars.push(("GOOSE_MODEL".to_owned(), model_id.to_owned())); } + // temperature and context_limit stay as GOOSE_* (only goose reads them) if let Some(temp) = persona.temperature { vars.push(("GOOSE_TEMPERATURE".to_owned(), temp.to_string())); } @@ -570,12 +583,12 @@ mod tests { assert!(resolve_hooks(None).is_none()); } - // ── project_env_vars ────────────────────────────────────────────────── + // ── runtime_env_vars ────────────────────────────────────────────────── #[test] fn env_vars_projected_from_model() { let lp = stub_persona(Some("anthropic:claude-sonnet-4-20250514"), None, None); - let vars = project_env_vars(&lp); + let vars = runtime_env_vars(&lp); let map: HashMap<&str, &str> = vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); assert_eq!(map["GOOSE_PROVIDER"], "anthropic"); assert_eq!(map["GOOSE_MODEL"], "claude-sonnet-4-20250514"); @@ -584,7 +597,7 @@ mod tests { #[test] fn env_vars_model_without_provider() { let lp = stub_persona(Some("gpt-4o"), None, None); - let vars = project_env_vars(&lp); + let vars = runtime_env_vars(&lp); let map: HashMap<&str, &str> = vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); assert!(!map.contains_key("GOOSE_PROVIDER")); assert_eq!(map["GOOSE_MODEL"], "gpt-4o"); @@ -593,7 +606,7 @@ mod tests { #[test] fn env_vars_temperature_and_context() { let lp = stub_persona(None, Some(0.7), Some(8192)); - let vars = project_env_vars(&lp); + let vars = runtime_env_vars(&lp); let map: HashMap<&str, &str> = vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); assert_eq!(map["GOOSE_TEMPERATURE"], "0.7"); assert_eq!(map["GOOSE_CONTEXT_LIMIT"], "8192"); @@ -602,17 +615,59 @@ mod tests { #[test] fn env_vars_empty_when_no_config() { let lp = stub_persona(None, None, None); - let vars = project_env_vars(&lp); + let vars = runtime_env_vars(&lp); assert!(vars.is_empty()); } #[test] fn env_vars_full_projection() { let lp = stub_persona(Some("openai:gpt-4o"), Some(0.5), Some(16384)); - let vars = project_env_vars(&lp); + let vars = runtime_env_vars(&lp); assert_eq!(vars.len(), 4); // PROVIDER, MODEL, TEMPERATURE, CONTEXT_LIMIT } + #[test] + fn runtime_env_vars_sprout_agent_emits_sprout_agent_vars() { + let mut lp = stub_persona(Some("databricks:goose-claude-4-6-opus"), None, None); + lp.runtime = Some("sprout-agent".to_owned()); + let vars = runtime_env_vars(&lp); + let map: HashMap<&str, &str> = vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); + assert_eq!(map["SPROUT_AGENT_MODEL"], "goose-claude-4-6-opus"); + assert_eq!(map["SPROUT_AGENT_PROVIDER"], "databricks"); + assert!(!map.contains_key("GOOSE_MODEL")); + assert!(!map.contains_key("GOOSE_PROVIDER")); + } + + #[test] + fn runtime_env_vars_goose_emits_goose_vars() { + let mut lp = stub_persona(Some("databricks:goose-claude-4-6-opus"), None, None); + lp.runtime = Some("goose".to_owned()); + let vars = runtime_env_vars(&lp); + let map: HashMap<&str, &str> = vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); + assert_eq!(map["GOOSE_MODEL"], "goose-claude-4-6-opus"); + assert_eq!(map["GOOSE_PROVIDER"], "databricks"); + assert!(!map.contains_key("SPROUT_AGENT_MODEL")); + } + + #[test] + fn runtime_env_vars_no_runtime_defaults_to_goose() { + let lp = stub_persona(Some("anthropic:claude-sonnet-4-20250514"), None, None); + let vars = runtime_env_vars(&lp); + let map: HashMap<&str, &str> = vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); + assert_eq!(map["GOOSE_PROVIDER"], "anthropic"); + assert_eq!(map["GOOSE_MODEL"], "claude-sonnet-4-20250514"); + } + + #[test] + fn runtime_env_vars_sprout_agent_model_without_provider() { + let mut lp = stub_persona(Some("gpt-4o"), None, None); + lp.runtime = Some("sprout-agent".to_owned()); + let vars = runtime_env_vars(&lp); + let map: HashMap<&str, &str> = vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect(); + assert_eq!(map["SPROUT_AGENT_MODEL"], "gpt-4o"); + assert!(!map.contains_key("SPROUT_AGENT_PROVIDER")); + } + // ── Full pipeline (resolve_pack via filesystem) ─────────────────────── #[test] @@ -650,7 +705,7 @@ mod tests { assert!(p.llm_provider.is_none()); assert!(p.triggers.mentions); // built-in default assert!(p.mcp_servers.is_empty()); - assert!(p.goose_env_vars.is_empty()); + assert!(p.runtime_env_vars.is_empty()); } #[test] @@ -694,7 +749,7 @@ mod tests { // Env vars projected let env_map: HashMap<&str, &str> = p - .goose_env_vars + .runtime_env_vars .iter() .map(|(k, v)| (k.as_str(), v.as_str())) .collect(); diff --git a/crates/sprout-persona/tests/e2e_env_flow.rs b/crates/sprout-persona/tests/e2e_env_flow.rs new file mode 100644 index 000000000..3246a1b86 --- /dev/null +++ b/crates/sprout-persona/tests/e2e_env_flow.rs @@ -0,0 +1,383 @@ +//! End-to-end tests for the env var flow introduced in PRs #783 and #794. +//! +//! These tests exercise the full pack-resolve pipeline and verify that: +//! - Goose personas emit GOOSE_PROVIDER, GOOSE_MODEL, GOOSE_TEMPERATURE +//! - Sprout-agent personas emit SPROUT_AGENT_MODEL, SPROUT_AGENT_PROVIDER +//! - The import filter strips derived provider/model keys but preserves knobs +//! - Multi-runtime packs produce correct per-persona env var prefixes +//! - Models without a provider prefix emit only the model key (no provider) + +use std::collections::BTreeMap; +use std::fs; + +use sprout_persona::resolve::resolve_pack; + +// ── Import filter (replicates desktop crate logic) ─────────────────────────── + +const DERIVED_PROVIDER_MODEL_ENV_KEYS: &[&str] = &[ + "GOOSE_MODEL", + "GOOSE_PROVIDER", + "SPROUT_AGENT_MODEL", + "SPROUT_AGENT_PROVIDER", +]; + +fn filter_derived(env_vars: Vec<(String, String)>) -> BTreeMap { + env_vars + .into_iter() + .filter(|(k, _)| { + !DERIVED_PROVIDER_MODEL_ENV_KEYS + .iter() + .any(|d| d.eq_ignore_ascii_case(k)) + }) + .collect() +} + +// ── Test 1: Goose persona emits correct runtime env vars ───────────────────── + +#[test] +fn resolve_pack_goose_persona_emits_correct_runtime_env_vars() { + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path(); + + fs::create_dir_all(root.join(".plugin")).unwrap(); + fs::create_dir_all(root.join("agents")).unwrap(); + + fs::write( + root.join(".plugin/plugin.json"), + r#"{ + "id": "com.test.e2e-env", + "name": "E2E Env Test", + "version": "1.0.0", + "personas": ["agents/bot.persona.md"], + "defaults": {} +}"#, + ) + .unwrap(); + + fs::write( + root.join("agents/bot.persona.md"), + r#"--- +name: "bot" +display_name: "Bot" +description: "Test bot" +model: "databricks:goose-claude-4-6-opus" +temperature: 0.7 +--- +You are a test bot. +"#, + ) + .unwrap(); + + let pack = resolve_pack(root).unwrap(); + let persona = &pack.personas[0]; + + let env: std::collections::HashMap<_, _> = persona + .runtime_env_vars + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + + assert_eq!( + env.get("GOOSE_PROVIDER"), + Some(&"databricks"), + "should emit GOOSE_PROVIDER=databricks" + ); + assert_eq!( + env.get("GOOSE_MODEL"), + Some(&"goose-claude-4-6-opus"), + "should emit GOOSE_MODEL=goose-claude-4-6-opus" + ); + assert_eq!( + env.get("GOOSE_TEMPERATURE"), + Some(&"0.7"), + "should emit GOOSE_TEMPERATURE=0.7" + ); +} + +// ── Test 2: Sprout-agent persona emits SPROUT_AGENT_* vars ─────────────────── + +#[test] +fn resolve_pack_sprout_agent_persona_emits_sprout_agent_vars() { + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path(); + + fs::create_dir_all(root.join(".plugin")).unwrap(); + fs::create_dir_all(root.join("agents")).unwrap(); + + fs::write( + root.join(".plugin/plugin.json"), + r#"{ + "id": "com.test.e2e-env", + "name": "E2E Env Test", + "version": "1.0.0", + "personas": ["agents/bot.persona.md"], + "defaults": {} +}"#, + ) + .unwrap(); + + fs::write( + root.join("agents/bot.persona.md"), + r#"--- +name: "bot" +display_name: "Bot" +description: "Test bot" +runtime: "sprout-agent" +model: "openai:gpt-4o" +--- +You are a test bot. +"#, + ) + .unwrap(); + + let pack = resolve_pack(root).unwrap(); + let persona = &pack.personas[0]; + + let env: std::collections::HashMap<_, _> = persona + .runtime_env_vars + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + + assert_eq!( + env.get("SPROUT_AGENT_MODEL"), + Some(&"gpt-4o"), + "should emit SPROUT_AGENT_MODEL=gpt-4o" + ); + assert_eq!( + env.get("SPROUT_AGENT_PROVIDER"), + Some(&"openai"), + "should emit SPROUT_AGENT_PROVIDER=openai" + ); + + // Must NOT contain GOOSE_* keys + assert!( + !env.contains_key("GOOSE_MODEL"), + "sprout-agent runtime must not emit GOOSE_MODEL" + ); + assert!( + !env.contains_key("GOOSE_PROVIDER"), + "sprout-agent runtime must not emit GOOSE_PROVIDER" + ); +} + +// ── Test 3: Import filter strips derived keys, preserves knobs ─────────────── + +#[test] +fn import_filter_strips_derived_preserves_knobs() { + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path(); + + fs::create_dir_all(root.join(".plugin")).unwrap(); + fs::create_dir_all(root.join("agents")).unwrap(); + + fs::write( + root.join(".plugin/plugin.json"), + r#"{ + "id": "com.test.e2e-env", + "name": "E2E Env Test", + "version": "1.0.0", + "personas": ["agents/bot.persona.md"], + "defaults": {} +}"#, + ) + .unwrap(); + + fs::write( + root.join("agents/bot.persona.md"), + r#"--- +name: "bot" +display_name: "Bot" +description: "Test bot" +model: "databricks:goose-claude-4-6-opus" +temperature: 0.7 +--- +You are a test bot. +"#, + ) + .unwrap(); + + let pack = resolve_pack(root).unwrap(); + let persona = &pack.personas[0]; + + // Apply the import filter (mirrors desktop import_persona_pack logic). + let filtered = filter_derived(persona.runtime_env_vars.clone()); + + // Derived provider/model keys must be stripped. + assert!( + !filtered.contains_key("GOOSE_MODEL"), + "GOOSE_MODEL must be stripped by import filter" + ); + assert!( + !filtered.contains_key("GOOSE_PROVIDER"), + "GOOSE_PROVIDER must be stripped by import filter" + ); + + // Knob keys must survive. + assert_eq!( + filtered.get("GOOSE_TEMPERATURE").map(|s| s.as_str()), + Some("0.7"), + "GOOSE_TEMPERATURE must survive the import filter" + ); +} + +// ── Test 4: Two runtimes in one pack get different env var prefixes ─────────── + +#[test] +fn full_pipeline_two_runtimes_different_env_vars() { + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path(); + + fs::create_dir_all(root.join(".plugin")).unwrap(); + fs::create_dir_all(root.join("agents")).unwrap(); + + fs::write( + root.join(".plugin/plugin.json"), + r#"{ + "id": "com.test.e2e-env", + "name": "E2E Env Test", + "version": "1.0.0", + "personas": [ + "agents/goose-bot.persona.md", + "agents/sprout-bot.persona.md" + ], + "defaults": {} +}"#, + ) + .unwrap(); + + // Goose persona (default runtime) + fs::write( + root.join("agents/goose-bot.persona.md"), + r#"--- +name: "goose-bot" +display_name: "Goose Bot" +description: "A goose runtime bot" +model: "anthropic:claude-sonnet-4-20250514" +--- +You are a goose bot. +"#, + ) + .unwrap(); + + // Sprout-agent persona + fs::write( + root.join("agents/sprout-bot.persona.md"), + r#"--- +name: "sprout-bot" +display_name: "Sprout Bot" +description: "A sprout-agent runtime bot" +runtime: "sprout-agent" +model: "openai:gpt-4o" +--- +You are a sprout bot. +"#, + ) + .unwrap(); + + let pack = resolve_pack(root).unwrap(); + assert_eq!(pack.personas.len(), 2); + + let goose = pack + .personas + .iter() + .find(|p| p.name == "goose-bot") + .expect("goose-bot should exist"); + let sprout = pack + .personas + .iter() + .find(|p| p.name == "sprout-bot") + .expect("sprout-bot should exist"); + + // Goose persona gets GOOSE_* env vars + let goose_env: std::collections::HashMap<_, _> = goose + .runtime_env_vars + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + assert_eq!(goose_env.get("GOOSE_PROVIDER"), Some(&"anthropic")); + assert_eq!( + goose_env.get("GOOSE_MODEL"), + Some(&"claude-sonnet-4-20250514") + ); + assert!( + !goose_env.contains_key("SPROUT_AGENT_MODEL"), + "goose persona must not emit SPROUT_AGENT_MODEL" + ); + assert!( + !goose_env.contains_key("SPROUT_AGENT_PROVIDER"), + "goose persona must not emit SPROUT_AGENT_PROVIDER" + ); + + // Sprout-agent persona gets SPROUT_AGENT_* env vars + let sprout_env: std::collections::HashMap<_, _> = sprout + .runtime_env_vars + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + assert_eq!(sprout_env.get("SPROUT_AGENT_MODEL"), Some(&"gpt-4o")); + assert_eq!(sprout_env.get("SPROUT_AGENT_PROVIDER"), Some(&"openai")); + assert!( + !sprout_env.contains_key("GOOSE_MODEL"), + "sprout-agent persona must not emit GOOSE_MODEL" + ); + assert!( + !sprout_env.contains_key("GOOSE_PROVIDER"), + "sprout-agent persona must not emit GOOSE_PROVIDER" + ); +} + +// ── Test 5: Model without provider prefix emits model only ─────────────────── + +#[test] +fn model_without_provider_prefix_emits_model_only() { + let tmp = tempfile::tempdir().unwrap(); + let root = tmp.path(); + + fs::create_dir_all(root.join(".plugin")).unwrap(); + fs::create_dir_all(root.join("agents")).unwrap(); + + fs::write( + root.join(".plugin/plugin.json"), + r#"{ + "id": "com.test.e2e-env", + "name": "E2E Env Test", + "version": "1.0.0", + "personas": ["agents/bot.persona.md"], + "defaults": {} +}"#, + ) + .unwrap(); + + fs::write( + root.join("agents/bot.persona.md"), + r#"--- +name: "bot" +display_name: "Bot" +description: "Test bot" +model: "gpt-4o" +--- +You are a test bot. +"#, + ) + .unwrap(); + + let pack = resolve_pack(root).unwrap(); + let persona = &pack.personas[0]; + + let env: std::collections::HashMap<_, _> = persona + .runtime_env_vars + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + + assert_eq!( + env.get("GOOSE_MODEL"), + Some(&"gpt-4o"), + "should emit GOOSE_MODEL=gpt-4o" + ); + assert!( + !env.contains_key("GOOSE_PROVIDER"), + "model without colon prefix must NOT emit GOOSE_PROVIDER" + ); +} diff --git a/crates/sprout-persona/tests/integration.rs b/crates/sprout-persona/tests/integration.rs index 7e445ee95..83f356e69 100644 --- a/crates/sprout-persona/tests/integration.rs +++ b/crates/sprout-persona/tests/integration.rs @@ -461,8 +461,8 @@ fn resolve_full_pipeline() { assert!(lep.triggers.keywords.contains(&"vulnerability".to_string())); assert!(!lep.triggers.all_messages); - // Goose env vars projected from model - let pip_env: std::collections::HashMap<_, _> = pip.goose_env_vars.iter().cloned().collect(); + // Env vars projected from model + let pip_env: std::collections::HashMap<_, _> = pip.runtime_env_vars.iter().cloned().collect(); assert_eq!( pip_env.get("GOOSE_PROVIDER").map(|s| s.as_str()), Some("anthropic") diff --git a/desktop/scripts/check-file-sizes.mjs b/desktop/scripts/check-file-sizes.mjs index cd30da97c..09e46abf5 100644 --- a/desktop/scripts/check-file-sizes.mjs +++ b/desktop/scripts/check-file-sizes.mjs @@ -30,14 +30,15 @@ const rules = [ // Do not add to this list; split the file instead. Remove each entry as its // file is broken up. Tracked as a follow-up. const overrides = new Map([ - ["src-tauri/src/commands/agents.rs", 1190], - ["src-tauri/src/managed_agents/nest.rs", 1417], - ["src-tauri/src/managed_agents/runtime.rs", 1387], + ["src-tauri/src/commands/agents.rs", 1208], + ["src-tauri/src/managed_agents/nest.rs", 1420], + ["src-tauri/src/managed_agents/runtime.rs", 1465], + ["src-tauri/src/managed_agents/persona_card.rs", 1050], ["src-tauri/src/huddle/tts.rs", 1364], ["src/shared/api/tauri.ts", 1196], ["src-tauri/src/nostr_convert.rs", 1116], ["src/shared/api/relayClientSession.ts", 1022], - ["src-tauri/src/migration.rs", 1130], + ["src-tauri/src/migration.rs", 1295], ]); await runFileSizeCheck({ diff --git a/desktop/src-tauri/src/commands/agents.rs b/desktop/src-tauri/src/commands/agents.rs index 7d06341ac..2c6933ede 100644 --- a/desktop/src-tauri/src/commands/agents.rs +++ b/desktop/src-tauri/src/commands/agents.rs @@ -131,6 +131,23 @@ fn build_deploy_payload( crate::managed_agents::resolve_persona_env(app, record.persona_id.as_deref())?; let merged_env = crate::managed_agents::merged_user_env(&persona_env, &record.env_vars); + // Resolve effective model/provider from the persona's structured fields. + // Agent record's model takes precedence (user override via UI). + let (effective_model, effective_provider) = if let Some(ref pid) = record.persona_id { + let personas = load_personas(app).map_err(|e| { + format!("failed to load personas for deploy payload model resolution: {e}") + })?; + let persona = personas.iter().find(|p| p.id == *pid); + let model = record + .model + .clone() + .or_else(|| persona.and_then(|p| p.model.clone())); + let provider = persona.and_then(|p| p.provider.clone()); + (model, provider) + } else { + (record.model.clone(), None) + }; + Ok(serde_json::json!({ "name": &record.name, "relay_url": &record.relay_url, @@ -139,7 +156,8 @@ fn build_deploy_payload( "agent_command": &record.agent_command, "agent_args": &record.agent_args, "system_prompt": &record.system_prompt, - "model": &record.model, + "model": effective_model, + "provider": effective_provider, "turn_timeout_seconds": record.turn_timeout_seconds, "idle_timeout_seconds": record.idle_timeout_seconds, "max_turn_duration_seconds": record.max_turn_duration_seconds, diff --git a/desktop/src-tauri/src/commands/personas.rs b/desktop/src-tauri/src/commands/personas.rs index 99138797a..a221e2d7a 100644 --- a/desktop/src-tauri/src/commands/personas.rs +++ b/desktop/src-tauri/src/commands/personas.rs @@ -53,6 +53,7 @@ pub fn create_persona( let avatar_url = trim_optional(input.avatar_url); let runtime = trim_optional(input.runtime); let model = trim_optional(input.model); + let provider = trim_optional(input.provider); let now = now_iso(); let _store_guard = state @@ -74,6 +75,7 @@ pub fn create_persona( system_prompt, runtime, model, + provider, name_pool, is_builtin: false, is_active: true, @@ -100,6 +102,7 @@ pub fn update_persona( let avatar_url = trim_optional(input.avatar_url); let runtime = trim_optional(input.runtime); let model = trim_optional(input.model); + let provider = trim_optional(input.provider); let _store_guard = state .managed_agents_store_lock @@ -119,6 +122,7 @@ pub fn update_persona( persona.system_prompt = system_prompt; persona.runtime = runtime; persona.model = model; + persona.provider = provider; persona.name_pool = input .name_pool .into_iter() @@ -335,7 +339,7 @@ pub async fn export_persona_to_json( // forked, distributed), and bundling API keys / credentials in them // would be a significant footgun. Users who import a card and need // credentials must supply them post-import via the persona dialog. - let (display_name, system_prompt, avatar_url, runtime, model, name_pool) = { + let (display_name, system_prompt, avatar_url, runtime, model, provider, name_pool) = { let _store_guard = state .managed_agents_store_lock .lock() @@ -351,6 +355,7 @@ pub async fn export_persona_to_json( persona.avatar_url.clone(), persona.runtime.clone(), persona.model.clone(), + persona.provider.clone(), persona.name_pool.clone(), ) }; @@ -361,6 +366,7 @@ pub async fn export_persona_to_json( avatar_url.as_deref(), runtime.as_deref(), model.as_deref(), + provider.as_deref(), &name_pool, )?; diff --git a/desktop/src-tauri/src/managed_agents/discovery.rs b/desktop/src-tauri/src/managed_agents/discovery.rs index 09c464675..e774f0a80 100644 --- a/desktop/src-tauri/src/managed_agents/discovery.rs +++ b/desktop/src-tauri/src/managed_agents/discovery.rs @@ -11,7 +11,8 @@ pub(crate) struct KnownAcpRuntime { pub commands: &'static [&'static str], pub aliases: &'static [&'static str], pub avatar_url: &'static str, - /// MCP server binary for this runtime, or `None` for no MCP server. + /// Legacy MCP server binary field. Vestigial — all agents now use sprout CLI + /// directly. Will be removed when runtime discovery is simplified. pub mcp_command: Option<&'static str>, /// Whether to enable MCP hook tools (`_Stop`, `_PostCompact`) for this agent. pub mcp_hooks: bool, @@ -32,11 +33,13 @@ pub(crate) struct KnownAcpRuntime { /// pointing to the canonical `.agents/skills/sprout-cli`. `None` → this /// runtime reads the canonical path directly or has no skill support. pub skill_dir: Option<&'static str>, + /// Whether this runtime handles model switching via ACP protocol natively. + /// Currently unused — env var injection runs unconditionally regardless of + /// this value. Retained as scaffolding for when ACP model switching matures. + #[allow(dead_code)] pub supports_acp_model_switching: bool, pub model_env_var: Option<&'static str>, - #[allow(dead_code)] pub provider_env_var: Option<&'static str>, - #[allow(dead_code)] pub provider_locked: bool, pub default_env: &'static [(&'static str, &'static str)], } @@ -149,7 +152,7 @@ const KNOWN_ACP_RUNTIMES: &[KnownAcpRuntime] = &[ adapter_install_hint: "", skill_dir: None, supports_acp_model_switching: true, - model_env_var: None, + model_env_var: Some("SPROUT_AGENT_MODEL"), provider_env_var: Some("SPROUT_AGENT_PROVIDER"), provider_locked: false, default_env: &[], diff --git a/desktop/src-tauri/src/managed_agents/env_vars.rs b/desktop/src-tauri/src/managed_agents/env_vars.rs index 95f8e2e93..5abe4118b 100644 --- a/desktop/src-tauri/src/managed_agents/env_vars.rs +++ b/desktop/src-tauri/src/managed_agents/env_vars.rs @@ -14,6 +14,47 @@ use std::collections::BTreeMap; +/// Env var keys that are *derived* from the structured `PersonaRecord.provider` +/// and `PersonaRecord.model` fields at spawn/deploy time. These must NOT be +/// persisted in `PersonaRecord.env_vars` because they would shadow the +/// structured fields after the user edits provider/model in the UI. +/// +/// At local spawn time, `runtime_metadata_env_vars` re-derives these from the +/// current structured fields, so they are always up-to-date. At remote deploy +/// time, `build_deploy_payload` projects the structured fields directly. +/// +/// Non-structured knobs (`GOOSE_TEMPERATURE`, `GOOSE_CONTEXT_LIMIT`) are NOT +/// in this list — they have no structured counterpart and must be preserved. +pub(crate) const DERIVED_PROVIDER_MODEL_ENV_KEYS: &[&str] = &[ + "GOOSE_MODEL", + "GOOSE_PROVIDER", + "SPROUT_AGENT_MODEL", + "SPROUT_AGENT_PROVIDER", +]; + +/// Returns `true` if `key` is a derived provider/model env key that should be +/// filtered out of persisted `PersonaRecord.env_vars` at pack import time. +pub(crate) fn is_derived_provider_model_key(key: &str) -> bool { + DERIVED_PROVIDER_MODEL_ENV_KEYS + .iter() + .any(|k| k.eq_ignore_ascii_case(key)) +} + +/// Strip derived provider/model env keys from a pack persona's `runtime_env_vars` +/// before persisting them in `PersonaRecord.env_vars`. +/// +/// The structured `PersonaRecord.provider` / `PersonaRecord.model` fields are +/// the authoritative source of truth. Keeping the derived copies would cause +/// stale env values to override updated structured fields at spawn/deploy time. +pub(crate) fn filter_derived_provider_model_env_vars( + env_vars: impl IntoIterator, +) -> BTreeMap { + env_vars + .into_iter() + .filter(|(k, _)| !is_derived_provider_model_key(k)) + .collect() +} + /// Env var keys that Sprout sets itself and users must not override from /// the persona/agent env_vars UI. Three categories: /// diff --git a/desktop/src-tauri/src/managed_agents/env_vars/tests.rs b/desktop/src-tauri/src/managed_agents/env_vars/tests.rs index 9edc772e0..f07c4359c 100644 --- a/desktop/src-tauri/src/managed_agents/env_vars/tests.rs +++ b/desktop/src-tauri/src/managed_agents/env_vars/tests.rs @@ -1,8 +1,9 @@ use std::collections::BTreeMap; use super::{ - display_invalid_key, is_reserved_env_key, is_well_formed_env_key, merged_user_env, - validate_user_env_keys, MAX_ENV_TOTAL_BYTES, MAX_ENV_VALUE_BYTES, RESERVED_ENV_KEYS, + display_invalid_key, filter_derived_provider_model_env_vars, is_derived_provider_model_key, + is_reserved_env_key, is_well_formed_env_key, merged_user_env, validate_user_env_keys, + DERIVED_PROVIDER_MODEL_ENV_KEYS, MAX_ENV_TOTAL_BYTES, MAX_ENV_VALUE_BYTES, RESERVED_ENV_KEYS, }; fn map(pairs: &[(&str, &str)]) -> BTreeMap { @@ -415,3 +416,88 @@ fn merged_env_drops_oversize_value() { assert!(!merged.contains_key("HUGE")); assert_eq!(merged.get("LEGIT").map(String::as_str), Some("v")); } + +// ── derived provider/model key filter ────────────────────────────── +// +// Pack import must strip derived env keys (GOOSE_MODEL, GOOSE_PROVIDER, +// SPROUT_AGENT_MODEL, SPROUT_AGENT_PROVIDER) so they don't shadow the +// structured PersonaRecord.model / PersonaRecord.provider fields after +// the user edits them in the UI. + +#[test] +fn is_derived_key_matches_all_known_keys() { + for key in DERIVED_PROVIDER_MODEL_ENV_KEYS { + assert!( + is_derived_provider_model_key(key), + "{key} should be recognized as derived" + ); + } +} + +#[test] +fn is_derived_key_is_case_insensitive() { + assert!(is_derived_provider_model_key("goose_model")); + assert!(is_derived_provider_model_key("Goose_Provider")); + assert!(is_derived_provider_model_key("sprout_agent_model")); + assert!(is_derived_provider_model_key("SPROUT_AGENT_PROVIDER")); +} + +#[test] +fn is_derived_key_does_not_match_unrelated_keys() { + assert!(!is_derived_provider_model_key("GOOSE_TEMPERATURE")); + assert!(!is_derived_provider_model_key("GOOSE_CONTEXT_LIMIT")); + assert!(!is_derived_provider_model_key("ANTHROPIC_API_KEY")); + assert!(!is_derived_provider_model_key("SPROUT_PRIVATE_KEY")); + assert!(!is_derived_provider_model_key("MODEL")); + assert!(!is_derived_provider_model_key("PROVIDER")); +} + +#[test] +fn filter_derived_strips_provider_model_keys_preserves_rest() { + let input = vec![ + ( + "GOOSE_MODEL".to_string(), + "claude-sonnet-4-20250514".to_string(), + ), + ("GOOSE_PROVIDER".to_string(), "anthropic".to_string()), + ("SPROUT_AGENT_MODEL".to_string(), "gpt-4o".to_string()), + ("SPROUT_AGENT_PROVIDER".to_string(), "openai".to_string()), + ("GOOSE_TEMPERATURE".to_string(), "0.7".to_string()), + ("ANTHROPIC_API_KEY".to_string(), "sk-test".to_string()), + ]; + let filtered = filter_derived_provider_model_env_vars(input); + assert_eq!(filtered.len(), 2); + assert_eq!( + filtered.get("GOOSE_TEMPERATURE").map(String::as_str), + Some("0.7") + ); + assert_eq!( + filtered.get("ANTHROPIC_API_KEY").map(String::as_str), + Some("sk-test") + ); +} + +#[test] +fn filter_derived_empty_input_returns_empty() { + let filtered = filter_derived_provider_model_env_vars(std::iter::empty()); + assert!(filtered.is_empty()); +} + +#[test] +fn stale_derived_env_does_not_override_structured_fields() { + // Documents that merged_user_env is transparent to derived keys — it + // does NOT strip them. The defense is the import filter + // (filter_derived_provider_model_env_vars) which prevents them from + // being persisted in the first place. If a stale record somehow has + // them, they flow through merged_user_env unchanged — the spawn-time + // re-derivation from structured fields writes AFTER merged env. + let persona_env = map(&[("GOOSE_MODEL", "stale-model"), ("LEGIT", "v")]); + let merged = merged_user_env(&persona_env, &BTreeMap::new()); + // merged_user_env does NOT filter derived keys — that's by design. + // The import filter is the boundary defense. + assert_eq!( + merged.get("GOOSE_MODEL").map(String::as_str), + Some("stale-model") + ); + assert_eq!(merged.get("LEGIT").map(String::as_str), Some("v")); +} diff --git a/desktop/src-tauri/src/managed_agents/nest.rs b/desktop/src-tauri/src/managed_agents/nest.rs index 9d8cf7488..062b4fc57 100644 --- a/desktop/src-tauri/src/managed_agents/nest.rs +++ b/desktop/src-tauri/src/managed_agents/nest.rs @@ -943,6 +943,7 @@ mod tests { system_prompt: String::new(), runtime: None, model: None, + provider: None, name_pool: vec![], is_builtin: false, is_active: true, diff --git a/desktop/src-tauri/src/managed_agents/persona_card.rs b/desktop/src-tauri/src/managed_agents/persona_card.rs index c2539da28..48466302b 100644 --- a/desktop/src-tauri/src/managed_agents/persona_card.rs +++ b/desktop/src-tauri/src/managed_agents/persona_card.rs @@ -15,6 +15,7 @@ pub struct ParsedPersonaPreview { pub avatar_data_url: Option, pub runtime: Option, pub model: Option, + pub provider: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub name_pool: Vec, pub source_file: String, @@ -82,6 +83,7 @@ pub fn parse_png_persona(png_bytes: &[u8]) -> Result, runtime: Option, model: Option, + provider: Option, name_pool: Vec, } @@ -149,6 +152,15 @@ fn extract_sprout_fields(v: &Value) -> Result { .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|s| s.to_string()); + // "llmProvider" is the LLM inference provider (e.g. "databricks", "anthropic"). + // Distinct from "runtime" (the ACP harness) and from the legacy "provider" key + // (which mapped to runtime for backward compat). + let provider = v + .get("llmProvider") + .and_then(|v| v.as_str()) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()); let name_pool = v .get("namePool") .and_then(|v| v.as_array()) @@ -166,6 +178,7 @@ fn extract_sprout_fields(v: &Value) -> Result { avatar_url, runtime, model, + provider, name_pool, }) } @@ -210,6 +223,7 @@ fn parse_chara_payload(b64: &str) -> Result { avatar_url: None, runtime: None, model: None, + provider: None, name_pool: Vec::new(), }) } @@ -228,6 +242,7 @@ pub fn parse_json_persona(json_bytes: &[u8]) -> Result, runtime: Option<&str>, model: Option<&str>, + provider: Option<&str>, name_pool: &[String], ) -> Result, String> { let mut map = serde_json::Map::new(); @@ -254,6 +270,9 @@ pub fn encode_persona_json( if let Some(m) = model { map.insert("model".to_string(), serde_json::json!(m)); } + if let Some(p) = provider { + map.insert("llmProvider".to_string(), serde_json::json!(p)); + } if !name_pool.is_empty() { map.insert("namePool".to_string(), serde_json::json!(name_pool)); } @@ -287,6 +306,7 @@ pub fn parse_md_persona(md_bytes: &[u8]) -> Result avatar_data_url: None, // .persona.md avatars are paths, not data URIs runtime: config.runtime, model, + provider: None, // .persona.md format does not carry llmProvider name_pool: Vec::new(), source_file: String::new(), }) @@ -387,6 +407,7 @@ pub fn parse_zip_pack(zip_bytes: &[u8]) -> Result Vec { system_prompt: persona.system_prompt.to_string(), runtime: persona.runtime.map(|s| s.to_string()), model: persona.model.map(|s| s.to_string()), + provider: None, name_pool: persona.name_pool.iter().map(|s| s.to_string()).collect(), is_builtin: true, is_active: true, @@ -837,12 +838,17 @@ pub fn import_persona_pack( system_prompt: p.system_prompt.clone(), runtime: p.runtime.clone(), model: p.model.clone(), + provider: p.llm_provider.clone(), name_pool: Vec::new(), is_builtin: false, is_active: true, source_pack: Some(resolved.id.clone()), source_pack_persona_slug: Some(p.name.clone()), - env_vars: std::collections::BTreeMap::new(), + // Filter derived provider/model env keys at import time so stale + // values don't shadow the structured fields after UI edits. + env_vars: crate::managed_agents::env_vars::filter_derived_provider_model_env_vars( + p.runtime_env_vars.iter().cloned(), + ), created_at: now.clone(), updated_at: now.clone(), }) diff --git a/desktop/src-tauri/src/managed_agents/personas/tests.rs b/desktop/src-tauri/src/managed_agents/personas/tests.rs index 1b8da0b2d..ae1b037bb 100644 --- a/desktop/src-tauri/src/managed_agents/personas/tests.rs +++ b/desktop/src-tauri/src/managed_agents/personas/tests.rs @@ -13,6 +13,7 @@ fn custom_persona(id: &str, display_name: &str) -> PersonaRecord { system_prompt: "Custom prompt".to_string(), runtime: None, model: None, + provider: None, name_pool: Vec::new(), is_builtin: false, is_active: true, diff --git a/desktop/src-tauri/src/managed_agents/runtime.rs b/desktop/src-tauri/src/managed_agents/runtime.rs index efe1f4f6d..00d2a701d 100644 --- a/desktop/src-tauri/src/managed_agents/runtime.rs +++ b/desktop/src-tauri/src/managed_agents/runtime.rs @@ -6,7 +6,7 @@ use crate::{ managed_agents::{ append_log_marker, known_acp_runtime, login_shell_path, managed_agent_log_path, missing_command_message, normalize_agent_args, open_log_file, resolve_command, - ManagedAgentProcess, ManagedAgentRecord, ManagedAgentSummary, + ManagedAgentProcess, ManagedAgentRecord, ManagedAgentSummary, PersonaRecord, }, util::now_iso, }; @@ -864,6 +864,7 @@ pub fn spawn_agent_child( } } // Enable MCP hook tools (_Stop, _PostCompact) for agents that need them. + // Uses "*" because build_mcp_servers() hard-codes the server name to "sprout-mcp". let runtime_meta = known_acp_runtime(&record.agent_command); if runtime_meta.is_some_and(|r| r.mcp_hooks) { command.env("MCP_HOOK_SERVERS", "*"); @@ -899,29 +900,31 @@ pub fn spawn_agent_child( command.env("SPROUT_ACP_PERSONA_NAME", persona_name); } - // Resolve system prompt and model: prefer the persona definition (if a - // persona pack is configured and the persona matched), otherwise fall back - // to the record-level overrides. + // Resolve system prompt, model, and provider: prefer the persona definition + // (if a persona pack is configured and the persona matched), otherwise fall + // back to the record-level overrides. Provider always flows from the persona + // when one is linked (the record has no provider field of its own). let has_persona_pack = record.persona_pack_path.is_some() && record.persona_name_in_pack.is_some(); - let persona_prompt_and_model: Option<(String, Option)> = has_persona_pack - .then(|| { - record - .persona_id - .as_deref() - .and_then(|pid| { - super::load_personas(app) - .ok()? - .into_iter() - .find(|p| p.id == pid) - }) - .map(|p| (p.system_prompt, p.model)) - }) - .flatten(); - - let (effective_prompt, effective_model) = match persona_prompt_and_model { - Some((prompt, model)) => (Some(prompt), model), - None => (record.system_prompt.clone(), record.model.clone()), + let persona_record: Option = record.persona_id.as_deref().and_then(|pid| { + super::load_personas(app) + .ok()? + .into_iter() + .find(|p| p.id == pid) + }); + + let (effective_prompt, effective_model, effective_provider) = if has_persona_pack { + match &persona_record { + Some(p) => ( + Some(p.system_prompt.clone()), + p.model.clone(), + p.provider.clone(), + ), + None => (record.system_prompt.clone(), record.model.clone(), None), + } + } else { + let provider = persona_record.as_ref().and_then(|p| p.provider.clone()); + (record.system_prompt.clone(), record.model.clone(), provider) }; if let Some(prompt) = &effective_prompt { @@ -935,10 +938,14 @@ pub fn spawn_agent_child( command.env_remove("SPROUT_ACP_MODEL"); } if let Some(meta) = runtime_meta { - if !meta.supports_acp_model_switching { - if let (Some(env_key), Some(model)) = (meta.model_env_var, &effective_model) { - command.env(env_key, model); - } + for (key, value) in runtime_metadata_env_vars( + meta.model_env_var, + meta.provider_env_var, + meta.provider_locked, + effective_model.as_deref(), + effective_provider.as_deref(), + ) { + command.env(key, value); } } if let Some(toolsets) = &record.mcp_toolsets { @@ -1185,6 +1192,32 @@ pub fn stop_managed_agent_process( Ok(()) } +/// Returns the (key, value) env var pairs that should be forwarded to the +/// agent process for model and provider selection. +/// +/// Model injection is unconditional — even agents that support ACP model +/// switching need the initial bootstrap value. Provider injection is skipped +/// when `provider_locked` is true (e.g. Claude runtimes that only work with +/// Anthropic). +fn runtime_metadata_env_vars<'a>( + model_env_var: Option<&'a str>, + provider_env_var: Option<&'a str>, + provider_locked: bool, + effective_model: Option<&'a str>, + effective_provider: Option<&'a str>, +) -> Vec<(&'a str, &'a str)> { + let mut vars = Vec::new(); + if let (Some(env_key), Some(model)) = (model_env_var, effective_model) { + vars.push((env_key, model)); + } + if !provider_locked { + if let (Some(env_key), Some(provider)) = (provider_env_var, effective_provider) { + vars.push((env_key, provider)); + } + } + vars +} + #[cfg(test)] mod tests { use crate::managed_agents::known_acp_runtime; @@ -1375,4 +1408,55 @@ mod tests { let err = build_respond_to_env(&rec, Some("owner")).unwrap_err(); assert!(err.contains("at least one pubkey")); } + + // ── runtime_metadata_env_vars tests ───────────────────────────────────── + + use super::runtime_metadata_env_vars; + + #[test] + fn runtime_metadata_env_vars_injects_model_and_provider() { + let vars = runtime_metadata_env_vars( + Some("GOOSE_MODEL"), + Some("GOOSE_PROVIDER"), + false, + Some("gpt-4o"), + Some("openai"), + ); + assert_eq!( + vars, + vec![("GOOSE_MODEL", "gpt-4o"), ("GOOSE_PROVIDER", "openai")] + ); + } + + #[test] + fn runtime_metadata_env_vars_skips_provider_when_locked() { + let vars = runtime_metadata_env_vars( + None, // claude has no model_env_var + None, // claude has no provider_env_var + true, // provider_locked = true + Some("claude-opus-4-7"), + Some("anthropic"), + ); + assert!(vars.is_empty()); + } + + #[test] + fn runtime_metadata_env_vars_injects_model_even_with_acp_model_switching() { + // sprout-agent has supports_acp_model_switching=true but we still inject + // the model env var because ACP model switching is post-bootstrap + let vars = runtime_metadata_env_vars( + Some("SPROUT_AGENT_MODEL"), + Some("SPROUT_AGENT_PROVIDER"), + false, + Some("goose-claude-4-6-opus"), + Some("databricks"), + ); + assert_eq!( + vars, + vec![ + ("SPROUT_AGENT_MODEL", "goose-claude-4-6-opus"), + ("SPROUT_AGENT_PROVIDER", "databricks"), + ] + ); + } } diff --git a/desktop/src-tauri/src/managed_agents/teams.rs b/desktop/src-tauri/src/managed_agents/teams.rs index 1833eeb4d..45b169970 100644 --- a/desktop/src-tauri/src/managed_agents/teams.rs +++ b/desktop/src-tauri/src/managed_agents/teams.rs @@ -331,6 +331,7 @@ mod tests { system_prompt: prompt.to_string(), runtime: None, model: None, + provider: None, name_pool: Vec::new(), is_builtin: false, is_active: true, diff --git a/desktop/src-tauri/src/managed_agents/types.rs b/desktop/src-tauri/src/managed_agents/types.rs index 4a83be996..6a160c384 100644 --- a/desktop/src-tauri/src/managed_agents/types.rs +++ b/desktop/src-tauri/src/managed_agents/types.rs @@ -28,6 +28,11 @@ pub struct PersonaRecord { /// direct). Sprout stores and passes through without interpretation. #[serde(default, skip_serializing_if = "Option::is_none")] pub model: Option, + /// LLM inference provider (e.g., 'databricks', 'anthropic', 'openai'). Optional — when set, + /// injected as the runtime's provider env var at agent creation time. When absent, the runtime + /// falls back to auto-detection (e.g., goose config file or available credentials). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub provider: Option, /// Pool of short, thematic names for bot instances created from this persona. /// When a new copy is added to a channel, a random unused name is picked from this pool. #[serde(default, skip_serializing_if = "Vec::is_empty")] @@ -279,6 +284,8 @@ pub struct CreatePersonaRequest { #[serde(default)] pub model: Option, #[serde(default)] + pub provider: Option, + #[serde(default)] pub name_pool: Vec, /// Environment variables for agents created from this persona. #[serde(default)] @@ -297,6 +304,8 @@ pub struct UpdatePersonaRequest { #[serde(default)] pub model: Option, #[serde(default)] + pub provider: Option, + #[serde(default)] pub name_pool: Vec, /// Environment variables for agents created from this persona. /// diff --git a/desktop/src-tauri/src/migration.rs b/desktop/src-tauri/src/migration.rs index eda5cae74..b006c84af 100644 --- a/desktop/src-tauri/src/migration.rs +++ b/desktop/src-tauri/src/migration.rs @@ -6,6 +6,11 @@ //! `SPROUT_SHARE_IDENTITY=1` and `SPROUT_PRIVATE_KEY` is set. All dev //! instances share the same physical files — edits in any worktree are //! immediately visible to all others. +//! +//! **Provider reconciliation** (`reconcile_provider_mcp_commands`): Per-launch +//! fix-up of `mcp_command` values in `managed-agents.json` against the +//! discovery table. Ensures known providers always have their canonical +//! `mcp_command`; unknown/custom agents are left untouched. use std::path::{Path, PathBuf}; use tauri::Manager; @@ -697,6 +702,169 @@ mod tests { serde_json::from_str(&content).unwrap() } + #[test] + fn reconcile_clears_mcp_command_for_goose() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([{ + "name": "Scout", + "agent_command": "goose", + "mcp_command": "sprout-mcp-server" + }]), + ); + reconcile_mcp_commands_in_file(&dir.path().join("agents/managed-agents.json")); + let records = read_agents_json(dir.path()); + assert_eq!(records[0]["mcp_command"], ""); + } + + #[test] + fn reconcile_clears_mcp_command_for_claude() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([{ + "name": "Claude Agent", + "agent_command": "claude-agent-acp", + "mcp_command": "sprout-mcp-server" + }]), + ); + reconcile_mcp_commands_in_file(&dir.path().join("agents/managed-agents.json")); + let records = read_agents_json(dir.path()); + assert_eq!(records[0]["mcp_command"], ""); + } + + #[test] + fn reconcile_preserves_sprout_dev_mcp() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([{ + "name": "Solo", + "agent_command": "sprout-agent", + "mcp_command": "sprout-dev-mcp" + }]), + ); + let before = + std::fs::read_to_string(dir.path().join("agents/managed-agents.json")).unwrap(); + reconcile_mcp_commands_in_file(&dir.path().join("agents/managed-agents.json")); + let after = std::fs::read_to_string(dir.path().join("agents/managed-agents.json")).unwrap(); + assert_eq!( + before, after, + "file should not be rewritten when already correct" + ); + } + + #[test] + fn reconcile_fixes_sprout_agent_if_stale() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([{ + "name": "Solo", + "agent_command": "sprout-agent", + "mcp_command": "sprout-mcp-server" + }]), + ); + reconcile_mcp_commands_in_file(&dir.path().join("agents/managed-agents.json")); + let records = read_agents_json(dir.path()); + assert_eq!(records[0]["mcp_command"], "sprout-dev-mcp"); + } + + #[test] + fn reconcile_leaves_unknown_agent_untouched() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([{ + "name": "Custom Bot", + "agent_command": "my-custom-agent", + "mcp_command": "my-custom-mcp" + }]), + ); + reconcile_mcp_commands_in_file(&dir.path().join("agents/managed-agents.json")); + let records = read_agents_json(dir.path()); + assert_eq!(records[0]["mcp_command"], "my-custom-mcp"); + } + + #[test] + fn reconcile_is_idempotent() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([{ + "name": "Scout", + "agent_command": "goose", + "mcp_command": "sprout-mcp-server" + }]), + ); + let path = dir.path().join("agents/managed-agents.json"); + reconcile_mcp_commands_in_file(&path); + let after_first = std::fs::read_to_string(&path).unwrap(); + reconcile_mcp_commands_in_file(&path); + let after_second = std::fs::read_to_string(&path).unwrap(); + assert_eq!(after_first, after_second); + } + + #[test] + fn reconcile_handles_mixed_records() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([ + {"name": "Scout", "agent_command": "goose", "mcp_command": "sprout-mcp-server"}, + {"name": "Claude", "agent_command": "claude-agent-acp", "mcp_command": "sprout-mcp-server"}, + {"name": "Solo", "agent_command": "sprout-agent", "mcp_command": "sprout-dev-mcp"}, + {"name": "Custom", "agent_command": "my-bot", "mcp_command": "my-mcp"}, + {"name": "Codex", "agent_command": "codex-acp", "mcp_command": "sprout-mcp-server"} + ]), + ); + reconcile_mcp_commands_in_file(&dir.path().join("agents/managed-agents.json")); + let records = read_agents_json(dir.path()); + assert_eq!(records[0]["mcp_command"], "", "goose should be cleared"); + assert_eq!(records[1]["mcp_command"], "", "claude should be cleared"); + assert_eq!( + records[2]["mcp_command"], "sprout-dev-mcp", + "sprout-agent preserved" + ); + assert_eq!( + records[3]["mcp_command"], "my-mcp", + "custom agent untouched" + ); + assert_eq!(records[4]["mcp_command"], "", "codex should be cleared"); + } + + #[test] + fn reconcile_adds_mcp_command_when_key_absent() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([{ + "name": "Solo", + "agent_command": "sprout-agent" + }]), + ); + reconcile_mcp_commands_in_file(&dir.path().join("agents/managed-agents.json")); + let records = read_agents_json(dir.path()); + assert_eq!(records[0]["mcp_command"], "sprout-dev-mcp"); + } + + #[test] + fn reconcile_treats_null_mcp_command_as_empty() { + let dir = tempfile::tempdir().unwrap(); + write_agents_json( + dir.path(), + &serde_json::json!([{ + "name": "Solo", + "agent_command": "sprout-agent", + "mcp_command": null + }]), + ); + reconcile_mcp_commands_in_file(&dir.path().join("agents/managed-agents.json")); + let records = read_agents_json(dir.path()); + assert_eq!(records[0]["mcp_command"], "sprout-dev-mcp"); + } + #[test] fn sync_creates_packs_directory_symlink() { let (_parent, canonical, worktree) = setup_sync_layout(); diff --git a/desktop/src/features/agents/ui/PersonaDialog.tsx b/desktop/src/features/agents/ui/PersonaDialog.tsx index f6fd3d756..4393187f5 100644 --- a/desktop/src/features/agents/ui/PersonaDialog.tsx +++ b/desktop/src/features/agents/ui/PersonaDialog.tsx @@ -66,6 +66,7 @@ export function PersonaDialog({ const [systemPrompt, setSystemPrompt] = React.useState(""); const [runtime, setRuntime] = React.useState(""); const [model, setModel] = React.useState(""); + const [provider, setProvider] = React.useState(""); const [namePoolText, setNamePoolText] = React.useState(""); const [envVars, setEnvVars] = React.useState({}); const [isImportingUpdate, setIsImportingUpdate] = React.useState(false); @@ -90,6 +91,7 @@ export function PersonaDialog({ setSystemPrompt(initialValues.systemPrompt); setRuntime(initialValues.runtime ?? ""); setModel(initialValues.model ?? ""); + setProvider(initialValues.provider ?? ""); setNamePoolText( ("namePool" in initialValues ? (initialValues as { namePool?: string[] }).namePool @@ -216,6 +218,7 @@ export function PersonaDialog({ setSystemPrompt(""); setRuntime(""); setModel(""); + setProvider(""); setNamePoolText(""); setImportErrorMessage(null); setIsImportingUpdate(false); @@ -240,6 +243,7 @@ export function PersonaDialog({ systemPrompt, runtime: runtime.trim() || undefined, model: model.trim() || undefined, + provider: provider.trim() || undefined, namePool: namePool.length > 0 ? namePool : undefined, envVars, }; @@ -405,6 +409,27 @@ export function PersonaDialog({

+
+ + setProvider(event.target.value)} + placeholder="e.g. databricks, anthropic, openai" + spellCheck={false} + value={provider} + /> +

+ Optional. Injected as the runtime's provider env var at agent + creation time. Leave blank for auto-detection or provider-locked + runtimes. +

+
+