diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 8d5534ec90..0000000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1 +0,0 @@ -schemas/ @chaynabors diff --git a/Cargo.lock b/Cargo.lock index 3a59446f65..74a6c9f8e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -948,9 +948,9 @@ dependencies = [ [[package]] name = "bm25" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84ff0d57042bc263e2ebadb3703424b59b65870902649a2b3d0f4d7ab863244" +checksum = "1cbd8ffdfb7b4c2ff038726178a780a94f90525ed0ad264c0afaa75dd8c18a64" dependencies = [ "cached", "deunicode", @@ -983,7 +983,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", - "regex-automata 0.4.9", + "regex-automata", "serde", ] @@ -1062,7 +1062,7 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9225bdcf4e4a9a4c08bf16607908eb2fbf746828d5e0b5e019726dbf6571f201" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -1207,7 +1207,7 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chat_cli" -version = "1.15.0" +version = "1.16.3" dependencies = [ "amzn-codewhisperer-client", "amzn-codewhisperer-streaming-client", @@ -1280,6 +1280,7 @@ dependencies = [ "regex", "reqwest", "ring", + "rmcp", "rusqlite", "rustls 0.23.31", "rustls-native-certs 0.8.1", @@ -1812,8 +1813,18 @@ version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", ] [[package]] @@ -1830,13 +1841,38 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.104", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", "quote", "syn 2.0.104", ] @@ -1902,7 +1938,7 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -2257,8 +2293,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" dependencies = [ "bit-set 0.5.3", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -2268,8 +2304,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298" dependencies = [ "bit-set 0.8.0", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -2775,8 +2811,8 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -3454,7 +3490,7 @@ dependencies = [ "percent-encoding", "referencing", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "reqwest", "serde", "serde_json", @@ -3617,7 +3653,7 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53304fff6ab1e597661eee37e42ea8c47a146fca280af902bb76bff8a896e523" dependencies = [ - "nu-ansi-term 0.50.1", + "nu-ansi-term", ] [[package]] @@ -3653,11 +3689,11 @@ checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -3899,16 +3935,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - [[package]] name = "nu-ansi-term" version = "0.50.1" @@ -3924,7 +3950,7 @@ version = "0.104.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5185420e479f45c9afabfb534b26282d3de13b9b286ac16851221cc17d04def3" dependencies = [ - "nu-ansi-term 0.50.1", + "nu-ansi-term", "nu-engine", "nu-json", "nu-protocol", @@ -4205,6 +4231,26 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "oauth2" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" +dependencies = [ + "base64 0.22.1", + "chrono", + "getrandom 0.2.16", + "http 1.3.1", + "rand 0.8.5", + "reqwest", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror 1.0.69", + "url", +] + [[package]] name = "objc-sys" version = "0.3.5" @@ -4454,12 +4500,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "owo-colors" version = "4.2.2" @@ -4699,6 +4739,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "process-wrap" +version = "8.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3ef4f2f0422f23a82ec9f628ea2acd12871c81a9362b02c43c1aa86acfc3ba1" +dependencies = [ + "futures", + "indexmap", + "nix 0.30.1", + "tokio", + "tracing", + "windows 0.61.3", +] + [[package]] name = "procfs" version = "0.17.0" @@ -5110,17 +5164,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -5131,7 +5176,7 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] @@ -5140,12 +5185,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -5215,6 +5254,48 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7dd163d26e254725137b7933e4ba042ea6bf2d756a4260559aaea8b6ad4c27e" +dependencies = [ + "base64 0.22.1", + "chrono", + "futures", + "http 1.3.1", + "oauth2", + "paste", + "pin-project-lite", + "process-wrap", + "reqwest", + "rmcp-macros", + "schemars", + "serde", + "serde_json", + "sse-stream", + "thiserror 2.0.14", + "tokio", + "tokio-stream", + "tokio-util", + "tower-service", + "tracing", + "url", +] + +[[package]] +name = "rmcp-macros" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a43bb4c90a0d4b12f7315eb681a73115d335a2cee81322eca96f3467fe4cd06f" +dependencies = [ + "darling 0.21.3", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.104", +] + [[package]] name = "roxmltree" version = "0.14.1" @@ -5492,6 +5573,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" dependencies = [ + "chrono", "dyn-clone", "ref-cast", "schemars_derive", @@ -5565,7 +5647,7 @@ dependencies = [ [[package]] name = "semantic_search_client" -version = "1.15.0" +version = "1.16.3" dependencies = [ "anyhow", "bm25", @@ -5659,6 +5741,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -5727,7 +5819,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fce6d5bc71503c9ec2337c80dc41f4fb2ac62fe52d6ab7500d899db19ae436f8" dependencies = [ "bitflags 2.9.1", - "nu-ansi-term 0.50.1", + "nu-ansi-term", "nu-color-config", ] @@ -5889,6 +5981,19 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -5903,9 +6008,9 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "stop-words" -version = "0.8.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c6a86be9f7fa4559b7339669e72026eb437f5e9c5a85c207fe1033079033a17" +checksum = "645a3d441ccf4bf47f2e4b7681461986681a6eeea9937d4c3bc9febd61d17c71" dependencies = [ "serde_json", ] @@ -6069,7 +6174,7 @@ dependencies = [ "once_cell", "onig", "plist", - "regex-syntax 0.8.5", + "regex-syntax", "serde", "serde_derive", "serde_json", @@ -6338,7 +6443,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "serde", "serde_json", "spm_precompiled", @@ -6600,15 +6705,15 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", - "nu-ansi-term 0.46.0", + "nu-ansi-term", "once_cell", "parking_lot", - "regex", + "regex-automata", "serde", "serde_json", "sharded-slab", @@ -6820,6 +6925,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 48d9e5d937..d2fb418eb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ authors = ["Amazon Q CLI Team (q-cli@amazon.com)", "Chay Nabors (nabochay@amazon edition = "2024" homepage = "https://aws.amazon.com/q/" publish = false -version = "1.15.0" +version = "1.16.3" license = "MIT OR Apache-2.0" [workspace.dependencies] @@ -129,6 +129,7 @@ winnow = "=0.6.2" winreg = "0.55.0" schemars = "1.0.4" jsonschema = "0.30.0" +rmcp = { version = "0.6.3", features = ["client", "transport-sse-client-reqwest", "reqwest", "transport-streamable-http-client-reqwest", "transport-child-process", "tower", "auth"] } [workspace.lints.rust] future_incompatible = "warn" diff --git a/README.md b/README.md index 3554b6c239..ecfe529a4e 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ cargo install typos-cli ## Project Layout -- [`chat_cli`](crates/chat_cli/) - the `q` CLI, allows users to interface with Amazon Q Developer from +- [`chat_cli`](crates/chat-cli/) - the `q` CLI, allows users to interface with Amazon Q Developer from the command line - [`scripts/`](scripts/) - Contains ops and build related scripts - [`crates/`](crates/) - Contains all rust crates diff --git a/crates/chat-cli/Cargo.toml b/crates/chat-cli/Cargo.toml index 1b9b78d869..37911de5f1 100644 --- a/crates/chat-cli/Cargo.toml +++ b/crates/chat-cli/Cargo.toml @@ -15,12 +15,6 @@ workspace = true default = [] wayland = ["arboard/wayland-data-control"] -[[bin]] -name = "test_mcp_server" -path = "test_mcp_server/test_server.rs" -test = true -doc = false - [dependencies] amzn-codewhisperer-client.workspace = true amzn-codewhisperer-streaming-client.workspace = true @@ -123,6 +117,7 @@ whoami.workspace = true winnow.workspace = true schemars.workspace = true jsonschema.workspace = true +rmcp.workspace = true [target.'cfg(unix)'.dependencies] nix.workspace = true diff --git a/crates/chat-cli/build.rs b/crates/chat-cli/build.rs index fe39e2dbc3..0735d05e78 100644 --- a/crates/chat-cli/build.rs +++ b/crates/chat-cli/build.rs @@ -83,6 +83,11 @@ fn write_plist() { fn main() { println!("cargo:rerun-if-changed=def.json"); + // Download feed.json if FETCH_FEED environment variable is set + if std::env::var("FETCH_FEED").is_ok() { + download_feed_json(); + } + #[cfg(target_os = "macos")] write_plist(); @@ -319,6 +324,70 @@ fn main() { let file: syn::File = syn::parse_str(&out).unwrap(); let pp = prettyplease::unparse(&file); - // write an empty file to the output directory - std::fs::write(format!("{}/mod.rs", outdir), pp).unwrap(); +/// Downloads the latest feed.json from the autocomplete repository. +/// This ensures official builds have the most up-to-date changelog information. +/// +/// # Errors +/// +/// Returns errors if: +/// - `curl` command is not available +/// - Network request fails +/// - File write operation fails +/// - Downloaded content fails validation +fn download_feed_json() { + use std::process::Command; + + println!("cargo:warning=Downloading latest feed.json from autocomplete repo..."); + + // Check if curl is available first + let curl_check = Command::new("curl").arg("--version").output(); + + if curl_check.is_err() { + eprintln!("curl command not found. Cannot download latest feed.json. Please install curl or build without FETCH_FEED=1 to use existing feed.json."); + std::process::exit(1); + } + + let output = Command::new("curl") + .args([ + "-H", + "Accept: application/vnd.github.v3.raw", + "-f", // fail on HTTP errors + "-s", // silent + "-v", // verbose output printed to stderr + "--show-error", // print error message to stderr (since -s is used) + "--max-filesize", "1048576", // 1MB limit + "", + ]) + .output(); + + match output { + Ok(result) if result.status.success() => { + // Basic validation - ensure it's valid JSON + if let Err(e) = serde_json::from_slice::(&result.stdout) { + eprintln!("Downloaded content is not valid JSON: {}", e); + std::process::exit(1); + } + + if let Err(e) = std::fs::write("src/cli/feed.json", result.stdout) { + eprintln!("Failed to write feed.json: {}", e); + std::process::exit(1); + } else { + println!("cargo:warning=Successfully downloaded latest feed.json"); + } + }, + Ok(result) => { + let error_msg = if !result.stderr.is_empty() { + format!("{}", String::from_utf8_lossy(&result.stderr)) + } else { + "An unknown error occurred".to_string() + }; + eprintln!("Failed to download feed.json: {}", error_msg); + std::process::exit(1); + }, + Err(e) => { + eprintln!("Failed to execute curl: {}", e); + std::process::exit(1); + }, + } +} } diff --git a/crates/chat-cli/src/api_client/mod.rs b/crates/chat-cli/src/api_client/mod.rs index caa2507f2d..f21b448b77 100644 --- a/crates/chat-cli/src/api_client/mod.rs +++ b/crates/chat-cli/src/api_client/mod.rs @@ -193,12 +193,19 @@ impl ApiClient { }, } - let profile = match database.get_auth_profile() { - Ok(profile) => profile, - Err(err) => { - error!("Failed to get auth profile: {err}"); - None - }, + // Check if using custom endpoint + let use_profile = !Self::is_custom_endpoint(database); + let profile = if use_profile { + match database.get_auth_profile() { + Ok(profile) => profile, + Err(err) => { + error!("Failed to get auth profile: {err}"); + None + }, + } + } else { + debug!("Custom endpoint detected, skipping profile ARN"); + None }; Ok(Self { @@ -598,6 +605,11 @@ impl ApiClient { self.mock_client = Some(Arc::new(Mutex::new(mock.into_iter()))); } + + // Add a helper method to check if using non-default endpoint + fn is_custom_endpoint(database: &Database) -> bool { + database.settings.get(Setting::ApiCodeWhispererService).is_some() + } } fn timeout_config(database: &Database) -> TimeoutConfig { diff --git a/crates/chat-cli/src/auth/mod.rs b/crates/chat-cli/src/auth/mod.rs index 4b425f2a6f..1e38864750 100644 --- a/crates/chat-cli/src/auth/mod.rs +++ b/crates/chat-cli/src/auth/mod.rs @@ -14,15 +14,17 @@ pub use builder_id::{ pub use consts::START_URL; use thiserror::Error; +use crate::aws_common::SdkErrorDisplay; + #[derive(Debug, Error)] pub enum AuthError { #[error(transparent)] Ssooidc(Box), - #[error(transparent)] + #[error("{}", SdkErrorDisplay(.0))] SdkRegisterClient(Box>), - #[error(transparent)] + #[error("{}", SdkErrorDisplay(.0))] SdkCreateToken(Box>), - #[error(transparent)] + #[error("{}", SdkErrorDisplay(.0))] SdkStartDeviceAuthorization(Box>), #[error(transparent)] Io(#[from] std::io::Error), diff --git a/crates/chat-cli/src/cli/agent/hook.rs b/crates/chat-cli/src/cli/agent/hook.rs index 1cb899f5a7..89ca74146b 100644 --- a/crates/chat-cli/src/cli/agent/hook.rs +++ b/crates/chat-cli/src/cli/agent/hook.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::fmt::Display; use schemars::JsonSchema; @@ -11,9 +10,6 @@ const DEFAULT_TIMEOUT_MS: u64 = 30_000; const DEFAULT_MAX_OUTPUT_SIZE: usize = 1024 * 10; const DEFAULT_CACHE_TTL_SECONDS: u64 = 0; -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, JsonSchema)] -pub struct Hooks(HashMap); - #[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, PartialEq, JsonSchema, Hash)] #[serde(rename_all = "camelCase")] pub enum HookTrigger { @@ -21,6 +17,10 @@ pub enum HookTrigger { AgentSpawn, /// Triggered per user message submission UserPromptSubmit, + /// Triggered before tool execution + PreToolUse, + /// Triggered after tool execution + PostToolUse, } impl Display for HookTrigger { @@ -28,6 +28,8 @@ impl Display for HookTrigger { match self { HookTrigger::AgentSpawn => write!(f, "agentSpawn"), HookTrigger::UserPromptSubmit => write!(f, "userPromptSubmit"), + HookTrigger::PreToolUse => write!(f, "preToolUse"), + HookTrigger::PostToolUse => write!(f, "postToolUse"), } } } @@ -61,6 +63,11 @@ pub struct Hook { #[serde(default = "Hook::default_cache_ttl_seconds")] pub cache_ttl_seconds: u64, + /// Optional glob matcher for hook + /// Currently used for matching tool name of PreToolUse and PostToolUse hook + #[serde(skip_serializing_if = "Option::is_none")] + pub matcher: Option, + #[schemars(skip)] #[serde(default, skip_serializing)] pub source: Source, @@ -73,6 +80,7 @@ impl Hook { timeout_ms: Self::default_timeout_ms(), max_output_size: Self::default_max_output_size(), cache_ttl_seconds: Self::default_cache_ttl_seconds(), + matcher: None, source, } } diff --git a/crates/chat-cli/src/cli/agent/legacy/hooks.rs b/crates/chat-cli/src/cli/agent/legacy/hooks.rs index 2a7a639d7f..6929049b3a 100644 --- a/crates/chat-cli/src/cli/agent/legacy/hooks.rs +++ b/crates/chat-cli/src/cli/agent/legacy/hooks.rs @@ -80,6 +80,7 @@ impl From for Option { timeout_ms: value.timeout_ms, max_output_size: value.max_output_size, cache_ttl_seconds: value.cache_ttl_seconds, + matcher: None, source: Default::default(), }) } diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index e3d89b4846..7ee44c601d 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -161,6 +161,9 @@ pub struct Agent { /// you configure in the mcpServers field in this config #[serde(default)] pub use_legacy_mcp_json: bool, + /// The model ID to use for this agent. If not specified, uses the default model. + #[serde(default)] + pub model: Option, #[serde(skip)] pub path: Option, } @@ -181,13 +184,19 @@ impl Default for Agent { set.extend(default_approve); set }, - resources: vec!["file://AmazonQ.md", "file://README.md", "file://.amazonq/rules/**/*.md"] - .into_iter() - .map(Into::into) - .collect::>(), + resources: vec![ + "file://AmazonQ.md", + "file://AGENTS.md", + "file://README.md", + "file://.amazonq/rules/**/*.md", + ] + .into_iter() + .map(Into::into) + .collect::>(), hooks: Default::default(), tools_settings: Default::default(), use_legacy_mcp_json: true, + model: None, path: None, } } @@ -659,21 +668,24 @@ impl Agents { } if let Some(user_set_default) = os.database.settings.get_string(Setting::ChatDefaultAgent) { - if all_agents.iter().any(|a| a.name == user_set_default) { - break 'active_idx user_set_default; + // Treat empty strings as "no default set" to allow clean reset + if !user_set_default.is_empty() { + if all_agents.iter().any(|a| a.name == user_set_default) { + break 'active_idx user_set_default; + } + let _ = queue!( + output, + style::SetForegroundColor(Color::Red), + style::Print("Error"), + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + ": user defined default {} not found. Falling back to in-memory default", + user_set_default + )), + style::Print("\n"), + style::SetForegroundColor(Color::Reset) + ); } - let _ = queue!( - output, - style::SetForegroundColor(Color::Red), - style::Print("Error"), - style::SetForegroundColor(Color::Yellow), - style::Print(format!( - ": user defined default {} not found. Falling back to in-memory default", - user_set_default - )), - style::Print("\n"), - style::SetForegroundColor(Color::Reset) - ); } all_agents.push({ @@ -767,32 +779,14 @@ impl Agents { /// Returns a label to describe the permission status for a given tool. pub fn display_label(&self, tool_name: &str, origin: &ToolOrigin) -> String { - use crate::util::pattern_matching::matches_any_pattern; + use crate::util::tool_permission_checker::is_tool_in_allowlist; let tool_trusted = self.get_active().is_some_and(|a| { - if matches!(origin, &ToolOrigin::Native) { - return matches_any_pattern(&a.allowed_tools, tool_name); - } - - a.allowed_tools.iter().any(|name| { - name.strip_prefix("@").is_some_and(|remainder| { - remainder - .split_once(MCP_SERVER_TOOL_DELIMITER) - .is_some_and(|(_left, right)| right == tool_name) - || remainder == >::borrow(origin) - }) || { - if let Some(server_name) = name.strip_prefix("@").and_then(|s| s.split('/').next()) { - if server_name == >::borrow(origin) { - let tool_pattern = format!("@{}/{}", server_name, tool_name); - matches_any_pattern(&a.allowed_tools, &tool_pattern) - } else { - false - } - } else { - false - } - } - }) + let server_name = match origin { + ToolOrigin::Native => None, + ToolOrigin::McpServer(_) => Some(>::borrow(origin)), + }; + is_tool_in_allowlist(&a.allowed_tools, tool_name, server_name) }); if tool_trusted || self.trust_all_tools { @@ -806,12 +800,12 @@ impl Agents { // This "static" way avoids needing to construct a tool instance. fn default_permission_label(&self, tool_name: &str) -> String { let label = match tool_name { - "fs_read" => "trusted".dark_green().bold(), + "fs_read" => "trust working directory".dark_grey(), "fs_write" => "not trusted".dark_grey(), #[cfg(not(windows))] - "execute_bash" => "trust read-only commands".dark_grey(), + "execute_bash" => "not trusted".dark_grey(), #[cfg(windows)] - "execute_cmd" => "trust read-only commands".dark_grey(), + "execute_cmd" => "not trusted".dark_grey(), "use_aws" => "trust read-only commands".dark_grey(), "report_issue" => "trusted".dark_green().bold(), "introspect" => "trusted".dark_green().bold(), @@ -950,6 +944,7 @@ mod tests { use serde_json::json; use super::*; + use crate::cli::agent::hook::Source; const INPUT: &str = r#" { "name": "some_agent", @@ -959,21 +954,21 @@ mod tests { "fetch": { "command": "fetch3.1", "args": [] }, "git": { "command": "git-mcp", "args": [] } }, - "tools": [ + "tools": [ "@git" ], "toolAliases": { "@gits/some_tool": "some_tool2" }, - "allowedTools": [ - "fs_read", + "allowedTools": [ + "fs_read", "@fetch", "@gits/git_status" ], - "resources": [ + "resources": [ "file://~/my-genai-prompts/unittest.md" ], - "toolsSettings": { + "toolsSettings": { "fs_write": { "allowedPaths": ["~/**"] }, "@git/git_status": { "git_user": "$GIT_USER" } } @@ -1133,9 +1128,9 @@ mod tests { let label = agents.display_label("fs_read", &ToolOrigin::Native); // With no active agent, it should fall back to default permissions - // fs_read has a default of "trusted" + // fs_read has a default of "trust working directory" assert!( - label.contains("trusted"), + label.contains("trust working directory"), "fs_read should show default trusted permission, instead found: {}", label ); @@ -1164,7 +1159,7 @@ mod tests { // Test default permissions for known tools let fs_read_label = agents.display_label("fs_read", &ToolOrigin::Native); assert!( - fs_read_label.contains("trusted"), + fs_read_label.contains("trust working directory"), "fs_read should be trusted by default, instead found: {}", fs_read_label ); @@ -1179,8 +1174,8 @@ mod tests { let execute_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; let execute_bash_label = agents.display_label(execute_name, &ToolOrigin::Native); assert!( - execute_bash_label.contains("read-only"), - "execute_bash should show read-only by default, instead found: {}", + execute_bash_label.contains("not trusted"), + "execute_bash should not be trusted by default, instead found: {}", execute_bash_label ); } @@ -1215,6 +1210,7 @@ mod tests { resources: Vec::new(), hooks: Default::default(), use_legacy_mcp_json: false, + model: None, path: None, }; @@ -1285,4 +1281,128 @@ mod tests { label ); } + + #[test] + fn test_agent_model_field() { + // Test deserialization with model field + let agent_json = r#"{ + "name": "test-agent", + "model": "claude-sonnet-4" + }"#; + + let agent: Agent = serde_json::from_str(agent_json).expect("Failed to deserialize agent with model"); + assert_eq!(agent.model, Some("claude-sonnet-4".to_string())); + + // Test default agent has no model + let default_agent = Agent::default(); + assert_eq!(default_agent.model, None); + + // Test serialization includes model field + let agent_with_model = Agent { + model: Some("test-model".to_string()), + ..Default::default() + }; + let serialized = serde_json::to_string(&agent_with_model).expect("Failed to serialize"); + assert!(serialized.contains("\"model\":\"test-model\"")); + } + + #[test] + fn test_agent_model_fallback_priority() { + // Test that agent model is checked and falls back correctly + let mut agents = Agents::default(); + + // Create agent with unavailable model + let agent_with_invalid_model = Agent { + name: "test-agent".to_string(), + model: Some("unavailable-model".to_string()), + ..Default::default() + }; + + agents.agents.insert("test-agent".to_string(), agent_with_invalid_model); + agents.active_idx = "test-agent".to_string(); + + // Verify the agent has the model set + assert_eq!( + agents.get_active().and_then(|a| a.model.as_ref()), + Some(&"unavailable-model".to_string()) + ); + + // Test agent without model + let agent_without_model = Agent { + name: "no-model-agent".to_string(), + model: None, + ..Default::default() + }; + + agents.agents.insert("no-model-agent".to_string(), agent_without_model); + agents.active_idx = "no-model-agent".to_string(); + + assert_eq!(agents.get_active().and_then(|a| a.model.as_ref()), None); + } + + #[test] + fn test_agent_with_hooks() { + let agent_json = json!({ + "name": "test-agent", + "hooks": { + "agentSpawn": [ + { + "command": "git status" + } + ], + "preToolUse": [ + { + "matcher": "fs_write", + "command": "validate-tool.sh" + }, + { + "matcher": "fs_read", + "command": "enforce-tdd.sh" + } + ], + "postToolUse": [ + { + "matcher": "fs_write", + "command": "format-python.sh" + } + ] + } + }); + + let agent: Agent = serde_json::from_value(agent_json).expect("Failed to deserialize agent"); + + // Verify agent name + assert_eq!(agent.name, "test-agent"); + + // Verify agentSpawn hook + assert!(agent.hooks.contains_key(&HookTrigger::AgentSpawn)); + let agent_spawn_hooks = &agent.hooks[&HookTrigger::AgentSpawn]; + assert_eq!(agent_spawn_hooks.len(), 1); + assert_eq!(agent_spawn_hooks[0].command, "git status"); + assert_eq!(agent_spawn_hooks[0].matcher, None); + + // Verify preToolUse hooks + assert!(agent.hooks.contains_key(&HookTrigger::PreToolUse)); + let pre_tool_hooks = &agent.hooks[&HookTrigger::PreToolUse]; + assert_eq!(pre_tool_hooks.len(), 2); + + assert_eq!(pre_tool_hooks[0].command, "validate-tool.sh"); + assert_eq!(pre_tool_hooks[0].matcher, Some("fs_write".to_string())); + + assert_eq!(pre_tool_hooks[1].command, "enforce-tdd.sh"); + assert_eq!(pre_tool_hooks[1].matcher, Some("fs_read".to_string())); + + // Verify postToolUse hooks + assert!(agent.hooks.contains_key(&HookTrigger::PostToolUse)); + + // Verify default values are set correctly + for hooks in agent.hooks.values() { + for hook in hooks { + assert_eq!(hook.timeout_ms, 30_000); + assert_eq!(hook.max_output_size, 10_240); + assert_eq!(hook.cache_ttl_seconds, 0); + assert_eq!(hook.source, Source::Agent); + } + } + } } diff --git a/crates/chat-cli/src/cli/agent/root_command_args.rs b/crates/chat-cli/src/cli/agent/root_command_args.rs index 469e0982ba..0f02028e50 100644 --- a/crates/chat-cli/src/cli/agent/root_command_args.rs +++ b/crates/chat-cli/src/cli/agent/root_command_args.rs @@ -46,6 +46,12 @@ pub enum AgentSubcommands { #[arg(long, short)] from: Option, }, + /// Edit an existing agent config + Edit { + /// Name of the agent to edit + #[arg(long, short)] + name: String, + }, /// Validate a config with the given path Validate { #[arg(long, short)] @@ -138,6 +144,38 @@ impl AgentArgs { path_with_file_name.display() )?; }, + Some(AgentSubcommands::Edit { name }) => { + let _agents = Agents::load(os, None, true, &mut stderr, mcp_enabled).await.0; + let (_agent, path_with_file_name) = Agent::get_agent_by_name(os, &name).await?; + + let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); + let mut cmd = std::process::Command::new(editor_cmd); + + let status = cmd.arg(&path_with_file_name).status()?; + if !status.success() { + bail!("Editor process did not exit with success"); + } + + let Ok(content) = os.fs.read(&path_with_file_name).await else { + bail!( + "Post edit validation failed. Error opening {}. Aborting", + path_with_file_name.display() + ); + }; + if let Err(e) = serde_json::from_slice::(&content) { + bail!( + "Post edit validation failed for agent '{name}' at path: {}. Malformed config detected: {e}", + path_with_file_name.display() + ); + } + + writeln!( + stderr, + "\n✏️ Edited agent {} '{}'\n", + name, + path_with_file_name.display() + )?; + }, Some(AgentSubcommands::Validate { path }) => { let mut global_mcp_config = None::; let agent = Agent::load(os, path.as_str(), &mut global_mcp_config, mcp_enabled, &mut stderr).await; @@ -386,4 +424,24 @@ mod tests { }) ); } + + #[test] + fn test_agent_subcommand_edit() { + assert_parse!( + ["agent", "edit", "--name", "existing_agent"], + RootSubcommand::Agent(AgentArgs { + cmd: Some(AgentSubcommands::Edit { + name: "existing_agent".to_string(), + }) + }) + ); + assert_parse!( + ["agent", "edit", "-n", "existing_agent"], + RootSubcommand::Agent(AgentArgs { + cmd: Some(AgentSubcommands::Edit { + name: "existing_agent".to_string(), + }) + }) + ); + } } diff --git a/crates/chat-cli/src/cli/chat/checkpoint.rs b/crates/chat-cli/src/cli/chat/checkpoint.rs new file mode 100644 index 0000000000..c5fb0b8183 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/checkpoint.rs @@ -0,0 +1,422 @@ +use std::collections::{ + HashMap, + VecDeque, +}; +use std::path::{ + Path, + PathBuf, +}; +use std::process::{ + Command, + Output, +}; + +use chrono::{ + DateTime, + Local, +}; +use crossterm::style::Stylize; +use eyre::{ + Result, + bail, + eyre, +}; +use serde::{ + Deserialize, + Serialize, +}; + +use crate::cli::ConversationState; +use crate::cli::chat::conversation::HistoryEntry; +use crate::os::Os; + +/// Manages a shadow git repository for tracking and restoring workspace changes +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CheckpointManager { + /// Path to the shadow (bare) git repository + pub shadow_repo_path: PathBuf, + + /// All checkpoints in chronological order + pub checkpoints: Vec, + + /// Fast lookup: tag -> index in checkpoints vector + pub tag_index: HashMap, + + /// Track the current turn number + pub current_turn: usize, + + /// Track tool uses within current turn + pub tools_in_turn: usize, + + /// Last user message for commit description + pub pending_user_message: Option, + + /// Whether the message has been locked for this turn + pub message_locked: bool, + + /// Cached file change statistics + #[serde(default)] + pub file_stats_cache: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct FileStats { + pub added: usize, + pub modified: usize, + pub deleted: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Checkpoint { + pub tag: String, + pub timestamp: DateTime, + pub description: String, + pub history_snapshot: VecDeque, + pub is_turn: bool, + pub tool_name: Option, +} + +impl CheckpointManager { + /// Initialize checkpoint manager automatically (when in a git repo) + pub async fn auto_init( + os: &Os, + shadow_path: impl AsRef, + current_history: &VecDeque, + ) -> Result { + if !is_git_installed() { + bail!("Git is not installed. Checkpoints require git to function."); + } + if !is_in_git_repo() { + bail!("Not in a git repository. Use '/checkpoint init' to manually enable checkpoints."); + } + + let manager = Self::manual_init(os, shadow_path, current_history).await?; + Ok(manager) + } + + /// Initialize checkpoint manager manually + pub async fn manual_init( + os: &Os, + path: impl AsRef, + current_history: &VecDeque, + ) -> Result { + let path = path.as_ref(); + os.fs.create_dir_all(path).await?; + + // Initialize bare repository + run_git(path, false, &["init", "--bare", &path.to_string_lossy()])?; + + // Configure git + configure_git(&path.to_string_lossy())?; + + // Create initial checkpoint + stage_commit_tag(&path.to_string_lossy(), "Initial state", "0")?; + + let initial_checkpoint = Checkpoint { + tag: "0".to_string(), + timestamp: Local::now(), + description: "Initial state".to_string(), + history_snapshot: current_history.clone(), + is_turn: true, + tool_name: None, + }; + + let mut tag_index = HashMap::new(); + tag_index.insert("0".to_string(), 0); + + Ok(Self { + shadow_repo_path: path.to_path_buf(), + checkpoints: vec![initial_checkpoint], + tag_index, + current_turn: 0, + tools_in_turn: 0, + pending_user_message: None, + message_locked: false, + file_stats_cache: HashMap::new(), + }) + } + + /// Create a new checkpoint point + pub fn create_checkpoint( + &mut self, + tag: &str, + description: &str, + history: &VecDeque, + is_turn: bool, + tool_name: Option, + ) -> Result<()> { + // Stage, commit and tag + stage_commit_tag(&self.shadow_repo_path.to_string_lossy(), description, tag)?; + + // Record checkpoint metadata + let checkpoint = Checkpoint { + tag: tag.to_string(), + timestamp: Local::now(), + description: description.to_string(), + history_snapshot: history.clone(), + is_turn, + tool_name, + }; + + self.checkpoints.push(checkpoint); + self.tag_index.insert(tag.to_string(), self.checkpoints.len() - 1); + + // Cache file stats for this checkpoint + if let Ok(stats) = self.compute_file_stats(tag) { + self.file_stats_cache.insert(tag.to_string(), stats); + } + + Ok(()) + } + + /// Restore workspace to a specific checkpoint + pub fn restore(&self, conversation: &mut ConversationState, tag: &str, hard: bool) -> Result<()> { + let checkpoint = self.get_checkpoint(tag)?; + + if hard { + // Hard: reset the whole work-tree to the tag + let output = run_git(&self.shadow_repo_path, true, &["reset", "--hard", tag])?; + if !output.status.success() { + bail!("Failed to restore: {}", String::from_utf8_lossy(&output.stderr)); + } + } else { + // Soft: only restore tracked files. If the tag is an empty tree, this is a no-op. + if !self.tag_has_any_paths(tag)? { + // Nothing tracked in this checkpoint -> nothing to restore; treat as success. + conversation.restore_to_checkpoint(checkpoint)?; + return Ok(()); + } + // Use checkout against work-tree + let output = run_git(&self.shadow_repo_path, true, &["checkout", tag, "--", "."])?; + if !output.status.success() { + bail!("Failed to restore: {}", String::from_utf8_lossy(&output.stderr)); + } + } + + // Restore conversation history + conversation.restore_to_checkpoint(checkpoint)?; + + Ok(()) + } + + /// Return true iff the given tag/tree has any tracked paths. + fn tag_has_any_paths(&self, tag: &str) -> eyre::Result { + // Use `git ls-tree -r --name-only ` to check if the tree is empty + let out = run_git( + &self.shadow_repo_path, + // work_tree + false, + &["ls-tree", "-r", "--name-only", tag], + )?; + Ok(!out.stdout.is_empty()) + } + + /// Get file change statistics for a checkpoint + pub fn compute_file_stats(&self, tag: &str) -> Result { + if tag == "0" { + return Ok(FileStats::default()); + } + + let prev_tag = get_previous_tag(tag); + self.compute_stats_between(&prev_tag, tag) + } + + /// Compute file statistics between two checkpoints + pub fn compute_stats_between(&self, from: &str, to: &str) -> Result { + let output = run_git(&self.shadow_repo_path, false, &["diff", "--name-status", from, to])?; + + let mut stats = FileStats::default(); + for line in String::from_utf8_lossy(&output.stdout).lines() { + if let Some((status, _)) = line.split_once('\t') { + match status.chars().next() { + Some('A') => stats.added += 1, + Some('M') => stats.modified += 1, + Some('D') => stats.deleted += 1, + Some('R' | 'C') => stats.modified += 1, + _ => {}, + } + } + } + + Ok(stats) + } + + /// Generate detailed diff between checkpoints + pub fn diff(&self, from: &str, to: &str) -> Result { + let mut result = String::new(); + + // Get file changes + let output = run_git(&self.shadow_repo_path, false, &["diff", "--name-status", from, to])?; + + for line in String::from_utf8_lossy(&output.stdout).lines() { + if let Some((status, file)) = line.split_once('\t') { + match status.chars().next() { + Some('A') => result.push_str(&format!(" + {} (added)\n", file).green().to_string()), + Some('M') => result.push_str(&format!(" ~ {} (modified)\n", file).yellow().to_string()), + Some('D') => result.push_str(&format!(" - {} (deleted)\n", file).red().to_string()), + Some('R' | 'C') => result.push_str(&format!(" ~ {} (renamed)\n", file).yellow().to_string()), + _ => {}, + } + } + } + + // Add statistics + let stat_output = run_git(&self.shadow_repo_path, false, &[ + "diff", + from, + to, + "--stat", + "--color=always", + ])?; + + if stat_output.status.success() { + result.push('\n'); + result.push_str(&String::from_utf8_lossy(&stat_output.stdout)); + } + + Ok(result) + } + + /// Check for uncommitted changes + pub fn has_changes(&self) -> Result { + let output = run_git(&self.shadow_repo_path, true, &["status", "--porcelain"])?; + Ok(!output.stdout.is_empty()) + } + + /// Clean up shadow repository + pub async fn cleanup(&self, os: &Os) -> Result<()> { + if self.shadow_repo_path.exists() { + os.fs.remove_dir_all(&self.shadow_repo_path).await?; + } + Ok(()) + } + + fn get_checkpoint(&self, tag: &str) -> Result<&Checkpoint> { + self.tag_index + .get(tag) + .and_then(|&idx| self.checkpoints.get(idx)) + .ok_or_else(|| eyre!("Checkpoint '{}' not found", tag)) + } +} + +impl Drop for CheckpointManager { + fn drop(&mut self) { + let path = self.shadow_repo_path.clone(); + // Try to spawn cleanup task + if let Ok(handle) = tokio::runtime::Handle::try_current() { + handle.spawn(async move { + let _ = tokio::fs::remove_dir_all(path).await; + }); + } else { + // Fallback to thread + std::thread::spawn(move || { + let _ = std::fs::remove_dir_all(path); + }); + } + } +} + +// Helper functions + +/// Truncate message for display +pub fn truncate_message(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + return s.to_string(); + } + + let truncated = &s[..max_len]; + if let Some(pos) = truncated.rfind(' ') { + format!("{}...", &truncated[..pos]) + } else { + format!("{}...", truncated) + } +} + +pub const CHECKPOINT_MESSAGE_MAX_LENGTH: usize = 60; + +fn is_git_installed() -> bool { + Command::new("git") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +fn is_in_git_repo() -> bool { + Command::new("git") + .args(["rev-parse", "--is-inside-work-tree"]) + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +fn configure_git(shadow_path: &str) -> Result<()> { + run_git(Path::new(shadow_path), false, &["config", "user.name", "Q"])?; + run_git(Path::new(shadow_path), false, &["config", "user.email", "qcli@local"])?; + run_git(Path::new(shadow_path), false, &["config", "core.preloadindex", "true"])?; + Ok(()) +} + +fn stage_commit_tag(shadow_path: &str, message: &str, tag: &str) -> Result<()> { + // Stage all changes + run_git(Path::new(shadow_path), true, &["add", "-A"])?; + + // Commit + let output = run_git(Path::new(shadow_path), true, &[ + "commit", + "--allow-empty", + "--no-verify", + "-m", + message, + ])?; + + if !output.status.success() { + bail!("Git commit failed: {}", String::from_utf8_lossy(&output.stderr)); + } + + // Tag + let output = run_git(Path::new(shadow_path), false, &["tag", tag])?; + if !output.status.success() { + bail!("Git tag failed: {}", String::from_utf8_lossy(&output.stderr)); + } + + Ok(()) +} + +fn run_git(dir: &Path, with_work_tree: bool, args: &[&str]) -> Result { + let mut cmd = Command::new("git"); + cmd.arg(format!("--git-dir={}", dir.display())); + + if with_work_tree { + cmd.arg("--work-tree=."); + } + + cmd.args(args); + + let output = cmd.output()?; + if !output.status.success() && !output.stderr.is_empty() { + bail!(String::from_utf8_lossy(&output.stderr).to_string()); + } + + Ok(output) +} + +fn get_previous_tag(tag: &str) -> String { + // Parse turn.tool format + if let Some((turn_str, tool_str)) = tag.split_once('.') { + if let Ok(tool_num) = tool_str.parse::() { + return if tool_num > 1 { + format!("{}.{}", turn_str, tool_num - 1) + } else { + turn_str.to_string() + }; + } + } + + // Parse turn-only format + if let Ok(turn) = tag.parse::() { + return turn.saturating_sub(1).to_string(); + } + + "0".to_string() +} diff --git a/crates/chat-cli/src/cli/chat/cli/changelog.rs b/crates/chat-cli/src/cli/chat/cli/changelog.rs new file mode 100644 index 0000000000..5578c599c4 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/cli/changelog.rs @@ -0,0 +1,23 @@ +use clap::Args; +use eyre::Result; + +use crate::cli::chat::{ + ChatError, + ChatSession, + ChatState, +}; +use crate::util::ui; + +#[derive(Debug, PartialEq, Args)] +pub struct ChangelogArgs {} + +impl ChangelogArgs { + pub async fn execute(self, session: &mut ChatSession) -> Result { + // Use the shared rendering function from util::ui + ui::render_changelog_content(&mut session.stderr).map_err(|e| ChatError::Std(std::io::Error::other(e)))?; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } +} diff --git a/crates/chat-cli/src/cli/chat/cli/checkpoint.rs b/crates/chat-cli/src/cli/chat/cli/checkpoint.rs new file mode 100644 index 0000000000..634da119c3 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/cli/checkpoint.rs @@ -0,0 +1,573 @@ +use std::io::Write; + +use clap::Subcommand; +use crossterm::style::{ + Attribute, + Color, + StyledContent, + Stylize, +}; +use crossterm::{ + execute, + style, +}; +use dialoguer::Select; + +use crate::cli::chat::checkpoint::{ + Checkpoint, + CheckpointManager, + FileStats, +}; +use crate::cli::chat::{ + ChatError, + ChatSession, + ChatState, +}; +use crate::database::settings::Setting; +use crate::os::Os; +use crate::util::directories::get_shadow_repo_dir; + +#[derive(Debug, PartialEq, Subcommand)] +pub enum CheckpointSubcommand { + /// Initialize checkpoints manually + Init, + + /// Restore workspace to a checkpoint + #[command( + about = "Restore workspace to a checkpoint", + long_about = r#"Restore files to a checkpoint . If is omitted, you'll pick one interactively. + +Default mode: + • Restores tracked file changes + • Keeps new files created after the checkpoint + +With --hard: + • Exactly matches the checkpoint state + • Removes files created after the checkpoint"# + )] + Restore { + /// Checkpoint tag (e.g., 3 or 3.1). Leave empty to select interactively. + tag: Option, + + /// Exactly match checkpoint state (removes newer files) + #[arg(long)] + hard: bool, + }, + + /// List all checkpoints + List { + /// Limit number of results shown + #[arg(short, long)] + limit: Option, + }, + + /// Delete the shadow repository + Clean, + + /// Show details of a checkpoint + Expand { + /// Checkpoint tag to expand + tag: String, + }, + + /// Show differences between checkpoints + Diff { + /// First checkpoint tag + tag1: String, + + /// Second checkpoint tag (defaults to current state) + #[arg(required = false)] + tag2: Option, + }, +} + +impl CheckpointSubcommand { + pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { + // Check if checkpoint is enabled + if !os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + { + execute!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("\nCheckpoint is disabled. Enable it with: q settings chat.enableCheckpoint true\n"), + style::SetForegroundColor(Color::Reset) + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + // Check if in tangent mode - captures are disabled during tangent mode + if session.conversation.is_in_tangent_mode() { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print( + "⚠️ Checkpoint is disabled while in tangent mode. Disable tangent mode with: q settings -d chat.enableTangentMode.\n\n" + ), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + match self { + Self::Init => self.handle_init(os, session).await, + Self::Restore { ref tag, hard } => self.handle_restore(session, tag.clone(), hard).await, + Self::List { limit } => Self::handle_list(session, limit), + Self::Clean => self.handle_clean(os, session).await, + Self::Expand { ref tag } => Self::handle_expand(session, tag.clone()), + Self::Diff { ref tag1, ref tag2 } => Self::handle_diff(session, tag1.clone(), tag2.clone()), + } + } + + async fn handle_init(&self, os: &Os, session: &mut ChatSession) -> Result { + if session.conversation.checkpoint_manager.is_some() { + execute!( + session.stderr, + style::SetForegroundColor(Color::Blue), + style::Print( + "✓ Checkpoints are already enabled for this session! Use /checkpoint list to see current checkpoints.\n" + ), + style::SetForegroundColor(Color::Reset) + )?; + } else { + let path = get_shadow_repo_dir(os, session.conversation.conversation_id().to_string()) + .map_err(|e| ChatError::Custom(e.to_string().into()))?; + + let start = std::time::Instant::now(); + session.conversation.checkpoint_manager = Some( + CheckpointManager::manual_init(os, path, session.conversation.history()) + .await + .map_err(|e| ChatError::Custom(format!("Checkpoints could not be initialized: {e}").into()))?, + ); + + execute!( + session.stderr, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!( + "📷 Checkpoints are enabled! (took {:.2}s)\n", + start.elapsed().as_secs_f32() + )), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset), + )?; + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + async fn handle_restore( + &self, + session: &mut ChatSession, + tag: Option, + hard: bool, + ) -> Result { + // Take manager out temporarily to avoid borrow issues + let Some(manager) = session.conversation.checkpoint_manager.take() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Checkpoints not enabled. Use '/checkpoint init' to enable.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + let tag_result = if let Some(tag) = tag { + Ok(tag) + } else { + // Interactive selection + match gather_turn_checkpoints(&manager) { + Ok(entries) => { + if let Some(idx) = select_checkpoint(&entries, "Select checkpoint to restore:") { + Ok(entries[idx].tag.clone()) + } else { + Err(()) + } + }, + Err(e) => { + session.conversation.checkpoint_manager = Some(manager); + return Err(ChatError::Custom(format!("Failed to gather checkpoints: {}", e).into())); + }, + } + }; + + let tag = match tag_result { + Ok(tag) => tag, + Err(_) => { + session.conversation.checkpoint_manager = Some(manager); + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }, + }; + + match manager.restore(&mut session.conversation, &tag, hard) { + Ok(_) => { + execute!( + session.stderr, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!("✓ Restored to checkpoint {}\n", tag)), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset), + )?; + session.conversation.checkpoint_manager = Some(manager); + }, + Err(e) => { + session.conversation.checkpoint_manager = Some(manager); + return Err(ChatError::Custom(format!("Failed to restore: {}", e).into())); + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + fn handle_list(session: &mut ChatSession, limit: Option) -> Result { + let Some(manager) = session.conversation.checkpoint_manager.as_ref() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Checkpoints not enabled. Use '/checkpoint init' to enable.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + print_checkpoints(manager, &mut session.stderr, limit) + .map_err(|e| ChatError::Custom(format!("Could not display all checkpoints: {}", e).into()))?; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + async fn handle_clean(&self, os: &Os, session: &mut ChatSession) -> Result { + let Some(manager) = session.conversation.checkpoint_manager.take() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ ️Checkpoints not enabled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + // Print the path that will be deleted + execute!( + session.stderr, + style::Print(format!("Deleting: {}\n", manager.shadow_repo_path.display())) + )?; + + match manager.cleanup(os).await { + Ok(()) => { + execute!( + session.stderr, + style::SetAttribute(Attribute::Bold), + style::Print("✓ Deleted shadow repository for this session.\n"), + style::SetAttribute(Attribute::Reset), + )?; + }, + Err(e) => { + session.conversation.checkpoint_manager = Some(manager); + return Err(ChatError::Custom(format!("Failed to clean: {e}").into())); + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + fn handle_expand(session: &mut ChatSession, tag: String) -> Result { + let Some(manager) = session.conversation.checkpoint_manager.as_ref() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ ️Checkpoints not enabled. Use '/checkpoint init' to enable.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + expand_checkpoint(manager, &mut session.stderr, &tag) + .map_err(|e| ChatError::Custom(format!("Failed to expand checkpoint: {}", e).into()))?; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + fn handle_diff(session: &mut ChatSession, tag1: String, tag2: Option) -> Result { + let Some(manager) = session.conversation.checkpoint_manager.as_ref() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Checkpoints not enabled. Use '/checkpoint init' to enable.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + let tag2 = tag2.unwrap_or_else(|| "HEAD".to_string()); + + // Validate tags exist + if tag1 != "HEAD" && !manager.tag_index.contains_key(&tag1) { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + "⚠️ Checkpoint '{}' not found! Use /checkpoint list to see available checkpoints\n", + tag1 + )), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + if tag2 != "HEAD" && !manager.tag_index.contains_key(&tag2) { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + "⚠️ Checkpoint '{}' not found! Use /checkpoint list to see available checkpoints\n", + tag2 + )), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + let header = if tag2 == "HEAD" { + format!("Changes since checkpoint {}:\n", tag1) + } else { + format!("Changes from {} to {}:\n", tag1, tag2) + }; + + execute!( + session.stderr, + style::SetForegroundColor(Color::Blue), + style::Print(header), + style::SetForegroundColor(Color::Reset), + )?; + + match manager.diff(&tag1, &tag2) { + Ok(diff) => { + if diff.trim().is_empty() { + execute!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("No changes.\n"), + style::SetForegroundColor(Color::Reset), + )?; + } else { + execute!(session.stderr, style::Print(diff))?; + } + }, + Err(e) => { + return Err(ChatError::Custom(format!("Failed to generate diff: {e}").into())); + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } +} + +// Display helpers + +struct CheckpointDisplay { + tag: String, + parts: Vec>, +} + +impl CheckpointDisplay { + fn from_checkpoint(checkpoint: &Checkpoint, manager: &CheckpointManager) -> Result { + let mut parts = Vec::new(); + + // Tag + parts.push(format!("[{}] ", checkpoint.tag).blue()); + + // Content + if checkpoint.is_turn { + // Turn checkpoint: show timestamp and description + parts.push( + format!( + "{} - {}", + checkpoint.timestamp.format("%Y-%m-%d %H:%M:%S"), + checkpoint.description + ) + .reset(), + ); + + // Add file stats if available + if let Some(stats) = manager.file_stats_cache.get(&checkpoint.tag) { + let stats_str = format_stats(stats); + if !stats_str.is_empty() { + parts.push(format!(" ({})", stats_str).dark_grey()); + } + } + } else { + // Tool checkpoint: show tool name and description + let tool_name = checkpoint.tool_name.clone().unwrap_or_else(|| "Tool".to_string()); + parts.push(format!("{}: ", tool_name).magenta()); + parts.push(checkpoint.description.clone().reset()); + } + + Ok(Self { + tag: checkpoint.tag.clone(), + parts, + }) + } +} + +impl std::fmt::Display for CheckpointDisplay { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for part in &self.parts { + write!(f, "{}", part)?; + } + Ok(()) + } +} + +fn format_stats(stats: &FileStats) -> String { + let mut parts = Vec::new(); + + if stats.added > 0 { + parts.push(format!("+{}", stats.added)); + } + if stats.modified > 0 { + parts.push(format!("~{}", stats.modified)); + } + if stats.deleted > 0 { + parts.push(format!("-{}", stats.deleted)); + } + + parts.join(" ") +} + +fn gather_turn_checkpoints(manager: &CheckpointManager) -> Result, eyre::Report> { + manager + .checkpoints + .iter() + .filter(|c| c.is_turn) + .map(|c| CheckpointDisplay::from_checkpoint(c, manager)) + .collect() +} + +fn print_checkpoints( + manager: &CheckpointManager, + output: &mut impl Write, + limit: Option, +) -> Result<(), eyre::Report> { + let entries = gather_turn_checkpoints(manager)?; + let limit = limit.unwrap_or(entries.len()); + + for entry in entries.iter().take(limit) { + execute!(output, style::Print(&entry), style::Print("\n"))?; + } + + Ok(()) +} + +fn expand_checkpoint(manager: &CheckpointManager, output: &mut impl Write, tag: &str) -> Result<(), eyre::Report> { + let Some(&idx) = manager.tag_index.get(tag) else { + execute!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print(format!("⚠️ checkpoint '{}' not found\n", tag)), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(()); + }; + + let checkpoint = &manager.checkpoints[idx]; + + // Print main checkpoint + let display = CheckpointDisplay::from_checkpoint(checkpoint, manager)?; + execute!(output, style::Print(&display), style::Print("\n"))?; + + if !checkpoint.is_turn { + return Ok(()); + } + + // Print tool checkpoints for this turn + let mut tool_checkpoints = Vec::new(); + for i in (0..idx).rev() { + let c = &manager.checkpoints[i]; + if c.is_turn { + break; + } + tool_checkpoints.push((i, CheckpointDisplay::from_checkpoint(c, manager)?)); + } + + for (checkpoint_idx, display) in tool_checkpoints.iter().rev() { + // Compute stats for this tool + let curr_tag = &manager.checkpoints[*checkpoint_idx].tag; + let prev_tag = if *checkpoint_idx > 0 { + &manager.checkpoints[checkpoint_idx - 1].tag + } else { + "0" + }; + + let stats_str = manager + .compute_stats_between(prev_tag, curr_tag) + .map(|s| format_stats(&s)) + .unwrap_or_default(); + + execute!( + output, + style::SetForegroundColor(Color::Blue), + style::Print(" └─ "), + style::Print(display), + style::SetForegroundColor(Color::Reset), + )?; + + if !stats_str.is_empty() { + execute!( + output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!(" ({})", stats_str)), + style::SetForegroundColor(Color::Reset), + )?; + } + + execute!(output, style::Print("\n"))?; + } + + Ok(()) +} + +fn select_checkpoint(entries: &[CheckpointDisplay], prompt: &str) -> Option { + Select::with_theme(&crate::util::dialoguer_theme()) + .with_prompt(prompt) + .items(entries) + .report(false) + .interact_opt() + .unwrap_or(None) +} diff --git a/crates/chat-cli/src/cli/chat/cli/clear.rs b/crates/chat-cli/src/cli/chat/cli/clear.rs index 7f2bd9d9ae..994ed35e03 100644 --- a/crates/chat-cli/src/cli/chat/cli/clear.rs +++ b/crates/chat-cli/src/cli/chat/cli/clear.rs @@ -17,6 +17,7 @@ use crate::cli::chat::{ #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] +/// Arguments for the clear command that erases conversation history and context. pub struct ClearArgs; impl ClearArgs { @@ -47,10 +48,16 @@ impl ClearArgs { }; if ["y", "Y"].contains(&user_input.as_str()) { - session.conversation.clear(true); + session.conversation.clear(); if let Some(cm) = session.conversation.context_manager.as_mut() { cm.hook_executor.cache.clear(); } + + // Reset pending tool state to prevent orphaned tool approval prompts + session.tool_uses.clear(); + session.pending_tool_index = None; + session.tool_turn_start_time = None; + execute!( session.stderr, style::SetForegroundColor(Color::Green), diff --git a/crates/chat-cli/src/cli/chat/cli/compact.rs b/crates/chat-cli/src/cli/chat/cli/compact.rs index 79e5727ef4..27b9b1a465 100644 --- a/crates/chat-cli/src/cli/chat/cli/compact.rs +++ b/crates/chat-cli/src/cli/chat/cli/compact.rs @@ -31,6 +31,12 @@ How it works Compaction will be automatically performed whenever the context window overflows. To disable this behavior, run: `q settings chat.disableAutoCompaction true`" )] +/// Arguments for the `/compact` command that summarizes conversation history to free up context +/// space. +/// +/// This command creates an AI-generated summary of the conversation while preserving essential +/// information, code, and tool executions. It's useful for long-running conversations that +/// may reach memory constraints. pub struct CompactArgs { /// The prompt to use when generating the summary prompt: Vec, diff --git a/crates/chat-cli/src/cli/chat/cli/context.rs b/crates/chat-cli/src/cli/chat/cli/context.rs index df008330cf..025ef440a0 100644 --- a/crates/chat-cli/src/cli/chat/cli/context.rs +++ b/crates/chat-cli/src/cli/chat/cli/context.rs @@ -38,6 +38,7 @@ Notes: • Agent rules apply only to the current agent • Context changes are NOT preserved between chat sessions. To make these changes permanent, edit the agent config file." )] +/// Subcommands for managing context rules and files in Amazon Q chat sessions pub enum ContextSubcommand { /// Display the context rule configuration and matched files Show { @@ -52,17 +53,20 @@ pub enum ContextSubcommand { #[arg(short, long)] force: bool, #[arg(required = true)] + /// Paths or glob patterns to remove from context rules paths: Vec, }, /// Remove specified rules #[command(alias = "rm")] Remove { + /// Paths or glob patterns to remove from context rules #[arg(required = true)] paths: Vec, }, /// Remove all rules Clear, #[command(hide = true)] + /// Display information about agent format hooks (deprecated) Hooks, } diff --git a/crates/chat-cli/src/cli/chat/cli/editor.rs b/crates/chat-cli/src/cli/chat/cli/editor.rs index 53ddc54ddf..ff0433e9e4 100644 --- a/crates/chat-cli/src/cli/chat/cli/editor.rs +++ b/crates/chat-cli/src/cli/chat/cli/editor.rs @@ -15,7 +15,9 @@ use crate::cli::chat::{ #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] +/// Command-line arguments for the editor functionality pub struct EditorArgs { + /// Initial text to populate in the editor pub initial_text: Vec, } @@ -82,13 +84,8 @@ impl EditorArgs { } } -/// Opens the user's preferred editor to compose a prompt -pub fn open_editor(initial_text: Option) -> Result { - // Create a temporary file with a unique name - let temp_dir = std::env::temp_dir(); - let file_name = format!("q_prompt_{}.md", Uuid::new_v4()); - let temp_file_path = temp_dir.join(file_name); - +/// Launch the user's preferred editor with the given file path +fn launch_editor(file_path: &std::path::Path) -> Result<(), ChatError> { // Get the editor from environment variable or use a default let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); @@ -102,11 +99,6 @@ pub fn open_editor(initial_text: Option) -> Result { let editor_bin = parts.remove(0); - // Write initial content to the file if provided - let initial_content = initial_text.unwrap_or_default(); - std::fs::write(&temp_file_path, &initial_content) - .map_err(|e| ChatError::Custom(format!("Failed to create temporary file: {}", e).into()))?; - // Open the editor with the parsed command and arguments let mut cmd = std::process::Command::new(editor_bin); // Add any arguments that were part of the EDITOR variable @@ -115,7 +107,7 @@ pub fn open_editor(initial_text: Option) -> Result { } // Add the file path as the last argument let status = cmd - .arg(&temp_file_path) + .arg(file_path) .status() .map_err(|e| ChatError::Custom(format!("Failed to open editor: {}", e).into()))?; @@ -123,6 +115,29 @@ pub fn open_editor(initial_text: Option) -> Result { return Err(ChatError::Custom("Editor exited with non-zero status".into())); } + Ok(()) +} + +/// Opens the user's preferred editor to edit an existing file +pub fn open_editor_file(file_path: &std::path::Path) -> Result<(), ChatError> { + launch_editor(file_path) +} + +/// Opens the user's preferred editor to compose a prompt +pub fn open_editor(initial_text: Option) -> Result { + // Create a temporary file with a unique name + let temp_dir = std::env::temp_dir(); + let file_name = format!("q_prompt_{}.md", Uuid::new_v4()); + let temp_file_path = temp_dir.join(file_name); + + // Write initial content to the file if provided + let initial_content = initial_text.unwrap_or_default(); + std::fs::write(&temp_file_path, &initial_content) + .map_err(|e| ChatError::Custom(format!("Failed to create temporary file: {}", e).into()))?; + + // Launch the editor + launch_editor(&temp_file_path)?; + // Read the content back let content = std::fs::read_to_string(&temp_file_path) .map_err(|e| ChatError::Custom(format!("Failed to read temporary file: {}", e).into()))?; diff --git a/crates/chat-cli/src/cli/chat/cli/experiment.rs b/crates/chat-cli/src/cli/chat/cli/experiment.rs index 7854974c42..9b7c3a8cd2 100644 --- a/crates/chat-cli/src/cli/chat/cli/experiment.rs +++ b/crates/chat-cli/src/cli/chat/cli/experiment.rs @@ -50,6 +50,20 @@ static AVAILABLE_EXPERIMENTS: &[Experiment] = &[ description: "Enables Q to create todo lists that can be viewed and managed using /todos", setting_key: Setting::EnabledTodoList, }, + Experiment { + name: "Checkpoint", + description: concat!( + "Enables workspace checkpoints to snapshot, list, expand, diff, and restore files (/checkpoint)\n", + " ", + "Cannot be used in tangent mode (to avoid mixing up conversation history)" + ), + setting_key: Setting::EnabledCheckpoint, + }, + Experiment { + name: "Context Usage Indicator", + description: "Shows context usage percentage in the prompt (e.g., [rust-agent] 6% >)", + setting_key: Setting::EnabledContextUsageIndicator, + }, ]; #[derive(Debug, PartialEq, Args)] @@ -108,7 +122,16 @@ async fn select_experiment(os: &mut Os, session: &mut ChatSession) -> Result Err(Interrupted) - Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => return Ok(None), + Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => { + // Move to beginning of line and clear everything from warning message down + queue!( + session.stderr, + crossterm::cursor::MoveToColumn(0), + crossterm::cursor::MoveUp(experiment_labels.len() as u16 + 3), + crossterm::terminal::Clear(crossterm::terminal::ClearType::FromCursorDown), + )?; + return Ok(None); + }, Err(e) => return Err(ChatError::Custom(format!("Failed to choose experiment: {e}").into())), }; @@ -161,6 +184,13 @@ async fn select_experiment(os: &mut Os, session: &mut ChatSession) -> Result bool { + match &hook.matcher { + None => true, // No matcher means the hook runs for all tools + Some(pattern) => { + match pattern.as_str() { + "*" => true, // Wildcard matches all tools + "@builtin" => !is_mcp_tool_ref(tool_name), // Built-in tools are not MCP tools + _ => { + // If tool_name is MCP, check server pattern first + if is_mcp_tool_ref(tool_name) { + if let Some(server_name) = tool_name + .strip_prefix('@') + .and_then(|s| s.split(MCP_SERVER_TOOL_DELIMITER).next()) + { + let server_pattern = format!("@{}", server_name); + if pattern == &server_pattern { + return true; + } + } + } + + // Use matches_any_pattern for both MCP and built-in tools + let mut patterns = std::collections::HashSet::new(); + patterns.insert(pattern.clone()); + matches_any_pattern(&patterns, tool_name) + }, + } + }, + } +} + +#[derive(Debug, Clone)] +pub struct ToolContext { + pub tool_name: String, + pub tool_input: serde_json::Value, + pub tool_response: Option, +} #[derive(Debug, Clone)] pub struct CachedHook { @@ -74,22 +120,32 @@ impl HookExecutor { &mut self, hooks: HashMap>, output: &mut impl Write, + cwd: &str, prompt: Option<&str>, - ) -> Result, ChatError> { + tool_context: Option, + ) -> Result, ChatError> { let mut cached = vec![]; let mut futures = FuturesUnordered::new(); for hook in hooks .into_iter() .flat_map(|(trigger, hooks)| hooks.into_iter().map(move |hook| (trigger, hook))) { + // Filter hooks by tool matcher + if let Some(tool_ctx) = &tool_context { + if !hook_matches_tool(&hook.1, &tool_ctx.tool_name) { + continue; // Skip this hook - doesn't match tool + } + } + if let Some(cache) = self.get_cache(&hook) { - cached.push((hook.clone(), cache.clone())); + // Note: we only cache successful hook run. hence always using 0 as exit code for cached hook + cached.push((hook.clone(), (0, cache))); continue; } - futures.push(self.run_hook(hook, prompt)); + futures.push(self.run_hook(hook, cwd, prompt, tool_context.clone())); } - let mut complete = 0; + let mut complete = 0; // number of hooks that are run successfully with exit code 0 let total = futures.len(); let mut spinner = None; let spinner_text = |complete: usize, total: usize| { @@ -138,9 +194,29 @@ impl HookExecutor { } // Process results regardless of output enabled - if let Ok(output) = result { - complete += 1; - results.push((hook, output)); + if let Ok((exit_code, hook_output)) = &result { + // Print warning if exit code is not 0 + if *exit_code != 0 { + queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗ "), + style::ResetColor, + style::Print(format!("{} \"", hook.0)), + style::Print(&hook.1.command), + style::Print("\""), + style::SetForegroundColor(style::Color::Red), + style::Print(format!( + " failed with exit code: {}, stderr: {})\n", + exit_code, + hook_output.trim_end() + )), + style::ResetColor, + )?; + } else { + complete += 1; + } + results.push((hook, result.unwrap())); } // Display ending summary or add a new spinner @@ -167,12 +243,17 @@ impl HookExecutor { drop(futures); // Fill cache with executed results, skipping what was already from cache - for ((trigger, hook), output) in &results { + for ((trigger, hook), (exit_code, output)) in &results { + if *exit_code != 0 { + continue; // Only cache successful hooks + } self.cache.insert((*trigger, hook.clone()), CachedHook { output: output.clone(), expiry: match trigger { HookTrigger::AgentSpawn => None, HookTrigger::UserPromptSubmit => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), + HookTrigger::PreToolUse => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), + HookTrigger::PostToolUse => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), }, }); } @@ -185,8 +266,10 @@ impl HookExecutor { async fn run_hook( &self, hook: (HookTrigger, Hook), + cwd: &str, prompt: Option<&str>, - ) -> ((HookTrigger, Hook), Result, Duration) { + tool_context: Option, + ) -> ((HookTrigger, Hook), Result, Duration) { let start_time = Instant::now(); let command = &hook.1.command; @@ -213,33 +296,61 @@ impl HookExecutor { let timeout = Duration::from_millis(hook.1.timeout_ms); - // Set USER_PROMPT environment variable if provided + // Generate hook command input in JSON format + let mut hook_input = serde_json::json!({ + "hook_event_name": hook.0.to_string(), + "cwd": cwd + }); + + // Set USER_PROMPT environment variable and add to JSON input if provided if let Some(prompt) = prompt { // Sanitize the prompt to avoid issues with special characters let sanitized_prompt = sanitize_user_prompt(prompt); cmd.env("USER_PROMPT", sanitized_prompt); + hook_input["prompt"] = serde_json::Value::String(prompt.to_string()); } - let command_future = cmd.output(); + // ToolUse specific input + if let Some(tool_ctx) = tool_context { + hook_input["tool_name"] = serde_json::Value::String(tool_ctx.tool_name); + hook_input["tool_input"] = tool_ctx.tool_input; + if let Some(response) = tool_ctx.tool_response { + hook_input["tool_response"] = response; + } + } + let json_input = serde_json::to_string(&hook_input).unwrap_or_default(); + + // Build a future for hook command w/ the JSON input passed in through STDIN + let command_future = async move { + let mut child = cmd.spawn()?; + if let Some(stdin) = child.stdin.take() { + use tokio::io::AsyncWriteExt; + let mut stdin = stdin; + let _ = stdin.write_all(json_input.as_bytes()).await; + let _ = stdin.shutdown().await; + } + child.wait_with_output().await + }; // Run with timeout let result = match tokio::time::timeout(timeout, command_future).await { - Ok(Ok(result)) => { - if result.status.success() { - let stdout = result.stdout.to_str_lossy(); - let stdout = format!( - "{}{}", - truncate_safe(&stdout, hook.1.max_output_size), - if stdout.len() > hook.1.max_output_size { - " ... truncated" - } else { - "" - } - ); - Ok(stdout) + Ok(Ok(output)) => { + let exit_code = output.status.code().unwrap_or(-1); + let raw_output = if exit_code == 0 { + output.stdout.to_str_lossy() } else { - Err(eyre!("command returned non-zero exit code: {}", result.status)) - } + output.stderr.to_str_lossy() + }; + let formatted_output = format!( + "{}{}", + truncate_safe(&raw_output, hook.1.max_output_size), + if raw_output.len() > hook.1.max_output_size { + " ... truncated" + } else { + "" + } + ); + Ok((exit_code, formatted_output)) }, Ok(Err(err)) => Err(eyre!("failed to execute command: {}", err)), Err(_) => Err(eyre!("command timed out after {} ms", timeout.as_millis())), @@ -286,6 +397,7 @@ Notes: • 'conversation_start' hooks run on the first user prompt and are attached once to the conversation history sent to Amazon Q • 'per_prompt' hooks run on each user prompt and are attached to the prompt, but are not stored in conversation history" )] +/// Arguments for the hooks command that displays configured context hooks pub struct HooksArgs; impl HooksArgs { @@ -329,3 +441,267 @@ impl HooksArgs { }) } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use tempfile::TempDir; + + use super::*; + use crate::cli::agent::hook::{ + Hook, + HookTrigger, + }; + + #[test] + fn test_hook_matches_tool() { + let hook_no_matcher = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: None, + source: crate::cli::agent::hook::Source::Session, + }; + + let fs_write_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("fs_write".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let fs_wildcard_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("fs_*".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let all_tools_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("*".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let builtin_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("@builtin".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let git_server_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("@git".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let git_status_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("@git/status".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + // No matcher should match all tools + assert!(hook_matches_tool(&hook_no_matcher, "fs_write")); + assert!(hook_matches_tool(&hook_no_matcher, "execute_bash")); + assert!(hook_matches_tool(&hook_no_matcher, "@git/status")); + + // Exact matcher should only match exact tool + assert!(hook_matches_tool(&fs_write_hook, "fs_write")); + assert!(!hook_matches_tool(&fs_write_hook, "fs_read")); + + // Wildcard matcher should match pattern + assert!(hook_matches_tool(&fs_wildcard_hook, "fs_write")); + assert!(hook_matches_tool(&fs_wildcard_hook, "fs_read")); + assert!(!hook_matches_tool(&fs_wildcard_hook, "execute_bash")); + + // * should match all tools + assert!(hook_matches_tool(&all_tools_hook, "fs_write")); + assert!(hook_matches_tool(&all_tools_hook, "execute_bash")); + assert!(hook_matches_tool(&all_tools_hook, "@git/status")); + + // @builtin should match built-in tools only + assert!(hook_matches_tool(&builtin_hook, "fs_write")); + assert!(hook_matches_tool(&builtin_hook, "execute_bash")); + assert!(!hook_matches_tool(&builtin_hook, "@git/status")); + + // @git should match all git server tools + assert!(hook_matches_tool(&git_server_hook, "@git/status")); + assert!(!hook_matches_tool(&git_server_hook, "@other/tool")); + assert!(!hook_matches_tool(&git_server_hook, "fs_write")); + + // @git/status should match exact MCP tool + assert!(hook_matches_tool(&git_status_hook, "@git/status")); + assert!(!hook_matches_tool(&git_status_hook, "@git/commit")); + assert!(!hook_matches_tool(&git_status_hook, "fs_write")); + } + + #[tokio::test] + async fn test_hook_executor_with_tool_context() { + let mut executor = HookExecutor::new(); + let mut output = Vec::new(); + + // Create temp directory and file + let temp_dir = TempDir::new().unwrap(); + let test_file = temp_dir.path().join("hook_output.json"); + let test_file_str = test_file.to_string_lossy(); + + // Create a simple hook that writes JSON input to a file + #[cfg(unix)] + let command = format!("cat > {}", test_file_str); + #[cfg(windows)] + let command = format!("type > {}", test_file_str); + + let hook = Hook { + command, + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("fs_write".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let mut hooks = HashMap::new(); + hooks.insert(HookTrigger::PreToolUse, vec![hook]); + + let tool_context = ToolContext { + tool_name: "fs_write".to_string(), + tool_input: serde_json::json!({ + "command": "create", + "path": "/test/file.py" + }), + tool_response: None, + }; + + // Run the hook + let result = executor + .run_hooks(hooks, &mut output, ".", None, Some(tool_context)) + .await; + + assert!(result.is_ok()); + + // Verify the hook wrote the JSON input to the file + if let Ok(content) = std::fs::read_to_string(&test_file) { + let json: serde_json::Value = serde_json::from_str(&content).unwrap(); + assert_eq!(json["hook_event_name"], "preToolUse"); + assert_eq!(json["tool_name"], "fs_write"); + assert_eq!(json["tool_input"]["command"], "create"); + assert_eq!(json["cwd"], "."); + } + // TempDir automatically cleans up when dropped + } + + #[tokio::test] + async fn test_hook_filtering_no_match() { + let mut executor = HookExecutor::new(); + let mut output = Vec::new(); + + // Hook that matches execute_bash (should NOT run for fs_write tool call) + let execute_bash_hook = Hook { + command: "echo 'should not run'".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("execute_bash".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let mut hooks = HashMap::new(); + hooks.insert(HookTrigger::PostToolUse, vec![execute_bash_hook]); + + let tool_context = ToolContext { + tool_name: "fs_write".to_string(), + tool_input: serde_json::json!({"command": "create"}), + tool_response: Some(serde_json::json!({"success": true})), + }; + + // Run the hooks + let result = executor + .run_hooks( + hooks, + &mut output, + ".", // cwd - using current directory for now + None, // prompt - no user prompt for this test + Some(tool_context), + ) + .await; + + assert!(result.is_ok()); + let hook_results = result.unwrap(); + + // Should run 0 hooks because matcher doesn't match tool_name + assert_eq!(hook_results.len(), 0); + + // Output should be empty since no hooks ran + assert!(output.is_empty()); + } + + #[tokio::test] + async fn test_hook_exit_code_2() { + let mut executor = HookExecutor::new(); + let mut output = Vec::new(); + + // Create a hook that exits with code 2 and outputs to stderr + #[cfg(unix)] + let command = "echo 'Tool execution blocked by security policy' >&2; exit 2"; + #[cfg(windows)] + let command = "echo Tool execution blocked by security policy 1>&2 & exit /b 2"; + + let hook = Hook { + command: command.to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("fs_write".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let hooks = HashMap::from([(HookTrigger::PreToolUse, vec![hook])]); + + let tool_context = ToolContext { + tool_name: "fs_write".to_string(), + tool_input: serde_json::json!({ + "command": "create", + "path": "/sensitive/file.py" + }), + tool_response: None, + }; + + let results = executor + .run_hooks( + hooks, + &mut output, + ".", // cwd + None, // prompt + Some(tool_context), + ) + .await + .unwrap(); + + // Should have one result + assert_eq!(results.len(), 1); + + let ((trigger, _hook), (exit_code, hook_output)) = &results[0]; + assert_eq!(*trigger, HookTrigger::PreToolUse); + assert_eq!(*exit_code, 2); + assert!(hook_output.contains("Tool execution blocked by security policy")); + } +} diff --git a/crates/chat-cli/src/cli/chat/cli/mcp.rs b/crates/chat-cli/src/cli/chat/cli/mcp.rs index 82a9740c5e..1cabd5344b 100644 --- a/crates/chat-cli/src/cli/chat/cli/mcp.rs +++ b/crates/chat-cli/src/cli/chat/cli/mcp.rs @@ -14,6 +14,10 @@ use crate::cli::chat::{ ChatState, }; +/// Arguments for the MCP (Model Context Protocol) command. +/// +/// This struct handles MCP-related functionality, allowing users to view +/// the status of MCP servers and their loading progress. #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct McpArgs; @@ -50,9 +54,9 @@ impl McpArgs { let msg = msg .iter() .map(|record| match record { - LoadingRecord::Err(content) | LoadingRecord::Warn(content) | LoadingRecord::Success(content) => { - content.clone() - }, + LoadingRecord::Err(timestamp, content) + | LoadingRecord::Warn(timestamp, content) + | LoadingRecord::Success(timestamp, content) => format!("[{timestamp}]: {content}"), }) .collect::>() .join("\n--- tools refreshed ---\n"); diff --git a/crates/chat-cli/src/cli/chat/cli/mod.rs b/crates/chat-cli/src/cli/chat/cli/mod.rs index 4e0f38a3d4..bf951596e6 100644 --- a/crates/chat-cli/src/cli/chat/cli/mod.rs +++ b/crates/chat-cli/src/cli/chat/cli/mod.rs @@ -1,3 +1,5 @@ +pub mod changelog; +pub mod checkpoint; pub mod clear; pub mod compact; pub mod context; @@ -16,6 +18,7 @@ pub mod todos; pub mod tools; pub mod usage; +use changelog::ChangelogArgs; use clap::Parser; use clear::ClearArgs; use compact::CompactArgs; @@ -33,6 +36,7 @@ use tangent::TangentArgs; use todos::TodoSubcommand; use tools::ToolsArgs; +use crate::cli::chat::cli::checkpoint::CheckpointSubcommand; use crate::cli::chat::cli::subscribe::SubscribeArgs; use crate::cli::chat::cli::usage::UsageArgs; use crate::cli::chat::consts::AGENT_MIGRATION_DOC_URL; @@ -75,6 +79,9 @@ pub enum SlashCommand { Tools(ToolsArgs), /// Create a new Github issue or make a feature request Issue(issue::IssueArgs), + /// View changelog for Amazon Q CLI + #[command(name = "changelog")] + Changelog(ChangelogArgs), /// View and retrieve prompts Prompts(PromptsArgs), /// View context hooks @@ -97,6 +104,8 @@ pub enum SlashCommand { Persist(PersistSubcommand), // #[command(flatten)] // Root(RootSubcommand), + #[command(subcommand)] + Checkpoint(CheckpointSubcommand), /// View, manage, and resume to-do lists #[command(subcommand)] Todos(TodoSubcommand), @@ -145,7 +154,8 @@ impl SlashCommand { skip_printing_tools: true, }) }, - Self::Prompts(args) => args.execute(session).await, + Self::Changelog(args) => args.execute(session).await, + Self::Prompts(args) => args.execute(os, session).await, Self::Hooks(args) => args.execute(session).await, Self::Usage(args) => args.execute(os, session).await, Self::Mcp(args) => args.execute(session).await, @@ -163,6 +173,7 @@ impl SlashCommand { // skip_printing_tools: true, // }) // }, + Self::Checkpoint(subcommand) => subcommand.execute(os, session).await, Self::Todos(subcommand) => subcommand.execute(os, session).await, } } @@ -179,6 +190,7 @@ impl SlashCommand { Self::Compact(_) => "compact", Self::Tools(_) => "tools", Self::Issue(_) => "issue", + Self::Changelog(_) => "changelog", Self::Prompts(_) => "prompts", Self::Hooks(_) => "hooks", Self::Usage(_) => "usage", @@ -191,6 +203,7 @@ impl SlashCommand { PersistSubcommand::Save { .. } => "save", PersistSubcommand::Load { .. } => "load", }, + Self::Checkpoint(_) => "checkpoint", Self::Todos(_) => "todos", } } diff --git a/crates/chat-cli/src/cli/chat/cli/model.rs b/crates/chat-cli/src/cli/chat/cli/model.rs index 1e484666f0..de9819fe41 100644 --- a/crates/chat-cli/src/cli/chat/cli/model.rs +++ b/crates/chat-cli/src/cli/chat/cli/model.rs @@ -60,10 +60,11 @@ impl ModelInfo { self.model_name.as_deref().unwrap_or(&self.model_id) } } + +/// Command-line arguments for model selection operations #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct ModelArgs; - impl ModelArgs { pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { Ok(select_model(os, session).await?.unwrap_or(ChatState::PromptUser { diff --git a/crates/chat-cli/src/cli/chat/cli/persist.rs b/crates/chat-cli/src/cli/chat/cli/persist.rs index 1f4568f7ed..b1f5d0da19 100644 --- a/crates/chat-cli/src/cli/chat/cli/persist.rs +++ b/crates/chat-cli/src/cli/chat/cli/persist.rs @@ -14,17 +14,23 @@ use crate::cli::chat::{ }; use crate::os::Os; +/// Commands for persisting and loading conversation state #[deny(missing_docs)] #[derive(Debug, PartialEq, Subcommand)] pub enum PersistSubcommand { /// Save the current conversation Save { + /// Path where the conversation will be saved path: String, #[arg(short, long)] + /// Force overwrite if file already exists force: bool, }, /// Load a previous conversation - Load { path: String }, + Load { + /// Path to the conversation file to load + path: String, + }, } impl PersistSubcommand { diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index 14150524bc..83a7b634cd 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -55,11 +55,12 @@ use crate::util::{ Notes • Launch q chat with a specific agent with --agent -• Construct an agent under ~/.aws/amazonq/cli-agents/ (accessible globally) or cwd/.aws/amazonq/cli-agents (accessible in workspace) +• Construct an agent under ~/.aws/amazonq/cli-agents/ (accessible globally) or cwd/.amazonq/cli-agents (accessible in workspace) • See example config under global directory • Set default agent to assume with settings by running \"q settings chat.defaultAgent agent_name\" • Each agent maintains its own set of context and customizations" )] +/// Subcommands for managing agents in the chat CLI pub enum AgentSubcommand { /// List all available agents List, @@ -76,40 +77,67 @@ pub enum AgentSubcommand { #[arg(long, short)] from: Option, }, + /// Edit an existing agent configuration + Edit { + /// Name of the agent to edit + #[arg(long, short)] + name: String, + }, /// Generate an agent configuration using AI Generate {}, /// Delete the specified agent #[command(hide = true)] - Delete { name: String }, + Delete { + /// Name of the agent to delete + name: String, + }, /// Switch to the specified agent #[command(hide = true)] - Set { name: String }, + Set { + /// Name of the agent to switch to + name: String, + }, /// Show agent config schema Schema, /// Define a default agent to use when q chat launches SetDefault { + /// Name of the agent to set as default #[arg(long, short)] name: String, }, /// Swap to a new agent at runtime #[command(alias = "switch")] - Swap { name: Option }, + Swap { + /// Optional name of the agent to swap to. If not provided, a selection dialog will be shown + name: Option, + }, } -fn prompt_mcp_server_selection(servers: &[McpServerInfo]) -> eyre::Result> { +fn prompt_mcp_server_selection(servers: &[McpServerInfo]) -> eyre::Result>> { let items: Vec = servers .iter() .map(|server| format!("{} ({})", server.name, server.config.command)) .collect(); - let selections = MultiSelect::new() + let selections = match MultiSelect::new() .with_prompt("Select MCP servers (use Space to toggle, Enter to confirm)") .items(&items) - .interact()?; - - let selected_servers: Vec<&McpServerInfo> = selections.iter().filter_map(|&i| servers.get(i)).collect(); + .interact_on_opt(&dialoguer::console::Term::stdout()) + { + Ok(sel) => sel, + Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => { + return Ok(None); + }, + Err(e) => return Err(eyre::eyre!("Failed to get MCP server selection: {e}")), + }; + + let selected_servers: Vec<&McpServerInfo> = selections + .unwrap_or_default() + .iter() + .filter_map(|&i| servers.get(i)) + .collect(); - Ok(selected_servers) + Ok(Some(selected_servers)) } impl AgentSubcommand { @@ -220,6 +248,64 @@ impl AgentSubcommand { )?; }, + Self::Edit { name } => { + let (_agent, path_with_file_name) = Agent::get_agent_by_name(os, &name) + .await + .map_err(|e| ChatError::Custom(Cow::Owned(e.to_string())))?; + + let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); + let mut cmd = std::process::Command::new(editor_cmd); + + let status = cmd.arg(&path_with_file_name).status()?; + if !status.success() { + return Err(ChatError::Custom("Editor process did not exit with success".into())); + } + + let updated_agent = Agent::load( + os, + &path_with_file_name, + &mut None, + session.conversation.mcp_enabled, + &mut session.stderr, + ) + .await; + match updated_agent { + Ok(agent) => { + session.conversation.agents.agents.insert(agent.name.clone(), agent); + }, + Err(e) => { + execute!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("Error: "), + style::ResetColor, + style::Print(&e), + style::Print("\n"), + )?; + + return Err(ChatError::Custom( + format!("Post edit validation failed for agent '{name}'. Malformed config detected: {e}") + .into(), + )); + }, + } + + execute!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print("Agent "), + style::SetForegroundColor(Color::Cyan), + style::Print(name), + style::SetForegroundColor(Color::Green), + style::Print(" has been edited successfully"), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Changes take effect on next launch"), + style::SetForegroundColor(Color::Reset) + )?; + }, + Self::Generate {} => { let agent_name = match crate::util::input("Enter agent name: ", None) { Ok(input) => input.trim().to_string(), @@ -280,7 +366,12 @@ impl AgentSubcommand { let selected_servers = if mcp_servers.is_empty() { Vec::new() } else { - prompt_mcp_server_selection(&mcp_servers).map_err(|e| ChatError::Custom(e.to_string().into()))? + match prompt_mcp_server_selection(&mcp_servers) + .map_err(|e| ChatError::Custom(e.to_string().into()))? + { + Some(servers) => servers, + None => return Ok(ChatState::default()), + } }; let mcp_servers_json = if !selected_servers.is_empty() { @@ -413,6 +504,7 @@ impl AgentSubcommand { match self { Self::List => "list", Self::Create { .. } => "create", + Self::Edit { .. } => "edit", Self::Generate { .. } => "generate", Self::Delete { .. } => "delete", Self::Set { .. } => "set", diff --git a/crates/chat-cli/src/cli/chat/cli/prompts.rs b/crates/chat-cli/src/cli/chat/cli/prompts.rs index 53b0012a57..1dffde5169 100644 --- a/crates/chat-cli/src/cli/chat/cli/prompts.rs +++ b/crates/chat-cli/src/cli/chat/cli/prompts.rs @@ -2,6 +2,9 @@ use std::collections::{ HashMap, VecDeque, }; +use std::fs; +use std::path::PathBuf; +use std::sync::LazyLock; use clap::{ Args, @@ -16,17 +19,34 @@ use crossterm::{ execute, queue, }; +use regex::Regex; +use rmcp::model::{ + PromptMessage, + PromptMessageContent, + PromptMessageRole, +}; use thiserror::Error; use unicode_width::UnicodeWidthStr; -use crate::cli::chat::error_formatter::format_mcp_error; +use crate::cli::chat::cli::editor::open_editor_file; use crate::cli::chat::tool_manager::PromptBundle; use crate::cli::chat::{ ChatError, ChatSession, ChatState, }; -use crate::mcp_client::PromptGetResult; +use crate::mcp_client::McpClientError; +use crate::os::Os; +use crate::util::directories::{ + chat_global_prompts_dir, + chat_local_prompts_dir, +}; + +/// Maximum allowed length for prompt names +const MAX_PROMPT_NAME_LENGTH: usize = 50; + +/// Regex for validating prompt names (alphanumeric, hyphens, underscores only) +static PROMPT_NAME_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap()); #[derive(Debug, Error)] pub enum GetPromptError { @@ -46,8 +66,160 @@ pub enum GetPromptError { IncorrectResponseType, #[error("Missing channel")] MissingChannel, + #[error(transparent)] + McpClient(#[from] McpClientError), + #[error(transparent)] + Service(#[from] rmcp::ServiceError), + #[error(transparent)] + Io(#[from] std::io::Error), +} + +/// Represents a single prompt (local or global) +#[derive(Debug, Clone)] +struct Prompt { + name: String, + path: PathBuf, +} + +impl std::fmt::Display for Prompt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl Prompt { + /// Create a new prompt with the given name in the specified directory + fn new(name: &str, base_dir: PathBuf) -> Self { + let path = base_dir.join(format!("{}.md", name)); + Self { + name: name.to_string(), + path, + } + } + + /// Check if the prompt file exists + fn exists(&self) -> bool { + self.path.exists() + } + + /// Load the content of the prompt file + fn load_content(&self) -> Result { + fs::read_to_string(&self.path).map_err(GetPromptError::Io) + } + + /// Save content to the prompt file + fn save_content(&self, content: &str) -> Result<(), GetPromptError> { + // Ensure parent directory exists + if let Some(parent) = self.path.parent() { + fs::create_dir_all(parent).map_err(GetPromptError::Io)?; + } + fs::write(&self.path, content).map_err(GetPromptError::Io) + } + + /// Delete the prompt file + fn delete(&self) -> Result<(), GetPromptError> { + fs::remove_file(&self.path).map_err(GetPromptError::Io) + } +} + +/// Represents both local and global prompts for a given name +#[derive(Debug)] +struct Prompts { + local: Prompt, + global: Prompt, +} + +impl Prompts { + /// Create a new Prompts instance for the given name + fn new(name: &str, os: &Os) -> Result { + let local_dir = chat_local_prompts_dir(os).map_err(|e| GetPromptError::General(e.into()))?; + let global_dir = chat_global_prompts_dir(os).map_err(|e| GetPromptError::General(e.into()))?; + + Ok(Self { + local: Prompt::new(name, local_dir), + global: Prompt::new(name, global_dir), + }) + } + + /// Check if local prompt overrides a global one (both local and global exist) + fn has_local_override(&self) -> bool { + self.local.exists() && self.global.exists() + } + + /// Find and load existing prompt content (local takes priority) + fn load_existing(&self) -> Result, GetPromptError> { + if self.local.exists() { + let content = self.local.load_content()?; + Ok(Some((content, self.local.path.clone()))) + } else if self.global.exists() { + let content = self.global.load_content()?; + Ok(Some((content, self.global.path.clone()))) + } else { + Ok(None) + } + } + + /// Get all available prompt names from both directories + fn get_available_names(os: &Os) -> Result, GetPromptError> { + let mut prompt_names = std::collections::HashSet::new(); + + // Helper function to collect prompt names from a directory + let collect_from_dir = + |dir: PathBuf, names: &mut std::collections::HashSet| -> Result<(), GetPromptError> { + if dir.exists() { + for entry in fs::read_dir(&dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("md") { + if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { + let prompt = Prompt::new(file_stem, dir.clone()); + names.insert(prompt.name); + } + } + } + } + Ok(()) + }; + + // Check global prompts + if let Ok(global_dir) = chat_global_prompts_dir(os) { + collect_from_dir(global_dir, &mut prompt_names)?; + } + + // Check local prompts + if let Ok(local_dir) = chat_local_prompts_dir(os) { + collect_from_dir(local_dir, &mut prompt_names)?; + } + + Ok(prompt_names.into_iter().collect()) + } +} + +/// Validate prompt name to ensure it's safe and follows naming conventions +fn validate_prompt_name(name: &str) -> Result<(), String> { + // Check for empty name + if name.trim().is_empty() { + return Err("Prompt name cannot be empty. Please provide a valid name for your prompt.".to_string()); + } + + // Check length limit + if name.len() > MAX_PROMPT_NAME_LENGTH { + return Err(format!( + "Prompt name must be {} characters or less. Current length: {} characters.", + MAX_PROMPT_NAME_LENGTH, + name.len() + )); + } + + // Check for valid characters using regex (alphanumeric, hyphens, underscores only) + if !PROMPT_NAME_REGEX.is_match(name) { + return Err("Prompt name can only contain letters, numbers, hyphens (-), and underscores (_). Special characters, spaces, and path separators are not allowed.".to_string()); + } + + Ok(()) } +/// Command-line arguments for prompt operations #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] #[command(color = clap::ColorChoice::Always, @@ -65,21 +237,39 @@ pub struct PromptsArgs { } impl PromptsArgs { - pub async fn execute(self, session: &mut ChatSession) -> Result { + pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { let search_word = match &self.subcommand { Some(PromptsSubcommand::List { search_word }) => search_word.clone(), _ => None, }; if let Some(subcommand) = self.subcommand { - if matches!(subcommand, PromptsSubcommand::Get { .. }) { - return subcommand.execute(session).await; + if matches!( + subcommand, + PromptsSubcommand::Get { .. } + | PromptsSubcommand::Create { .. } + | PromptsSubcommand::Edit { .. } + | PromptsSubcommand::Remove { .. } + ) { + return subcommand.execute(os, session).await; } } let terminal_width = session.terminal_width(); let prompts = session.conversation.tool_manager.list_prompts().await?; + + // Get available prompt names + let prompt_names = Prompts::get_available_names(os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + let mut longest_name = ""; + + // Update longest_name to include local prompts + for name in &prompt_names { + if name.contains(search_word.as_deref().unwrap_or("")) && name.len() > longest_name.len() { + longest_name = name; + } + } + let arg_pos = { let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4; if optimal_case > terminal_width { @@ -142,10 +332,96 @@ impl PromptsArgs { .collect(); prompts_by_server.sort_by_key(|(server_name, _)| server_name.as_str()); + // Display prompts by category + let filtered_names: Vec<_> = prompt_names + .iter() + .filter(|name| name.contains(search_word.as_deref().unwrap_or(""))) + .collect(); + + if !filtered_names.is_empty() { + // Separate global and local prompts for display + let _global_dir = chat_global_prompts_dir(os).ok(); + let _local_dir = chat_local_prompts_dir(os).ok(); + + let mut global_prompts = Vec::new(); + let mut local_prompts = Vec::new(); + let mut overridden_globals = Vec::new(); + + for name in &filtered_names { + // Use the Prompts struct to check for conflicts + if let Ok(prompts) = Prompts::new(name, os) { + let (local_exists, global_exists) = (prompts.local.exists(), prompts.global.exists()); + + if global_exists { + global_prompts.push(name); + } + + if local_exists { + local_prompts.push(name); + // Check for overrides using has_local_override method + if global_exists { + overridden_globals.push(name); + } + } + } + } + + if !global_prompts.is_empty() { + queue!( + session.stderr, + style::SetAttribute(Attribute::Bold), + style::Print("Global (.aws/amazonq/prompts):"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + )?; + for name in &global_prompts { + queue!(session.stderr, style::Print("- "), style::Print(name))?; + queue!(session.stderr, style::Print("\n"))?; + } + } + + if !local_prompts.is_empty() { + if !global_prompts.is_empty() { + queue!(session.stderr, style::Print("\n"))?; + } + queue!( + session.stderr, + style::SetAttribute(Attribute::Bold), + style::Print("Local (.amazonq/prompts):"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + )?; + for name in &local_prompts { + let has_global_version = overridden_globals.contains(name); + queue!(session.stderr, style::Print("- "), style::Print(name),)?; + if has_global_version { + queue!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print(" (overrides global)"), + style::SetForegroundColor(Color::Reset), + )?; + } + + // Show override indicator if this local prompt overrides a global one + if overridden_globals.contains(name) { + queue!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print(" (overrides global)"), + style::SetForegroundColor(Color::Reset), + )?; + } + + queue!(session.stderr, style::Print("\n"))?; + } + } + } + for (i, (server_name, bundles)) in prompts_by_server.iter_mut().enumerate() { bundles.sort_by_key(|bundle| &bundle.prompt_get.name); - if i > 0 { + if i > 0 || !filtered_names.is_empty() { queue!(session.stderr, style::Print("\n"))?; } queue!( @@ -205,31 +481,103 @@ impl PromptsArgs { } } +/// Subcommands for prompt operations #[deny(missing_docs)] #[derive(Clone, Debug, PartialEq, Subcommand)] pub enum PromptsSubcommand { /// List available prompts from a tool or show all available prompt - List { search_word: Option }, + List { + /// Optional search word to filter prompts + search_word: Option, + }, + /// Get a specific prompt by name Get { #[arg(long, hide = true)] + /// Original input string (hidden) orig_input: Option, + /// Name of the prompt to retrieve name: String, + /// Optional arguments for the prompt arguments: Option>, }, + /// Create a new prompt + Create { + /// Name of the prompt to create + #[arg(short = 'n', long)] + name: String, + /// Content of the prompt (if not provided, opens editor) + #[arg(long)] + content: Option, + /// Create in global directory instead of local + #[arg(long)] + global: bool, + }, + /// Edit an existing prompt + Edit { + /// Name of the prompt to edit + name: String, + /// Edit global prompt instead of local + #[arg(long)] + global: bool, + }, + /// Remove an existing prompt + Remove { + /// Name of the prompt to remove + name: String, + /// Remove global prompt instead of local + #[arg(long)] + global: bool, + }, } impl PromptsSubcommand { - pub async fn execute(self, session: &mut ChatSession) -> Result { - let PromptsSubcommand::Get { - orig_input, - name, - arguments, - } = self - else { - unreachable!("List has already been parsed out at this point"); - }; + pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { + match self { + PromptsSubcommand::Get { + orig_input, + name, + arguments: _, + } => Self::execute_get(os, session, orig_input, name).await, + PromptsSubcommand::Create { name, content, global } => { + Self::execute_create(os, session, name, content, global).await + }, + PromptsSubcommand::Edit { name, global } => Self::execute_edit(os, session, name, global).await, + PromptsSubcommand::Remove { name, global } => Self::execute_remove(os, session, name, global).await, + PromptsSubcommand::List { .. } => { + unreachable!("List has already been parsed out at this point"); + }, + } + } + + async fn execute_get( + os: &Os, + session: &mut ChatSession, + orig_input: Option, + name: String, + ) -> Result { + // First try to find prompt (global or local) + let prompts = Prompts::new(&name, os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + if let Some((content, _)) = prompts + .load_existing() + .map_err(|e| ChatError::Custom(e.to_string().into()))? + { + // Handle local prompt + session.pending_prompts.clear(); + + // Create a PromptMessage from the local prompt content + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: PromptMessageContent::Text { text: content.clone() }, + }; + session.pending_prompts.push_back(prompt_message); + + return Ok(ChatState::HandleInput { + input: orig_input.unwrap_or_default(), + }); + } - let prompts = match session.conversation.tool_manager.get_prompt(name, arguments).await { + // If not found locally, try MCP prompts + let prompts = match session.conversation.tool_manager.get_prompt(name, None).await { Ok(resp) => resp, Err(e) => { match e { @@ -273,36 +621,535 @@ impl PromptsSubcommand { }); }, }; - if let Some(err) = prompts.error { - // If we are running into error we should just display the error - // and abort. - let to_display = serde_json::json!(err); + + session.pending_prompts.clear(); + session.pending_prompts.append(&mut VecDeque::from(prompts.messages)); + + Ok(ChatState::HandleInput { + input: orig_input.unwrap_or_default(), + }) + } + + async fn execute_create( + os: &Os, + session: &mut ChatSession, + name: String, + content: Option, + global: bool, + ) -> Result { + // Create prompts instance and validate name + let mut prompts = Prompts::new(&name, os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + + if let Err(validation_error) = validate_prompt_name(&name) { queue!( session.stderr, style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print("Error encountered while retrieving prompt:"), - style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Red), + style::Print("❌ Invalid prompt name: "), + style::Print(validation_error), + style::Print("\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("Valid names contain only letters, numbers, hyphens, and underscores (1-50 characters)\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + // Check if prompt already exists in target location + let (local_exists, global_exists) = (prompts.local.exists(), prompts.global.exists()); + let target_exists = if global { global_exists } else { local_exists }; + + if target_exists { + let location = if global { "global" } else { "local" }; + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" already exists in "), + style::Print(location), + style::Print(" directory. Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts edit "), + style::Print(&name), + if global { + style::Print(" --global") + } else { + style::Print("") + }, + style::SetForegroundColor(Color::Yellow), + style::Print(" to modify it.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + // Check if creating this prompt will cause or involve a conflict + let opposite_exists = if global { local_exists } else { global_exists }; + + if prompts.has_local_override() || opposite_exists { + let (existing_scope, _creating_scope, override_message) = if !global { + ( + "global", + "local", + "Creating this local prompt will override the global one.", + ) + } else { + ( + "local", + "global", + "The local prompt will continue to override this global one.", + ) + }; + + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("⚠ Warning: A "), + style::Print(existing_scope), + style::Print(" prompt named '"), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print("' already exists.\n"), + style::Print(override_message), + style::Print("\n"), + style::SetForegroundColor(Color::Reset), + )?; + + // Flush stderr to ensure the warning is displayed before asking for input + execute!(session.stderr)?; + + // Ask for user confirmation + let user_input = match crate::util::input("Do you want to continue? (y/n): ", None) { + Ok(input) => input.trim().to_lowercase(), + Err(_) => { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Prompt creation cancelled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }, + }; + + if user_input != "y" && user_input != "yes" { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Prompt creation cancelled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + } + + match content { + Some(content) => { + // Write the prompt file with provided content + let target_prompt = if global { + &mut prompts.global + } else { + &mut prompts.local + }; + + target_prompt + .save_content(&content) + .map_err(|e| ChatError::Custom(e.to_string().into()))?; + + let location = if global { "global" } else { "local" }; + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Created "), + style::Print(location), + style::Print(" prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Green), + style::Print(" at "), + style::SetForegroundColor(Color::DarkGrey), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + None => { + // Create file with default template and open editor + let default_content = "# Enter your prompt content here\n\nDescribe what this prompt should do..."; + let target_prompt = if global { + &mut prompts.global + } else { + &mut prompts.local + }; + + target_prompt + .save_content(default_content) + .map_err(|e| ChatError::Custom(e.to_string().into()))?; + + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("Opening editor to create prompt content...\n"), + style::SetForegroundColor(Color::Reset), + )?; + + // Try to open the editor + match open_editor_file(&target_prompt.path) { + Ok(()) => { + let location = if global { "global" } else { "local" }; + queue!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print("✓ Created "), + style::Print(location), + style::Print(" prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Green), + style::Print(" at "), + style::SetForegroundColor(Color::DarkGrey), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + Err(err) => { + queue!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("Error opening editor: "), + style::Print(err.to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("Tip: You can edit this file directly: "), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + } + }, + }; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + async fn execute_edit( + os: &Os, + session: &mut ChatSession, + name: String, + global: bool, + ) -> Result { + // Validate prompt name + if let Err(validation_error) = validate_prompt_name(&name) { + queue!( + session.stderr, style::Print("\n"), style::SetForegroundColor(Color::Red), - style::Print(format_mcp_error(&to_display)), + style::Print("❌ Invalid prompt name: "), + style::Print(validation_error), + style::Print("\n"), style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + let prompts = Prompts::new(&name, os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + let (local_exists, global_exists) = (prompts.local.exists(), prompts.global.exists()); + + // Find the target prompt to edit + let target_prompt = if global { + if !global_exists { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Global prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + &prompts.global + } else if local_exists { + &prompts.local + } else if global_exists { + // Found global prompt, but user wants to edit local + queue!( + session.stderr, style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Local prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found, but global version exists.\n"), + style::Print("Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts edit "), + style::Print(&name), + style::Print(" --global"), + style::SetForegroundColor(Color::Yellow), + style::Print(" to edit the global version, or\n"), + style::Print("use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts create "), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" to create a local override.\n"), + style::SetForegroundColor(Color::Reset), )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); } else { - let prompts = prompts - .result - .ok_or(ChatError::Custom("Result field missing from prompt/get request".into()))?; - let prompts = serde_json::from_value::(prompts) - .map_err(|e| ChatError::Custom(format!("Failed to deserialize prompt/get result: {:?}", e).into()))?; - session.pending_prompts.clear(); - session.pending_prompts.append(&mut VecDeque::from(prompts.messages)); - return Ok(ChatState::HandleInput { - input: orig_input.unwrap_or_default(), + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + let location = if global { "global" } else { "local" }; + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("Opening editor for "), + style::Print(location), + style::Print(" prompt: "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("File: "), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + + // Try to open the editor + match open_editor_file(&target_prompt.path) { + Ok(()) => { + queue!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print("✓ Prompt edited successfully.\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + }, + Err(err) => { + queue!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("Error opening editor: "), + style::Print(err.to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("Tip: You can edit this file directly: "), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + async fn execute_remove( + os: &Os, + session: &mut ChatSession, + name: String, + global: bool, + ) -> Result { + let prompts = Prompts::new(&name, os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + let (local_exists, global_exists) = (prompts.local.exists(), prompts.global.exists()); + + // Find the target prompt to remove + let target_prompt = if global { + if !global_exists { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Global prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + &prompts.global + } else if local_exists { + &prompts.local + } else if global_exists { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Local prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found, but global version exists.\n"), + style::Print("Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts remove "), + style::Print(&name), + style::Print(" --global"), + style::SetForegroundColor(Color::Yellow), + style::Print(" to remove the global version.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } else { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + let location = if global { "global" } else { "local" }; + + // Ask for confirmation + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("⚠ Warning: This will permanently remove the "), + style::Print(location), + style::Print(" prompt '"), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print("'.\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("File: "), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + )?; + + // Flush stderr to ensure the warning is displayed before asking for input + execute!(session.stderr)?; + + // Ask for user confirmation + let user_input = match crate::util::input("Are you sure you want to remove this prompt? (y/n): ", None) { + Ok(input) => input.trim().to_lowercase(), + Err(_) => { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Removal cancelled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }, + }; + + if user_input != "y" && user_input != "yes" { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Removal cancelled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, }); } - execute!(session.stderr, style::Print("\n"))?; + // Remove the file + match target_prompt.delete() { + Ok(()) => { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Removed "), + style::Print(location), + style::Print(" prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Green), + style::Print(" successfully.\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + }, + Err(err) => { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print("Error deleting prompt: "), + style::Print(err.to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + } Ok(ChatState::PromptUser { skip_printing_tools: true, @@ -313,6 +1160,154 @@ impl PromptsSubcommand { match self { PromptsSubcommand::List { .. } => "list", PromptsSubcommand::Get { .. } => "get", + PromptsSubcommand::Create { .. } => "create", + PromptsSubcommand::Edit { .. } => "edit", + PromptsSubcommand::Remove { .. } => "remove", + } + } +} +#[cfg(test)] +mod tests { + use std::fs; + + use tempfile::TempDir; + + use super::*; + + fn create_prompt_file(dir: &PathBuf, name: &str, content: &str) { + fs::create_dir_all(dir).unwrap(); + fs::write(dir.join(format!("{}.md", name)), content).unwrap(); + } + + #[tokio::test] + async fn test_prompt_file_operations() { + let temp_dir = TempDir::new().unwrap(); + + // Create test prompts in temp directory structure + let global_dir = temp_dir.path().join(".aws/amazonq/prompts"); + let local_dir = temp_dir.path().join(".amazonq/prompts"); + + create_prompt_file(&global_dir, "global_only", "Global content"); + create_prompt_file(&global_dir, "shared", "Global shared"); + create_prompt_file(&local_dir, "local_only", "Local content"); + create_prompt_file(&local_dir, "shared", "Local shared"); + + // Test that we can read the files directly + assert_eq!( + fs::read_to_string(global_dir.join("global_only.md")).unwrap(), + "Global content" + ); + assert_eq!(fs::read_to_string(local_dir.join("shared.md")).unwrap(), "Local shared"); + } + + #[test] + fn test_local_prompts_override_global() { + let temp_dir = TempDir::new().unwrap(); + + // Create global and local directories + let global_dir = temp_dir.path().join(".aws/amazonq/prompts"); + let local_dir = temp_dir.path().join(".amazonq/prompts"); + + // Create prompts: one with same name in both directories, one unique to each + create_prompt_file(&global_dir, "shared", "Global version"); + create_prompt_file(&global_dir, "global_only", "Global only"); + create_prompt_file(&local_dir, "shared", "Local version"); + create_prompt_file(&local_dir, "local_only", "Local only"); + + // Simulate the priority logic from get_available_prompt_names() + let mut names = Vec::new(); + + // Add global prompts first + for entry in fs::read_dir(&global_dir).unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("md") { + if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { + let prompt = Prompt::new(file_stem, global_dir.clone()); + names.push(prompt.name); + } + } } + + // Add local prompts (with override logic) + for entry in fs::read_dir(&local_dir).unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("md") { + if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { + let prompt = Prompt::new(file_stem, local_dir.clone()); + let name = prompt.name; + // Remove duplicate if it exists (local overrides global) + names.retain(|n| n != &name); + names.push(name); + } + } + } + + // Verify: should have 3 unique prompts (shared, global_only, local_only) + assert_eq!(names.len(), 3); + assert!(names.contains(&"shared".to_string())); + assert!(names.contains(&"global_only".to_string())); + assert!(names.contains(&"local_only".to_string())); + + // Verify only one "shared" exists (local overrode global) + let shared_count = names.iter().filter(|&name| name == "shared").count(); + assert_eq!(shared_count, 1); + + // Simulate load_prompt_by_name() priority: local first, then global + let shared_content = if local_dir.join("shared.md").exists() { + fs::read_to_string(local_dir.join("shared.md")).unwrap() + } else { + fs::read_to_string(global_dir.join("shared.md")).unwrap() + }; + + // Verify local version was loaded + assert_eq!(shared_content, "Local version"); + } + + #[test] + fn test_validate_prompt_name() { + // Empty name + assert!(validate_prompt_name("").is_err()); + assert!(validate_prompt_name(" ").is_err()); + + // Too long name (over 50 characters) + let long_name = "a".repeat(51); + assert!(validate_prompt_name(&long_name).is_err()); + + // Exactly 50 characters should be valid + let max_name = "a".repeat(50); + assert!(validate_prompt_name(&max_name).is_ok()); + + // Valid names with allowed characters + assert!(validate_prompt_name("valid_name").is_ok()); + assert!(validate_prompt_name("valid-name-v2").is_ok()); + + // Invalid characters (spaces, special chars, path separators) + assert!(validate_prompt_name("invalid name").is_err()); // space + assert!(validate_prompt_name("path/name").is_err()); // forward slash + assert!(validate_prompt_name("path\\name").is_err()); // backslash + assert!(validate_prompt_name("name.ext").is_err()); // dot + assert!(validate_prompt_name("name@host").is_err()); // at symbol + assert!(validate_prompt_name("name#tag").is_err()); // hash + assert!(validate_prompt_name("name$var").is_err()); // dollar sign + assert!(validate_prompt_name("name%percent").is_err()); // percent + assert!(validate_prompt_name("name&and").is_err()); // ampersand + assert!(validate_prompt_name("name*star").is_err()); // asterisk + assert!(validate_prompt_name("name+plus").is_err()); // plus + assert!(validate_prompt_name("name=equals").is_err()); // equals + assert!(validate_prompt_name("name?question").is_err()); // question mark + assert!(validate_prompt_name("name[bracket]").is_err()); // brackets + assert!(validate_prompt_name("name{brace}").is_err()); // braces + assert!(validate_prompt_name("name(paren)").is_err()); // parentheses + assert!(validate_prompt_name("name").is_err()); // angle brackets + assert!(validate_prompt_name("name|pipe").is_err()); // pipe + assert!(validate_prompt_name("name;semicolon").is_err()); // semicolon + assert!(validate_prompt_name("name:colon").is_err()); // colon + assert!(validate_prompt_name("name\"quote").is_err()); // double quote + assert!(validate_prompt_name("name'apostrophe").is_err()); // single quote + assert!(validate_prompt_name("name`backtick").is_err()); // backtick + assert!(validate_prompt_name("name~tilde").is_err()); // tilde + assert!(validate_prompt_name("name!exclamation").is_err()); // exclamation } } diff --git a/crates/chat-cli/src/cli/chat/cli/subscribe.rs b/crates/chat-cli/src/cli/chat/cli/subscribe.rs index c920908743..36dd670f04 100644 --- a/crates/chat-cli/src/cli/chat/cli/subscribe.rs +++ b/crates/chat-cli/src/cli/chat/cli/subscribe.rs @@ -28,9 +28,11 @@ const SUBSCRIBE_TEXT: &str = color_print::cstr! { "During the upgrade, you'll be Need help? Visit our subscription support page> https://docs.aws.amazon.com/console/amazonq/upgrade-builder-id" }; +/// Arguments for the subscribe command to manage Q Developer Pro subscriptions #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct SubscribeArgs { + /// Open the AWS console to manage an existing subscription #[arg(long)] manage: bool, } diff --git a/crates/chat-cli/src/cli/chat/cli/tangent.rs b/crates/chat-cli/src/cli/chat/cli/tangent.rs index 5bd04eb9bf..65165c84f3 100644 --- a/crates/chat-cli/src/cli/chat/cli/tangent.rs +++ b/crates/chat-cli/src/cli/chat/cli/tangent.rs @@ -1,4 +1,7 @@ -use clap::Args; +use clap::{ + Args, + Subcommand, +}; use crossterm::execute; use crossterm::style::{ self, @@ -14,9 +17,33 @@ use crate::database::settings::Setting; use crate::os::Os; #[derive(Debug, PartialEq, Args)] -pub struct TangentArgs; +pub struct TangentArgs { + #[command(subcommand)] + pub subcommand: Option, +} + +#[derive(Debug, PartialEq, Subcommand)] +pub enum TangentSubcommand { + /// Exit tangent mode and keep the last conversation entry (user question + assistant response) + Tail, +} impl TangentArgs { + async fn send_tangent_telemetry(os: &Os, session: &ChatSession, duration_seconds: i64) { + if let Err(err) = os + .telemetry + .send_tangent_mode_session( + &os.database, + session.conversation.conversation_id().to_string(), + crate::telemetry::TelemetryResult::Succeeded, + crate::telemetry::core::TangentModeSessionArgs { duration_seconds }, + ) + .await + { + tracing::warn!(?err, "Failed to send tangent mode session telemetry"); + } + } + pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { // Check if tangent mode is enabled if !os @@ -35,69 +62,119 @@ impl TangentArgs { skip_printing_tools: true, }); } - if session.conversation.is_in_tangent_mode() { - // Get duration before exiting tangent mode - let duration_seconds = session.conversation.get_tangent_duration_seconds().unwrap_or(0); - - session.conversation.exit_tangent_mode(); - - // Send telemetry for tangent mode session - if let Err(err) = os - .telemetry - .send_tangent_mode_session( - &os.database, - session.conversation.conversation_id().to_string(), - crate::telemetry::TelemetryResult::Succeeded, - crate::telemetry::core::TangentModeSessionArgs { duration_seconds }, - ) - .await - { - tracing::warn!(?err, "Failed to send tangent mode session telemetry"); - } - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print("Restored conversation from checkpoint ("), - style::SetForegroundColor(Color::Yellow), - style::Print("↯"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("). - Returned to main conversation.\n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - session.conversation.enter_tangent_mode(); - - // Get the configured tangent mode key for display - let tangent_key_char = match os - .database - .settings - .get_string(crate::database::settings::Setting::TangentModeKey) - { - Some(key) if key.len() == 1 => key.chars().next().unwrap_or('t'), - _ => 't', // Default to 't' if setting is missing or invalid - }; - let tangent_key_display = format!("ctrl + {}", tangent_key_char.to_lowercase()); + match self.subcommand { + Some(TangentSubcommand::Tail) => { + // Check if checkpoint is enabled + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print( + "⚠️ Checkpoint is disabled while in tangent mode. Please exit tangent mode if you want to use checkpoint.\n" + ), + style::SetForegroundColor(Color::Reset), + )?; + } + if session.conversation.is_in_tangent_mode() { + let duration_seconds = session.conversation.get_tangent_duration_seconds().unwrap_or(0); + session.conversation.exit_tangent_mode_with_tail(); + Self::send_tangent_telemetry(os, session, duration_seconds).await; - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print("Created a conversation checkpoint ("), - style::SetForegroundColor(Color::Yellow), - style::Print("↯"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("). Use "), - style::SetForegroundColor(Color::Green), - style::Print(&tangent_key_display), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" or "), - style::SetForegroundColor(Color::Green), - style::Print("/tangent"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to restore the conversation later.\n"), - style::Print("Note: this functionality is experimental and may change or be removed in the future.\n"), - style::SetForegroundColor(Color::Reset) - )?; + execute!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("Restored conversation from checkpoint ("), + style::SetForegroundColor(Color::Yellow), + style::Print("↯"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(") with last conversation entry preserved.\n"), + style::SetForegroundColor(Color::Reset) + )?; + } else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("You need to be in tangent mode to use tail.\n"), + style::SetForegroundColor(Color::Reset) + )?; + } + }, + None => { + if session.conversation.is_in_tangent_mode() { + let duration_seconds = session.conversation.get_tangent_duration_seconds().unwrap_or(0); + session.conversation.exit_tangent_mode(); + Self::send_tangent_telemetry(os, session, duration_seconds).await; + + execute!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("Restored conversation from checkpoint ("), + style::SetForegroundColor(Color::Yellow), + style::Print("↯"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("). - Returned to main conversation.\n"), + style::SetForegroundColor(Color::Reset) + )?; + } else { + // Check if checkpoint is enabled + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print( + "⚠️ Checkpoint is disabled while in tangent mode. Please exit tangent mode if you want to use checkpoint.\n" + ), + style::SetForegroundColor(Color::Reset), + )?; + } + + session.conversation.enter_tangent_mode(); + + // Get the configured tangent mode key for display + let tangent_key_char = match os + .database + .settings + .get_string(crate::database::settings::Setting::TangentModeKey) + { + Some(key) if key.len() == 1 => key.chars().next().unwrap_or('t'), + _ => 't', // Default to 't' if setting is missing or invalid + }; + let tangent_key_display = format!("ctrl + {}", tangent_key_char.to_lowercase()); + + execute!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("Created a conversation checkpoint ("), + style::SetForegroundColor(Color::Yellow), + style::Print("↯"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("). Use "), + style::SetForegroundColor(Color::Green), + style::Print(&tangent_key_display), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" or "), + style::SetForegroundColor(Color::Green), + style::Print("/tangent"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to restore the conversation later.\n"), + style::Print( + "Note: this functionality is experimental and may change or be removed in the future.\n" + ), + style::SetForegroundColor(Color::Reset) + )?; + } + }, } Ok(ChatState::PromptUser { diff --git a/crates/chat-cli/src/cli/chat/cli/tools.rs b/crates/chat-cli/src/cli/chat/cli/tools.rs index ce05dce3cb..1f1e5267ff 100644 --- a/crates/chat-cli/src/cli/chat/cli/tools.rs +++ b/crates/chat-cli/src/cli/chat/cli/tools.rs @@ -35,6 +35,7 @@ use crate::cli::chat::{ }; use crate::util::consts::MCP_SERVER_TOOL_DELIMITER; +/// Command-line arguments for managing tools in the chat session #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct ToolsArgs { @@ -146,7 +147,11 @@ impl ToolsArgs { queue!( session.stderr, style::SetAttribute(Attribute::Bold), - style::Print("Servers still loading"), + style::Print("Servers loading (Some of these might need auth. See "), + style::SetForegroundColor(Color::Green), + style::Print("/mcp"), + style::SetForegroundColor(Color::Reset), + style::Print(" for details)"), style::SetAttribute(Attribute::Reset), style::Print("\n"), style::Print("▔".repeat(terminal_width)), @@ -197,17 +202,20 @@ trust so that no confirmation is required. Refer to the documentation for how to configure tools with your agent: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/agent-format.md#tools-field" )] +/// Subcommands for managing tool permissions and configurations pub enum ToolsSubcommand { /// Show the input schema for all available tools Schema, /// Trust a specific tool or tools for the session Trust { #[arg(required = true)] + /// Names of tools to trust tool_names: Vec, }, /// Revert a tool or tools to per-request confirmation Untrust { #[arg(required = true)] + /// Names of tools to untrust tool_names: Vec, }, /// Trust all tools (equivalent to deprecated /acceptall) diff --git a/crates/chat-cli/src/cli/chat/cli/usage.rs b/crates/chat-cli/src/cli/chat/cli/usage.rs index eca538e2b6..4240bf4c2f 100644 --- a/crates/chat-cli/src/cli/chat/cli/usage.rs +++ b/crates/chat-cli/src/cli/chat/cli/usage.rs @@ -20,18 +20,75 @@ use crate::cli::chat::{ ChatState, }; use crate::os::Os; + +/// Detailed usage data for context window analysis +#[derive(Debug)] +pub struct DetailedUsageData { + pub total_tokens: TokenCount, + pub context_tokens: TokenCount, + pub assistant_tokens: TokenCount, + pub user_tokens: TokenCount, + pub tools_tokens: TokenCount, + pub context_window_size: usize, + pub dropped_context_files: Vec<(String, String)>, +} + +/// Calculate usage percentage from token counts +pub fn calculate_usage_percentage(tokens: TokenCount, context_window_size: usize) -> f32 { + (tokens.value() as f32 / context_window_size as f32) * 100.0 +} + +/// Get detailed usage data for context window analysis +pub async fn get_detailed_usage_data(session: &mut ChatSession, os: &Os) -> Result { + let context_window_size = context_window_tokens(session.conversation.model_info.as_ref()); + + let state = session + .conversation + .backend_conversation_state(os, true, &mut std::io::stderr()) + .await?; + + let data = state.calculate_conversation_size(); + let tool_specs_json: String = state + .tools + .values() + .filter_map(|s| serde_json::to_string(s).ok()) + .collect::>() + .join(""); + let tools_char_count: CharCount = tool_specs_json.len().into(); + let total_tokens: TokenCount = + (data.context_messages + data.user_messages + data.assistant_messages + tools_char_count).into(); + + Ok(DetailedUsageData { + total_tokens, + context_tokens: data.context_messages.into(), + assistant_tokens: data.assistant_messages.into(), + user_tokens: data.user_messages.into(), + tools_tokens: tools_char_count.into(), + context_window_size, + dropped_context_files: state.dropped_context_files, + }) +} + +/// Get total usage percentage (simple interface for prompt generation) +pub async fn get_total_usage_percentage(session: &mut ChatSession, os: &Os) -> Result { + let data = get_detailed_usage_data(session, os).await?; + Ok(calculate_usage_percentage(data.total_tokens, data.context_window_size)) +} + +/// Arguments for the usage command that displays token usage statistics and context window +/// information. +/// +/// This command shows how many tokens are being used by different components (context files, tools, +/// assistant responses, and user prompts) within the current chat session's context window. #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct UsageArgs; impl UsageArgs { pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - let state = session - .conversation - .backend_conversation_state(os, true, &mut session.stderr) - .await?; + let usage_data = get_detailed_usage_data(session, os).await?; - if !state.dropped_context_files.is_empty() { + if !usage_data.dropped_context_files.is_empty() { execute!( session.stderr, style::SetForegroundColor(Color::DarkYellow), @@ -44,33 +101,18 @@ impl UsageArgs { )?; } - let data = state.calculate_conversation_size(); - let tool_specs_json: String = state - .tools - .values() - .filter_map(|s| serde_json::to_string(s).ok()) - .collect::>() - .join(""); - let context_token_count: TokenCount = data.context_messages.into(); - let assistant_token_count: TokenCount = data.assistant_messages.into(); - let user_token_count: TokenCount = data.user_messages.into(); - let tools_char_count: CharCount = tool_specs_json.len().into(); // usize → CharCount - let tools_token_count: TokenCount = tools_char_count.into(); // CharCount → TokenCount - let total_token_used: TokenCount = - (data.context_messages + data.user_messages + data.assistant_messages + tools_char_count).into(); let window_width = session.terminal_width(); // set a max width for the progress bar for better aesthetic let progress_bar_width = std::cmp::min(window_width, 80); - let context_window_size = context_window_tokens(session.conversation.model_info.as_ref()); - let context_width = - ((context_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; - let assistant_width = - ((assistant_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; - let tools_width = - ((tools_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; - let user_width = - ((user_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; + let context_width = ((usage_data.context_tokens.value() as f64 / usage_data.context_window_size as f64) + * progress_bar_width as f64) as usize; + let assistant_width = ((usage_data.assistant_tokens.value() as f64 / usage_data.context_window_size as f64) + * progress_bar_width as f64) as usize; + let tools_width = ((usage_data.tools_tokens.value() as f64 / usage_data.context_window_size as f64) + * progress_bar_width as f64) as usize; + let user_width = ((usage_data.user_tokens.value() as f64 / usage_data.context_window_size as f64) + * progress_bar_width as f64) as usize; let left_over_width = progress_bar_width - std::cmp::min( @@ -80,44 +122,45 @@ impl UsageArgs { let is_overflow = (context_width + assistant_width + user_width + tools_width) > progress_bar_width; + let total_percentage = calculate_usage_percentage(usage_data.total_tokens, usage_data.context_window_size); + if is_overflow { queue!( session.stderr, style::Print(format!( "\nCurrent context window ({} of {}k tokens used)\n", - total_token_used, - context_window_size / 1000 + usage_data.total_tokens, + usage_data.context_window_size / 1000 )), style::SetForegroundColor(Color::DarkRed), style::Print("█".repeat(progress_bar_width)), style::SetForegroundColor(Color::Reset), style::Print(" "), - style::Print(format!( - "{:.2}%", - (total_token_used.value() as f32 / context_window_size as f32) * 100.0 - )), + style::Print(format!("{:.2}%", total_percentage)), )?; } else { queue!( session.stderr, style::Print(format!( "\nCurrent context window ({} of {}k tokens used)\n", - total_token_used, - context_window_size / 1000 + usage_data.total_tokens, + usage_data.context_window_size / 1000 )), // Context files style::SetForegroundColor(Color::DarkCyan), // add a nice visual to mimic "tiny" progress, so the overrall progress bar doesn't look too // empty - style::Print("|".repeat(if context_width == 0 && *context_token_count > 0 { - 1 - } else { - 0 - })), + style::Print( + "|".repeat(if context_width == 0 && usage_data.context_tokens.value() > 0 { + 1 + } else { + 0 + }) + ), style::Print("█".repeat(context_width)), // Tools style::SetForegroundColor(Color::DarkRed), - style::Print("|".repeat(if tools_width == 0 && *tools_token_count > 0 { + style::Print("|".repeat(if tools_width == 0 && usage_data.tools_tokens.value() > 0 { 1 } else { 0 @@ -125,24 +168,27 @@ impl UsageArgs { style::Print("█".repeat(tools_width)), // Assistant responses style::SetForegroundColor(Color::Blue), - style::Print("|".repeat(if assistant_width == 0 && *assistant_token_count > 0 { + style::Print( + "|".repeat(if assistant_width == 0 && usage_data.assistant_tokens.value() > 0 { + 1 + } else { + 0 + }) + ), + style::Print("█".repeat(assistant_width)), + // User prompts + style::SetForegroundColor(Color::Magenta), + style::Print("|".repeat(if user_width == 0 && usage_data.user_tokens.value() > 0 { 1 } else { 0 })), - style::Print("█".repeat(assistant_width)), - // User prompts - style::SetForegroundColor(Color::Magenta), - style::Print("|".repeat(if user_width == 0 && *user_token_count > 0 { 1 } else { 0 })), style::Print("█".repeat(user_width)), style::SetForegroundColor(Color::DarkGrey), style::Print("█".repeat(left_over_width)), style::Print(" "), style::SetForegroundColor(Color::Reset), - style::Print(format!( - "{:.2}%", - (total_token_used.value() as f32 / context_window_size as f32) * 100.0 - )), + style::Print(format!("{:.2}%", total_percentage)), )?; } @@ -155,32 +201,32 @@ impl UsageArgs { style::SetForegroundColor(Color::Reset), style::Print(format!( "~{} tokens ({:.2}%)\n", - context_token_count, - (context_token_count.value() as f32 / context_window_size as f32) * 100.0 + usage_data.context_tokens, + calculate_usage_percentage(usage_data.context_tokens, usage_data.context_window_size) )), style::SetForegroundColor(Color::DarkRed), style::Print("█ Tools: "), style::SetForegroundColor(Color::Reset), style::Print(format!( " ~{} tokens ({:.2}%)\n", - tools_token_count, - (tools_token_count.value() as f32 / context_window_size as f32) * 100.0 + usage_data.tools_tokens, + calculate_usage_percentage(usage_data.tools_tokens, usage_data.context_window_size) )), style::SetForegroundColor(Color::Blue), style::Print("█ Q responses: "), style::SetForegroundColor(Color::Reset), style::Print(format!( " ~{} tokens ({:.2}%)\n", - assistant_token_count, - (assistant_token_count.value() as f32 / context_window_size as f32) * 100.0 + usage_data.assistant_tokens, + calculate_usage_percentage(usage_data.assistant_tokens, usage_data.context_window_size) )), style::SetForegroundColor(Color::Magenta), style::Print("█ Your prompts: "), style::SetForegroundColor(Color::Reset), style::Print(format!( " ~{} tokens ({:.2}%)\n\n", - user_token_count, - (user_token_count.value() as f32 / context_window_size as f32) * 100.0 + usage_data.user_tokens, + calculate_usage_percentage(usage_data.user_tokens, usage_data.context_window_size) )), )?; diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs index 1fdcc5e8ee..89edfb47aa 100644 --- a/crates/chat-cli/src/cli/chat/context.rs +++ b/crates/chat-cli/src/cli/chat/context.rs @@ -14,6 +14,7 @@ use serde::{ Serializer, }; +use super::cli::hooks::HookOutput; use super::cli::model::context_window_tokens; use super::util::drop_matched_context_files; use crate::cli::agent::Agent; @@ -247,11 +248,16 @@ impl ContextManager { &mut self, trigger: HookTrigger, output: &mut impl Write, + os: &crate::os::Os, prompt: Option<&str>, - ) -> Result, ChatError> { + tool_context: Option, + ) -> Result, ChatError> { let mut hooks = self.hooks.clone(); hooks.retain(|t, _| *t == trigger); - self.hook_executor.run_hooks(hooks, output, prompt).await + let cwd = os.env.current_dir()?.to_string_lossy().to_string(); + self.hook_executor + .run_hooks(hooks, output, &cwd, prompt, tool_context) + .await } } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 48fd13f991..1217c0289b 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -13,6 +13,12 @@ use crossterm::{ style, }; use eyre::Result; +use rmcp::model::{ + PromptMessage, + PromptMessageContent, + PromptMessageRole, + ResourceContents, +}; use serde::{ Deserialize, Serialize, @@ -23,6 +29,7 @@ use tracing::{ }; use super::cli::compact::CompactStrategy; +use super::cli::hooks::HookOutput; use super::cli::model::context_window_tokens; use super::consts::{ DUMMY_TOOL_NAME, @@ -67,12 +74,15 @@ use crate::cli::agent::hook::{ HookTrigger, }; use crate::cli::chat::ChatError; +use crate::cli::chat::checkpoint::{ + Checkpoint, + CheckpointManager, +}; use crate::cli::chat::cli::model::{ ModelInfo, get_model_info, }; use crate::cli::chat::tools::custom_tool::CustomToolConfig; -use crate::mcp_client::Prompt; use crate::os::Os; pub const CONTEXT_ENTRY_START_HEADER: &str = "--- CONTEXT ENTRY BEGIN ---\n"; @@ -132,6 +142,8 @@ pub struct ConversationState { /// Maps from a file path to [FileLineTracker] #[serde(default)] pub file_line_tracker: HashMap, + + pub checkpoint_manager: Option, #[serde(default = "default_true")] pub mcp_enabled: bool, /// Tangent mode checkpoint - stores main conversation when in tangent mode @@ -197,6 +209,7 @@ impl ConversationState { model: None, model_info: model, file_line_tracker: HashMap::new(), + checkpoint_manager: None, mcp_enabled, tangent_state: None, } @@ -210,13 +223,11 @@ impl ConversationState { &self.history } - /// Clears the conversation history and optionally the summary. - pub fn clear(&mut self, preserve_summary: bool) { + /// Clears the conversation history and summary. + pub fn clear(&mut self) { self.next_message = None; self.history.clear(); - if !preserve_summary { - self.latest_summary = None; - } + self.latest_summary = None; } /// Check if currently in tangent mode @@ -266,25 +277,80 @@ impl ConversationState { } } + /// Exit tangent mode and preserve the last conversation entry (user + assistant) + pub fn exit_tangent_mode_with_tail(&mut self) { + if let Some(checkpoint) = self.tangent_state.take() { + // Capture the last history entry from tangent conversation if it exists + // and if it's different from what was in the main conversation + let last_entry = if self.history.len() > checkpoint.main_history.len() { + self.history.back().cloned() + } else { + None // No new entries in tangent mode + }; + + // Restore from checkpoint + self.restore_from_checkpoint(checkpoint); + + // Add the last entry if it exists + if let Some(entry) = last_entry { + self.history.push_back(entry); + } + } + } + /// Appends a collection prompts into history and returns the last message in the collection. /// It asserts that the collection ends with a prompt that assumes the role of user. - pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { + pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { + fn stringify_prompt_message_content(prompt_msg_content: PromptMessageContent) -> String { + match prompt_msg_content { + PromptMessageContent::Text { text } => text, + PromptMessageContent::Image { image } => image.raw.data, + PromptMessageContent::Resource { resource } => { + // TODO: add support for resources for prompt + match resource.raw.resource { + ResourceContents::TextResourceContents { + uri, mime_type, text, .. + } => { + let mime_type = mime_type.as_deref().unwrap_or("unknown"); + format!("Text resource of uri: {uri}, mime_type: {mime_type}, text: {text}") + }, + ResourceContents::BlobResourceContents { + uri, mime_type, blob, .. + } => { + let mime_type = mime_type.as_deref().unwrap_or("unknown"); + format!("Blob resource of uri: {uri}, mime_type: {mime_type}, blob: {blob}") + }, + } + }, + PromptMessageContent::ResourceLink { link } => serde_json::to_string(&link.raw).unwrap_or(format!( + "Resource link with uri: {}, name: {}", + link.raw.uri, link.raw.name + )), + } + } + debug_assert!(self.next_message.is_none(), "next_message should not exist"); - debug_assert!(prompts.back().is_some_and(|p| p.role == crate::mcp_client::Role::User)); + debug_assert!(prompts.back().is_some_and(|p| p.role == PromptMessageRole::User)); let last_msg = prompts.pop_back()?; let (mut candidate_user, mut candidate_asst) = (None::, None::); - while let Some(prompt) = prompts.pop_front() { - let Prompt { role, content } = prompt; + while let Some(prompt_msg) = prompts.pop_front() { + let PromptMessage { + role, + content: prompt_msg_content, + } = prompt_msg; + let content_str = stringify_prompt_message_content(prompt_msg_content); + match role { - crate::mcp_client::Role::User => { - let user_msg = UserMessage::new_prompt(content.to_string(), None); + PromptMessageRole::User => { + let user_msg = UserMessage::new_prompt(content_str, None); candidate_user.replace(user_msg); }, - crate::mcp_client::Role::Assistant => { - let assistant_msg = AssistantMessage::new_response(None, content.into()); + PromptMessageRole::Assistant => { + let assistant_msg = AssistantMessage::new_response(None, content_str); candidate_asst.replace(assistant_msg); }, } + if candidate_asst.is_some() && candidate_user.is_some() { let assistant = candidate_asst.take().unwrap(); let user = candidate_user.take().unwrap(); @@ -296,7 +362,8 @@ impl ConversationState { }); } } - Some(last_msg.content.to_string()) + + Some(stringify_prompt_message_content(last_msg.content)) } pub fn next_user_message(&self) -> Option<&UserMessage> { @@ -504,12 +571,26 @@ impl ConversationState { let mut agent_spawn_context = None; if let Some(cm) = self.context_manager.as_mut() { let user_prompt = self.next_message.as_ref().and_then(|m| m.prompt()); - let agent_spawn = cm.run_hooks(HookTrigger::AgentSpawn, output, user_prompt).await?; + let agent_spawn = cm + .run_hooks( + HookTrigger::AgentSpawn, + output, + os, + user_prompt, + None, // tool_context + ) + .await?; agent_spawn_context = format_hook_context(&agent_spawn, HookTrigger::AgentSpawn); if let (true, Some(next_message)) = (run_perprompt_hooks, self.next_message.as_mut()) { let per_prompt = cm - .run_hooks(HookTrigger::UserPromptSubmit, output, next_message.prompt()) + .run_hooks( + HookTrigger::UserPromptSubmit, + output, + os, + next_message.prompt(), + None, // tool_context + ) .await?; if let Some(ctx) = format_hook_context(&per_prompt, HookTrigger::UserPromptSubmit) { next_message.additional_context = ctx; @@ -665,6 +746,7 @@ IMPORTANT: Return ONLY raw JSON with NO markdown formatting, NO code blocks, NO Your task is to generate an agent configuration file for an agent named '{}' with the following description: {}\n\n\ The configuration must conform to this JSON schema:\n{}\n\n\ We have a prepopulated template: {} \n\n\ +Please change the useLegacyMcpJson field to false. Please generate the prompt field using user provided description, and fill in the MCP tools that user has selected {}. Return only the JSON configuration, no additional text.", agent_name, agent_description, schema, prepopulated_content, selected_servers @@ -816,6 +898,20 @@ Return only the JSON configuration, no additional text.", self.transcript.push_back(message); } + /// Restore conversation from a checkpoint's history snapshot + pub fn restore_to_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<(), eyre::Report> { + // 1. Restore history from snapshot + self.history = checkpoint.history_snapshot.clone(); + + // 2. Clear any pending next message (uncommitted state) + self.next_message = None; + + // 3. Update valid history range + self.valid_history_range = (0, self.history.len()); + + Ok(()) + } + /// Swapping agent involves the following: /// - Reinstantiate the context manager /// - Swap agent on tool manager @@ -970,8 +1066,12 @@ impl From for ToolInputSchema { /// # Returns /// [Option::Some] if `hook_results` is not empty and at least one hook has content. Otherwise, /// [Option::None] -fn format_hook_context(hook_results: &[((HookTrigger, Hook), String)], trigger: HookTrigger) -> Option { - if hook_results.iter().all(|(_, content)| content.is_empty()) { +fn format_hook_context(hook_results: &[((HookTrigger, Hook), HookOutput)], trigger: HookTrigger) -> Option { + // Note: only format context when hook command exit code is 0 + if hook_results + .iter() + .all(|(_, (exit_code, content))| *exit_code != 0 || content.is_empty()) + { return None; } @@ -984,7 +1084,10 @@ fn format_hook_context(hook_results: &[((HookTrigger, Hook), String)], trigger: } context_content.push_str("\n\n"); - for (_, output) in hook_results.iter().filter(|((h_trigger, _), _)| *h_trigger == trigger) { + for (_, (_, output)) in hook_results + .iter() + .filter(|((h_trigger, _), (exit_code, _))| *h_trigger == trigger && *exit_code == 0) + { context_content.push_str(&format!("{output}\n\n")); } context_content.push_str(CONTEXT_ENTRY_END_HEADER); @@ -1156,6 +1259,7 @@ mod tests { use crate::cli::chat::tool_manager::ToolManager; const AMAZONQ_FILENAME: &str = "AmazonQ.md"; + const AGENTS_FILENAME: &str = "AGENTS.md"; fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { @@ -1368,11 +1472,13 @@ mod tests { let mut agents = Agents::default(); let mut agent = Agent::default(); agent.resources.push(AMAZONQ_FILENAME.into()); + agent.resources.push(AGENTS_FILENAME.into()); agents.agents.insert("TestAgent".to_string(), agent); agents.switch("TestAgent").expect("Agent switch failed"); agents }; os.fs.write(AMAZONQ_FILENAME, "test context").await.unwrap(); + os.fs.write(AGENTS_FILENAME, "test agents context").await.unwrap(); let mut output = vec![]; let mut tool_manager = ToolManager::default(); @@ -1526,4 +1632,99 @@ mod tests { // No duration when not in tangent mode assert!(conversation.get_tangent_duration_seconds().is_none()); } + + #[tokio::test] + async fn test_tangent_mode_with_tail() { + let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); + let mut tool_manager = ToolManager::default(); + let mut conversation = ConversationState::new( + "test_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(), + tool_manager, + None, + &os, + false, + ) + .await; + + // Add main conversation + conversation.set_next_user_message("main question".to_string()).await; + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_response(None, "main response".to_string()), + None, + ); + + let main_history_len = conversation.history.len(); + + // Enter tangent mode + conversation.enter_tangent_mode(); + assert!(conversation.is_in_tangent_mode()); + + // Add tangent conversation + conversation.set_next_user_message("tangent question".to_string()).await; + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_response(None, "tangent response".to_string()), + None, + ); + + // Exit tangent mode with tail + conversation.exit_tangent_mode_with_tail(); + assert!(!conversation.is_in_tangent_mode()); + + // Should have main conversation + last assistant message from tangent + assert_eq!(conversation.history.len(), main_history_len + 1); + + // Check that the last message is the tangent response + if let Some(entry) = conversation.history.back() { + assert_eq!(entry.assistant.content(), "tangent response"); + } else { + panic!("Expected history entry at the end"); + } + } + + #[tokio::test] + async fn test_tangent_mode_with_tail_edge_cases() { + let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); + let mut tool_manager = ToolManager::default(); + let mut conversation = ConversationState::new( + "test_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(), + tool_manager, + None, + &os, + false, + ) + .await; + + // Add main conversation + conversation.set_next_user_message("main question".to_string()).await; + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_response(None, "main response".to_string()), + None, + ); + + let main_history_len = conversation.history.len(); + + // Test: Enter tangent mode but don't add any new conversation + conversation.enter_tangent_mode(); + assert!(conversation.is_in_tangent_mode()); + + // Exit tangent mode with tail (should not add anything since no new entries) + conversation.exit_tangent_mode_with_tail(); + assert!(!conversation.is_in_tangent_mode()); + + // Should have same length as before (no new entries added) + assert_eq!(conversation.history.len(), main_history_len); + + // Test: Call exit_tangent_mode_with_tail when not in tangent mode (should do nothing) + conversation.exit_tangent_mode_with_tail(); + assert_eq!(conversation.history.len(), main_history_len); + } } diff --git a/crates/chat-cli/src/cli/chat/error_formatter.rs b/crates/chat-cli/src/cli/chat/error_formatter.rs deleted file mode 100644 index 96a604bdbe..0000000000 --- a/crates/chat-cli/src/cli/chat/error_formatter.rs +++ /dev/null @@ -1,148 +0,0 @@ -/// Formats an MCP error message to be more user-friendly. -/// -/// This function extracts nested JSON from the error message and formats it -/// with proper indentation and newlines. -/// -/// # Arguments -/// -/// * `err` - A reference to a serde_json::Value containing the error information -/// -/// # Returns -/// -/// A formatted string representation of the error message -pub fn format_mcp_error(err: &serde_json::Value) -> String { - // Extract the message field from the error JSON - if let Some(message) = err.get("message").and_then(|m| m.as_str()) { - // Check if the message contains a nested JSON array - if let Some(start_idx) = message.find('[') { - if let Some(end_idx) = message.rfind(']') { - let prefix = &message[..start_idx].trim(); - let nested_json = &message[start_idx..=end_idx]; - - // Try to parse the nested JSON - if let Ok(nested_value) = serde_json::from_str::(nested_json) { - // Format the error message with the prefix and pretty-printed nested JSON - return format!( - "{}\n{}", - prefix, - serde_json::to_string_pretty(&nested_value).unwrap_or_else(|_| nested_json.to_string()) - ); - } - } - } - } - - // Fallback if message field is missing or if we couldn't extract and parse nested JSON - serde_json::to_string_pretty(err).unwrap_or_else(|_| format!("{:?}", err)) -} - -#[cfg(test)] -mod tests { - use serde_json::json; - - use super::*; - - #[test] - fn test_format_mcp_error_with_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt agent_script_coco_was_sev2_ticket_details_retrieve: [\n {\n \"code\": \"invalid_type\",\n \"expected\": \"object\",\n \"received\": \"undefined\",\n \"path\": [],\n \"message\": \"Required\"\n }\n]" - }); - - let formatted = format_mcp_error(&error); - - // Extract the prefix and JSON part from the formatted string - let parts: Vec<&str> = formatted.split('\n').collect(); - let prefix = parts[0]; - let json_part = &formatted[prefix.len() + 1..]; - - // Check that the prefix is correct - assert_eq!( - prefix, - "MCP error -32602: Invalid arguments for prompt agent_script_coco_was_sev2_ticket_details_retrieve:" - ); - - // Parse the JSON part to compare the actual content rather than the exact string - let parsed_json: serde_json::Value = serde_json::from_str(json_part).expect("Failed to parse JSON part"); - - // Expected JSON structure - let expected_json = json!([ - { - "code": "invalid_type", - "expected": "object", - "received": "undefined", - "path": [], - "message": "Required" - } - ]); - - // Compare the parsed JSON values - assert_eq!(parsed_json, expected_json); - } - - #[test] - fn test_format_mcp_error_without_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt" - }); - - let formatted = format_mcp_error(&error); - - assert_eq!( - formatted, - "{\n \"code\": -32602,\n \"message\": \"MCP error -32602: Invalid arguments for prompt\"\n}" - ); - } - - #[test] - fn test_format_mcp_error_non_mcp_error() { - let error = json!({ - "error": "Unknown error occurred" - }); - - let formatted = format_mcp_error(&error); - - // Should pretty-print the entire error - assert_eq!(formatted, "{\n \"error\": \"Unknown error occurred\"\n}"); - } - - #[test] - fn test_format_mcp_error_empty_message() { - let error = json!({ - "code": -32602, - "message": "" - }); - - let formatted = format_mcp_error(&error); - - assert_eq!(formatted, "{\n \"code\": -32602,\n \"message\": \"\"\n}"); - } - - #[test] - fn test_format_mcp_error_missing_message() { - let error = json!({ - "code": -32602 - }); - - let formatted = format_mcp_error(&error); - - assert_eq!(formatted, "{\n \"code\": -32602\n}"); - } - - #[test] - fn test_format_mcp_error_malformed_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt: [{\n \"code\": \"invalid_type\",\n \"expected\": \"object\",\n \"received\": \"undefined\",\n \"path\": [],\n \"message\": \"Required\"\n" - }); - - let formatted = format_mcp_error(&error); - - // Should return the pretty-printed JSON since the nested JSON is malformed - assert_eq!( - formatted, - "{\n \"code\": -32602,\n \"message\": \"MCP error -32602: Invalid arguments for prompt: [{\\n \\\"code\\\": \\\"invalid_type\\\",\\n \\\"expected\\\": \\\"object\\\",\\n \\\"received\\\": \\\"undefined\\\",\\n \\\"path\\\": [],\\n \\\"message\\\": \\\"Required\\\"\\n\"\n}" - ); - } -} diff --git a/crates/chat-cli/src/cli/chat/input_source.rs b/crates/chat-cli/src/cli/chat/input_source.rs index 5d88abf6f3..0c0830852c 100644 --- a/crates/chat-cli/src/cli/chat/input_source.rs +++ b/crates/chat-cli/src/cli/chat/input_source.rs @@ -31,11 +31,33 @@ mod inner { } } +impl Drop for InputSource { + fn drop(&mut self) { + self.save_history().unwrap(); + } +} impl InputSource { pub fn new(os: &Os, sender: PromptQuerySender, receiver: PromptQueryResponseReceiver) -> Result { Ok(Self(inner::Inner::Readline(rl(os, sender, receiver)?))) } + /// Save history to file + pub fn save_history(&mut self) -> Result<()> { + if let inner::Inner::Readline(rl) = &mut self.0 { + if let Some(helper) = rl.helper() { + let history_path = helper.get_history_path(); + + // Create directory if it doesn't exist + if let Some(parent) = history_path.parent() { + std::fs::create_dir_all(parent)?; + } + + rl.append_history(&history_path)?; + } + } + Ok(()) + } + #[cfg(unix)] pub fn put_skim_command_selector( &mut self, @@ -78,12 +100,9 @@ impl InputSource { let curr_line = rl.readline(prompt); match curr_line { Ok(line) => { - let _ = rl.add_history_entry(line.as_str()); - - if let Some(helper) = rl.helper_mut() { - helper.update_hinter_history(&line); + if Self::should_append_history(&line) { + let _ = rl.add_history_entry(line.as_str()); } - Ok(Some(line)) }, Err(ReadlineError::Interrupted | ReadlineError::Eof) => Ok(None), @@ -97,6 +116,18 @@ impl InputSource { } } + fn should_append_history(line: &str) -> bool { + let trimmed = line.trim().to_lowercase(); + if trimmed.is_empty() { + return false; + } + + if matches!(trimmed.as_str(), "y" | "n" | "t") { + return false; + } + true + } + // We're keeping this method for potential future use #[allow(dead_code)] pub fn set_buffer(&mut self, content: &str) { diff --git a/crates/chat-cli/src/cli/chat/line_tracker.rs b/crates/chat-cli/src/cli/chat/line_tracker.rs index 80f640ecd9..1717d16fe7 100644 --- a/crates/chat-cli/src/cli/chat/line_tracker.rs +++ b/crates/chat-cli/src/cli/chat/line_tracker.rs @@ -13,6 +13,10 @@ pub struct FileLineTracker { pub before_fswrite_lines: usize, /// Line count after `fs_write` executes pub after_fswrite_lines: usize, + /// Lines added by agent in the current operation + pub lines_added_by_agent: usize, + /// Lines removed by agent in the current operation + pub lines_removed_by_agent: usize, /// Whether or not this is the first `fs_write` invocation pub is_first_write: bool, } @@ -23,6 +27,8 @@ impl Default for FileLineTracker { prev_fswrite_lines: 0, before_fswrite_lines: 0, after_fswrite_lines: 0, + lines_added_by_agent: 0, + lines_removed_by_agent: 0, is_first_write: true, } } @@ -34,7 +40,6 @@ impl FileLineTracker { } pub fn lines_by_agent(&self) -> isize { - let lines = (self.after_fswrite_lines as isize) - (self.before_fswrite_lines as isize); - lines.abs() + (self.lines_added_by_agent + self.lines_removed_by_agent) as isize } } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index c9b6b058aa..fcdb8b30ef 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -2,16 +2,17 @@ pub mod cli; mod consts; pub mod context; mod conversation; -mod error_formatter; mod input_source; mod message; mod parse; use std::path::MAIN_SEPARATOR; +pub mod checkpoint; mod line_tracker; mod parser; mod prompt; mod prompt_parser; -mod server_messenger; +pub mod server_messenger; +use crate::cli::chat::checkpoint::CHECKPOINT_MESSAGE_MAX_LENGTH; #[cfg(unix)] mod skim_integration; mod token_counter; @@ -40,9 +41,12 @@ use clap::{ Args, CommandFactory, Parser, + ValueEnum, }; use cli::compact::CompactStrategy; +use cli::hooks::ToolContext; use cli::model::{ + find_model, get_available_models, select_model, }; @@ -83,6 +87,7 @@ use parser::{ SendMessageStream, }; use regex::Regex; +use rmcp::model::PromptMessage; use spinners::{ Spinner, Spinners, @@ -139,9 +144,12 @@ use crate::auth::AuthError; use crate::auth::builder_id::is_idc_user; use crate::cli::TodoListState; use crate::cli::agent::Agents; +use crate::cli::chat::checkpoint::{ + CheckpointManager, + truncate_message, +}; use crate::cli::chat::cli::SlashCommand; use crate::cli::chat::cli::editor::open_editor; -use crate::cli::chat::cli::model::find_model; use crate::cli::chat::cli::prompts::{ GetPromptError, PromptsSubcommand, @@ -149,7 +157,6 @@ use crate::cli::chat::cli::prompts::{ use crate::cli::chat::message::UserMessage; use crate::cli::chat::util::sanitize_unicode_tags; use crate::database::settings::Setting; -use crate::mcp_client::Prompt; use crate::os::Os; use crate::telemetry::core::{ AgentConfigInitArgs, @@ -164,9 +171,11 @@ use crate::telemetry::{ TelemetryResult, get_error_reason, }; +use crate::util::directories::get_shadow_repo_dir; use crate::util::{ MCP_SERVER_TOOL_DELIMITER, directories, + ui, }; const LIMIT_REACHED_TEXT: &str = color_print::cstr! { "You've used all your free requests for this month. You have two options: @@ -190,6 +199,16 @@ pub const EXTRA_HELP: &str = color_print::cstr! {" Change using: q settings chat.skimCommandKey x "}; +#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)] +pub enum WrapMode { + /// Always wrap at terminal width + Always, + /// Never wrap (raw output) + Never, + /// Auto-detect based on output target (default) + Auto, +} + #[derive(Debug, Clone, PartialEq, Eq, Default, Args)] pub struct ChatArgs { /// Resumes the previous conversation from this directory. @@ -213,6 +232,9 @@ pub struct ChatArgs { pub no_interactive: bool, /// The first question to ask pub input: Option, + /// Control line wrapping behavior (default: auto-detect) + #[arg(short = 'w', long, value_enum)] + pub wrap: Option, } impl ChatArgs { @@ -341,7 +363,19 @@ impl ChatArgs { // If modelId is specified, verify it exists before starting the chat // Otherwise, CLI will use a default model when starting chat let (models, default_model_opt) = get_available_models(os).await?; + // Fallback logic: try user's saved default, then system default + let fallback_model_id = || { + if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { + find_model(&models, &saved) + .map(|m| m.model_id.clone()) + .or(Some(default_model_opt.model_id.clone())) + } else { + Some(default_model_opt.model_id.clone()) + } + }; + let model_id: Option = if let Some(requested) = self.model.as_ref() { + // CLI argument takes highest priority if let Some(m) = find_model(&models, requested) { Some(m.model_id.clone()) } else { @@ -352,12 +386,26 @@ impl ChatArgs { .join(", "); bail!("Model '{}' does not exist. Available models: {}", requested, available); } - } else if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { - find_model(&models, &saved) - .map(|m| m.model_id.clone()) - .or(Some(default_model_opt.model_id.clone())) + } else if let Some(agent_model) = agents.get_active().and_then(|a| a.model.as_ref()) { + // Agent model takes second priority + if let Some(m) = find_model(&models, agent_model) { + Some(m.model_id.clone()) + } else { + let _ = execute!( + stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("WARNING: "), + style::SetForegroundColor(Color::Reset), + style::Print("Agent specifies model '"), + style::SetForegroundColor(Color::Cyan), + style::Print(agent_model), + style::SetForegroundColor(Color::Reset), + style::Print("' which is not available. Falling back to configured defaults.\n"), + ); + fallback_model_id() + } } else { - Some(default_model_opt.model_id.clone()) + fallback_model_id() }; let (prompt_request_sender, prompt_request_receiver) = tokio::sync::broadcast::channel::(5); @@ -389,6 +437,7 @@ impl ChatArgs { tool_config, !self.no_interactive, mcp_enabled, + self.wrap, ) .await? .spawn(os) @@ -409,8 +458,11 @@ const WELCOME_TEXT: &str = color_print::cstr! {" const SMALL_SCREEN_WELCOME_TEXT: &str = color_print::cstr! {"Welcome to Amazon Q!"}; const RESUME_TEXT: &str = color_print::cstr! {"Picking up where we left off..."}; +// Maximum number of times to show the changelog announcement per version +const CHANGELOG_MAX_SHOW_COUNT: i64 = 2; + // Only show the model-related tip for now to make users aware of this feature. -const ROTATING_TIPS: [&str; 18] = [ +const ROTATING_TIPS: [&str; 20] = [ color_print::cstr! {"You can resume the last conversation from your current directory by launching with q chat --resume"}, color_print::cstr! {"Get notified whenever Q CLI finishes responding. @@ -444,6 +496,8 @@ const ROTATING_TIPS: [&str; 18] = [ color_print::cstr! {"Run /prompts to learn how to build & run repeatable workflows"}, color_print::cstr! {"Use /tangent or ctrl + t (customizable) to start isolated conversations ( ↯ ) that don't affect your main chat history"}, color_print::cstr! {"Ask me directly about my capabilities! Try questions like \"What can you do?\" or \"Can you save conversations?\""}, + color_print::cstr! {"Stay up to date with the latest features and improvements! Use /changelog to see what's new in Amazon Q CLI"}, + color_print::cstr! {"Enable workspace checkpoints to snapshot & restore changes. Just run q settings chat.enableCheckpoint true"}, ]; const GREETING_BREAK_POINT: usize = 80; @@ -594,10 +648,11 @@ pub struct ChatSession { /// Any failed requests that could be useful for error report/debugging failed_request_ids: Vec, /// Pending prompts to be sent - pending_prompts: VecDeque, + pending_prompts: VecDeque, interactive: bool, inner: Option, ctrlc_rx: broadcast::Receiver<()>, + wrap: Option, } impl ChatSession { @@ -617,6 +672,7 @@ impl ChatSession { tool_config: HashMap, interactive: bool, mcp_enabled: bool, + wrap: Option, ) -> Result { // Reload prior conversation let mut existing_conversation = false; @@ -708,6 +764,7 @@ impl ChatSession { interactive, inner: Some(ChatState::default()), ctrlc_rx, + wrap, }) } @@ -1066,6 +1123,34 @@ impl ChatSession { Ok(()) } + + async fn show_changelog_announcement(&mut self, os: &mut Os) -> Result<()> { + let current_version = env!("CARGO_PKG_VERSION"); + let last_version = os.database.get_changelog_last_version()?; + let show_count = os.database.get_changelog_show_count()?.unwrap_or(0); + + // Check if version changed or if we haven't shown it max times yet + let should_show = match &last_version { + Some(last) if last == current_version => show_count < CHANGELOG_MAX_SHOW_COUNT, + _ => true, // New version or no previous version + }; + + if should_show { + // Use the shared rendering function + ui::render_changelog_content(&mut self.stderr)?; + + // Update the database entries + os.database.set_changelog_last_version(current_version)?; + let new_count = if last_version.as_deref() == Some(current_version) { + show_count + 1 + } else { + 1 + }; + os.database.set_changelog_show_count(new_count)?; + } + + Ok(()) + } } impl Drop for ChatSession { @@ -1212,6 +1297,9 @@ impl ChatSession { execute!(self.stderr, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; } + // Check if we should show the whats-new announcement + self.show_changelog_announcement(os).await?; + if self.all_tools_trusted() { queue!( self.stderr, @@ -1242,6 +1330,38 @@ impl ChatSession { } } + // Initialize capturing if possible + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + { + let path = get_shadow_repo_dir(os, self.conversation.conversation_id().to_string())?; + let start = std::time::Instant::now(); + let checkpoint_manager = match CheckpointManager::auto_init(os, &path, self.conversation.history()).await { + Ok(manager) => { + execute!( + self.stderr, + style::Print( + format!( + "📷 Checkpoints are enabled! (took {:.2}s)\n\n", + start.elapsed().as_secs_f32() + ) + .blue() + .bold() + ) + )?; + Some(manager) + }, + Err(e) => { + execute!(self.stderr, style::Print(format!("{e}\n\n").blue()))?; + None + }, + }; + self.conversation.checkpoint_manager = checkpoint_manager; + } + if let Some(user_input) = self.initial_input.take() { self.inner = Some(ChatState::HandleInput { input: user_input }); } @@ -1840,7 +1960,7 @@ impl ChatSession { style::SetForegroundColor(Color::Reset), style::SetAttribute(Attribute::Reset) )?; - let prompt = self.generate_tool_trust_prompt(); + let prompt = self.generate_tool_trust_prompt(os).await; let user_input = match self.read_user_input(&prompt, false) { Some(input) => input, None => return Ok(ChatState::Exit), @@ -1968,7 +2088,7 @@ impl ChatSession { name: prompt_name, arguments, }; - return subcommand.execute(self).await; + return subcommand.execute(os, self).await; } else if let Some(command) = input.strip_prefix("!") { // Use platform-appropriate shell let result = if cfg!(target_os = "windows") { @@ -2003,6 +2123,23 @@ impl ChatSession { skip_printing_tools: false, }) } else { + // Track the message for checkpoint descriptions, but only if not already set + // This prevents tool approval responses (y/n/t) from overwriting the original message + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + && !self.conversation.is_in_tangent_mode() + { + if let Some(manager) = self.conversation.checkpoint_manager.as_mut() { + if !manager.message_locked && self.pending_tool_index.is_none() { + manager.pending_user_message = Some(user_input.clone()); + manager.message_locked = true; + } + } + } + // Check for a pending tool approval if let Some(index) = self.pending_tool_index { let is_trust = ["t", "T"].contains(&input); @@ -2182,6 +2319,7 @@ impl ChatSession { }); } + // All tools are allowed now // Execute the requested tools. let mut tool_results = vec![]; let mut image_blocks: Vec = Vec::new(); @@ -2225,6 +2363,74 @@ impl ChatSession { } execute!(self.stdout, style::Print("\n"))?; + // Handle checkpoint after tool execution - store tag for later display + let checkpoint_tag: Option = { + let enabled = os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + && !self.conversation.is_in_tangent_mode(); + if invoke_result.is_err() || !enabled { + None + } + // Take manager out temporarily to avoid borrow conflicts + else if let Some(mut manager) = self.conversation.checkpoint_manager.take() { + // Check if there are uncommitted changes + let has_changes = match manager.has_changes() { + Ok(b) => b, + Err(e) => { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!("Could not check if uncommitted changes exist: {e}\n")), + style::Print("Saving anyways...\n"), + style::SetForegroundColor(Color::Reset), + )?; + true + }, + }; + let tag = if has_changes { + // Generate tag for this tool use + let tag = format!("{}.{}", manager.current_turn + 1, manager.tools_in_turn + 1); + + // Get tool summary for commit message + let is_fs_read = matches!(&tool.tool, Tool::FsRead(_)); + let description = if is_fs_read { + "External edits detected (likely manual change)".to_string() + } else { + match tool.tool.get_summary() { + Some(summary) => summary, + None => tool.tool.display_name(), + } + }; + + // Create checkpoint + if let Err(e) = manager.create_checkpoint( + &tag, + &description, + &self.conversation.history().clone(), + false, + Some(tool.name.clone()), + ) { + debug!("Failed to create tool checkpoint: {}", e); + None + } else { + manager.tools_in_turn += 1; + Some(tag) + } + } else { + None + }; + + // Put manager back + self.conversation.checkpoint_manager = Some(manager); + tag + } else { + None + } + }; + let tool_end_time = Instant::now(); let tool_time = tool_end_time.duration_since(tool_start); tool_telemetry = tool_telemetry.and_modify(|ev| { @@ -2267,8 +2473,18 @@ impl ChatSession { style::SetAttribute(Attribute::Bold), style::Print(format!(" ● Completed in {}s", tool_time)), style::SetForegroundColor(Color::Reset), - style::Print("\n\n"), )?; + if let Some(tag) = checkpoint_tag { + execute!( + self.stdout, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!(" [{tag}]")), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset), + )?; + } + execute!(self.stdout, style::Print("\n\n"))?; tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_success = Some(true)); if let Tool::Custom(_) = &tool.tool { @@ -2347,6 +2563,50 @@ impl ChatSession { } } + // Run PostToolUse hooks for all executed tools after we have the tool_results + if let Some(cm) = self.conversation.context_manager.as_mut() { + for result in &tool_results { + if let Some(tool) = self.tool_uses.iter().find(|t| t.id == result.tool_use_id) { + let content: Vec = result + .content + .iter() + .map(|block| match block { + ToolUseResultBlock::Text(text) => serde_json::Value::String(text.clone()), + ToolUseResultBlock::Json(json) => json.clone(), + }) + .collect(); + + let tool_response = match result.status { + ToolResultStatus::Success => serde_json::json!({"success": true, "result": content}), + ToolResultStatus::Error => serde_json::json!({"success": false, "error": content}), + }; + + let tool_context = ToolContext { + tool_name: match &tool.tool { + Tool::Custom(custom_tool) => custom_tool.namespaced_tool_name(), /* for MCP tool, pass MCP name to the hook */ + _ => tool.name.clone(), + }, + tool_input: tool.tool_input.clone(), + tool_response: Some(tool_response), + }; + + // Here is how we handle postToolUse output: + // Exit code is 0: nothing. stdout is not shown to user. We don't support processing the PostToolUse + // hook output yet. Exit code is non-zero: display an error to user (already + // taken care of by the ContextManager.run_hooks) + let _ = cm + .run_hooks( + crate::cli::agent::hook::HookTrigger::PostToolUse, + &mut std::io::stderr(), + os, + None, + Some(tool_context), + ) + .await; + } + } + } + if !image_blocks.is_empty() { let images = image_blocks.into_iter().map(|(block, _)| block).collect(); self.conversation.add_tool_results_with_images(tool_results, images); @@ -2396,8 +2656,20 @@ impl ChatSession { let mut buf = String::new(); let mut offset = 0; let mut ended = false; + let terminal_width = match self.wrap { + Some(WrapMode::Never) => None, + Some(WrapMode::Always) => Some(self.terminal_width()), + Some(WrapMode::Auto) | None => { + if std::io::stdout().is_terminal() { + Some(self.terminal_width()) + } else { + None + } + }, + }; + let mut state = ParseState::new( - Some(self.terminal_width()), + terminal_width, os.database.settings.get_bool(Setting::ChatDisableMarkdownRendering), ); let mut response_prefix_printed = false; @@ -2557,6 +2829,56 @@ impl ChatSession { .await?, )); }, + RecvErrorKind::ToolValidationError { + tool_use_id, + name, + message, + error_message, + } => { + self.send_chat_telemetry( + os, + TelemetryResult::Failed, + Some(reason), + Some(reason_desc), + status_code, + false, // We retry the request, so don't end the current turn yet. + ) + .await; + + error!( + recv_error.request_metadata.request_id, + tool_use_id, name, error_message, "Tool validation failed" + ); + self.conversation + .push_assistant_message(os, *message, Some(recv_error.request_metadata)); + let tool_results = vec![ToolUseResult { + tool_use_id, + content: vec![ToolUseResultBlock::Text(format!( + "Tool validation failed: {}. Please ensure tool arguments are provided as a valid JSON object.", + error_message + ))], + status: ToolResultStatus::Error, + }]; + // User hint of what happened + let _ = queue!( + self.stdout, + style::Print("\n\n"), + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + "Tool validation failed: {}\n Retrying the request...", + error_message + )), + style::ResetColor, + style::Print("\n"), + ); + self.conversation.add_tool_results(tool_results); + self.send_tool_use_telemetry(os).await; + return Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )); + }, _ => { self.send_chat_telemetry( os, @@ -2661,6 +2983,64 @@ impl ChatSession { self.pending_tool_index = None; self.tool_turn_start_time = None; + // Create turn checkpoint if tools were used + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + && !self.conversation.is_in_tangent_mode() + { + if let Some(mut manager) = self.conversation.checkpoint_manager.take() { + if manager.tools_in_turn > 0 { + // Increment turn counter + manager.current_turn += 1; + + // Get user message for description + let description = manager.pending_user_message.take().map_or_else( + || "Turn completed".to_string(), + |msg| truncate_message(&msg, CHECKPOINT_MESSAGE_MAX_LENGTH), + ); + + // Create turn checkpoint + let tag = manager.current_turn.to_string(); + if let Err(e) = manager.create_checkpoint( + &tag, + &description, + &self.conversation.history().clone(), + true, + None, + ) { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!("⚠️ Could not create automatic checkpoint: {}\n\n", e)), + style::SetForegroundColor(Color::Reset), + )?; + } else { + execute!( + self.stderr, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!("✓ Created checkpoint {}\n\n", tag)), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset), + )?; + } + + // Reset for next turn + manager.tools_in_turn = 0; + manager.message_locked = false; // Unlock for next turn + } else { + // Clear pending message even if no tools were used + manager.pending_user_message = None; + } + + // Put manager back + self.conversation.checkpoint_manager = Some(manager); + } + } + self.send_chat_telemetry(os, TelemetryResult::Succeeded, None, None, None, true) .await; @@ -2670,6 +3050,8 @@ impl ChatSession { } } + // Validate the tool use request from LLM, including basic checks like fs_read file should exist, as + // well as user-defined preToolUse hook check. async fn validate_tools(&mut self, os: &Os, tool_uses: Vec) -> Result { let conv_id = self.conversation.conversation_id().to_owned(); debug!(?tool_uses, "Validating tool uses"); @@ -2679,6 +3061,7 @@ impl ChatSession { for tool_use in tool_uses { let tool_use_id = tool_use.id.clone(); let tool_use_name = tool_use.name.clone(); + let tool_input = tool_use.args.clone(); let mut tool_telemetry = ToolUseEventBuilder::new( conv_id.clone(), tool_use.id.clone(), @@ -2687,7 +3070,7 @@ impl ChatSession { .set_tool_use_id(tool_use_id.clone()) .set_tool_name(tool_use.name.clone()) .utterance_id(self.conversation.message_id().map(|s| s.to_string())); - match self.conversation.tool_manager.get_tool_from_tool_use(tool_use) { + match self.conversation.tool_manager.get_tool_from_tool_use(tool_use).await { Ok(mut tool) => { // Apply non-Q-generated context to tools self.contextualize_tool(&mut tool); @@ -2700,6 +3083,7 @@ impl ChatSession { name: tool_use_name, tool, accepted: false, + tool_input, }); }, Err(err) => { @@ -2771,9 +3155,81 @@ impl ChatSession { )); } + // Execute PreToolUse hooks for all validated tools + // The mental model is preToolHook is like validate tools, but its behavior can be customized by + // user Note that after preTookUse hook, user can still reject the took run + if let Some(cm) = self.conversation.context_manager.as_mut() { + for tool in &queued_tools { + let tool_context = ToolContext { + tool_name: match &tool.tool { + Tool::Custom(custom_tool) => custom_tool.namespaced_tool_name(), // for MCP tool, pass MCP + // name to the hook + _ => tool.name.clone(), + }, + tool_input: tool.tool_input.clone(), + tool_response: None, + }; + + let hook_results = cm + .run_hooks( + crate::cli::agent::hook::HookTrigger::PreToolUse, + &mut std::io::stderr(), + os, + None, // prompt + Some(tool_context), + ) + .await?; + + // Here is how we handle the preToolUse hook output: + // Exit code is 0: nothing. stdout is not shown to user. + // Exit code is 2: block the tool use. return stderr to LLM. show warning to user + // Other error: show warning to user. + + // Check for exit code 2 and add to tool_results + for (_, (exit_code, output)) in &hook_results { + if *exit_code == 2 { + tool_results.push(ToolUseResult { + tool_use_id: tool.id.clone(), + content: vec![ToolUseResultBlock::Text(format!( + "PreToolHook blocked the tool execution: {}", + output + ))], + status: ToolResultStatus::Error, + }); + } + } + } + } + + // If we have any hook validation errors, return them immediately to the model + if !tool_results.is_empty() { + debug!(?tool_results, "Error found in PreToolUse hooks"); + for tool_result in &tool_results { + for block in &tool_result.content { + if let ToolUseResultBlock::Text(content) = block { + queue!( + self.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print(format!("{}\n", content)), + style::SetForegroundColor(Color::Reset), + )?; + } + } + } + + self.conversation.add_tool_results(tool_results); + return Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )); + } + self.tool_uses = queued_tools; self.pending_tool_index = Some(0); self.tool_turn_start_time = Some(Instant::now()); + Ok(ChatState::ExecuteTools) } @@ -2848,7 +3304,7 @@ impl ChatSession { style::SetForegroundColor(Color::Reset), style::Print(" from mcp server "), style::SetForegroundColor(Color::Magenta), - style::Print(tool.client.get_server_name()), + style::Print(&tool.server_name), style::SetForegroundColor(Color::Reset), )?; } @@ -2902,11 +3358,25 @@ impl ChatSession { } /// Helper function to generate a prompt based on the current context - fn generate_tool_trust_prompt(&mut self) -> String { + async fn generate_tool_trust_prompt(&mut self, os: &Os) -> String { let profile = self.conversation.current_profile().map(|s| s.to_string()); let all_trusted = self.all_tools_trusted(); let tangent_mode = self.conversation.is_in_tangent_mode(); - prompt::generate_prompt(profile.as_deref(), all_trusted, tangent_mode) + + // Check if context usage indicator is enabled + let usage_percentage = if os + .database + .settings + .get_bool(crate::database::settings::Setting::EnabledContextUsageIndicator) + .unwrap_or(false) + { + use crate::cli::chat::cli::usage::get_total_usage_percentage; + get_total_usage_percentage(self, os).await.ok() + } else { + None + }; + + prompt::generate_prompt(profile.as_deref(), all_trusted, tangent_mode, usage_percentage) } async fn send_tool_use_telemetry(&mut self, os: &Os) { @@ -3317,6 +3787,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3459,6 +3930,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3556,6 +4028,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3631,6 +4104,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3682,6 +4156,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3690,6 +4165,252 @@ mod tests { .unwrap(); } + // Integration test for PreToolUse hook functionality. + // + // In this integration test we create a preToolUse hook that logs tool info into a file + // and we run fs_read and verify the log is generated with the correct ToolContext data. + #[tokio::test] + async fn test_tool_hook_integration() { + use std::collections::HashMap; + + use crate::cli::agent::hook::{ + Hook, + HookTrigger, + }; + + let mut os = Os::new().await.unwrap(); + os.client.set_mock_output(serde_json::json!([ + [ + "I'll read that file for you", + { + "tool_use_id": "1", + "name": "fs_read", + "args": { + "operations": [ + { + "mode": "Line", + "path": "/test.txt", + "start_line": 1, + "end_line": 3 + } + ] + } + } + ], + [ + "Here's the file content!", + ], + ])); + + // Create test file + os.fs.write("/test.txt", "line1\nline2\nline3\n").await.unwrap(); + + // Create agent with PreToolUse and PostToolUse hooks + let mut agents = Agents::default(); + let mut hooks = HashMap::new(); + + // Get the real path in the temp directory for the hooks to write to + let pre_hook_log_path = os.fs.chroot_path_str("/pre-hook-test.log"); + let post_hook_log_path = os.fs.chroot_path_str("/post-hook-test.log"); + let pre_hook_command = format!("cat > {}", pre_hook_log_path); + let post_hook_command = format!("cat > {}", post_hook_log_path); + + hooks.insert(HookTrigger::PreToolUse, vec![Hook { + command: pre_hook_command, + timeout_ms: 5000, + max_output_size: 1024, + cache_ttl_seconds: 0, + matcher: Some("fs_*".to_string()), // Match fs_read, fs_write, etc. + source: crate::cli::agent::hook::Source::Agent, + }]); + + hooks.insert(HookTrigger::PostToolUse, vec![Hook { + command: post_hook_command, + timeout_ms: 5000, + max_output_size: 1024, + cache_ttl_seconds: 0, + matcher: Some("fs_*".to_string()), // Match fs_read, fs_write, etc. + source: crate::cli::agent::hook::Source::Agent, + }]); + + let agent = Agent { + name: "TestAgent".to_string(), + hooks, + ..Default::default() + }; + agents.agents.insert("TestAgent".to_string(), agent); + agents.switch("TestAgent").expect("Failed to switch agent"); + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + + // Test that PreToolUse hook runs + ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "fake_conv_id", + agents, + None, // No initial input + InputSource::new_mock(vec![ + "read /test.txt".to_string(), + "y".to_string(), // Accept tool execution + "exit".to_string(), + ]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + None, + ) + .await + .unwrap() + .spawn(&mut os) + .await + .unwrap(); + + // Verify the PreToolUse hook was called + if let Ok(pre_log_content) = os.fs.read_to_string("/pre-hook-test.log").await { + let pre_hook_data: serde_json::Value = + serde_json::from_str(&pre_log_content).expect("PreToolUse hook output should be valid JSON"); + + assert_eq!(pre_hook_data["hook_event_name"], "preToolUse"); + assert_eq!(pre_hook_data["tool_name"], "fs_read"); + assert_eq!(pre_hook_data["tool_response"], serde_json::Value::Null); + + let tool_input = &pre_hook_data["tool_input"]; + assert!(tool_input["operations"].is_array()); + + println!("✓ PreToolUse hook validation passed: {}", pre_log_content); + } else { + panic!("PreToolUse hook log file not found - hook may not have been called"); + } + + // Verify the PostToolUse hook was called + if let Ok(post_log_content) = os.fs.read_to_string("/post-hook-test.log").await { + let post_hook_data: serde_json::Value = + serde_json::from_str(&post_log_content).expect("PostToolUse hook output should be valid JSON"); + + assert_eq!(post_hook_data["hook_event_name"], "postToolUse"); + assert_eq!(post_hook_data["tool_name"], "fs_read"); + + // Validate tool_response structure for successful execution + let tool_response = &post_hook_data["tool_response"]; + assert_eq!(tool_response["success"], true); + assert!(tool_response["result"].is_array()); + + let result_blocks = tool_response["result"].as_array().unwrap(); + assert!(!result_blocks.is_empty()); + let content = result_blocks[0].as_str().unwrap(); + assert!(content.contains("line1\nline2\nline3")); + + println!("✓ PostToolUse hook validation passed: {}", post_log_content); + } else { + panic!("PostToolUse hook log file not found - hook may not have been called"); + } + } + + #[tokio::test] + async fn test_pretool_hook_blocking_integration() { + use std::collections::HashMap; + + use crate::cli::agent::hook::{ + Hook, + HookTrigger, + }; + + let mut os = Os::new().await.unwrap(); + + // Create a test file to read + os.fs.write("/sensitive.txt", "classified information").await.unwrap(); + + // Mock LLM responses: first tries fs_read, gets blocked, then responds to error + os.client.set_mock_output(serde_json::json!([ + [ + "I'll read that file for you", + { + "tool_use_id": "1", + "name": "fs_read", + "args": { + "operations": [ + { + "mode": "Line", + "path": "/sensitive.txt" + } + ] + } + } + ], + [ + "I understand the security policy blocked access to that file.", + ], + ])); + + // Create agent with blocking PreToolUse hook + let mut agents = Agents::default(); + let mut hooks = HashMap::new(); + + // Create a hook that blocks fs_read of sensitive files with exit code 2 + #[cfg(unix)] + let hook_command = "echo 'Security policy violation: cannot read sensitive files' >&2; exit 2"; + #[cfg(windows)] + let hook_command = "echo Security policy violation: cannot read sensitive files 1>&2 & exit /b 2"; + + hooks.insert(HookTrigger::PreToolUse, vec![Hook { + command: hook_command.to_string(), + timeout_ms: 5000, + max_output_size: 1024, + cache_ttl_seconds: 0, + matcher: Some("fs_read".to_string()), + source: crate::cli::agent::hook::Source::Agent, + }]); + + let agent = Agent { + name: "SecurityAgent".to_string(), + hooks, + ..Default::default() + }; + agents.agents.insert("SecurityAgent".to_string(), agent); + agents.switch("SecurityAgent").expect("Failed to switch agent"); + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + + // Run chat session - hook should block tool execution + let result = ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "test_conv_id", + agents, + None, + InputSource::new_mock(vec!["read /sensitive.txt".to_string(), "exit".to_string()]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + None, + ) + .await + .unwrap() + .spawn(&mut os) + .await; + + // The session should complete successfully (hook blocks tool but doesn't crash) + assert!( + result.is_ok(), + "Chat session should complete successfully even when hook blocks tool" + ); + } + #[test] fn test_does_input_reference_file() { let tests = &[ diff --git a/crates/chat-cli/src/cli/chat/parse.rs b/crates/chat-cli/src/cli/chat/parse.rs index 72f0ab94c2..ed30f3944b 100644 --- a/crates/chat-cli/src/cli/chat/parse.rs +++ b/crates/chat-cli/src/cli/chat/parse.rs @@ -81,6 +81,7 @@ impl<'a> ParserError> for Error<'a> { #[derive(Debug)] pub struct ParseState { + pub is_first_line: bool, pub terminal_width: Option, pub markdown_disabled: Option, pub column: usize, @@ -96,6 +97,7 @@ pub struct ParseState { impl ParseState { pub fn new(terminal_width: Option, markdown_disabled: Option) -> Self { Self { + is_first_line: true, terminal_width, markdown_disabled, column: 0, @@ -198,8 +200,17 @@ fn text<'a, 'b>( ) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { move |i| { let content = take_while(1.., |t| AsChar::is_alphanum(t) || "+,.!?\"".contains(t)).parse_next(i)?; - queue_newline_or_advance(&mut o, state, content.width())?; + if state.is_first_line { + state.is_first_line = false; + // The extra space here is reserved for the prompt pointer ("> "). + // Essentially we want the input to wrap as if the prompt pointer is a part of it + // but only display what is received. + queue_newline_or_advance(&mut o, state, content.width() + 2)?; + } else { + queue_newline_or_advance(&mut o, state, content.width())?; + } queue(&mut o, style::Print(content))?; + Ok(()) } } diff --git a/crates/chat-cli/src/cli/chat/parser.rs b/crates/chat-cli/src/cli/chat/parser.rs index 55e57519ba..2e0cdfb03c 100644 --- a/crates/chat-cli/src/cli/chat/parser.rs +++ b/crates/chat-cli/src/cli/chat/parser.rs @@ -91,6 +91,7 @@ impl RecvError { RecvErrorKind::StreamTimeout { .. } => None, RecvErrorKind::UnexpectedToolUseEos { .. } => None, RecvErrorKind::Cancelled => None, + RecvErrorKind::ToolValidationError { .. } => None, } } } @@ -103,6 +104,7 @@ impl ReasonCode for RecvError { RecvErrorKind::StreamTimeout { .. } => "RecvErrorStreamTimeout".to_string(), RecvErrorKind::UnexpectedToolUseEos { .. } => "RecvErrorUnexpectedToolUseEos".to_string(), RecvErrorKind::Cancelled => "Interrupted".to_string(), + RecvErrorKind::ToolValidationError { .. } => "RecvErrorToolValidation".to_string(), } } } @@ -151,6 +153,14 @@ pub enum RecvErrorKind { /// The stream processing task was cancelled #[error("Stream handling was cancelled")] Cancelled, + /// Tool validation failed due to invalid arguments + #[error("Tool validation failed for tool: {} with id: {}", .name, .tool_use_id)] + ToolValidationError { + tool_use_id: String, + name: String, + message: Box, + error_message: String, + }, } /// Represents a response stream from a call to the SendMessage API. @@ -472,7 +482,43 @@ impl ResponseParser { } let args = match serde_json::from_str(&tool_string) { - Ok(args) => args, + Ok(args) => { + // Ensure we have a valid JSON object + match args { + serde_json::Value::Object(_) => args, + _ => { + error!("Received non-object JSON for tool arguments: {:?}", args); + let warning_args = serde_json::Value::Object( + [( + "key".to_string(), + serde_json::Value::String( + "WARNING: the actual tool use arguments were not a valid JSON object".to_string(), + ), + )] + .into_iter() + .collect(), + ); + self.tool_uses.push(AssistantToolUse { + id: id.clone(), + name: name.clone(), + orig_name: name.clone(), + args: warning_args.clone(), + orig_args: warning_args.clone(), + }); + let message = Box::new(AssistantMessage::new_tool_use( + Some(self.message_id.clone()), + std::mem::take(&mut self.assistant_text), + self.tool_uses.clone().into_iter().collect(), + )); + return Err(self.error(RecvErrorKind::ToolValidationError { + tool_use_id: id, + name, + message, + error_message: format!("Expected JSON object, got: {:?}", args), + })); + }, + } + }, Err(err) if !tool_string.is_empty() => { // If we failed deserializing after waiting for a long time, then this is most // likely bedrock responding with a stop event for some reason without actually @@ -753,4 +799,75 @@ mod tests { "assistant text preceding a code reference should be ignored as this indicates licensed code is being returned" ); } + + #[tokio::test] + async fn test_response_parser_avoid_invalid_json() { + let content_to_ignore = "IGNORE ME PLEASE"; + let tool_use_id = "TEST_ID".to_string(); + let tool_name = "execute_bash".to_string(); + let tool_args = serde_json::json!("invalid json").to_string(); + let mut events = vec![ + ChatResponseStream::AssistantResponseEvent { + content: "hi".to_string(), + }, + ChatResponseStream::AssistantResponseEvent { + content: " there".to_string(), + }, + ChatResponseStream::AssistantResponseEvent { + content: content_to_ignore.to_string(), + }, + ChatResponseStream::CodeReferenceEvent(()), + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: tool_name.clone(), + input: None, + stop: None, + }, + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: tool_name.clone(), + input: Some(tool_args), + stop: None, + }, + ]; + events.reverse(); + let mock = SendMessageOutput::Mock(events); + let mut parser = ResponseParser::new( + mock, + "".to_string(), + None, + 1, + vec![], + mpsc::channel(32).0, + Instant::now(), + SystemTime::now(), + CancellationToken::new(), + Arc::new(Mutex::new(None)), + ); + + let mut output = String::new(); + let mut found_validation_error = false; + for _ in 0..5 { + match parser.recv().await { + Ok(event) => { + output.push_str(&format!("{:?}", event)); + }, + Err(recv_error) => { + if matches!(recv_error.source, RecvErrorKind::ToolValidationError { .. }) { + found_validation_error = true; + } + break; + }, + } + } + + assert!( + !output.contains(content_to_ignore), + "assistant text preceding a code reference should be ignored as this indicates licensed code is being returned" + ); + assert!( + found_validation_error, + "Expected to find tool validation error for non-object JSON" + ); + } } diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index a2ac52c7db..210a0f635c 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::cell::RefCell; +use std::path::PathBuf; use eyre::Result; use rustyline::completion::{ @@ -13,7 +14,10 @@ use rustyline::highlight::{ Highlighter, }; use rustyline::hint::Hinter as RustylineHinter; -use rustyline::history::DefaultHistory; +use rustyline::history::{ + FileHistory, + SearchDirection, +}; use rustyline::validate::{ ValidationContext, ValidationResult, @@ -44,6 +48,7 @@ use super::tool_manager::{ }; use crate::database::settings::Setting; use crate::os::Os; +use crate::util::directories::chat_cli_bash_history_path; pub const COMMANDS: &[&str] = &[ "/clear", @@ -67,6 +72,7 @@ pub const COMMANDS: &[&str] = &[ "/agent rename", "/agent set", "/agent schema", + "/agent generate", "/prompts", "/context", "/context help", @@ -86,6 +92,7 @@ pub const COMMANDS: &[&str] = &[ "/compact", "/compact help", "/usage", + "/changelog", "/save", "/load", "/subscribe", @@ -261,31 +268,26 @@ impl Completer for ChatCompleter { /// Custom hinter that provides shadowtext suggestions pub struct ChatHinter { - /// Command history for providing suggestions based on past commands - history: Vec, /// Whether history-based hints are enabled history_hints_enabled: bool, + history_path: PathBuf, } impl ChatHinter { /// Creates a new ChatHinter instance - pub fn new(history_hints_enabled: bool) -> Self { + pub fn new(history_hints_enabled: bool, history_path: PathBuf) -> Self { Self { - history: Vec::new(), history_hints_enabled, + history_path, } } - /// Updates the history with a new command - pub fn update_history(&mut self, command: &str) { - let command = command.trim(); - if !command.is_empty() && !command.contains('\n') && !command.contains('\r') { - self.history.push(command.to_string()); - } + pub fn get_history_path(&self) -> PathBuf { + self.history_path.clone() } - /// Finds the best hint for the current input - fn find_hint(&self, line: &str) -> Option { + /// Finds the best hint for the current input using rustyline's history + fn find_hint(&self, line: &str, ctx: &Context<'_>) -> Option { // If line is empty, no hint if line.is_empty() { return None; @@ -299,13 +301,20 @@ impl ChatHinter { .map(|cmd| cmd[line.len()..].to_string()); } - // Try to find a hint from history if history hints are enabled + // Try to find a hint from rustyline's history if history hints are enabled if self.history_hints_enabled { - return self.history - .iter() - .rev() // Start from most recent - .find(|cmd| cmd.starts_with(line) && cmd.len() > line.len()) - .map(|cmd| cmd[line.len()..].to_string()); + let history = ctx.history(); + let history_len = history.len(); + if history_len == 0 { + return None; + } + + if let Ok(Some(search_result)) = history.starts_with(line, history_len - 1, SearchDirection::Reverse) { + let entry = search_result.entry.to_string(); + if entry.len() > line.len() { + return Some(entry[line.len()..].to_string()); + } + } } None @@ -315,13 +324,13 @@ impl ChatHinter { impl RustylineHinter for ChatHinter { type Hint = String; - fn hint(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Option { + fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option { // Only provide hints when cursor is at the end of the line if pos < line.len() { return None; } - self.find_hint(line) + self.find_hint(line, ctx) } } @@ -362,9 +371,8 @@ pub struct ChatHelper { } impl ChatHelper { - /// Updates the history of the ChatHinter with a new command - pub fn update_hinter_history(&mut self, command: &str) { - self.hinter.update_history(command); + pub fn get_history_path(&self) -> PathBuf { + self.hinter.get_history_path() } } @@ -400,6 +408,18 @@ impl Highlighter for ChatHelper { result.push_str(&format!("[{}] ", profile).cyan().to_string()); } + // Add percentage part if present (colored by usage level) + if let Some(percentage) = components.usage_percentage { + let colored_percentage = if percentage < 50.0 { + format!("{}% ", percentage as u32).green() + } else if percentage < 90.0 { + format!("{}% ", percentage as u32).yellow() + } else { + format!("{}% ", percentage as u32).red() + }; + result.push_str(&colored_percentage.to_string()); + } + // Add tangent indicator if present (yellow) if components.tangent_mode { result.push_str(&"↯ ".yellow().to_string()); @@ -425,7 +445,7 @@ pub fn rl( os: &Os, sender: PromptQuerySender, receiver: PromptQueryResponseReceiver, -) -> Result> { +) -> Result> { let edit_mode = match os.database.settings.get_string(Setting::ChatEditMode).as_deref() { Some("vi" | "vim") => EditMode::Vi, _ => EditMode::Emacs, @@ -436,40 +456,53 @@ pub fn rl( .edit_mode(edit_mode) .build(); - // Default to disabled if setting doesn't exist let history_hints_enabled = os .database .settings .get_bool(Setting::ChatEnableHistoryHints) .unwrap_or(false); + + let history_path = chat_cli_bash_history_path(os)?; + let h = ChatHelper { completer: ChatCompleter::new(sender, receiver), - hinter: ChatHinter::new(history_hints_enabled), + hinter: ChatHinter::new(history_hints_enabled, history_path), validator: MultiLineValidator, }; let mut rl = Editor::with_config(config)?; rl.set_helper(Some(h)); + // Load history from ~/.aws/amazonq/cli_history + if let Err(e) = rl.load_history(&rl.helper().unwrap().get_history_path()) { + if !matches!(e, ReadlineError::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::NotFound) { + eprintln!("Warning: Failed to load history: {}", e); + } + } + // Add custom keybinding for Alt+Enter to insert a newline rl.bind_sequence( KeyEvent(KeyCode::Enter, Modifiers::ALT), EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), ); - // Add custom keybinding for Ctrl+J to insert a newline + // Add custom keybinding for Ctrl+j to insert a newline rl.bind_sequence( KeyEvent(KeyCode::Char('j'), Modifiers::CTRL), EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), ); - // Add custom keybinding for Ctrl+F to accept hint (like fish shell) + // Add custom keybinding for autocompletion hint acceptance (configurable) + let autocompletion_key_char = match os.database.settings.get_string(Setting::AutocompletionKey) { + Some(key) if key.len() == 1 => key.chars().next().unwrap_or('g'), + _ => 'g', // Default to 'g' if setting is missing or invalid + }; rl.bind_sequence( - KeyEvent(KeyCode::Char('f'), Modifiers::CTRL), + KeyEvent(KeyCode::Char(autocompletion_key_char), Modifiers::CTRL), EventHandler::Simple(Cmd::CompleteHint), ); - // Add custom keybinding for Ctrl+T to toggle tangent mode (configurable) + // Add custom keybinding for Ctrl+t to toggle tangent mode (configurable) let tangent_key_char = match os.database.settings.get_string(Setting::TangentModeKey) { Some(key) if key.len() == 1 => key.chars().next().unwrap_or('t'), _ => 't', // Default to 't' if setting is missing or invalid @@ -486,6 +519,7 @@ pub fn rl( mod tests { use crossterm::style::Stylize; use rustyline::highlight::Highlighter; + use rustyline::history::DefaultHistory; use super::*; @@ -536,7 +570,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -552,7 +586,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -568,7 +602,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -584,7 +618,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -603,7 +637,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -619,7 +653,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -634,7 +668,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -649,7 +683,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -663,7 +697,7 @@ mod tests { #[test] fn test_chat_hinter_command_hint() { - let hinter = ChatHinter::new(true); + let hinter = ChatHinter::new(true, PathBuf::new()); // Test hint for a command let line = "/he"; @@ -693,11 +727,7 @@ mod tests { #[test] fn test_chat_hinter_history_hint_disabled() { - let mut hinter = ChatHinter::new(false); - - // Add some history - hinter.update_history("Hello, world!"); - hinter.update_history("How are you?"); + let hinter = ChatHinter::new(false, PathBuf::new()); // Test hint from history - should be None since history hints are disabled let line = "How"; @@ -708,4 +738,32 @@ mod tests { let hint = hinter.hint(line, pos, &ctx); assert_eq!(hint, None); } + + #[tokio::test] + // If you get a unit test failure for key override, please consider using a new key binding instead. + // The list of reserved keybindings here are the standard in UNIX world so please don't take them + async fn test_no_emacs_keybindings_overridden() { + let (sender, _) = tokio::sync::broadcast::channel::(1); + let (_, receiver) = tokio::sync::broadcast::channel::(1); + + // Create a mock Os for testing + let mock_os = crate::os::Os::new().await.unwrap(); + let mut test_editor = rl(&mock_os, sender, receiver).unwrap(); + + // Reserved Emacs keybindings that should not be overridden + let reserved_keys = ['a', 'e', 'f', 'b', 'k']; + + for &key in &reserved_keys { + let key_event = KeyEvent(KeyCode::Char(key), Modifiers::CTRL); + + // Try to bind and get the previous handler + let previous_handler = test_editor.bind_sequence(key_event, EventHandler::Simple(Cmd::Noop)); + + // If there was a previous handler, it means the key was already bound + // (which could be our custom binding overriding Emacs) + if previous_handler.is_some() { + panic!("Ctrl+{} appears to be overridden (found existing binding)", key); + } + } + } } diff --git a/crates/chat-cli/src/cli/chat/prompt_parser.rs b/crates/chat-cli/src/cli/chat/prompt_parser.rs index daf67a9859..87f0639054 100644 --- a/crates/chat-cli/src/cli/chat/prompt_parser.rs +++ b/crates/chat-cli/src/cli/chat/prompt_parser.rs @@ -6,15 +6,16 @@ pub struct PromptComponents { pub profile: Option, pub warning: bool, pub tangent_mode: bool, + pub usage_percentage: Option, } /// Parse prompt components from a plain text prompt pub fn parse_prompt_components(prompt: &str) -> Option { - // Expected format: "[agent] !> " or "> " or "!> " or "[agent] ↯ > " or "↯ > " or "[agent] ↯ !> " - // etc. + // Expected format: "[agent] 6% !> " or "> " or "!> " or "[agent] ↯ > " or "6% ↯ > " etc. let mut profile = None; let mut warning = false; let mut tangent_mode = false; + let mut usage_percentage = None; let mut remaining = prompt.trim(); // Check for agent pattern [agent] first @@ -28,6 +29,17 @@ pub fn parse_prompt_components(prompt: &str) -> Option { } } + // Check for percentage pattern (e.g., "6% ") + if let Some(percent_pos) = remaining.find('%') { + let before_percent = &remaining[..percent_pos]; + if let Ok(percentage) = before_percent.trim().parse::() { + usage_percentage = Some(percentage); + if let Some(space_after_percent) = remaining[percent_pos..].find(' ') { + remaining = remaining[percent_pos + space_after_percent + 1..].trim_start(); + } + } + } + // Check for tangent mode ↯ first if let Some(after_tangent) = remaining.strip_prefix('↯') { tangent_mode = true; @@ -46,13 +58,19 @@ pub fn parse_prompt_components(prompt: &str) -> Option { profile, warning, tangent_mode, + usage_percentage, }) } else { None } } -pub fn generate_prompt(current_profile: Option<&str>, warning: bool, tangent_mode: bool) -> String { +pub fn generate_prompt( + current_profile: Option<&str>, + warning: bool, + tangent_mode: bool, + usage_percentage: Option, +) -> String { // Generate plain text prompt that will be colored by highlight_prompt let warning_symbol = if warning { "!" } else { "" }; let profile_part = current_profile @@ -60,10 +78,12 @@ pub fn generate_prompt(current_profile: Option<&str>, warning: bool, tangent_mod .map(|p| format!("[{p}] ")) .unwrap_or_default(); + let percentage_part = usage_percentage.map(|p| format!("{:.0}% ", p)).unwrap_or_default(); + if tangent_mode { - format!("{profile_part}↯ {warning_symbol}> ") + format!("{profile_part}{percentage_part}↯ {warning_symbol}> ") } else { - format!("{profile_part}{warning_symbol}> ") + format!("{profile_part}{percentage_part}{warning_symbol}> ") } } @@ -74,26 +94,43 @@ mod tests { #[test] fn test_generate_prompt() { // Test default prompt (no profile) - assert_eq!(generate_prompt(None, false, false), "> "); + assert_eq!(generate_prompt(None, false, false, None), "> "); // Test default prompt with warning - assert_eq!(generate_prompt(None, true, false), "!> "); + assert_eq!(generate_prompt(None, true, false, None), "!> "); // Test tangent mode - assert_eq!(generate_prompt(None, false, true), "↯ > "); + assert_eq!(generate_prompt(None, false, true, None), "↯ > "); // Test tangent mode with warning - assert_eq!(generate_prompt(None, true, true), "↯ !> "); + assert_eq!(generate_prompt(None, true, true, None), "↯ !> "); // Test default profile (should be same as no profile) - assert_eq!(generate_prompt(Some(DEFAULT_AGENT_NAME), false, false), "> "); + assert_eq!(generate_prompt(Some(DEFAULT_AGENT_NAME), false, false, None), "> "); // Test custom profile - assert_eq!(generate_prompt(Some("test-profile"), false, false), "[test-profile] > "); + assert_eq!( + generate_prompt(Some("test-profile"), false, false, None), + "[test-profile] > " + ); // Test custom profile with tangent mode assert_eq!( - generate_prompt(Some("test-profile"), false, true), + generate_prompt(Some("test-profile"), false, true, None), "[test-profile] ↯ > " ); // Test another custom profile with warning - assert_eq!(generate_prompt(Some("dev"), true, false), "[dev] !> "); + assert_eq!(generate_prompt(Some("dev"), true, false, None), "[dev] !> "); // Test custom profile with warning and tangent mode - assert_eq!(generate_prompt(Some("dev"), true, true), "[dev] ↯ !> "); + assert_eq!(generate_prompt(Some("dev"), true, true, None), "[dev] ↯ !> "); + // Test custom profile with usage percentage + assert_eq!( + generate_prompt(Some("rust-agent"), false, false, Some(6.2)), + "[rust-agent] 6% > " + ); + // Test custom profile with usage percentage and warning + assert_eq!( + generate_prompt(Some("rust-agent"), true, false, Some(15.7)), + "[rust-agent] 16% !> " + ); + // Test usage percentage without profile + assert_eq!(generate_prompt(None, false, false, Some(25.3)), "25% > "); + // Test usage percentage with tangent mode + assert_eq!(generate_prompt(None, false, true, Some(8.9)), "9% ↯ > "); } #[test] @@ -103,48 +140,75 @@ mod tests { assert!(components.profile.is_none()); assert!(!components.warning); assert!(!components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test warning prompt let components = parse_prompt_components("!> ").unwrap(); assert!(components.profile.is_none()); assert!(components.warning); assert!(!components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test tangent mode let components = parse_prompt_components("↯ > ").unwrap(); assert!(components.profile.is_none()); assert!(!components.warning); assert!(components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test tangent mode with warning let components = parse_prompt_components("↯ !> ").unwrap(); assert!(components.profile.is_none()); assert!(components.warning); assert!(components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test profile prompt let components = parse_prompt_components("[test] > ").unwrap(); assert_eq!(components.profile.as_deref(), Some("test")); assert!(!components.warning); assert!(!components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test profile with warning let components = parse_prompt_components("[dev] !> ").unwrap(); assert_eq!(components.profile.as_deref(), Some("dev")); assert!(components.warning); assert!(!components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test profile with tangent mode let components = parse_prompt_components("[dev] ↯ > ").unwrap(); assert_eq!(components.profile.as_deref(), Some("dev")); assert!(!components.warning); assert!(components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test profile with warning and tangent mode let components = parse_prompt_components("[dev] ↯ !> ").unwrap(); assert_eq!(components.profile.as_deref(), Some("dev")); assert!(components.warning); assert!(components.tangent_mode); + assert!(components.usage_percentage.is_none()); + + // Test prompts with percentages + let components = parse_prompt_components("[rust-agent] 6% > ").unwrap(); + assert_eq!(components.profile.as_deref(), Some("rust-agent")); + assert!(!components.warning); + assert!(!components.tangent_mode); + assert_eq!(components.usage_percentage, Some(6.0)); + + let components = parse_prompt_components("25% > ").unwrap(); + assert!(components.profile.is_none()); + assert!(!components.warning); + assert!(!components.tangent_mode); + assert_eq!(components.usage_percentage, Some(25.0)); + + let components = parse_prompt_components("8% ↯ > ").unwrap(); + assert!(components.profile.is_none()); + assert!(!components.warning); + assert!(components.tangent_mode); + assert_eq!(components.usage_percentage, Some(8.0)); // Test invalid prompt assert!(parse_prompt_components("invalid").is_none()); diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs index aaf685c399..be86d891a9 100644 --- a/crates/chat-cli/src/cli/chat/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/server_messenger.rs @@ -1,48 +1,58 @@ +use rmcp::model::{ + ListPromptsResult, + ListResourceTemplatesResult, + ListResourcesResult, + ListToolsResult, +}; +use rmcp::{ + Peer, + RoleClient, +}; use tokio::sync::mpsc::{ Receiver, Sender, channel, }; -use crate::mcp_client::{ +use crate::mcp_client::messenger::{ Messenger, MessengerError, - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ToolsListResult, + MessengerResult, + Result, }; #[allow(dead_code)] #[derive(Debug)] pub enum UpdateEventMessage { - ToolsListResult { + ListToolsResult { server_name: String, - result: eyre::Result, - pid: Option, + result: Result, + peer: Option>, }, - PromptsListResult { + ListPromptsResult { server_name: String, - result: eyre::Result, - pid: Option, + result: Result, + peer: Option>, }, - ResourcesListResult { + ListResourcesResult { server_name: String, - result: eyre::Result, - pid: Option, + result: Result, + peer: Option>, }, ResourceTemplatesListResult { server_name: String, - result: eyre::Result, - pid: Option, + result: Result, + peer: Option>, + }, + OauthLink { + server_name: String, + link: String, }, InitStart { server_name: String, - pid: Option, }, Deinit { server_name: String, - pid: Option, }, } @@ -64,7 +74,6 @@ impl ServerMessengerBuilder { ServerMessenger { server_name, update_event_sender: self.update_event_sender.clone(), - pid: None, } } } @@ -73,30 +82,37 @@ impl ServerMessengerBuilder { pub struct ServerMessenger { pub server_name: String, pub update_event_sender: Sender, - pub pid: Option, } #[async_trait::async_trait] impl Messenger for ServerMessenger { - async fn send_tools_list_result(&self, result: eyre::Result) -> Result<(), MessengerError> { + async fn send_tools_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult { Ok(self .update_event_sender - .send(UpdateEventMessage::ToolsListResult { + .send(UpdateEventMessage::ListToolsResult { server_name: self.server_name.clone(), result, - pid: self.pid, + peer, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) } - async fn send_prompts_list_result(&self, result: eyre::Result) -> Result<(), MessengerError> { + async fn send_prompts_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult { Ok(self .update_event_sender - .send(UpdateEventMessage::PromptsListResult { + .send(UpdateEventMessage::ListPromptsResult { server_name: self.server_name.clone(), result, - pid: self.pid, + peer, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -104,14 +120,15 @@ impl Messenger for ServerMessenger { async fn send_resources_list_result( &self, - result: eyre::Result, - ) -> Result<(), MessengerError> { + result: Result, + peer: Option>, + ) -> MessengerResult { Ok(self .update_event_sender - .send(UpdateEventMessage::ResourcesListResult { + .send(UpdateEventMessage::ListResourcesResult { server_name: self.server_name.clone(), result, - pid: self.pid, + peer, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -119,25 +136,36 @@ impl Messenger for ServerMessenger { async fn send_resource_templates_list_result( &self, - result: eyre::Result, - ) -> Result<(), MessengerError> { + result: Result, + peer: Option>, + ) -> MessengerResult { Ok(self .update_event_sender .send(UpdateEventMessage::ResourceTemplatesListResult { server_name: self.server_name.clone(), result, - pid: self.pid, + peer, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + async fn send_oauth_link(&self, link: String) -> MessengerResult { + Ok(self + .update_event_sender + .send(UpdateEventMessage::OauthLink { + server_name: self.server_name.clone(), + link, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) } - async fn send_init_msg(&self) -> Result<(), MessengerError> { + async fn send_init_msg(&self) -> MessengerResult { Ok(self .update_event_sender .send(UpdateEventMessage::InitStart { server_name: self.server_name.clone(), - pid: self.pid, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -146,9 +174,8 @@ impl Messenger for ServerMessenger { fn send_deinit_msg(&self) { let sender = self.update_event_sender.clone(); let server_name = self.server_name.clone(); - let pid = self.pid; tokio::spawn(async move { - let _ = sender.send(UpdateEventMessage::Deinit { server_name, pid }).await; + let _ = sender.send(UpdateEventMessage::Deinit { server_name }).await; }); } diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 3171459915..6ef93bea26 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -32,12 +32,14 @@ use crossterm::{ terminal, }; use eyre::Report; -use futures::{ - StreamExt, - future, - stream, -}; +use futures::future; use regex::Regex; +use rmcp::ServiceError; +use rmcp::model::{ + GetPromptRequestParam, + GetPromptResult, + Prompt, +}; use tokio::signal::ctrl_c; use tokio::sync::{ Mutex, @@ -68,10 +70,7 @@ use crate::cli::chat::server_messenger::{ ServerMessengerBuilder, UpdateEventMessage, }; -use crate::cli::chat::tools::custom_tool::{ - CustomTool, - CustomToolClient, -}; +use crate::cli::chat::tools::custom_tool::CustomTool; use crate::cli::chat::tools::execute::ExecuteCommand; use crate::cli::chat::tools::fs_read::FsRead; use crate::cli::chat::tools::fs_write::FsWrite; @@ -88,10 +87,11 @@ use crate::cli::chat::tools::{ }; use crate::database::Database; use crate::database::settings::Setting; +use crate::mcp_client::messenger::Messenger; use crate::mcp_client::{ - JsonRpcResponse, - Messenger, - PromptGet, + InitializedMcpClient, + InnerService, + McpClientService, }; use crate::os::Os; use crate::telemetry::TelemetryThread; @@ -99,8 +99,7 @@ use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::util::directories::home_dir; const NAMESPACE_DELIMITER: &str = "___"; -// This applies for both mcp server and tool name since in the end the tool name as seen by the -// model is just {server_name}{NAMESPACE_DELIMITER}{tool_name} +// This applies for both mcp server and tool name const VALID_TOOL_NAME: &str = "^[a-zA-Z][a-zA-Z0-9_]*$"; const SPINNER_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; @@ -138,6 +137,11 @@ enum LoadingMsg { /// This is sent when all tool initialization is complete or when the application is shutting /// down. Terminate { still_loading: Vec }, + /// Indicates that a server requires user authentication and provides a sign-in link. + /// This message is used to notify the user about authentication requirements for MCP servers + /// that need OAuth or other authentication methods. Contains the server name and the + /// authentication message (typically a URL or instructions). + SignInNotice { name: String }, } /// Used to denote the loading outcome associated with a server. @@ -145,9 +149,26 @@ enum LoadingMsg { /// surface (since we would only want to surface fatal errors in non-interactive mode). #[derive(Clone, Debug)] pub enum LoadingRecord { - Success(String), - Warn(String), - Err(String), + Success(String, String), + Warn(String, String), + Err(String, String), +} + +impl LoadingRecord { + pub fn success(msg: String) -> Self { + let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string(); + LoadingRecord::Success(timestamp, msg) + } + + pub fn warn(msg: String) -> Self { + let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string(); + LoadingRecord::Warn(timestamp, msg) + } + + pub fn err(msg: String) -> Self { + let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string(); + LoadingRecord::Err(timestamp, msg) + } } pub struct ToolManagerBuilder { @@ -160,6 +181,7 @@ pub struct ToolManagerBuilder { has_new_stuff: Arc, mcp_load_record: Arc>>>, new_tool_specs: NewToolSpecs, + pending_clients: Option>>>, is_first_launch: bool, agent: Option>>, } @@ -176,6 +198,7 @@ impl Default for ToolManagerBuilder { has_new_stuff: Default::default(), mcp_load_record: Default::default(), new_tool_specs: Default::default(), + pending_clients: Default::default(), is_first_launch: true, agent: Default::default(), } @@ -196,6 +219,7 @@ impl From<&mut ToolManager> for ToolManagerBuilder { has_new_stuff: value.has_new_stuff.clone(), mcp_load_record: value.mcp_load_record.clone(), new_tool_specs: value.new_tool_specs.clone(), + pending_clients: Some(value.pending_clients.clone()), // if we are getting a builder from an instantiated tool manager this field would be // false is_first_launch: false, @@ -271,8 +295,8 @@ impl ToolManagerBuilder { .collect(); let pre_initialized = enabled_servers - .into_iter() - .filter_map(|(server_name, server_config)| { + .iter() + .filter(|(server_name, _)| { if server_name == "builtin" { let _ = queue!( output, @@ -287,13 +311,26 @@ impl ToolManagerBuilder { style::ResetColor, style::Print(" (it is used to denote native tools)\n") ); - None + false } else { - let custom_tool_client = CustomToolClient::from_config(server_name.clone(), server_config, os); - Some((server_name, custom_tool_client)) + true } }) - .collect::>(); + .collect::>(); + + let mut clients = HashMap::::new(); + let new_tool_specs = self.new_tool_specs; + let has_new_stuff = self.has_new_stuff; + let pending = self.pending_clients.unwrap_or(Arc::new(RwLock::new({ + let mut pending = HashSet::::new(); + pending.extend(pre_initialized.iter().map(|(name, _)| name.clone())); + pending + }))); + let notify = Arc::new(Notify::new()); + let load_record = self.mcp_load_record; + let agent = self.agent.unwrap_or_default(); + let database = os.database.clone(); + let mut messenger_builder = self.messenger_builder.take(); let mut loading_servers = HashMap::::new(); for (server_name, _) in &pre_initialized { @@ -308,16 +345,6 @@ impl ToolManagerBuilder { let (loading_display_task, loading_status_sender) = spawn_display_task(interactive, total, disabled_servers, output); - let mut clients = HashMap::>::new(); - let new_tool_specs = self.new_tool_specs; - let has_new_stuff = self.has_new_stuff; - let pending = Arc::new(RwLock::new(HashSet::::new())); - let notify = Arc::new(Notify::new()); - let load_record = self.mcp_load_record; - let agent = self.agent.unwrap_or_default(); - let database = os.database.clone(); - let mut messenger_builder = self.messenger_builder.take(); - // This is the orchestrator task that serves as a bridge between tool manager and mcp // clients for server initiated async events if let (Some(prompt_list_sender), Some(prompt_list_receiver)) = ( @@ -358,19 +385,29 @@ impl ToolManagerBuilder { debug_assert!(messenger_builder.is_some()); let messenger_builder = messenger_builder.unwrap(); - for (mut name, init_res) in pre_initialized { - let mut messenger = messenger_builder.build_with_name(name.clone()); + let pre_initialized = enabled_servers + .into_iter() + .map(|(server_name, server_config)| { + ( + server_name.clone(), + McpClientService::new( + server_name.clone(), + server_config, + messenger_builder.build_with_name(server_name), + ), + ) + }) + .collect::>(); + + for (mut name, mcp_client) in pre_initialized { + let init_res = mcp_client.init(os).await; match init_res { - Ok(mut client) => { - let pid = client.get_pid(); - messenger.pid = pid; - client.assign_messenger(Box::new(messenger)); - let mut client = Arc::new(client); - while let Some(collided_client) = clients.insert(name.clone(), client) { + Ok(mut running_service) => { + while let Some(collided_service) = clients.insert(name.clone(), running_service) { // to avoid server name collision we are going to circumvent this by // appending the name with 1 name.push('1'); - client = collided_client; + running_service = collided_service; } }, Err(e) => { @@ -379,7 +416,7 @@ impl ToolManagerBuilder { .send_mcp_server_init( &os.database, conversation_id.clone(), - name, + name.clone(), Some(e.to_string()), 0, Some("".to_string()), @@ -388,7 +425,11 @@ impl ToolManagerBuilder { ) .await .ok(); - let _ = messenger.send_tools_list_result(Err(e)).await; + + let temp_messenger = messenger_builder.build_with_name(name); + let _ = temp_messenger + .send_tools_list_result(Err(ServiceError::UnexpectedResponse), None) + .await; }, } } @@ -428,7 +469,7 @@ pub struct PromptBundle { /// The server name from which the prompt is offered / exposed pub server_name: String, /// The prompt get (info with which a prompt is retrieved) cached - pub prompt_get: PromptGet, + pub prompt_get: Prompt, } #[derive(Clone, Debug)] @@ -448,10 +489,11 @@ pub enum PromptQueryResult { /// - `IllegalChar`: The tool name contains characters that are not allowed /// - `EmptyDescription`: The tool description is empty or missing #[allow(dead_code)] -enum OutOfSpecName { +enum ToolValidationViolation { TooLong(String), IllegalChar(String), EmptyDescription(String), + DescriptionTooLong(String), } #[derive(Clone, Default, Debug, Eq, PartialEq)] @@ -509,7 +551,7 @@ pub struct ToolManager { /// Map of server names to their corresponding client instances. /// These clients are used to communicate with MCP servers. - pub clients: HashMap>, + pub clients: HashMap, /// A list of client names that are still in the process of being initialized pub pending_clients: Arc>>, @@ -579,7 +621,6 @@ impl Clone for ToolManager { fn clone(&self) -> Self { Self { conversation_id: self.conversation_id.clone(), - clients: self.clients.clone(), has_new_stuff: self.has_new_stuff.clone(), new_tool_specs: self.new_tool_specs.clone(), tn_map: self.tn_map.clone(), @@ -603,7 +644,42 @@ impl ToolManager { /// function) /// - Calling load tools pub async fn swap_agent(&mut self, os: &mut Os, output: &mut impl Write, agent: &Agent) -> eyre::Result<()> { - self.clients.clear(); + let to_evict = self.clients.drain().collect::>(); + tokio::spawn(async move { + for (server_name, initialized_client) in to_evict { + info!("Evicting {server_name} due to agent swap"); + match initialized_client { + InitializedMcpClient::Pending(handle) => { + let server_name_clone = server_name.clone(); + tokio::spawn(async move { + match handle.await { + Ok(Ok(client)) => { + let InnerService::Original(client) = client.inner_service else { + unreachable!(); + }; + match client.cancel().await { + Ok(_) => info!("Server {server_name_clone} evicted due to agent swap"), + Err(e) => error!("Server {server_name_clone} has failed to cancel: {e}"), + } + }, + Ok(Err(_)) | Err(_) => { + error!("Server {server_name_clone} has failed to cancel"); + }, + } + }); + }, + InitializedMcpClient::Ready(running_service) => { + let InnerService::Original(client) = running_service.inner_service else { + unreachable!(); + }; + match client.cancel().await { + Ok(_) => info!("Server {server_name} evicted due to agent swap"), + Err(e) => error!("Server {server_name} has failed to cancel: {e}"), + } + }, + } + } + }); let mut agent_lock = self.agent.lock().await; *agent_lock = agent.clone(); @@ -615,9 +691,7 @@ impl ToolManager { let mut new_tool_manager = builder.build(os, Box::new(std::io::sink()), true).await?; std::mem::swap(self, &mut new_tool_manager); - // we can discard the output here and let background server load take care of getting the - // new tools - let _ = self.load_tools(os, output).await?; + self.load_tools(os, output).await?; Ok(()) } @@ -684,20 +758,7 @@ impl ToolManager { tool_specs }; - let load_tools = self - .clients - .values() - .map(|c| { - let clone = Arc::clone(c); - async move { clone.init().await } - }) - .collect::>(); - let initial_poll = stream::iter(load_tools) - .map(|async_closure| tokio::spawn(async_closure)) - .buffer_unordered(20); - tokio::spawn(async move { - initial_poll.collect::>().await; - }); + // We need to cast it to erase the type otherwise the compiler will default to static // dispatch, which would result in an error of inconsistent match arm return type. let timeout_fut: Pin>> = if self.clients.is_empty() || !self.is_first_launch { @@ -770,7 +831,7 @@ impl ToolManager { .lock() .await .iter() - .any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(_)))) + .any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(..)))) { queue!( stderr, @@ -785,7 +846,7 @@ impl ToolManager { Ok(self.schema.clone()) } - pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { + pub async fn get_tool_from_tool_use(&mut self, value: AssistantToolUse) -> Result { let map_err = |parse_error| ToolResult { tool_use_id: value.id.clone(), content: vec![ToolResultContentBlock::Text(format!( @@ -811,7 +872,7 @@ impl ToolManager { "thinking" => Tool::Thinking(serde_json::from_value::(value.args).map_err(map_err)?), "knowledge" => Tool::Knowledge(serde_json::from_value::(value.args).map_err(map_err)?), "todo_list" => Tool::Todo(serde_json::from_value::(value.args).map_err(map_err)?), - // Note that this name is namespaced with server_name{DELIMITER}tool_name + // Note that this name is NO LONGER namespaced with server_name{DELIMITER}tool_name name => { // Note: tn_map also has tools that underwent no transformation. In otherwords, if // it is a valid tool name, we should get a hit. @@ -831,7 +892,7 @@ impl ToolManager { }) }, }?; - let Some(client) = self.clients.get(server_name) else { + let Some(client) = self.clients.get_mut(server_name) else { return Err(ToolResult { tool_use_id: value.id, content: vec![ToolResultContentBlock::Text(format!( @@ -840,22 +901,19 @@ impl ToolManager { status: ToolResultStatus::Error, }); }; - // The tool input schema has the shape of { type, properties }. - // The field "params" expected by MCP is { name, arguments }, where name is the - // name of the tool being invoked, - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools. - // The field "arguments" is where ToolUse::args belong. - let mut params = serde_json::Map::::new(); - params.insert("name".to_owned(), serde_json::Value::String(tool_name.to_owned())); - params.insert("arguments".to_owned(), value.args); - let params = serde_json::Value::Object(params); - let custom_tool = CustomTool { + + let running_service = client.get_running_service().await.map_err(|e| ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!("Mcp tool client not ready: {e}"))], + status: ToolResultStatus::Error, + })?; + + Tool::Custom(CustomTool { name: tool_name.to_owned(), - client: client.clone(), - method: "tools/call".to_owned(), - params: Some(params), - }; - Tool::Custom(custom_tool) + server_name: server_name.to_owned(), + client: running_service.clone(), + params: value.args.as_object().cloned(), + }) }, }) } @@ -921,7 +979,7 @@ impl ToolManager { if !conflicts.is_empty() { let mut record_lock = self.mcp_load_record.lock().await; for (server_name, msg) in conflicts { - let record = LoadingRecord::Err(msg); + let record = LoadingRecord::err(msg); record_lock .entry(server_name) .and_modify(|v| v.push(record.clone())) @@ -951,10 +1009,10 @@ impl ToolManager { } pub async fn get_prompt( - &self, + &mut self, name: String, arguments: Option>, - ) -> Result { + ) -> Result { let (server_name, prompt_name) = match name.split_once('/') { None => (None::, Some(name.clone())), Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), @@ -1013,9 +1071,9 @@ impl ToolManager { }; let server_name = &bundle.server_name; - let client = self.clients.get(server_name).ok_or(GetPromptError::MissingClient)?; + let client = self.clients.get_mut(server_name).ok_or(GetPromptError::MissingClient)?; let PromptBundle { prompt_get, .. } = bundle; - let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, &arguments) { + let arguments = if let (Some(schema), Some(value)) = (&prompt_get.arguments, &arguments) { let params = schema.iter().zip(value.iter()).fold( HashMap::::new(), |mut acc, (prompt_get_arg, value)| { @@ -1023,19 +1081,20 @@ impl ToolManager { acc }, ); - Some(serde_json::json!(params)) + Some( + params + .into_iter() + .map(|(k, v)| (k, serde_json::Value::String(v))) + .collect(), + ) } else { None }; - let params = { - let mut params = serde_json::Map::new(); - params.insert("name".to_string(), serde_json::Value::String(prompt_name)); - if let Some(args) = args { - params.insert("arguments".to_string(), args); - } - Some(serde_json::Value::Object(params)) - }; - let resp = client.request("prompts/get", params).await?; + + let params = GetPromptRequestParam { name, arguments }; + let running_service = client.get_running_service().await?; + let resp = running_service.get_prompt(params).await?; + Ok(resp) }, (None, _) => Err(GetPromptError::PromptNotFound(prompt_name)), @@ -1143,6 +1202,16 @@ fn spawn_display_task( execute!(output, style::Print("\n"),)?; break; }, + LoadingMsg::SignInNotice { name } => { + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_oauth_message(&name, &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, }, Err(_e) => { spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); @@ -1296,10 +1365,10 @@ fn spawn_orchestrator_task( // request method on the mcp client no longer buffers all the pages from // list calls. match msg { - UpdateEventMessage::ToolsListResult { + UpdateEventMessage::ListToolsResult { server_name, result, - pid, + peer, } => { let time_taken = loading_servers .remove(&server_name) @@ -1311,11 +1380,8 @@ fn spawn_orchestrator_task( let result_tools = match &result { Ok(tools_result) => { - let names: Vec = tools_result - .tools - .iter() - .filter_map(|tool| tool.get("name")?.as_str().map(String::from)) - .collect(); + let names: Vec = + tools_result.tools.iter().map(|tool| tool.name.to_string()).collect(); names }, Err(_) => vec![], @@ -1369,40 +1435,27 @@ fn spawn_orchestrator_task( match result { Ok(result) => { - if pid.is_none_or(|pid| !is_process_running(pid)) { - let pid = pid.map_or("unknown".to_string(), |pid| pid.to_string()); - info!( - "Received tool list result from {server_name} but its associated process {pid} is no longer running. Ignoring." - ); - - let mut buf_writer = BufWriter::new(&mut *record_temp_buf); - let _ = queue_failure_message( - &server_name, - &eyre::eyre!("Process associated is no longer running"), - &time_taken, - &mut buf_writer, - ); - let _ = buf_writer.flush(); - drop(buf_writer); - let record_content = String::from_utf8_lossy(record_temp_buf).to_string(); - let record = LoadingRecord::Err(record_content); - - load_record - .lock() - .await - .entry(server_name.clone()) - .and_modify(|load_record| { - load_record.push(record.clone()); - }) - .or_insert(vec![record]); - + if let Some(peer) = peer { + if peer.is_transport_closed() { + error!( + "Received tool list result from {server_name} but transport has been closed. Ignoring." + ); + return; + } + } else { + error!("Received tool list result from {server_name} without a peer. Ignoring."); return; } let mut specs = result .tools .into_iter() - .filter_map(|v| serde_json::from_value::(v).ok()) + .map(|v| ToolSpec { + name: v.name.to_string(), + description: v.description.as_ref().map(|d| d.to_string()).unwrap_or_default(), + input_schema: crate::cli::chat::tools::InputSchema(v.schema_as_json_value()), + tool_origin: ToolOrigin::Native, + }) .filter(|spec| tool_filter.should_include(&spec.name)) .collect::>(); let mut sanitized_mapping = HashMap::::new(); @@ -1418,6 +1471,7 @@ fn spawn_orchestrator_task( &result_tools, ) .await; + if let Some(sender) = &loading_status_sender { // Anomalies here are not considered fatal, thus we shall give // warnings. @@ -1458,9 +1512,9 @@ fn spawn_orchestrator_task( drop(buf_writer); let record = String::from_utf8_lossy(record_temp_buf).to_string(); let record = if process_result.is_err() { - LoadingRecord::Warn(record) + LoadingRecord::warn(record) } else { - LoadingRecord::Success(record) + LoadingRecord::success(record) }; load_record .lock() @@ -1476,11 +1530,17 @@ fn spawn_orchestrator_task( error!("Error loading server {server_name}: {:?}", e); // Maintain a record of the server load: let mut buf_writer = BufWriter::new(&mut *record_temp_buf); - let _ = queue_failure_message(server_name.as_str(), &e, &time_taken, &mut buf_writer); + let fail_load_msg = eyre::eyre!("{}", e); + let _ = queue_failure_message( + server_name.as_str(), + &fail_load_msg, + &time_taken, + &mut buf_writer, + ); let _ = buf_writer.flush(); drop(buf_writer); let record = String::from_utf8_lossy(record_temp_buf).to_string(); - let record = LoadingRecord::Err(record); + let record = LoadingRecord::err(record); load_record .lock() .await @@ -1494,7 +1554,7 @@ fn spawn_orchestrator_task( if let Some(sender) = &loading_status_sender { let msg = LoadingMsg::Error { name: server_name.clone(), - msg: e, + msg: eyre::eyre!("{}", e.to_string()), time: time_taken, }; if let Err(e) = sender.send(msg).await { @@ -1514,17 +1574,21 @@ fn spawn_orchestrator_task( } } }, - UpdateEventMessage::PromptsListResult { + UpdateEventMessage::ListPromptsResult { server_name, result, - pid, + peer, } => match result { - Ok(prompt_list_result) if pid.is_some() => { - let pid = pid.unwrap(); - if !is_process_running(pid) { - info!( - "Received prompt list result from {server_name} but its associated process {pid} is no longer running. Ignoring." - ); + Ok(prompt_list_result) => { + if let Some(peer) = peer { + if peer.is_transport_closed() { + error!( + "Received prompt list result from {server_name} but transport has been closed. Ignoring." + ); + return; + } + } else { + error!("Received prompt list result from {server_name} without a peer. Ignoring."); return; } // We first need to clear all the PromptGets that are associated with @@ -1535,38 +1599,32 @@ fn spawn_orchestrator_task( .for_each(|bundles| bundles.retain(|bundle| bundle.server_name != server_name)); // And then we update them with the new comers - for result in prompt_list_result.prompts { - let Ok(prompt_get) = serde_json::from_value::(result) else { - error!("Failed to deserialize prompt get from server {server_name}"); - continue; - }; + for prompt in prompt_list_result.prompts { prompts - .entry(prompt_get.name.clone()) + .entry(prompt.name.clone()) .and_modify(|bundles| { bundles.push(PromptBundle { server_name: server_name.clone(), - prompt_get: prompt_get.clone(), + prompt_get: prompt.clone(), }); }) .or_insert_with(|| { vec![PromptBundle { server_name: server_name.clone(), - prompt_get, + prompt_get: prompt, }] }); } }, - Ok(_) => { - error!("Received prompt list result without pid from {server_name}. Ignoring."); - }, Err(e) => { error!("Error fetching prompts from server {server_name}: {:?}", e); let mut buf_writer = BufWriter::new(&mut *record_temp_buf); - let _ = queue_prompts_load_error_message(&server_name, &e, &mut buf_writer); + let msg = eyre::eyre!("{}", e); + let _ = queue_prompts_load_error_message(&server_name, &msg, &mut buf_writer); let _ = buf_writer.flush(); drop(buf_writer); let record = String::from_utf8_lossy(record_temp_buf).to_string(); - let record = LoadingRecord::Err(record); + let record = LoadingRecord::err(record); load_record .lock() .await @@ -1577,16 +1635,37 @@ fn spawn_orchestrator_task( .or_insert(vec![record]); }, }, - UpdateEventMessage::ResourcesListResult { - server_name: _, - result: _, - pid: _, - } => {}, - UpdateEventMessage::ResourceTemplatesListResult { - server_name: _, - result: _, - pid: _, - } => {}, + UpdateEventMessage::ListResourcesResult { .. } => {}, + UpdateEventMessage::ResourceTemplatesListResult { .. } => {}, + UpdateEventMessage::OauthLink { server_name, link } => { + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + let msg = eyre::eyre!(link); + let _ = queue_oauth_message_with_link(server_name.as_str(), &msg, &mut buf_writer); + let _ = buf_writer.flush(); + drop(buf_writer); + let record_str = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = LoadingRecord::warn(record_str.clone()); + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + if let Some(sender) = &loading_status_sender { + let msg = LoadingMsg::SignInNotice { + name: server_name.clone(), + }; + if let Err(e) = sender.send(msg).await { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + loading_status_sender.take(); + } + } + }, UpdateEventMessage::InitStart { server_name, .. } => { pending.write().await.insert(server_name.clone()); loading_servers.insert(server_name, std::time::Instant::now()); @@ -1659,7 +1738,7 @@ async fn process_tool_specs( // // For non-compliance due to point 1, we shall change it on behalf of the users. // For the rest, we simply throw a warning and reject the tool. - let mut out_of_spec_tool_names = Vec::::new(); + let mut out_of_spec_tool_names = Vec::::new(); let mut hasher = DefaultHasher::new(); let mut number_of_tools = 0_usize; @@ -1684,12 +1763,18 @@ async fn process_tool_specs( } }); if model_tool_name.len() > 64 { - out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); + out_of_spec_tool_names.push(ToolValidationViolation::TooLong(spec.name.clone())); continue; } else if spec.description.is_empty() { - out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); + out_of_spec_tool_names.push(ToolValidationViolation::EmptyDescription(spec.name.clone())); continue; } + + if spec.description.len() > 10_004 { + spec.description.truncate(10_004); + out_of_spec_tool_names.push(ToolValidationViolation::DescriptionTooLong(spec.name.clone())); + } + tn_map.insert(model_tool_name.clone(), ToolInfo { server_name: server_name.to_string(), host_tool_name: spec.name.clone(), @@ -1727,21 +1812,25 @@ async fn process_tool_specs( if !out_of_spec_tool_names.is_empty() { Err(eyre::eyre!(out_of_spec_tool_names.iter().fold( String::from( - "The following tools are out of spec. They will be excluded from the list of available tools:\n", + "The following tools are out of spec. They may have been excluded from the list of available tools:\n", ), |mut acc, name| { let (tool_name, msg) = match name { - OutOfSpecName::TooLong(tool_name) => ( + ToolValidationViolation::TooLong(tool_name) => ( tool_name.as_str(), "tool name exceeds max length of 64 when combined with server name", ), - OutOfSpecName::IllegalChar(tool_name) => ( + ToolValidationViolation::IllegalChar(tool_name) => ( tool_name.as_str(), "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$", ), - OutOfSpecName::EmptyDescription(tool_name) => { + ToolValidationViolation::EmptyDescription(tool_name) => { (tool_name.as_str(), "tool schema contains empty description") }, + ToolValidationViolation::DescriptionTooLong(tool_name) => ( + tool_name.as_str(), + "tool description is longer than 10024 characters and has been truncated", + ), }; acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); acc @@ -1778,22 +1867,6 @@ fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) - } } -// Add this function to check if a process is still running -fn is_process_running(pid: u32) -> bool { - #[cfg(unix)] - { - let system = sysinfo::System::new_all(); - system.process(sysinfo::Pid::from(pid as usize)).is_some() - } - #[cfg(windows)] - { - // TODO: fill in the process health check for windows when when we officially support - // windows - _ = pid; - true - } -} - fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { Ok(queue!( output, @@ -1878,12 +1951,40 @@ fn queue_failure_message( style::Print(fail_load_msg), style::Print("\n"), style::Print(format!( - " - run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n" + " - run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog/{CHAT_BINARY_NAME}.log for detail\n" )), style::ResetColor, )?) } +fn queue_oauth_message(name: &str, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" requires OAuth authentication. Use /mcp to see the auth link\n"), + )?) +} + +fn queue_oauth_message_with_link(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" requires OAuth authentication. Follow this link to proceed: \n"), + style::SetForegroundColor(style::Color::Yellow), + style::Print(msg), + style::ResetColor, + style::Print("\n") + )?) +} + fn queue_warn_message(name: &str, msg: &eyre::Report, time: &str, output: &mut impl Write) -> eyre::Result<()> { Ok(queue!( output, diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index fafb55a9a4..0e45205678 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -1,19 +1,18 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::io::Write; -use std::sync::Arc; use crossterm::{ queue, style, }; use eyre::Result; -use regex::Regex; +use rmcp::model::CallToolRequestParam; use schemars::JsonSchema; use serde::{ Deserialize, Serialize, }; -use tokio::sync::RwLock; use tracing::warn; use super::InvokeOutput; @@ -24,24 +23,44 @@ use crate::cli::agent::{ use crate::cli::chat::CONTINUATION_LINE; use crate::cli::chat::token_counter::TokenCounter; use crate::mcp_client::{ - Client as McpClient, - ClientConfig as McpClientConfig, - JsonRpcResponse, - JsonRpcStdioTransport, - MessageContent, - Messenger, - ServerCapabilities, - StdioTransport, - ToolCallResult, + RunningService, + oauth_util, }; use crate::os::Os; use crate::util::MCP_SERVER_TOOL_DELIMITER; -use crate::util::pattern_matching::matches_any_pattern; -// TODO: support http transport type #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub enum TransportType { + /// Standard input/output transport (default) + Stdio, + /// HTTP transport for web-based communication + Http, +} + +impl Default for TransportType { + fn default() -> Self { + Self::Stdio + } +} + +#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] +#[serde(rename_all = "camelCase")] pub struct CustomToolConfig { + /// The transport type to use for communication with the MCP server + #[serde(default)] + pub r#type: TransportType, + /// The URL for HTTP-based MCP server communication + #[serde(default)] + pub url: String, + /// HTTP headers to include when communicating with HTTP-based MCP servers + #[serde(default)] + pub headers: HashMap, + /// Scopes with which oauth is done + #[serde(default = "get_default_scopes")] + pub oauth_scopes: Vec, /// The command string used to initialize the mcp server + #[serde(default)] pub command: String, /// A list of arguments to be used to run the command with #[serde(default)] @@ -60,130 +79,15 @@ pub struct CustomToolConfig { pub is_from_legacy_mcp_json: bool, } -pub fn default_timeout() -> u64 { - 120 * 1000 -} - -/// Substitutes environment variables in the format ${env:VAR_NAME} with their actual values -fn substitute_env_vars(input: &str, env: &crate::os::Env) -> String { - // Create a regex to match ${env:VAR_NAME} pattern - let re = Regex::new(r"\$\{env:([^}]+)\}").unwrap(); - - re.replace_all(input, |caps: ®ex::Captures<'_>| { - let var_name = &caps[1]; - env.get(var_name).unwrap_or_else(|_| format!("${{{}}}", var_name)) - }) - .to_string() -} - -/// Process a HashMap of environment variables, substituting any ${env:VAR_NAME} patterns -/// with their actual values from the environment -fn process_env_vars(env_vars: &mut HashMap, env: &crate::os::Env) { - for (_, value) in env_vars.iter_mut() { - *value = substitute_env_vars(value, env); - } -} - -#[derive(Debug)] -pub enum CustomToolClient { - Stdio { - /// This is the server name as recognized by the model (post sanitized) - server_name: String, - client: McpClient, - server_capabilities: RwLock>, - }, +pub fn get_default_scopes() -> Vec { + oauth_util::get_default_scopes() + .iter() + .map(|s| (*s).to_string()) + .collect::>() } -impl CustomToolClient { - // TODO: add support for http transport - pub fn from_config(server_name: String, config: CustomToolConfig, os: &crate::os::Os) -> Result { - let CustomToolConfig { - command, - args, - env, - timeout, - disabled: _, - .. - } = config; - - // Process environment variables if present - let processed_env = env.map(|mut env_vars| { - process_env_vars(&mut env_vars, &os.env); - env_vars - }); - - let mcp_client_config = McpClientConfig { - server_name: server_name.clone(), - bin_path: command.clone(), - args, - timeout, - client_info: serde_json::json!({ - "name": "Q CLI Chat", - "version": "1.0.0" - }), - env: processed_env, - }; - let client = McpClient::::from_config(mcp_client_config)?; - Ok(CustomToolClient::Stdio { - server_name, - client, - server_capabilities: RwLock::new(None), - }) - } - - pub async fn init(&self) -> Result<()> { - match self { - CustomToolClient::Stdio { - client, - server_capabilities, - .. - } => { - if let Some(messenger) = &client.messenger { - let _ = messenger.send_init_msg().await; - } - // We'll need to first initialize. This is the handshake every client and server - // needs to do before proceeding to anything else - let cap = client.init().await?; - // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 - // So don't worry about the tidiness for now - server_capabilities.write().await.replace(cap); - Ok(()) - }, - } - } - - pub fn assign_messenger(&mut self, messenger: Box) { - match self { - CustomToolClient::Stdio { client, .. } => { - client.messenger = Some(messenger); - }, - } - } - - pub fn get_server_name(&self) -> &str { - match self { - CustomToolClient::Stdio { server_name, .. } => server_name.as_str(), - } - } - - pub async fn request(&self, method: &str, params: Option) -> Result { - match self { - CustomToolClient::Stdio { client, .. } => Ok(client.request(method, params).await?), - } - } - - pub fn get_pid(&self) -> Option { - match self { - CustomToolClient::Stdio { client, .. } => client.server_process_id.as_ref().map(|pid| pid.as_u32()), - } - } - - #[allow(dead_code)] - pub async fn notify(&self, method: &str, params: Option) -> Result<()> { - match self { - CustomToolClient::Stdio { client, .. } => Ok(client.notify(method, params).await?), - } - } +pub fn default_timeout() -> u64 { + 120 * 1000 } /// Represents a custom tool that can be invoked through the Model Context Protocol (MCP). @@ -192,47 +96,40 @@ pub struct CustomTool { /// Actual tool name as recognized by its MCP server. This differs from the tool names as they /// are seen by the model since they are not prefixed by its MCP server name. pub name: String, + /// The name of the MCP (Model Context Protocol) server that hosts this tool. + /// This is used to identify which server instance the tool belongs to and is + /// prefixed to the tool name when presented to the model for disambiguation. + pub server_name: String, /// Reference to the client that manages communication with the tool's server process. - pub client: Arc, - /// The method name to call on the tool's server, following the JSON-RPC convention. - /// This corresponds to a specific functionality provided by the tool. - pub method: String, + pub client: RunningService, /// Optional parameters to pass to the tool when invoking the method. /// Structured as a JSON value to accommodate various parameter types and structures. - pub params: Option, + pub params: Option>, } impl CustomTool { - pub async fn invoke(&self, _os: &Os, _updates: impl Write) -> Result { - // Assuming a response shape as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools - let resp = self.client.request(self.method.as_str(), self.params.clone()).await?; - let result = match resp.result { - Some(result) => result, - None => { - let failure = resp.error.map_or("Unknown error encountered".to_string(), |err| { - serde_json::to_string(&err).unwrap_or_default() - }); - return Err(eyre::eyre!(failure)); - }, + /// Returns the full tool name with server prefix in the format @server_name/tool_name + pub fn namespaced_tool_name(&self) -> String { + format!("@{}{}{}", self.server_name, MCP_SERVER_TOOL_DELIMITER, self.name) + } + + pub async fn invoke(&self, _os: &Os, _updates: &mut impl Write) -> Result { + let params = CallToolRequestParam { + name: Cow::from(self.name.clone()), + arguments: self.params.clone(), }; - match serde_json::from_value::(result.clone()) { - Ok(mut de_result) => { - for content in &mut de_result.content { - if let MessageContent::Image { data, .. } = content { - *data = format!("Redacted base64 encoded string of an image of size {}", data.len()); - } - } - Ok(InvokeOutput { - output: super::OutputKind::Json(serde_json::json!(de_result)), - }) - }, - Err(e) => { - warn!("Tool call result deserialization failed: {:?}", e); - Ok(InvokeOutput { - output: super::OutputKind::Json(result.clone()), - }) - }, + let resp = self.client.call_tool(params.clone()).await?; + + if resp.is_error.is_none_or(|v| !v) { + Ok(InvokeOutput { + output: super::OutputKind::Json(serde_json::json!(resp)), + }) + } else { + warn!("Tool call for {} failed", self.name); + Ok(InvokeOutput { + output: super::OutputKind::Json(serde_json::json!(resp)), + }) } } @@ -271,83 +168,18 @@ impl CustomTool { } pub fn get_input_token_size(&self) -> usize { - TokenCounter::count_tokens(self.method.as_str()) - + TokenCounter::count_tokens(self.params.as_ref().map_or("", |p| p.as_str().unwrap_or_default())) + TokenCounter::count_tokens( + &serde_json::to_string(self.params.as_ref().unwrap_or(&serde_json::Map::new())).unwrap_or_default(), + ) } pub fn eval_perm(&self, _os: &Os, agent: &Agent) -> PermissionEvalResult { - let Self { - name: tool_name, - client, - .. - } = self; - let server_name = client.get_server_name(); - - let server_pattern = format!("@{server_name}"); - if agent.allowed_tools.contains(&server_pattern) { - return PermissionEvalResult::Allow; - } + use crate::util::tool_permission_checker::is_tool_in_allowlist; - let tool_pattern = format!("@{server_name}{MCP_SERVER_TOOL_DELIMITER}{tool_name}"); - if matches_any_pattern(&agent.allowed_tools, &tool_pattern) { - return PermissionEvalResult::Allow; - } - - PermissionEvalResult::Ask - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_substitute_env_vars() { - // Set a test environment variable - let os = Os::new().await.unwrap(); - unsafe { - os.env.set_var("TEST_VAR", "test_value"); - } - - // Test basic substitution - assert_eq!( - substitute_env_vars("Value is ${env:TEST_VAR}", &os.env), - "Value is test_value" - ); - - // Test multiple substitutions - assert_eq!( - substitute_env_vars("${env:TEST_VAR} and ${env:TEST_VAR}", &os.env), - "test_value and test_value" - ); - - // Test non-existent variable - assert_eq!( - substitute_env_vars("${env:NON_EXISTENT_VAR}", &os.env), - "${NON_EXISTENT_VAR}" - ); - - // Test mixed content - assert_eq!( - substitute_env_vars("Prefix ${env:TEST_VAR} suffix", &os.env), - "Prefix test_value suffix" - ); - } - - #[tokio::test] - async fn test_process_env_vars() { - let os = Os::new().await.unwrap(); - unsafe { - os.env.set_var("TEST_VAR", "test_value"); + if is_tool_in_allowlist(&agent.allowed_tools, &self.name, Some(&self.server_name)) { + PermissionEvalResult::Allow + } else { + PermissionEvalResult::Ask } - - let mut env_vars = HashMap::new(); - env_vars.insert("KEY1".to_string(), "Value is ${env:TEST_VAR}".to_string()); - env_vars.insert("KEY2".to_string(), "No substitution".to_string()); - - process_env_vars(&mut env_vars, &os.env); - - assert_eq!(env_vars.get("KEY1").unwrap(), "Value is test_value"); - assert_eq!(env_vars.get("KEY2").unwrap(), "No substitution"); } } diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs index 388c48476b..200cac641a 100644 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs @@ -23,7 +23,7 @@ use crate::cli::chat::tools::{ }; use crate::cli::chat::util::truncate_safe; use crate::os::Os; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; // Platform-specific modules #[cfg(windows)] @@ -70,7 +70,7 @@ impl ExecuteCommand { let Some(args) = shlex::split(&self.command) else { return true; }; - const DANGEROUS_PATTERNS: &[&str] = &["<(", "$(", "`", ">", "&&", "||", "&", ";", "${", "\n", "\r", "IFS"]; + const DANGEROUS_PATTERNS: &[&str] = &["<(", "$(", "`", ">", "&&", "||", "&", ";", "$", "\n", "\r", "IFS"]; if args .iter() @@ -196,22 +196,22 @@ impl ExecuteCommand { #[serde(default)] denied_commands: Vec, #[serde(default = "default_allow_read_only")] - allow_read_only: bool, + auto_allow_readonly: bool, } fn default_allow_read_only() -> bool { - true + false } let Self { command, .. } = self; let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, tool_name); + let is_in_allowlist = is_tool_in_allowlist(&agent.allowed_tools, tool_name, None); match agent.tools_settings.get(tool_name) { Some(settings) => { let Settings { allowed_commands, denied_commands, - allow_read_only, + auto_allow_readonly, } = match serde_json::from_value::(settings.clone()) { Ok(settings) => settings, Err(e) => { @@ -233,7 +233,7 @@ impl ExecuteCommand { if is_in_allowlist { PermissionEvalResult::Allow - } else if self.requires_acceptance(Some(&allowed_commands), allow_read_only) { + } else if self.requires_acceptance(Some(&allowed_commands), auto_allow_readonly) { PermissionEvalResult::Ask } else { PermissionEvalResult::Allow @@ -328,6 +328,7 @@ mod tests { (r#"find / -fprintf "/path/to/file" -quit"#, true), (r"find . -${t}exec touch asdf \{\} +", true), (r"find . -${t:=exec} touch asdf2 \{\} +", true), + (r#"find /tmp -name "*" -exe$9c touch /tmp/find_result {} +"#, true), // `grep` command arguments ("echo 'test data' | grep -P '(?{system(\"date\")})'", true), ("echo 'test data' | grep --perl-regexp '(?{system(\"date\")})'", true), @@ -488,6 +489,133 @@ mod tests { assert!(matches!(res, PermissionEvalResult::Deny(ref rules) if rules.contains(&"\\Agit .*\\z".to_string()))); } + #[tokio::test] + async fn test_eval_perm_allow_read_only_default() { + use crate::cli::agent::Agent; + + let os = Os::new().await.unwrap(); + + // Test read-only command with default settings (allow_read_only = false) + let readonly_cmd = serde_json::from_value::(serde_json::json!({ + "command": "ls -la", + })) + .unwrap(); + + let agent = Agent::default(); + let res = readonly_cmd.eval_perm(&os, &agent); + // Should ask for confirmation even for read-only commands by default + assert!(matches!(res, PermissionEvalResult::Ask)); + + // Test non-read-only command with default settings + let write_cmd = serde_json::from_value::(serde_json::json!({ + "command": "rm file.txt", + })) + .unwrap(); + + let res = write_cmd.eval_perm(&os, &agent); + // Should ask for confirmation for write commands + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_allow_read_only_enabled() { + use std::collections::HashMap; + + use crate::cli::agent::{ + Agent, + ToolSettingTarget, + }; + + let os = Os::new().await.unwrap(); + let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::::new(); + map.insert( + ToolSettingTarget(tool_name.to_string()), + serde_json::json!({ + "autoAllowReadonly": true + }), + ); + map + }, + ..Default::default() + }; + + // Test read-only command with allow_read_only = true + let readonly_cmd = serde_json::from_value::(serde_json::json!({ + "command": "ls -la", + })) + .unwrap(); + + let res = readonly_cmd.eval_perm(&os, &agent); + // Should allow read-only commands without confirmation + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test write command with allow_read_only = true + let write_cmd = serde_json::from_value::(serde_json::json!({ + "command": "rm file.txt", + })) + .unwrap(); + + let res = write_cmd.eval_perm(&os, &agent); + // Should still ask for confirmation for write commands + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_allow_read_only_with_denied_commands() { + use std::collections::HashMap; + + use crate::cli::agent::{ + Agent, + ToolSettingTarget, + }; + + let os = Os::new().await.unwrap(); + let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::::new(); + map.insert( + ToolSettingTarget(tool_name.to_string()), + serde_json::json!({ + "autoAllowReadonly": true, + "deniedCommands": ["ls .*"] + }), + ); + map + }, + ..Default::default() + }; + + // Test read-only command that's in denied list + let denied_readonly_cmd = serde_json::from_value::(serde_json::json!({ + "command": "ls -la", + })) + .unwrap(); + + let res = denied_readonly_cmd.eval_perm(&os, &agent); + // Should deny even read-only commands if they're in denied list + assert!( + matches!(res, PermissionEvalResult::Deny(ref commands) if commands.contains(&"\\Als .*\\z".to_string())) + ); + + // Test different read-only command not in denied list + let allowed_readonly_cmd = serde_json::from_value::(serde_json::json!({ + "command": "cat file.txt", + })) + .unwrap(); + + let res = allowed_readonly_cmd.eval_perm(&os, &agent); + // Should allow read-only commands not in denied list + assert!(matches!(res, PermissionEvalResult::Allow)); + } + #[tokio::test] async fn test_cloudtrail_tracking() { use crate::cli::chat::consts::{ diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs index a11924e9a2..e6dc7e31ba 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_read.rs @@ -46,7 +46,7 @@ use crate::cli::chat::{ }; use crate::os::Os; use crate::util::directories; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; #[derive(Debug, Clone, Deserialize)] pub struct FsRead { @@ -109,159 +109,160 @@ impl FsRead { allowed_paths: Vec, #[serde(default)] denied_paths: Vec, - #[serde(default = "default_allow_read_only")] + #[serde(default)] allow_read_only: bool, } - fn default_allow_read_only() -> bool { - true - } + let is_in_allowlist = is_tool_in_allowlist(&agent.allowed_tools, "fs_read", None); + let settings = agent + .tools_settings + .get("fs_read") + .cloned() + .unwrap_or_else(|| serde_json::json!({})); + + { + let Settings { + mut allowed_paths, + denied_paths, + allow_read_only, + } = match serde_json::from_value::(settings) { + Ok(settings) => settings, + Err(e) => { + error!("Failed to deserialize tool settings for fs_read: {:?}", e); + return PermissionEvalResult::Ask; + }, + }; - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_read"); - match agent.tools_settings.get("fs_read") { - Some(settings) => { - let Settings { - allowed_paths, - denied_paths, - allow_read_only, - } = match serde_json::from_value::(settings.clone()) { - Ok(settings) => settings, - Err(e) => { - error!("Failed to deserialize tool settings for fs_read: {:?}", e); - return PermissionEvalResult::Ask; - }, - }; - let allow_set = { - let mut builder = GlobSetBuilder::new(); - for path in &allowed_paths { - let Ok(path) = directories::canonicalizes_path(os, path) else { - continue; - }; - if let Err(e) = directories::add_gitignore_globs(&mut builder, path.as_str()) { - warn!("Failed to create glob from path given: {path}: {e}. Ignoring."); - } + // Always add current working directory to allowed paths + if let Ok(cwd) = os.env.current_dir() { + allowed_paths.push(cwd.to_string_lossy().to_string()); + } + + let allow_set = { + let mut builder = GlobSetBuilder::new(); + for path in &allowed_paths { + let Ok(path) = directories::canonicalizes_path(os, path) else { + continue; + }; + if let Err(e) = directories::add_gitignore_globs(&mut builder, path.as_str()) { + warn!("Failed to create glob from path given: {path}: {e}. Ignoring."); } - builder.build() - }; + } + builder.build() + }; - let mut sanitized_deny_list = Vec::<&String>::new(); - let deny_set = { - let mut builder = GlobSetBuilder::new(); - for path in &denied_paths { - let Ok(processed_path) = directories::canonicalizes_path(os, path) else { - continue; - }; - match directories::add_gitignore_globs(&mut builder, processed_path.as_str()) { - Ok(_) => { - // Note that we need to push twice here because for each rule we - // are creating two globs (one for file and one for directory) - sanitized_deny_list.push(path); - sanitized_deny_list.push(path); - }, - Err(e) => warn!("Failed to create glob from path given: {path}: {e}. Ignoring."), - } + let mut sanitized_deny_list = Vec::<&String>::new(); + let deny_set = { + let mut builder = GlobSetBuilder::new(); + for path in &denied_paths { + let Ok(processed_path) = directories::canonicalizes_path(os, path) else { + continue; + }; + match directories::add_gitignore_globs(&mut builder, processed_path.as_str()) { + Ok(_) => { + // Note that we need to push twice here because for each rule we + // are creating two globs (one for file and one for directory) + sanitized_deny_list.push(path); + sanitized_deny_list.push(path); + }, + Err(e) => warn!("Failed to create glob from path given: {path}: {e}. Ignoring."), } - builder.build() - }; + } + builder.build() + }; - match (allow_set, deny_set) { - (Ok(allow_set), Ok(deny_set)) => { - let mut deny_list = Vec::::new(); - let mut ask = false; - - for op in &self.operations { - match op { - FsReadOperation::Line(FsLine { path, .. }) - | FsReadOperation::Directory(FsDirectory { path, .. }) - | FsReadOperation::Search(FsSearch { path, .. }) => { - let Ok(path) = directories::canonicalizes_path(os, path) else { - ask = true; - continue; - }; - let denied_match_set = deny_set.matches(path.as_ref() as &str); - if !denied_match_set.is_empty() { - let deny_res = PermissionEvalResult::Deny({ - denied_match_set - .iter() - .filter_map(|i| sanitized_deny_list.get(*i).map(|s| (*s).clone())) - .collect::>() - }); - deny_list.push(deny_res); - continue; - } - - // We only want to ask if we are not allowing read only - // operation - if !is_in_allowlist - && !allow_read_only - && !allow_set.is_match(path.as_ref() as &str) - { - ask = true; - } - }, - FsReadOperation::Image(fs_image) => { - let paths = &fs_image.image_paths; - let denied_match_set = paths - .iter() - .flat_map(|path| { - let Ok(path) = directories::canonicalizes_path(os, path) else { - return vec![]; - }; - deny_set.matches(path.as_ref() as &str) - }) - .collect::>(); - if !denied_match_set.is_empty() { - let deny_res = PermissionEvalResult::Deny({ - denied_match_set - .iter() - .filter_map(|i| sanitized_deny_list.get(*i).map(|s| (*s).clone())) - .collect::>() - }); - deny_list.push(deny_res); - continue; - } - - // We only want to ask if we are not allowing read only - // operation - if !is_in_allowlist - && !allow_read_only - && !paths.iter().any(|path| allow_set.is_match(path)) - { - ask = true; - } - }, - } + match (allow_set, deny_set) { + (Ok(allow_set), Ok(deny_set)) => { + let mut deny_list = Vec::::new(); + let mut ask = false; + + for op in &self.operations { + match op { + FsReadOperation::Line(FsLine { path, .. }) + | FsReadOperation::Directory(FsDirectory { path, .. }) + | FsReadOperation::Search(FsSearch { path, .. }) => { + let Ok(path) = directories::canonicalizes_path(os, path) else { + ask = true; + continue; + }; + let denied_match_set = deny_set.matches(path.as_ref() as &str); + if !denied_match_set.is_empty() { + let deny_res = PermissionEvalResult::Deny({ + denied_match_set + .iter() + .filter_map(|i| sanitized_deny_list.get(*i).map(|s| (*s).clone())) + .collect::>() + }); + deny_list.push(deny_res); + continue; + } + + // We only want to ask if we are not allowing read only + // operation + if !is_in_allowlist && !allow_read_only && !allow_set.is_match(path.as_ref() as &str) { + ask = true; + } + }, + FsReadOperation::Image(fs_image) => { + let paths = &fs_image.image_paths; + let denied_match_set = paths + .iter() + .flat_map(|path| { + let Ok(path) = directories::canonicalizes_path(os, path) else { + return vec![]; + }; + deny_set.matches(path.as_ref() as &str) + }) + .collect::>(); + if !denied_match_set.is_empty() { + let deny_res = PermissionEvalResult::Deny({ + denied_match_set + .iter() + .filter_map(|i| sanitized_deny_list.get(*i).map(|s| (*s).clone())) + .collect::>() + }); + deny_list.push(deny_res); + continue; + } + + // We only want to ask if we are not allowing read only + // operation + if !is_in_allowlist + && !allow_read_only + && !paths.iter().any(|path| allow_set.is_match(path)) + { + ask = true; + } + }, } + } - if !deny_list.is_empty() { - PermissionEvalResult::Deny({ - deny_list.into_iter().fold(Vec::::new(), |mut acc, res| { - if let PermissionEvalResult::Deny(mut rules) = res { - acc.append(&mut rules); - } - acc - }) + if !deny_list.is_empty() { + PermissionEvalResult::Deny({ + deny_list.into_iter().fold(Vec::::new(), |mut acc, res| { + if let PermissionEvalResult::Deny(mut rules) = res { + acc.append(&mut rules); + } + acc }) - } else if ask { - PermissionEvalResult::Ask - } else { - PermissionEvalResult::Allow - } - }, - (allow_res, deny_res) => { - if let Err(e) = allow_res { - warn!("fs_read failed to build allow set: {:?}", e); - } - if let Err(e) = deny_res { - warn!("fs_read failed to build deny set: {:?}", e); - } - warn!("One or more detailed args failed to parse, falling back to ask"); + }) + } else if ask { PermissionEvalResult::Ask - }, - } - }, - None if is_in_allowlist => PermissionEvalResult::Allow, - _ => PermissionEvalResult::Ask, + } else { + PermissionEvalResult::Allow + } + }, + (allow_res, deny_res) => { + if let Err(e) = allow_res { + warn!("fs_read failed to build allow set: {:?}", e); + } + if let Err(e) = deny_res { + warn!("fs_read failed to build deny set: {:?}", e); + } + warn!("One or more detailed args failed to parse, falling back to ask"); + PermissionEvalResult::Ask + }, + } } } @@ -862,6 +863,7 @@ fn format_mode(mode: u32) -> [char; 9] { #[cfg(test)] mod tests { use std::collections::HashMap; + use std::path::PathBuf; use super::*; use crate::cli::agent::ToolSettingTarget; @@ -1397,7 +1399,7 @@ mod tests { } #[tokio::test] - async fn test_eval_perm() { + async fn test_eval_perm_denied_path() { const DENIED_PATH_OR_FILE: &str = "/some/denied/path"; const DENIED_PATH_OR_FILE_GLOB: &str = "/denied/glob/**/path"; @@ -1447,4 +1449,92 @@ mod tests { && deny_list.iter().filter(|p| *p == DENIED_PATH_OR_FILE).collect::>().len() == 2 )); } + + #[tokio::test] + async fn test_eval_perm_allowed_path_and_cwd() { + // by default the fake env uses "/" as the CWD. + // change it to a sub folder so we can test fs_read reading files outside CWD + let os = Os::new().await.unwrap(); + os.env.set_current_dir_for_test(PathBuf::from("/home/user")); + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::new(); + map.insert( + ToolSettingTarget("fs_read".to_string()), + serde_json::json!({ + "allowedPaths": ["/explicitly/allowed/path"] + }), + ); + map + }, + ..Default::default() // Not in allowed_tools, allow_read_only = false + }; + + // Test 1: Explicitly allowed path should work + let allowed_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/explicitly/allowed/path", "mode": "Directory" }, + { "path": "/explicitly/allowed/path/file.txt", "mode": "Line" }, + ] + })) + .unwrap(); + let res = allowed_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test 2: CWD should always be allowed + let cwd_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/home/user/", "mode": "Directory" }, + { "path": "/home/user/file.txt", "mode": "Line" }, + ] + })) + .unwrap(); + let res = cwd_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test 3: Outside CWD and not explicitly allowed should ask + let outside_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/tmp/not/allowed/file.txt", "mode": "Line" } + ] + })) + .unwrap(); + let res = outside_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_no_settings_cwd_behavior() { + let os = Os::new().await.unwrap(); + os.env.set_current_dir_for_test(PathBuf::from("/home/user")); + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: HashMap::new(), // No fs_read settings + ..Default::default() + }; + + // Test 1: CWD should be allowed even with no settings + let cwd_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/home/user/", "mode": "Directory" }, + { "path": "/home/user/file.txt", "mode": "Line" }, + ] + })) + .unwrap(); + let res = cwd_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test 2: Outside CWD should ask for permission + let outside_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/tmp/not/allowed/file.txt", "mode": "Line" } + ] + })) + .unwrap(); + let res = outside_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Ask)); + } } diff --git a/crates/chat-cli/src/cli/chat/tools/fs_write.rs b/crates/chat-cli/src/cli/chat/tools/fs_write.rs index 6222b0cd57..d72ccb2be6 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_write.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_write.rs @@ -45,7 +45,7 @@ use crate::cli::agent::{ use crate::cli::chat::line_tracker::FileLineTracker; use crate::os::Os; use crate::util::directories; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; static SYNTAX_SET: LazyLock = LazyLock::new(SyntaxSet::load_defaults_newlines); static THEME_SET: LazyLock = LazyLock::new(ThemeSet::load_defaults); @@ -247,11 +247,57 @@ impl FsWrite { let tracker = line_tracker.entry(path.to_string_lossy().to_string()).or_default(); tracker.after_fswrite_lines = after_lines; + + // Calculate actual lines added and removed by analyzing the diff + let (lines_added, lines_removed) = self.calculate_diff_lines(os).await?; + tracker.lines_added_by_agent = lines_added; + tracker.lines_removed_by_agent = lines_removed; + tracker.is_first_write = false; Ok(()) } + async fn calculate_diff_lines(&self, os: &Os) -> Result<(usize, usize)> { + let path = self.path(os); + + let result = match self { + FsWrite::Create { .. } => { + // For create operations, all lines in the new file are added + let new_content = os.fs.read_to_string(&path).await?; + let lines_added = new_content.lines().count(); + (lines_added, 0) + }, + FsWrite::StrReplace { old_str, new_str, .. } => { + // Use actual diff analysis for accurate line counting + let diff = similar::TextDiff::from_lines(old_str, new_str); + let mut lines_added = 0; + let mut lines_removed = 0; + + for change in diff.iter_all_changes() { + match change.tag() { + similar::ChangeTag::Insert => lines_added += 1, + similar::ChangeTag::Delete => lines_removed += 1, + similar::ChangeTag::Equal => {}, + } + } + (lines_added, lines_removed) + }, + FsWrite::Insert { new_str, .. } => { + // For insert operations, all lines in new_str are added + let lines_added = new_str.lines().count(); + (lines_added, 0) + }, + FsWrite::Append { new_str, .. } => { + // For append operations, all lines in new_str are added + let lines_added = new_str.lines().count(); + (lines_added, 0) + }, + }; + + Ok(result) + } + pub fn queue_description(&self, os: &Os, output: &mut impl Write) -> Result<()> { let cwd = os.env.current_dir()?; self.print_relative_path(os, output)?; @@ -405,7 +451,7 @@ impl FsWrite { } /// Returns the summary from any variant of the FsWrite enum - fn get_summary(&self) -> Option<&String> { + pub fn get_summary(&self) -> Option<&String> { match self { FsWrite::Create { summary, .. } => summary.as_ref(), FsWrite::StrReplace { summary, .. } => summary.as_ref(), @@ -424,7 +470,7 @@ impl FsWrite { denied_paths: Vec, } - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_write"); + let is_in_allowlist = is_tool_in_allowlist(&agent.allowed_tools, "fs_write", None); match agent.tools_settings.get("fs_write") { Some(settings) => { let Settings { diff --git a/crates/chat-cli/src/cli/chat/tools/introspect.rs b/crates/chat-cli/src/cli/chat/tools/introspect.rs index 9a8d7be9ad..4968cd8e94 100644 --- a/crates/chat-cli/src/cli/chat/tools/introspect.rs +++ b/crates/chat-cli/src/cli/chat/tools/introspect.rs @@ -62,6 +62,29 @@ impl Introspect { documentation.push_str("\n\n--- docs/agent-file-locations.md ---\n"); documentation.push_str(include_str!("../../../../../../docs/agent-file-locations.md")); + documentation.push_str("\n\n--- docs/tangent-mode.md ---\n"); + documentation.push_str(include_str!("../../../../../../docs/tangent-mode.md")); + + documentation.push_str("\n\n--- docs/introspect-tool.md ---\n"); + documentation.push_str(include_str!("../../../../../../docs/introspect-tool.md")); + + documentation.push_str("\n\n--- docs/todo-lists.md ---\n"); + documentation.push_str(include_str!("../../../../../../docs/todo-lists.md")); + + documentation.push_str("\n\n--- docs/hooks.md ---\n"); + documentation.push_str(include_str!("../../../../../../docs/hooks.md")); + + documentation.push_str("\n\n--- changelog (from feed.json) ---\n"); + // Include recent changelog entries from feed.json + let feed = crate::cli::feed::Feed::load(); + let recent_entries = feed.get_all_changelogs().into_iter().take(5).collect::>(); + for entry in recent_entries { + documentation.push_str(&format!("\n## {} ({})\n", entry.version, entry.date)); + for change in &entry.changes { + documentation.push_str(&format!("- {}: {}\n", change.change_type, change.description)); + } + } + documentation.push_str("\n\n--- CONTRIBUTING.md ---\n"); documentation.push_str(include_str!("../../../../../../CONTRIBUTING.md")); @@ -93,6 +116,14 @@ impl Introspect { documentation .push_str("• Experiments: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/experiments.md\n"); documentation.push_str("• Agent File Locations: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/agent-file-locations.md\n"); + documentation + .push_str("• Tangent Mode: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/tangent-mode.md\n"); + documentation.push_str( + "• Introspect Tool: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/introspect-tool.md\n", + ); + documentation + .push_str("• Todo Lists: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/todo-lists.md\n"); + documentation.push_str("• Hooks: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/hooks.md\n"); documentation .push_str("• Contributing: https://github.com/aws/amazon-q-developer-cli/blob/main/CONTRIBUTING.md\n"); diff --git a/crates/chat-cli/src/cli/chat/tools/knowledge.rs b/crates/chat-cli/src/cli/chat/tools/knowledge.rs index 7bde4f0651..7f8428e6f8 100644 --- a/crates/chat-cli/src/cli/chat/tools/knowledge.rs +++ b/crates/chat-cli/src/cli/chat/tools/knowledge.rs @@ -20,7 +20,7 @@ use crate::cli::agent::{ use crate::database::settings::Setting; use crate::os::Os; use crate::util::knowledge_store::KnowledgeStore; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; /// The Knowledge tool allows storing and retrieving information across chat sessions. /// It provides semantic search capabilities for files, directories, and text content. @@ -497,7 +497,7 @@ impl Knowledge { _ = self; _ = os; - if matches_any_pattern(&agent.allowed_tools, "knowledge") { + if is_tool_in_allowlist(&agent.allowed_tools, "knowledge", None) { PermissionEvalResult::Allow } else { PermissionEvalResult::Ask diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index 6b8baec18f..9b90e1d052 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -57,7 +57,7 @@ use crate::cli::agent::{ use crate::cli::chat::line_tracker::FileLineTracker; use crate::os::Os; -pub const DEFAULT_APPROVE: [&str; 1] = ["fs_read"]; +pub const DEFAULT_APPROVE: [&str; 0] = []; pub const NATIVE_TOOLS: [&str; 8] = [ "fs_read", "fs_write", @@ -187,6 +187,16 @@ impl Tool { _ => None, } } + + /// Returns the tool's summary if available + pub fn get_summary(&self) -> Option { + match self { + Tool::FsWrite(fs_write) => fs_write.get_summary().cloned(), + Tool::ExecuteCommand(execute_cmd) => execute_cmd.summary.clone(), + Tool::FsRead(fs_read) => fs_read.summary.clone(), + _ => None, + } + } } /// A tool specification to be sent to the model as part of a conversation. Maps to @@ -271,6 +281,7 @@ pub struct QueuedTool { pub name: String, pub accepted: bool, pub tool: Tool, + pub tool_input: serde_json::Value, } /// The schema specification describing a tool's fields. diff --git a/crates/chat-cli/src/cli/chat/tools/use_aws.rs b/crates/chat-cli/src/cli/chat/tools/use_aws.rs index 456510b5bf..b7390744cd 100644 --- a/crates/chat-cli/src/cli/chat/tools/use_aws.rs +++ b/crates/chat-cli/src/cli/chat/tools/use_aws.rs @@ -29,7 +29,7 @@ use crate::cli::agent::{ PermissionEvalResult, }; use crate::os::Os; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; const READONLY_OPS: [&str; 6] = ["get", "describe", "list", "ls", "search", "batch_get"]; @@ -182,10 +182,12 @@ impl UseAws { allowed_services: Vec, #[serde(default)] denied_services: Vec, + #[serde(default)] + auto_allow_readonly: bool, } let Self { service_name, .. } = self; - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "use_aws"); + let is_in_allowlist = is_tool_in_allowlist(&agent.allowed_tools, "use_aws", None); match agent.tools_settings.get("use_aws") { Some(settings) => { let settings = match serde_json::from_value::(settings.clone()) { @@ -201,15 +203,16 @@ impl UseAws { if is_in_allowlist || settings.allowed_services.contains(service_name) { return PermissionEvalResult::Allow; } + // Check auto_allow_readonly setting for read-only operations + if settings.auto_allow_readonly && !self.requires_acceptance() { + return PermissionEvalResult::Allow; + } PermissionEvalResult::Ask }, None if is_in_allowlist => PermissionEvalResult::Allow, _ => { - if self.requires_acceptance() { - PermissionEvalResult::Ask - } else { - PermissionEvalResult::Allow - } + // Default behavior: always ask for confirmation (no auto-approval for read-only) + PermissionEvalResult::Ask }, } } @@ -390,4 +393,116 @@ mod tests { let res = cmd_one.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Deny(ref services) if services.contains(&"s3".to_string()))); } + + #[tokio::test] + async fn test_eval_perm_auto_allow_readonly_default() { + let os = Os::new().await.unwrap(); + + // Test read-only operation with default settings (auto_allow_readonly = false) + let readonly_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "list-objects", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let agent = Agent::default(); + let res = readonly_cmd.eval_perm(&os, &agent); + // Should ask for confirmation even for read-only operations by default + assert!(matches!(res, PermissionEvalResult::Ask)); + + // Test write operation with default settings + let write_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "put-object", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let res = write_cmd.eval_perm(&os, &agent); + // Should ask for confirmation for write operations + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_auto_allow_readonly_enabled() { + let os = Os::new().await.unwrap(); + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::::new(); + map.insert( + ToolSettingTarget("use_aws".to_string()), + serde_json::json!({ + "autoAllowReadonly": true + }), + ); + map + }, + ..Default::default() + }; + + // Test read-only operation with auto_allow_readonly = true + let readonly_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "list-objects", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let res = readonly_cmd.eval_perm(&os, &agent); + // Should allow read-only operations without confirmation + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test write operation with auto_allow_readonly = true + let write_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "put-object", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let res = write_cmd.eval_perm(&os, &agent); + // Should still ask for confirmation for write operations + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_auto_allow_readonly_with_denied_services() { + let os = Os::new().await.unwrap(); + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::::new(); + map.insert( + ToolSettingTarget("use_aws".to_string()), + serde_json::json!({ + "autoAllowReadonly": true, + "deniedServices": ["s3"] + }), + ); + map + }, + ..Default::default() + }; + + // Test read-only operation on denied service + let readonly_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "list-objects", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let res = readonly_cmd.eval_perm(&os, &agent); + // Should deny even read-only operations on denied services + assert!(matches!(res, PermissionEvalResult::Deny(ref services) if services.contains(&"s3".to_string()))); + } } diff --git a/crates/chat-cli/src/cli/feed.json b/crates/chat-cli/src/cli/feed.json index 439e455621..53c38195c2 100644 --- a/crates/chat-cli/src/cli/feed.json +++ b/crates/chat-cli/src/cli/feed.json @@ -10,6 +10,394 @@ "hidden": true, "changes": [] }, + { + "type": "release", + "date": "2025-09-26", + "version": "1.16.3", + "title": "Version 1.16.3", + "changes": [ + { + "type": "added", + "description": "[Experimental] Adds checkpointing functionality using Git CLI commands - [#2896](https://github.com/aws/amazon-q-developer-cli/pull/2896)" + }, + { + "type": "fixed", + "description": "Validation issues with MCP tool arguments - [#2986](https://github.com/aws/amazon-q-developer-cli/pull/2986)" + }, + { + "type": "added", + "description": "[Experimental] Add context usage percentage indicator to prompt - [#2994](https://github.com/aws/amazon-q-developer-cli/pull/2994)" + }, + { + "type": "added", + "description": "Support for custom prompts, see `/prompts` - [#2799](https://github.com/aws/amazon-q-developer-cli/pull/2799)" + }, + { + "type": "fixed", + "description": "Various issues with MCP OAuth - [#2976](https://github.com/aws/amazon-q-developer-cli/pull/2976)" + }, + { + "type": "fixed", + "description": "Improve error messages for dispatch failures - [#2969](https://github.com/aws/amazon-q-developer-cli/pull/2969)" + } + ] + }, + { + "type": "release", + "date": "2025-09-19", + "version": "1.16.2", + "title": "Version 1.16.2", + "changes": [ + { + "type": "added", + "description": "Add support for preToolUse and postToolUse hook - [#2875](https://github.com/aws/amazon-q-developer-cli/pull/2875)" + }, + { + "type": "added", + "description": "Support for specifying oauth scopes via config - [#2925]( https://github.com/aws/amazon-q-developer-cli/pull/2925)" + }, + { + "type": "fixed", + "description": "Support for headers ingestion for remote mcp - [#2925]( https://github.com/aws/amazon-q-developer-cli/pull/2925)" + }, + { + "type": "added", + "description": "Change autocomplete shortcut from ctrl-f to ctrl-g - [#2634](https://github.com/aws/amazon-q-developer-cli/pull/2634)" + }, + { + "type": "fixed", + "description": "Fix file-path expansion in mcp-config - [#2915]( https://github.com/aws/amazon-q-developer-cli/pull/2915)" + }, + { + "type": "fixed", + "description": "Fix filepath expansion to use absolute paths - [#2933](https://github.com/aws/amazon-q-developer-cli/pull/2933)" + } + ] + }, + { + "type": "release", + "date": "2025-09-17", + "version": "1.16.1", + "title": "Version 1.16.1", + "changes": [ + { + "type": "fixed", + "description": "Dashboard not updating after logging in - [#688](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/688)" + } + ] + }, + { + "type": "release", + "date": "2025-09-16", + "version": "1.16.0", + "title": "Version 1.16.0", + "changes": [ + { + "type": "added", + "description": "Support for remote MCP connections - [#2836](https://github.com/aws/amazon-q-developer-cli/pull/2836)" + }, + { + "type": "added", + "description": "A new `/tangent tail` command to preserve the last tangent conversation - [#2838](https://github.com/aws/amazon-q-developer-cli/pull/2838)" + }, + { + "type": "added", + "description": "A new edit subcommand to `/agent` slash command for modifying existing agents - [#2854](https://github.com/aws/amazon-q-developer-cli/pull/2854)" + }, + { + "type": "added", + "description": "A new auto-announcement feature with `/changelog` command - [#2833](https://github.com/aws/amazon-q-developer-cli/pull/2833)" + }, + { + "type": "added", + "description": "A new CLI history persistence feature with file storage - [#2769](https://github.com/aws/amazon-q-developer-cli/pull/2769)" + }, + { + "type": "added", + "description": "Support for comma-containing arguments in MCP --args parameter - [#2754](https://github.com/aws/amazon-q-developer-cli/pull/2754)" + }, + { + "type": "added", + "description": "Support for configurable autoAllowReadonly setting in use_aws tool - [#2828](https://github.com/aws/amazon-q-developer-cli/pull/2828)" + }, + { + "type": "added", + "description": "Support for configurable line wrapping in chat interface - [#2816](https://github.com/aws/amazon-q-developer-cli/pull/2816)" + }, + { + "type": "added", + "description": "Support for model field in agent configuration format - [#2815](https://github.com/aws/amazon-q-developer-cli/pull/2815)" + }, + { + "type": "added", + "description": "AGENTS.md documentation to default agent resources - [#2812](https://github.com/aws/amazon-q-developer-cli/pull/2812)" + }, + { + "type": "security", + "description": "Reduced default fs_read trust permission to current working directory only - [#2824](https://github.com/aws/amazon-q-developer-cli/pull/2824)" + }, + { + "type": "security", + "description": "Changed autoAllowReadonly default to false for security in execute_bash - [#2846](https://github.com/aws/amazon-q-developer-cli/pull/2846)" + }, + { + "type": "security", + "description": "Updated dangerous patterns for execute_bash to include $ character - [#2811](https://github.com/aws/amazon-q-developer-cli/pull/2811)" + }, + { + "type": "fixed", + "description": "Path with trailing slash not being handled in file matching - [#2817](https://github.com/aws/amazon-q-developer-cli/pull/2817)" + }, + { + "type": "fixed", + "description": "Summary being erroneously preserved when conversation is cleared - [#2793](https://github.com/aws/amazon-q-developer-cli/pull/2793)" + } + ] + }, + { + "type": "release", + "date": "2025-09-02", + "version": "1.15.0", + "title": "Version 1.15.0", + "changes": [ + { + "type": "added", + "description": "A new command `/experiment` for toggling experimental features - [#2711](https://github.com/aws/amazon-q-developer-cli/pull/2711)" + }, + { + "type": "added", + "description": "A new command `/agent generate` for generating agent config with Q - [#2690](https://github.com/aws/amazon-q-developer-cli/pull/2690)" + }, + { + "type": "added", + "description": "A new command `/tangent` for going on a tangent without context pollution - [#2634](https://github.com/aws/amazon-q-developer-cli/pull/2634)" + }, + { + "type": "added", + "description": "A new to-do list tool for handling complex multi-step prompts - [#2533](https://github.com/aws/amazon-q-developer-cli/pull/2533)" + }, + { + "type": "added", + "description": "Agent-scoped knowledge base and context-specific search - [#2647](https://github.com/aws/amazon-q-developer-cli/pull/2647)" + }, + { + "type": "added", + "description": "A new tool `introspect` that allows Q CLI to answer questions about itself - [#2677](https://github.com/aws/amazon-q-developer-cli/pull/2677)" + } + ] + }, + { + "type": "release", + "date": "2025-08-21", + "version": "1.14.1", + "title": "Version 1.14.1", + "changes": [ + { + "type": "fixed", + "description": "Tool permission issue in agent - [#2619](https://github.com/aws/amazon-q-developer-cli/pull/2619)" + }, + { + "type": "added", + "description": "MCP admin-level configuration with GetProfile - [#2639](https://github.com/aws/amazon-q-developer-cli/pull/2639)" + }, + { + "type": "added", + "description": "Wildcard pattern matching support for agent allowedTools - [#2612](https://github.com/aws/amazon-q-developer-cli/pull/2612)" + }, + { + "type": "added", + "description": "Agent hot swap capability - [#2637](https://github.com/aws/amazon-q-developer-cli/pull/2637)" + }, + { + "type": "fixed", + "description": "Agent default profile printing issue in `use_aws`, plus minor doc updates - [#2617](https://github.com/aws/amazon-q-developer-cli/pull/2617)" + }, + { + "type": "changed", + "description": "Knowledge beta improvements (phase 2): Refactored async_client and added BM25 support - [#2608](https://github.com/aws/amazon-q-developer-cli/pull/2608)" + } + ] + }, + { + "type": "release", + "date": "2025-08-15", + "version": "1.14.0", + "title": "Version 1.14.0", + "changes": [ + { + "type": "added", + "description": "Additional supported models in `q chat`, see `/model` - [#2419](https://github.com/aws/amazon-q-developer-cli/pull/2419)" + }, + { + "type": "added", + "description": "`--include` and `--exclude` flags for the `/knowledge add` command - [#2545](https://github.com/aws/amazon-q-developer-cli/pull/2545)" + }, + { + "type": "added", + "description": "Notifications on API retries - [#2607](https://github.com/aws/amazon-q-developer-cli/pull/2607)" + } + ] + }, + { + "type": "release", + "date": "2025-08-11", + "version": "1.13.3", + "title": "Version 1.13.3", + "changes": [ + { + "type": "added", + "description": "Support for setting denied shell commands with `toolsSettings.execute_bash.deniedCommands` - [#2512](https://github.com/aws/amazon-q-developer-cli/pull/2512)" + }, + { + "type": "added", + "description": "Support for setting denied AWS services with `toolsSettings.use_aws.deniedServices` - [#2512](https://github.com/aws/amazon-q-developer-cli/pull/2512)" + }, + { + "type": "added", + "description": "Support for setting denied file paths for `fs_read` and `fs_write` using `deniedPaths` in `toolsSettings` - [#2512](https://github.com/aws/amazon-q-developer-cli/pull/2512)" + }, + { + "type": "added", + "description": "Support for setting environment variables in MCP config - [#2241](https://github.com/aws/amazon-q-developer-cli/pull/2241)" + }, + { + "type": "fixed", + "description": "`q mcp add` from failing when the targeted `mcp.json` file does not exist - [#2561](https://github.com/aws/amazon-q-developer-cli/pull/2561)" + } + ] + }, + { + "type": "release", + "date": "2025-08-08", + "version": "1.13.2", + "title": "Version 1.13.2", + "changes": [ + { + "type": "added", + "description": "Regex matching for the `toolsSettings.execute_bash.allowedCommands` agent configuration - [#2483](https://github.com/aws/amazon-q-developer-cli/pull/2483)" + }, + { + "type": "added", + "description": "Support for workspace `mcp.json` configuration when `useLegacyMcpJson` is enabled - [#2516](https://github.com/aws/amazon-q-developer-cli/pull/2516)" + }, + { + "type": "added", + "description": "`/context show` to differentiate between agent and session context - [#2494](https://github.com/aws/amazon-q-developer-cli/pull/2494)" + }, + { + "type": "fixed", + "description": "Issues with `q mcp` subcommands failing - [#2475](https://github.com/aws/amazon-q-developer-cli/pull/2475)" + }, + { + "type": "fixed", + "description": "The `knowledge` tool always requiring permission - [#2501](https://github.com/aws/amazon-q-developer-cli/pull/2501)" + } + ] + }, + { + "type": "release", + "date": "2025-08-01", + "version": "1.13.1", + "title": "Version 1.13.1", + "changes": [ + { + "type": "added", + "description": "JSON schema support for the agent specification. Try it with `/agent create` - [#2440](https://github.com/aws/amazon-q-developer-cli/pull/2440)" + }, + { + "type": "deprecated", + "description": "The `/profile` command - [#2468](https://github.com/aws/amazon-q-developer-cli/pull/2468)" + }, + { + "type": "fixed", + "description": "Tool permissioning not being reset - [#2469](https://github.com/aws/amazon-q-developer-cli/pull/2469)" + }, + { + "type": "fixed", + "description": "An issue with history compaction not being applied on context overflow" + } + ] + }, + { + "type": "release", + "date": "2025-07-31", + "version": "1.13.0", + "title": "Version 1.13.0", + "changes": [ + { + "type": "added", + "description": "A new paradigm for working with `q chat` using agents. [See the documentation for more details](https://github.com/aws/amazon-q-developer-cli/blob/main/docs/SUMMARY.md)" + }, + { + "type": "added", + "description": "A new setting to disable markdown rendering in `qchat` with `chat.disableMarkdownRendering` - [#2223](https://github.com/aws/amazon-q-developer-cli/pull/2223)" + }, + { + "type": "added", + "description": "A new setting to disable markdown rendering in `qchat` with `chat.disableMarkdownRendering` - [#2236](https://github.com/aws/amazon-q-developer-cli/pull/2236)" + }, + { + "type": "fixed", + "description": "An issue with `/compact` failing for large initial messages - [#2375](https://github.com/aws/amazon-q-developer-cli/pull/2375)" + }, + { + "type": "fixed", + "description": "Images being removed from the conversation history - [#2333](https://github.com/aws/amazon-q-developer-cli/pull/2333)" + }, + { + "type": "fixed", + "description": "Code block detection for multi-line input - [#2384](https://github.com/aws/amazon-q-developer-cli/pull/2384)" + } + ] + }, + { + "type": "release", + "date": "2025-07-22", + "version": "1.12.7", + "title": "Version 1.12.7", + "changes": [ + { + "type": "fixed", + "description": "Issues with `q chat` requests not being cached correctly - [#461](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/461)" + } + ] + }, + { + "type": "release", + "date": "2025-07-17", + "version": "1.12.6", + "title": "Version 1.12.6", + "changes": [ + { + "type": "fixed", + "description": "Issues with read-only commands with the `execute_bash` tool - [#444](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/444)" + } + ] + }, + { + "type": "release", + "date": "2025-07-14", + "version": "1.12.5", + "title": "Version 1.12.5", + "changes": [ + { + "type": "added", + "description": "(Experimental) Support for Sigv4 authentication with `q chat`. Launch chat with the environment variable `AMAZON_Q_SIGV4=1` - [#207](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/207)" + }, + { + "type": "fixed", + "description": "An issue with authentication failing for long chat sessions - [#424](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/424)" + }, + { + "type": "fixed", + "description": "Issues with parsing `/compact` and `/editor` arguments - [#425](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/425)" + }, + { + "type": "changed", + "description": "`q chat` inline prompt hints to be disabled by default. To enable, run `q settings chat.enableHistoryHints true` - [#429](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/429)" + } + ] + }, { "type": "release", "date": "2025-07-09", diff --git a/crates/chat-cli/src/cli/mcp.rs b/crates/chat-cli/src/cli/mcp.rs index c70951a9b5..f0e8b97886 100644 --- a/crates/chat-cli/src/cli/mcp.rs +++ b/crates/chat-cli/src/cli/mcp.rs @@ -96,8 +96,11 @@ pub struct AddArgs { /// The command used to launch the server #[arg(long)] pub command: String, - /// Arguments to pass to the command - #[arg(long, action = ArgAction::Append, allow_hyphen_values = true, value_delimiter = ',')] + /// Arguments to pass to the command. Can be provided as: + /// 1. Multiple --args flags: --args arg1 --args arg2 --args "arg,with,commas" + /// 2. Comma-separated with escaping: --args "arg1,arg2,arg\,with\,commas" + /// 3. JSON array format: --args '["arg1", "arg2", "arg,with,commas"]' + #[arg(long, action = ArgAction::Append, allow_hyphen_values = true)] pub args: Vec, /// Where to add the server to. If an agent name is not supplied, the changes shall be made to /// the global mcp.json @@ -119,6 +122,9 @@ pub struct AddArgs { impl AddArgs { pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { + // Process args to handle comma-separated values, escaping, and JSON arrays + let processed_args = self.process_args()?; + match self.agent.as_deref() { Some(agent_name) => { let (mut agent, config_path) = Agent::get_agent_by_name(os, agent_name).await?; @@ -136,7 +142,7 @@ impl AddArgs { let merged_env = self.env.into_iter().flatten().collect::>(); let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({ "command": self.command, - "args": self.args, + "args": processed_args, "env": merged_env, "timeout": self.timeout.unwrap_or(default_timeout()), "disabled": self.disabled, @@ -169,7 +175,7 @@ impl AddArgs { let merged_env = self.env.into_iter().flatten().collect::>(); let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({ "command": self.command, - "args": self.args, + "args": processed_args, "env": merged_env, "timeout": self.timeout.unwrap_or(default_timeout()), "disabled": self.disabled, @@ -188,6 +194,17 @@ impl AddArgs { Ok(()) } + + fn process_args(&self) -> Result> { + let mut processed_args = Vec::new(); + + for arg in &self.args { + let parsed = parse_args(arg)?; + processed_args.extend(parsed); + } + + Ok(processed_args) + } } #[derive(Debug, Clone, PartialEq, Eq, Args)] @@ -389,7 +406,7 @@ impl StatusArgs { style::Print(format!("Disabled: {}\n", cfg.disabled)), style::Print(format!( "Env Vars: {}\n", - cfg.env.as_ref().map_or_else( + cfg.env.map_or_else( || "(none)".into(), |e| e .iter() @@ -507,6 +524,65 @@ fn parse_env_vars(arg: &str) -> Result> { Ok(vars) } +fn parse_args(arg: &str) -> Result> { + // Try to parse as JSON array first + if arg.trim_start().starts_with('[') { + match serde_json::from_str::>(arg) { + Ok(args) => return Ok(args), + Err(_) => { + bail!( + "Failed to parse arguments as JSON array. Expected format: '[\"arg1\", \"arg2\", \"arg,with,commas\"]'" + ); + }, + } + } + + // Check if the string contains escaped commas + let has_escaped_commas = arg.contains("\\,"); + + if has_escaped_commas { + // Parse with escape support + let mut args = Vec::new(); + let mut current_arg = String::new(); + let mut chars = arg.chars().peekable(); + + while let Some(ch) = chars.next() { + match ch { + '\\' => { + // Handle escape sequences + if let Some(&next_ch) = chars.peek() { + if next_ch == ',' || next_ch == '\\' { + current_arg.push(chars.next().unwrap()); + } else { + current_arg.push(ch); + } + } else { + current_arg.push(ch); + } + }, + ',' => { + // Split on unescaped comma + args.push(current_arg.trim().to_string()); + current_arg.clear(); + }, + _ => { + current_arg.push(ch); + }, + } + } + + // Add the last argument + if !current_arg.is_empty() || !args.is_empty() { + args.push(current_arg.trim().to_string()); + } + + Ok(args) + } else { + // Default behavior: split on commas (backward compatibility) + Ok(arg.split(',').map(|s| s.trim().to_string()).collect()) + } +} + async fn load_cfg(os: &Os, p: &PathBuf) -> Result { Ok(if os.fs.exists(p) { McpServerConfig::load_from_file(os, p).await? @@ -618,11 +694,7 @@ mod tests { name: "test_server".to_string(), scope: None, command: "test_command".to_string(), - args: vec![ - "awslabs.eks-mcp-server".to_string(), - "--allow-write".to_string(), - "--allow-sensitive-data-access".to_string(), - ], + args: vec!["awslabs.eks-mcp-server,--allow-write,--allow-sensitive-data-access".to_string(),], agent: None, env: vec![ [ @@ -680,4 +752,46 @@ mod tests { })) ); } + + #[test] + fn test_parse_args_comma_separated() { + let result = parse_args("arg1,arg2,arg3").unwrap(); + assert_eq!(result, vec!["arg1", "arg2", "arg3"]); + } + + #[test] + fn test_parse_args_with_escaped_commas() { + let result = parse_args("arg1,arg2\\,with\\,commas,arg3").unwrap(); + assert_eq!(result, vec!["arg1", "arg2,with,commas", "arg3"]); + } + + #[test] + fn test_parse_args_json_array() { + let result = parse_args(r#"["arg1", "arg2", "arg,with,commas"]"#).unwrap(); + assert_eq!(result, vec!["arg1", "arg2", "arg,with,commas"]); + } + + #[test] + fn test_parse_args_single_arg_with_commas() { + let result = parse_args("--config=key1=val1\\,key2=val2").unwrap(); + assert_eq!(result, vec!["--config=key1=val1,key2=val2"]); + } + + #[test] + fn test_parse_args_backward_compatibility() { + let result = parse_args("--config=key1=val1,key2=val2").unwrap(); + assert_eq!(result, vec!["--config=key1=val1", "key2=val2"]); + } + + #[test] + fn test_parse_args_mixed_escaping() { + let result = parse_args("normal,escaped\\,comma,--flag=val1\\,val2").unwrap(); + assert_eq!(result, vec!["normal", "escaped,comma", "--flag=val1,val2"]); + } + + #[test] + fn test_parse_args_json_array_invalid() { + let result = parse_args(r#"["invalid json"#); + assert!(result.is_err()); + } } diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs index c2185e36be..1e384cf63e 100644 --- a/crates/chat-cli/src/cli/mod.rs +++ b/crates/chat-cli/src/cli/mod.rs @@ -1,8 +1,8 @@ mod agent; -mod chat; +pub mod chat; mod debug; mod diagnostics; -mod feed; +pub mod feed; mod issue; mod mcp; mod settings; @@ -144,6 +144,11 @@ impl RootSubcommand { ); } + // Daily heartbeat check + if os.database.should_send_heartbeat() && os.telemetry.send_daily_heartbeat().is_ok() { + os.database.record_heartbeat_sent().ok(); + } + // Send executed telemetry. if self.valid_for_telemetry() { os.telemetry @@ -331,6 +336,12 @@ impl Cli { #[cfg(test)] mod test { + use chat::WrapMode::{ + Always, + Auto, + Never, + }; + use super::*; use crate::util::CHAT_BINARY_NAME; use crate::util::test::assert_parse; @@ -370,6 +381,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + wrap: None, })), verbose: 2, help_all: false, @@ -409,6 +421,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + wrap: None, }) ); } @@ -425,6 +438,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + wrap: None, }) ); } @@ -441,6 +455,7 @@ mod test { trust_all_tools: true, trust_tools: None, no_interactive: false, + wrap: None, }) ); } @@ -457,6 +472,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: true, + wrap: None, }) ); assert_parse!( @@ -469,6 +485,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: true, + wrap: None, }) ); } @@ -485,6 +502,7 @@ mod test { trust_all_tools: true, trust_tools: None, no_interactive: false, + wrap: None, }) ); } @@ -501,6 +519,7 @@ mod test { trust_all_tools: false, trust_tools: Some(vec!["".to_string()]), no_interactive: false, + wrap: None, }) ); } @@ -517,6 +536,50 @@ mod test { trust_all_tools: false, trust_tools: Some(vec!["fs_read".to_string(), "fs_write".to_string()]), no_interactive: false, + wrap: None, + }) + ); + } + + #[test] + fn test_chat_with_different_wrap_modes() { + assert_parse!( + ["chat", "-w", "never"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + wrap: Some(Never), + }) + ); + assert_parse!( + ["chat", "--wrap", "always"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + wrap: Some(Always), + }) + ); + assert_parse!( + ["chat", "--wrap", "auto"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + wrap: Some(Auto), }) ); } diff --git a/crates/chat-cli/src/cli/user.rs b/crates/chat-cli/src/cli/user.rs index 50e761b9b0..9394746c29 100644 --- a/crates/chat-cli/src/cli/user.rs +++ b/crates/chat-cli/src/cli/user.rs @@ -118,7 +118,7 @@ impl LoginArgs { }; let start_url = input("Enter Start URL", default_start_url.as_deref())?; - let region = input("Enter Region", default_region.as_deref())?; + let region = input("Enter Region", default_region.as_deref())?.trim().to_string(); let _ = os.database.set_start_url(start_url.clone()); let _ = os.database.set_idc_region(region.clone()); diff --git a/crates/chat-cli/src/database/mod.rs b/crates/chat-cli/src/database/mod.rs index 9b5a48ee10..80c667b8a8 100644 --- a/crates/chat-cli/src/database/mod.rs +++ b/crates/chat-cli/src/database/mod.rs @@ -61,6 +61,7 @@ const IDC_REGION_KEY: &str = "auth.idc.region"; // We include this key to remove for backwards compatibility const CUSTOMIZATION_STATE_KEY: &str = "api.selectedCustomization"; const PROFILE_MIGRATION_KEY: &str = "profile.Migrated"; +const HEARTBEAT_DATE_KEY: &str = "telemetry.lastHeartbeatDate"; const MIGRATIONS: &[Migration] = migrations![ "000_migration_table", @@ -274,6 +275,28 @@ impl Database { .and_then(|s| Uuid::from_str(&s).ok())) } + /// Get changelog last version from state table + pub fn get_changelog_last_version(&self) -> Result, DatabaseError> { + self.get_entry::(Table::State, "changelog.lastVersion") + } + + /// Set changelog last version in state table + pub fn set_changelog_last_version(&self, version: &str) -> Result<(), DatabaseError> { + self.set_entry(Table::State, "changelog.lastVersion", version)?; + Ok(()) + } + + /// Get changelog show count from state table + pub fn get_changelog_show_count(&self) -> Result, DatabaseError> { + self.get_entry::(Table::State, "changelog.showCount") + } + + /// Set changelog show count in state table + pub fn set_changelog_show_count(&self, count: i64) -> Result<(), DatabaseError> { + self.set_entry(Table::State, "changelog.showCount", count)?; + Ok(()) + } + /// Set the client ID used for telemetry requests. pub fn set_client_id(&mut self, client_id: Uuid) -> Result { self.set_json_entry(Table::State, CLIENT_ID_KEY, client_id.to_string()) @@ -311,6 +334,26 @@ impl Database { self.set_entry(Table::State, PROFILE_MIGRATION_KEY, true) } + /// Check if daily heartbeat should be sent + pub fn should_send_heartbeat(&self) -> bool { + use chrono::Utc; + let today = Utc::now().format("%Y-%m-%d").to_string(); + + match self.get_entry::(Table::State, HEARTBEAT_DATE_KEY) { + Ok(Some(last_date)) => last_date != today, + Ok(None) => true, // First time - definitely send + Err(_) => false, // Database error - don't send (might have already sent) + } + } + + /// Record that heartbeat was sent today + pub fn record_heartbeat_sent(&self) -> Result<(), DatabaseError> { + use chrono::Utc; + let today = Utc::now().format("%Y-%m-%d").to_string(); + self.set_entry(Table::State, HEARTBEAT_DATE_KEY, today)?; + Ok(()) + } + // /// Get the model id used for last conversation state. // pub fn get_last_used_model_id(&self) -> Result, DatabaseError> { // self.get_json_entry::(Table::State, LAST_USED_MODEL_ID) diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs index 21e8e98097..9f440f3ab4 100644 --- a/crates/chat-cli/src/database/settings.rs +++ b/crates/chat-cli/src/database/settings.rs @@ -41,6 +41,8 @@ pub enum Setting { KnowledgeIndexType, #[strum(message = "Key binding for fuzzy search command (single character)")] SkimCommandKey, + #[strum(message = "Key binding for autocompletion hint acceptance (single character)")] + AutocompletionKey, #[strum(message = "Enable tangent mode feature (boolean)")] EnabledTangentMode, #[strum(message = "Key binding for tangent mode toggle (single character)")] @@ -65,6 +67,8 @@ pub enum Setting { McpNoInteractiveTimeout, #[strum(message = "Track previously loaded MCP servers (boolean)")] McpLoadedBefore, + #[strum(message = "Show context usage percentage in prompt (boolean)")] + EnabledContextUsageIndicator, #[strum(message = "Default AI model for conversations (string)")] ChatDefaultModel, #[strum(message = "Disable markdown formatting in chat (boolean)")] @@ -77,6 +81,8 @@ pub enum Setting { ChatEnableHistoryHints, #[strum(message = "Enable the todo list feature (boolean)")] EnabledTodoList, + #[strum(message = "Enable the checkpoint feature (boolean)")] + EnabledCheckpoint, } impl AsRef for Setting { @@ -94,6 +100,7 @@ impl AsRef for Setting { Self::KnowledgeChunkOverlap => "knowledge.chunkOverlap", Self::KnowledgeIndexType => "knowledge.indexType", Self::SkimCommandKey => "chat.skimCommandKey", + Self::AutocompletionKey => "chat.autocompletionKey", Self::EnabledTangentMode => "chat.enableTangentMode", Self::TangentModeKey => "chat.tangentModeKey", Self::IntrospectTangentMode => "introspect.tangentMode", @@ -112,6 +119,8 @@ impl AsRef for Setting { Self::ChatDisableAutoCompaction => "chat.disableAutoCompaction", Self::ChatEnableHistoryHints => "chat.enableHistoryHints", Self::EnabledTodoList => "chat.enableTodoList", + Self::EnabledCheckpoint => "chat.enableCheckpoint", + Self::EnabledContextUsageIndicator => "chat.enableContextUsageIndicator", } } } @@ -139,6 +148,7 @@ impl TryFrom<&str> for Setting { "knowledge.chunkOverlap" => Ok(Self::KnowledgeChunkOverlap), "knowledge.indexType" => Ok(Self::KnowledgeIndexType), "chat.skimCommandKey" => Ok(Self::SkimCommandKey), + "chat.autocompletionKey" => Ok(Self::AutocompletionKey), "chat.enableTangentMode" => Ok(Self::EnabledTangentMode), "chat.tangentModeKey" => Ok(Self::TangentModeKey), "introspect.tangentMode" => Ok(Self::IntrospectTangentMode), @@ -157,6 +167,8 @@ impl TryFrom<&str> for Setting { "chat.disableAutoCompaction" => Ok(Self::ChatDisableAutoCompaction), "chat.enableHistoryHints" => Ok(Self::ChatEnableHistoryHints), "chat.enableTodoList" => Ok(Self::EnabledTodoList), + "chat.enableCheckpoint" => Ok(Self::EnabledCheckpoint), + "chat.enableContextUsageIndicator" => Ok(Self::EnabledContextUsageIndicator), _ => Err(DatabaseError::InvalidSetting(value.to_string())), } } @@ -293,6 +305,7 @@ mod test { .set(Setting::ChatDisableMarkdownRendering, false) .await .unwrap(); + settings.set(Setting::EnabledCheckpoint, true).await.unwrap(); assert_eq!(settings.get(Setting::TelemetryEnabled), Some(&Value::Bool(true))); assert_eq!( @@ -316,6 +329,7 @@ mod test { settings.get(Setting::ChatDisableMarkdownRendering), Some(&Value::Bool(false)) ); + assert_eq!(settings.get(Setting::EnabledCheckpoint), Some(&Value::Bool(true))); settings.remove(Setting::TelemetryEnabled).await.unwrap(); settings.remove(Setting::OldClientId).await.unwrap(); @@ -323,6 +337,7 @@ mod test { settings.remove(Setting::KnowledgeIndexType).await.unwrap(); settings.remove(Setting::McpLoadedBefore).await.unwrap(); settings.remove(Setting::ChatDisableMarkdownRendering).await.unwrap(); + settings.remove(Setting::EnabledCheckpoint).await.unwrap(); assert_eq!(settings.get(Setting::TelemetryEnabled), None); assert_eq!(settings.get(Setting::OldClientId), None); @@ -330,5 +345,6 @@ mod test { assert_eq!(settings.get(Setting::KnowledgeIndexType), None); assert_eq!(settings.get(Setting::McpLoadedBefore), None); assert_eq!(settings.get(Setting::ChatDisableMarkdownRendering), None); + assert_eq!(settings.get(Setting::EnabledCheckpoint), None); } } diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 27918c5773..44473c3a74 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -1,1140 +1,761 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::process::Stdio; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, + +use regex::Regex; +use rmcp::model::{ + CallToolRequestParam, + CallToolResult, + ErrorCode, + GetPromptRequestParam, + GetPromptResult, + Implementation, + InitializeRequestParam, + ListPromptsResult, + ListToolsResult, + LoggingLevel, + LoggingMessageNotificationParam, + PaginatedRequestParam, + ServerNotification, + ServerRequest, }; -use std::sync::{ - Arc, - RwLock as SyncRwLock, +use rmcp::service::{ + ClientInitializeError, + DynService, + NotificationContext, }; -use std::time::Duration; - -use serde::{ - Deserialize, - Serialize, +use rmcp::transport::{ + ConfigureCommandExt, + TokioChildProcess, +}; +use rmcp::{ + ErrorData, + RoleClient, + Service, + ServiceError, + ServiceExt, }; -use thiserror::Error; -use tokio::time; -use tokio::time::error::Elapsed; - -use super::transport::base_protocol::{ - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcVersion, +use tokio::io::AsyncReadExt as _; +use tokio::process::{ + ChildStderr, + Command, }; -use super::transport::stdio::JsonRpcStdioTransport; -use super::transport::{ - self, - Transport, - TransportError, +use tokio::task::JoinHandle; +use tracing::{ + debug, + error, + info, }; + +use super::messenger::Messenger; +use super::oauth_util::HttpTransport; use super::{ - JsonRpcResponse, - Listener as _, - LogListener, - Messenger, - PaginationSupportedOps, - PromptGet, - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ServerCapabilities, - ToolsListResult, + AuthClientWrapper, + OauthUtilError, + get_http_transport, }; -use crate::util::process::{ - Pid, - terminate_process, +use crate::cli::chat::server_messenger::ServerMessenger; +use crate::cli::chat::tools::custom_tool::{ + CustomToolConfig, + TransportType, }; +use crate::os::Os; +use crate::util::directories::DirectoryError; + +/// Fetches all pages of specified resources from a server +macro_rules! paginated_fetch { + ( + final_result_type: $final_result_type:ty, + content_type: $content_type:ty, + service_method: $service_method:ident, + result_field: $result_field:ident, + messenger_method: $messenger_method:ident, + service: $service:expr, + messenger: $messenger:expr, + server_name: $server_name:expr + ) => { + { + let mut cursor = None::; + let mut final_result = Ok(<$final_result_type>::with_all_items(Default::default())); + let mut content = Vec::<$content_type>::new(); -pub type ClientInfo = serde_json::Value; -pub type StdioTransport = JsonRpcStdioTransport; - -/// Represents the capabilities of a client in the Model Context Protocol. -/// This structure is sent to the server during initialization to communicate -/// what features the client supports and provide information about the client. -/// When features are added to the client, these should be declared in the [From] trait implemented -/// for the struct. -#[derive(Default, Debug, Serialize)] -#[serde(rename_all = "camelCase")] -struct ClientCapabilities { - protocol_version: JsonRpcVersion, - capabilities: HashMap, - client_info: serde_json::Value, -} + loop { + let param = Some(PaginatedRequestParam { cursor: cursor.clone() }); + match $service.$service_method(param).await { + Ok(mut result) => { + if let Some(s) = result.next_cursor { + cursor.replace(s); + } + content.append(&mut result.$result_field); + }, + Err(e) => { + final_result = Err(e); + break; + }, + } + if cursor.is_none() { + break; + } + } + + if let Ok(final_result) = &mut final_result { + final_result.$result_field.append(&mut content); + } -impl From for ClientCapabilities { - fn from(client_info: ClientInfo) -> Self { - ClientCapabilities { - client_info, - ..Default::default() + if let Err(e) = $messenger.$messenger_method(final_result, Some($service)).await { + error!(target: "mcp", "Initial {} result failed to send for server {}: {}", + stringify!($result_field), $server_name, e); + } } - } + }; +} + +/// Substitutes environment variables in the format ${env:VAR_NAME} with their actual values +fn substitute_env_vars(input: &str, env: &crate::os::Env) -> String { + // Create a regex to match ${env:VAR_NAME} pattern + let re = Regex::new(r"\$\{env:([^}]+)\}").unwrap(); + + re.replace_all(input, |caps: ®ex::Captures<'_>| { + let var_name = &caps[1]; + env.get(var_name).unwrap_or_else(|_| format!("${{{}}}", var_name)) + }) + .to_string() } -#[derive(Debug, Deserialize)] -pub struct ClientConfig { - pub server_name: String, - pub bin_path: String, - pub args: Vec, - pub timeout: u64, - pub client_info: serde_json::Value, - pub env: Option>, +/// Process a HashMap of environment variables, substituting any ${env:VAR_NAME} patterns +/// with their actual values from the environment +fn process_env_vars(env_vars: &mut HashMap, env: &crate::os::Env) { + for (_, value) in env_vars.iter_mut() { + *value = substitute_env_vars(value, env); + } } -#[allow(dead_code)] -#[derive(Debug, Error)] -pub enum ClientError { +#[derive(Debug, thiserror::Error)] +pub enum McpClientError { #[error(transparent)] - TransportError(#[from] TransportError), + ClientInitializeError(#[from] Box), #[error(transparent)] Io(#[from] std::io::Error), #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Operation timed out: {context}")] - RuntimeError { - #[source] - source: tokio::time::error::Elapsed, - context: String, - }, - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error("Failed to obtain process id")] - MissingProcessId, - #[error("Invalid path received")] - InvalidPath, - #[error("{0}")] - ProcessKillError(String), + JoinError(#[from] tokio::task::JoinError), + #[error("Client has not finished initializing")] + NotReady, + #[error(transparent)] + Directory(#[from] DirectoryError), + #[error(transparent)] + OauthUtil(#[from] OauthUtilError), + #[error(transparent)] + Parse(#[from] url::ParseError), + #[error(transparent)] + Auth(#[from] crate::auth::AuthError), #[error("{0}")] - PoisonError(String), + MalformedConfig(&'static str), + #[error(transparent)] + LookUp(#[from] shellexpand::LookupError), } -impl From<(tokio::time::error::Elapsed, String)> for ClientError { - fn from((error, context): (tokio::time::error::Elapsed, String)) -> Self { - ClientError::RuntimeError { source: error, context } - } +/// Decorates the method passed in with retry logic, but only if the [RunningService] has an +/// instance of [AuthClientDropGuard]. +/// The various methods to interact with the mcp server provided by RMCP supposedly does refresh +/// token once the token expires but that logic would require us to also note down the time at +/// which a token is obtained since the only time related information in the token is the duration +/// for which a token is valid. However, if we do solely rely on the internals of these methods to +/// refresh tokens, we would have no way of knowing when a token is obtained. (Maybe there is a +/// method that would allow us to configure what extra info to include in the token. If you find it, +/// feel free to remove this. That would also enable us to simplify the definition of +/// [RunningService]) +macro_rules! decorate_with_auth_retry { + ($param_type:ty, $method_name:ident, $return_type:ty) => { + pub async fn $method_name(&self, param: $param_type) -> Result<$return_type, rmcp::ServiceError> { + let first_attempt = match &self.inner_service { + InnerService::Original(rs) => rs.$method_name(param.clone()).await, + InnerService::Peer(peer) => peer.$method_name(param.clone()).await, + }; + + match first_attempt { + Ok(result) => Ok(result), + Err(e) => { + // TODO: discern error type prior to retrying + // Not entirely sure what is thrown when auth is required + if let Some(auth_client) = self.auth_client.as_ref() { + let refresh_result = auth_client.refresh_token().await; + match refresh_result { + Ok(_) => { + info!("Token refreshed"); + // Retry the operation after token refresh + match &self.inner_service { + InnerService::Original(rs) => rs.$method_name(param).await, + InnerService::Peer(peer) => peer.$method_name(param).await, + } + }, + Err(_) => { + // If refresh fails, return the original error + // Currently our event loop just does not allow us easy ways to + // reauth entirely once a session starts since this would mean + // swapping of transport (which also means swapping of client) + Err(e) + }, + } + } else { + // No auth client available, return original error + Err(e) + } + }, + } + } + }; } -#[derive(Debug)] -pub struct Client { - server_name: String, - transport: Arc, - timeout: u64, - pub server_process_id: Option, - client_info: serde_json::Value, - current_id: Arc, - pub messenger: Option>, - // TODO: move this to tool manager that way all the assets are treated equally - pub prompt_gets: Arc>>, - pub is_prompts_out_of_date: Arc, +/// Wrapper around rmcp service types to enable cloning. +/// +/// This exists because `rmcp::service::RunningService` is not directly cloneable as it is a +/// pointer type to `Peer`. This enum allows us to hold either the original service or its +/// peer representation, enabling cloning by converting the original service to a peer when needed. +pub enum InnerService { + Original(rmcp::service::RunningService>>), + Peer(rmcp::service::Peer), } -impl Clone for Client { - fn clone(&self) -> Self { - Self { - server_name: self.server_name.clone(), - transport: self.transport.clone(), - timeout: self.timeout, - // Note that we cannot have an id for the clone because we would kill the original - // process when we drop the clone - server_process_id: None, - client_info: self.client_info.clone(), - current_id: self.current_id.clone(), - messenger: None, - prompt_gets: self.prompt_gets.clone(), - is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), +impl std::fmt::Debug for InnerService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InnerService::Original(_) => f.debug_tuple("Original").field(&"RunningService<..>").finish(), + InnerService::Peer(peer) => f.debug_tuple("Peer").field(peer).finish(), } } } -impl Client { - pub fn from_config(config: ClientConfig) -> Result { - let ClientConfig { - server_name, - bin_path, - args, - timeout, - client_info, - env, - } = config; - let child = { - let expanded_bin_path = shellexpand::tilde(&bin_path); - - // On Windows, we need to use cmd.exe to run the binary with arguments because Tokio - // always assumes that the program has an .exe extension, which is not the case for - // helpers like `uvx` or `npx`. - let mut command = if cfg!(windows) { - let mut cmd = tokio::process::Command::new("cmd.exe"); - cmd.args(["/C", &Self::build_windows_command(&expanded_bin_path, args)]); - cmd - } else { - let mut cmd = tokio::process::Command::new(expanded_bin_path.to_string()); - cmd.args(args); - cmd - }; - - command - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .envs(std::env::vars()); - - #[cfg(not(windows))] - command.process_group(0); - - if let Some(env) = env { - for (env_name, env_value) in env { - command.env(env_name, env_value); - } - } - - command.spawn()? - }; - - let server_process_id = child.id().ok_or(ClientError::MissingProcessId)?; - let server_process_id = Some(Pid::from_u32(server_process_id)); - - let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?); - Ok(Self { - server_name, - transport, - timeout, - server_process_id, - client_info, - current_id: Arc::new(AtomicU64::new(0)), - messenger: None, - prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), - is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), - }) +impl Clone for InnerService { + fn clone(&self) -> Self { + match self { + InnerService::Original(rs) => InnerService::Peer((*rs).clone()), + InnerService::Peer(peer) => InnerService::Peer(peer.clone()), + } } +} - fn build_windows_command(bin_path: &str, args: Vec) -> String { - let mut parts = Vec::new(); - - // Add the binary path, quoted if necessary - parts.push(Self::quote_windows_arg(bin_path)); +/// A wrapper around MCP (Model Context Protocol) service instances that manages +/// authentication and enables cloning functionality. +/// +/// This struct holds either an original `RunningService` or its peer representation, +/// along with an optional authentication drop guard for managing OAuth tokens. +/// The authentication drop guard handles token lifecycle and cleanup when the +/// service is dropped. +/// +/// # Fields +/// * `inner_service` - The underlying MCP service instance (original or peer) +/// * `auth_dropguard` - Optional authentication manager for OAuth token handling +#[derive(Debug)] +pub struct RunningService { + pub inner_service: InnerService, + auth_client: Option, +} - // Add all arguments, quoted if necessary - for arg in args { - parts.push(Self::quote_windows_arg(&arg)); +impl Clone for RunningService { + fn clone(&self) -> Self { + RunningService { + inner_service: self.inner_service.clone(), + auth_client: self.auth_client.clone(), } - - parts.join(" ") } +} - fn quote_windows_arg(arg: &str) -> String { - // If the argument doesn't need quoting, return as-is - if !arg.chars().any(|c| " \t\n\r\"".contains(c)) { - return arg.to_string(); - } - - let mut result = String::from("\""); - let mut backslashes = 0; - - for c in arg.chars() { - match c { - '\\' => { - backslashes += 1; - result.push('\\'); - }, - '"' => { - // Escape all preceding backslashes and the quote - for _ in 0..backslashes { - result.push('\\'); - } - result.push_str("\\\""); - backslashes = 0; - }, - _ => { - backslashes = 0; - result.push(c); - }, - } - } +impl RunningService { + decorate_with_auth_retry!(CallToolRequestParam, call_tool, CallToolResult); - // Escape trailing backslashes before the closing quote - for _ in 0..backslashes { - result.push('\\'); - } + decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult); +} - result.push('"'); - result - } +pub type StdioTransport = (TokioChildProcess, Option); + +// TODO: add sse support (even though it's deprecated) +/// Represents the different transport mechanisms available for MCP (Model Context Protocol) +/// communication. +/// +/// This enum encapsulates the two primary ways to communicate with MCP servers: +/// - HTTP-based transport for remote servers +/// - Standard I/O transport for local process-based servers +pub enum Transport { + /// HTTP transport for communicating with remote MCP servers over network protocols. + /// Uses a streamable HTTP client with authentication support. + Http(HttpTransport), + /// Standard I/O transport for communicating with local MCP servers via child processes. + /// Communication happens through stdin/stdout pipes. + Stdio(StdioTransport), } -impl Drop for Client -where - T: Transport, -{ - // IF the servers are implemented well, they will shutdown once the pipe closes. - // This drop trait is here as a fail safe to ensure we don't leave behind any orphans. - fn drop(&mut self) { - if let Some(process_id) = self.server_process_id { - let _ = terminate_process(process_id); - } - if let Some(ref messenger) = self.messenger { - messenger.send_deinit_msg(); +impl std::fmt::Debug for Transport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Transport::Http(_) => f.debug_tuple("Http").field(&"HttpTransport").finish(), + Transport::Stdio(_) => f.debug_tuple("Stdio").field(&"TokioChildProcess").finish(), } } } -impl Client -where - T: Transport, -{ - /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization - /// - /// Also done are the following: - /// - Spawns task for listening to server driven workflows - /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server - /// capabilities received - pub async fn init(&self) -> Result { - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - - // Spawning a task to listen and log stderr output - tokio::spawn(async move { - let mut log_listener = transport_ref.get_log_listener(); - loop { - match log_listener.recv().await { - Ok(msg) => { - tracing::trace!(target: "mcp", "{server_name} logged {}", msg); - }, - Err(e) => { - tracing::error!( - "Error encountered while reading from stderr for {server_name}: {:?}\nEnding stderr listening task.", - e - ); - break; - }, - } - } - }); +/// This struct implements the [Service] trait from rmcp. It is within this trait the logic of +/// server driven data flow (i.e. requests and notifications that are sent from the server) are +/// handled. +#[derive(Debug)] +pub struct McpClientService { + pub config: CustomToolConfig, + server_name: String, + messenger: ServerMessenger, +} - let init_params = Some({ - let client_cap = ClientCapabilities::from(self.client_info.clone()); - serde_json::json!(client_cap) - }); - let init_resp = self.request("initialize", init_params).await?; - if let Err(e) = examine_server_capabilities(&init_resp) { - return Err(ClientError::NegotiationError(format!( - "Client {} has failed to negotiate server capabilities with server: {:?}", - self.server_name, e - ))); - } - let cap = { - let result = init_resp.result.ok_or(ClientError::NegotiationError(format!( - "Server {} init resp is missing result", - self.server_name - )))?; - let cap = result - .get("capabilities") - .ok_or(ClientError::NegotiationError(format!( - "Server {} init resp result is missing capabilities", - self.server_name - )))? - .clone(); - serde_json::from_value::(cap)? - }; - self.notify("initialized", None).await?; - - // TODO: group this into examine_server_capabilities - // Prefetch prompts in the background. We should only do this after the server has been - // initialized - if cap.prompts.is_some() { - self.is_prompts_out_of_date.store(true, Ordering::Relaxed); - let client_ref = (*self).clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - tokio::spawn(async move { - fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; - }); - } - if cap.tools.is_some() { - let client_ref = (*self).clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - tokio::spawn(async move { - fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; - }); +impl McpClientService { + pub fn new(server_name: String, config: CustomToolConfig, messenger: ServerMessenger) -> Self { + Self { + server_name, + config, + messenger, } + } - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - let client_ref = (*self).clone(); - - let prompts_list_changed_supported = cap.prompts.as_ref().is_some_and(|p| p.get("listChanged").is_some()); - let tools_list_changed_supported = cap.tools.as_ref().is_some_and(|t| t.get("listChanged").is_some()); - tokio::spawn(async move { - let mut listener = transport_ref.get_listener(); - loop { - match listener.recv().await { - Ok(msg) => { - match msg { - JsonRpcMessage::Request(_req) => {}, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { method, params, .. } = notif; - match method.as_str() { - "notifications/message" | "message" => { - let level = params - .as_ref() - .and_then(|p| p.get("level")) - .and_then(|v| serde_json::to_string(v).ok()); - let data = params - .as_ref() - .and_then(|p| p.get("data")) - .and_then(|v| serde_json::to_string(v).ok()); - if let (Some(level), Some(data)) = (level, data) { - match level.to_lowercase().as_str() { - "error" => { - tracing::error!(target: "mcp", "{}: {}", server_name, data); - }, - "warn" => { - tracing::warn!(target: "mcp", "{}: {}", server_name, data); - }, - "info" => { - tracing::info!(target: "mcp", "{}: {}", server_name, data); - }, - "debug" => { - tracing::debug!(target: "mcp", "{}: {}", server_name, data); - }, - "trace" => { - tracing::trace!(target: "mcp", "{}: {}", server_name, data); - }, - _ => {}, - } + pub async fn init(mut self, os: &Os) -> Result { + let os_clone = os.clone(); + + let handle: JoinHandle> = tokio::spawn(async move { + let messenger_clone = self.messenger.clone(); + let server_name = self.server_name.clone(); + let backup_config = self.config.clone(); + + let result: Result<_, McpClientError> = async { + let messenger_dup = messenger_clone.duplicate(); + let (service, stderr, auth_client) = match self.get_transport(&os_clone, &*messenger_dup).await? { + Transport::Stdio((child_process, stderr)) => { + let service = self + .into_dyn() + .serve::(child_process) + .await + .map_err(Box::new)?; + + (service, stderr, None) + }, + Transport::Http(http_transport) => { + match http_transport { + HttpTransport::WithAuth((transport, mut auth_client)) => { + // The crate does not automatically refresh tokens when they expire. We + // would need to handle that here + let url = &backup_config.url; + let service = match self.into_dyn().serve(transport).await.map_err(Box::new) { + Ok(service) => service, + Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => { + debug!("## mcp: first hand shake attempt failed: {:?}", e); + let refresh_res = auth_client.refresh_token().await; + let new_self = McpClientService::new( + server_name.clone(), + backup_config.clone(), + messenger_clone.clone(), + ); + + let scopes = &backup_config.oauth_scopes; + let timeout = backup_config.timeout; + let headers = &backup_config.headers; + let new_transport = + get_http_transport(&os_clone, url, timeout, scopes, headers,Some(auth_client.auth_client.clone()), &*messenger_dup).await?; + + match new_transport { + HttpTransport::WithAuth((new_transport, new_auth_client)) => { + auth_client = new_auth_client; + + match refresh_res { + Ok(_) => { + new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? + }, + Err(e) => { + error!("## mcp: token refresh attempt failed: {:?}", e); + info!("Retry for http transport failed {e}. Possible reauth needed"); + // This could be because the refresh token is expired, in which + // case we would need to have user go through the auth flow + // again. We do this by deleting the cred + // and discarding the client to trigger a full auth flow + tokio::fs::remove_file(&auth_client.cred_full_path).await?; + let new_transport = + get_http_transport(&os_clone, url, timeout, scopes,headers,None, &*messenger_dup).await?; + + match new_transport { + HttpTransport::WithAuth((new_transport, new_auth_client)) => { + auth_client = new_auth_client; + new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? + }, + HttpTransport::WithoutAuth(new_transport) => { + new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? + }, + } + }, + } + }, + HttpTransport::WithoutAuth(new_transport) => + new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?, } }, - "notifications/prompts/list_changed" | "prompts/list_changed" - if prompts_list_changed_supported => - { - // TODO: after we have moved the prompts to the tool - // manager we follow the same workflow as the list changed - // for tools - fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) - .await; - client_ref.is_prompts_out_of_date.store(true, Ordering::Release); - }, - "notifications/tools/list_changed" | "tools/list_changed" - if tools_list_changed_supported => - { - fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) - .await; - }, - _ => {}, - } + Err(e) => return Err(e.into()), + }; + + (service, None, Some(auth_client)) }, - JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ + HttpTransport::WithoutAuth(transport) => { + let service = self.into_dyn().serve(transport).await.map_err(Box::new)?; + + (service, None, None) }, } }, - Err(e) => { - tracing::error!("Background listening thread for client {}: {:?}", server_name, e); - // If we don't have anything on the other end, we should just end the task - // now - if let TransportError::RecvError(tokio::sync::broadcast::error::RecvError::Closed) = e { - tracing::error!( - "All senders dropped for transport layer for server {}: {:?}. This likely means the mcp server process is no longer running.", - server_name, - e - ); - break; - } - }, - } - } - }); - - Ok(cap) - } + }; - /// Sends a request to the server associated. - /// This call will yield until a response is received. - pub async fn request( - &self, - method: &str, - params: Option, - ) -> Result { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let recv_map_err = |e: Elapsed| (e, format!("recv for {method}")); - let mut id = self.get_id(); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - tracing::trace!(target: "mcp", "To {}:\n{:#?}", self.server_name, request); - let msg = JsonRpcMessage::Request(request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let mut listener = self.transport.get_listener(); - let mut resp = time::timeout(Duration::from_millis(self.timeout), async { - // we want to ignore all other messages sent by the server at this point and let the - // background loop handle them - // We also want to ignore all messages emitted by the server to its stdout that does - // not deserialize into a valid JsonRpcMessage (they are not supposed to do this but - // too many people complained about this so we are adding this safeguard in) - loop { - if let Ok(JsonRpcMessage::Response(resp)) = listener.recv().await { - if resp.id == id { - break Ok::(resp); - } - } + Ok((service, stderr, auth_client)) } - }) - .await - .map_err(recv_map_err)??; - // Pagination support: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#pagination-model - let mut next_cursor = resp.result.as_ref().and_then(|v| v.get("nextCursor")); - if next_cursor.is_some() { - let mut current_resp = resp.clone(); - let mut results = Vec::::new(); - let pagination_supported_ops = { - let maybe_pagination_supported_op: Result = method.try_into(); - maybe_pagination_supported_op.ok() - }; - if let Some(ops) = pagination_supported_ops { - loop { - let result = current_resp.result.as_ref().cloned().unwrap(); - let mut list: Vec = match ops { - PaginationSupportedOps::ResourcesList => { - let ResourcesListResult { resources: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ResourceTemplatesList => { - let ResourceTemplatesListResult { - resource_templates: list, - .. - } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::PromptsList => { - let PromptsListResult { prompts: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ToolsList => { - let ToolsListResult { tools: list, .. } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, + .await; + + let (service, child_stderr, auth_dropguard) = match result { + Ok((service, stderr, auth_dg)) => (service, stderr, auth_dg), + Err(e) => { + let msg = e.to_string(); + let error_data = ErrorData { + code: ErrorCode::RESOURCE_NOT_FOUND, + message: Cow::from(msg), + data: None, }; - results.append(&mut list); - if next_cursor.is_none() { - break; + let err = ServiceError::McpError(error_data); + + if let Err(send_err) = messenger_clone.send_tools_list_result(Err(err), None).await { + error!("Error sending tool result for {server_name}: {send_err}"); } - id = self.get_id(); - let next_request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params: Some(serde_json::json!({ - "cursor": next_cursor, - })), - }; - let msg = JsonRpcMessage::Request(next_request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let resp = time::timeout(Duration::from_millis(self.timeout), async { - loop { - if let Ok(JsonRpcMessage::Response(resp)) = listener.recv().await { - if resp.id == id { - break Ok::(resp); - } - } + + return Err(e); + }, + }; + + if let Some(mut stderr) = child_stderr { + let server_name_clone = server_name.clone(); + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + loop { + match stderr.read(&mut buf).await { + Ok(0) => { + tracing::info!(target: "mcp", "{server_name_clone} stderr listening process exited due to EOF"); + break; + }, + Ok(size) => { + tracing::info!(target: "mcp", "{server_name_clone} logged to its stderr: {}", String::from_utf8_lossy(&buf[0..size])); + }, + Err(e) => { + tracing::info!(target: "mcp", "{server_name_clone} stderr listening process exited due to error: {e}"); + break; // Error reading + }, } - }) - .await - .map_err(recv_map_err)??; - current_resp = resp; - next_cursor = current_resp.result.as_ref().and_then(|v| v.get("nextCursor")); - } - resp.result = Some({ - let mut map = serde_json::Map::new(); - map.insert(ops.as_key().to_owned(), serde_json::to_value(results)?); - serde_json::to_value(map)? + } }); } - } - tracing::trace!(target: "mcp", "From {}:\n{:#?}", self.server_name, resp); - Ok(resp) - } - - /// Sends a notification to the server associated. - /// Notifications are requests that expect no responses. - pub async fn notify(&self, method: &str, params: Option) -> Result<(), ClientError> { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let notification = JsonRpcNotification { - jsonrpc: JsonRpcVersion::default(), - method: format!("notifications/{}", method), - params, - }; - let msg = JsonRpcMessage::Notification(notification); - Ok( - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??, - ) - } - fn get_id(&self) -> u64 { - self.current_id.fetch_add(1, Ordering::SeqCst) - } -} - -fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientError> { - // Check the jrpc version. - // Currently we are only proceeding if the versions are EXACTLY the same. - let jrpc_version = ser_cap.jsonrpc.as_u32_vec(); - let client_jrpc_version = JsonRpcVersion::default().as_u32_vec(); - for (sv, cv) in jrpc_version.iter().zip(client_jrpc_version.iter()) { - if sv != cv { - return Err(ClientError::NegotiationError( - "Incompatible jrpc version between server and client".to_owned(), - )); - } - } - Ok(()) -} + let service_clone = service.clone(); + tokio::spawn(async move { + let result: Result<(), Box> = async { + let init_result = service_clone.peer_info(); + if let Some(init_result) = init_result { + if init_result.capabilities.tools.is_some() { + paginated_fetch! { + final_result_type: ListToolsResult, + content_type: rmcp::model::Tool, + service_method: list_tools, + result_field: tools, + messenger_method: send_tools_list_result, + service: service_clone.clone(), + messenger: messenger_clone, + server_name: server_name + }; + } -#[allow(clippy::borrowed_box)] -async fn fetch_prompts_and_notify_with_messenger(client: &Client, messenger: Option<&Box>) -where - T: Transport, -{ - let prompt_list_result = 'prompt_list_result: { - let Ok(resp) = client.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client.server_name); - return; - }; - let prompt_list_result = match serde_json::from_value::(result) { - Ok(res) => res, - Err(e) => { - let msg = format!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); - break 'prompt_list_result Err(eyre::eyre!(msg)); - }, - }; - Ok::(prompt_list_result) - }; + if init_result.capabilities.prompts.is_some() { + paginated_fetch! { + final_result_type: ListPromptsResult, + content_type: rmcp::model::Prompt, + service_method: list_prompts, + result_field: prompts, + messenger_method: send_prompts_list_result, + service: service_clone, + messenger: messenger_clone, + server_name: server_name + }; + } + } + Ok(()) + } + .await; - if let Some(messenger) = messenger { - if let Err(e) = messenger.send_prompts_list_result(prompt_list_result).await { - tracing::error!("Failed to send prompt result through messenger: {:?}", e); - } - } -} + if let Err(e) = result { + error!(target: "mcp", "Error in MCP client initialization: {}", e); + } + }); -#[allow(clippy::borrowed_box)] -async fn fetch_tools_and_notify_with_messenger(client: &Client, messenger: Option<&Box>) -where - T: Transport, -{ - // TODO: decouple pagination logic from request and have page fetching logic here - // instead - let tool_list_result = 'tool_list_result: { - let resp = match client.request("tools/list", None).await { - Ok(resp) => resp, - Err(e) => break 'tool_list_result Err(e.into()), - }; - if let Some(error) = resp.error { - let msg = format!("Failed to retrieve tool list for {}: {:?}", client.server_name, error); - break 'tool_list_result Err(eyre::eyre!(msg)); - } - let Some(result) = resp.result else { - let msg = format!("Tool list response from {} is missing result", client.server_name); - break 'tool_list_result Err(eyre::eyre!(msg)); - }; - let tool_list_result = match serde_json::from_value::(result) { - Ok(result) => result, - Err(e) => { - let msg = format!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); - break 'tool_list_result Err(eyre::eyre!(msg)); - }, - }; - Ok::(tool_list_result) - }; + Ok(RunningService { + inner_service: InnerService::Original(service), + auth_client: auth_dropguard, + }) + }); - if let Some(messenger) = messenger { - if let Err(e) = messenger.send_tools_list_result(tool_list_result).await { - tracing::error!("Failed to send tool result through messenger {:?}", e); - } + Ok(InitializedMcpClient::Pending(handle)) } -} -#[cfg(test)] -mod tests { - use std::path::PathBuf; + async fn get_transport(&mut self, os: &Os, messenger: &dyn Messenger) -> Result { + let CustomToolConfig { + r#type, + url, + headers, + oauth_scopes: scopes, + command: command_as_str, + args, + env: config_envs, + timeout, + .. + } = &mut self.config; - use serde_json::Value; + let is_malformed_http = matches!(r#type, TransportType::Http) && url.is_empty(); + let is_malformed_stdio = matches!(r#type, TransportType::Stdio) && command_as_str.is_empty(); - use super::*; - const TEST_BIN_OUT_DIR: &str = "target/debug"; - const TEST_SERVER_NAME: &str = "test_mcp_server"; + if is_malformed_http { + return Err(McpClientError::MalformedConfig( + "MCP config is malformed: transport type is specified to be http but url is empty", + )); + } else if is_malformed_stdio { + return Err(McpClientError::MalformedConfig( + "MCP config is malformed: transport type is specified to be stdio but command is empty", + )); + } - fn get_workspace_root() -> PathBuf { - let output = std::process::Command::new("cargo") - .args(["metadata", "--format-version=1", "--no-deps"]) - .output() - .expect("Failed to execute cargo metadata"); + match r#type { + TransportType::Stdio => { + let context = |input: &str| Ok(os.env.get(input).ok()); + let home_dir = || os.env.home().map(|p| p.to_string_lossy().to_string()); + let expanded_cmd = shellexpand::full_with_context(command_as_str, home_dir, context)?; - let metadata: serde_json::Value = - serde_json::from_slice(&output.stdout).expect("Failed to parse cargo metadata"); + let command = Command::new(expanded_cmd.as_ref() as &str).configure(|cmd| { + if let Some(envs) = config_envs { + process_env_vars(envs, &os.env); + cmd.envs(envs); + } + cmd.envs(std::env::vars()).args(args); - let workspace_root = metadata["workspace_root"] - .as_str() - .expect("Failed to find workspace_root in metadata"); + #[cfg(not(windows))] + cmd.process_group(0); + }); - PathBuf::from(workspace_root) - } + let (tokio_child_process, child_stderr) = + TokioChildProcess::builder(command).stderr(Stdio::piped()).spawn()?; - #[tokio::test(flavor = "multi_thread")] - // For some reason this test is quite flakey when ran in the CI but not on developer's - // machines. As a result it is hard to debug, hence we are ignoring it for now. - #[ignore] - async fn test_client_stdio() { - std::process::Command::new("cargo") - .args(["build", "--bin", TEST_SERVER_NAME]) - .status() - .expect("Failed to build binary"); - let workspace_root = get_workspace_root(); - let bin_path = workspace_root.join(TEST_BIN_OUT_DIR).join(TEST_SERVER_NAME); - println!("bin path: {}", bin_path.to_str().unwrap_or("no path found")); - - // Testing 2 concurrent sessions to make sure transport layer does not overlap. - let client_info_one = serde_json::json!({ - "name": "TestClientOne", - "version": "1.0.0" - }); - let client_config_one = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["1".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_one.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) + Ok(Transport::Stdio((tokio_child_process, child_stderr))) }, - }; - let client_info_two = serde_json::json!({ - "name": "TestClientTwo", - "version": "1.0.0" - }); - let client_config_two = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["2".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_two.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let mut client_one = Client::::from_config(client_config_one).expect("Failed to create client"); - let mut client_two = Client::::from_config(client_config_two).expect("Failed to create client"); - let client_one_cap = ClientCapabilities::from(client_info_one); - let client_two_cap = ClientCapabilities::from(client_info_two); - - let (res_one, res_two) = tokio::join!( - time::timeout( - time::Duration::from_secs(10), - test_client_routine(&mut client_one, serde_json::json!(client_one_cap)) - ), - time::timeout( - time::Duration::from_secs(10), - test_client_routine(&mut client_two, serde_json::json!(client_two_cap)) - ) - ); - let res_one = res_one.expect("Client one timed out"); - let res_two = res_two.expect("Client two timed out"); - assert!(res_one.is_ok()); - assert!(res_two.is_ok()); - } + TransportType::Http => { + let http_transport = get_http_transport(os, url, *timeout, scopes, headers, None, messenger).await?; - #[allow(clippy::await_holding_lock)] - async fn test_client_routine( - client: &mut Client, - cap_sent: serde_json::Value, - ) -> Result<(), Box> { - // Test init - let _ = client.init().await.expect("Client init failed"); - tokio::time::sleep(time::Duration::from_millis(1500)).await; - let client_capabilities_sent = client - .request("verify_init_ack_sent", None) - .await - .expect("Verify init ack mock request failed"); - let has_server_recvd_init_ack = client_capabilities_sent - .result - .expect("Failed to retrieve client capabilities sent."); - assert_eq!(has_server_recvd_init_ack.to_string(), "true"); - let cap_recvd = client - .request("verify_init_params_sent", None) - .await - .expect("Verify init params mock request failed"); - let cap_recvd = cap_recvd - .result - .expect("Verify init params mock request does not contain required field (result)"); - assert!(are_json_values_equal(&cap_sent, &cap_recvd)); - - // test list tools - let fake_tool_names = ["get_weather_one", "get_weather_two", "get_weather_three"]; - let mock_result_spec = fake_tool_names.map(create_fake_tool_spec); - let mock_tool_specs_for_verify = serde_json::json!(mock_result_spec.clone()); - let mock_tool_specs_prep_param = mock_result_spec - .iter() - .zip(fake_tool_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_tool_specs_prep_param = - serde_json::to_value(mock_tool_specs_prep_param).expect("Failed to create mock tool specs prep param"); - let _ = client - .request("store_mock_tool_spec", Some(mock_tool_specs_prep_param)) - .await - .expect("Mock tool spec prep failed"); - let tool_spec_recvd = client.request("tools/list", None).await.expect("List tools failed"); - assert!(are_json_values_equal( - tool_spec_recvd - .result - .as_ref() - .and_then(|v| v.get("tools")) - .expect("Failed to retrieve tool specs from result received"), - &mock_tool_specs_for_verify - )); - - // Test list prompts directly - let fake_prompt_names = ["code_review_one", "code_review_two", "code_review_three"]; - let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); - let mock_prompts_for_verify = serde_json::json!(mock_result_prompts.clone()); - let mock_prompts_prep_param = mock_result_prompts - .iter() - .zip(fake_prompt_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_prompts_prep_param = - serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); - let _ = client - .request("store_mock_prompts", Some(mock_prompts_prep_param)) - .await - .expect("Mock prompt prep failed"); - let prompts_recvd = client.request("prompts/list", None).await.expect("List prompts failed"); - client.is_prompts_out_of_date.store(false, Ordering::Release); - assert!(are_json_values_equal( - prompts_recvd - .result - .as_ref() - .and_then(|v| v.get("prompts")) - .expect("Failed to retrieve prompts from results received"), - &mock_prompts_for_verify - )); - - // Test prompts list changed - let fake_prompt_names = ["code_review_four", "code_review_five", "code_review_six"]; - let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); - let mock_prompts_prep_param = mock_result_prompts - .iter() - .zip(fake_prompt_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_prompts_prep_param = - serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); - let _ = client - .request("store_mock_prompts", Some(mock_prompts_prep_param)) - .await - .expect("Mock new prompt request failed"); - // After we send the signal for the server to clear prompts, we should be receiving signal - // to fetch for new prompts, after which we should be getting no prompts. - let is_prompts_out_of_date = client.is_prompts_out_of_date.clone(); - let wait_for_new_prompts = async move { - while !is_prompts_out_of_date.load(Ordering::Acquire) { - tokio::time::sleep(time::Duration::from_millis(100)).await; - } - }; - time::timeout(time::Duration::from_secs(5), wait_for_new_prompts) - .await - .expect("Timed out while waiting for new prompts"); - let new_prompts = client.prompt_gets.read().expect("Failed to read new prompts"); - for k in new_prompts.keys() { - assert!(fake_prompt_names.contains(&k.as_str())); + Ok(Transport::Http(http_transport)) + }, } - - // Test env var inclusion - let env_vars = client.request("get_env_vars", None).await.expect("Get env vars failed"); - let env_one = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_ONE") - .expect("Failed to retrieve env one from env var request"); - let env_two = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_TWO") - .expect("Failed to retrieve env two from env var request"); - let env_one_as_str = serde_json::to_string(env_one).expect("Failed to convert env one to string"); - let env_two_as_str = serde_json::to_string(env_two).expect("Failed to convert env two to string"); - assert_eq!(env_one_as_str, "\"1\"".to_string()); - assert_eq!(env_two_as_str, "\"2\"".to_string()); - - Ok(()) } - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) + async fn on_logging_message( + &self, + params: LoggingMessageNotificationParam, + _context: NotificationContext, + ) { + let level = params.level; + let data = params.data; + let server_name = &self.server_name; + + match level { + LoggingLevel::Error | LoggingLevel::Critical | LoggingLevel::Emergency | LoggingLevel::Alert => { + tracing::error!(target: "mcp", "{}: {}", server_name, data); }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) + LoggingLevel::Warning => { + tracing::warn!(target: "mcp", "{}: {}", server_name, data); + }, + LoggingLevel::Info => { + tracing::info!(target: "mcp", "{}: {}", server_name, data); + }, + LoggingLevel::Debug => { + tracing::debug!(target: "mcp", "{}: {}", server_name, data); + }, + LoggingLevel::Notice => { + tracing::trace!(target: "mcp", "{}: {}", server_name, data); }, - _ => false, } } - fn create_fake_tool_spec(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Get current weather information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - }) + async fn on_tool_list_changed(&self, context: NotificationContext) { + let NotificationContext { peer, .. } = context; + + paginated_fetch! { + final_result_type: ListToolsResult, + content_type: rmcp::model::Tool, + service_method: list_tools, + result_field: tools, + messenger_method: send_tools_list_result, + service: peer, + messenger: self.messenger, + server_name: self.server_name + }; } - fn create_fake_prompts(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Asks the LLM to analyze code quality and suggest improvements", - "arguments": [ - { - "name": "code", - "description": "The code to review", - "required": true - } - ] - }) + async fn on_prompt_list_changed(&self, context: NotificationContext) { + let NotificationContext { peer, .. } = context; + + paginated_fetch! { + final_result_type: ListPromptsResult, + content_type: rmcp::model::Prompt, + service_method: list_prompts, + result_field: prompts, + messenger_method: send_prompts_list_result, + service: peer, + messenger: self.messenger, + server_name: self.server_name + }; } +} - #[cfg(windows)] - mod windows_command_tests { - use super::*; - use crate::mcp_client::transport::stdio::JsonRpcStdioTransport as StdioTransport; - - #[test] - fn test_quote_windows_arg_no_special_chars() { - let result = Client::::quote_windows_arg("simple"); - assert_eq!(result, "simple"); - } - - #[test] - fn test_quote_windows_arg_with_spaces() { - let result = Client::::quote_windows_arg("with spaces"); - assert_eq!(result, "\"with spaces\""); - } - - #[test] - fn test_quote_windows_arg_with_quotes() { - let result = Client::::quote_windows_arg("with \"quotes\""); - assert_eq!(result, "\"with \\\"quotes\\\"\""); - } - - #[test] - fn test_quote_windows_arg_with_backslashes() { - let result = Client::::quote_windows_arg("path\\to\\file"); - assert_eq!(result, "path\\to\\file"); +impl Service for McpClientService { + async fn handle_request( + &self, + request: ::PeerReq, + _context: rmcp::service::RequestContext, + ) -> Result<::Resp, rmcp::ErrorData> { + match request { + ServerRequest::PingRequest(_) => Err(rmcp::ErrorData::method_not_found::()), + ServerRequest::CreateMessageRequest(_) => Err(rmcp::ErrorData::method_not_found::< + rmcp::model::CreateMessageRequestMethod, + >()), + ServerRequest::ListRootsRequest(_) => { + Err(rmcp::ErrorData::method_not_found::()) + }, + ServerRequest::CreateElicitationRequest(_) => Err(rmcp::ErrorData::method_not_found::< + rmcp::model::ElicitationCreateRequestMethod, + >()), } + } - #[test] - fn test_quote_windows_arg_with_trailing_backslashes() { - let result = Client::::quote_windows_arg("path\\to\\dir\\"); - assert_eq!(result, "path\\to\\dir\\"); - } + async fn handle_notification( + &self, + notification: ::PeerNot, + context: NotificationContext, + ) -> Result<(), rmcp::ErrorData> { + match notification { + ServerNotification::ToolListChangedNotification(_) => self.on_tool_list_changed(context).await, + ServerNotification::LoggingMessageNotification(notification) => { + self.on_logging_message(notification.params, context).await; + }, + ServerNotification::PromptListChangedNotification(_) => self.on_prompt_list_changed(context).await, + // TODO: support these + ServerNotification::CancelledNotification(_) => (), + ServerNotification::ResourceUpdatedNotification(_) => (), + ServerNotification::ResourceListChangedNotification(_) => (), + ServerNotification::ProgressNotification(_) => (), + }; + Ok(()) + } - #[test] - fn test_quote_windows_arg_with_backslashes_before_quote() { - let result = Client::::quote_windows_arg("path\\\\\"quoted\""); - assert_eq!(result, "\"path\\\\\\\\\\\"quoted\\\"\""); + fn get_info(&self) -> ::Info { + InitializeRequestParam { + protocol_version: Default::default(), + capabilities: Default::default(), + client_info: Implementation { + name: "Q DEV CLI".to_string(), + version: "1.0.0".to_string(), + }, } + } +} - #[test] - fn test_quote_windows_arg_complex_case() { - let result = Client::::quote_windows_arg("C:\\Program Files\\My App\\bin\\app.exe"); - assert_eq!(result, "\"C:\\Program Files\\My App\\bin\\app.exe\""); - } +/// InitializedMcpClient is the return of [McpClientService::init]. +/// This is necessitated by the fact that [Service::serve], the command to spawn the process, is +/// async and does not resolve immediately. This delay can be significant and causes long perceived +/// latency during start up. However, our current architecture still requires the main chat loop to +/// have ownership of [RunningService]. +/// The solution chosen here is to instead spawn a task and have [Service::serve] called there and +/// return the handle to said task, stored in the [InitializedMcpClient::Pending] variant. This +/// enum is then flipped lazily (if applicable) when a [RunningService] is needed. +pub enum InitializedMcpClient { + Pending(JoinHandle>), + Ready(RunningService), +} - #[test] - fn test_quote_windows_arg_with_tabs_and_newlines() { - let result = Client::::quote_windows_arg("with\ttabs\nand\rnewlines"); - assert_eq!(result, "\"with\ttabs\nand\rnewlines\""); +impl std::fmt::Debug for InitializedMcpClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InitializedMcpClient::Pending(_) => f.debug_tuple("Pending").field(&"JoinHandle<..>").finish(), + InitializedMcpClient::Ready(_) => f.debug_tuple("Ready").field(&"RunningService<..>").finish(), } + } +} - #[test] - fn test_quote_windows_arg_edge_case_only_backslashes() { - let result = Client::::quote_windows_arg("\\\\\\"); - assert_eq!(result, "\\\\\\"); +impl InitializedMcpClient { + pub async fn get_running_service(&mut self) -> Result<&RunningService, McpClientError> { + match self { + InitializedMcpClient::Pending(handle) if handle.is_finished() => { + let running_service = handle.await??; + *self = InitializedMcpClient::Ready(running_service); + let InitializedMcpClient::Ready(running_service) = self else { + unreachable!() + }; + + Ok(running_service) + }, + InitializedMcpClient::Ready(running_service) => Ok(running_service), + InitializedMcpClient::Pending(_) => Err(McpClientError::NotReady), } + } +} - #[test] - fn test_quote_windows_arg_edge_case_only_quotes() { - let result = Client::::quote_windows_arg("\"\"\""); - assert_eq!(result, "\"\\\"\\\"\\\"\""); - } +#[cfg(test)] +mod tests { + use super::*; - // Tests for build_windows_command function - #[test] - fn test_build_windows_command_empty_args() { - let bin_path = "myapp"; - let args = vec![]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "myapp"); + #[tokio::test] + async fn test_substitute_env_vars() { + // Set a test environment variable + let os = Os::new().await.unwrap(); + unsafe { + os.env.set_var("TEST_VAR", "test_value"); } - #[test] - fn test_build_windows_command_uvx_example() { - let bin_path = "uvx"; - let args = vec!["mcp-server-fetch".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "uvx mcp-server-fetch"); - } + // Test basic substitution + assert_eq!( + substitute_env_vars("Value is ${env:TEST_VAR}", &os.env), + "Value is test_value" + ); - #[test] - fn test_build_windows_command_npx_example() { - let bin_path = "npx"; - let args = vec!["-y".to_string(), "@modelcontextprotocol/server-memory".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "npx -y @modelcontextprotocol/server-memory"); - } + // Test multiple substitutions + assert_eq!( + substitute_env_vars("${env:TEST_VAR} and ${env:TEST_VAR}", &os.env), + "test_value and test_value" + ); - #[test] - fn test_build_windows_command_docker_example() { - let bin_path = "docker"; - let args = vec![ - "run".to_string(), - "-i".to_string(), - "--rm".to_string(), - "-e".to_string(), - "GITHUB_PERSONAL_ACCESS_TOKEN".to_string(), - "ghcr.io/github/github-mcp-server".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "docker run -i --rm -e GITHUB_PERSONAL_ACCESS_TOKEN ghcr.io/github/github-mcp-server" - ); - } + // Test non-existent variable + assert_eq!( + substitute_env_vars("${env:NON_EXISTENT_VAR}", &os.env), + "${NON_EXISTENT_VAR}" + ); - #[test] - fn test_build_windows_command_with_quotes_in_args() { - let bin_path = "myapp"; - let args = vec!["--config".to_string(), "{\"key\": \"value\"}".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "myapp --config \"{\\\"key\\\": \\\"value\\\"}\""); - } + // Test mixed content + assert_eq!( + substitute_env_vars("Prefix ${env:TEST_VAR} suffix", &os.env), + "Prefix test_value suffix" + ); + } - #[test] - fn test_build_windows_command_with_spaces_in_path() { - let bin_path = "C:\\Program Files\\My App\\bin\\app.exe"; - let args = vec!["--input".to_string(), "file with spaces.txt".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "\"C:\\Program Files\\My App\\bin\\app.exe\" --input \"file with spaces.txt\"" - ); + #[tokio::test] + async fn test_process_env_vars() { + let os = Os::new().await.unwrap(); + unsafe { + os.env.set_var("TEST_VAR", "test_value"); } - #[test] - fn test_build_windows_command_complex_args() { - let bin_path = "myapp"; - let args = vec![ - "--config".to_string(), - "C:\\Users\\test\\config.json".to_string(), - "--output".to_string(), - "C:\\Output\\result file.txt".to_string(), - "--verbose".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "myapp --config C:\\Users\\test\\config.json --output \"C:\\Output\\result file.txt\" --verbose" - ); - } + let mut env_vars = HashMap::new(); + env_vars.insert("KEY1".to_string(), "Value is ${env:TEST_VAR}".to_string()); + env_vars.insert("KEY2".to_string(), "No substitution".to_string()); - #[test] - fn test_build_windows_command_with_environment_variables() { - let bin_path = "cmd"; - let args = vec!["/c".to_string(), "echo %PATH%".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "cmd /c \"echo %PATH%\""); - } + process_env_vars(&mut env_vars, &os.env); - #[test] - fn test_build_windows_command_real_world_python() { - let bin_path = "python"; - let args = vec![ - "-m".to_string(), - "mcp_server".to_string(), - "--config".to_string(), - "C:\\configs\\server.json".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "python -m mcp_server --config C:\\configs\\server.json"); - } + assert_eq!(env_vars.get("KEY1").unwrap(), "Value is test_value"); + assert_eq!(env_vars.get("KEY2").unwrap(), "No substitution"); } } diff --git a/crates/chat-cli/src/mcp_client/error.rs b/crates/chat-cli/src/mcp_client/error.rs deleted file mode 100644 index 01f77cfa8b..0000000000 --- a/crates/chat-cli/src/mcp_client/error.rs +++ /dev/null @@ -1,66 +0,0 @@ -/// Error codes as defined in the MCP protocol. -/// -/// These error codes are based on the JSON-RPC 2.0 specification with additional -/// MCP-specific error codes in the -32000 to -32099 range. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(i32)] -pub enum ErrorCode { - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - ParseError = -32700, - - /// The JSON sent is not a valid Request object. - InvalidRequest = -32600, - - /// The method does not exist / is not available. - MethodNotFound = -32601, - - /// Invalid method parameter(s). - InvalidParams = -32602, - - /// Internal JSON-RPC error. - InternalError = -32603, - - /// Server has not been initialized. - /// This error is returned when a request is made before the server - /// has been properly initialized. - ServerNotInitialized = -32002, - - /// Unknown error code. - /// This error is returned when an error code is received that is not - /// recognized by the implementation. - Unknown = -32001, - - /// Request failed. - /// This error is returned when a request fails for a reason not covered - /// by other error codes. - RequestFailed = -32000, -} - -impl From for ErrorCode { - fn from(code: i32) -> Self { - match code { - -32700 => ErrorCode::ParseError, - -32600 => ErrorCode::InvalidRequest, - -32601 => ErrorCode::MethodNotFound, - -32602 => ErrorCode::InvalidParams, - -32603 => ErrorCode::InternalError, - -32002 => ErrorCode::ServerNotInitialized, - -32001 => ErrorCode::Unknown, - -32000 => ErrorCode::RequestFailed, - _ => ErrorCode::Unknown, - } - } -} - -impl From for i32 { - fn from(code: ErrorCode) -> Self { - code as i32 - } -} - -impl std::fmt::Display for ErrorCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} diff --git a/crates/chat-cli/src/mcp_client/facilitator_types.rs b/crates/chat-cli/src/mcp_client/facilitator_types.rs deleted file mode 100644 index 87fbd79b27..0000000000 --- a/crates/chat-cli/src/mcp_client/facilitator_types.rs +++ /dev/null @@ -1,248 +0,0 @@ -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; - -/// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination -#[allow(clippy::enum_variant_names)] -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PaginationSupportedOps { - ResourcesList, - ResourceTemplatesList, - PromptsList, - ToolsList, -} - -impl PaginationSupportedOps { - pub fn as_key(&self) -> &str { - match self { - PaginationSupportedOps::ResourcesList => "resources", - PaginationSupportedOps::ResourceTemplatesList => "resourceTemplates", - PaginationSupportedOps::PromptsList => "prompts", - PaginationSupportedOps::ToolsList => "tools", - } - } -} - -impl TryFrom<&str> for PaginationSupportedOps { - type Error = OpsConversionError; - - fn try_from(value: &str) -> Result { - match value { - "resources/list" => Ok(PaginationSupportedOps::ResourcesList), - "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplatesList), - "prompts/list" => Ok(PaginationSupportedOps::PromptsList), - "tools/list" => Ok(PaginationSupportedOps::ToolsList), - _ => Err(OpsConversionError::InvalidMethod), - } - } -} - -#[derive(Error, Debug)] -pub enum OpsConversionError { - #[error("Invalid method encountered")] - InvalidMethod, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "camelCase")] -/// Role assumed for a particular message -pub enum Role { - User, - Assistant, -} - -impl std::fmt::Display for Role { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "user"), - Role::Assistant => write!(f, "assistant"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing resources operation -pub struct ResourcesListResult { - /// List of resources - pub resources: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -/// Result of listing resource templates operation -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ResourceTemplatesListResult { - /// List of resource templates - pub resource_templates: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of prompt listing query -pub struct PromptsListResult { - /// List of prompts - pub prompts: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents an argument to be supplied to a [PromptGet] -pub struct PromptGetArg { - /// The name identifier of the prompt - pub name: String, - /// Optional description providing context about the prompt - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Indicates whether a response to this prompt is required - /// If not specified, defaults to false - #[serde(skip_serializing_if = "Option::is_none")] - pub required: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents a request to get a prompt from a mcp server -pub struct PromptGet { - /// Unique identifier for the prompt - pub name: String, - /// Optional description providing context about the prompt's purpose - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Optional list of arguments that define the structure of information to be collected - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// `result` field in [JsonRpcResponse] from a `prompts/get` request -pub struct PromptGetResult { - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub messages: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Completed prompt from `prompts/get` to be returned by a mcp server -pub struct Prompt { - pub role: Role, - pub content: MessageContent, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing tools operation -pub struct ToolsListResult { - /// List of tools - pub tools: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolCallResult { - pub content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub is_error: Option, -} - -/// Content of a message -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum MessageContent { - /// Text content - Text { - /// The text content - text: String, - }, - /// Image content - #[serde(rename_all = "camelCase")] - Image { - /// base64-encoded-data - data: String, - mime_type: String, - }, - /// Resource content - Resource { - /// The resource - resource: Resource, - }, -} - -impl From for String { - fn from(val: MessageContent) -> Self { - match val { - MessageContent::Text { text } => text, - MessageContent::Image { data, mime_type } => serde_json::json!({ - "data": data, - "mime_type": mime_type - }) - .to_string(), - MessageContent::Resource { resource } => serde_json::json!(resource).to_string(), - } - } -} - -impl std::fmt::Display for MessageContent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MessageContent::Text { text } => write!(f, "{}", text), - MessageContent::Image { data: _, mime_type } => write!(f, "Image [base64-encoded-string] ({})", mime_type), - MessageContent::Resource { resource } => write!(f, "Resource: {} ({})", resource.title, resource.uri), - } - } -} - -/// Resource contents -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum ResourceContents { - Text { text: String }, - Blob { data: Vec }, -} - -/// A resource in the system -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Resource { - /// Unique identifier for the resource - pub uri: String, - /// Human-readable title - pub title: String, - /// Optional description - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Resource contents - pub contents: ResourceContents, -} - -/// Represents the capabilities supported by a Model Context Protocol server -/// This is the "capabilities" field in the result of a response for init -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerCapabilities { - /// Configuration for server logging capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub logging: Option, - /// Configuration for prompt-related capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub prompts: Option, - /// Configuration for resource management capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub resources: Option, - /// Configuration for tool integration capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option, -} diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs index 75723cd9c7..40e9bc84ca 100644 --- a/crates/chat-cli/src/mcp_client/messenger.rs +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -1,11 +1,18 @@ +use rmcp::model::{ + ListPromptsResult, + ListResourceTemplatesResult, + ListResourcesResult, + ListToolsResult, +}; +use rmcp::{ + Peer, + RoleClient, + ServiceError, +}; use thiserror::Error; -use super::{ - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ToolsListResult, -}; +pub type Result = core::result::Result; +pub type MessengerResult = core::result::Result<(), MessengerError>; /// An interface that abstracts the implementation for information delivery from client and its /// consumer. It is through this interface secondary information (i.e. information that are needed @@ -16,26 +23,42 @@ use super::{ pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { /// Sends the result of a tools list operation to the consumer /// This function is used to deliver information about available tools - async fn send_tools_list_result(&self, result: eyre::Result) -> Result<(), MessengerError>; + async fn send_tools_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult; /// Sends the result of a prompts list operation to the consumer /// This function is used to deliver information about available prompts - async fn send_prompts_list_result(&self, result: eyre::Result) -> Result<(), MessengerError>; + async fn send_prompts_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult; /// Sends the result of a resources list operation to the consumer /// This function is used to deliver information about available resources - async fn send_resources_list_result(&self, result: eyre::Result) - -> Result<(), MessengerError>; + async fn send_resources_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult; /// Sends the result of a resource templates list operation to the consumer /// This function is used to deliver information about available resource templates async fn send_resource_templates_list_result( &self, - result: eyre::Result, - ) -> Result<(), MessengerError>; + result: Result, + peer: Option>, + ) -> MessengerResult; + + /// Sends an OAuth authorization link to the consumer + /// This function is used to deliver OAuth links that users need to visit for authentication + async fn send_oauth_link(&self, link: String) -> MessengerResult; /// Signals to the orchestrator that a server has started initializing - async fn send_init_msg(&self) -> Result<(), MessengerError>; + async fn send_init_msg(&self) -> MessengerResult; /// Signals to the orchestrator that a server has deinitialized fn send_deinit_msg(&self); @@ -56,29 +79,43 @@ pub struct NullMessenger; #[async_trait::async_trait] impl Messenger for NullMessenger { - async fn send_tools_list_result(&self, _result: eyre::Result) -> Result<(), MessengerError> { + async fn send_tools_list_result( + &self, + _result: Result, + _peer: Option>, + ) -> MessengerResult { Ok(()) } - async fn send_prompts_list_result(&self, _result: eyre::Result) -> Result<(), MessengerError> { + async fn send_prompts_list_result( + &self, + _result: Result, + _peer: Option>, + ) -> MessengerResult { Ok(()) } async fn send_resources_list_result( &self, - _result: eyre::Result, - ) -> Result<(), MessengerError> { + _result: Result, + _peer: Option>, + ) -> MessengerResult { Ok(()) } async fn send_resource_templates_list_result( &self, - _result: eyre::Result, - ) -> Result<(), MessengerError> { + _result: Result, + _peer: Option>, + ) -> MessengerResult { + Ok(()) + } + + async fn send_oauth_link(&self, _link: String) -> MessengerResult { Ok(()) } - async fn send_init_msg(&self) -> Result<(), MessengerError> { + async fn send_init_msg(&self) -> MessengerResult { Ok(()) } diff --git a/crates/chat-cli/src/mcp_client/mod.rs b/crates/chat-cli/src/mcp_client/mod.rs index 51f8b178fd..432e4421f3 100644 --- a/crates/chat-cli/src/mcp_client/mod.rs +++ b/crates/chat-cli/src/mcp_client/mod.rs @@ -1,13 +1,6 @@ pub mod client; -pub mod error; -pub mod facilitator_types; pub mod messenger; -pub mod server; -pub mod transport; +pub mod oauth_util; pub use client::*; -pub use facilitator_types::*; -pub use messenger::*; -#[allow(unused_imports)] -pub use server::*; -pub use transport::*; +pub use oauth_util::*; diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs new file mode 100644 index 0000000000..8af59a6c13 --- /dev/null +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -0,0 +1,509 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; + +use http::{ + HeaderMap, + StatusCode, +}; +use http_body_util::Full; +use hyper::Response; +use hyper::body::Bytes; +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use reqwest::Client; +use rmcp::serde_json; +use rmcp::transport::auth::{ + AuthClient, + OAuthClientConfig, + OAuthState, + OAuthTokenResponse, +}; +use rmcp::transport::streamable_http_client::{ + StreamableHttpClientTransportConfig, + StreamableHttpClientWorker, +}; +use rmcp::transport::{ + AuthorizationManager, + AuthorizationSession, + StreamableHttpClientTransport, + WorkerTransport, +}; +use serde::{ + Deserialize, + Serialize, +}; +use sha2::{ + Digest, + Sha256, +}; +use tokio::sync::oneshot::Sender; +use tokio_util::sync::CancellationToken; +use tracing::{ + debug, + error, + info, +}; +use url::Url; + +use super::messenger::Messenger; +use crate::os::Os; +use crate::util::directories::{ + DirectoryError, + get_mcp_auth_dir, +}; + +#[derive(Debug, thiserror::Error)] +pub enum OauthUtilError { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Parse(#[from] url::ParseError), + #[error(transparent)] + Auth(#[from] rmcp::transport::AuthError), + #[error(transparent)] + Serde(#[from] serde_json::Error), + #[error("Missing authorization manager")] + MissingAuthorizationManager, + #[error(transparent)] + OneshotRecv(#[from] tokio::sync::oneshot::error::RecvError), + #[error(transparent)] + Directory(#[from] DirectoryError), + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + #[error("{0}")] + Http(String), + #[error("Malformed directory")] + MalformDirectory, + #[error("Missing credential")] + MissingCredentials, +} + +/// A guard that automatically cancels the cancellation token when dropped. +/// This ensures that the OAuth loopback server is properly cleaned up +/// when the guard goes out of scope. +struct LoopBackDropGuard { + cancellation_token: CancellationToken, +} + +impl Drop for LoopBackDropGuard { + fn drop(&mut self) { + self.cancellation_token.cancel(); + } +} + +/// This is modeled after [OAuthClientConfig] +/// It's only here because [OAuthClientConfig] does not implement Serialize and Deserialize +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Registration { + pub client_id: String, + pub client_secret: Option, + pub scopes: Vec, + pub redirect_uri: String, +} + +impl From for Registration { + fn from(value: OAuthClientConfig) -> Self { + Self { + client_id: value.client_id, + client_secret: value.client_secret, + scopes: value.scopes, + redirect_uri: value.redirect_uri, + } + } +} + +/// A wrapper that manages an authenticated MCP client. +/// +/// This struct wraps an `AuthClient` and provides access to OAuth credentials +/// for MCP server connections that require authentication. The credentials +/// are managed separately from this wrapper's lifecycle. +#[derive(Clone, Debug)] +pub struct AuthClientWrapper { + pub cred_full_path: PathBuf, + pub auth_client: AuthClient, +} + +impl AuthClientWrapper { + pub fn new(cred_full_path: PathBuf, auth_client: AuthClient) -> Self { + Self { + cred_full_path, + auth_client, + } + } + + /// Refreshes token in memory using the registration read from when the auth client was + /// spawned. This also persists the retrieved token + pub async fn refresh_token(&self) -> Result<(), OauthUtilError> { + let cred = self.auth_client.auth_manager.lock().await.refresh_token().await?; + let parent_path = self.cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir_all(parent_path).await?; + + let cred_as_bytes = serde_json::to_string_pretty(&cred)?; + tokio::fs::write(&self.cred_full_path, &cred_as_bytes).await?; + + Ok(()) + } +} + +/// HTTP transport wrapper that handles both authenticated and non-authenticated MCP connections. +/// +/// This enum provides two variants for different authentication scenarios: +/// - `WithAuth`: Used when the MCP server requires OAuth authentication, containing both the +/// transport worker and an auth client guard that manages credential persistence +/// - `WithoutAuth`: Used for servers that don't require authentication, containing only the basic +/// transport worker +/// +/// The appropriate variant is automatically selected based on the server's response to +/// an initial probe request during transport creation. +pub enum HttpTransport { + WithAuth( + ( + WorkerTransport>>, + AuthClientWrapper, + ), + ), + WithoutAuth(WorkerTransport>), +} + +pub fn get_default_scopes() -> &'static [&'static str] { + &["openid", "email", "profile", "offline_access"] +} + +pub async fn get_http_transport( + os: &Os, + url: &str, + timeout: u64, + scopes: &[String], + headers: &HashMap, + auth_client: Option>, + messenger: &dyn Messenger, +) -> Result { + let cred_dir = get_mcp_auth_dir(os)?; + let url = Url::from_str(url)?; + let key = compute_key(&url); + let cred_full_path = cred_dir.join(format!("{key}.token.json")); + let reg_full_path = cred_dir.join(format!("{key}.registration.json")); + + let mut client_builder = reqwest::ClientBuilder::new().timeout(std::time::Duration::from_millis(timeout)); + if !headers.is_empty() { + let headers = HeaderMap::try_from(headers).map_err(|e| OauthUtilError::Http(e.to_string()))?; + client_builder = client_builder.default_headers(headers); + }; + let reqwest_client = client_builder.build()?; + + // The probe request, like all other request, should adhere to the standards as per https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + let mut probe_request = reqwest_client.post(url.clone()); + probe_request = probe_request.header("Accept", "application/json, text/event-stream"); + let probe_resp = probe_request.send().await?; + match probe_resp.status() { + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + let auth_client = match auth_client { + Some(auth_client) => auth_client, + None => { + let am = get_auth_manager( + url.clone(), + cred_full_path.clone(), + reg_full_path.clone(), + scopes, + messenger, + ) + .await?; + AuthClient::new(reqwest_client, am) + }, + }; + let transport = + StreamableHttpClientTransport::with_client(auth_client.clone(), StreamableHttpClientTransportConfig { + uri: url.as_str().into(), + allow_stateless: true, + ..Default::default() + }); + + let auth_dg = AuthClientWrapper::new(cred_full_path, auth_client); + + Ok(HttpTransport::WithAuth((transport, auth_dg))) + }, + _ => { + let transport = + StreamableHttpClientTransport::with_client(reqwest_client, StreamableHttpClientTransportConfig { + uri: url.as_str().into(), + allow_stateless: true, + ..Default::default() + }); + + Ok(HttpTransport::WithoutAuth(transport)) + }, + } +} + +async fn get_auth_manager( + url: Url, + cred_full_path: PathBuf, + reg_full_path: PathBuf, + scopes: &[String], + messenger: &dyn Messenger, +) -> Result { + let cred_as_bytes = tokio::fs::read(&cred_full_path).await; + let reg_as_bytes = tokio::fs::read(®_full_path).await; + let mut oauth_state = OAuthState::new(url, None).await?; + + match (cred_as_bytes, reg_as_bytes) { + (Ok(cred_as_bytes), Ok(reg_as_bytes)) => { + let token = serde_json::from_slice::(&cred_as_bytes)?; + let reg = serde_json::from_slice::(®_as_bytes)?; + + oauth_state.set_credentials(®.client_id, token).await?; + + debug!("## mcp: credentials set with cache"); + + Ok(oauth_state + .into_authorization_manager() + .ok_or(OauthUtilError::MissingAuthorizationManager)?) + }, + _ => { + info!("Error reading cached credentials"); + debug!("## mcp: cache read failed. constructing auth manager from scratch"); + let (am, redirect_uri) = get_auth_manager_impl(oauth_state, scopes, messenger).await?; + + // Client registration is done in [start_authorization] + // If we have gotten past that point that means we have the info to persist the + // registration on disk. + let (client_id, credentials) = am.get_credentials().await?; + let reg = Registration { + client_id, + client_secret: None, + scopes: get_default_scopes() + .iter() + .map(|s| (*s).to_string()) + .collect::>(), + redirect_uri, + }; + let reg_as_str = serde_json::to_string_pretty(®)?; + let reg_parent_path = reg_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir_all(reg_parent_path).await?; + tokio::fs::write(reg_full_path, ®_as_str).await?; + + let credentials = credentials.ok_or(OauthUtilError::MissingCredentials)?; + + let cred_parent_path = cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir_all(cred_parent_path).await?; + let reg_as_str = serde_json::to_string_pretty(&credentials)?; + tokio::fs::write(cred_full_path, ®_as_str).await?; + + Ok(am) + }, + } +} + +async fn get_auth_manager_impl( + mut oauth_state: OAuthState, + scopes: &[String], + messenger: &dyn Messenger, +) -> Result<(AuthorizationManager, String), OauthUtilError> { + let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let (tx, rx) = tokio::sync::oneshot::channel::(); + + let (actual_addr, _dg) = make_svc(tx, socket_addr, cancellation_token).await?; + info!("Listening on local host port {:?} for oauth", actual_addr); + + let redirect_uri = format!("http://{}", actual_addr); + let scopes_as_str = scopes.iter().map(String::as_str).collect::>(); + let scopes_as_slice = scopes_as_str.as_slice(); + start_authorization(&mut oauth_state, scopes_as_slice, &redirect_uri).await?; + + let auth_url = oauth_state.get_authorization_url().await?; + _ = messenger.send_oauth_link(auth_url).await; + + let auth_code = rx.await?; + oauth_state.handle_callback(&auth_code).await?; + let am = oauth_state + .into_authorization_manager() + .ok_or(OauthUtilError::MissingAuthorizationManager)?; + + Ok((am, redirect_uri)) +} + +pub fn compute_key(rs: &Url) -> String { + let mut hasher = Sha256::new(); + let input = format!("{}{}", rs.origin().ascii_serialization(), rs.path()); + hasher.update(input.as_bytes()); + format!("{:x}", hasher.finalize()) +} + +/// This is our own implementation of [OAuthState::start_authorization]. +/// This differs from [OAuthState::start_authorization] by assigning our own client_id for DCR. +/// We need this because the SDK hardcodes their own client id. And some servers will use client_id +/// to identify if a client is even allowed to perform the auth handshake. +async fn start_authorization( + oauth_state: &mut OAuthState, + scopes: &[&str], + redirect_uri: &str, +) -> Result<(), OauthUtilError> { + // DO NOT CHANGE THIS + // This string has significance as it is used for remote servers to identify us + const CLIENT_ID: &str = "Q DEV CLI"; + + let stub_cred = get_stub_credentials()?; + oauth_state.set_credentials(CLIENT_ID, stub_cred).await?; + + // The setting of credentials would put the oauth state into authorize. + if let OAuthState::Authorized(auth_manager) = oauth_state { + // set redirect uri + let config = OAuthClientConfig { + client_id: CLIENT_ID.to_string(), + client_secret: None, + scopes: scopes.iter().map(|s| (*s).to_string()).collect(), + redirect_uri: redirect_uri.to_string(), + }; + + // try to dynamic register client + let config = match auth_manager.register_client(CLIENT_ID, redirect_uri).await { + Ok(config) => config, + Err(e) => { + eprintln!("Dynamic registration failed: {}", e); + // fallback to default config + config + }, + }; + // reset client config + auth_manager.configure_client(config)?; + let auth_url = auth_manager.get_authorization_url(scopes).await?; + + let mut stub_auth_manager = AuthorizationManager::new("http://localhost").await?; + std::mem::swap(auth_manager, &mut stub_auth_manager); + + let session = AuthorizationSession { + auth_manager: stub_auth_manager, + auth_url, + redirect_uri: redirect_uri.to_string(), + }; + + let mut new_oauth_state = OAuthState::Session(session); + std::mem::swap(oauth_state, &mut new_oauth_state); + } else { + unreachable!() + } + + Ok(()) +} + +/// This looks silly but [rmcp::transport::auth::OAuthTokenResponse] is private and there is no +/// other way to create this directly +fn get_stub_credentials() -> Result { + const STUB_TOKEN: &str = r#" + { + "access_token": "stub", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "stub", + "scope": "stub" + } + "#; + + serde_json::from_str::(STUB_TOKEN) +} + +async fn make_svc( + one_shot_sender: Sender, + socket_addr: SocketAddr, + cancellation_token: CancellationToken, +) -> Result<(SocketAddr, LoopBackDropGuard), OauthUtilError> { + #[derive(Clone, Debug)] + struct LoopBackForSendingAuthCode { + one_shot_sender: Arc>>>, + } + + #[derive(Debug, thiserror::Error)] + enum LoopBackError { + #[error("Poison error encountered: {0}")] + Poison(String), + #[error(transparent)] + Http(#[from] http::Error), + #[error("Failed to send auth code: {0}")] + Send(String), + } + + fn mk_response(s: String) -> Result>, LoopBackError> { + Ok(Response::builder().body(Full::new(Bytes::from(s)))?) + } + + impl hyper::service::Service> for LoopBackForSendingAuthCode { + type Error = LoopBackError; + type Future = Pin> + Send>>; + type Response = Response>; + + fn call(&self, req: hyper::Request) -> Self::Future { + let uri = req.uri(); + let query = uri.query().unwrap_or(""); + let params: std::collections::HashMap = + url::form_urlencoded::parse(query.as_bytes()).into_owned().collect(); + debug!("## mcp: uri: {}, query: {}, params: {:?}", uri, query, params); + + let self_clone = self.clone(); + Box::pin(async move { + let error = params.get("error"); + let resp = if let Some(err) = error { + mk_response(format!( + "OAuth failed. Check URL for precise reasons. Possible reasons: {}.\n\ + If this is scope related, you can try configuring the server scopes \n\ + to be an empty array by adding \"oauthScopes\": [] to your server config.\n\ + Example: {{\"type\": \"http\", \"uri\": \"https://example.com/mcp\", \"oauthScopes\": []}}\n", + err + )) + } else { + mk_response("You can close this page now".to_string()) + }; + + let code = params.get("code").cloned().unwrap_or_default(); + if let Some(sender) = self_clone + .one_shot_sender + .lock() + .map_err(|e| LoopBackError::Poison(e.to_string()))? + .take() + { + sender.send(code).map_err(LoopBackError::Send)?; + } + + resp + }) + } + } + + let listener = tokio::net::TcpListener::bind(socket_addr).await?; + let actual_addr = listener.local_addr()?; + let cancellation_token_clone = cancellation_token.clone(); + let dg = LoopBackDropGuard { + cancellation_token: cancellation_token_clone, + }; + + let loop_back = LoopBackForSendingAuthCode { + one_shot_sender: Arc::new(std::sync::Mutex::new(Some(one_shot_sender))), + }; + + // This is one and done + // This server only needs to last as long as it takes to send the auth code or to fail the auth + // flow + tokio::spawn(async move { + let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); + + tokio::select! { + _ = cancellation_token.cancelled() => { + info!("Oauth loopback server cancelled"); + }, + res = http1::Builder::new().serve_connection(io, loop_back) => { + if let Err(err) = res { + error!("Auth code loop back has failed: {:?}", err); + } + } + } + + Ok::<(), eyre::Report>(()) + }); + + Ok((actual_addr, dg)) +} diff --git a/crates/chat-cli/src/mcp_client/server.rs b/crates/chat-cli/src/mcp_client/server.rs deleted file mode 100644 index 7b320a2c6e..0000000000 --- a/crates/chat-cli/src/mcp_client/server.rs +++ /dev/null @@ -1,311 +0,0 @@ -#![allow(dead_code)] -use std::collections::HashMap; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - Mutex, -}; - -use tokio::io::{ - Stdin, - Stdout, -}; -use tokio::task::JoinHandle; - -use super::Listener as _; -use super::client::StdioTransport; -use super::error::ErrorCode; -use super::transport::base_protocol::{ - JsonRpcError, - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcResponse, -}; -use super::transport::stdio::JsonRpcStdioTransport; -use super::transport::{ - JsonRpcVersion, - Transport, - TransportError, -}; - -pub type Request = serde_json::Value; -pub type Response = Option; -pub type InitializedServer = JoinHandle>; - -pub trait PreServerRequestHandler { - fn register_pending_request_callback(&mut self, cb: impl Fn(u64) -> Option + Send + Sync + 'static); - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ); -} - -#[async_trait::async_trait] -pub trait ServerRequestHandler: PreServerRequestHandler + Send + Sync + 'static { - async fn handle_initialize(&self, params: Option) -> Result; - async fn handle_incoming(&self, method: &str, params: Option) -> Result; - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError>; - async fn handle_shutdown(&self) -> Result<(), ServerError>; -} - -pub struct Server { - transport: Option>, - handler: Option, - #[allow(dead_code)] - pending_requests: Arc>>, - #[allow(dead_code)] - current_id: Arc, -} - -#[derive(Debug, thiserror::Error)] -pub enum ServerError { - #[error(transparent)] - TransportError(#[from] TransportError), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error(transparent)] - TokioJoinError(#[from] tokio::task::JoinError), - #[error("Failed to obtain mutex lock")] - MutexError, - #[error("Failed to obtain request method")] - MissingMethod, - #[error("Failed to obtain request id")] - MissingId, - #[error("Failed to initialize server. Missing transport")] - MissingTransport, - #[error("Failed to initialize server. Missing handler")] - MissingHandler, -} - -impl Server -where - H: ServerRequestHandler, -{ - pub fn new(mut handler: H, stdin: Stdin, stdout: Stdout) -> Result { - let transport = Arc::new(JsonRpcStdioTransport::server(stdin, stdout)?); - let pending_requests = Arc::new(Mutex::new(HashMap::::new())); - let pending_requests_clone_one = pending_requests.clone(); - let current_id = Arc::new(AtomicU64::new(0)); - let pending_request_getter = move |id: u64| -> Option { - match pending_requests_clone_one.lock() { - Ok(mut p) => p.remove(&id), - Err(_) => None, - } - }; - handler.register_pending_request_callback(pending_request_getter); - let transport_clone = transport.clone(); - let pending_request_clone_two = pending_requests.clone(); - let current_id_clone = current_id.clone(); - let request_sender = move |method: &str, params: Option| -> Result<(), ServerError> { - let id = current_id_clone.fetch_add(1, Ordering::SeqCst); - let msg = match method.split_once("/") { - Some(("request", _)) => { - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - let msg = JsonRpcMessage::Request(request.clone()); - #[allow(clippy::map_err_ignore)] - let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; - pending_request.insert(id, request); - Some(msg) - }, - Some(("notifications", _)) => { - let notif = JsonRpcNotification { - jsonrpc: JsonRpcVersion::default(), - method: method.to_owned(), - params, - }; - let msg = JsonRpcMessage::Notification(notif); - Some(msg) - }, - _ => None, - }; - if let Some(msg) = msg { - let transport = transport_clone.clone(); - tokio::task::spawn(async move { - let _ = transport.send(&msg).await; - }); - } - Ok(()) - }; - handler.register_send_request_callback(request_sender); - let server = Self { - transport: Some(transport), - handler: Some(handler), - pending_requests, - current_id, - }; - Ok(server) - } -} - -impl Server -where - T: Transport, - H: ServerRequestHandler, -{ - pub fn init(mut self) -> Result { - let transport = self.transport.take().ok_or(ServerError::MissingTransport)?; - let handler = Arc::new(self.handler.take().ok_or(ServerError::MissingHandler)?); - let has_initialized = Arc::new(AtomicBool::new(false)); - let listener = tokio::spawn(async move { - let mut listener = transport.get_listener(); - loop { - let request = listener.recv().await; - let transport_clone = transport.clone(); - let has_init_clone = has_initialized.clone(); - let handler_clone = handler.clone(); - tokio::task::spawn(async move { - process_request(has_init_clone, transport_clone, handler_clone, request).await; - }); - } - }); - Ok(listener) - } -} - -async fn process_request( - has_initialized: Arc, - transport: Arc, - handler: Arc, - request: Result, -) where - T: Transport, - H: ServerRequestHandler, -{ - match request { - Ok(msg) if msg.is_initialize() => { - let id = msg.id().unwrap_or_default(); - if has_initialized.load(Ordering::SeqCst) { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Server has already been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - } - let JsonRpcMessage::Request(req) = msg else { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Invalid method for initialization (use request)".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - }; - let JsonRpcRequest { params, .. } = req; - match handler.handle_initialize(params).await { - Ok(result) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - id, - result, - ..Default::default() - }); - let _ = transport.send(&resp).await; - has_initialized.store(true, Ordering::SeqCst); - }, - Err(_e) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InternalError.into(), - message: "Error producing initialization response".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - } - }, - Ok(msg) if msg.is_shutdown() => { - // TODO: add shutdown routine - }, - Ok(msg) if has_initialized.load(Ordering::SeqCst) => match msg { - JsonRpcMessage::Request(req) => { - let JsonRpcRequest { - id, - jsonrpc, - params, - ref method, - } = req; - let resp = handler.handle_incoming(method, params).await.map_or_else( - |error| { - let err = JsonRpcError { - code: ErrorCode::InternalError.into(), - message: error.to_string(), - data: None, - }; - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result: None, - error: Some(err), - }; - JsonRpcMessage::Response(resp) - }, - |result| { - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result, - error: None, - }; - JsonRpcMessage::Response(resp) - }, - ); - let _ = transport.send(&resp).await; - }, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { ref method, params, .. } = notif; - let _ = handler.handle_incoming(method, params).await; - }, - JsonRpcMessage::Response(resp) => { - let _ = handler.handle_response(resp).await; - }, - }, - Ok(msg) => { - let id = msg.id().unwrap_or_default(); - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::ServerNotInitialized.into(), - message: "Server has not been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - Err(_e) => { - // TODO: error handling - }, - } -} diff --git a/crates/chat-cli/src/mcp_client/transport/base_protocol.rs b/crates/chat-cli/src/mcp_client/transport/base_protocol.rs deleted file mode 100644 index b0394e6e0c..0000000000 --- a/crates/chat-cli/src/mcp_client/transport/base_protocol.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Referencing https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/messages/ -//! Protocol Revision 2024-11-05 -use serde::{ - Deserialize, - Serialize, -}; - -pub type RequestId = u64; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct JsonRpcVersion(String); - -impl Default for JsonRpcVersion { - fn default() -> Self { - JsonRpcVersion("2.0".to_owned()) - } -} - -impl JsonRpcVersion { - pub fn as_u32_vec(&self) -> Vec { - self.0 - .split(".") - .map(|n| n.parse::().unwrap()) - .collect::>() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(untagged)] -#[serde(deny_unknown_fields)] -// DO NOT change the order of these variants. This body of json is [untagged](https://serde.rs/enum-representations.html#untagged) -// The categorization of the deserialization depends on the order in which the variants are -// declared. -pub enum JsonRpcMessage { - Response(JsonRpcResponse), - Notification(JsonRpcNotification), - Request(JsonRpcRequest), -} - -impl JsonRpcMessage { - pub fn is_initialize(&self) -> bool { - match self { - JsonRpcMessage::Request(req) => req.method == "initialize", - _ => false, - } - } - - pub fn is_shutdown(&self) -> bool { - match self { - JsonRpcMessage::Notification(notif) => notif.method == "notification/shutdown", - _ => false, - } - } - - pub fn id(&self) -> Option { - match self { - JsonRpcMessage::Request(req) => Some(req.id), - JsonRpcMessage::Response(resp) => Some(resp.id), - JsonRpcMessage::Notification(_) => None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcRequest { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcResponse { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcNotification { - pub jsonrpc: JsonRpcVersion, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcError { - pub code: i32, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -pub enum TransportType { - #[default] - Stdio, - Websocket, -} diff --git a/crates/chat-cli/src/mcp_client/transport/mod.rs b/crates/chat-cli/src/mcp_client/transport/mod.rs deleted file mode 100644 index f752b1675a..0000000000 --- a/crates/chat-cli/src/mcp_client/transport/mod.rs +++ /dev/null @@ -1,57 +0,0 @@ -pub mod base_protocol; -pub mod stdio; - -use std::fmt::Debug; - -pub use base_protocol::*; -pub use stdio::*; -use thiserror::Error; - -#[derive(Clone, Debug, Error)] -pub enum TransportError { - #[error("Serialization error: {0}")] - Serialization(String), - #[error("IO error: {0}")] - Stdio(String), - #[error("{0}")] - Custom(String), - #[error(transparent)] - RecvError(#[from] tokio::sync::broadcast::error::RecvError), -} - -impl From for TransportError { - fn from(err: serde_json::Error) -> Self { - TransportError::Serialization(err.to_string()) - } -} - -impl From for TransportError { - fn from(err: std::io::Error) -> Self { - TransportError::Stdio(err.to_string()) - } -} - -#[allow(dead_code)] -#[async_trait::async_trait] -pub trait Transport: Send + Sync + Debug + 'static { - /// Sends a message over the transport layer. - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError>; - /// Listens to awaits for a response. This is a call that should be used after `send` is called - /// to listen for a response from the message recipient. - fn get_listener(&self) -> impl Listener; - /// Gracefully terminates the transport connection, cleaning up any resources. - /// This should be called when the transport is no longer needed to ensure proper cleanup. - async fn shutdown(&self) -> Result<(), TransportError>; - /// Listener that listens for logging messages. - fn get_log_listener(&self) -> impl LogListener; -} - -#[async_trait::async_trait] -pub trait Listener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} - -#[async_trait::async_trait] -pub trait LogListener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} diff --git a/crates/chat-cli/src/mcp_client/transport/stdio.rs b/crates/chat-cli/src/mcp_client/transport/stdio.rs deleted file mode 100644 index 89266a183d..0000000000 --- a/crates/chat-cli/src/mcp_client/transport/stdio.rs +++ /dev/null @@ -1,285 +0,0 @@ -use std::sync::Arc; - -use tokio::io::{ - AsyncBufReadExt, - AsyncRead, - AsyncWriteExt as _, - BufReader, - Stdin, - Stdout, -}; -use tokio::process::{ - Child, - ChildStdin, -}; -use tokio::sync::{ - Mutex, - broadcast, -}; - -use super::base_protocol::JsonRpcMessage; -use super::{ - Listener, - LogListener, - Transport, - TransportError, -}; - -#[derive(Debug)] -pub enum JsonRpcStdioTransport { - Client { - stdin: Arc>, - receiver: broadcast::Receiver>, - log_receiver: broadcast::Receiver, - }, - Server { - stdout: Arc>, - receiver: broadcast::Receiver>, - }, -} - -impl JsonRpcStdioTransport { - fn spawn_reader( - reader: R, - tx: broadcast::Sender>, - ) { - tokio::spawn(async move { - let mut buffer = Vec::::new(); - let mut buf_reader = BufReader::new(reader); - loop { - buffer.clear(); - // Messages are delimited by newlines and assumed to contain no embedded newlines - // See https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio - match buf_reader.read_until(b'\n', &mut buffer).await { - Ok(0) => break, - Ok(_) => match serde_json::from_slice::(buffer.as_slice()) { - Ok(msg) => { - let _ = tx.send(Ok(msg)); - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - } - } - }); - } - - pub fn client(child_process: Child) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - let Some(stdout) = child_process.stdout else { - return Err(TransportError::Custom("No stdout found on child process".to_owned())); - }; - let Some(stdin) = child_process.stdin else { - return Err(TransportError::Custom("No stdin found on child process".to_owned())); - }; - let Some(stderr) = child_process.stderr else { - return Err(TransportError::Custom("No stderr found on child process".to_owned())); - }; - let (log_tx, log_receiver) = broadcast::channel::(100); - tokio::task::spawn(async move { - let stderr = tokio::io::BufReader::new(stderr); - let mut lines = stderr.lines(); - while let Ok(Some(line)) = lines.next_line().await { - let _ = log_tx.send(line); - } - }); - let stdin = Arc::new(Mutex::new(stdin)); - Self::spawn_reader(stdout, tx); - Ok(JsonRpcStdioTransport::Client { - stdin, - receiver, - log_receiver, - }) - } - - pub fn server(stdin: Stdin, stdout: Stdout) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - Self::spawn_reader(stdin, tx); - let stdout = Arc::new(Mutex::new(stdout)); - Ok(JsonRpcStdioTransport::Server { stdout, receiver }) - } -} - -#[async_trait::async_trait] -impl Transport for JsonRpcStdioTransport { - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdin = stdin.lock().await; - stdin - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - stdin - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - Ok(()) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdout = stdout.lock().await; - stdout - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - stdout - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - Ok(()) - }, - } - } - - fn get_listener(&self) -> impl Listener { - match self { - JsonRpcStdioTransport::Client { receiver, .. } | JsonRpcStdioTransport::Server { receiver, .. } => { - StdioListener { - receiver: receiver.resubscribe(), - } - }, - } - } - - async fn shutdown(&self) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut stdin = stdin.lock().await; - Ok(stdin.shutdown().await?) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut stdout = stdout.lock().await; - Ok(stdout.shutdown().await?) - }, - } - } - - fn get_log_listener(&self) -> impl LogListener { - match self { - JsonRpcStdioTransport::Client { log_receiver, .. } => StdioLogListener { - receiver: log_receiver.resubscribe(), - }, - JsonRpcStdioTransport::Server { .. } => unreachable!("server does not need a log listener"), - } - } -} - -pub struct StdioListener { - pub receiver: broadcast::Receiver>, -} - -#[async_trait::async_trait] -impl Listener for StdioListener { - async fn recv(&mut self) -> Result { - self.receiver.recv().await? - } -} - -pub struct StdioLogListener { - pub receiver: broadcast::Receiver, -} - -#[async_trait::async_trait] -impl LogListener for StdioLogListener { - async fn recv(&mut self) -> Result { - Ok(self.receiver.recv().await?) - } -} - -#[cfg(test)] -mod tests { - use std::process::Stdio; - - use serde_json::{ - Value, - json, - }; - use tokio::process::Command; - - use super::{ - JsonRpcMessage, - JsonRpcStdioTransport, - Listener, - Transport, - }; - - // Helpers for testing - fn create_test_message() -> JsonRpcMessage { - serde_json::from_value(json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "test_method", - "params": { - "test_param": "test_value" - } - })) - .unwrap() - } - - #[tokio::test] - async fn test_client_transport() { - #[cfg(windows)] - let mut cmd = { - let mut cmd = Command::new("powershell"); - cmd.args(&["cat"]); - cmd - }; - #[cfg(not(windows))] - let mut cmd = Command::new("cat"); - - cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).stderr(Stdio::piped()); - - // Inject our mock transport instead - let child = cmd.spawn().expect("Failed to spawn command"); - let transport = JsonRpcStdioTransport::client(child).expect("Failed to create client transport"); - - let message = create_test_message(); - let result = transport.send(&message).await; - assert!(result.is_ok(), "Failed to send message: {:?}", result); - - let echo = transport - .get_listener() - .recv() - .await - .expect("Failed to receive message"); - let echo_value = serde_json::to_value(&echo).expect("Failed to convert echo to value"); - let message_value = serde_json::to_value(&message).expect("Failed to convert message to value"); - assert!(are_json_values_equal(&echo_value, &message_value)); - } - - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } -} diff --git a/crates/chat-cli/src/mcp_client/transport/websocket.rs b/crates/chat-cli/src/mcp_client/transport/websocket.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/chat-cli/src/os/env.rs b/crates/chat-cli/src/os/env.rs index 40aac3b5fe..63e5449419 100644 --- a/crates/chat-cli/src/os/env.rs +++ b/crates/chat-cli/src/os/env.rs @@ -132,6 +132,13 @@ impl Env { } } + pub fn set_current_dir_for_test(&self, path: PathBuf) { + use inner::Inner; + if let Inner::Fake(fake) = &self.0 { + fake.lock().unwrap().cwd = path; + } + } + pub fn current_exe(&self) -> Result { use inner::Inner; match &self.0 { diff --git a/crates/chat-cli/src/telemetry/core.rs b/crates/chat-cli/src/telemetry/core.rs index 48d23bbef2..58091ae8ff 100644 --- a/crates/chat-cli/src/telemetry/core.rs +++ b/crates/chat-cli/src/telemetry/core.rs @@ -19,6 +19,7 @@ use crate::telemetry::definitions::metrics::{ AmazonqMessageResponseError, AmazonqProfileState, AmazonqStartChat, + AmazonqcliDailyHeartbeat, CodewhispererterminalAddChatMessage, CodewhispererterminalAgentConfigInit, CodewhispererterminalAgentContribution, @@ -499,6 +500,14 @@ impl Event { } .into_metric_datum(), ), + EventType::DailyHeartbeat {} => Some( + AmazonqcliDailyHeartbeat { + create_time: self.created_time, + value: None, + source: None, + } + .into_metric_datum(), + ), } } } @@ -689,6 +698,7 @@ pub enum EventType { message_id: Option, context_file_length: Option, }, + DailyHeartbeat {}, } #[derive(Debug)] diff --git a/crates/chat-cli/src/telemetry/mod.rs b/crates/chat-cli/src/telemetry/mod.rs index 0b0a535a5f..90a9faa8b2 100644 --- a/crates/chat-cli/src/telemetry/mod.rs +++ b/crates/chat-cli/src/telemetry/mod.rs @@ -235,6 +235,10 @@ impl TelemetryThread { Ok(self.tx.send(Event::new(EventType::UserLoggedIn {}))?) } + pub fn send_daily_heartbeat(&self) -> Result<(), TelemetryError> { + Ok(self.tx.send(Event::new(EventType::DailyHeartbeat {}))?) + } + pub async fn send_cli_subcommand_executed( &self, database: &Database, diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index 50091ce87c..2e89ce1e09 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -43,7 +43,11 @@ pub enum DirectoryError { type Result = std::result::Result; const WORKSPACE_AGENT_DIR_RELATIVE: &str = ".amazonq/cli-agents"; +const GLOBAL_SHADOW_REPO_DIR: &str = ".aws/amazonq/cli-checkpoints"; const GLOBAL_AGENT_DIR_RELATIVE_TO_HOME: &str = ".aws/amazonq/cli-agents"; +const WORKSPACE_PROMPTS_DIR_RELATIVE: &str = ".amazonq/prompts"; +const GLOBAL_PROMPTS_DIR_RELATIVE_TO_HOME: &str = ".aws/amazonq/prompts"; +const CLI_BASH_HISTORY_PATH: &str = ".aws/amazonq/.cli_bash_history"; /// The directory of the users home /// @@ -158,6 +162,10 @@ pub fn chat_legacy_global_mcp_config(os: &Os) -> Result { Ok(home_dir(os)?.join(".aws").join("amazonq").join("mcp.json")) } +pub fn chat_cli_bash_history_path(os: &Os) -> Result { + Ok(home_dir(os)?.join(CLI_BASH_HISTORY_PATH)) +} + /// Legacy workspace MCP server config path pub fn chat_legacy_workspace_mcp_config(os: &Os) -> Result { let cwd = os.env.current_dir()?; @@ -175,12 +183,61 @@ pub fn chat_local_agent_dir(os: &Os) -> Result { Ok(cwd.join(WORKSPACE_AGENT_DIR_RELATIVE)) } +/// The directory containing global prompts +pub fn chat_global_prompts_dir(os: &Os) -> Result { + Ok(home_dir(os)?.join(GLOBAL_PROMPTS_DIR_RELATIVE_TO_HOME)) +} + +/// The directory containing local prompts +pub fn chat_local_prompts_dir(os: &Os) -> Result { + let cwd = os.env.current_dir()?; + Ok(cwd.join(WORKSPACE_PROMPTS_DIR_RELATIVE)) +} + /// Canonicalizes path given by expanding the path given pub fn canonicalizes_path(os: &Os, path_as_str: &str) -> Result { let context = |input: &str| Ok(os.env.get(input).ok()); let home_dir = || os.env.home().map(|p| p.to_string_lossy().to_string()); - Ok(shellexpand::full_with_context(path_as_str, home_dir, context)?.to_string()) + let expanded = shellexpand::full_with_context(path_as_str, home_dir, context)?; + let path_buf = if !expanded.starts_with("/") { + // Convert relative paths to absolute paths + let current_dir = os.env.current_dir()?; + current_dir.join(expanded.as_ref() as &str) + } else { + // Already absolute path + PathBuf::from(expanded.as_ref() as &str) + }; + + // Try canonicalize first, fallback to manual normalization if it fails + match path_buf.canonicalize() { + Ok(normalized) => Ok(normalized.as_path().to_string_lossy().to_string()), + Err(_) => { + // If canonicalize fails (e.g., path doesn't exist), do manual normalization + let normalized = normalize_path(&path_buf); + Ok(normalized.to_string_lossy().to_string()) + }, + } +} + +/// Manually normalize a path by resolving . and .. components +fn normalize_path(path: &std::path::Path) -> std::path::PathBuf { + let mut components = Vec::new(); + for component in path.components() { + match component { + std::path::Component::CurDir => { + // Skip current directory components + }, + std::path::Component::ParentDir => { + // Pop the last component for parent directory + components.pop(); + }, + _ => { + components.push(component); + }, + } + } + components.iter().collect() } /// Given a globset builder and a path, build globs for both the file and directory patterns @@ -188,7 +245,11 @@ pub fn canonicalizes_path(os: &Os, path_as_str: &str) -> Result { /// patterns to exist in a globset. pub fn add_gitignore_globs(builder: &mut GlobSetBuilder, path: &str) -> Result<()> { let glob_for_file = Glob::new(path)?; - let glob_for_dir = Glob::new(&format!("{path}/**"))?; + + // remove existing slash in path so we don't end up with double slash + // Glob doesn't normalize the path so it doesn't work with double slash + let dir_pattern: String = format!("{}/**", path.trim_end_matches('/')); + let glob_for_dir = Glob::new(&dir_pattern)?; builder.add(glob_for_file); builder.add(glob_for_dir); @@ -232,6 +293,18 @@ pub fn agent_knowledge_dir(os: &Os, agent: Option<&crate::cli::Agent>) -> Result Ok(knowledge_bases_dir(os)?.join(unique_id)) } +/// The directory for MCP authentication cache +/// +/// This is the same directory used by IDE for SSO cache storage. +/// - All platforms: `$HOME/.aws/sso/cache` +pub fn get_mcp_auth_dir(os: &Os) -> Result { + Ok(home_dir(os)?.join(".aws").join("sso").join("cache")) +} + +pub fn get_shadow_repo_dir(os: &Os, conversation_id: String) -> Result { + Ok(home_dir(os)?.join(GLOBAL_SHADOW_REPO_DIR).join(conversation_id)) +} + /// Generate a unique identifier for an agent based on its path and name fn generate_agent_unique_id(agent: &crate::cli::Agent) -> String { use std::collections::hash_map::DefaultHasher; @@ -273,6 +346,40 @@ mod linux_tests { assert!(logs_dir().is_ok()); assert!(settings_path().is_ok()); } + + #[test] + fn test_add_gitignore_globs() { + let direct_file = "/home/user/a.txt"; + let nested_file = "/home/user/folder/a.txt"; + let other_file = "/home/admin/a.txt"; + + // Case 1: Path with trailing slash + let mut builder1 = GlobSetBuilder::new(); + add_gitignore_globs(&mut builder1, "/home/user/").unwrap(); + let globset1 = builder1.build().unwrap(); + + assert!(globset1.is_match(direct_file)); + assert!(globset1.is_match(nested_file)); + assert!(!globset1.is_match(other_file)); + + // Case 2: Path without trailing slash - should behave same as case 1 + let mut builder2 = GlobSetBuilder::new(); + add_gitignore_globs(&mut builder2, "/home/user").unwrap(); + let globset2 = builder2.build().unwrap(); + + assert!(globset2.is_match(direct_file)); + assert!(globset2.is_match(nested_file)); + assert!(!globset1.is_match(other_file)); + + // Case 3: File path - should only match exact file + let mut builder3 = GlobSetBuilder::new(); + add_gitignore_globs(&mut builder3, "/home/user/a.txt").unwrap(); + let globset3 = builder3.build().unwrap(); + + assert!(globset3.is_match(direct_file)); + assert!(!globset3.is_match(nested_file)); + assert!(!globset1.is_match(other_file)); + } } // TODO(grant): Add back path tests on linux @@ -394,28 +501,71 @@ mod tests { // Test home directory expansion let result = canonicalizes_path(&test_os, "~/test").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\home\\testuser\\test"); + #[cfg(unix)] assert_eq!(result, "/home/testuser/test"); // Test environment variable expansion let result = canonicalizes_path(&test_os, "$TEST_VAR/path").unwrap(); - assert_eq!(result, "test_value/path"); + #[cfg(windows)] + assert_eq!(result, "\\test_value\\path"); + #[cfg(unix)] + assert_eq!(result, "/test_value/path"); // Test combined expansion let result = canonicalizes_path(&test_os, "~/$TEST_VAR").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\home\\testuser\\test_value"); + #[cfg(unix)] assert_eq!(result, "/home/testuser/test_value"); + // Test ~, . and .. expansion + let result = canonicalizes_path(&test_os, "~/./.././testuser").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\home\\testuser"); + #[cfg(unix)] + assert_eq!(result, "/home/testuser"); + // Test absolute path (no expansion needed) let result = canonicalizes_path(&test_os, "/absolute/path").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\absolute\\path"); + #[cfg(unix)] assert_eq!(result, "/absolute/path"); - // Test relative path (no expansion needed) + // Test ~, . and .. expansion for a path that does not exist + let result = canonicalizes_path(&test_os, "~/./.././testuser/new/path/../../new").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\home\\testuser\\new"); + #[cfg(unix)] + assert_eq!(result, "/home/testuser/new"); + + // Test path with . and .. + let result = canonicalizes_path(&test_os, "/absolute/./../path").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\path"); + #[cfg(unix)] + assert_eq!(result, "/path"); + + // Test relative path (which should be expanded because now all inputs are converted to + // absolute) let result = canonicalizes_path(&test_os, "relative/path").unwrap(); - assert_eq!(result, "relative/path"); + #[cfg(windows)] + assert_eq!(result, "\\relative\\path"); + #[cfg(unix)] + assert_eq!(result, "/relative/path"); // Test glob prefixed paths let result = canonicalizes_path(&test_os, "**/path").unwrap(); - assert_eq!(result, "**/path"); + #[cfg(windows)] + assert_eq!(result, "\\**\\path"); + #[cfg(unix)] + assert_eq!(result, "/**/path"); let result = canonicalizes_path(&test_os, "**/middle/**/path").unwrap(); - assert_eq!(result, "**/middle/**/path"); + #[cfg(windows)] + assert_eq!(result, "\\**\\middle\\**\\path"); + #[cfg(unix)] + assert_eq!(result, "/**/middle/**/path"); } } diff --git a/crates/chat-cli/src/util/mod.rs b/crates/chat-cli/src/util/mod.rs index ad5ef15898..48d8c94c97 100644 --- a/crates/chat-cli/src/util/mod.rs +++ b/crates/chat-cli/src/util/mod.rs @@ -3,11 +3,12 @@ pub mod directories; pub mod knowledge_store; pub mod open; pub mod pattern_matching; -pub mod process; pub mod spinner; pub mod system_info; #[cfg(test)] pub mod test; +pub mod tool_permission_checker; +pub mod ui; use std::fmt::Display; use std::io; diff --git a/crates/chat-cli/src/util/process/mod.rs b/crates/chat-cli/src/util/process/mod.rs deleted file mode 100644 index e0a8414592..0000000000 --- a/crates/chat-cli/src/util/process/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -pub use sysinfo::Pid; - -#[cfg(target_os = "windows")] -mod windows; -#[cfg(target_os = "windows")] -pub use windows::*; - -#[cfg(not(windows))] -mod unix; -#[cfg(not(windows))] -pub use unix::*; diff --git a/crates/chat-cli/src/util/process/unix.rs b/crates/chat-cli/src/util/process/unix.rs deleted file mode 100644 index b0ffc60935..0000000000 --- a/crates/chat-cli/src/util/process/unix.rs +++ /dev/null @@ -1,64 +0,0 @@ -use nix::sys::signal::Signal; -use sysinfo::Pid; - -pub fn terminate_process(pid: Pid) -> Result<(), String> { - let nix_pid = nix::unistd::Pid::from_raw(pid.as_u32() as i32); - nix::sys::signal::kill(nix_pid, Signal::SIGTERM).map_err(|e| format!("Failed to terminate process: {}", e)) -} - -#[cfg(test)] -#[cfg(not(windows))] -mod tests { - use std::process::Command; - use std::time::Duration; - - use super::*; - - // Helper to create a long-running process for testing - fn spawn_test_process() -> std::process::Child { - let mut command = Command::new("sleep"); - command.arg("30"); - command.spawn().expect("Failed to spawn test process") - } - - #[test] - fn test_terminate_process() { - // Spawn a test process - let mut child = spawn_test_process(); - let pid = Pid::from_u32(child.id()); - - // Terminate the process - let result = terminate_process(pid); - - // Verify termination was successful - assert!(result.is_ok(), "Process termination failed: {:?}", result.err()); - - // Give it a moment to terminate - std::thread::sleep(Duration::from_millis(100)); - - // Verify the process is actually terminated - match child.try_wait() { - Ok(Some(_)) => { - // Process exited, which is what we expect - }, - Ok(None) => { - panic!("Process is still running after termination"); - }, - Err(e) => { - panic!("Error checking process status: {}", e); - }, - } - } - - #[test] - fn test_terminate_nonexistent_process() { - // Use a likely invalid PID - let invalid_pid = Pid::from_u32(u32::MAX - 1); - - // Attempt to terminate a non-existent process - let result = terminate_process(invalid_pid); - - // Should return an error - assert!(result.is_err(), "Terminating non-existent process should fail"); - } -} diff --git a/crates/chat-cli/src/util/process/windows.rs b/crates/chat-cli/src/util/process/windows.rs deleted file mode 100644 index 12e0389bd8..0000000000 --- a/crates/chat-cli/src/util/process/windows.rs +++ /dev/null @@ -1,120 +0,0 @@ -use std::ops::Deref; - -use sysinfo::Pid; -use windows::Win32::Foundation::{ - CloseHandle, - HANDLE, -}; -use windows::Win32::System::Threading::{ - OpenProcess, - PROCESS_TERMINATE, - TerminateProcess, -}; - -/// Terminate a process on Windows using the Windows API -pub fn terminate_process(pid: Pid) -> Result<(), String> { - unsafe { - // Open the process with termination rights - let handle = OpenProcess(PROCESS_TERMINATE, false, pid.as_u32()) - .map_err(|e| format!("Failed to open process: {}", e))?; - - // Create a safe handle that will be closed automatically when dropped - let safe_handle = SafeHandle::new(handle).ok_or_else(|| "Invalid process handle".to_string())?; - - // Terminate the process with exit code 1 - TerminateProcess(*safe_handle, 1).map_err(|e| format!("Failed to terminate process: {}", e))?; - - Ok(()) - } -} - -struct SafeHandle(HANDLE); - -impl SafeHandle { - fn new(handle: HANDLE) -> Option { - if !handle.is_invalid() { Some(Self(handle)) } else { None } - } -} - -impl Drop for SafeHandle { - fn drop(&mut self) { - unsafe { - let _ = CloseHandle(self.0); - } - } -} - -impl Deref for SafeHandle { - type Target = HANDLE; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[cfg(test)] -mod tests { - use std::process::Command; - use std::time::Duration; - - use super::*; - - // Helper to create a long-running process for testing - fn spawn_test_process() -> std::process::Child { - let mut command = Command::new("cmd"); - command.args(["/C", "timeout 30 > nul"]); - command.spawn().expect("Failed to spawn test process") - } - - #[test] - fn test_terminate_process() { - // Spawn a test process - let mut child = spawn_test_process(); - let pid = Pid::from_u32(child.id()); - - // Terminate the process - let result = terminate_process(pid); - - // Verify termination was successful - assert!(result.is_ok(), "Process termination failed: {:?}", result.err()); - - // Give it a moment to terminate - std::thread::sleep(Duration::from_millis(100)); - - // Verify the process is actually terminated - match child.try_wait() { - Ok(Some(_)) => { - // Process exited, which is what we expect - }, - Ok(None) => { - panic!("Process is still running after termination"); - }, - Err(e) => { - panic!("Error checking process status: {}", e); - }, - } - } - - #[test] - fn test_terminate_nonexistent_process() { - // Use a likely invalid PID - let invalid_pid = Pid::from_u32(u32::MAX - 1); - - // Attempt to terminate a non-existent process - let result = terminate_process(invalid_pid); - - // Should return an error - assert!(result.is_err(), "Terminating non-existent process should fail"); - } - - #[test] - fn test_safe_handle() { - // Test creating a SafeHandle with an invalid handle - let invalid_handle = HANDLE(std::ptr::null_mut()); - let safe_handle = SafeHandle::new(invalid_handle); - assert!(safe_handle.is_none(), "SafeHandle should be None for invalid handle"); - - // We can't easily test a valid handle without actually opening a process, - // which would require additional setup and teardown - } -} diff --git a/crates/chat-cli/src/util/tool_permission_checker.rs b/crates/chat-cli/src/util/tool_permission_checker.rs new file mode 100644 index 0000000000..f1cc04f895 --- /dev/null +++ b/crates/chat-cli/src/util/tool_permission_checker.rs @@ -0,0 +1,82 @@ +use std::collections::HashSet; + +use tracing::debug; + +use crate::util::MCP_SERVER_TOOL_DELIMITER; +use crate::util::pattern_matching::matches_any_pattern; + +/// Checks if a tool is allowed based on the agent's allowed_tools configuration. +/// This function handles both native tools and MCP tools with wildcard pattern support. +pub fn is_tool_in_allowlist(allowed_tools: &HashSet, tool_name: &str, server_name: Option<&str>) -> bool { + let filter_patterns = |predicate: fn(&str) -> bool| -> HashSet { + allowed_tools + .iter() + .filter(|pattern| predicate(pattern)) + .cloned() + .collect() + }; + + match server_name { + // Native tool + None => { + let patterns = filter_patterns(|p| !p.starts_with('@')); + debug!("Native patterns: {:?}", patterns); + let result = matches_any_pattern(&patterns, tool_name); + debug!("Native tool '{}' permission check result: {}", tool_name, result); + result + }, + // MCP tool + Some(server) => { + let patterns = filter_patterns(|p| p.starts_with('@')); + debug!("MCP patterns: {:?}", patterns); + + // Check server-level permission first: @server_name + let server_pattern = format!("@{}", server); + debug!("Checking server-level pattern: '{}'", server_pattern); + if matches_any_pattern(&patterns, &server_pattern) { + debug!("Server-level permission granted for '{}'", server_pattern); + return true; + } + + // Check tool-specific permission: @server_name/tool_name + let tool_pattern = format!("@{}{}{}", server, MCP_SERVER_TOOL_DELIMITER, tool_name); + debug!("Checking tool-specific pattern: '{}'", tool_pattern); + let result = matches_any_pattern(&patterns, &tool_pattern); + debug!("Tool-specific permission result for '{}': {}", tool_pattern, result); + result + }, + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + #[test] + fn test_native_vs_mcp_separation() { + let mut allowed = HashSet::new(); + allowed.insert("fs_*".to_string()); + allowed.insert("@git".to_string()); + + // Native patterns only apply to native tools + assert!(is_tool_in_allowlist(&allowed, "fs_read", None)); + assert!(!is_tool_in_allowlist(&allowed, "fs_read", Some("server"))); + + // MCP patterns only apply to MCP tools + assert!(is_tool_in_allowlist(&allowed, "status", Some("git"))); + assert!(!is_tool_in_allowlist(&allowed, "git", None)); + } + + #[test] + fn test_mcp_wildcard_patterns() { + let mut allowed = HashSet::new(); + allowed.insert("@*quip*".to_string()); + allowed.insert("@git/read_*".to_string()); + + assert!(is_tool_in_allowlist(&allowed, "tool", Some("quip-server"))); + assert!(is_tool_in_allowlist(&allowed, "read_file", Some("git"))); + assert!(!is_tool_in_allowlist(&allowed, "write_file", Some("git"))); + } +} diff --git a/crates/chat-cli/src/util/ui.rs b/crates/chat-cli/src/util/ui.rs new file mode 100644 index 0000000000..ee93f0abae --- /dev/null +++ b/crates/chat-cli/src/util/ui.rs @@ -0,0 +1,136 @@ +use std::io::Write; + +use crossterm::execute; +use crossterm::style::{ + self, + Attribute, + Color, +}; +use eyre::Result; + +use crate::cli::feed::Feed; + +/// Render changelog content from feed.json with manual formatting +pub fn render_changelog_content(output: &mut impl Write) -> Result<()> { + let feed = Feed::load(); + let recent_entries = feed.get_all_changelogs() + .into_iter() + .take(2) // Show last 2 releases + .collect::>(); + + execute!(output, style::Print("\n"))?; + + // Title + execute!( + output, + style::SetForegroundColor(Color::Magenta), + style::SetAttribute(Attribute::Bold), + style::Print("What's New in Amazon Q CLI\n\n"), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + + // Render recent entries + for entry in recent_entries { + // Show version header + execute!( + output, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!("## {} ({})\n", entry.version, entry.date)), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + + for change in &entry.changes { + // Process **bold** syntax and remove PR links + let cleaned_description = clean_pr_links(&change.description); + let processed_description = process_bold_text(&cleaned_description); + execute!(output, style::Print("• "))?; + print_with_bold(output, &processed_description)?; + execute!(output, style::Print("\n"))?; + } + execute!(output, style::Print("\n"))?; // Add spacing between versions + } + + execute!( + output, + style::Print("\nRun `/changelog` anytime to see the latest updates and features!\n\n") + )?; + Ok(()) +} + +/// Removes PR links and numbers from changelog descriptions to improve readability. +/// +/// Removes text matching the pattern " - [#NUMBER](URL)" from the end of descriptions. +/// +/// Example input: "A new feature - [#2711](https://github.com/aws/amazon-q-developer-cli/pull/2711)" +/// Example output: "A new feature" +fn clean_pr_links(text: &str) -> String { + // Remove PR links like " - [#2711](https://github.com/aws/amazon-q-developer-cli/pull/2711)" + if let Some(pos) = text.find(" - [#") { + text[..pos].to_string() + } else { + text.to_string() + } +} + +/// Processes text to identify **bold** markdown syntax and returns segments with formatting info. +/// +/// Returns a vector of tuples where each tuple contains: +/// - `String`: The text segment +/// - `bool`: Whether this segment should be rendered in bold +/// +/// Example input: "This is **bold** text" +/// Example output: [("This is ", false), ("bold", true), (" text", false)] +fn process_bold_text(text: &str) -> Vec<(String, bool)> { + let mut result = Vec::new(); + let mut current = String::new(); + let mut in_bold = false; + let mut chars = text.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '*' && chars.peek() == Some(&'*') { + chars.next(); // consume second * + if !current.is_empty() { + result.push((current.clone(), in_bold)); + current.clear(); + } + in_bold = !in_bold; + } else { + current.push(ch); + } + } + + if !current.is_empty() { + result.push((current, in_bold)); + } + + result +} + +/// Renders text segments with proper bold formatting using crossterm. +/// +/// # Arguments +/// +/// * `output` - The writer to output formatted text to +/// * `segments` - Vector of (text, is_bold) tuples from `process_bold_text` +/// +/// # Errors +/// +/// Returns an error if writing to the output fails. +fn print_with_bold(output: &mut impl Write, segments: &[(String, bool)]) -> Result<()> { + for (text, is_bold) in segments { + if *is_bold { + execute!( + output, + style::SetAttribute(Attribute::Bold), + style::Print(text), + style::SetAttribute(Attribute::Reset), + )?; + } else { + execute!(output, style::Print(text))?; + } + } + Ok(()) +} diff --git a/crates/chat-cli/telemetry_definitions.json b/crates/chat-cli/telemetry_definitions.json index 3e52e5d3b2..5dac4ec712 100644 --- a/crates/chat-cli/telemetry_definitions.json +++ b/crates/chat-cli/telemetry_definitions.json @@ -530,6 +530,14 @@ { "type": "statusCode", "required": false }, { "type": "codewhispererterminal_clientApplication" } ] + }, + { + "name": "amazonqcli_dailyHeartbeat", + "description": "Daily heartbeat to track active CLI usage", + "unit": "None", + "metadata": [ + { "type": "source", "required": false } + ] } ] } diff --git a/crates/chat-cli/test_mcp_server/test_server.rs b/crates/chat-cli/test_mcp_server/test_server.rs deleted file mode 100644 index 970157f96b..0000000000 --- a/crates/chat-cli/test_mcp_server/test_server.rs +++ /dev/null @@ -1,340 +0,0 @@ -//! This is a bin used solely for testing the client -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::atomic::{ - AtomicU8, - Ordering, -}; - -use chat_cli::{ - self, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcStdioTransport, - PreServerRequestHandler, - Response, - Server, - ServerError, - ServerRequestHandler, -}; -use tokio::sync::Mutex; - -#[derive(Default)] -struct Handler { - pending_request: Option Option + Send + Sync>>, - #[allow(clippy::type_complexity)] - send_request: Option) -> Result<(), ServerError> + Send + Sync>>, - storage: Mutex>, - tool_spec: Mutex>, - tool_spec_key_list: Mutex>, - prompts: Mutex>, - prompt_key_list: Mutex>, - prompt_list_call_no: AtomicU8, -} - -impl PreServerRequestHandler for Handler { - fn register_pending_request_callback( - &mut self, - cb: impl Fn(u64) -> Option + Send + Sync + 'static, - ) { - self.pending_request = Some(Box::new(cb)); - } - - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ) { - self.send_request = Some(Box::new(cb)); - } -} - -#[async_trait::async_trait] -impl ServerRequestHandler for Handler { - async fn handle_initialize(&self, params: Option) -> Result { - let mut storage = self.storage.lock().await; - if let Some(params) = params { - storage.insert("client_cap".to_owned(), params); - } - let capabilities = serde_json::json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "subscribe": true, - "listChanged": true - }, - "tools": { - "listChanged": true - } - }, - "serverInfo": { - "name": "TestServer", - "version": "1.0.0" - } - }); - Ok(Some(capabilities)) - } - - async fn handle_incoming(&self, method: &str, params: Option) -> Result { - match method { - "notifications/initialized" => { - { - let mut storage = self.storage.lock().await; - storage.insert( - "init_ack_sent".to_owned(), - serde_json::Value::from_str("true").expect("Failed to convert string to value"), - ); - } - Ok(None) - }, - "verify_init_params_sent" => { - let client_capabilities = { - let storage = self.storage.lock().await; - storage.get("client_cap").cloned() - }; - Ok(client_capabilities) - }, - "verify_init_ack_sent" => { - let result = { - let storage = self.storage.lock().await; - storage.get("init_ack_sent").cloned() - }; - Ok(result) - }, - "store_mock_tool_spec" => { - let Some(params) = params else { - eprintln!("Params missing from store mock tool spec"); - return Ok(None); - }; - // expecting a mock_specs: { key: String, value: serde_json::Value }[]; - let Ok(mock_specs) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let self_tool_specs = self.tool_spec.lock().await; - let mut self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let _ = mock_specs.iter().fold(self_tool_specs, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_tool_spec_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - Ok(None) - }, - "tools/list" => { - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let self_tool_spec = self.tool_spec.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_tool_spec_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_tool_spec_key_list.get(i + 1).cloned(), - self_tool_spec.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - let tool_spec_key_list = self.tool_spec_key_list.lock().await; - let tool_spec = self.tool_spec.lock().await; - let first_key = tool_spec_key_list - .first() - .expect("First key missing from tool specs") - .clone(); - let first_value = tool_spec - .get(&first_key) - .expect("First value missing from tool specs") - .clone(); - let second_key = tool_spec_key_list - .get(1) - .expect("Second key missing from tool specs") - .clone(); - return Ok(Some(serde_json::json!({ - "tools": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_env_vars" => { - let kv = std::env::vars().fold(HashMap::::new(), |mut acc, (k, v)| { - acc.insert(k, v); - acc - }); - Ok(Some(serde_json::json!(kv))) - }, - // This is a test path relevant only to sampling - "trigger_server_request" => { - let Some(ref send_request) = self.send_request else { - return Err(ServerError::MissingMethod); - }; - let params = Some(serde_json::json!({ - "messages": [ - { - "role": "user", - "content": { - "type": "text", - "text": "What is the capital of France?" - } - } - ], - "modelPreferences": { - "hints": [ - { - "name": "claude-3-sonnet" - } - ], - "intelligencePriority": 0.8, - "speedPriority": 0.5 - }, - "systemPrompt": "You are a helpful assistant.", - "maxTokens": 100 - })); - send_request("sampling/createMessage", params)?; - Ok(None) - }, - "store_mock_prompts" => { - let Some(params) = params else { - eprintln!("Params missing from store mock prompts"); - return Ok(None); - }; - // expecting a mock_prompts: { key: String, value: serde_json::Value }[]; - let Ok(mock_prompts) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let mut self_prompts = self.prompts.lock().await; - let mut self_prompt_key_list = self.prompt_key_list.lock().await; - let is_first_mock = self_prompts.is_empty(); - self_prompts.clear(); - self_prompt_key_list.clear(); - let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_prompt_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - if !is_first_mock { - if let Some(sender) = &self.send_request { - let _ = sender("notifications/prompts/list_changed", None); - } - } - Ok(None) - }, - "prompts/list" => { - // We expect this method to be called after the mock prompts have already been - // stored. - self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed); - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_prompt_key_list = self.prompt_key_list.lock().await; - let self_prompts = self.prompts.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_prompt_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_prompt_key_list.get(i + 1).cloned(), - self_prompts.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - // If there is no parameter, this is the request to retrieve the first page - let prompt_key_list = self.prompt_key_list.lock().await; - let prompts = self.prompts.lock().await; - let first_key = prompt_key_list.first().expect("first key missing"); - let first_value = prompts.get(first_key).cloned().unwrap().unwrap(); - let second_key = prompt_key_list.get(1).expect("second key missing"); - return Ok(Some(serde_json::json!({ - "prompts": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_prompt_list_call_no" => Ok(Some( - serde_json::to_value::(self.prompt_list_call_no.load(Ordering::Relaxed)) - .expect("Failed to convert list call no to u8"), - )), - _ => Err(ServerError::MissingMethod), - } - } - - // This is a test path relevant only to sampling - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError> { - let JsonRpcResponse { id, .. } = resp; - let _pending = self.pending_request.as_ref().and_then(|f| f(id)); - Ok(()) - } - - async fn handle_shutdown(&self) -> Result<(), ServerError> { - Ok(()) - } -} - -#[tokio::main] -async fn main() { - let handler = Handler::default(); - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); - let test_server = Server::::new(handler, stdin, stdout).expect("Failed to create server"); - let _ = test_server.init().expect("Test server failed to init").await; -} diff --git a/crates/semantic-search-client/Cargo.toml b/crates/semantic-search-client/Cargo.toml index 5024a5f4c6..fcf6c49447 100644 --- a/crates/semantic-search-client/Cargo.toml +++ b/crates/semantic-search-client/Cargo.toml @@ -30,7 +30,7 @@ glob.workspace = true hnsw_rs = "=0.3.1" # BM25 implementation - works on all platforms including ARM -bm25 = { version = "2.2.1", features = ["language_detection"] } +bm25 = { version = "2.3.2", features = ["language_detection"] } # Common dependencies for all platforms anyhow = "1.0" diff --git a/deny.toml b/deny.toml index 60971d3c22..6a974b0974 100644 --- a/deny.toml +++ b/deny.toml @@ -27,6 +27,8 @@ ignore = [ "RUSTSEC-2024-0429", # paste is used in core deps "RUSTSEC-2024-0436", + # fxhash is unmaintained but used by bm25, waiting for bm25 to migrate + "RUSTSEC-2025-0057", ] [licenses] diff --git a/docs/agent-file-locations.md b/docs/agent-file-locations.md index c09ed9311b..b5be46e1b3 100644 --- a/docs/agent-file-locations.md +++ b/docs/agent-file-locations.md @@ -47,7 +47,7 @@ These agents are available from any directory when using Q CLI. When Q CLI looks for an agent, it follows this precedence order: -1. **Local first**: Checks `.aws/amazonq/cli-agents/` in the current working directory +1. **Local first**: Checks `.amazonq/cli-agents/` in the current working directory 2. **Global fallback**: If not found locally, checks `~/.aws/amazonq/cli-agents/` in the home directory ## Naming Conflicts diff --git a/docs/agent-format.md b/docs/agent-format.md index c005ad13cd..35adc8cfaa 100644 --- a/docs/agent-format.md +++ b/docs/agent-format.md @@ -2,6 +2,9 @@ The agent configuration file for each agent is a JSON file. The filename (without the `.json` extension) becomes the agent's name. It contains configuration needed to instantiate and run the agent. +> [!TIP] +> We recommend using the `/agent generate` slash command within your active Q session to intelligently generate your agent configuration with the help of Q. + Every agent configuration file can include the following sections: - [`name`](#name-field) — The name of the agent (optional, derived from filename if not specified). @@ -15,6 +18,7 @@ Every agent configuration file can include the following sections: - [`resources`](#resources-field) — Resources available to the agent. - [`hooks`](#hooks-field) — Commands run at specific trigger points. - [`useLegacyMcpJson`](#uselegacymcpjson-field) — Whether to include legacy MCP configuration. +- [`model`](#model-field) — The model ID to use for this agent. ## Name Field @@ -252,19 +256,37 @@ Resources can include: ## Hooks Field -The `hooks` field defines commands to run at specific trigger points. The output of these commands is added to the agent's context. +The `hooks` field defines commands to run at specific trigger points during agent lifecycle and tool execution. + +For detailed information about hook behavior, input/output formats, and examples, see the [Hooks documentation](hooks.md). ```json { "hooks": { "agentSpawn": [ { - "command": "git status", + "command": "git status" } ], "userPromptSubmit": [ { - "command": "ls -la", + "command": "ls -la" + } + ], + "preToolUse": [ + { + "matcher": "execute_bash", + "command": "{ echo \"$(date) - Bash command:\"; cat; echo; } >> /tmp/bash_audit_log" + }, + { + "matcher": "use_aws", + "command": "{ echo \"$(date) - AWS CLI call:\"; cat; echo; } >> /tmp/aws_audit_log" + } + ], + "postToolUse": [ + { + "matcher": "fs_write", + "command": "cargo fmt --all" } ] } @@ -273,10 +295,13 @@ The `hooks` field defines commands to run at specific trigger points. The output Each hook is defined with: - `command` (required): The command to execute +- `matcher` (optional): Pattern to match tool names for `preToolUse` and `postToolUse` hooks. See [built-in tools documentation](./built-in-tools.md) for available tool names. Available hook triggers: -- `agentSpawn`: Triggered when the agent is initialized -- `userPromptSubmit`: Triggered when the user submits a message +- `agentSpawn`: Triggered when the agent is initialized. +- `userPromptSubmit`: Triggered when the user submits a message. +- `preToolUse`: Triggered before a tool is executed. Can block the tool use. +- `postToolUse`: Triggered after a tool is executed. ## UseLegacyMcpJson Field @@ -290,6 +315,20 @@ The `useLegacyMcpJson` field determines whether to include MCP servers defined i When set to `true`, the agent will have access to all MCP servers defined in the global and local configurations in addition to those defined in the agent's `mcpServers` field. +## Model Field + +The `model` field specifies the model ID to use for this agent. If not specified, the agent will use the default model. + +```json +{ + "model": "claude-sonnet-4" +} +``` + +The model ID must match one of the available models returned by the Q CLI's model service. You can see available models by using the `/model` command in an active chat session. + +If the specified model is not available, the agent will fall back to the default model and display a warning. + ## Complete Example Here's a complete example of an agent configuration file: @@ -348,6 +387,7 @@ Here's a complete example of an agent configuration file: } ] }, - "useLegacyMcpJson": true + "useLegacyMcpJson": true, + "model": "claude-sonnet-4" } ``` diff --git a/docs/built-in-tools.md b/docs/built-in-tools.md index 337b7ec0f0..6d5012ba57 100644 --- a/docs/built-in-tools.md +++ b/docs/built-in-tools.md @@ -24,7 +24,7 @@ Execute the specified bash command. "execute_bash": { "allowedCommands": ["git status", "git fetch"], "deniedCommands": ["git commit .*", "git push .*"], - "allowReadOnly": true + "autoAllowReadonly": true } } } @@ -36,7 +36,7 @@ Execute the specified bash command. |--------|------|---------|------------------------------------------------------------------------------------------| | `allowedCommands` | array of strings | `[]` | List of specific commands that are allowed without prompting. Supports regex formatting. Note that regex entered are anchored with \A and \z | | `deniedCommands` | array of strings | `[]` | List of specific commands that are denied. Supports regex formatting. Note that regex entered are anchored with \A and \z. Deny rules are evaluated before allow rules | -| `allowReadOnly` | boolean | `true` | Whether to allow read-only commands without prompting | +| `autoAllowReadonly` | boolean | `false` | Whether to allow read-only commands without prompting | ## Fs_read Tool @@ -110,19 +110,19 @@ Opens the browser to a pre-filled GitHub issue template to report chat issues, b This tool has no configuration options. -## Knowledge Tool +## Knowledge Tool (experimental) Store and retrieve information in a knowledge base across chat sessions. Provides semantic search capabilities for files, directories, and text content. This tool has no configuration options. -## Thinking Tool +## Thinking Tool (experimental) An internal reasoning mechanism that improves the quality of complex tasks by breaking them down into atomic actions. This tool has no configuration options. -## Todo_list Tool +## TODO List Tool (experimental) Create and manage TODO lists for tracking multi-step tasks. Lists are stored locally in `.amazonq/cli-todo-lists/`. @@ -139,7 +139,8 @@ Make AWS CLI API calls with the specified service, operation, and parameters. "toolsSettings": { "use_aws": { "allowedServices": ["s3", "lambda", "ec2"], - "deniedServices": ["eks", "rds"] + "deniedServices": ["eks", "rds"], + "autoAllowReadonly": true } } } @@ -151,6 +152,7 @@ Make AWS CLI API calls with the specified service, operation, and parameters. |--------|------|---------|-------------| | `allowedServices` | array of strings | `[]` | List of AWS services that can be accessed without prompting | | `deniedServices` | array of strings | `[]` | List of AWS services to deny. Deny rules are evaluated before allow rules | +| `autoAllowReadonly` | boolean | `false` | Whether to automatically allow read-only operations (get, describe, list, ls, search, batch_get) without prompting | ## Using Tool Settings in Agent Configuration diff --git a/docs/default-agent-behavior.md b/docs/default-agent-behavior.md index 0510906a60..727de2e93f 100644 --- a/docs/default-agent-behavior.md +++ b/docs/default-agent-behavior.md @@ -96,7 +96,7 @@ q chat --agent specialized-agent ### Create a Custom Default You can create your own "default" agent by placing an agent file with the name `q_cli_default` in either: -- `.aws/amazonq/cli-agents/` (local) +- `.amazonq/cli-agents/` (local) - `~/.aws/amazonq/cli-agents/` (global) This will override the built-in default agent configuration. diff --git a/docs/experiments.md b/docs/experiments.md index 92da46fd1c..3864cc3670 100644 --- a/docs/experiments.md +++ b/docs/experiments.md @@ -4,6 +4,56 @@ Amazon Q CLI includes experimental features that can be toggled on/off using the ## Available Experiments +### Checkpointing +**Description:** Enables session-scoped checkpoints for tracking file changes using Git CLI commands + +**Features:** +- Snapshots file changes into a shadow bare git repo +- List, expand, diff, and restore to any checkpoint +- Conversation history unwinds when restoring checkpoints +- Auto-enables in git repositories (ephemeral, cleaned on session end) +- Manual initialization available for non-git directories + +**Usage:** +``` +/checkpoint init # Manually enable checkpoints (if not in git repo) +/checkpoint list [--limit N] # Show turn-level checkpoints with file stats +/checkpoint expand # Show tool-level checkpoints under a turn +/checkpoint diff [tag2|HEAD] # Compare checkpoints or with current state +/checkpoint restore [] [--hard] # Restore to checkpoint (interactive picker if no tag) +/checkpoint clean # Delete session shadow repo +``` + +**Restore Options:** +- Default: Revert tracked changes & deletions; keep files created after checkpoint +- `--hard`: Make workspace exactly match checkpoint; deletes tracked files created after it + +**Example:** +``` +/checkpoint list +[0] 2025-09-18 14:00:00 - Initial checkpoint +[1] 2025-09-18 14:05:31 - add two_sum.py (+1 file) +[2] 2025-09-18 14:07:10 - add tests (modified 1) + +/checkpoint expand 2 +[2] 2025-09-18 14:07:10 - add tests + └─ [2.1] fs_write: Add minimal test cases to two_sum.py (modified 1) +``` + +### Context Usage Percentage +**Description:** Shows context window usage as a percentage in the chat prompt + +**Features:** +- Displays percentage of context window used in prompt (e.g., "[rust-agent] 6% >") +- Color-coded indicators: + - Green: <50% usage + - Yellow: 50-89% usage + - Red: 90-100% usage +- Helps monitor context window consumption +- Disabled by default + +**When enabled:** The chat prompt will show your current context usage percentage with color coding to help you understand how much of the available context window is being used. + ### Knowledge **Command:** `/knowledge` **Description:** Enables persistent context storage and retrieval across chat sessions @@ -58,6 +108,28 @@ Amazon Q CLI includes experimental features that can be toggled on/off using the **When enabled:** Use `/tangent` or the keyboard shortcut to create a checkpoint and explore tangential topics. Use the same command to return to your main conversation. +### TODO Lists +**Tool name**: `todo_list` +**Command:** `/todos` +**Description:** Enables Q to create and modify TODO lists using the `todo_list` tool and the user to view and manage existing TODO lists using `/todos`. + +**Features:** +- Q will automatically make TODO lists when appropriate or when asked +- View, manage, and delete TODOs using `/todos` +- Resume existing TODO lists stored in `.amazonq/cli-todo-lists` + +**Usage:** +``` +/todos clear-finished # Delete completed TODOs in your working directory +/todos resume # Select and resume an existing TODO list +/todos view # Select and view and existing TODO list +/todos delete # Select and delete an existing TODO list +``` + +**Settings:** +- `chat.enableTodoList` - Enable/disable TODO list functionality (boolean) + + ## Managing Experiments Use the `/experiment` command to toggle experimental features: @@ -84,11 +156,15 @@ These features are provided to gather feedback and test new capabilities. Please All experimental commands are available in the fuzzy search (Ctrl+S): - `/experiment` - Manage experimental features - `/knowledge` - Knowledge base commands (when enabled) +- `/todos` - User-controlled TODO list commands (when enabled) ## Settings Integration Experiments are stored as settings and persist across sessions: +- `EnabledCheckpointing` - Checkpointing experiment state +- `EnabledContextUsagePercentage` - Context usage percentage experiment state - `EnabledKnowledge` - Knowledge experiment state - `EnabledThinking` - Thinking experiment state +- `EnabledTodoList` - TODO list experiment state You can also manage these through the settings system if needed. diff --git a/docs/hooks.md b/docs/hooks.md new file mode 100644 index 0000000000..d7cfa3d50f --- /dev/null +++ b/docs/hooks.md @@ -0,0 +1,161 @@ +# Hooks + +Hooks allow you to execute custom commands at specific points during agent lifecycle and tool execution. This enables security validation, logging, formatting, context gathering, and other custom behaviors. + +## Defining Hooks + +Hooks are defined in the agent configuration file. See the [agent format documentation](agent-format.md#hooks-field) for the complete syntax and examples. + +## Hook Event + +Hooks receive hook event in JSON format via STDIN: + +```json +{ + "hook_event_name": "agentSpawn", + "cwd": "/current/working/directory" +} +``` + +For tool-related hooks, additional fields are included: +- `tool_name`: Name of the tool being executed +- `tool_input`: Tool-specific parameters (see individual tool documentation) +- `tool_response`: Tool execution results (PostToolUse only) + +## Hook Output + +- **Exit code 0**: Hook succeeded. STDOUT is captured but not shown to user. +- **Exit code 2**: (PreToolUse only) Block tool execution. STDERR is returned to the LLM. +- **Other exit codes**: Hook failed. STDERR is shown as warning to user. + +## Tool Matching + +Use the `matcher` field to specify which tools the hook applies to: + +### Examples +- `"fs_write"` - Exact match for built-in tools +- `"fs_*"` - Wildcard pattern for built-in tools +- `"@git"` - All tools from git MCP server +- `"@git/status"` - Specific tool from git MCP server +- `"*"` - All tools (built-in and MCP) +- `"@builtin"` - All built-in tools only +- No matcher - Applies to all tools + +For complete tool reference format, see [agent format documentation](agent-format.md#tools-field). + +## Hook Types + +### AgentSpawn + +Runs when agent is activated. No tool context provided. + +**Hook Event** +```json +{ + "hook_event_name": "agentSpawn", + "cwd": "/current/working/directory" +} +``` + +**Exit Code Behavior:** +- **0**: Hook succeeded, STDOUT is added to agent's context +- **Other**: Show STDERR warning to user + +### UserPromptSubmit + +Runs when user submits a prompt. Output is added to conversation context. + +**Hook Event** +```json +{ + "hook_event_name": "userPromptSubmit", + "cwd": "/current/working/directory", + "prompt": "user's input prompt" +} +``` + +**Exit Code Behavior:** +- **0**: Hook succeeded, STDOUT is added to agent's context +- **Other**: Show STDERR warning to user + +### PreToolUse + +Runs before tool execution. Can validate and block tool usage. + +**Hook Event** +```json +{ + "hook_event_name": "preToolUse", + "cwd": "/current/working/directory", + "tool_name": "fs_read", + "tool_input": { + "operations": [ + { + "mode": "Line", + "path": "/current/working/directory/docs/hooks.md" + } + ] + } +} +``` + +**Exit Code Behavior:** +- **0**: Allow tool execution. +- **2**: Block tool execution, return STDERR to LLM. +- **Other**: Show STDERR warning to user, allow tool execution. + +### PostToolUse + +Runs after tool execution with access to tool results. + +**Hook Event** +```json +{ + "hook_event_name": "postToolUse", + "cwd": "/current/working/directory", + "tool_name": "fs_read", + "tool_input": { + "operations": [ + { + "mode": "Line", + "path": "/current/working/directory/docs/hooks.md" + } + ] + }, + "tool_response": { + "success": true, + "result": ["# Hooks\n\nHooks allow you to execute..."] + } +} +``` + +**Exit Code Behavior:** +- **0**: Hook succeeded. +- **Other**: Show STDERR warning to user. Tool already ran. + +### MCP Example + +For MCP tools, the tool name includes the full namespaced format including the MCP Server name: + +**Hook Event** +```json +{ + "hook_event_name": "preToolUse", + "cwd": "/current/working/directory", + "tool_name": "@postgres/query", + "tool_input": { + "sql": "SELECT * FROM orders LIMIT 10;" + } +} +``` + +## Timeout + +Default timeout is 30 seconds (30,000ms). Configure with `timeout_ms` field. + +## Caching + +Successful hook results are cached based on `cache_ttl_seconds`: +- `0`: No caching (default) +- `> 0`: Cache successful results for specified seconds +- AgentSpawn hooks are never cached \ No newline at end of file diff --git a/docs/introspect-tool.md b/docs/introspect-tool.md new file mode 100644 index 0000000000..53f6b50029 --- /dev/null +++ b/docs/introspect-tool.md @@ -0,0 +1,66 @@ +# Introspect Tool + +The introspect tool provides Q CLI with self-awareness, automatically answering questions about Q CLI's features, commands, and functionality using official documentation. + +## How It Works + +The introspect tool activates automatically when you ask Q CLI questions like: +- "How do I save conversations with Q CLI?" +- "What experimental features does Q CLI have?" +- "Can Q CLI read files?" + +## What It Provides + +- **Command Help**: Real-time help for all slash commands (`/save`, `/load`, etc.) +- **Documentation**: Access to README, built-in tools, experiments, and feature guides +- **Settings**: All configuration options and how to change them +- **GitHub Links**: Direct links to official documentation for verification + +## Important Limitations + +**Hallucination Risk**: Despite safeguards, the AI may occasionally provide inaccurate information or make assumptions. **Always verify important details** using the GitHub documentation links provided in responses. + +## Usage Examples + +``` +> How do I save conversations with Q CLI? +You can save conversations using `/save` or `/save name`. +Load them later with `/load`. + +> What experimental features does Q CLI have? +Q CLI offers Tangent Mode and Thinking Mode. +Use `/experiment` to enable them. + +> Can Q CLI read and write files? +Yes, Q CLI has fs_read, fs_write, and execute_bash tools +for file operations. +``` + +## Auto-Tangent Mode + +Enable automatic tangent mode for Q CLI help questions: + +```bash +q settings introspect.tangentMode true +``` + +This keeps help separate from your main conversation. + +## Best Practices + +1. **Be explicit**: Ask "How does Q CLI handle files?" not "How do you handle files?" +2. **Verify information**: Check the GitHub links provided in responses +3. **Use proper syntax**: Reference commands with `/` (e.g., `/save`) +4. **Enable auto-tangent**: Keep help isolated from main conversations + +## Configuration + +```bash +# Enable auto-tangent for introspect questions +q settings introspect.tangentMode true +``` + +## Related Features + +- **Tangent Mode**: Isolate help conversations +- **Experiments**: Enable experimental features with `/experiment` diff --git a/docs/knowledge-management.md b/docs/knowledge-management.md index a403092d4b..afa42a1dc9 100644 --- a/docs/knowledge-management.md +++ b/docs/knowledge-management.md @@ -2,7 +2,8 @@ The /knowledge command provides persistent knowledge base functionality for Amazon Q CLI, allowing you to store, search, and manage contextual information that persists across chat sessions. -> Note: This is a beta feature that must be enabled before use. +> [!NOTE] +> This is a beta feature that must be enabled before use. ## Getting Started @@ -168,6 +169,8 @@ Configure knowledge base behavior: ## Agent-Specific Knowledge Bases +> **Note**: Agent-specific knowledge bases are available in development versions but not yet released. In current releases (v1.14.1 and earlier), all knowledge bases are stored globally at `~/.aws/amazonq/knowledge_bases/` and shared across all agents. + ### Isolated Knowledge Storage Each agent maintains its own isolated knowledge base, ensuring that knowledge contexts are scoped to the specific agent you're working with. This provides better organization and prevents knowledge conflicts between different agents. @@ -177,7 +180,7 @@ Each agent maintains its own isolated knowledge base, ensuring that knowledge co Knowledge bases are stored in the following directory structure: ``` -~/.q/knowledge_bases/ +~/.aws/amazonq/knowledge_bases/ ├── q_cli_default/ # Default agent knowledge base │ ├── contexts.json # Metadata for all contexts │ ├── context-id-1/ # Individual context storage @@ -186,13 +189,13 @@ Knowledge bases are stored in the following directory structure: │ └── context-id-2/ │ ├── data.json │ └── bm25_data.json -├── my-custom-agent/ # Custom agent knowledge base +├── my-custom-agent_/ # Custom agent knowledge base │ ├── contexts.json │ ├── context-id-3/ │ │ └── data.json │ └── context-id-4/ │ └── data.json -└── another-agent/ # Another agent's knowledge base +└── another-agent_/ # Another agent's knowledge base ├── contexts.json └── context-id-5/ └── data.json diff --git a/docs/tangent-mode.md b/docs/tangent-mode.md new file mode 100644 index 0000000000..2bdbc346aa --- /dev/null +++ b/docs/tangent-mode.md @@ -0,0 +1,196 @@ +# Tangent Mode + +Tangent mode creates conversation checkpoints, allowing you to explore side topics without disrupting your main conversation flow. Enter tangent mode, ask questions or explore ideas, then return to your original conversation exactly where you left off. + +## Enabling Tangent Mode + +Tangent mode is experimental and must be enabled: + +**Via Experiment Command**: Run `/experiment` and select tangent mode from the list. + +**Via Settings**: `q settings chat.enableTangentMode true` + +## Basic Usage + +### Enter Tangent Mode +Use `/tangent` or Ctrl+T: +``` +> /tangent +Created a conversation checkpoint (↯). Use ctrl + t or /tangent to restore the conversation later. +``` + +### In Tangent Mode +You'll see a yellow `↯` symbol in your prompt: +``` +↯ > What is the difference between async and sync functions? +``` + +### Exit Tangent Mode +Use `/tangent` or Ctrl+T again: +``` +↯ > /tangent +Restored conversation from checkpoint (↯). - Returned to main conversation. +``` + +### Exit Tangent Mode with Tail +Use `/tangent tail` to preserve the last conversation entry (question + answer): +``` +↯ > /tangent tail +Restored conversation from checkpoint (↯) with last conversation entry preserved. +``` + +## Usage Examples + +### Example 1: Exploring Alternatives +``` +> I need to process a large CSV file in Python. What's the best approach? + +I recommend using pandas for CSV processing... + +> /tangent +Created a conversation checkpoint (↯). + +↯ > What about using the csv module instead of pandas? + +The csv module is lighter weight... + +↯ > /tangent +Restored conversation from checkpoint (↯). + +> Thanks! I'll go with pandas. Can you show me error handling? +``` + +### Example 2: Getting Q CLI Help +``` +> Help me write a deployment script + +I can help you create a deployment script... + +> /tangent +Created a conversation checkpoint (↯). + +↯ > What Q CLI commands are available for file operations? + +Q CLI provides fs_read, fs_write, execute_bash... + +↯ > /tangent +Restored conversation from checkpoint (↯). + +> It's a Node.js application for AWS +``` + +### Example 3: Clarifying Requirements +``` +> I need to optimize this SQL query + +Could you share the query you'd like to optimize? + +> /tangent +Created a conversation checkpoint (↯). + +↯ > What information do you need to help optimize a query? + +To optimize SQL queries effectively, I need: +1. The current query +2. Table schemas and indexes... + +↯ > /tangent +Restored conversation from checkpoint (↯). + +> Here's my query: SELECT * FROM orders... +``` + +### Example 4: Keeping Useful Information +``` +> Help me debug this Python error + +I can help you debug that. Could you share the error message? + +> /tangent +Created a conversation checkpoint (↯). + +↯ > What are the most common Python debugging techniques? + +Here are the most effective Python debugging techniques: +1. Use print statements strategically +2. Leverage the Python debugger (pdb)... + +↯ > /tangent tail +Restored conversation from checkpoint (↯) with last conversation entry preserved. + +> Here's my error: TypeError: unsupported operand type(s)... + +# The preserved entry (question + answer about debugging techniques) is now part of main conversation +``` + +## Configuration + +### Keyboard Shortcut +```bash +# Change shortcut key (default: t) +q settings chat.tangentModeKey y +``` + +### Auto-Tangent for Introspect +```bash +# Auto-enter tangent mode for Q CLI help questions +q settings introspect.tangentMode true +``` + +## Visual Indicators + +- **Normal mode**: `> ` (magenta) +- **Tangent mode**: `↯ > ` (yellow ↯ + magenta) +- **With profile**: `[dev] ↯ > ` (cyan + yellow ↯ + magenta) + +## Best Practices + +### When to Use Tangent Mode +- Asking clarifying questions about the current topic +- Exploring alternative approaches before deciding +- Getting help with Q CLI commands or features +- Testing understanding of concepts + +### When NOT to Use +- Completely unrelated topics (start new conversation) +- Long, complex discussions (use regular flow) +- When you want the side discussion in main context + +### Tips +1. **Keep tangents focused** - Brief explorations, not extended discussions +2. **Return promptly** - Don't forget you're in tangent mode +3. **Use for clarification** - Perfect for "wait, what does X mean?" questions +4. **Experiment safely** - Test ideas without affecting main conversation +5. **Use `/tangent tail`** - When both the tangent question and answer are useful for main conversation + +## Limitations + +- Tangent conversations are discarded when you exit +- Only one level of tangent supported (no nested tangents) +- Experimental feature that may change or be removed +- Must be explicitly enabled + +## Troubleshooting + +### Tangent Mode Not Working +```bash +# Enable via experiment (select from list) +/experiment + +# Or enable via settings +q settings chat.enableTangentMode true +``` + +### Keyboard Shortcut Not Working +```bash +# Check/reset shortcut key +q settings chat.tangentModeKey t +``` + +### Lost in Tangent Mode +Look for the `↯` symbol in your prompt. Use `/tangent` to exit and return to main conversation. + +## Related Features + +- **Introspect**: Q CLI help (auto-enters tangent if configured) +- **Experiments**: Manage experimental features with `/experiment` diff --git a/docs/todo-lists.md b/docs/todo-lists.md index 2db9d93910..8f218d3daa 100644 --- a/docs/todo-lists.md +++ b/docs/todo-lists.md @@ -91,3 +91,56 @@ If lists exist but won't load: 1. **Check permissions**: Ensure read access to `.amazonq/cli-todo-lists/` 2. **Verify format**: Lists should be valid JSON files 3. **Check file integrity**: Corrupted files may prevent loading + +## `todo_list` vs. `/todos` +The `todo_list` tool is specifically for the model to call. The model is allowed to create TODO lists, mark tasks as complete, add/remove +tasks, load TODO lists with a given ID (which are automatically provided when resuming TODO lists), and search for existing TODO lists. + +The `/todos` command is for the user to manage existing TODO lists created by the model. The user can view, resume, and delete TODO lists +by using the appropriate subcommand and selecting the TODO list to perform the action on. + +## Examples +#### Asking Q to make a TODO list: +``` +> Make a todo list with 3 read-only tasks. + +> I'll create a todo list with 3 read-only tasks for you. + +🛠️ Using tool: todo_list (trusted) + ⋮ + ● TODO: +[ ] Review project documentation +[ ] Check system status +[ ] Read latest updates + ⋮ + ● Completed in 0.4s +``` + +#### Selecting a TODO list to view: +``` +> /todos view + +? Select a to-do list to view: › +❯ ✗ Unfinished todo list (0/3) + ✔ Completed todo list (3/3) +``` + +#### Resuming a TODO list (after selecting): +``` +> /todos resume + +⟳ Resuming: Read-only tasks for information gathering + +🛠️ Using tool: todo_list (trusted) + ⋮ + ● TODO: +[x] Review project documentation +[ ] Check system status +[ ] Read latest updates + ⋮ + ● Completed in 0.1s + ``` + + + + diff --git a/schemas/agent-v1.json b/schemas/agent-v1.json index 15b626dea8..5e72b08476 100644 --- a/schemas/agent-v1.json +++ b/schemas/agent-v1.json @@ -16,6 +16,20 @@ }, "required": ["command"] } + }, + "TransportType": { + "oneOf": [ + { + "description": "Standard input/output transport (default)", + "type": "string", + "const": "stdio" + }, + { + "description": "HTTP transport for web-based communication", + "type": "string", + "const": "http" + } + ] } }, "properties": { @@ -49,9 +63,41 @@ "additionalProperties": { "type": "object", "properties": { + "type": { + "description": "The type of transport the mcp server is expecting. For http transport, only url (for now) is taken into account", + "$ref": "#/$definitions/TransportType", + "default": "stdio" + }, + "url": { + "description": "The URL endpoint for HTTP-based MCP servers", + "type": "string", + "default": "" + }, + "headers": { + "description": "HTTP headers to include when communicating with HTTP-based MCP servers", + "type": "object", + "additionalProperties": { + "type": "string" + }, + "default": {} + }, + "oauthScopes": { + "description": "Scopes with which oauth is done", + "type": "array", + "items": { + "type": "string" + }, + "default": [ + "openid", + "email", + "profile", + "offline_access" + ] + }, "command": { "description": "The command string used to initialize the mcp server", - "type": "string" + "type": "string", + "default": "" }, "args": { "description": "A list of arguments to be used to run the command with", @@ -159,6 +205,14 @@ "description": "Whether or not to include the legacy ~/.aws/amazonq/mcp.json in the agent\nYou can reference tools brought in by these servers as just as you would with the servers\nyou configure in the mcpServers field in this config", "type": "boolean", "default": false + }, + "model": { + "description": "The model ID to use for this agent. If not specified, uses the default model.", + "type": [ + "string", + "null" + ], + "default": null } }, "additionalProperties": false,