From 8733a516d135618d827faa57a915ec650ade219c Mon Sep 17 00:00:00 2001 From: xianwwu Date: Tue, 2 Sep 2025 09:27:08 -0700 Subject: [PATCH 01/71] fix: CTRL+C handling during multi-select, auto completion for /agent generate (#2741) * fixing bugs * formatting * fix: CTRL+C handling during multi-select, auto completion for /agent generate * set use legacy mcp config to false --------- Co-authored-by: Xian Wu --- crates/chat-cli/src/cli/chat/cli/profile.rs | 30 +++++++++++++++----- crates/chat-cli/src/cli/chat/conversation.rs | 1 + crates/chat-cli/src/cli/chat/prompt.rs | 1 + 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index 14150524bc..4b33fcb420 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -96,20 +96,31 @@ pub enum AgentSubcommand { Swap { name: Option }, } -fn prompt_mcp_server_selection(servers: &[McpServerInfo]) -> eyre::Result> { +fn prompt_mcp_server_selection(servers: &[McpServerInfo]) -> eyre::Result>> { let items: Vec = servers .iter() .map(|server| format!("{} ({})", server.name, server.config.command)) .collect(); - let selections = MultiSelect::new() + let selections = match MultiSelect::new() .with_prompt("Select MCP servers (use Space to toggle, Enter to confirm)") .items(&items) - .interact()?; - - let selected_servers: Vec<&McpServerInfo> = selections.iter().filter_map(|&i| servers.get(i)).collect(); + .interact_on_opt(&dialoguer::console::Term::stdout()) + { + Ok(sel) => sel, + Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => { + return Ok(None); + }, + Err(e) => return Err(eyre::eyre!("Failed to get MCP server selection: {e}")), + }; + + let selected_servers: Vec<&McpServerInfo> = selections + .unwrap_or_default() + .iter() + .filter_map(|&i| servers.get(i)) + .collect(); - Ok(selected_servers) + Ok(Some(selected_servers)) } impl AgentSubcommand { @@ -280,7 +291,12 @@ impl AgentSubcommand { let selected_servers = if mcp_servers.is_empty() { Vec::new() } else { - prompt_mcp_server_selection(&mcp_servers).map_err(|e| ChatError::Custom(e.to_string().into()))? + match prompt_mcp_server_selection(&mcp_servers) + .map_err(|e| ChatError::Custom(e.to_string().into()))? + { + Some(servers) => servers, + None => return Ok(ChatState::default()), + } }; let mcp_servers_json = if !selected_servers.is_empty() { diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 48fd13f991..803970ac9d 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -665,6 +665,7 @@ IMPORTANT: Return ONLY raw JSON with NO markdown formatting, NO code blocks, NO Your task is to generate an agent configuration file for an agent named '{}' with the following description: {}\n\n\ The configuration must conform to this JSON schema:\n{}\n\n\ We have a prepopulated template: {} \n\n\ +Please change the useLegacyMcpJson field to false. Please generate the prompt field using user provided description, and fill in the MCP tools that user has selected {}. Return only the JSON configuration, no additional text.", agent_name, agent_description, schema, prepopulated_content, selected_servers diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index a2ac52c7db..37fcec7739 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -67,6 +67,7 @@ pub const COMMANDS: &[&str] = &[ "/agent rename", "/agent set", "/agent schema", + "/agent generate", "/prompts", "/context", "/context help", From a8a0426eaff32ac9b1599ca1b1f1881ffd16054d Mon Sep 17 00:00:00 2001 From: kkashilk <93673379+kkashilk@users.noreply.github.com> Date: Tue, 2 Sep 2025 10:12:51 -0700 Subject: [PATCH 02/71] Fix calculation for num-lines contributed by q-cli (#2738) --- crates/chat-cli/src/cli/chat/line_tracker.rs | 9 +++- .../chat-cli/src/cli/chat/tools/fs_write.rs | 46 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/line_tracker.rs b/crates/chat-cli/src/cli/chat/line_tracker.rs index 80f640ecd9..1717d16fe7 100644 --- a/crates/chat-cli/src/cli/chat/line_tracker.rs +++ b/crates/chat-cli/src/cli/chat/line_tracker.rs @@ -13,6 +13,10 @@ pub struct FileLineTracker { pub before_fswrite_lines: usize, /// Line count after `fs_write` executes pub after_fswrite_lines: usize, + /// Lines added by agent in the current operation + pub lines_added_by_agent: usize, + /// Lines removed by agent in the current operation + pub lines_removed_by_agent: usize, /// Whether or not this is the first `fs_write` invocation pub is_first_write: bool, } @@ -23,6 +27,8 @@ impl Default for FileLineTracker { prev_fswrite_lines: 0, before_fswrite_lines: 0, after_fswrite_lines: 0, + lines_added_by_agent: 0, + lines_removed_by_agent: 0, is_first_write: true, } } @@ -34,7 +40,6 @@ impl FileLineTracker { } pub fn lines_by_agent(&self) -> isize { - let lines = (self.after_fswrite_lines as isize) - (self.before_fswrite_lines as isize); - lines.abs() + (self.lines_added_by_agent + self.lines_removed_by_agent) as isize } } diff --git a/crates/chat-cli/src/cli/chat/tools/fs_write.rs b/crates/chat-cli/src/cli/chat/tools/fs_write.rs index 6222b0cd57..284706f43a 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_write.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_write.rs @@ -247,11 +247,57 @@ impl FsWrite { let tracker = line_tracker.entry(path.to_string_lossy().to_string()).or_default(); tracker.after_fswrite_lines = after_lines; + + // Calculate actual lines added and removed by analyzing the diff + let (lines_added, lines_removed) = self.calculate_diff_lines(os).await?; + tracker.lines_added_by_agent = lines_added; + tracker.lines_removed_by_agent = lines_removed; + tracker.is_first_write = false; Ok(()) } + async fn calculate_diff_lines(&self, os: &Os) -> Result<(usize, usize)> { + let path = self.path(os); + + let result = match self { + FsWrite::Create { .. } => { + // For create operations, all lines in the new file are added + let new_content = os.fs.read_to_string(&path).await?; + let lines_added = new_content.lines().count(); + (lines_added, 0) + }, + FsWrite::StrReplace { old_str, new_str, .. } => { + // Use actual diff analysis for accurate line counting + let diff = similar::TextDiff::from_lines(old_str, new_str); + let mut lines_added = 0; + let mut lines_removed = 0; + + for change in diff.iter_all_changes() { + match change.tag() { + similar::ChangeTag::Insert => lines_added += 1, + similar::ChangeTag::Delete => lines_removed += 1, + similar::ChangeTag::Equal => {}, + } + } + (lines_added, lines_removed) + }, + FsWrite::Insert { new_str, .. } => { + // For insert operations, all lines in new_str are added + let lines_added = new_str.lines().count(); + (lines_added, 0) + }, + FsWrite::Append { new_str, .. } => { + // For append operations, all lines in new_str are added + let lines_added = new_str.lines().count(); + (lines_added, 0) + }, + }; + + Ok(result) + } + pub fn queue_description(&self, os: &Os, output: &mut impl Write) -> Result<()> { let cwd = os.env.current_dir()?; self.print_relative_path(os, output)?; From 5eb5d5519ab0953a12e5c06e2cfeb6eb57b88651 Mon Sep 17 00:00:00 2001 From: Kenneth Sanchez V Date: Wed, 3 Sep 2025 11:13:04 -0700 Subject: [PATCH 03/71] Update knowledge base directory path documentation (#2763) - Changed from ~/.q/knowledge_bases/ to ~/.aws/amazonq/knowledge_bases/ - Default agent uses q_cli_default/ (no alphanumeric suffix) - Custom agents use _/ format Co-authored-by: Kenneth S. --- docs/knowledge-management.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/knowledge-management.md b/docs/knowledge-management.md index a403092d4b..cb29cfd997 100644 --- a/docs/knowledge-management.md +++ b/docs/knowledge-management.md @@ -177,7 +177,7 @@ Each agent maintains its own isolated knowledge base, ensuring that knowledge co Knowledge bases are stored in the following directory structure: ``` -~/.q/knowledge_bases/ +~/.aws/amazonq/knowledge_bases/ ├── q_cli_default/ # Default agent knowledge base │ ├── contexts.json # Metadata for all contexts │ ├── context-id-1/ # Individual context storage @@ -186,13 +186,13 @@ Knowledge bases are stored in the following directory structure: │ └── context-id-2/ │ ├── data.json │ └── bm25_data.json -├── my-custom-agent/ # Custom agent knowledge base +├── my-custom-agent_/ # Custom agent knowledge base │ ├── contexts.json │ ├── context-id-3/ │ │ └── data.json │ └── context-id-4/ │ └── data.json -└── another-agent/ # Another agent's knowledge base +└── another-agent_/ # Another agent's knowledge base ├── contexts.json └── context-id-5/ └── data.json From 8038416a74270fcbed3e00d8fcf5be87a96d378b Mon Sep 17 00:00:00 2001 From: kiran-garre <137448023+kiran-garre@users.noreply.github.com> Date: Wed, 3 Sep 2025 15:45:58 -0700 Subject: [PATCH 04/71] docs: Update todo list docs for introspect (#2776) --- Cargo.lock | 77 ++++++------------- .../chat-cli/src/cli/chat/tools/introspect.rs | 5 ++ docs/todo-lists.md | 53 +++++++++++++ 3 files changed, 81 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3a59446f65..c8bdaadb7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -983,7 +983,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", - "regex-automata 0.4.9", + "regex-automata", "serde", ] @@ -2257,8 +2257,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" dependencies = [ "bit-set 0.5.3", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -2268,8 +2268,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298" dependencies = [ "bit-set 0.8.0", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -2775,8 +2775,8 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -3454,7 +3454,7 @@ dependencies = [ "percent-encoding", "referencing", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "reqwest", "serde", "serde_json", @@ -3617,7 +3617,7 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53304fff6ab1e597661eee37e42ea8c47a146fca280af902bb76bff8a896e523" dependencies = [ - "nu-ansi-term 0.50.1", + "nu-ansi-term", ] [[package]] @@ -3653,11 +3653,11 @@ checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -3899,16 +3899,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - [[package]] name = "nu-ansi-term" version = "0.50.1" @@ -3924,7 +3914,7 @@ version = "0.104.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5185420e479f45c9afabfb534b26282d3de13b9b286ac16851221cc17d04def3" dependencies = [ - "nu-ansi-term 0.50.1", + "nu-ansi-term", "nu-engine", "nu-json", "nu-protocol", @@ -4454,12 +4444,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "owo-colors" version = "4.2.2" @@ -5110,17 +5094,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -5131,7 +5106,7 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] @@ -5140,12 +5115,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -5727,7 +5696,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fce6d5bc71503c9ec2337c80dc41f4fb2ac62fe52d6ab7500d899db19ae436f8" dependencies = [ "bitflags 2.9.1", - "nu-ansi-term 0.50.1", + "nu-ansi-term", "nu-color-config", ] @@ -6069,7 +6038,7 @@ dependencies = [ "once_cell", "onig", "plist", - "regex-syntax 0.8.5", + "regex-syntax", "serde", "serde_derive", "serde_json", @@ -6338,7 +6307,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "serde", "serde_json", "spm_precompiled", @@ -6600,15 +6569,15 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", - "nu-ansi-term 0.46.0", + "nu-ansi-term", "once_cell", "parking_lot", - "regex", + "regex-automata", "serde", "serde_json", "sharded-slab", diff --git a/crates/chat-cli/src/cli/chat/tools/introspect.rs b/crates/chat-cli/src/cli/chat/tools/introspect.rs index 9a8d7be9ad..d64c4a51d6 100644 --- a/crates/chat-cli/src/cli/chat/tools/introspect.rs +++ b/crates/chat-cli/src/cli/chat/tools/introspect.rs @@ -62,6 +62,9 @@ impl Introspect { documentation.push_str("\n\n--- docs/agent-file-locations.md ---\n"); documentation.push_str(include_str!("../../../../../../docs/agent-file-locations.md")); + documentation.push_str("\n\n--- docs/todo-lists.md ---\n"); + documentation.push_str(include_str!("../../../../../../docs/todo-lists.md")); + documentation.push_str("\n\n--- CONTRIBUTING.md ---\n"); documentation.push_str(include_str!("../../../../../../CONTRIBUTING.md")); @@ -93,6 +96,8 @@ impl Introspect { documentation .push_str("• Experiments: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/experiments.md\n"); documentation.push_str("• Agent File Locations: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/agent-file-locations.md\n"); + documentation + .push_str("• Todo Lists: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/todo-lists.md\n"); documentation .push_str("• Contributing: https://github.com/aws/amazon-q-developer-cli/blob/main/CONTRIBUTING.md\n"); diff --git a/docs/todo-lists.md b/docs/todo-lists.md index 2db9d93910..8f218d3daa 100644 --- a/docs/todo-lists.md +++ b/docs/todo-lists.md @@ -91,3 +91,56 @@ If lists exist but won't load: 1. **Check permissions**: Ensure read access to `.amazonq/cli-todo-lists/` 2. **Verify format**: Lists should be valid JSON files 3. **Check file integrity**: Corrupted files may prevent loading + +## `todo_list` vs. `/todos` +The `todo_list` tool is specifically for the model to call. The model is allowed to create TODO lists, mark tasks as complete, add/remove +tasks, load TODO lists with a given ID (which are automatically provided when resuming TODO lists), and search for existing TODO lists. + +The `/todos` command is for the user to manage existing TODO lists created by the model. The user can view, resume, and delete TODO lists +by using the appropriate subcommand and selecting the TODO list to perform the action on. + +## Examples +#### Asking Q to make a TODO list: +``` +> Make a todo list with 3 read-only tasks. + +> I'll create a todo list with 3 read-only tasks for you. + +🛠️ Using tool: todo_list (trusted) + ⋮ + ● TODO: +[ ] Review project documentation +[ ] Check system status +[ ] Read latest updates + ⋮ + ● Completed in 0.4s +``` + +#### Selecting a TODO list to view: +``` +> /todos view + +? Select a to-do list to view: › +❯ ✗ Unfinished todo list (0/3) + ✔ Completed todo list (3/3) +``` + +#### Resuming a TODO list (after selecting): +``` +> /todos resume + +⟳ Resuming: Read-only tasks for information gathering + +🛠️ Using tool: todo_list (trusted) + ⋮ + ● TODO: +[x] Review project documentation +[ ] Check system status +[ ] Read latest updates + ⋮ + ● Completed in 0.1s + ``` + + + + From 46f1f326c62a6ad7b3ec733fbcb5b7b75d17cb9c Mon Sep 17 00:00:00 2001 From: abhraina-aws Date: Wed, 3 Sep 2025 16:09:01 -0700 Subject: [PATCH 05/71] feat: added tangent & introspect docs & provided to introspect (#2775) --- .../chat-cli/src/cli/chat/tools/introspect.rs | 11 ++ docs/introspect-tool.md | 66 +++++++ docs/tangent-mode.md | 165 ++++++++++++++++++ 3 files changed, 242 insertions(+) create mode 100644 docs/introspect-tool.md create mode 100644 docs/tangent-mode.md diff --git a/crates/chat-cli/src/cli/chat/tools/introspect.rs b/crates/chat-cli/src/cli/chat/tools/introspect.rs index d64c4a51d6..0b431ae4d3 100644 --- a/crates/chat-cli/src/cli/chat/tools/introspect.rs +++ b/crates/chat-cli/src/cli/chat/tools/introspect.rs @@ -62,6 +62,12 @@ impl Introspect { documentation.push_str("\n\n--- docs/agent-file-locations.md ---\n"); documentation.push_str(include_str!("../../../../../../docs/agent-file-locations.md")); + documentation.push_str("\n\n--- docs/tangent-mode.md ---\n"); + documentation.push_str(include_str!("../../../../../../docs/tangent-mode.md")); + + documentation.push_str("\n\n--- docs/introspect-tool.md ---\n"); + documentation.push_str(include_str!("../../../../../../docs/introspect-tool.md")); + documentation.push_str("\n\n--- docs/todo-lists.md ---\n"); documentation.push_str(include_str!("../../../../../../docs/todo-lists.md")); @@ -96,6 +102,11 @@ impl Introspect { documentation .push_str("• Experiments: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/experiments.md\n"); documentation.push_str("• Agent File Locations: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/agent-file-locations.md\n"); + documentation + .push_str("• Tangent Mode: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/tangent-mode.md\n"); + documentation.push_str( + "• Introspect Tool: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/introspect-tool.md\n", + ); documentation .push_str("• Todo Lists: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/todo-lists.md\n"); documentation diff --git a/docs/introspect-tool.md b/docs/introspect-tool.md new file mode 100644 index 0000000000..53f6b50029 --- /dev/null +++ b/docs/introspect-tool.md @@ -0,0 +1,66 @@ +# Introspect Tool + +The introspect tool provides Q CLI with self-awareness, automatically answering questions about Q CLI's features, commands, and functionality using official documentation. + +## How It Works + +The introspect tool activates automatically when you ask Q CLI questions like: +- "How do I save conversations with Q CLI?" +- "What experimental features does Q CLI have?" +- "Can Q CLI read files?" + +## What It Provides + +- **Command Help**: Real-time help for all slash commands (`/save`, `/load`, etc.) +- **Documentation**: Access to README, built-in tools, experiments, and feature guides +- **Settings**: All configuration options and how to change them +- **GitHub Links**: Direct links to official documentation for verification + +## Important Limitations + +**Hallucination Risk**: Despite safeguards, the AI may occasionally provide inaccurate information or make assumptions. **Always verify important details** using the GitHub documentation links provided in responses. + +## Usage Examples + +``` +> How do I save conversations with Q CLI? +You can save conversations using `/save` or `/save name`. +Load them later with `/load`. + +> What experimental features does Q CLI have? +Q CLI offers Tangent Mode and Thinking Mode. +Use `/experiment` to enable them. + +> Can Q CLI read and write files? +Yes, Q CLI has fs_read, fs_write, and execute_bash tools +for file operations. +``` + +## Auto-Tangent Mode + +Enable automatic tangent mode for Q CLI help questions: + +```bash +q settings introspect.tangentMode true +``` + +This keeps help separate from your main conversation. + +## Best Practices + +1. **Be explicit**: Ask "How does Q CLI handle files?" not "How do you handle files?" +2. **Verify information**: Check the GitHub links provided in responses +3. **Use proper syntax**: Reference commands with `/` (e.g., `/save`) +4. **Enable auto-tangent**: Keep help isolated from main conversations + +## Configuration + +```bash +# Enable auto-tangent for introspect questions +q settings introspect.tangentMode true +``` + +## Related Features + +- **Tangent Mode**: Isolate help conversations +- **Experiments**: Enable experimental features with `/experiment` diff --git a/docs/tangent-mode.md b/docs/tangent-mode.md new file mode 100644 index 0000000000..0ee785c99b --- /dev/null +++ b/docs/tangent-mode.md @@ -0,0 +1,165 @@ +# Tangent Mode + +Tangent mode creates conversation checkpoints, allowing you to explore side topics without disrupting your main conversation flow. Enter tangent mode, ask questions or explore ideas, then return to your original conversation exactly where you left off. + +## Enabling Tangent Mode + +Tangent mode is experimental and must be enabled: + +**Via Experiment Command**: Run `/experiment` and select tangent mode from the list. + +**Via Settings**: `q settings chat.enableTangentMode true` + +## Basic Usage + +### Enter Tangent Mode +Use `/tangent` or Ctrl+T: +``` +> /tangent +Created a conversation checkpoint (↯). Use ctrl + t or /tangent to restore the conversation later. +``` + +### In Tangent Mode +You'll see a yellow `↯` symbol in your prompt: +``` +↯ > What is the difference between async and sync functions? +``` + +### Exit Tangent Mode +Use `/tangent` or Ctrl+T again: +``` +↯ > /tangent +Restored conversation from checkpoint (↯). - Returned to main conversation. +``` + +## Usage Examples + +### Example 1: Exploring Alternatives +``` +> I need to process a large CSV file in Python. What's the best approach? + +I recommend using pandas for CSV processing... + +> /tangent +Created a conversation checkpoint (↯). + +↯ > What about using the csv module instead of pandas? + +The csv module is lighter weight... + +↯ > /tangent +Restored conversation from checkpoint (↯). + +> Thanks! I'll go with pandas. Can you show me error handling? +``` + +### Example 2: Getting Q CLI Help +``` +> Help me write a deployment script + +I can help you create a deployment script... + +> /tangent +Created a conversation checkpoint (↯). + +↯ > What Q CLI commands are available for file operations? + +Q CLI provides fs_read, fs_write, execute_bash... + +↯ > /tangent +Restored conversation from checkpoint (↯). + +> It's a Node.js application for AWS +``` + +### Example 3: Clarifying Requirements +``` +> I need to optimize this SQL query + +Could you share the query you'd like to optimize? + +> /tangent +Created a conversation checkpoint (↯). + +↯ > What information do you need to help optimize a query? + +To optimize SQL queries effectively, I need: +1. The current query +2. Table schemas and indexes... + +↯ > /tangent +Restored conversation from checkpoint (↯). + +> Here's my query: SELECT * FROM orders... +``` + +## Configuration + +### Keyboard Shortcut +```bash +# Change shortcut key (default: t) +q settings chat.tangentModeKey y +``` + +### Auto-Tangent for Introspect +```bash +# Auto-enter tangent mode for Q CLI help questions +q settings introspect.tangentMode true +``` + +## Visual Indicators + +- **Normal mode**: `> ` (magenta) +- **Tangent mode**: `↯ > ` (yellow ↯ + magenta) +- **With profile**: `[dev] ↯ > ` (cyan + yellow ↯ + magenta) + +## Best Practices + +### When to Use Tangent Mode +- Asking clarifying questions about the current topic +- Exploring alternative approaches before deciding +- Getting help with Q CLI commands or features +- Testing understanding of concepts + +### When NOT to Use +- Completely unrelated topics (start new conversation) +- Long, complex discussions (use regular flow) +- When you want the side discussion in main context + +### Tips +1. **Keep tangents focused** - Brief explorations, not extended discussions +2. **Return promptly** - Don't forget you're in tangent mode +3. **Use for clarification** - Perfect for "wait, what does X mean?" questions +4. **Experiment safely** - Test ideas without affecting main conversation + +## Limitations + +- Tangent conversations are discarded when you exit +- Only one level of tangent supported (no nested tangents) +- Experimental feature that may change or be removed +- Must be explicitly enabled + +## Troubleshooting + +### Tangent Mode Not Working +```bash +# Enable via experiment (select from list) +/experiment + +# Or enable via settings +q settings chat.enableTangentMode true +``` + +### Keyboard Shortcut Not Working +```bash +# Check/reset shortcut key +q settings chat.tangentModeKey t +``` + +### Lost in Tangent Mode +Look for the `↯` symbol in your prompt. Use `/tangent` to exit and return to main conversation. + +## Related Features + +- **Introspect**: Q CLI help (auto-enters tangent if configured) +- **Experiments**: Manage experimental features with `/experiment` From 3aa06ebcfde82da531a95faa87e0ffd75a3fbb71 Mon Sep 17 00:00:00 2001 From: kkashilk <93673379+kkashilk@users.noreply.github.com> Date: Thu, 4 Sep 2025 12:24:27 -0700 Subject: [PATCH 06/71] feat: implement persistent CLI history with file storage (#2769) - Add Drop trait to InputSource for automatic history saving - Replace DefaultHistory with FileHistory for persistence - Store history in ~/.aws/amazonq/cli_history - Refactor ChatHinter to use rustyline's built-in history search - Remove manual history tracking in favor of rustyline's implementation - Add history loading on startup with error handling - Clean up unused hinter history update methods --- crates/chat-cli/src/cli/chat/input_source.rs | 41 +++++++-- crates/chat-cli/src/cli/chat/prompt.rs | 94 +++++++++++--------- crates/chat-cli/src/util/directories.rs | 5 ++ 3 files changed, 94 insertions(+), 46 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/input_source.rs b/crates/chat-cli/src/cli/chat/input_source.rs index 5d88abf6f3..0c0830852c 100644 --- a/crates/chat-cli/src/cli/chat/input_source.rs +++ b/crates/chat-cli/src/cli/chat/input_source.rs @@ -31,11 +31,33 @@ mod inner { } } +impl Drop for InputSource { + fn drop(&mut self) { + self.save_history().unwrap(); + } +} impl InputSource { pub fn new(os: &Os, sender: PromptQuerySender, receiver: PromptQueryResponseReceiver) -> Result { Ok(Self(inner::Inner::Readline(rl(os, sender, receiver)?))) } + /// Save history to file + pub fn save_history(&mut self) -> Result<()> { + if let inner::Inner::Readline(rl) = &mut self.0 { + if let Some(helper) = rl.helper() { + let history_path = helper.get_history_path(); + + // Create directory if it doesn't exist + if let Some(parent) = history_path.parent() { + std::fs::create_dir_all(parent)?; + } + + rl.append_history(&history_path)?; + } + } + Ok(()) + } + #[cfg(unix)] pub fn put_skim_command_selector( &mut self, @@ -78,12 +100,9 @@ impl InputSource { let curr_line = rl.readline(prompt); match curr_line { Ok(line) => { - let _ = rl.add_history_entry(line.as_str()); - - if let Some(helper) = rl.helper_mut() { - helper.update_hinter_history(&line); + if Self::should_append_history(&line) { + let _ = rl.add_history_entry(line.as_str()); } - Ok(Some(line)) }, Err(ReadlineError::Interrupted | ReadlineError::Eof) => Ok(None), @@ -97,6 +116,18 @@ impl InputSource { } } + fn should_append_history(line: &str) -> bool { + let trimmed = line.trim().to_lowercase(); + if trimmed.is_empty() { + return false; + } + + if matches!(trimmed.as_str(), "y" | "n" | "t") { + return false; + } + true + } + // We're keeping this method for potential future use #[allow(dead_code)] pub fn set_buffer(&mut self, content: &str) { diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index 37fcec7739..77fff472a1 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::cell::RefCell; +use std::path::PathBuf; use eyre::Result; use rustyline::completion::{ @@ -13,7 +14,10 @@ use rustyline::highlight::{ Highlighter, }; use rustyline::hint::Hinter as RustylineHinter; -use rustyline::history::DefaultHistory; +use rustyline::history::{ + FileHistory, + SearchDirection, +}; use rustyline::validate::{ ValidationContext, ValidationResult, @@ -44,6 +48,7 @@ use super::tool_manager::{ }; use crate::database::settings::Setting; use crate::os::Os; +use crate::util::directories::chat_cli_bash_history_path; pub const COMMANDS: &[&str] = &[ "/clear", @@ -262,31 +267,26 @@ impl Completer for ChatCompleter { /// Custom hinter that provides shadowtext suggestions pub struct ChatHinter { - /// Command history for providing suggestions based on past commands - history: Vec, /// Whether history-based hints are enabled history_hints_enabled: bool, + history_path: PathBuf, } impl ChatHinter { /// Creates a new ChatHinter instance - pub fn new(history_hints_enabled: bool) -> Self { + pub fn new(history_hints_enabled: bool, history_path: PathBuf) -> Self { Self { - history: Vec::new(), history_hints_enabled, + history_path, } } - /// Updates the history with a new command - pub fn update_history(&mut self, command: &str) { - let command = command.trim(); - if !command.is_empty() && !command.contains('\n') && !command.contains('\r') { - self.history.push(command.to_string()); - } + pub fn get_history_path(&self) -> PathBuf { + self.history_path.clone() } - /// Finds the best hint for the current input - fn find_hint(&self, line: &str) -> Option { + /// Finds the best hint for the current input using rustyline's history + fn find_hint(&self, line: &str, ctx: &Context<'_>) -> Option { // If line is empty, no hint if line.is_empty() { return None; @@ -300,13 +300,20 @@ impl ChatHinter { .map(|cmd| cmd[line.len()..].to_string()); } - // Try to find a hint from history if history hints are enabled + // Try to find a hint from rustyline's history if history hints are enabled if self.history_hints_enabled { - return self.history - .iter() - .rev() // Start from most recent - .find(|cmd| cmd.starts_with(line) && cmd.len() > line.len()) - .map(|cmd| cmd[line.len()..].to_string()); + let history = ctx.history(); + let history_len = history.len(); + if history_len == 0 { + return None; + } + + if let Ok(Some(search_result)) = history.starts_with(line, history_len - 1, SearchDirection::Reverse) { + let entry = search_result.entry.to_string(); + if entry.len() > line.len() { + return Some(entry[line.len()..].to_string()); + } + } } None @@ -316,13 +323,13 @@ impl ChatHinter { impl RustylineHinter for ChatHinter { type Hint = String; - fn hint(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Option { + fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option { // Only provide hints when cursor is at the end of the line if pos < line.len() { return None; } - self.find_hint(line) + self.find_hint(line, ctx) } } @@ -363,9 +370,8 @@ pub struct ChatHelper { } impl ChatHelper { - /// Updates the history of the ChatHinter with a new command - pub fn update_hinter_history(&mut self, command: &str) { - self.hinter.update_history(command); + pub fn get_history_path(&self) -> PathBuf { + self.hinter.get_history_path() } } @@ -426,7 +432,7 @@ pub fn rl( os: &Os, sender: PromptQuerySender, receiver: PromptQueryResponseReceiver, -) -> Result> { +) -> Result> { let edit_mode = match os.database.settings.get_string(Setting::ChatEditMode).as_deref() { Some("vi" | "vim") => EditMode::Vi, _ => EditMode::Emacs, @@ -437,21 +443,30 @@ pub fn rl( .edit_mode(edit_mode) .build(); - // Default to disabled if setting doesn't exist let history_hints_enabled = os .database .settings .get_bool(Setting::ChatEnableHistoryHints) .unwrap_or(false); + + let history_path = chat_cli_bash_history_path(os)?; + let h = ChatHelper { completer: ChatCompleter::new(sender, receiver), - hinter: ChatHinter::new(history_hints_enabled), + hinter: ChatHinter::new(history_hints_enabled, history_path), validator: MultiLineValidator, }; let mut rl = Editor::with_config(config)?; rl.set_helper(Some(h)); + // Load history from ~/.aws/amazonq/cli_history + if let Err(e) = rl.load_history(&rl.helper().unwrap().get_history_path()) { + if !matches!(e, ReadlineError::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::NotFound) { + eprintln!("Warning: Failed to load history: {}", e); + } + } + // Add custom keybinding for Alt+Enter to insert a newline rl.bind_sequence( KeyEvent(KeyCode::Enter, Modifiers::ALT), @@ -487,6 +502,7 @@ pub fn rl( mod tests { use crossterm::style::Stylize; use rustyline::highlight::Highlighter; + use rustyline::history::DefaultHistory; use super::*; @@ -537,7 +553,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -553,7 +569,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -569,7 +585,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -585,7 +601,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -604,7 +620,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -620,7 +636,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -635,7 +651,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -650,7 +666,7 @@ mod tests { let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), + hinter: ChatHinter::new(true, PathBuf::new()), validator: MultiLineValidator, }; @@ -664,7 +680,7 @@ mod tests { #[test] fn test_chat_hinter_command_hint() { - let hinter = ChatHinter::new(true); + let hinter = ChatHinter::new(true, PathBuf::new()); // Test hint for a command let line = "/he"; @@ -694,11 +710,7 @@ mod tests { #[test] fn test_chat_hinter_history_hint_disabled() { - let mut hinter = ChatHinter::new(false); - - // Add some history - hinter.update_history("Hello, world!"); - hinter.update_history("How are you?"); + let hinter = ChatHinter::new(false, PathBuf::new()); // Test hint from history - should be None since history hints are disabled let line = "How"; diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index 50091ce87c..a34b71b6f1 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -44,6 +44,7 @@ type Result = std::result::Result; const WORKSPACE_AGENT_DIR_RELATIVE: &str = ".amazonq/cli-agents"; const GLOBAL_AGENT_DIR_RELATIVE_TO_HOME: &str = ".aws/amazonq/cli-agents"; +const CLI_BASH_HISTORY_PATH: &str = ".aws/amazonq/.cli_bash_history"; /// The directory of the users home /// @@ -158,6 +159,10 @@ pub fn chat_legacy_global_mcp_config(os: &Os) -> Result { Ok(home_dir(os)?.join(".aws").join("amazonq").join("mcp.json")) } +pub fn chat_cli_bash_history_path(os: &Os) -> Result { + Ok(home_dir(os)?.join(CLI_BASH_HISTORY_PATH)) +} + /// Legacy workspace MCP server config path pub fn chat_legacy_workspace_mcp_config(os: &Os) -> Result { let cwd = os.env.current_dir()?; From f1a76645ee5e04fd01973f7e0ab5cda24279a7fc Mon Sep 17 00:00:00 2001 From: evanliu048 Date: Thu, 4 Sep 2025 14:23:49 -0700 Subject: [PATCH 07/71] chore: Skip sending profileArn when using custom endpoints (#2777) * comment profile set * comment profile in apiclient * add a helper func * fix compile issue * remove dead code tag --- crates/chat-cli/src/api_client/mod.rs | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/crates/chat-cli/src/api_client/mod.rs b/crates/chat-cli/src/api_client/mod.rs index caa2507f2d..f21b448b77 100644 --- a/crates/chat-cli/src/api_client/mod.rs +++ b/crates/chat-cli/src/api_client/mod.rs @@ -193,12 +193,19 @@ impl ApiClient { }, } - let profile = match database.get_auth_profile() { - Ok(profile) => profile, - Err(err) => { - error!("Failed to get auth profile: {err}"); - None - }, + // Check if using custom endpoint + let use_profile = !Self::is_custom_endpoint(database); + let profile = if use_profile { + match database.get_auth_profile() { + Ok(profile) => profile, + Err(err) => { + error!("Failed to get auth profile: {err}"); + None + }, + } + } else { + debug!("Custom endpoint detected, skipping profile ARN"); + None }; Ok(Self { @@ -598,6 +605,11 @@ impl ApiClient { self.mock_client = Some(Arc::new(Mutex::new(mock.into_iter()))); } + + // Add a helper method to check if using non-default endpoint + fn is_custom_endpoint(database: &Database) -> bool { + database.settings.get(Setting::ApiCodeWhispererService).is_some() + } } fn timeout_config(database: &Database) -> TimeoutConfig { From 1813bb775d92d0876488c4494f28f799e089dfe4 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Thu, 4 Sep 2025 18:06:10 -0700 Subject: [PATCH 08/71] chore(mcp): migrate to rmcp (#2700) * client struct definition * clean up unused code * adds mechanism for checking if server is alive * prefetches prompts if applicable * fixes agent swap fixes agent swap * only applies process group leader promo for unix * removes unused import for windows * renames abstractions for different stages of mcp config --- Cargo.lock | 97 +- Cargo.toml | 1 + crates/chat-cli/Cargo.toml | 7 +- crates/chat-cli/src/cli/chat/cli/clear.rs | 1 + crates/chat-cli/src/cli/chat/cli/compact.rs | 6 + crates/chat-cli/src/cli/chat/cli/context.rs | 4 + crates/chat-cli/src/cli/chat/cli/editor.rs | 2 + crates/chat-cli/src/cli/chat/cli/hooks.rs | 1 + crates/chat-cli/src/cli/chat/cli/mcp.rs | 4 + crates/chat-cli/src/cli/chat/cli/model.rs | 3 +- crates/chat-cli/src/cli/chat/cli/persist.rs | 8 +- crates/chat-cli/src/cli/chat/cli/profile.rs | 17 +- crates/chat-cli/src/cli/chat/cli/prompts.rs | 53 +- crates/chat-cli/src/cli/chat/cli/subscribe.rs | 2 + crates/chat-cli/src/cli/chat/cli/tools.rs | 4 + crates/chat-cli/src/cli/chat/cli/usage.rs | 6 + crates/chat-cli/src/cli/chat/conversation.rs | 60 +- .../chat-cli/src/cli/chat/error_formatter.rs | 148 -- crates/chat-cli/src/cli/chat/mod.rs | 11 +- .../chat-cli/src/cli/chat/server_messenger.rs | 86 +- crates/chat-cli/src/cli/chat/tool_manager.rs | 340 ++-- .../src/cli/chat/tools/custom_tool.rs | 259 +-- crates/chat-cli/src/cli/mod.rs | 2 +- crates/chat-cli/src/mcp_client/client.rs | 1423 +++++------------ crates/chat-cli/src/mcp_client/error.rs | 66 - .../src/mcp_client/facilitator_types.rs | 248 --- crates/chat-cli/src/mcp_client/messenger.rs | 69 +- crates/chat-cli/src/mcp_client/mod.rs | 9 - crates/chat-cli/src/mcp_client/server.rs | 311 ---- .../src/mcp_client/transport/base_protocol.rs | 108 -- .../chat-cli/src/mcp_client/transport/mod.rs | 57 - .../src/mcp_client/transport/stdio.rs | 285 ---- .../src/mcp_client/transport/websocket.rs | 0 crates/chat-cli/src/util/mod.rs | 1 - crates/chat-cli/src/util/process/mod.rs | 11 - crates/chat-cli/src/util/process/unix.rs | 64 - crates/chat-cli/src/util/process/windows.rs | 120 -- .../chat-cli/test_mcp_server/test_server.rs | 340 ---- 38 files changed, 890 insertions(+), 3344 deletions(-) delete mode 100644 crates/chat-cli/src/cli/chat/error_formatter.rs delete mode 100644 crates/chat-cli/src/mcp_client/error.rs delete mode 100644 crates/chat-cli/src/mcp_client/facilitator_types.rs delete mode 100644 crates/chat-cli/src/mcp_client/server.rs delete mode 100644 crates/chat-cli/src/mcp_client/transport/base_protocol.rs delete mode 100644 crates/chat-cli/src/mcp_client/transport/mod.rs delete mode 100644 crates/chat-cli/src/mcp_client/transport/stdio.rs delete mode 100644 crates/chat-cli/src/mcp_client/transport/websocket.rs delete mode 100644 crates/chat-cli/src/util/process/mod.rs delete mode 100644 crates/chat-cli/src/util/process/unix.rs delete mode 100644 crates/chat-cli/src/util/process/windows.rs delete mode 100644 crates/chat-cli/test_mcp_server/test_server.rs diff --git a/Cargo.lock b/Cargo.lock index c8bdaadb7c..c7f5d0043a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1062,7 +1062,7 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9225bdcf4e4a9a4c08bf16607908eb2fbf746828d5e0b5e019726dbf6571f201" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -1280,6 +1280,7 @@ dependencies = [ "regex", "reqwest", "ring", + "rmcp", "rusqlite", "rustls 0.23.31", "rustls-native-certs 0.8.1", @@ -1812,8 +1813,18 @@ version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", ] [[package]] @@ -1830,13 +1841,38 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.104", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", "quote", "syn 2.0.104", ] @@ -1902,7 +1938,7 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -4683,6 +4719,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "process-wrap" +version = "8.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3ef4f2f0422f23a82ec9f628ea2acd12871c81a9362b02c43c1aa86acfc3ba1" +dependencies = [ + "futures", + "indexmap", + "nix 0.30.1", + "tokio", + "tracing", + "windows 0.61.3", +] + [[package]] name = "procfs" version = "0.17.0" @@ -5184,6 +5234,42 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7dd163d26e254725137b7933e4ba042ea6bf2d756a4260559aaea8b6ad4c27e" +dependencies = [ + "base64 0.22.1", + "chrono", + "futures", + "paste", + "pin-project-lite", + "process-wrap", + "rmcp-macros", + "schemars", + "serde", + "serde_json", + "thiserror 2.0.14", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + +[[package]] +name = "rmcp-macros" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a43bb4c90a0d4b12f7315eb681a73115d335a2cee81322eca96f3467fe4cd06f" +dependencies = [ + "darling 0.21.3", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.104", +] + [[package]] name = "roxmltree" version = "0.14.1" @@ -5461,6 +5547,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" dependencies = [ + "chrono", "dyn-clone", "ref-cast", "schemars_derive", diff --git a/Cargo.toml b/Cargo.toml index 48d9e5d937..0bbee837a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -129,6 +129,7 @@ winnow = "=0.6.2" winreg = "0.55.0" schemars = "1.0.4" jsonschema = "0.30.0" +rmcp = { version = "0.6.0", features = ["client", "transport-child-process"] } [workspace.lints.rust] future_incompatible = "warn" diff --git a/crates/chat-cli/Cargo.toml b/crates/chat-cli/Cargo.toml index 1b9b78d869..37911de5f1 100644 --- a/crates/chat-cli/Cargo.toml +++ b/crates/chat-cli/Cargo.toml @@ -15,12 +15,6 @@ workspace = true default = [] wayland = ["arboard/wayland-data-control"] -[[bin]] -name = "test_mcp_server" -path = "test_mcp_server/test_server.rs" -test = true -doc = false - [dependencies] amzn-codewhisperer-client.workspace = true amzn-codewhisperer-streaming-client.workspace = true @@ -123,6 +117,7 @@ whoami.workspace = true winnow.workspace = true schemars.workspace = true jsonschema.workspace = true +rmcp.workspace = true [target.'cfg(unix)'.dependencies] nix.workspace = true diff --git a/crates/chat-cli/src/cli/chat/cli/clear.rs b/crates/chat-cli/src/cli/chat/cli/clear.rs index 7f2bd9d9ae..8da854abea 100644 --- a/crates/chat-cli/src/cli/chat/cli/clear.rs +++ b/crates/chat-cli/src/cli/chat/cli/clear.rs @@ -17,6 +17,7 @@ use crate::cli::chat::{ #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] +/// Arguments for the clear command that erases conversation history and context. pub struct ClearArgs; impl ClearArgs { diff --git a/crates/chat-cli/src/cli/chat/cli/compact.rs b/crates/chat-cli/src/cli/chat/cli/compact.rs index 79e5727ef4..27b9b1a465 100644 --- a/crates/chat-cli/src/cli/chat/cli/compact.rs +++ b/crates/chat-cli/src/cli/chat/cli/compact.rs @@ -31,6 +31,12 @@ How it works Compaction will be automatically performed whenever the context window overflows. To disable this behavior, run: `q settings chat.disableAutoCompaction true`" )] +/// Arguments for the `/compact` command that summarizes conversation history to free up context +/// space. +/// +/// This command creates an AI-generated summary of the conversation while preserving essential +/// information, code, and tool executions. It's useful for long-running conversations that +/// may reach memory constraints. pub struct CompactArgs { /// The prompt to use when generating the summary prompt: Vec, diff --git a/crates/chat-cli/src/cli/chat/cli/context.rs b/crates/chat-cli/src/cli/chat/cli/context.rs index df008330cf..025ef440a0 100644 --- a/crates/chat-cli/src/cli/chat/cli/context.rs +++ b/crates/chat-cli/src/cli/chat/cli/context.rs @@ -38,6 +38,7 @@ Notes: • Agent rules apply only to the current agent • Context changes are NOT preserved between chat sessions. To make these changes permanent, edit the agent config file." )] +/// Subcommands for managing context rules and files in Amazon Q chat sessions pub enum ContextSubcommand { /// Display the context rule configuration and matched files Show { @@ -52,17 +53,20 @@ pub enum ContextSubcommand { #[arg(short, long)] force: bool, #[arg(required = true)] + /// Paths or glob patterns to remove from context rules paths: Vec, }, /// Remove specified rules #[command(alias = "rm")] Remove { + /// Paths or glob patterns to remove from context rules #[arg(required = true)] paths: Vec, }, /// Remove all rules Clear, #[command(hide = true)] + /// Display information about agent format hooks (deprecated) Hooks, } diff --git a/crates/chat-cli/src/cli/chat/cli/editor.rs b/crates/chat-cli/src/cli/chat/cli/editor.rs index 53ddc54ddf..fdf8d0d501 100644 --- a/crates/chat-cli/src/cli/chat/cli/editor.rs +++ b/crates/chat-cli/src/cli/chat/cli/editor.rs @@ -15,7 +15,9 @@ use crate::cli::chat::{ #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] +/// Command-line arguments for the editor functionality pub struct EditorArgs { + /// Initial text to populate in the editor pub initial_text: Vec, } diff --git a/crates/chat-cli/src/cli/chat/cli/hooks.rs b/crates/chat-cli/src/cli/chat/cli/hooks.rs index fab1989464..38763285c6 100644 --- a/crates/chat-cli/src/cli/chat/cli/hooks.rs +++ b/crates/chat-cli/src/cli/chat/cli/hooks.rs @@ -286,6 +286,7 @@ Notes: • 'conversation_start' hooks run on the first user prompt and are attached once to the conversation history sent to Amazon Q • 'per_prompt' hooks run on each user prompt and are attached to the prompt, but are not stored in conversation history" )] +/// Arguments for the hooks command that displays configured context hooks pub struct HooksArgs; impl HooksArgs { diff --git a/crates/chat-cli/src/cli/chat/cli/mcp.rs b/crates/chat-cli/src/cli/chat/cli/mcp.rs index 82a9740c5e..bbb4883ff5 100644 --- a/crates/chat-cli/src/cli/chat/cli/mcp.rs +++ b/crates/chat-cli/src/cli/chat/cli/mcp.rs @@ -14,6 +14,10 @@ use crate::cli::chat::{ ChatState, }; +/// Arguments for the MCP (Model Context Protocol) command. +/// +/// This struct handles MCP-related functionality, allowing users to view +/// the status of MCP servers and their loading progress. #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct McpArgs; diff --git a/crates/chat-cli/src/cli/chat/cli/model.rs b/crates/chat-cli/src/cli/chat/cli/model.rs index 1e484666f0..de9819fe41 100644 --- a/crates/chat-cli/src/cli/chat/cli/model.rs +++ b/crates/chat-cli/src/cli/chat/cli/model.rs @@ -60,10 +60,11 @@ impl ModelInfo { self.model_name.as_deref().unwrap_or(&self.model_id) } } + +/// Command-line arguments for model selection operations #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct ModelArgs; - impl ModelArgs { pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { Ok(select_model(os, session).await?.unwrap_or(ChatState::PromptUser { diff --git a/crates/chat-cli/src/cli/chat/cli/persist.rs b/crates/chat-cli/src/cli/chat/cli/persist.rs index 1f4568f7ed..b1f5d0da19 100644 --- a/crates/chat-cli/src/cli/chat/cli/persist.rs +++ b/crates/chat-cli/src/cli/chat/cli/persist.rs @@ -14,17 +14,23 @@ use crate::cli::chat::{ }; use crate::os::Os; +/// Commands for persisting and loading conversation state #[deny(missing_docs)] #[derive(Debug, PartialEq, Subcommand)] pub enum PersistSubcommand { /// Save the current conversation Save { + /// Path where the conversation will be saved path: String, #[arg(short, long)] + /// Force overwrite if file already exists force: bool, }, /// Load a previous conversation - Load { path: String }, + Load { + /// Path to the conversation file to load + path: String, + }, } impl PersistSubcommand { diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index 4b33fcb420..2edca042f3 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -60,6 +60,7 @@ Notes • Set default agent to assume with settings by running \"q settings chat.defaultAgent agent_name\" • Each agent maintains its own set of context and customizations" )] +/// Subcommands for managing agents in the chat CLI pub enum AgentSubcommand { /// List all available agents List, @@ -80,20 +81,30 @@ pub enum AgentSubcommand { Generate {}, /// Delete the specified agent #[command(hide = true)] - Delete { name: String }, + Delete { + /// Name of the agent to delete + name: String, + }, /// Switch to the specified agent #[command(hide = true)] - Set { name: String }, + Set { + /// Name of the agent to switch to + name: String, + }, /// Show agent config schema Schema, /// Define a default agent to use when q chat launches SetDefault { + /// Name of the agent to set as default #[arg(long, short)] name: String, }, /// Swap to a new agent at runtime #[command(alias = "switch")] - Swap { name: Option }, + Swap { + /// Optional name of the agent to swap to. If not provided, a selection dialog will be shown + name: Option, + }, } fn prompt_mcp_server_selection(servers: &[McpServerInfo]) -> eyre::Result>> { diff --git a/crates/chat-cli/src/cli/chat/cli/prompts.rs b/crates/chat-cli/src/cli/chat/cli/prompts.rs index 53b0012a57..55f7661f76 100644 --- a/crates/chat-cli/src/cli/chat/cli/prompts.rs +++ b/crates/chat-cli/src/cli/chat/cli/prompts.rs @@ -19,14 +19,13 @@ use crossterm::{ use thiserror::Error; use unicode_width::UnicodeWidthStr; -use crate::cli::chat::error_formatter::format_mcp_error; use crate::cli::chat::tool_manager::PromptBundle; use crate::cli::chat::{ ChatError, ChatSession, ChatState, }; -use crate::mcp_client::PromptGetResult; +use crate::mcp_client::McpClientError; #[derive(Debug, Error)] pub enum GetPromptError { @@ -46,8 +45,13 @@ pub enum GetPromptError { IncorrectResponseType, #[error("Missing channel")] MissingChannel, + #[error(transparent)] + McpClient(#[from] McpClientError), + #[error(transparent)] + Service(#[from] rmcp::ServiceError), } +/// Command-line arguments for prompt operations #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] #[command(color = clap::ColorChoice::Always, @@ -205,15 +209,23 @@ impl PromptsArgs { } } +/// Subcommands for prompt operations #[deny(missing_docs)] #[derive(Clone, Debug, PartialEq, Subcommand)] pub enum PromptsSubcommand { /// List available prompts from a tool or show all available prompt - List { search_word: Option }, + List { + /// Optional search word to filter prompts + search_word: Option, + }, + /// Get a specific prompt by name Get { #[arg(long, hide = true)] + /// Original input string (hidden) orig_input: Option, + /// Name of the prompt to retrieve name: String, + /// Optional arguments for the prompt arguments: Option>, }, } @@ -273,39 +285,12 @@ impl PromptsSubcommand { }); }, }; - if let Some(err) = prompts.error { - // If we are running into error we should just display the error - // and abort. - let to_display = serde_json::json!(err); - queue!( - session.stderr, - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print("Error encountered while retrieving prompt:"), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - style::SetForegroundColor(Color::Red), - style::Print(format_mcp_error(&to_display)), - style::SetForegroundColor(Color::Reset), - style::Print("\n"), - )?; - } else { - let prompts = prompts - .result - .ok_or(ChatError::Custom("Result field missing from prompt/get request".into()))?; - let prompts = serde_json::from_value::(prompts) - .map_err(|e| ChatError::Custom(format!("Failed to deserialize prompt/get result: {:?}", e).into()))?; - session.pending_prompts.clear(); - session.pending_prompts.append(&mut VecDeque::from(prompts.messages)); - return Ok(ChatState::HandleInput { - input: orig_input.unwrap_or_default(), - }); - } - execute!(session.stderr, style::Print("\n"))?; + session.pending_prompts.clear(); + session.pending_prompts.append(&mut VecDeque::from(prompts.messages)); - Ok(ChatState::PromptUser { - skip_printing_tools: true, + Ok(ChatState::HandleInput { + input: orig_input.unwrap_or_default(), }) } diff --git a/crates/chat-cli/src/cli/chat/cli/subscribe.rs b/crates/chat-cli/src/cli/chat/cli/subscribe.rs index c920908743..36dd670f04 100644 --- a/crates/chat-cli/src/cli/chat/cli/subscribe.rs +++ b/crates/chat-cli/src/cli/chat/cli/subscribe.rs @@ -28,9 +28,11 @@ const SUBSCRIBE_TEXT: &str = color_print::cstr! { "During the upgrade, you'll be Need help? Visit our subscription support page> https://docs.aws.amazon.com/console/amazonq/upgrade-builder-id" }; +/// Arguments for the subscribe command to manage Q Developer Pro subscriptions #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct SubscribeArgs { + /// Open the AWS console to manage an existing subscription #[arg(long)] manage: bool, } diff --git a/crates/chat-cli/src/cli/chat/cli/tools.rs b/crates/chat-cli/src/cli/chat/cli/tools.rs index ce05dce3cb..717649070d 100644 --- a/crates/chat-cli/src/cli/chat/cli/tools.rs +++ b/crates/chat-cli/src/cli/chat/cli/tools.rs @@ -35,6 +35,7 @@ use crate::cli::chat::{ }; use crate::util::consts::MCP_SERVER_TOOL_DELIMITER; +/// Command-line arguments for managing tools in the chat session #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct ToolsArgs { @@ -197,17 +198,20 @@ trust so that no confirmation is required. Refer to the documentation for how to configure tools with your agent: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/agent-format.md#tools-field" )] +/// Subcommands for managing tool permissions and configurations pub enum ToolsSubcommand { /// Show the input schema for all available tools Schema, /// Trust a specific tool or tools for the session Trust { #[arg(required = true)] + /// Names of tools to trust tool_names: Vec, }, /// Revert a tool or tools to per-request confirmation Untrust { #[arg(required = true)] + /// Names of tools to untrust tool_names: Vec, }, /// Trust all tools (equivalent to deprecated /acceptall) diff --git a/crates/chat-cli/src/cli/chat/cli/usage.rs b/crates/chat-cli/src/cli/chat/cli/usage.rs index eca538e2b6..6e6fe0c961 100644 --- a/crates/chat-cli/src/cli/chat/cli/usage.rs +++ b/crates/chat-cli/src/cli/chat/cli/usage.rs @@ -20,6 +20,12 @@ use crate::cli::chat::{ ChatState, }; use crate::os::Os; + +/// Arguments for the usage command that displays token usage statistics and context window +/// information. +/// +/// This command shows how many tokens are being used by different components (context files, tools, +/// assistant responses, and user prompts) within the current chat session's context window. #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct UsageArgs; diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 803970ac9d..56bef53885 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -13,6 +13,12 @@ use crossterm::{ style, }; use eyre::Result; +use rmcp::model::{ + PromptMessage, + PromptMessageContent, + PromptMessageRole, + ResourceContents, +}; use serde::{ Deserialize, Serialize, @@ -72,7 +78,6 @@ use crate::cli::chat::cli::model::{ get_model_info, }; use crate::cli::chat::tools::custom_tool::CustomToolConfig; -use crate::mcp_client::Prompt; use crate::os::Os; pub const CONTEXT_ENTRY_START_HEADER: &str = "--- CONTEXT ENTRY BEGIN ---\n"; @@ -268,23 +273,57 @@ impl ConversationState { /// Appends a collection prompts into history and returns the last message in the collection. /// It asserts that the collection ends with a prompt that assumes the role of user. - pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { + pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { + fn stringify_prompt_message_content(prompt_msg_content: PromptMessageContent) -> String { + match prompt_msg_content { + PromptMessageContent::Text { text } => text, + PromptMessageContent::Image { image } => image.raw.data, + PromptMessageContent::Resource { resource } => { + // TODO: add support for resources for prompt + match resource.raw.resource { + ResourceContents::TextResourceContents { + uri, mime_type, text, .. + } => { + let mime_type = mime_type.as_deref().unwrap_or("unknown"); + format!("Text resource of uri: {uri}, mime_type: {mime_type}, text: {text}") + }, + ResourceContents::BlobResourceContents { + uri, mime_type, blob, .. + } => { + let mime_type = mime_type.as_deref().unwrap_or("unknown"); + format!("Blob resource of uri: {uri}, mime_type: {mime_type}, blob: {blob}") + }, + } + }, + PromptMessageContent::ResourceLink { link } => serde_json::to_string(&link.raw).unwrap_or(format!( + "Resource link with uri: {}, name: {}", + link.raw.uri, link.raw.name + )), + } + } + debug_assert!(self.next_message.is_none(), "next_message should not exist"); - debug_assert!(prompts.back().is_some_and(|p| p.role == crate::mcp_client::Role::User)); + debug_assert!(prompts.back().is_some_and(|p| p.role == PromptMessageRole::User)); let last_msg = prompts.pop_back()?; let (mut candidate_user, mut candidate_asst) = (None::, None::); - while let Some(prompt) = prompts.pop_front() { - let Prompt { role, content } = prompt; + while let Some(prompt_msg) = prompts.pop_front() { + let PromptMessage { + role, + content: prompt_msg_content, + } = prompt_msg; + let content_str = stringify_prompt_message_content(prompt_msg_content); + match role { - crate::mcp_client::Role::User => { - let user_msg = UserMessage::new_prompt(content.to_string(), None); + PromptMessageRole::User => { + let user_msg = UserMessage::new_prompt(content_str, None); candidate_user.replace(user_msg); }, - crate::mcp_client::Role::Assistant => { - let assistant_msg = AssistantMessage::new_response(None, content.into()); + PromptMessageRole::Assistant => { + let assistant_msg = AssistantMessage::new_response(None, content_str); candidate_asst.replace(assistant_msg); }, } + if candidate_asst.is_some() && candidate_user.is_some() { let assistant = candidate_asst.take().unwrap(); let user = candidate_user.take().unwrap(); @@ -296,7 +335,8 @@ impl ConversationState { }); } } - Some(last_msg.content.to_string()) + + Some(stringify_prompt_message_content(last_msg.content)) } pub fn next_user_message(&self) -> Option<&UserMessage> { diff --git a/crates/chat-cli/src/cli/chat/error_formatter.rs b/crates/chat-cli/src/cli/chat/error_formatter.rs deleted file mode 100644 index 96a604bdbe..0000000000 --- a/crates/chat-cli/src/cli/chat/error_formatter.rs +++ /dev/null @@ -1,148 +0,0 @@ -/// Formats an MCP error message to be more user-friendly. -/// -/// This function extracts nested JSON from the error message and formats it -/// with proper indentation and newlines. -/// -/// # Arguments -/// -/// * `err` - A reference to a serde_json::Value containing the error information -/// -/// # Returns -/// -/// A formatted string representation of the error message -pub fn format_mcp_error(err: &serde_json::Value) -> String { - // Extract the message field from the error JSON - if let Some(message) = err.get("message").and_then(|m| m.as_str()) { - // Check if the message contains a nested JSON array - if let Some(start_idx) = message.find('[') { - if let Some(end_idx) = message.rfind(']') { - let prefix = &message[..start_idx].trim(); - let nested_json = &message[start_idx..=end_idx]; - - // Try to parse the nested JSON - if let Ok(nested_value) = serde_json::from_str::(nested_json) { - // Format the error message with the prefix and pretty-printed nested JSON - return format!( - "{}\n{}", - prefix, - serde_json::to_string_pretty(&nested_value).unwrap_or_else(|_| nested_json.to_string()) - ); - } - } - } - } - - // Fallback if message field is missing or if we couldn't extract and parse nested JSON - serde_json::to_string_pretty(err).unwrap_or_else(|_| format!("{:?}", err)) -} - -#[cfg(test)] -mod tests { - use serde_json::json; - - use super::*; - - #[test] - fn test_format_mcp_error_with_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt agent_script_coco_was_sev2_ticket_details_retrieve: [\n {\n \"code\": \"invalid_type\",\n \"expected\": \"object\",\n \"received\": \"undefined\",\n \"path\": [],\n \"message\": \"Required\"\n }\n]" - }); - - let formatted = format_mcp_error(&error); - - // Extract the prefix and JSON part from the formatted string - let parts: Vec<&str> = formatted.split('\n').collect(); - let prefix = parts[0]; - let json_part = &formatted[prefix.len() + 1..]; - - // Check that the prefix is correct - assert_eq!( - prefix, - "MCP error -32602: Invalid arguments for prompt agent_script_coco_was_sev2_ticket_details_retrieve:" - ); - - // Parse the JSON part to compare the actual content rather than the exact string - let parsed_json: serde_json::Value = serde_json::from_str(json_part).expect("Failed to parse JSON part"); - - // Expected JSON structure - let expected_json = json!([ - { - "code": "invalid_type", - "expected": "object", - "received": "undefined", - "path": [], - "message": "Required" - } - ]); - - // Compare the parsed JSON values - assert_eq!(parsed_json, expected_json); - } - - #[test] - fn test_format_mcp_error_without_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt" - }); - - let formatted = format_mcp_error(&error); - - assert_eq!( - formatted, - "{\n \"code\": -32602,\n \"message\": \"MCP error -32602: Invalid arguments for prompt\"\n}" - ); - } - - #[test] - fn test_format_mcp_error_non_mcp_error() { - let error = json!({ - "error": "Unknown error occurred" - }); - - let formatted = format_mcp_error(&error); - - // Should pretty-print the entire error - assert_eq!(formatted, "{\n \"error\": \"Unknown error occurred\"\n}"); - } - - #[test] - fn test_format_mcp_error_empty_message() { - let error = json!({ - "code": -32602, - "message": "" - }); - - let formatted = format_mcp_error(&error); - - assert_eq!(formatted, "{\n \"code\": -32602,\n \"message\": \"\"\n}"); - } - - #[test] - fn test_format_mcp_error_missing_message() { - let error = json!({ - "code": -32602 - }); - - let formatted = format_mcp_error(&error); - - assert_eq!(formatted, "{\n \"code\": -32602\n}"); - } - - #[test] - fn test_format_mcp_error_malformed_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt: [{\n \"code\": \"invalid_type\",\n \"expected\": \"object\",\n \"received\": \"undefined\",\n \"path\": [],\n \"message\": \"Required\"\n" - }); - - let formatted = format_mcp_error(&error); - - // Should return the pretty-printed JSON since the nested JSON is malformed - assert_eq!( - formatted, - "{\n \"code\": -32602,\n \"message\": \"MCP error -32602: Invalid arguments for prompt: [{\\n \\\"code\\\": \\\"invalid_type\\\",\\n \\\"expected\\\": \\\"object\\\",\\n \\\"received\\\": \\\"undefined\\\",\\n \\\"path\\\": [],\\n \\\"message\\\": \\\"Required\\\"\\n\"\n}" - ); - } -} diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index c9b6b058aa..43014bb857 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -2,7 +2,6 @@ pub mod cli; mod consts; pub mod context; mod conversation; -mod error_formatter; mod input_source; mod message; mod parse; @@ -11,7 +10,7 @@ mod line_tracker; mod parser; mod prompt; mod prompt_parser; -mod server_messenger; +pub mod server_messenger; #[cfg(unix)] mod skim_integration; mod token_counter; @@ -83,6 +82,7 @@ use parser::{ SendMessageStream, }; use regex::Regex; +use rmcp::model::PromptMessage; use spinners::{ Spinner, Spinners, @@ -149,7 +149,6 @@ use crate::cli::chat::cli::prompts::{ use crate::cli::chat::message::UserMessage; use crate::cli::chat::util::sanitize_unicode_tags; use crate::database::settings::Setting; -use crate::mcp_client::Prompt; use crate::os::Os; use crate::telemetry::core::{ AgentConfigInitArgs, @@ -594,7 +593,7 @@ pub struct ChatSession { /// Any failed requests that could be useful for error report/debugging failed_request_ids: Vec, /// Pending prompts to be sent - pending_prompts: VecDeque, + pending_prompts: VecDeque, interactive: bool, inner: Option, ctrlc_rx: broadcast::Receiver<()>, @@ -2687,7 +2686,7 @@ impl ChatSession { .set_tool_use_id(tool_use_id.clone()) .set_tool_name(tool_use.name.clone()) .utterance_id(self.conversation.message_id().map(|s| s.to_string())); - match self.conversation.tool_manager.get_tool_from_tool_use(tool_use) { + match self.conversation.tool_manager.get_tool_from_tool_use(tool_use).await { Ok(mut tool) => { // Apply non-Q-generated context to tools self.contextualize_tool(&mut tool); @@ -2848,7 +2847,7 @@ impl ChatSession { style::SetForegroundColor(Color::Reset), style::Print(" from mcp server "), style::SetForegroundColor(Color::Magenta), - style::Print(tool.client.get_server_name()), + style::Print(&tool.server_name), style::SetForegroundColor(Color::Reset), )?; } diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs index aaf685c399..a15bd8d696 100644 --- a/crates/chat-cli/src/cli/chat/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/server_messenger.rs @@ -1,48 +1,54 @@ +use rmcp::model::{ + ListPromptsResult, + ListResourceTemplatesResult, + ListResourcesResult, + ListToolsResult, +}; +use rmcp::{ + Peer, + RoleClient, +}; use tokio::sync::mpsc::{ Receiver, Sender, channel, }; -use crate::mcp_client::{ +use crate::mcp_client::messenger::{ Messenger, MessengerError, - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ToolsListResult, + MessengerResult, + Result, }; #[allow(dead_code)] #[derive(Debug)] pub enum UpdateEventMessage { - ToolsListResult { + ListToolsResult { server_name: String, - result: eyre::Result, - pid: Option, + result: Result, + peer: Option>, }, - PromptsListResult { + ListPromptsResult { server_name: String, - result: eyre::Result, - pid: Option, + result: Result, + peer: Option>, }, - ResourcesListResult { + ListResourcesResult { server_name: String, - result: eyre::Result, - pid: Option, + result: Result, + peer: Option>, }, ResourceTemplatesListResult { server_name: String, - result: eyre::Result, - pid: Option, + result: Result, + peer: Option>, }, InitStart { server_name: String, - pid: Option, }, Deinit { server_name: String, - pid: Option, }, } @@ -64,7 +70,6 @@ impl ServerMessengerBuilder { ServerMessenger { server_name, update_event_sender: self.update_event_sender.clone(), - pid: None, } } } @@ -73,30 +78,37 @@ impl ServerMessengerBuilder { pub struct ServerMessenger { pub server_name: String, pub update_event_sender: Sender, - pub pid: Option, } #[async_trait::async_trait] impl Messenger for ServerMessenger { - async fn send_tools_list_result(&self, result: eyre::Result) -> Result<(), MessengerError> { + async fn send_tools_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult { Ok(self .update_event_sender - .send(UpdateEventMessage::ToolsListResult { + .send(UpdateEventMessage::ListToolsResult { server_name: self.server_name.clone(), result, - pid: self.pid, + peer, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) } - async fn send_prompts_list_result(&self, result: eyre::Result) -> Result<(), MessengerError> { + async fn send_prompts_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult { Ok(self .update_event_sender - .send(UpdateEventMessage::PromptsListResult { + .send(UpdateEventMessage::ListPromptsResult { server_name: self.server_name.clone(), result, - pid: self.pid, + peer, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -104,14 +116,15 @@ impl Messenger for ServerMessenger { async fn send_resources_list_result( &self, - result: eyre::Result, - ) -> Result<(), MessengerError> { + result: Result, + peer: Option>, + ) -> MessengerResult { Ok(self .update_event_sender - .send(UpdateEventMessage::ResourcesListResult { + .send(UpdateEventMessage::ListResourcesResult { server_name: self.server_name.clone(), result, - pid: self.pid, + peer, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -119,25 +132,25 @@ impl Messenger for ServerMessenger { async fn send_resource_templates_list_result( &self, - result: eyre::Result, - ) -> Result<(), MessengerError> { + result: Result, + peer: Option>, + ) -> MessengerResult { Ok(self .update_event_sender .send(UpdateEventMessage::ResourceTemplatesListResult { server_name: self.server_name.clone(), result, - pid: self.pid, + peer, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) } - async fn send_init_msg(&self) -> Result<(), MessengerError> { + async fn send_init_msg(&self) -> MessengerResult { Ok(self .update_event_sender .send(UpdateEventMessage::InitStart { server_name: self.server_name.clone(), - pid: self.pid, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -146,9 +159,8 @@ impl Messenger for ServerMessenger { fn send_deinit_msg(&self) { let sender = self.update_event_sender.clone(); let server_name = self.server_name.clone(); - let pid = self.pid; tokio::spawn(async move { - let _ = sender.send(UpdateEventMessage::Deinit { server_name, pid }).await; + let _ = sender.send(UpdateEventMessage::Deinit { server_name }).await; }); } diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 3171459915..7ca372779c 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -32,12 +32,14 @@ use crossterm::{ terminal, }; use eyre::Report; -use futures::{ - StreamExt, - future, - stream, -}; +use futures::future; use regex::Regex; +use rmcp::ServiceError; +use rmcp::model::{ + GetPromptRequestParam, + GetPromptResult, + Prompt, +}; use tokio::signal::ctrl_c; use tokio::sync::{ Mutex, @@ -68,10 +70,7 @@ use crate::cli::chat::server_messenger::{ ServerMessengerBuilder, UpdateEventMessage, }; -use crate::cli::chat::tools::custom_tool::{ - CustomTool, - CustomToolClient, -}; +use crate::cli::chat::tools::custom_tool::CustomTool; use crate::cli::chat::tools::execute::ExecuteCommand; use crate::cli::chat::tools::fs_read::FsRead; use crate::cli::chat::tools::fs_write::FsWrite; @@ -88,10 +87,10 @@ use crate::cli::chat::tools::{ }; use crate::database::Database; use crate::database::settings::Setting; +use crate::mcp_client::messenger::Messenger; use crate::mcp_client::{ - JsonRpcResponse, - Messenger, - PromptGet, + InitializedMcpClient, + McpClientService, }; use crate::os::Os; use crate::telemetry::TelemetryThread; @@ -160,6 +159,7 @@ pub struct ToolManagerBuilder { has_new_stuff: Arc, mcp_load_record: Arc>>>, new_tool_specs: NewToolSpecs, + pending_clients: Option>>>, is_first_launch: bool, agent: Option>>, } @@ -176,6 +176,7 @@ impl Default for ToolManagerBuilder { has_new_stuff: Default::default(), mcp_load_record: Default::default(), new_tool_specs: Default::default(), + pending_clients: Default::default(), is_first_launch: true, agent: Default::default(), } @@ -196,6 +197,7 @@ impl From<&mut ToolManager> for ToolManagerBuilder { has_new_stuff: value.has_new_stuff.clone(), mcp_load_record: value.mcp_load_record.clone(), new_tool_specs: value.new_tool_specs.clone(), + pending_clients: Some(value.pending_clients.clone()), // if we are getting a builder from an instantiated tool manager this field would be // false is_first_launch: false, @@ -271,8 +273,8 @@ impl ToolManagerBuilder { .collect(); let pre_initialized = enabled_servers - .into_iter() - .filter_map(|(server_name, server_config)| { + .iter() + .filter(|(server_name, _)| { if server_name == "builtin" { let _ = queue!( output, @@ -287,13 +289,26 @@ impl ToolManagerBuilder { style::ResetColor, style::Print(" (it is used to denote native tools)\n") ); - None + false } else { - let custom_tool_client = CustomToolClient::from_config(server_name.clone(), server_config, os); - Some((server_name, custom_tool_client)) + true } }) - .collect::>(); + .collect::>(); + + let mut clients = HashMap::::new(); + let new_tool_specs = self.new_tool_specs; + let has_new_stuff = self.has_new_stuff; + let pending = self.pending_clients.unwrap_or(Arc::new(RwLock::new({ + let mut pending = HashSet::::new(); + pending.extend(pre_initialized.iter().map(|(name, _)| name.clone())); + pending + }))); + let notify = Arc::new(Notify::new()); + let load_record = self.mcp_load_record; + let agent = self.agent.unwrap_or_default(); + let database = os.database.clone(); + let mut messenger_builder = self.messenger_builder.take(); let mut loading_servers = HashMap::::new(); for (server_name, _) in &pre_initialized { @@ -308,16 +323,6 @@ impl ToolManagerBuilder { let (loading_display_task, loading_status_sender) = spawn_display_task(interactive, total, disabled_servers, output); - let mut clients = HashMap::>::new(); - let new_tool_specs = self.new_tool_specs; - let has_new_stuff = self.has_new_stuff; - let pending = Arc::new(RwLock::new(HashSet::::new())); - let notify = Arc::new(Notify::new()); - let load_record = self.mcp_load_record; - let agent = self.agent.unwrap_or_default(); - let database = os.database.clone(); - let mut messenger_builder = self.messenger_builder.take(); - // This is the orchestrator task that serves as a bridge between tool manager and mcp // clients for server initiated async events if let (Some(prompt_list_sender), Some(prompt_list_receiver)) = ( @@ -358,19 +363,29 @@ impl ToolManagerBuilder { debug_assert!(messenger_builder.is_some()); let messenger_builder = messenger_builder.unwrap(); - for (mut name, init_res) in pre_initialized { - let mut messenger = messenger_builder.build_with_name(name.clone()); + let pre_initialized = enabled_servers + .into_iter() + .map(|(server_name, server_config)| { + ( + server_name.clone(), + McpClientService::new( + server_name.clone(), + server_config, + messenger_builder.build_with_name(server_name), + ), + ) + }) + .collect::>(); + + for (mut name, mcp_client) in pre_initialized { + let init_res = mcp_client.init(os).await; match init_res { - Ok(mut client) => { - let pid = client.get_pid(); - messenger.pid = pid; - client.assign_messenger(Box::new(messenger)); - let mut client = Arc::new(client); - while let Some(collided_client) = clients.insert(name.clone(), client) { + Ok(mut running_service) => { + while let Some(collided_service) = clients.insert(name.clone(), running_service) { // to avoid server name collision we are going to circumvent this by // appending the name with 1 name.push('1'); - client = collided_client; + running_service = collided_service; } }, Err(e) => { @@ -379,7 +394,7 @@ impl ToolManagerBuilder { .send_mcp_server_init( &os.database, conversation_id.clone(), - name, + name.clone(), Some(e.to_string()), 0, Some("".to_string()), @@ -388,7 +403,11 @@ impl ToolManagerBuilder { ) .await .ok(); - let _ = messenger.send_tools_list_result(Err(e)).await; + + let temp_messenger = messenger_builder.build_with_name(name); + let _ = temp_messenger + .send_tools_list_result(Err(ServiceError::UnexpectedResponse), None) + .await; }, } } @@ -428,7 +447,7 @@ pub struct PromptBundle { /// The server name from which the prompt is offered / exposed pub server_name: String, /// The prompt get (info with which a prompt is retrieved) cached - pub prompt_get: PromptGet, + pub prompt_get: Prompt, } #[derive(Clone, Debug)] @@ -509,7 +528,7 @@ pub struct ToolManager { /// Map of server names to their corresponding client instances. /// These clients are used to communicate with MCP servers. - pub clients: HashMap>, + pub clients: HashMap, /// A list of client names that are still in the process of being initialized pub pending_clients: Arc>>, @@ -579,7 +598,6 @@ impl Clone for ToolManager { fn clone(&self) -> Self { Self { conversation_id: self.conversation_id.clone(), - clients: self.clients.clone(), has_new_stuff: self.has_new_stuff.clone(), new_tool_specs: self.new_tool_specs.clone(), tn_map: self.tn_map.clone(), @@ -603,7 +621,32 @@ impl ToolManager { /// function) /// - Calling load tools pub async fn swap_agent(&mut self, os: &mut Os, output: &mut impl Write, agent: &Agent) -> eyre::Result<()> { - self.clients.clear(); + let to_evict = self.clients.drain().collect::>(); + tokio::spawn(async move { + for (server_name, initialized_client) in to_evict { + info!("Evicting {server_name} due to agent swap"); + match initialized_client { + InitializedMcpClient::Pending(handle) => { + let server_name_clone = server_name.clone(); + tokio::spawn(async move { + match handle.await { + Ok(Ok(client)) => match client.cancel().await { + Ok(_) => info!("Server {server_name_clone} evicted due to agent swap"), + Err(e) => error!("Server {server_name_clone} has failed to cancel: {e}"), + }, + Ok(Err(_)) | Err(_) => { + error!("Server {server_name_clone} has failed to cancel"); + }, + } + }); + }, + InitializedMcpClient::Ready(running_service) => match running_service.cancel().await { + Ok(_) => info!("Server {server_name} evicted due to agent swap"), + Err(e) => error!("Server {server_name} has failed to cancel: {e}"), + }, + } + } + }); let mut agent_lock = self.agent.lock().await; *agent_lock = agent.clone(); @@ -615,9 +658,7 @@ impl ToolManager { let mut new_tool_manager = builder.build(os, Box::new(std::io::sink()), true).await?; std::mem::swap(self, &mut new_tool_manager); - // we can discard the output here and let background server load take care of getting the - // new tools - let _ = self.load_tools(os, output).await?; + self.load_tools(os, output).await?; Ok(()) } @@ -684,20 +725,7 @@ impl ToolManager { tool_specs }; - let load_tools = self - .clients - .values() - .map(|c| { - let clone = Arc::clone(c); - async move { clone.init().await } - }) - .collect::>(); - let initial_poll = stream::iter(load_tools) - .map(|async_closure| tokio::spawn(async_closure)) - .buffer_unordered(20); - tokio::spawn(async move { - initial_poll.collect::>().await; - }); + // We need to cast it to erase the type otherwise the compiler will default to static // dispatch, which would result in an error of inconsistent match arm return type. let timeout_fut: Pin>> = if self.clients.is_empty() || !self.is_first_launch { @@ -785,7 +813,7 @@ impl ToolManager { Ok(self.schema.clone()) } - pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { + pub async fn get_tool_from_tool_use(&mut self, value: AssistantToolUse) -> Result { let map_err = |parse_error| ToolResult { tool_use_id: value.id.clone(), content: vec![ToolResultContentBlock::Text(format!( @@ -831,7 +859,7 @@ impl ToolManager { }) }, }?; - let Some(client) = self.clients.get(server_name) else { + let Some(client) = self.clients.get_mut(server_name) else { return Err(ToolResult { tool_use_id: value.id, content: vec![ToolResultContentBlock::Text(format!( @@ -840,22 +868,20 @@ impl ToolManager { status: ToolResultStatus::Error, }); }; - // The tool input schema has the shape of { type, properties }. - // The field "params" expected by MCP is { name, arguments }, where name is the - // name of the tool being invoked, - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools. - // The field "arguments" is where ToolUse::args belong. - let mut params = serde_json::Map::::new(); - params.insert("name".to_owned(), serde_json::Value::String(tool_name.to_owned())); - params.insert("arguments".to_owned(), value.args); - let params = serde_json::Value::Object(params); - let custom_tool = CustomTool { + + let running_service = (*client.get_running_service().await.map_err(|e| ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!("Mcp tool client not ready: {e}"))], + status: ToolResultStatus::Error, + })?) + .clone(); + + Tool::Custom(CustomTool { name: tool_name.to_owned(), - client: client.clone(), - method: "tools/call".to_owned(), - params: Some(params), - }; - Tool::Custom(custom_tool) + server_name: server_name.to_owned(), + client: running_service, + params: value.args.as_object().cloned(), + }) }, }) } @@ -951,10 +977,10 @@ impl ToolManager { } pub async fn get_prompt( - &self, + &mut self, name: String, arguments: Option>, - ) -> Result { + ) -> Result { let (server_name, prompt_name) = match name.split_once('/') { None => (None::, Some(name.clone())), Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), @@ -1013,9 +1039,9 @@ impl ToolManager { }; let server_name = &bundle.server_name; - let client = self.clients.get(server_name).ok_or(GetPromptError::MissingClient)?; + let client = self.clients.get_mut(server_name).ok_or(GetPromptError::MissingClient)?; let PromptBundle { prompt_get, .. } = bundle; - let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, &arguments) { + let arguments = if let (Some(schema), Some(value)) = (&prompt_get.arguments, &arguments) { let params = schema.iter().zip(value.iter()).fold( HashMap::::new(), |mut acc, (prompt_get_arg, value)| { @@ -1023,19 +1049,20 @@ impl ToolManager { acc }, ); - Some(serde_json::json!(params)) + Some( + params + .into_iter() + .map(|(k, v)| (k, serde_json::Value::String(v))) + .collect(), + ) } else { None }; - let params = { - let mut params = serde_json::Map::new(); - params.insert("name".to_string(), serde_json::Value::String(prompt_name)); - if let Some(args) = args { - params.insert("arguments".to_string(), args); - } - Some(serde_json::Value::Object(params)) - }; - let resp = client.request("prompts/get", params).await?; + + let params = GetPromptRequestParam { name, arguments }; + let running_service = client.get_running_service().await?; + let resp = running_service.get_prompt(params).await?; + Ok(resp) }, (None, _) => Err(GetPromptError::PromptNotFound(prompt_name)), @@ -1296,10 +1323,10 @@ fn spawn_orchestrator_task( // request method on the mcp client no longer buffers all the pages from // list calls. match msg { - UpdateEventMessage::ToolsListResult { + UpdateEventMessage::ListToolsResult { server_name, result, - pid, + peer, } => { let time_taken = loading_servers .remove(&server_name) @@ -1311,11 +1338,8 @@ fn spawn_orchestrator_task( let result_tools = match &result { Ok(tools_result) => { - let names: Vec = tools_result - .tools - .iter() - .filter_map(|tool| tool.get("name")?.as_str().map(String::from)) - .collect(); + let names: Vec = + tools_result.tools.iter().map(|tool| tool.name.to_string()).collect(); names }, Err(_) => vec![], @@ -1369,40 +1393,27 @@ fn spawn_orchestrator_task( match result { Ok(result) => { - if pid.is_none_or(|pid| !is_process_running(pid)) { - let pid = pid.map_or("unknown".to_string(), |pid| pid.to_string()); - info!( - "Received tool list result from {server_name} but its associated process {pid} is no longer running. Ignoring." - ); - - let mut buf_writer = BufWriter::new(&mut *record_temp_buf); - let _ = queue_failure_message( - &server_name, - &eyre::eyre!("Process associated is no longer running"), - &time_taken, - &mut buf_writer, - ); - let _ = buf_writer.flush(); - drop(buf_writer); - let record_content = String::from_utf8_lossy(record_temp_buf).to_string(); - let record = LoadingRecord::Err(record_content); - - load_record - .lock() - .await - .entry(server_name.clone()) - .and_modify(|load_record| { - load_record.push(record.clone()); - }) - .or_insert(vec![record]); - + if let Some(peer) = peer { + if peer.is_transport_closed() { + error!( + "Received tool list result from {server_name} but transport has been closed. Ignoring." + ); + return; + } + } else { + error!("Received tool list result from {server_name} without a peer. Ignoring."); return; } let mut specs = result .tools .into_iter() - .filter_map(|v| serde_json::from_value::(v).ok()) + .map(|v| ToolSpec { + name: v.name.to_string(), + description: v.description.as_ref().map(|d| d.to_string()).unwrap_or_default(), + input_schema: crate::cli::chat::tools::InputSchema(v.schema_as_json_value()), + tool_origin: ToolOrigin::Native, + }) .filter(|spec| tool_filter.should_include(&spec.name)) .collect::>(); let mut sanitized_mapping = HashMap::::new(); @@ -1418,6 +1429,7 @@ fn spawn_orchestrator_task( &result_tools, ) .await; + if let Some(sender) = &loading_status_sender { // Anomalies here are not considered fatal, thus we shall give // warnings. @@ -1476,7 +1488,13 @@ fn spawn_orchestrator_task( error!("Error loading server {server_name}: {:?}", e); // Maintain a record of the server load: let mut buf_writer = BufWriter::new(&mut *record_temp_buf); - let _ = queue_failure_message(server_name.as_str(), &e, &time_taken, &mut buf_writer); + let fail_load_msg = eyre::eyre!("{}", e); + let _ = queue_failure_message( + server_name.as_str(), + &fail_load_msg, + &time_taken, + &mut buf_writer, + ); let _ = buf_writer.flush(); drop(buf_writer); let record = String::from_utf8_lossy(record_temp_buf).to_string(); @@ -1494,7 +1512,7 @@ fn spawn_orchestrator_task( if let Some(sender) = &loading_status_sender { let msg = LoadingMsg::Error { name: server_name.clone(), - msg: e, + msg: eyre::eyre!("{}", e.to_string()), time: time_taken, }; if let Err(e) = sender.send(msg).await { @@ -1514,17 +1532,21 @@ fn spawn_orchestrator_task( } } }, - UpdateEventMessage::PromptsListResult { + UpdateEventMessage::ListPromptsResult { server_name, result, - pid, + peer, } => match result { - Ok(prompt_list_result) if pid.is_some() => { - let pid = pid.unwrap(); - if !is_process_running(pid) { - info!( - "Received prompt list result from {server_name} but its associated process {pid} is no longer running. Ignoring." - ); + Ok(prompt_list_result) => { + if let Some(peer) = peer { + if peer.is_transport_closed() { + error!( + "Received prompt list result from {server_name} but transport has been closed. Ignoring." + ); + return; + } + } else { + error!("Received prompt list result from {server_name} without a peer. Ignoring."); return; } // We first need to clear all the PromptGets that are associated with @@ -1535,34 +1557,28 @@ fn spawn_orchestrator_task( .for_each(|bundles| bundles.retain(|bundle| bundle.server_name != server_name)); // And then we update them with the new comers - for result in prompt_list_result.prompts { - let Ok(prompt_get) = serde_json::from_value::(result) else { - error!("Failed to deserialize prompt get from server {server_name}"); - continue; - }; + for prompt in prompt_list_result.prompts { prompts - .entry(prompt_get.name.clone()) + .entry(prompt.name.clone()) .and_modify(|bundles| { bundles.push(PromptBundle { server_name: server_name.clone(), - prompt_get: prompt_get.clone(), + prompt_get: prompt.clone(), }); }) .or_insert_with(|| { vec![PromptBundle { server_name: server_name.clone(), - prompt_get, + prompt_get: prompt, }] }); } }, - Ok(_) => { - error!("Received prompt list result without pid from {server_name}. Ignoring."); - }, Err(e) => { error!("Error fetching prompts from server {server_name}: {:?}", e); let mut buf_writer = BufWriter::new(&mut *record_temp_buf); - let _ = queue_prompts_load_error_message(&server_name, &e, &mut buf_writer); + let msg = eyre::eyre!("{}", e); + let _ = queue_prompts_load_error_message(&server_name, &msg, &mut buf_writer); let _ = buf_writer.flush(); drop(buf_writer); let record = String::from_utf8_lossy(record_temp_buf).to_string(); @@ -1577,16 +1593,8 @@ fn spawn_orchestrator_task( .or_insert(vec![record]); }, }, - UpdateEventMessage::ResourcesListResult { - server_name: _, - result: _, - pid: _, - } => {}, - UpdateEventMessage::ResourceTemplatesListResult { - server_name: _, - result: _, - pid: _, - } => {}, + UpdateEventMessage::ListResourcesResult { .. } => {}, + UpdateEventMessage::ResourceTemplatesListResult { .. } => {}, UpdateEventMessage::InitStart { server_name, .. } => { pending.write().await.insert(server_name.clone()); loading_servers.insert(server_name, std::time::Instant::now()); @@ -1778,22 +1786,6 @@ fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) - } } -// Add this function to check if a process is still running -fn is_process_running(pid: u32) -> bool { - #[cfg(unix)] - { - let system = sysinfo::System::new_all(); - system.process(sysinfo::Pid::from(pid as usize)).is_some() - } - #[cfg(windows)] - { - // TODO: fill in the process health check for windows when when we officially support - // windows - _ = pid; - true - } -} - fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { Ok(queue!( output, diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index fafb55a9a4..be2a7a8831 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -1,19 +1,19 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::io::Write; -use std::sync::Arc; use crossterm::{ queue, style, }; use eyre::Result; -use regex::Regex; +use rmcp::RoleClient; +use rmcp::model::CallToolRequestParam; use schemars::JsonSchema; use serde::{ Deserialize, Serialize, }; -use tokio::sync::RwLock; use tracing::warn; use super::InvokeOutput; @@ -23,17 +23,6 @@ use crate::cli::agent::{ }; use crate::cli::chat::CONTINUATION_LINE; use crate::cli::chat::token_counter::TokenCounter; -use crate::mcp_client::{ - Client as McpClient, - ClientConfig as McpClientConfig, - JsonRpcResponse, - JsonRpcStdioTransport, - MessageContent, - Messenger, - ServerCapabilities, - StdioTransport, - ToolCallResult, -}; use crate::os::Os; use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::util::pattern_matching::matches_any_pattern; @@ -64,175 +53,41 @@ pub fn default_timeout() -> u64 { 120 * 1000 } -/// Substitutes environment variables in the format ${env:VAR_NAME} with their actual values -fn substitute_env_vars(input: &str, env: &crate::os::Env) -> String { - // Create a regex to match ${env:VAR_NAME} pattern - let re = Regex::new(r"\$\{env:([^}]+)\}").unwrap(); - - re.replace_all(input, |caps: ®ex::Captures<'_>| { - let var_name = &caps[1]; - env.get(var_name).unwrap_or_else(|_| format!("${{{}}}", var_name)) - }) - .to_string() -} - -/// Process a HashMap of environment variables, substituting any ${env:VAR_NAME} patterns -/// with their actual values from the environment -fn process_env_vars(env_vars: &mut HashMap, env: &crate::os::Env) { - for (_, value) in env_vars.iter_mut() { - *value = substitute_env_vars(value, env); - } -} - -#[derive(Debug)] -pub enum CustomToolClient { - Stdio { - /// This is the server name as recognized by the model (post sanitized) - server_name: String, - client: McpClient, - server_capabilities: RwLock>, - }, -} - -impl CustomToolClient { - // TODO: add support for http transport - pub fn from_config(server_name: String, config: CustomToolConfig, os: &crate::os::Os) -> Result { - let CustomToolConfig { - command, - args, - env, - timeout, - disabled: _, - .. - } = config; - - // Process environment variables if present - let processed_env = env.map(|mut env_vars| { - process_env_vars(&mut env_vars, &os.env); - env_vars - }); - - let mcp_client_config = McpClientConfig { - server_name: server_name.clone(), - bin_path: command.clone(), - args, - timeout, - client_info: serde_json::json!({ - "name": "Q CLI Chat", - "version": "1.0.0" - }), - env: processed_env, - }; - let client = McpClient::::from_config(mcp_client_config)?; - Ok(CustomToolClient::Stdio { - server_name, - client, - server_capabilities: RwLock::new(None), - }) - } - - pub async fn init(&self) -> Result<()> { - match self { - CustomToolClient::Stdio { - client, - server_capabilities, - .. - } => { - if let Some(messenger) = &client.messenger { - let _ = messenger.send_init_msg().await; - } - // We'll need to first initialize. This is the handshake every client and server - // needs to do before proceeding to anything else - let cap = client.init().await?; - // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 - // So don't worry about the tidiness for now - server_capabilities.write().await.replace(cap); - Ok(()) - }, - } - } - - pub fn assign_messenger(&mut self, messenger: Box) { - match self { - CustomToolClient::Stdio { client, .. } => { - client.messenger = Some(messenger); - }, - } - } - - pub fn get_server_name(&self) -> &str { - match self { - CustomToolClient::Stdio { server_name, .. } => server_name.as_str(), - } - } - - pub async fn request(&self, method: &str, params: Option) -> Result { - match self { - CustomToolClient::Stdio { client, .. } => Ok(client.request(method, params).await?), - } - } - - pub fn get_pid(&self) -> Option { - match self { - CustomToolClient::Stdio { client, .. } => client.server_process_id.as_ref().map(|pid| pid.as_u32()), - } - } - - #[allow(dead_code)] - pub async fn notify(&self, method: &str, params: Option) -> Result<()> { - match self { - CustomToolClient::Stdio { client, .. } => Ok(client.notify(method, params).await?), - } - } -} - /// Represents a custom tool that can be invoked through the Model Context Protocol (MCP). #[derive(Clone, Debug)] pub struct CustomTool { /// Actual tool name as recognized by its MCP server. This differs from the tool names as they /// are seen by the model since they are not prefixed by its MCP server name. pub name: String, + /// The name of the MCP (Model Context Protocol) server that hosts this tool. + /// This is used to identify which server instance the tool belongs to and is + /// prefixed to the tool name when presented to the model for disambiguation. + pub server_name: String, /// Reference to the client that manages communication with the tool's server process. - pub client: Arc, - /// The method name to call on the tool's server, following the JSON-RPC convention. - /// This corresponds to a specific functionality provided by the tool. - pub method: String, + pub client: rmcp::Peer, /// Optional parameters to pass to the tool when invoking the method. /// Structured as a JSON value to accommodate various parameter types and structures. - pub params: Option, + pub params: Option>, } impl CustomTool { pub async fn invoke(&self, _os: &Os, _updates: impl Write) -> Result { - // Assuming a response shape as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools - let resp = self.client.request(self.method.as_str(), self.params.clone()).await?; - let result = match resp.result { - Some(result) => result, - None => { - let failure = resp.error.map_or("Unknown error encountered".to_string(), |err| { - serde_json::to_string(&err).unwrap_or_default() - }); - return Err(eyre::eyre!(failure)); - }, + let params = CallToolRequestParam { + name: Cow::from(self.name.clone()), + arguments: self.params.clone(), }; - match serde_json::from_value::(result.clone()) { - Ok(mut de_result) => { - for content in &mut de_result.content { - if let MessageContent::Image { data, .. } = content { - *data = format!("Redacted base64 encoded string of an image of size {}", data.len()); - } - } - Ok(InvokeOutput { - output: super::OutputKind::Json(serde_json::json!(de_result)), - }) - }, - Err(e) => { - warn!("Tool call result deserialization failed: {:?}", e); - Ok(InvokeOutput { - output: super::OutputKind::Json(result.clone()), - }) - }, + let resp = self.client.call_tool(params).await?; + + if resp.is_error.is_none_or(|v| !v) { + Ok(InvokeOutput { + output: super::OutputKind::Json(serde_json::json!(resp)), + }) + } else { + warn!("Tool call for {} failed", self.name); + Ok(InvokeOutput { + output: super::OutputKind::Json(serde_json::json!(resp)), + }) } } @@ -271,17 +126,14 @@ impl CustomTool { } pub fn get_input_token_size(&self) -> usize { - TokenCounter::count_tokens(self.method.as_str()) - + TokenCounter::count_tokens(self.params.as_ref().map_or("", |p| p.as_str().unwrap_or_default())) + TokenCounter::count_tokens( + &serde_json::to_string(self.params.as_ref().unwrap_or(&serde_json::Map::new())).unwrap_or_default(), + ) } pub fn eval_perm(&self, _os: &Os, agent: &Agent) -> PermissionEvalResult { - let Self { - name: tool_name, - client, - .. - } = self; - let server_name = client.get_server_name(); + let Self { name: tool_name, .. } = self; + let server_name = &self.server_name; let server_pattern = format!("@{server_name}"); if agent.allowed_tools.contains(&server_pattern) { @@ -296,58 +148,3 @@ impl CustomTool { PermissionEvalResult::Ask } } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_substitute_env_vars() { - // Set a test environment variable - let os = Os::new().await.unwrap(); - unsafe { - os.env.set_var("TEST_VAR", "test_value"); - } - - // Test basic substitution - assert_eq!( - substitute_env_vars("Value is ${env:TEST_VAR}", &os.env), - "Value is test_value" - ); - - // Test multiple substitutions - assert_eq!( - substitute_env_vars("${env:TEST_VAR} and ${env:TEST_VAR}", &os.env), - "test_value and test_value" - ); - - // Test non-existent variable - assert_eq!( - substitute_env_vars("${env:NON_EXISTENT_VAR}", &os.env), - "${NON_EXISTENT_VAR}" - ); - - // Test mixed content - assert_eq!( - substitute_env_vars("Prefix ${env:TEST_VAR} suffix", &os.env), - "Prefix test_value suffix" - ); - } - - #[tokio::test] - async fn test_process_env_vars() { - let os = Os::new().await.unwrap(); - unsafe { - os.env.set_var("TEST_VAR", "test_value"); - } - - let mut env_vars = HashMap::new(); - env_vars.insert("KEY1".to_string(), "Value is ${env:TEST_VAR}".to_string()); - env_vars.insert("KEY2".to_string(), "No substitution".to_string()); - - process_env_vars(&mut env_vars, &os.env); - - assert_eq!(env_vars.get("KEY1").unwrap(), "Value is test_value"); - assert_eq!(env_vars.get("KEY2").unwrap(), "No substitution"); - } -} diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs index c2185e36be..56c3cb2178 100644 --- a/crates/chat-cli/src/cli/mod.rs +++ b/crates/chat-cli/src/cli/mod.rs @@ -1,5 +1,5 @@ mod agent; -mod chat; +pub mod chat; mod debug; mod diagnostics; mod feed; diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 27918c5773..92c33f074c 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -1,1140 +1,469 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::process::Stdio; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - RwLock as SyncRwLock, -}; -use std::time::Duration; - -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; -use tokio::time; -use tokio::time::error::Elapsed; -use super::transport::base_protocol::{ - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcVersion, +use regex::Regex; +use rmcp::model::{ + ErrorCode, + Implementation, + InitializeRequestParam, + ListPromptsResult, + ListToolsResult, + LoggingLevel, + LoggingMessageNotificationParam, + PaginatedRequestParam, + ServerNotification, + ServerRequest, }; -use super::transport::stdio::JsonRpcStdioTransport; -use super::transport::{ - self, - Transport, - TransportError, +use rmcp::service::{ + ClientInitializeError, + NotificationContext, }; -use super::{ - JsonRpcResponse, - Listener as _, - LogListener, - Messenger, - PaginationSupportedOps, - PromptGet, - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ServerCapabilities, - ToolsListResult, +use rmcp::transport::{ + ConfigureCommandExt, + TokioChildProcess, }; -use crate::util::process::{ - Pid, - terminate_process, +use rmcp::{ + ErrorData, + RoleClient, + Service, + ServiceError, + ServiceExt, }; +use tokio::io::AsyncReadExt as _; +use tokio::process::Command; +use tokio::task::JoinHandle; +use tracing::error; + +use super::messenger::Messenger; +use crate::cli::chat::server_messenger::ServerMessenger; +use crate::cli::chat::tools::custom_tool::CustomToolConfig; +use crate::os::Os; + +/// Fetches all pages of specified resources from a server +macro_rules! paginated_fetch { + ( + final_result_type: $final_result_type:ty, + content_type: $content_type:ty, + service_method: $service_method:ident, + result_field: $result_field:ident, + messenger_method: $messenger_method:ident, + service: $service:expr, + messenger: $messenger:expr, + server_name: $server_name:expr + ) => { + { + let mut cursor = None::; + let mut final_result = Ok(<$final_result_type>::with_all_items(Default::default())); + let mut content = Vec::<$content_type>::new(); -pub type ClientInfo = serde_json::Value; -pub type StdioTransport = JsonRpcStdioTransport; + loop { + let param = Some(PaginatedRequestParam { cursor: cursor.clone() }); + match $service.$service_method(param).await { + Ok(mut result) => { + if let Some(s) = result.next_cursor { + cursor.replace(s); + } + content.append(&mut result.$result_field); + }, + Err(e) => { + final_result = Err(e); + break; + }, + } + if cursor.is_none() { + break; + } + } -/// Represents the capabilities of a client in the Model Context Protocol. -/// This structure is sent to the server during initialization to communicate -/// what features the client supports and provide information about the client. -/// When features are added to the client, these should be declared in the [From] trait implemented -/// for the struct. -#[derive(Default, Debug, Serialize)] -#[serde(rename_all = "camelCase")] -struct ClientCapabilities { - protocol_version: JsonRpcVersion, - capabilities: HashMap, - client_info: serde_json::Value, -} + if let Ok(final_result) = &mut final_result { + final_result.$result_field.append(&mut content); + } -impl From for ClientCapabilities { - fn from(client_info: ClientInfo) -> Self { - ClientCapabilities { - client_info, - ..Default::default() + if let Err(e) = $messenger.$messenger_method(final_result, Some($service)).await { + error!(target: "mcp", "Initial {} result failed to send for server {}: {}", + stringify!($result_field), $server_name, e); + } } - } + }; } -#[derive(Debug, Deserialize)] -pub struct ClientConfig { - pub server_name: String, - pub bin_path: String, - pub args: Vec, - pub timeout: u64, - pub client_info: serde_json::Value, - pub env: Option>, +/// Substitutes environment variables in the format ${env:VAR_NAME} with their actual values +fn substitute_env_vars(input: &str, env: &crate::os::Env) -> String { + // Create a regex to match ${env:VAR_NAME} pattern + let re = Regex::new(r"\$\{env:([^}]+)\}").unwrap(); + + re.replace_all(input, |caps: ®ex::Captures<'_>| { + let var_name = &caps[1]; + env.get(var_name).unwrap_or_else(|_| format!("${{{}}}", var_name)) + }) + .to_string() } -#[allow(dead_code)] -#[derive(Debug, Error)] -pub enum ClientError { +/// Process a HashMap of environment variables, substituting any ${env:VAR_NAME} patterns +/// with their actual values from the environment +fn process_env_vars(env_vars: &mut HashMap, env: &crate::os::Env) { + for (_, value) in env_vars.iter_mut() { + *value = substitute_env_vars(value, env); + } +} + +#[derive(Debug, thiserror::Error)] +pub enum McpClientError { #[error(transparent)] - TransportError(#[from] TransportError), + ClientInitializeError(#[from] Box), #[error(transparent)] Io(#[from] std::io::Error), #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Operation timed out: {context}")] - RuntimeError { - #[source] - source: tokio::time::error::Elapsed, - context: String, - }, - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error("Failed to obtain process id")] - MissingProcessId, - #[error("Invalid path received")] - InvalidPath, - #[error("{0}")] - ProcessKillError(String), - #[error("{0}")] - PoisonError(String), + JoinError(#[from] tokio::task::JoinError), + #[error("Client has not finished initializing")] + NotReady, } -impl From<(tokio::time::error::Elapsed, String)> for ClientError { - fn from((error, context): (tokio::time::error::Elapsed, String)) -> Self { - ClientError::RuntimeError { source: error, context } - } -} +pub type RunningService = rmcp::service::RunningService; +/// This struct implements the [Service] trait from rmcp. It is within this trait the logic of +/// server driven data flow (i.e. requests and notifications that are sent from the server) are +/// handled. #[derive(Debug)] -pub struct Client { +pub struct McpClientService { + pub config: CustomToolConfig, server_name: String, - transport: Arc, - timeout: u64, - pub server_process_id: Option, - client_info: serde_json::Value, - current_id: Arc, - pub messenger: Option>, - // TODO: move this to tool manager that way all the assets are treated equally - pub prompt_gets: Arc>>, - pub is_prompts_out_of_date: Arc, + messenger: ServerMessenger, } -impl Clone for Client { - fn clone(&self) -> Self { +impl McpClientService { + pub fn new(server_name: String, config: CustomToolConfig, messenger: ServerMessenger) -> Self { Self { - server_name: self.server_name.clone(), - transport: self.transport.clone(), - timeout: self.timeout, - // Note that we cannot have an id for the clone because we would kill the original - // process when we drop the clone - server_process_id: None, - client_info: self.client_info.clone(), - current_id: self.current_id.clone(), - messenger: None, - prompt_gets: self.prompt_gets.clone(), - is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), + server_name, + config, + messenger, } } -} -impl Client { - pub fn from_config(config: ClientConfig) -> Result { - let ClientConfig { - server_name, - bin_path, - args, - timeout, - client_info, - env, - } = config; - let child = { - let expanded_bin_path = shellexpand::tilde(&bin_path); - - // On Windows, we need to use cmd.exe to run the binary with arguments because Tokio - // always assumes that the program has an .exe extension, which is not the case for - // helpers like `uvx` or `npx`. - let mut command = if cfg!(windows) { - let mut cmd = tokio::process::Command::new("cmd.exe"); - cmd.args(["/C", &Self::build_windows_command(&expanded_bin_path, args)]); - cmd - } else { - let mut cmd = tokio::process::Command::new(expanded_bin_path.to_string()); - cmd.args(args); - cmd - }; - - command - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .envs(std::env::vars()); - - #[cfg(not(windows))] - command.process_group(0); - - if let Some(env) = env { - for (env_name, env_value) in env { - command.env(env_name, env_value); + pub async fn init(mut self, os: &Os) -> Result { + let os_clone = os.clone(); + + let handle: JoinHandle> = tokio::spawn(async move { + let CustomToolConfig { + command: command_as_str, + args, + env: config_envs, + .. + } = &mut self.config; + + let command = Command::new(command_as_str).configure(|cmd| { + if let Some(envs) = config_envs { + process_env_vars(envs, &os_clone.env); + cmd.envs(envs); } - } + cmd.envs(std::env::vars()).args(args); - command.spawn()? - }; - - let server_process_id = child.id().ok_or(ClientError::MissingProcessId)?; - let server_process_id = Some(Pid::from_u32(server_process_id)); - - let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?); - Ok(Self { - server_name, - transport, - timeout, - server_process_id, - client_info, - current_id: Arc::new(AtomicU64::new(0)), - messenger: None, - prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), - is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), - }) - } - - fn build_windows_command(bin_path: &str, args: Vec) -> String { - let mut parts = Vec::new(); + #[cfg(not(windows))] + cmd.process_group(0); + }); - // Add the binary path, quoted if necessary - parts.push(Self::quote_windows_arg(bin_path)); + let messenger_clone = self.messenger.duplicate(); + let server_name = self.server_name.clone(); - // Add all arguments, quoted if necessary - for arg in args { - parts.push(Self::quote_windows_arg(&arg)); - } + let result: Result<_, McpClientError> = async { + // Spawn the child process with stderr piped + let (tokio_child_process, child_stderr) = + TokioChildProcess::builder(command).stderr(Stdio::piped()).spawn()?; - parts.join(" ") - } + // Attempt to serve the process + let service = self + .serve::(tokio_child_process) + .await + .map_err(Box::new)?; - fn quote_windows_arg(arg: &str) -> String { - // If the argument doesn't need quoting, return as-is - if !arg.chars().any(|c| " \t\n\r\"".contains(c)) { - return arg.to_string(); - } + Ok((service, child_stderr)) + } + .await; + + let (service, child_stderr) = match result { + Ok((service, stderr)) => (service, stderr), + Err(e) => { + let msg = e.to_string(); + let error_data = ErrorData { + code: ErrorCode::RESOURCE_NOT_FOUND, + message: Cow::from(msg), + data: None, + }; + let err = ServiceError::McpError(error_data); - let mut result = String::from("\""); - let mut backslashes = 0; + if let Err(send_err) = messenger_clone.send_tools_list_result(Err(err), None).await { + error!("Error sending tool result for {server_name}: {send_err}"); + } - for c in arg.chars() { - match c { - '\\' => { - backslashes += 1; - result.push('\\'); + return Err(e); }, - '"' => { - // Escape all preceding backslashes and the quote - for _ in 0..backslashes { - result.push('\\'); + }; + + if let Some(mut stderr) = child_stderr { + let server_name_clone = server_name.clone(); + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + loop { + match stderr.read(&mut buf).await { + Ok(0) => { + tracing::info!(target: "mcp", "{server_name_clone} stderr listening process exited due to EOF"); + break; + }, + Ok(size) => { + tracing::info!(target: "mcp", "{server_name_clone} logged to its stderr: {}", String::from_utf8_lossy(&buf[0..size])); + }, + Err(e) => { + tracing::info!(target: "mcp", "{server_name_clone} stderr listening process exited due to error: {e}"); + break; // Error reading + }, + } } - result.push_str("\\\""); - backslashes = 0; - }, - _ => { - backslashes = 0; - result.push(c); - }, + }); } - } - - // Escape trailing backslashes before the closing quote - for _ in 0..backslashes { - result.push('\\'); - } - - result.push('"'); - result - } -} - -impl Drop for Client -where - T: Transport, -{ - // IF the servers are implemented well, they will shutdown once the pipe closes. - // This drop trait is here as a fail safe to ensure we don't leave behind any orphans. - fn drop(&mut self) { - if let Some(process_id) = self.server_process_id { - let _ = terminate_process(process_id); - } - if let Some(ref messenger) = self.messenger { - messenger.send_deinit_msg(); - } - } -} -impl Client -where - T: Transport, -{ - /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization - /// - /// Also done are the following: - /// - Spawns task for listening to server driven workflows - /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server - /// capabilities received - pub async fn init(&self) -> Result { - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); + let service_clone = service.clone(); + tokio::spawn(async move { + let result: Result<(), Box> = async { + let init_result = service_clone.peer_info(); + if let Some(init_result) = init_result { + if init_result.capabilities.tools.is_some() { + paginated_fetch! { + final_result_type: ListToolsResult, + content_type: rmcp::model::Tool, + service_method: list_tools, + result_field: tools, + messenger_method: send_tools_list_result, + service: service_clone.clone(), + messenger: messenger_clone, + server_name: server_name + }; + } - // Spawning a task to listen and log stderr output - tokio::spawn(async move { - let mut log_listener = transport_ref.get_log_listener(); - loop { - match log_listener.recv().await { - Ok(msg) => { - tracing::trace!(target: "mcp", "{server_name} logged {}", msg); - }, - Err(e) => { - tracing::error!( - "Error encountered while reading from stderr for {server_name}: {:?}\nEnding stderr listening task.", - e - ); - break; - }, + if init_result.capabilities.prompts.is_some() { + paginated_fetch! { + final_result_type: ListPromptsResult, + content_type: rmcp::model::Prompt, + service_method: list_prompts, + result_field: prompts, + messenger_method: send_prompts_list_result, + service: service_clone, + messenger: messenger_clone, + server_name: server_name + }; + } + } + Ok(()) } - } - }); + .await; - let init_params = Some({ - let client_cap = ClientCapabilities::from(self.client_info.clone()); - serde_json::json!(client_cap) - }); - let init_resp = self.request("initialize", init_params).await?; - if let Err(e) = examine_server_capabilities(&init_resp) { - return Err(ClientError::NegotiationError(format!( - "Client {} has failed to negotiate server capabilities with server: {:?}", - self.server_name, e - ))); - } - let cap = { - let result = init_resp.result.ok_or(ClientError::NegotiationError(format!( - "Server {} init resp is missing result", - self.server_name - )))?; - let cap = result - .get("capabilities") - .ok_or(ClientError::NegotiationError(format!( - "Server {} init resp result is missing capabilities", - self.server_name - )))? - .clone(); - serde_json::from_value::(cap)? - }; - self.notify("initialized", None).await?; - - // TODO: group this into examine_server_capabilities - // Prefetch prompts in the background. We should only do this after the server has been - // initialized - if cap.prompts.is_some() { - self.is_prompts_out_of_date.store(true, Ordering::Relaxed); - let client_ref = (*self).clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - tokio::spawn(async move { - fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; - }); - } - if cap.tools.is_some() { - let client_ref = (*self).clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - tokio::spawn(async move { - fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; + if let Err(e) = result { + error!(target: "mcp", "Error in MCP client initialization: {}", e); + } }); - } - - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - let client_ref = (*self).clone(); - let prompts_list_changed_supported = cap.prompts.as_ref().is_some_and(|p| p.get("listChanged").is_some()); - let tools_list_changed_supported = cap.tools.as_ref().is_some_and(|t| t.get("listChanged").is_some()); - tokio::spawn(async move { - let mut listener = transport_ref.get_listener(); - loop { - match listener.recv().await { - Ok(msg) => { - match msg { - JsonRpcMessage::Request(_req) => {}, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { method, params, .. } = notif; - match method.as_str() { - "notifications/message" | "message" => { - let level = params - .as_ref() - .and_then(|p| p.get("level")) - .and_then(|v| serde_json::to_string(v).ok()); - let data = params - .as_ref() - .and_then(|p| p.get("data")) - .and_then(|v| serde_json::to_string(v).ok()); - if let (Some(level), Some(data)) = (level, data) { - match level.to_lowercase().as_str() { - "error" => { - tracing::error!(target: "mcp", "{}: {}", server_name, data); - }, - "warn" => { - tracing::warn!(target: "mcp", "{}: {}", server_name, data); - }, - "info" => { - tracing::info!(target: "mcp", "{}: {}", server_name, data); - }, - "debug" => { - tracing::debug!(target: "mcp", "{}: {}", server_name, data); - }, - "trace" => { - tracing::trace!(target: "mcp", "{}: {}", server_name, data); - }, - _ => {}, - } - } - }, - "notifications/prompts/list_changed" | "prompts/list_changed" - if prompts_list_changed_supported => - { - // TODO: after we have moved the prompts to the tool - // manager we follow the same workflow as the list changed - // for tools - fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) - .await; - client_ref.is_prompts_out_of_date.store(true, Ordering::Release); - }, - "notifications/tools/list_changed" | "tools/list_changed" - if tools_list_changed_supported => - { - fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) - .await; - }, - _ => {}, - } - }, - JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ - }, - } - }, - Err(e) => { - tracing::error!("Background listening thread for client {}: {:?}", server_name, e); - // If we don't have anything on the other end, we should just end the task - // now - if let TransportError::RecvError(tokio::sync::broadcast::error::RecvError::Closed) = e { - tracing::error!( - "All senders dropped for transport layer for server {}: {:?}. This likely means the mcp server process is no longer running.", - server_name, - e - ); - break; - } - }, - } - } + Ok(service) }); - Ok(cap) + Ok(InitializedMcpClient::Pending(handle)) } - /// Sends a request to the server associated. - /// This call will yield until a response is received. - pub async fn request( + async fn on_logging_message( &self, - method: &str, - params: Option, - ) -> Result { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let recv_map_err = |e: Elapsed| (e, format!("recv for {method}")); - let mut id = self.get_id(); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - tracing::trace!(target: "mcp", "To {}:\n{:#?}", self.server_name, request); - let msg = JsonRpcMessage::Request(request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let mut listener = self.transport.get_listener(); - let mut resp = time::timeout(Duration::from_millis(self.timeout), async { - // we want to ignore all other messages sent by the server at this point and let the - // background loop handle them - // We also want to ignore all messages emitted by the server to its stdout that does - // not deserialize into a valid JsonRpcMessage (they are not supposed to do this but - // too many people complained about this so we are adding this safeguard in) - loop { - if let Ok(JsonRpcMessage::Response(resp)) = listener.recv().await { - if resp.id == id { - break Ok::(resp); - } - } - } - }) - .await - .map_err(recv_map_err)??; - // Pagination support: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#pagination-model - let mut next_cursor = resp.result.as_ref().and_then(|v| v.get("nextCursor")); - if next_cursor.is_some() { - let mut current_resp = resp.clone(); - let mut results = Vec::::new(); - let pagination_supported_ops = { - let maybe_pagination_supported_op: Result = method.try_into(); - maybe_pagination_supported_op.ok() - }; - if let Some(ops) = pagination_supported_ops { - loop { - let result = current_resp.result.as_ref().cloned().unwrap(); - let mut list: Vec = match ops { - PaginationSupportedOps::ResourcesList => { - let ResourcesListResult { resources: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ResourceTemplatesList => { - let ResourceTemplatesListResult { - resource_templates: list, - .. - } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::PromptsList => { - let PromptsListResult { prompts: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ToolsList => { - let ToolsListResult { tools: list, .. } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - }; - results.append(&mut list); - if next_cursor.is_none() { - break; - } - id = self.get_id(); - let next_request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params: Some(serde_json::json!({ - "cursor": next_cursor, - })), - }; - let msg = JsonRpcMessage::Request(next_request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let resp = time::timeout(Duration::from_millis(self.timeout), async { - loop { - if let Ok(JsonRpcMessage::Response(resp)) = listener.recv().await { - if resp.id == id { - break Ok::(resp); - } - } - } - }) - .await - .map_err(recv_map_err)??; - current_resp = resp; - next_cursor = current_resp.result.as_ref().and_then(|v| v.get("nextCursor")); - } - resp.result = Some({ - let mut map = serde_json::Map::new(); - map.insert(ops.as_key().to_owned(), serde_json::to_value(results)?); - serde_json::to_value(map)? - }); - } + params: LoggingMessageNotificationParam, + _context: NotificationContext, + ) { + let level = params.level; + let data = params.data; + let server_name = &self.server_name; + + match level { + LoggingLevel::Error | LoggingLevel::Critical | LoggingLevel::Emergency | LoggingLevel::Alert => { + tracing::error!(target: "mcp", "{}: {}", server_name, data); + }, + LoggingLevel::Warning => { + tracing::warn!(target: "mcp", "{}: {}", server_name, data); + }, + LoggingLevel::Info => { + tracing::info!(target: "mcp", "{}: {}", server_name, data); + }, + LoggingLevel::Debug => { + tracing::debug!(target: "mcp", "{}: {}", server_name, data); + }, + LoggingLevel::Notice => { + tracing::trace!(target: "mcp", "{}: {}", server_name, data); + }, } - tracing::trace!(target: "mcp", "From {}:\n{:#?}", self.server_name, resp); - Ok(resp) } - /// Sends a notification to the server associated. - /// Notifications are requests that expect no responses. - pub async fn notify(&self, method: &str, params: Option) -> Result<(), ClientError> { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let notification = JsonRpcNotification { - jsonrpc: JsonRpcVersion::default(), - method: format!("notifications/{}", method), - params, + async fn on_tool_list_changed(&self, context: NotificationContext) { + let NotificationContext { peer, .. } = context; + let _timeout = self.config.timeout; + + paginated_fetch! { + final_result_type: ListToolsResult, + content_type: rmcp::model::Tool, + service_method: list_tools, + result_field: tools, + messenger_method: send_tools_list_result, + service: peer, + messenger: self.messenger, + server_name: self.server_name }; - let msg = JsonRpcMessage::Notification(notification); - Ok( - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??, - ) } - fn get_id(&self) -> u64 { - self.current_id.fetch_add(1, Ordering::SeqCst) + async fn on_prompt_list_changed(&self, context: NotificationContext) { + let NotificationContext { peer, .. } = context; + let _timeout = self.config.timeout; + + paginated_fetch! { + final_result_type: ListPromptsResult, + content_type: rmcp::model::Prompt, + service_method: list_prompts, + result_field: prompts, + messenger_method: send_prompts_list_result, + service: peer, + messenger: self.messenger, + server_name: self.server_name + }; } } -fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientError> { - // Check the jrpc version. - // Currently we are only proceeding if the versions are EXACTLY the same. - let jrpc_version = ser_cap.jsonrpc.as_u32_vec(); - let client_jrpc_version = JsonRpcVersion::default().as_u32_vec(); - for (sv, cv) in jrpc_version.iter().zip(client_jrpc_version.iter()) { - if sv != cv { - return Err(ClientError::NegotiationError( - "Incompatible jrpc version between server and client".to_owned(), - )); +impl Service for McpClientService { + async fn handle_request( + &self, + request: ::PeerReq, + _context: rmcp::service::RequestContext, + ) -> Result<::Resp, rmcp::ErrorData> { + match request { + ServerRequest::PingRequest(_) => Err(rmcp::ErrorData::method_not_found::()), + ServerRequest::CreateMessageRequest(_) => Err(rmcp::ErrorData::method_not_found::< + rmcp::model::CreateMessageRequestMethod, + >()), + ServerRequest::ListRootsRequest(_) => { + Err(rmcp::ErrorData::method_not_found::()) + }, + ServerRequest::CreateElicitationRequest(_) => Err(rmcp::ErrorData::method_not_found::< + rmcp::model::ElicitationCreateRequestMethod, + >()), } } - Ok(()) -} -#[allow(clippy::borrowed_box)] -async fn fetch_prompts_and_notify_with_messenger(client: &Client, messenger: Option<&Box>) -where - T: Transport, -{ - let prompt_list_result = 'prompt_list_result: { - let Ok(resp) = client.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client.server_name); - return; - }; - let prompt_list_result = match serde_json::from_value::(result) { - Ok(res) => res, - Err(e) => { - let msg = format!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); - break 'prompt_list_result Err(eyre::eyre!(msg)); + async fn handle_notification( + &self, + notification: ::PeerNot, + context: NotificationContext, + ) -> Result<(), rmcp::ErrorData> { + match notification { + ServerNotification::ToolListChangedNotification(_) => self.on_tool_list_changed(context).await, + ServerNotification::LoggingMessageNotification(notification) => { + self.on_logging_message(notification.params, context).await; }, + ServerNotification::PromptListChangedNotification(_) => self.on_prompt_list_changed(context).await, + // TODO: support these + ServerNotification::CancelledNotification(_) => (), + ServerNotification::ResourceUpdatedNotification(_) => (), + ServerNotification::ResourceListChangedNotification(_) => (), + ServerNotification::ProgressNotification(_) => (), }; - Ok::(prompt_list_result) - }; + Ok(()) + } - if let Some(messenger) = messenger { - if let Err(e) = messenger.send_prompts_list_result(prompt_list_result).await { - tracing::error!("Failed to send prompt result through messenger: {:?}", e); + fn get_info(&self) -> ::Info { + InitializeRequestParam { + protocol_version: Default::default(), + capabilities: Default::default(), + client_info: Implementation { + name: "Q DEV CLI".to_string(), + version: "1.0.0".to_string(), + }, } } } -#[allow(clippy::borrowed_box)] -async fn fetch_tools_and_notify_with_messenger(client: &Client, messenger: Option<&Box>) -where - T: Transport, -{ - // TODO: decouple pagination logic from request and have page fetching logic here - // instead - let tool_list_result = 'tool_list_result: { - let resp = match client.request("tools/list", None).await { - Ok(resp) => resp, - Err(e) => break 'tool_list_result Err(e.into()), - }; - if let Some(error) = resp.error { - let msg = format!("Failed to retrieve tool list for {}: {:?}", client.server_name, error); - break 'tool_list_result Err(eyre::eyre!(msg)); - } - let Some(result) = resp.result else { - let msg = format!("Tool list response from {} is missing result", client.server_name); - break 'tool_list_result Err(eyre::eyre!(msg)); - }; - let tool_list_result = match serde_json::from_value::(result) { - Ok(result) => result, - Err(e) => { - let msg = format!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); - break 'tool_list_result Err(eyre::eyre!(msg)); - }, - }; - Ok::(tool_list_result) - }; +/// InitializedMcpClient is the return of [McpClientService::init]. +/// This is necessitated by the fact that [Service::serve], the command to spawn the process, is +/// async and does not resolve immediately. This delay can be significant and causes long perceived +/// latency during start up. However, our current architecture still requires the main chat loop to +/// have ownership of [RunningService]. +/// The solution chosen here is to instead spawn a task and have [Service::serve] called there and +/// return the handle to said task, stored in the [InitializedMcpClient::Pending] variant. This +/// enum is then flipped lazily (if applicable) when a [RunningService] is needed. +#[derive(Debug)] +pub enum InitializedMcpClient { + Pending(JoinHandle>), + Ready(RunningService), +} - if let Some(messenger) = messenger { - if let Err(e) = messenger.send_tools_list_result(tool_list_result).await { - tracing::error!("Failed to send tool result through messenger {:?}", e); +impl InitializedMcpClient { + pub async fn get_running_service(&mut self) -> Result<&RunningService, McpClientError> { + match self { + InitializedMcpClient::Pending(handle) if handle.is_finished() => { + let running_service = handle.await??; + *self = InitializedMcpClient::Ready(running_service); + let InitializedMcpClient::Ready(running_service) = self else { + unreachable!() + }; + + Ok(running_service) + }, + InitializedMcpClient::Ready(running_service) => Ok(running_service), + InitializedMcpClient::Pending(_) => Err(McpClientError::NotReady), } } } #[cfg(test)] mod tests { - use std::path::PathBuf; - - use serde_json::Value; - use super::*; - const TEST_BIN_OUT_DIR: &str = "target/debug"; - const TEST_SERVER_NAME: &str = "test_mcp_server"; - - fn get_workspace_root() -> PathBuf { - let output = std::process::Command::new("cargo") - .args(["metadata", "--format-version=1", "--no-deps"]) - .output() - .expect("Failed to execute cargo metadata"); - - let metadata: serde_json::Value = - serde_json::from_slice(&output.stdout).expect("Failed to parse cargo metadata"); - - let workspace_root = metadata["workspace_root"] - .as_str() - .expect("Failed to find workspace_root in metadata"); - - PathBuf::from(workspace_root) - } - - #[tokio::test(flavor = "multi_thread")] - // For some reason this test is quite flakey when ran in the CI but not on developer's - // machines. As a result it is hard to debug, hence we are ignoring it for now. - #[ignore] - async fn test_client_stdio() { - std::process::Command::new("cargo") - .args(["build", "--bin", TEST_SERVER_NAME]) - .status() - .expect("Failed to build binary"); - let workspace_root = get_workspace_root(); - let bin_path = workspace_root.join(TEST_BIN_OUT_DIR).join(TEST_SERVER_NAME); - println!("bin path: {}", bin_path.to_str().unwrap_or("no path found")); - - // Testing 2 concurrent sessions to make sure transport layer does not overlap. - let client_info_one = serde_json::json!({ - "name": "TestClientOne", - "version": "1.0.0" - }); - let client_config_one = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["1".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_one.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let client_info_two = serde_json::json!({ - "name": "TestClientTwo", - "version": "1.0.0" - }); - let client_config_two = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["2".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_two.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let mut client_one = Client::::from_config(client_config_one).expect("Failed to create client"); - let mut client_two = Client::::from_config(client_config_two).expect("Failed to create client"); - let client_one_cap = ClientCapabilities::from(client_info_one); - let client_two_cap = ClientCapabilities::from(client_info_two); - - let (res_one, res_two) = tokio::join!( - time::timeout( - time::Duration::from_secs(10), - test_client_routine(&mut client_one, serde_json::json!(client_one_cap)) - ), - time::timeout( - time::Duration::from_secs(10), - test_client_routine(&mut client_two, serde_json::json!(client_two_cap)) - ) - ); - let res_one = res_one.expect("Client one timed out"); - let res_two = res_two.expect("Client two timed out"); - assert!(res_one.is_ok()); - assert!(res_two.is_ok()); - } - - #[allow(clippy::await_holding_lock)] - async fn test_client_routine( - client: &mut Client, - cap_sent: serde_json::Value, - ) -> Result<(), Box> { - // Test init - let _ = client.init().await.expect("Client init failed"); - tokio::time::sleep(time::Duration::from_millis(1500)).await; - let client_capabilities_sent = client - .request("verify_init_ack_sent", None) - .await - .expect("Verify init ack mock request failed"); - let has_server_recvd_init_ack = client_capabilities_sent - .result - .expect("Failed to retrieve client capabilities sent."); - assert_eq!(has_server_recvd_init_ack.to_string(), "true"); - let cap_recvd = client - .request("verify_init_params_sent", None) - .await - .expect("Verify init params mock request failed"); - let cap_recvd = cap_recvd - .result - .expect("Verify init params mock request does not contain required field (result)"); - assert!(are_json_values_equal(&cap_sent, &cap_recvd)); - - // test list tools - let fake_tool_names = ["get_weather_one", "get_weather_two", "get_weather_three"]; - let mock_result_spec = fake_tool_names.map(create_fake_tool_spec); - let mock_tool_specs_for_verify = serde_json::json!(mock_result_spec.clone()); - let mock_tool_specs_prep_param = mock_result_spec - .iter() - .zip(fake_tool_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_tool_specs_prep_param = - serde_json::to_value(mock_tool_specs_prep_param).expect("Failed to create mock tool specs prep param"); - let _ = client - .request("store_mock_tool_spec", Some(mock_tool_specs_prep_param)) - .await - .expect("Mock tool spec prep failed"); - let tool_spec_recvd = client.request("tools/list", None).await.expect("List tools failed"); - assert!(are_json_values_equal( - tool_spec_recvd - .result - .as_ref() - .and_then(|v| v.get("tools")) - .expect("Failed to retrieve tool specs from result received"), - &mock_tool_specs_for_verify - )); - // Test list prompts directly - let fake_prompt_names = ["code_review_one", "code_review_two", "code_review_three"]; - let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); - let mock_prompts_for_verify = serde_json::json!(mock_result_prompts.clone()); - let mock_prompts_prep_param = mock_result_prompts - .iter() - .zip(fake_prompt_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_prompts_prep_param = - serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); - let _ = client - .request("store_mock_prompts", Some(mock_prompts_prep_param)) - .await - .expect("Mock prompt prep failed"); - let prompts_recvd = client.request("prompts/list", None).await.expect("List prompts failed"); - client.is_prompts_out_of_date.store(false, Ordering::Release); - assert!(are_json_values_equal( - prompts_recvd - .result - .as_ref() - .and_then(|v| v.get("prompts")) - .expect("Failed to retrieve prompts from results received"), - &mock_prompts_for_verify - )); - - // Test prompts list changed - let fake_prompt_names = ["code_review_four", "code_review_five", "code_review_six"]; - let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); - let mock_prompts_prep_param = mock_result_prompts - .iter() - .zip(fake_prompt_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_prompts_prep_param = - serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); - let _ = client - .request("store_mock_prompts", Some(mock_prompts_prep_param)) - .await - .expect("Mock new prompt request failed"); - // After we send the signal for the server to clear prompts, we should be receiving signal - // to fetch for new prompts, after which we should be getting no prompts. - let is_prompts_out_of_date = client.is_prompts_out_of_date.clone(); - let wait_for_new_prompts = async move { - while !is_prompts_out_of_date.load(Ordering::Acquire) { - tokio::time::sleep(time::Duration::from_millis(100)).await; - } - }; - time::timeout(time::Duration::from_secs(5), wait_for_new_prompts) - .await - .expect("Timed out while waiting for new prompts"); - let new_prompts = client.prompt_gets.read().expect("Failed to read new prompts"); - for k in new_prompts.keys() { - assert!(fake_prompt_names.contains(&k.as_str())); + #[tokio::test] + async fn test_substitute_env_vars() { + // Set a test environment variable + let os = Os::new().await.unwrap(); + unsafe { + os.env.set_var("TEST_VAR", "test_value"); } - // Test env var inclusion - let env_vars = client.request("get_env_vars", None).await.expect("Get env vars failed"); - let env_one = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_ONE") - .expect("Failed to retrieve env one from env var request"); - let env_two = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_TWO") - .expect("Failed to retrieve env two from env var request"); - let env_one_as_str = serde_json::to_string(env_one).expect("Failed to convert env one to string"); - let env_two_as_str = serde_json::to_string(env_two).expect("Failed to convert env two to string"); - assert_eq!(env_one_as_str, "\"1\"".to_string()); - assert_eq!(env_two_as_str, "\"2\"".to_string()); - - Ok(()) - } + // Test basic substitution + assert_eq!( + substitute_env_vars("Value is ${env:TEST_VAR}", &os.env), + "Value is test_value" + ); - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } + // Test multiple substitutions + assert_eq!( + substitute_env_vars("${env:TEST_VAR} and ${env:TEST_VAR}", &os.env), + "test_value and test_value" + ); - fn create_fake_tool_spec(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Get current weather information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - }) - } + // Test non-existent variable + assert_eq!( + substitute_env_vars("${env:NON_EXISTENT_VAR}", &os.env), + "${NON_EXISTENT_VAR}" + ); - fn create_fake_prompts(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Asks the LLM to analyze code quality and suggest improvements", - "arguments": [ - { - "name": "code", - "description": "The code to review", - "required": true - } - ] - }) + // Test mixed content + assert_eq!( + substitute_env_vars("Prefix ${env:TEST_VAR} suffix", &os.env), + "Prefix test_value suffix" + ); } - #[cfg(windows)] - mod windows_command_tests { - use super::*; - use crate::mcp_client::transport::stdio::JsonRpcStdioTransport as StdioTransport; - - #[test] - fn test_quote_windows_arg_no_special_chars() { - let result = Client::::quote_windows_arg("simple"); - assert_eq!(result, "simple"); + #[tokio::test] + async fn test_process_env_vars() { + let os = Os::new().await.unwrap(); + unsafe { + os.env.set_var("TEST_VAR", "test_value"); } - #[test] - fn test_quote_windows_arg_with_spaces() { - let result = Client::::quote_windows_arg("with spaces"); - assert_eq!(result, "\"with spaces\""); - } + let mut env_vars = HashMap::new(); + env_vars.insert("KEY1".to_string(), "Value is ${env:TEST_VAR}".to_string()); + env_vars.insert("KEY2".to_string(), "No substitution".to_string()); - #[test] - fn test_quote_windows_arg_with_quotes() { - let result = Client::::quote_windows_arg("with \"quotes\""); - assert_eq!(result, "\"with \\\"quotes\\\"\""); - } - - #[test] - fn test_quote_windows_arg_with_backslashes() { - let result = Client::::quote_windows_arg("path\\to\\file"); - assert_eq!(result, "path\\to\\file"); - } + process_env_vars(&mut env_vars, &os.env); - #[test] - fn test_quote_windows_arg_with_trailing_backslashes() { - let result = Client::::quote_windows_arg("path\\to\\dir\\"); - assert_eq!(result, "path\\to\\dir\\"); - } - - #[test] - fn test_quote_windows_arg_with_backslashes_before_quote() { - let result = Client::::quote_windows_arg("path\\\\\"quoted\""); - assert_eq!(result, "\"path\\\\\\\\\\\"quoted\\\"\""); - } - - #[test] - fn test_quote_windows_arg_complex_case() { - let result = Client::::quote_windows_arg("C:\\Program Files\\My App\\bin\\app.exe"); - assert_eq!(result, "\"C:\\Program Files\\My App\\bin\\app.exe\""); - } - - #[test] - fn test_quote_windows_arg_with_tabs_and_newlines() { - let result = Client::::quote_windows_arg("with\ttabs\nand\rnewlines"); - assert_eq!(result, "\"with\ttabs\nand\rnewlines\""); - } - - #[test] - fn test_quote_windows_arg_edge_case_only_backslashes() { - let result = Client::::quote_windows_arg("\\\\\\"); - assert_eq!(result, "\\\\\\"); - } - - #[test] - fn test_quote_windows_arg_edge_case_only_quotes() { - let result = Client::::quote_windows_arg("\"\"\""); - assert_eq!(result, "\"\\\"\\\"\\\"\""); - } - - // Tests for build_windows_command function - #[test] - fn test_build_windows_command_empty_args() { - let bin_path = "myapp"; - let args = vec![]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "myapp"); - } - - #[test] - fn test_build_windows_command_uvx_example() { - let bin_path = "uvx"; - let args = vec!["mcp-server-fetch".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "uvx mcp-server-fetch"); - } - - #[test] - fn test_build_windows_command_npx_example() { - let bin_path = "npx"; - let args = vec!["-y".to_string(), "@modelcontextprotocol/server-memory".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "npx -y @modelcontextprotocol/server-memory"); - } - - #[test] - fn test_build_windows_command_docker_example() { - let bin_path = "docker"; - let args = vec![ - "run".to_string(), - "-i".to_string(), - "--rm".to_string(), - "-e".to_string(), - "GITHUB_PERSONAL_ACCESS_TOKEN".to_string(), - "ghcr.io/github/github-mcp-server".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "docker run -i --rm -e GITHUB_PERSONAL_ACCESS_TOKEN ghcr.io/github/github-mcp-server" - ); - } - - #[test] - fn test_build_windows_command_with_quotes_in_args() { - let bin_path = "myapp"; - let args = vec!["--config".to_string(), "{\"key\": \"value\"}".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "myapp --config \"{\\\"key\\\": \\\"value\\\"}\""); - } - - #[test] - fn test_build_windows_command_with_spaces_in_path() { - let bin_path = "C:\\Program Files\\My App\\bin\\app.exe"; - let args = vec!["--input".to_string(), "file with spaces.txt".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "\"C:\\Program Files\\My App\\bin\\app.exe\" --input \"file with spaces.txt\"" - ); - } - - #[test] - fn test_build_windows_command_complex_args() { - let bin_path = "myapp"; - let args = vec![ - "--config".to_string(), - "C:\\Users\\test\\config.json".to_string(), - "--output".to_string(), - "C:\\Output\\result file.txt".to_string(), - "--verbose".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "myapp --config C:\\Users\\test\\config.json --output \"C:\\Output\\result file.txt\" --verbose" - ); - } - - #[test] - fn test_build_windows_command_with_environment_variables() { - let bin_path = "cmd"; - let args = vec!["/c".to_string(), "echo %PATH%".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "cmd /c \"echo %PATH%\""); - } - - #[test] - fn test_build_windows_command_real_world_python() { - let bin_path = "python"; - let args = vec![ - "-m".to_string(), - "mcp_server".to_string(), - "--config".to_string(), - "C:\\configs\\server.json".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "python -m mcp_server --config C:\\configs\\server.json"); - } + assert_eq!(env_vars.get("KEY1").unwrap(), "Value is test_value"); + assert_eq!(env_vars.get("KEY2").unwrap(), "No substitution"); } } diff --git a/crates/chat-cli/src/mcp_client/error.rs b/crates/chat-cli/src/mcp_client/error.rs deleted file mode 100644 index 01f77cfa8b..0000000000 --- a/crates/chat-cli/src/mcp_client/error.rs +++ /dev/null @@ -1,66 +0,0 @@ -/// Error codes as defined in the MCP protocol. -/// -/// These error codes are based on the JSON-RPC 2.0 specification with additional -/// MCP-specific error codes in the -32000 to -32099 range. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(i32)] -pub enum ErrorCode { - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - ParseError = -32700, - - /// The JSON sent is not a valid Request object. - InvalidRequest = -32600, - - /// The method does not exist / is not available. - MethodNotFound = -32601, - - /// Invalid method parameter(s). - InvalidParams = -32602, - - /// Internal JSON-RPC error. - InternalError = -32603, - - /// Server has not been initialized. - /// This error is returned when a request is made before the server - /// has been properly initialized. - ServerNotInitialized = -32002, - - /// Unknown error code. - /// This error is returned when an error code is received that is not - /// recognized by the implementation. - Unknown = -32001, - - /// Request failed. - /// This error is returned when a request fails for a reason not covered - /// by other error codes. - RequestFailed = -32000, -} - -impl From for ErrorCode { - fn from(code: i32) -> Self { - match code { - -32700 => ErrorCode::ParseError, - -32600 => ErrorCode::InvalidRequest, - -32601 => ErrorCode::MethodNotFound, - -32602 => ErrorCode::InvalidParams, - -32603 => ErrorCode::InternalError, - -32002 => ErrorCode::ServerNotInitialized, - -32001 => ErrorCode::Unknown, - -32000 => ErrorCode::RequestFailed, - _ => ErrorCode::Unknown, - } - } -} - -impl From for i32 { - fn from(code: ErrorCode) -> Self { - code as i32 - } -} - -impl std::fmt::Display for ErrorCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} diff --git a/crates/chat-cli/src/mcp_client/facilitator_types.rs b/crates/chat-cli/src/mcp_client/facilitator_types.rs deleted file mode 100644 index 87fbd79b27..0000000000 --- a/crates/chat-cli/src/mcp_client/facilitator_types.rs +++ /dev/null @@ -1,248 +0,0 @@ -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; - -/// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination -#[allow(clippy::enum_variant_names)] -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PaginationSupportedOps { - ResourcesList, - ResourceTemplatesList, - PromptsList, - ToolsList, -} - -impl PaginationSupportedOps { - pub fn as_key(&self) -> &str { - match self { - PaginationSupportedOps::ResourcesList => "resources", - PaginationSupportedOps::ResourceTemplatesList => "resourceTemplates", - PaginationSupportedOps::PromptsList => "prompts", - PaginationSupportedOps::ToolsList => "tools", - } - } -} - -impl TryFrom<&str> for PaginationSupportedOps { - type Error = OpsConversionError; - - fn try_from(value: &str) -> Result { - match value { - "resources/list" => Ok(PaginationSupportedOps::ResourcesList), - "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplatesList), - "prompts/list" => Ok(PaginationSupportedOps::PromptsList), - "tools/list" => Ok(PaginationSupportedOps::ToolsList), - _ => Err(OpsConversionError::InvalidMethod), - } - } -} - -#[derive(Error, Debug)] -pub enum OpsConversionError { - #[error("Invalid method encountered")] - InvalidMethod, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "camelCase")] -/// Role assumed for a particular message -pub enum Role { - User, - Assistant, -} - -impl std::fmt::Display for Role { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "user"), - Role::Assistant => write!(f, "assistant"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing resources operation -pub struct ResourcesListResult { - /// List of resources - pub resources: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -/// Result of listing resource templates operation -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ResourceTemplatesListResult { - /// List of resource templates - pub resource_templates: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of prompt listing query -pub struct PromptsListResult { - /// List of prompts - pub prompts: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents an argument to be supplied to a [PromptGet] -pub struct PromptGetArg { - /// The name identifier of the prompt - pub name: String, - /// Optional description providing context about the prompt - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Indicates whether a response to this prompt is required - /// If not specified, defaults to false - #[serde(skip_serializing_if = "Option::is_none")] - pub required: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents a request to get a prompt from a mcp server -pub struct PromptGet { - /// Unique identifier for the prompt - pub name: String, - /// Optional description providing context about the prompt's purpose - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Optional list of arguments that define the structure of information to be collected - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// `result` field in [JsonRpcResponse] from a `prompts/get` request -pub struct PromptGetResult { - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub messages: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Completed prompt from `prompts/get` to be returned by a mcp server -pub struct Prompt { - pub role: Role, - pub content: MessageContent, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing tools operation -pub struct ToolsListResult { - /// List of tools - pub tools: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolCallResult { - pub content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub is_error: Option, -} - -/// Content of a message -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum MessageContent { - /// Text content - Text { - /// The text content - text: String, - }, - /// Image content - #[serde(rename_all = "camelCase")] - Image { - /// base64-encoded-data - data: String, - mime_type: String, - }, - /// Resource content - Resource { - /// The resource - resource: Resource, - }, -} - -impl From for String { - fn from(val: MessageContent) -> Self { - match val { - MessageContent::Text { text } => text, - MessageContent::Image { data, mime_type } => serde_json::json!({ - "data": data, - "mime_type": mime_type - }) - .to_string(), - MessageContent::Resource { resource } => serde_json::json!(resource).to_string(), - } - } -} - -impl std::fmt::Display for MessageContent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MessageContent::Text { text } => write!(f, "{}", text), - MessageContent::Image { data: _, mime_type } => write!(f, "Image [base64-encoded-string] ({})", mime_type), - MessageContent::Resource { resource } => write!(f, "Resource: {} ({})", resource.title, resource.uri), - } - } -} - -/// Resource contents -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum ResourceContents { - Text { text: String }, - Blob { data: Vec }, -} - -/// A resource in the system -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Resource { - /// Unique identifier for the resource - pub uri: String, - /// Human-readable title - pub title: String, - /// Optional description - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Resource contents - pub contents: ResourceContents, -} - -/// Represents the capabilities supported by a Model Context Protocol server -/// This is the "capabilities" field in the result of a response for init -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerCapabilities { - /// Configuration for server logging capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub logging: Option, - /// Configuration for prompt-related capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub prompts: Option, - /// Configuration for resource management capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub resources: Option, - /// Configuration for tool integration capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option, -} diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs index 75723cd9c7..e9202b7dae 100644 --- a/crates/chat-cli/src/mcp_client/messenger.rs +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -1,11 +1,18 @@ +use rmcp::model::{ + ListPromptsResult, + ListResourceTemplatesResult, + ListResourcesResult, + ListToolsResult, +}; +use rmcp::{ + Peer, + RoleClient, + ServiceError, +}; use thiserror::Error; -use super::{ - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ToolsListResult, -}; +pub type Result = core::result::Result; +pub type MessengerResult = core::result::Result<(), MessengerError>; /// An interface that abstracts the implementation for information delivery from client and its /// consumer. It is through this interface secondary information (i.e. information that are needed @@ -16,26 +23,38 @@ use super::{ pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { /// Sends the result of a tools list operation to the consumer /// This function is used to deliver information about available tools - async fn send_tools_list_result(&self, result: eyre::Result) -> Result<(), MessengerError>; + async fn send_tools_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult; /// Sends the result of a prompts list operation to the consumer /// This function is used to deliver information about available prompts - async fn send_prompts_list_result(&self, result: eyre::Result) -> Result<(), MessengerError>; + async fn send_prompts_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult; /// Sends the result of a resources list operation to the consumer /// This function is used to deliver information about available resources - async fn send_resources_list_result(&self, result: eyre::Result) - -> Result<(), MessengerError>; + async fn send_resources_list_result( + &self, + result: Result, + peer: Option>, + ) -> MessengerResult; /// Sends the result of a resource templates list operation to the consumer /// This function is used to deliver information about available resource templates async fn send_resource_templates_list_result( &self, - result: eyre::Result, - ) -> Result<(), MessengerError>; + result: Result, + peer: Option>, + ) -> MessengerResult; /// Signals to the orchestrator that a server has started initializing - async fn send_init_msg(&self) -> Result<(), MessengerError>; + async fn send_init_msg(&self) -> MessengerResult; /// Signals to the orchestrator that a server has deinitialized fn send_deinit_msg(&self); @@ -56,29 +75,39 @@ pub struct NullMessenger; #[async_trait::async_trait] impl Messenger for NullMessenger { - async fn send_tools_list_result(&self, _result: eyre::Result) -> Result<(), MessengerError> { + async fn send_tools_list_result( + &self, + _result: Result, + _peer: Option>, + ) -> MessengerResult { Ok(()) } - async fn send_prompts_list_result(&self, _result: eyre::Result) -> Result<(), MessengerError> { + async fn send_prompts_list_result( + &self, + _result: Result, + _peer: Option>, + ) -> MessengerResult { Ok(()) } async fn send_resources_list_result( &self, - _result: eyre::Result, - ) -> Result<(), MessengerError> { + _result: Result, + _peer: Option>, + ) -> MessengerResult { Ok(()) } async fn send_resource_templates_list_result( &self, - _result: eyre::Result, - ) -> Result<(), MessengerError> { + _result: Result, + _peer: Option>, + ) -> MessengerResult { Ok(()) } - async fn send_init_msg(&self) -> Result<(), MessengerError> { + async fn send_init_msg(&self) -> MessengerResult { Ok(()) } diff --git a/crates/chat-cli/src/mcp_client/mod.rs b/crates/chat-cli/src/mcp_client/mod.rs index 51f8b178fd..7bc6d76f5a 100644 --- a/crates/chat-cli/src/mcp_client/mod.rs +++ b/crates/chat-cli/src/mcp_client/mod.rs @@ -1,13 +1,4 @@ pub mod client; -pub mod error; -pub mod facilitator_types; pub mod messenger; -pub mod server; -pub mod transport; pub use client::*; -pub use facilitator_types::*; -pub use messenger::*; -#[allow(unused_imports)] -pub use server::*; -pub use transport::*; diff --git a/crates/chat-cli/src/mcp_client/server.rs b/crates/chat-cli/src/mcp_client/server.rs deleted file mode 100644 index 7b320a2c6e..0000000000 --- a/crates/chat-cli/src/mcp_client/server.rs +++ /dev/null @@ -1,311 +0,0 @@ -#![allow(dead_code)] -use std::collections::HashMap; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - Mutex, -}; - -use tokio::io::{ - Stdin, - Stdout, -}; -use tokio::task::JoinHandle; - -use super::Listener as _; -use super::client::StdioTransport; -use super::error::ErrorCode; -use super::transport::base_protocol::{ - JsonRpcError, - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcResponse, -}; -use super::transport::stdio::JsonRpcStdioTransport; -use super::transport::{ - JsonRpcVersion, - Transport, - TransportError, -}; - -pub type Request = serde_json::Value; -pub type Response = Option; -pub type InitializedServer = JoinHandle>; - -pub trait PreServerRequestHandler { - fn register_pending_request_callback(&mut self, cb: impl Fn(u64) -> Option + Send + Sync + 'static); - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ); -} - -#[async_trait::async_trait] -pub trait ServerRequestHandler: PreServerRequestHandler + Send + Sync + 'static { - async fn handle_initialize(&self, params: Option) -> Result; - async fn handle_incoming(&self, method: &str, params: Option) -> Result; - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError>; - async fn handle_shutdown(&self) -> Result<(), ServerError>; -} - -pub struct Server { - transport: Option>, - handler: Option, - #[allow(dead_code)] - pending_requests: Arc>>, - #[allow(dead_code)] - current_id: Arc, -} - -#[derive(Debug, thiserror::Error)] -pub enum ServerError { - #[error(transparent)] - TransportError(#[from] TransportError), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error(transparent)] - TokioJoinError(#[from] tokio::task::JoinError), - #[error("Failed to obtain mutex lock")] - MutexError, - #[error("Failed to obtain request method")] - MissingMethod, - #[error("Failed to obtain request id")] - MissingId, - #[error("Failed to initialize server. Missing transport")] - MissingTransport, - #[error("Failed to initialize server. Missing handler")] - MissingHandler, -} - -impl Server -where - H: ServerRequestHandler, -{ - pub fn new(mut handler: H, stdin: Stdin, stdout: Stdout) -> Result { - let transport = Arc::new(JsonRpcStdioTransport::server(stdin, stdout)?); - let pending_requests = Arc::new(Mutex::new(HashMap::::new())); - let pending_requests_clone_one = pending_requests.clone(); - let current_id = Arc::new(AtomicU64::new(0)); - let pending_request_getter = move |id: u64| -> Option { - match pending_requests_clone_one.lock() { - Ok(mut p) => p.remove(&id), - Err(_) => None, - } - }; - handler.register_pending_request_callback(pending_request_getter); - let transport_clone = transport.clone(); - let pending_request_clone_two = pending_requests.clone(); - let current_id_clone = current_id.clone(); - let request_sender = move |method: &str, params: Option| -> Result<(), ServerError> { - let id = current_id_clone.fetch_add(1, Ordering::SeqCst); - let msg = match method.split_once("/") { - Some(("request", _)) => { - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - let msg = JsonRpcMessage::Request(request.clone()); - #[allow(clippy::map_err_ignore)] - let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; - pending_request.insert(id, request); - Some(msg) - }, - Some(("notifications", _)) => { - let notif = JsonRpcNotification { - jsonrpc: JsonRpcVersion::default(), - method: method.to_owned(), - params, - }; - let msg = JsonRpcMessage::Notification(notif); - Some(msg) - }, - _ => None, - }; - if let Some(msg) = msg { - let transport = transport_clone.clone(); - tokio::task::spawn(async move { - let _ = transport.send(&msg).await; - }); - } - Ok(()) - }; - handler.register_send_request_callback(request_sender); - let server = Self { - transport: Some(transport), - handler: Some(handler), - pending_requests, - current_id, - }; - Ok(server) - } -} - -impl Server -where - T: Transport, - H: ServerRequestHandler, -{ - pub fn init(mut self) -> Result { - let transport = self.transport.take().ok_or(ServerError::MissingTransport)?; - let handler = Arc::new(self.handler.take().ok_or(ServerError::MissingHandler)?); - let has_initialized = Arc::new(AtomicBool::new(false)); - let listener = tokio::spawn(async move { - let mut listener = transport.get_listener(); - loop { - let request = listener.recv().await; - let transport_clone = transport.clone(); - let has_init_clone = has_initialized.clone(); - let handler_clone = handler.clone(); - tokio::task::spawn(async move { - process_request(has_init_clone, transport_clone, handler_clone, request).await; - }); - } - }); - Ok(listener) - } -} - -async fn process_request( - has_initialized: Arc, - transport: Arc, - handler: Arc, - request: Result, -) where - T: Transport, - H: ServerRequestHandler, -{ - match request { - Ok(msg) if msg.is_initialize() => { - let id = msg.id().unwrap_or_default(); - if has_initialized.load(Ordering::SeqCst) { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Server has already been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - } - let JsonRpcMessage::Request(req) = msg else { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Invalid method for initialization (use request)".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - }; - let JsonRpcRequest { params, .. } = req; - match handler.handle_initialize(params).await { - Ok(result) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - id, - result, - ..Default::default() - }); - let _ = transport.send(&resp).await; - has_initialized.store(true, Ordering::SeqCst); - }, - Err(_e) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InternalError.into(), - message: "Error producing initialization response".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - } - }, - Ok(msg) if msg.is_shutdown() => { - // TODO: add shutdown routine - }, - Ok(msg) if has_initialized.load(Ordering::SeqCst) => match msg { - JsonRpcMessage::Request(req) => { - let JsonRpcRequest { - id, - jsonrpc, - params, - ref method, - } = req; - let resp = handler.handle_incoming(method, params).await.map_or_else( - |error| { - let err = JsonRpcError { - code: ErrorCode::InternalError.into(), - message: error.to_string(), - data: None, - }; - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result: None, - error: Some(err), - }; - JsonRpcMessage::Response(resp) - }, - |result| { - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result, - error: None, - }; - JsonRpcMessage::Response(resp) - }, - ); - let _ = transport.send(&resp).await; - }, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { ref method, params, .. } = notif; - let _ = handler.handle_incoming(method, params).await; - }, - JsonRpcMessage::Response(resp) => { - let _ = handler.handle_response(resp).await; - }, - }, - Ok(msg) => { - let id = msg.id().unwrap_or_default(); - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::ServerNotInitialized.into(), - message: "Server has not been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - Err(_e) => { - // TODO: error handling - }, - } -} diff --git a/crates/chat-cli/src/mcp_client/transport/base_protocol.rs b/crates/chat-cli/src/mcp_client/transport/base_protocol.rs deleted file mode 100644 index b0394e6e0c..0000000000 --- a/crates/chat-cli/src/mcp_client/transport/base_protocol.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Referencing https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/messages/ -//! Protocol Revision 2024-11-05 -use serde::{ - Deserialize, - Serialize, -}; - -pub type RequestId = u64; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct JsonRpcVersion(String); - -impl Default for JsonRpcVersion { - fn default() -> Self { - JsonRpcVersion("2.0".to_owned()) - } -} - -impl JsonRpcVersion { - pub fn as_u32_vec(&self) -> Vec { - self.0 - .split(".") - .map(|n| n.parse::().unwrap()) - .collect::>() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(untagged)] -#[serde(deny_unknown_fields)] -// DO NOT change the order of these variants. This body of json is [untagged](https://serde.rs/enum-representations.html#untagged) -// The categorization of the deserialization depends on the order in which the variants are -// declared. -pub enum JsonRpcMessage { - Response(JsonRpcResponse), - Notification(JsonRpcNotification), - Request(JsonRpcRequest), -} - -impl JsonRpcMessage { - pub fn is_initialize(&self) -> bool { - match self { - JsonRpcMessage::Request(req) => req.method == "initialize", - _ => false, - } - } - - pub fn is_shutdown(&self) -> bool { - match self { - JsonRpcMessage::Notification(notif) => notif.method == "notification/shutdown", - _ => false, - } - } - - pub fn id(&self) -> Option { - match self { - JsonRpcMessage::Request(req) => Some(req.id), - JsonRpcMessage::Response(resp) => Some(resp.id), - JsonRpcMessage::Notification(_) => None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcRequest { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcResponse { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcNotification { - pub jsonrpc: JsonRpcVersion, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcError { - pub code: i32, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -pub enum TransportType { - #[default] - Stdio, - Websocket, -} diff --git a/crates/chat-cli/src/mcp_client/transport/mod.rs b/crates/chat-cli/src/mcp_client/transport/mod.rs deleted file mode 100644 index f752b1675a..0000000000 --- a/crates/chat-cli/src/mcp_client/transport/mod.rs +++ /dev/null @@ -1,57 +0,0 @@ -pub mod base_protocol; -pub mod stdio; - -use std::fmt::Debug; - -pub use base_protocol::*; -pub use stdio::*; -use thiserror::Error; - -#[derive(Clone, Debug, Error)] -pub enum TransportError { - #[error("Serialization error: {0}")] - Serialization(String), - #[error("IO error: {0}")] - Stdio(String), - #[error("{0}")] - Custom(String), - #[error(transparent)] - RecvError(#[from] tokio::sync::broadcast::error::RecvError), -} - -impl From for TransportError { - fn from(err: serde_json::Error) -> Self { - TransportError::Serialization(err.to_string()) - } -} - -impl From for TransportError { - fn from(err: std::io::Error) -> Self { - TransportError::Stdio(err.to_string()) - } -} - -#[allow(dead_code)] -#[async_trait::async_trait] -pub trait Transport: Send + Sync + Debug + 'static { - /// Sends a message over the transport layer. - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError>; - /// Listens to awaits for a response. This is a call that should be used after `send` is called - /// to listen for a response from the message recipient. - fn get_listener(&self) -> impl Listener; - /// Gracefully terminates the transport connection, cleaning up any resources. - /// This should be called when the transport is no longer needed to ensure proper cleanup. - async fn shutdown(&self) -> Result<(), TransportError>; - /// Listener that listens for logging messages. - fn get_log_listener(&self) -> impl LogListener; -} - -#[async_trait::async_trait] -pub trait Listener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} - -#[async_trait::async_trait] -pub trait LogListener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} diff --git a/crates/chat-cli/src/mcp_client/transport/stdio.rs b/crates/chat-cli/src/mcp_client/transport/stdio.rs deleted file mode 100644 index 89266a183d..0000000000 --- a/crates/chat-cli/src/mcp_client/transport/stdio.rs +++ /dev/null @@ -1,285 +0,0 @@ -use std::sync::Arc; - -use tokio::io::{ - AsyncBufReadExt, - AsyncRead, - AsyncWriteExt as _, - BufReader, - Stdin, - Stdout, -}; -use tokio::process::{ - Child, - ChildStdin, -}; -use tokio::sync::{ - Mutex, - broadcast, -}; - -use super::base_protocol::JsonRpcMessage; -use super::{ - Listener, - LogListener, - Transport, - TransportError, -}; - -#[derive(Debug)] -pub enum JsonRpcStdioTransport { - Client { - stdin: Arc>, - receiver: broadcast::Receiver>, - log_receiver: broadcast::Receiver, - }, - Server { - stdout: Arc>, - receiver: broadcast::Receiver>, - }, -} - -impl JsonRpcStdioTransport { - fn spawn_reader( - reader: R, - tx: broadcast::Sender>, - ) { - tokio::spawn(async move { - let mut buffer = Vec::::new(); - let mut buf_reader = BufReader::new(reader); - loop { - buffer.clear(); - // Messages are delimited by newlines and assumed to contain no embedded newlines - // See https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio - match buf_reader.read_until(b'\n', &mut buffer).await { - Ok(0) => break, - Ok(_) => match serde_json::from_slice::(buffer.as_slice()) { - Ok(msg) => { - let _ = tx.send(Ok(msg)); - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - } - } - }); - } - - pub fn client(child_process: Child) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - let Some(stdout) = child_process.stdout else { - return Err(TransportError::Custom("No stdout found on child process".to_owned())); - }; - let Some(stdin) = child_process.stdin else { - return Err(TransportError::Custom("No stdin found on child process".to_owned())); - }; - let Some(stderr) = child_process.stderr else { - return Err(TransportError::Custom("No stderr found on child process".to_owned())); - }; - let (log_tx, log_receiver) = broadcast::channel::(100); - tokio::task::spawn(async move { - let stderr = tokio::io::BufReader::new(stderr); - let mut lines = stderr.lines(); - while let Ok(Some(line)) = lines.next_line().await { - let _ = log_tx.send(line); - } - }); - let stdin = Arc::new(Mutex::new(stdin)); - Self::spawn_reader(stdout, tx); - Ok(JsonRpcStdioTransport::Client { - stdin, - receiver, - log_receiver, - }) - } - - pub fn server(stdin: Stdin, stdout: Stdout) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - Self::spawn_reader(stdin, tx); - let stdout = Arc::new(Mutex::new(stdout)); - Ok(JsonRpcStdioTransport::Server { stdout, receiver }) - } -} - -#[async_trait::async_trait] -impl Transport for JsonRpcStdioTransport { - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdin = stdin.lock().await; - stdin - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - stdin - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - Ok(()) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdout = stdout.lock().await; - stdout - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - stdout - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - Ok(()) - }, - } - } - - fn get_listener(&self) -> impl Listener { - match self { - JsonRpcStdioTransport::Client { receiver, .. } | JsonRpcStdioTransport::Server { receiver, .. } => { - StdioListener { - receiver: receiver.resubscribe(), - } - }, - } - } - - async fn shutdown(&self) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut stdin = stdin.lock().await; - Ok(stdin.shutdown().await?) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut stdout = stdout.lock().await; - Ok(stdout.shutdown().await?) - }, - } - } - - fn get_log_listener(&self) -> impl LogListener { - match self { - JsonRpcStdioTransport::Client { log_receiver, .. } => StdioLogListener { - receiver: log_receiver.resubscribe(), - }, - JsonRpcStdioTransport::Server { .. } => unreachable!("server does not need a log listener"), - } - } -} - -pub struct StdioListener { - pub receiver: broadcast::Receiver>, -} - -#[async_trait::async_trait] -impl Listener for StdioListener { - async fn recv(&mut self) -> Result { - self.receiver.recv().await? - } -} - -pub struct StdioLogListener { - pub receiver: broadcast::Receiver, -} - -#[async_trait::async_trait] -impl LogListener for StdioLogListener { - async fn recv(&mut self) -> Result { - Ok(self.receiver.recv().await?) - } -} - -#[cfg(test)] -mod tests { - use std::process::Stdio; - - use serde_json::{ - Value, - json, - }; - use tokio::process::Command; - - use super::{ - JsonRpcMessage, - JsonRpcStdioTransport, - Listener, - Transport, - }; - - // Helpers for testing - fn create_test_message() -> JsonRpcMessage { - serde_json::from_value(json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "test_method", - "params": { - "test_param": "test_value" - } - })) - .unwrap() - } - - #[tokio::test] - async fn test_client_transport() { - #[cfg(windows)] - let mut cmd = { - let mut cmd = Command::new("powershell"); - cmd.args(&["cat"]); - cmd - }; - #[cfg(not(windows))] - let mut cmd = Command::new("cat"); - - cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).stderr(Stdio::piped()); - - // Inject our mock transport instead - let child = cmd.spawn().expect("Failed to spawn command"); - let transport = JsonRpcStdioTransport::client(child).expect("Failed to create client transport"); - - let message = create_test_message(); - let result = transport.send(&message).await; - assert!(result.is_ok(), "Failed to send message: {:?}", result); - - let echo = transport - .get_listener() - .recv() - .await - .expect("Failed to receive message"); - let echo_value = serde_json::to_value(&echo).expect("Failed to convert echo to value"); - let message_value = serde_json::to_value(&message).expect("Failed to convert message to value"); - assert!(are_json_values_equal(&echo_value, &message_value)); - } - - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } -} diff --git a/crates/chat-cli/src/mcp_client/transport/websocket.rs b/crates/chat-cli/src/mcp_client/transport/websocket.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/chat-cli/src/util/mod.rs b/crates/chat-cli/src/util/mod.rs index ad5ef15898..ac48310ad5 100644 --- a/crates/chat-cli/src/util/mod.rs +++ b/crates/chat-cli/src/util/mod.rs @@ -3,7 +3,6 @@ pub mod directories; pub mod knowledge_store; pub mod open; pub mod pattern_matching; -pub mod process; pub mod spinner; pub mod system_info; #[cfg(test)] diff --git a/crates/chat-cli/src/util/process/mod.rs b/crates/chat-cli/src/util/process/mod.rs deleted file mode 100644 index e0a8414592..0000000000 --- a/crates/chat-cli/src/util/process/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -pub use sysinfo::Pid; - -#[cfg(target_os = "windows")] -mod windows; -#[cfg(target_os = "windows")] -pub use windows::*; - -#[cfg(not(windows))] -mod unix; -#[cfg(not(windows))] -pub use unix::*; diff --git a/crates/chat-cli/src/util/process/unix.rs b/crates/chat-cli/src/util/process/unix.rs deleted file mode 100644 index b0ffc60935..0000000000 --- a/crates/chat-cli/src/util/process/unix.rs +++ /dev/null @@ -1,64 +0,0 @@ -use nix::sys::signal::Signal; -use sysinfo::Pid; - -pub fn terminate_process(pid: Pid) -> Result<(), String> { - let nix_pid = nix::unistd::Pid::from_raw(pid.as_u32() as i32); - nix::sys::signal::kill(nix_pid, Signal::SIGTERM).map_err(|e| format!("Failed to terminate process: {}", e)) -} - -#[cfg(test)] -#[cfg(not(windows))] -mod tests { - use std::process::Command; - use std::time::Duration; - - use super::*; - - // Helper to create a long-running process for testing - fn spawn_test_process() -> std::process::Child { - let mut command = Command::new("sleep"); - command.arg("30"); - command.spawn().expect("Failed to spawn test process") - } - - #[test] - fn test_terminate_process() { - // Spawn a test process - let mut child = spawn_test_process(); - let pid = Pid::from_u32(child.id()); - - // Terminate the process - let result = terminate_process(pid); - - // Verify termination was successful - assert!(result.is_ok(), "Process termination failed: {:?}", result.err()); - - // Give it a moment to terminate - std::thread::sleep(Duration::from_millis(100)); - - // Verify the process is actually terminated - match child.try_wait() { - Ok(Some(_)) => { - // Process exited, which is what we expect - }, - Ok(None) => { - panic!("Process is still running after termination"); - }, - Err(e) => { - panic!("Error checking process status: {}", e); - }, - } - } - - #[test] - fn test_terminate_nonexistent_process() { - // Use a likely invalid PID - let invalid_pid = Pid::from_u32(u32::MAX - 1); - - // Attempt to terminate a non-existent process - let result = terminate_process(invalid_pid); - - // Should return an error - assert!(result.is_err(), "Terminating non-existent process should fail"); - } -} diff --git a/crates/chat-cli/src/util/process/windows.rs b/crates/chat-cli/src/util/process/windows.rs deleted file mode 100644 index 12e0389bd8..0000000000 --- a/crates/chat-cli/src/util/process/windows.rs +++ /dev/null @@ -1,120 +0,0 @@ -use std::ops::Deref; - -use sysinfo::Pid; -use windows::Win32::Foundation::{ - CloseHandle, - HANDLE, -}; -use windows::Win32::System::Threading::{ - OpenProcess, - PROCESS_TERMINATE, - TerminateProcess, -}; - -/// Terminate a process on Windows using the Windows API -pub fn terminate_process(pid: Pid) -> Result<(), String> { - unsafe { - // Open the process with termination rights - let handle = OpenProcess(PROCESS_TERMINATE, false, pid.as_u32()) - .map_err(|e| format!("Failed to open process: {}", e))?; - - // Create a safe handle that will be closed automatically when dropped - let safe_handle = SafeHandle::new(handle).ok_or_else(|| "Invalid process handle".to_string())?; - - // Terminate the process with exit code 1 - TerminateProcess(*safe_handle, 1).map_err(|e| format!("Failed to terminate process: {}", e))?; - - Ok(()) - } -} - -struct SafeHandle(HANDLE); - -impl SafeHandle { - fn new(handle: HANDLE) -> Option { - if !handle.is_invalid() { Some(Self(handle)) } else { None } - } -} - -impl Drop for SafeHandle { - fn drop(&mut self) { - unsafe { - let _ = CloseHandle(self.0); - } - } -} - -impl Deref for SafeHandle { - type Target = HANDLE; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[cfg(test)] -mod tests { - use std::process::Command; - use std::time::Duration; - - use super::*; - - // Helper to create a long-running process for testing - fn spawn_test_process() -> std::process::Child { - let mut command = Command::new("cmd"); - command.args(["/C", "timeout 30 > nul"]); - command.spawn().expect("Failed to spawn test process") - } - - #[test] - fn test_terminate_process() { - // Spawn a test process - let mut child = spawn_test_process(); - let pid = Pid::from_u32(child.id()); - - // Terminate the process - let result = terminate_process(pid); - - // Verify termination was successful - assert!(result.is_ok(), "Process termination failed: {:?}", result.err()); - - // Give it a moment to terminate - std::thread::sleep(Duration::from_millis(100)); - - // Verify the process is actually terminated - match child.try_wait() { - Ok(Some(_)) => { - // Process exited, which is what we expect - }, - Ok(None) => { - panic!("Process is still running after termination"); - }, - Err(e) => { - panic!("Error checking process status: {}", e); - }, - } - } - - #[test] - fn test_terminate_nonexistent_process() { - // Use a likely invalid PID - let invalid_pid = Pid::from_u32(u32::MAX - 1); - - // Attempt to terminate a non-existent process - let result = terminate_process(invalid_pid); - - // Should return an error - assert!(result.is_err(), "Terminating non-existent process should fail"); - } - - #[test] - fn test_safe_handle() { - // Test creating a SafeHandle with an invalid handle - let invalid_handle = HANDLE(std::ptr::null_mut()); - let safe_handle = SafeHandle::new(invalid_handle); - assert!(safe_handle.is_none(), "SafeHandle should be None for invalid handle"); - - // We can't easily test a valid handle without actually opening a process, - // which would require additional setup and teardown - } -} diff --git a/crates/chat-cli/test_mcp_server/test_server.rs b/crates/chat-cli/test_mcp_server/test_server.rs deleted file mode 100644 index 970157f96b..0000000000 --- a/crates/chat-cli/test_mcp_server/test_server.rs +++ /dev/null @@ -1,340 +0,0 @@ -//! This is a bin used solely for testing the client -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::atomic::{ - AtomicU8, - Ordering, -}; - -use chat_cli::{ - self, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcStdioTransport, - PreServerRequestHandler, - Response, - Server, - ServerError, - ServerRequestHandler, -}; -use tokio::sync::Mutex; - -#[derive(Default)] -struct Handler { - pending_request: Option Option + Send + Sync>>, - #[allow(clippy::type_complexity)] - send_request: Option) -> Result<(), ServerError> + Send + Sync>>, - storage: Mutex>, - tool_spec: Mutex>, - tool_spec_key_list: Mutex>, - prompts: Mutex>, - prompt_key_list: Mutex>, - prompt_list_call_no: AtomicU8, -} - -impl PreServerRequestHandler for Handler { - fn register_pending_request_callback( - &mut self, - cb: impl Fn(u64) -> Option + Send + Sync + 'static, - ) { - self.pending_request = Some(Box::new(cb)); - } - - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ) { - self.send_request = Some(Box::new(cb)); - } -} - -#[async_trait::async_trait] -impl ServerRequestHandler for Handler { - async fn handle_initialize(&self, params: Option) -> Result { - let mut storage = self.storage.lock().await; - if let Some(params) = params { - storage.insert("client_cap".to_owned(), params); - } - let capabilities = serde_json::json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "subscribe": true, - "listChanged": true - }, - "tools": { - "listChanged": true - } - }, - "serverInfo": { - "name": "TestServer", - "version": "1.0.0" - } - }); - Ok(Some(capabilities)) - } - - async fn handle_incoming(&self, method: &str, params: Option) -> Result { - match method { - "notifications/initialized" => { - { - let mut storage = self.storage.lock().await; - storage.insert( - "init_ack_sent".to_owned(), - serde_json::Value::from_str("true").expect("Failed to convert string to value"), - ); - } - Ok(None) - }, - "verify_init_params_sent" => { - let client_capabilities = { - let storage = self.storage.lock().await; - storage.get("client_cap").cloned() - }; - Ok(client_capabilities) - }, - "verify_init_ack_sent" => { - let result = { - let storage = self.storage.lock().await; - storage.get("init_ack_sent").cloned() - }; - Ok(result) - }, - "store_mock_tool_spec" => { - let Some(params) = params else { - eprintln!("Params missing from store mock tool spec"); - return Ok(None); - }; - // expecting a mock_specs: { key: String, value: serde_json::Value }[]; - let Ok(mock_specs) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let self_tool_specs = self.tool_spec.lock().await; - let mut self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let _ = mock_specs.iter().fold(self_tool_specs, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_tool_spec_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - Ok(None) - }, - "tools/list" => { - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let self_tool_spec = self.tool_spec.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_tool_spec_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_tool_spec_key_list.get(i + 1).cloned(), - self_tool_spec.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - let tool_spec_key_list = self.tool_spec_key_list.lock().await; - let tool_spec = self.tool_spec.lock().await; - let first_key = tool_spec_key_list - .first() - .expect("First key missing from tool specs") - .clone(); - let first_value = tool_spec - .get(&first_key) - .expect("First value missing from tool specs") - .clone(); - let second_key = tool_spec_key_list - .get(1) - .expect("Second key missing from tool specs") - .clone(); - return Ok(Some(serde_json::json!({ - "tools": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_env_vars" => { - let kv = std::env::vars().fold(HashMap::::new(), |mut acc, (k, v)| { - acc.insert(k, v); - acc - }); - Ok(Some(serde_json::json!(kv))) - }, - // This is a test path relevant only to sampling - "trigger_server_request" => { - let Some(ref send_request) = self.send_request else { - return Err(ServerError::MissingMethod); - }; - let params = Some(serde_json::json!({ - "messages": [ - { - "role": "user", - "content": { - "type": "text", - "text": "What is the capital of France?" - } - } - ], - "modelPreferences": { - "hints": [ - { - "name": "claude-3-sonnet" - } - ], - "intelligencePriority": 0.8, - "speedPriority": 0.5 - }, - "systemPrompt": "You are a helpful assistant.", - "maxTokens": 100 - })); - send_request("sampling/createMessage", params)?; - Ok(None) - }, - "store_mock_prompts" => { - let Some(params) = params else { - eprintln!("Params missing from store mock prompts"); - return Ok(None); - }; - // expecting a mock_prompts: { key: String, value: serde_json::Value }[]; - let Ok(mock_prompts) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let mut self_prompts = self.prompts.lock().await; - let mut self_prompt_key_list = self.prompt_key_list.lock().await; - let is_first_mock = self_prompts.is_empty(); - self_prompts.clear(); - self_prompt_key_list.clear(); - let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_prompt_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - if !is_first_mock { - if let Some(sender) = &self.send_request { - let _ = sender("notifications/prompts/list_changed", None); - } - } - Ok(None) - }, - "prompts/list" => { - // We expect this method to be called after the mock prompts have already been - // stored. - self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed); - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_prompt_key_list = self.prompt_key_list.lock().await; - let self_prompts = self.prompts.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_prompt_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_prompt_key_list.get(i + 1).cloned(), - self_prompts.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - // If there is no parameter, this is the request to retrieve the first page - let prompt_key_list = self.prompt_key_list.lock().await; - let prompts = self.prompts.lock().await; - let first_key = prompt_key_list.first().expect("first key missing"); - let first_value = prompts.get(first_key).cloned().unwrap().unwrap(); - let second_key = prompt_key_list.get(1).expect("second key missing"); - return Ok(Some(serde_json::json!({ - "prompts": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_prompt_list_call_no" => Ok(Some( - serde_json::to_value::(self.prompt_list_call_no.load(Ordering::Relaxed)) - .expect("Failed to convert list call no to u8"), - )), - _ => Err(ServerError::MissingMethod), - } - } - - // This is a test path relevant only to sampling - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError> { - let JsonRpcResponse { id, .. } = resp; - let _pending = self.pending_request.as_ref().and_then(|f| f(id)); - Ok(()) - } - - async fn handle_shutdown(&self) -> Result<(), ServerError> { - Ok(()) - } -} - -#[tokio::main] -async fn main() { - let handler = Handler::default(); - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); - let test_server = Server::::new(handler, stdin, stdout).expect("Failed to create server"); - let _ = test_server.init().expect("Test server failed to init").await; -} From 4b2dc04fee2b65c11f1dbe6fe8f8c137adfa2acc Mon Sep 17 00:00:00 2001 From: Justin Moser Date: Fri, 5 Sep 2025 17:46:31 -0400 Subject: [PATCH 09/71] Dont preserve summary when conversation is cleared (#2793) --- crates/chat-cli/src/cli/chat/cli/clear.rs | 2 +- crates/chat-cli/src/cli/chat/conversation.rs | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/clear.rs b/crates/chat-cli/src/cli/chat/cli/clear.rs index 8da854abea..b31bccd28e 100644 --- a/crates/chat-cli/src/cli/chat/cli/clear.rs +++ b/crates/chat-cli/src/cli/chat/cli/clear.rs @@ -48,7 +48,7 @@ impl ClearArgs { }; if ["y", "Y"].contains(&user_input.as_str()) { - session.conversation.clear(true); + session.conversation.clear(); if let Some(cm) = session.conversation.context_manager.as_mut() { cm.hook_executor.cache.clear(); } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 56bef53885..73d952478e 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -215,13 +215,11 @@ impl ConversationState { &self.history } - /// Clears the conversation history and optionally the summary. - pub fn clear(&mut self, preserve_summary: bool) { + /// Clears the conversation history and summary. + pub fn clear(&mut self) { self.next_message = None; self.history.clear(); - if !preserve_summary { - self.latest_summary = None; - } + self.latest_summary = None; } /// Check if currently in tangent mode From 9969208da36fe02d0c5a39073adb2f3750ff767c Mon Sep 17 00:00:00 2001 From: Matt Lee <1302416+mr-lee@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:54:05 -0400 Subject: [PATCH 10/71] feat: add AGENTS.md to default agent resources (#2812) - Add file://AGENTS.md to default resources list alongside AmazonQ.md - Update test to include both AmazonQ.md and AGENTS.md files - Ensures AGENTS.md is included everywhere AmazonQ.md was previously included Co-authored-by: Matt Lee --- crates/chat-cli/src/cli/agent/mod.rs | 2 +- crates/chat-cli/src/cli/chat/conversation.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index e3d89b4846..032891705d 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -181,7 +181,7 @@ impl Default for Agent { set.extend(default_approve); set }, - resources: vec!["file://AmazonQ.md", "file://README.md", "file://.amazonq/rules/**/*.md"] + resources: vec!["file://AmazonQ.md", "file://AGENTS.md", "file://README.md", "file://.amazonq/rules/**/*.md"] .into_iter() .map(Into::into) .collect::>(), diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 73d952478e..8ea506f929 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -1195,6 +1195,7 @@ mod tests { use crate::cli::chat::tool_manager::ToolManager; const AMAZONQ_FILENAME: &str = "AmazonQ.md"; + const AGENTS_FILENAME: &str = "AGENTS.md"; fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { @@ -1407,11 +1408,13 @@ mod tests { let mut agents = Agents::default(); let mut agent = Agent::default(); agent.resources.push(AMAZONQ_FILENAME.into()); + agent.resources.push(AGENTS_FILENAME.into()); agents.agents.insert("TestAgent".to_string(), agent); agents.switch("TestAgent").expect("Agent switch failed"); agents }; os.fs.write(AMAZONQ_FILENAME, "test context").await.unwrap(); + os.fs.write(AGENTS_FILENAME, "test agents context").await.unwrap(); let mut output = vec![]; let mut tool_manager = ToolManager::default(); From 8a091be674ab0f8380336766154c041d7897ae84 Mon Sep 17 00:00:00 2001 From: Matt Lee <1302416+mr-lee@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:14:09 -0400 Subject: [PATCH 11/71] feat: add model field support to agent format (#2815) - Add optional 'model' field to Agent struct for specifying model per agent - Update JSON schema and documentation with model field usage - Integrate agent model into model selection priority: 1. CLI argument (--model) 2. Agent's model field (new) 3. User's saved default model 4. System default model - Add proper fallback when agent specifies unavailable model - Extract fallback logic to eliminate code duplication - Include comprehensive unit tests for model field functionality - Maintain backward compatibility with existing agent configurations Co-authored-by: Matt Lee --- crates/chat-cli/src/cli/agent/mod.rs | 63 ++++++++++++++++++++++++++++ crates/chat-cli/src/cli/chat/mod.rs | 36 +++++++++++++--- docs/agent-format.md | 18 +++++++- schemas/agent-v1.json | 8 ++++ 4 files changed, 118 insertions(+), 7 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 032891705d..bf69dcba03 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -161,6 +161,9 @@ pub struct Agent { /// you configure in the mcpServers field in this config #[serde(default)] pub use_legacy_mcp_json: bool, + /// The model ID to use for this agent. If not specified, uses the default model. + #[serde(default)] + pub model: Option, #[serde(skip)] pub path: Option, } @@ -188,6 +191,7 @@ impl Default for Agent { hooks: Default::default(), tools_settings: Default::default(), use_legacy_mcp_json: true, + model: None, path: None, } } @@ -1215,6 +1219,7 @@ mod tests { resources: Vec::new(), hooks: Default::default(), use_legacy_mcp_json: false, + model: None, path: None, }; @@ -1285,4 +1290,62 @@ mod tests { label ); } + + #[test] + fn test_agent_model_field() { + // Test deserialization with model field + let agent_json = r#"{ + "name": "test-agent", + "model": "claude-sonnet-4" + }"#; + + let agent: Agent = serde_json::from_str(agent_json).expect("Failed to deserialize agent with model"); + assert_eq!(agent.model, Some("claude-sonnet-4".to_string())); + + // Test default agent has no model + let default_agent = Agent::default(); + assert_eq!(default_agent.model, None); + + // Test serialization includes model field + let agent_with_model = Agent { + model: Some("test-model".to_string()), + ..Default::default() + }; + let serialized = serde_json::to_string(&agent_with_model).expect("Failed to serialize"); + assert!(serialized.contains("\"model\":\"test-model\"")); + } + + #[test] + fn test_agent_model_fallback_priority() { + // Test that agent model is checked and falls back correctly + let mut agents = Agents::default(); + + // Create agent with unavailable model + let agent_with_invalid_model = Agent { + name: "test-agent".to_string(), + model: Some("unavailable-model".to_string()), + ..Default::default() + }; + + agents.agents.insert("test-agent".to_string(), agent_with_invalid_model); + agents.active_idx = "test-agent".to_string(); + + // Verify the agent has the model set + assert_eq!( + agents.get_active().and_then(|a| a.model.as_ref()), + Some(&"unavailable-model".to_string()) + ); + + // Test agent without model + let agent_without_model = Agent { + name: "no-model-agent".to_string(), + model: None, + ..Default::default() + }; + + agents.agents.insert("no-model-agent".to_string(), agent_without_model); + agents.active_idx = "no-model-agent".to_string(); + + assert_eq!(agents.get_active().and_then(|a| a.model.as_ref()), None); + } } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 43014bb857..d92c5f44f6 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -44,6 +44,7 @@ use cli::compact::CompactStrategy; use cli::model::{ get_available_models, select_model, + find_model, }; pub use conversation::ConversationState; use conversation::TokenWarningLevel; @@ -141,7 +142,6 @@ use crate::cli::TodoListState; use crate::cli::agent::Agents; use crate::cli::chat::cli::SlashCommand; use crate::cli::chat::cli::editor::open_editor; -use crate::cli::chat::cli::model::find_model; use crate::cli::chat::cli::prompts::{ GetPromptError, PromptsSubcommand, @@ -340,7 +340,17 @@ impl ChatArgs { // If modelId is specified, verify it exists before starting the chat // Otherwise, CLI will use a default model when starting chat let (models, default_model_opt) = get_available_models(os).await?; + // Fallback logic: try user's saved default, then system default + let fallback_model_id = || { + if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { + find_model(&models, &saved).map(|m| m.model_id.clone()).or(Some(default_model_opt.model_id.clone())) + } else { + Some(default_model_opt.model_id.clone()) + } + }; + let model_id: Option = if let Some(requested) = self.model.as_ref() { + // CLI argument takes highest priority if let Some(m) = find_model(&models, requested) { Some(m.model_id.clone()) } else { @@ -351,12 +361,26 @@ impl ChatArgs { .join(", "); bail!("Model '{}' does not exist. Available models: {}", requested, available); } - } else if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { - find_model(&models, &saved) - .map(|m| m.model_id.clone()) - .or(Some(default_model_opt.model_id.clone())) + } else if let Some(agent_model) = agents.get_active().and_then(|a| a.model.as_ref()) { + // Agent model takes second priority + if let Some(m) = find_model(&models, agent_model) { + Some(m.model_id.clone()) + } else { + let _ = execute!( + stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("WARNING: "), + style::SetForegroundColor(Color::Reset), + style::Print("Agent specifies model '"), + style::SetForegroundColor(Color::Cyan), + style::Print(agent_model), + style::SetForegroundColor(Color::Reset), + style::Print("' which is not available. Falling back to configured defaults.\n"), + ); + fallback_model_id() + } } else { - Some(default_model_opt.model_id.clone()) + fallback_model_id() }; let (prompt_request_sender, prompt_request_receiver) = tokio::sync::broadcast::channel::(5); diff --git a/docs/agent-format.md b/docs/agent-format.md index c005ad13cd..b467686774 100644 --- a/docs/agent-format.md +++ b/docs/agent-format.md @@ -15,6 +15,7 @@ Every agent configuration file can include the following sections: - [`resources`](#resources-field) — Resources available to the agent. - [`hooks`](#hooks-field) — Commands run at specific trigger points. - [`useLegacyMcpJson`](#uselegacymcpjson-field) — Whether to include legacy MCP configuration. +- [`model`](#model-field) — The model ID to use for this agent. ## Name Field @@ -290,6 +291,20 @@ The `useLegacyMcpJson` field determines whether to include MCP servers defined i When set to `true`, the agent will have access to all MCP servers defined in the global and local configurations in addition to those defined in the agent's `mcpServers` field. +## Model Field + +The `model` field specifies the model ID to use for this agent. If not specified, the agent will use the default model. + +```json +{ + "model": "claude-sonnet-4" +} +``` + +The model ID must match one of the available models returned by the Q CLI's model service. You can see available models by using the `/model` command in an active chat session. + +If the specified model is not available, the agent will fall back to the default model and display a warning. + ## Complete Example Here's a complete example of an agent configuration file: @@ -348,6 +363,7 @@ Here's a complete example of an agent configuration file: } ] }, - "useLegacyMcpJson": true + "useLegacyMcpJson": true, + "model": "claude-sonnet-4" } ``` diff --git a/schemas/agent-v1.json b/schemas/agent-v1.json index 15b626dea8..d49fc53406 100644 --- a/schemas/agent-v1.json +++ b/schemas/agent-v1.json @@ -159,6 +159,14 @@ "description": "Whether or not to include the legacy ~/.aws/amazonq/mcp.json in the agent\nYou can reference tools brought in by these servers as just as you would with the servers\nyou configure in the mcpServers field in this config", "type": "boolean", "default": false + }, + "model": { + "description": "The model ID to use for this agent. If not specified, uses the default model.", + "type": [ + "string", + "null" + ], + "default": null } }, "additionalProperties": false, From f1a7d5bdb6d0aa4730d9ee05d6313e21ffcb9bc6 Mon Sep 17 00:00:00 2001 From: xianwwu Date: Tue, 9 Sep 2025 12:53:04 -0700 Subject: [PATCH 12/71] chore: updating doc to surface /agent generate and note block for /knowledge (#2823) --- docs/agent-format.md | 3 +++ docs/knowledge-management.md | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/agent-format.md b/docs/agent-format.md index b467686774..c707b476e7 100644 --- a/docs/agent-format.md +++ b/docs/agent-format.md @@ -2,6 +2,9 @@ The agent configuration file for each agent is a JSON file. The filename (without the `.json` extension) becomes the agent's name. It contains configuration needed to instantiate and run the agent. +> [!TIP] +> We recommend using the `/agent generate` slash command within your active Q session to intelligently generate your agent configuration with the help of Q. + Every agent configuration file can include the following sections: - [`name`](#name-field) — The name of the agent (optional, derived from filename if not specified). diff --git a/docs/knowledge-management.md b/docs/knowledge-management.md index cb29cfd997..82e3a291af 100644 --- a/docs/knowledge-management.md +++ b/docs/knowledge-management.md @@ -2,7 +2,8 @@ The /knowledge command provides persistent knowledge base functionality for Amazon Q CLI, allowing you to store, search, and manage contextual information that persists across chat sessions. -> Note: This is a beta feature that must be enabled before use. +> [!NOTE] +> This is a beta feature that must be enabled before use. ## Getting Started From 46ddc72ac2025e007368d869005aa8d0f10192fb Mon Sep 17 00:00:00 2001 From: Erben Mo Date: Tue, 9 Sep 2025 14:57:43 -0700 Subject: [PATCH 13/71] Properly handle path with trailing slash in file matching (#2817) * Properly handle path with trailing slash in file matching Today if a path has a trailing slash, the glob pattern will look like "/path-to-folder//**" (note the double slash). Glob doesn't work with double slash actually (it doesn't match anything). As a result, the permission management for fs_read and fs_write is broken when allowed or denied path has trailing slash. The fix is to just manually remove the trailing slash. * format change --- crates/chat-cli/src/util/directories.rs | 40 ++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index a34b71b6f1..66bfbcbdf7 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -193,7 +193,11 @@ pub fn canonicalizes_path(os: &Os, path_as_str: &str) -> Result { /// patterns to exist in a globset. pub fn add_gitignore_globs(builder: &mut GlobSetBuilder, path: &str) -> Result<()> { let glob_for_file = Glob::new(path)?; - let glob_for_dir = Glob::new(&format!("{path}/**"))?; + + // remove existing slash in path so we don't end up with double slash + // Glob doesn't normalize the path so it doesn't work with double slash + let dir_pattern: String = format!("{}/**", path.trim_end_matches('/')); + let glob_for_dir = Glob::new(&dir_pattern)?; builder.add(glob_for_file); builder.add(glob_for_dir); @@ -278,6 +282,40 @@ mod linux_tests { assert!(logs_dir().is_ok()); assert!(settings_path().is_ok()); } + + #[test] + fn test_add_gitignore_globs() { + let direct_file = "/home/user/a.txt"; + let nested_file = "/home/user/folder/a.txt"; + let other_file = "/home/admin/a.txt"; + + // Case 1: Path with trailing slash + let mut builder1 = GlobSetBuilder::new(); + add_gitignore_globs(&mut builder1, "/home/user/").unwrap(); + let globset1 = builder1.build().unwrap(); + + assert!(globset1.is_match(direct_file)); + assert!(globset1.is_match(nested_file)); + assert!(!globset1.is_match(other_file)); + + // Case 2: Path without trailing slash - should behave same as case 1 + let mut builder2 = GlobSetBuilder::new(); + add_gitignore_globs(&mut builder2, "/home/user").unwrap(); + let globset2 = builder2.build().unwrap(); + + assert!(globset2.is_match(direct_file)); + assert!(globset2.is_match(nested_file)); + assert!(!globset1.is_match(other_file)); + + // Case 3: File path - should only match exact file + let mut builder3 = GlobSetBuilder::new(); + add_gitignore_globs(&mut builder3, "/home/user/a.txt").unwrap(); + let globset3 = builder3.build().unwrap(); + + assert!(globset3.is_match(direct_file)); + assert!(!globset3.is_match(nested_file)); + assert!(!globset1.is_match(other_file)); + } } // TODO(grant): Add back path tests on linux From 5ea7045a4d5f9cd8fc0fc9335d30c18a356615b3 Mon Sep 17 00:00:00 2001 From: evanliu048 Date: Wed, 10 Sep 2025 15:07:08 -0700 Subject: [PATCH 14/71] Fix: Add configurable line wrapping for chat (#2816) * add a wrapmode in chat args * add a ut for the wrap arg --- crates/chat-cli/src/cli/agent/mod.rs | 31 ++++++++------- crates/chat-cli/src/cli/chat/mod.rs | 43 +++++++++++++++++++-- crates/chat-cli/src/cli/mod.rs | 58 ++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 16 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index bf69dcba03..7d8d301723 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -184,10 +184,15 @@ impl Default for Agent { set.extend(default_approve); set }, - resources: vec!["file://AmazonQ.md", "file://AGENTS.md", "file://README.md", "file://.amazonq/rules/**/*.md"] - .into_iter() - .map(Into::into) - .collect::>(), + resources: vec![ + "file://AmazonQ.md", + "file://AGENTS.md", + "file://README.md", + "file://.amazonq/rules/**/*.md", + ] + .into_iter() + .map(Into::into) + .collect::>(), hooks: Default::default(), tools_settings: Default::default(), use_legacy_mcp_json: true, @@ -1298,14 +1303,14 @@ mod tests { "name": "test-agent", "model": "claude-sonnet-4" }"#; - + let agent: Agent = serde_json::from_str(agent_json).expect("Failed to deserialize agent with model"); assert_eq!(agent.model, Some("claude-sonnet-4".to_string())); - + // Test default agent has no model let default_agent = Agent::default(); assert_eq!(default_agent.model, None); - + // Test serialization includes model field let agent_with_model = Agent { model: Some("test-model".to_string()), @@ -1319,33 +1324,33 @@ mod tests { fn test_agent_model_fallback_priority() { // Test that agent model is checked and falls back correctly let mut agents = Agents::default(); - + // Create agent with unavailable model let agent_with_invalid_model = Agent { name: "test-agent".to_string(), model: Some("unavailable-model".to_string()), ..Default::default() }; - + agents.agents.insert("test-agent".to_string(), agent_with_invalid_model); agents.active_idx = "test-agent".to_string(); - + // Verify the agent has the model set assert_eq!( agents.get_active().and_then(|a| a.model.as_ref()), Some(&"unavailable-model".to_string()) ); - + // Test agent without model let agent_without_model = Agent { name: "no-model-agent".to_string(), model: None, ..Default::default() }; - + agents.agents.insert("no-model-agent".to_string(), agent_without_model); agents.active_idx = "no-model-agent".to_string(); - + assert_eq!(agents.get_active().and_then(|a| a.model.as_ref()), None); } } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index d92c5f44f6..ece167c23b 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -39,12 +39,13 @@ use clap::{ Args, CommandFactory, Parser, + ValueEnum, }; use cli::compact::CompactStrategy; use cli::model::{ + find_model, get_available_models, select_model, - find_model, }; pub use conversation::ConversationState; use conversation::TokenWarningLevel; @@ -189,6 +190,16 @@ pub const EXTRA_HELP: &str = color_print::cstr! {" Change using: q settings chat.skimCommandKey x "}; +#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)] +pub enum WrapMode { + /// Always wrap at terminal width + Always, + /// Never wrap (raw output) + Never, + /// Auto-detect based on output target (default) + Auto, +} + #[derive(Debug, Clone, PartialEq, Eq, Default, Args)] pub struct ChatArgs { /// Resumes the previous conversation from this directory. @@ -212,6 +223,9 @@ pub struct ChatArgs { pub no_interactive: bool, /// The first question to ask pub input: Option, + /// Control line wrapping behavior (default: auto-detect) + #[arg(short = 'w', long, value_enum)] + pub wrap: Option, } impl ChatArgs { @@ -343,7 +357,9 @@ impl ChatArgs { // Fallback logic: try user's saved default, then system default let fallback_model_id = || { if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { - find_model(&models, &saved).map(|m| m.model_id.clone()).or(Some(default_model_opt.model_id.clone())) + find_model(&models, &saved) + .map(|m| m.model_id.clone()) + .or(Some(default_model_opt.model_id.clone())) } else { Some(default_model_opt.model_id.clone()) } @@ -412,6 +428,7 @@ impl ChatArgs { tool_config, !self.no_interactive, mcp_enabled, + self.wrap, ) .await? .spawn(os) @@ -621,6 +638,7 @@ pub struct ChatSession { interactive: bool, inner: Option, ctrlc_rx: broadcast::Receiver<()>, + wrap: Option, } impl ChatSession { @@ -640,6 +658,7 @@ impl ChatSession { tool_config: HashMap, interactive: bool, mcp_enabled: bool, + wrap: Option, ) -> Result { // Reload prior conversation let mut existing_conversation = false; @@ -731,6 +750,7 @@ impl ChatSession { interactive, inner: Some(ChatState::default()), ctrlc_rx, + wrap, }) } @@ -2419,8 +2439,20 @@ impl ChatSession { let mut buf = String::new(); let mut offset = 0; let mut ended = false; + let terminal_width = match self.wrap { + Some(WrapMode::Never) => None, + Some(WrapMode::Always) => Some(self.terminal_width()), + Some(WrapMode::Auto) | None => { + if std::io::stdout().is_terminal() { + Some(self.terminal_width()) + } else { + None + } + }, + }; + let mut state = ParseState::new( - Some(self.terminal_width()), + terminal_width, os.database.settings.get_bool(Setting::ChatDisableMarkdownRendering), ); let mut response_prefix_printed = false; @@ -3340,6 +3372,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3482,6 +3515,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3579,6 +3613,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3654,6 +3689,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() @@ -3705,6 +3741,7 @@ mod tests { tool_config, true, false, + None, ) .await .unwrap() diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs index 56c3cb2178..6ac5bf6a0c 100644 --- a/crates/chat-cli/src/cli/mod.rs +++ b/crates/chat-cli/src/cli/mod.rs @@ -331,6 +331,12 @@ impl Cli { #[cfg(test)] mod test { + use chat::WrapMode::{ + Always, + Auto, + Never, + }; + use super::*; use crate::util::CHAT_BINARY_NAME; use crate::util::test::assert_parse; @@ -370,6 +376,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + wrap: None, })), verbose: 2, help_all: false, @@ -409,6 +416,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + wrap: None, }) ); } @@ -425,6 +433,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + wrap: None, }) ); } @@ -441,6 +450,7 @@ mod test { trust_all_tools: true, trust_tools: None, no_interactive: false, + wrap: None, }) ); } @@ -457,6 +467,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: true, + wrap: None, }) ); assert_parse!( @@ -469,6 +480,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: true, + wrap: None, }) ); } @@ -485,6 +497,7 @@ mod test { trust_all_tools: true, trust_tools: None, no_interactive: false, + wrap: None, }) ); } @@ -501,6 +514,7 @@ mod test { trust_all_tools: false, trust_tools: Some(vec!["".to_string()]), no_interactive: false, + wrap: None, }) ); } @@ -517,6 +531,50 @@ mod test { trust_all_tools: false, trust_tools: Some(vec!["fs_read".to_string(), "fs_write".to_string()]), no_interactive: false, + wrap: None, + }) + ); + } + + #[test] + fn test_chat_with_different_wrap_modes() { + assert_parse!( + ["chat", "-w", "never"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + wrap: Some(Never), + }) + ); + assert_parse!( + ["chat", "--wrap", "always"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + wrap: Some(Always), + }) + ); + assert_parse!( + ["chat", "--wrap", "auto"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + wrap: Some(Auto), }) ); } From c36892a49643903ed2a89107a43c8df0c435843e Mon Sep 17 00:00:00 2001 From: Matt Lee <1302416+mr-lee@users.noreply.github.com> Date: Thu, 11 Sep 2025 12:17:07 -0400 Subject: [PATCH 15/71] feat(use_aws): add configurable autoAllowReadonly setting (#2828) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add auto_allow_readonly field to use_aws Settings struct (defaults to false) - Update eval_perm method to use auto_allow_readonly setting instead of hardcoded behavior - Default behavior: all AWS operations require user confirmation (secure by default) - Opt-in behavior: when autoAllowReadonly=true, read-only operations are auto-approved - Add comprehensive tests covering all scenarios - Maintains backward compatibility through configuration 🤖 Assisted by Amazon Q Developer Co-authored-by: Matt Lee --- crates/chat-cli/src/cli/chat/tools/use_aws.rs | 125 +++++++++++++++++- docs/built-in-tools.md | 4 +- 2 files changed, 123 insertions(+), 6 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tools/use_aws.rs b/crates/chat-cli/src/cli/chat/tools/use_aws.rs index 456510b5bf..4ea40c80c4 100644 --- a/crates/chat-cli/src/cli/chat/tools/use_aws.rs +++ b/crates/chat-cli/src/cli/chat/tools/use_aws.rs @@ -182,6 +182,8 @@ impl UseAws { allowed_services: Vec, #[serde(default)] denied_services: Vec, + #[serde(default)] + auto_allow_readonly: bool, } let Self { service_name, .. } = self; @@ -201,15 +203,16 @@ impl UseAws { if is_in_allowlist || settings.allowed_services.contains(service_name) { return PermissionEvalResult::Allow; } + // Check auto_allow_readonly setting for read-only operations + if settings.auto_allow_readonly && !self.requires_acceptance() { + return PermissionEvalResult::Allow; + } PermissionEvalResult::Ask }, None if is_in_allowlist => PermissionEvalResult::Allow, _ => { - if self.requires_acceptance() { - PermissionEvalResult::Ask - } else { - PermissionEvalResult::Allow - } + // Default behavior: always ask for confirmation (no auto-approval for read-only) + PermissionEvalResult::Ask }, } } @@ -390,4 +393,116 @@ mod tests { let res = cmd_one.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Deny(ref services) if services.contains(&"s3".to_string()))); } + + #[tokio::test] + async fn test_eval_perm_auto_allow_readonly_default() { + let os = Os::new().await.unwrap(); + + // Test read-only operation with default settings (auto_allow_readonly = false) + let readonly_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "list-objects", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let agent = Agent::default(); + let res = readonly_cmd.eval_perm(&os, &agent); + // Should ask for confirmation even for read-only operations by default + assert!(matches!(res, PermissionEvalResult::Ask)); + + // Test write operation with default settings + let write_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "put-object", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let res = write_cmd.eval_perm(&os, &agent); + // Should ask for confirmation for write operations + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_auto_allow_readonly_enabled() { + let os = Os::new().await.unwrap(); + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::::new(); + map.insert( + ToolSettingTarget("use_aws".to_string()), + serde_json::json!({ + "autoAllowReadonly": true + }), + ); + map + }, + ..Default::default() + }; + + // Test read-only operation with auto_allow_readonly = true + let readonly_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "list-objects", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let res = readonly_cmd.eval_perm(&os, &agent); + // Should allow read-only operations without confirmation + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test write operation with auto_allow_readonly = true + let write_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "put-object", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let res = write_cmd.eval_perm(&os, &agent); + // Should still ask for confirmation for write operations + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_auto_allow_readonly_with_denied_services() { + let os = Os::new().await.unwrap(); + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::::new(); + map.insert( + ToolSettingTarget("use_aws".to_string()), + serde_json::json!({ + "autoAllowReadonly": true, + "deniedServices": ["s3"] + }), + ); + map + }, + ..Default::default() + }; + + // Test read-only operation on denied service + let readonly_cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "list-objects", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + + let res = readonly_cmd.eval_perm(&os, &agent); + // Should deny even read-only operations on denied services + assert!(matches!(res, PermissionEvalResult::Deny(ref services) if services.contains(&"s3".to_string()))); + } } diff --git a/docs/built-in-tools.md b/docs/built-in-tools.md index 337b7ec0f0..39778740bb 100644 --- a/docs/built-in-tools.md +++ b/docs/built-in-tools.md @@ -139,7 +139,8 @@ Make AWS CLI API calls with the specified service, operation, and parameters. "toolsSettings": { "use_aws": { "allowedServices": ["s3", "lambda", "ec2"], - "deniedServices": ["eks", "rds"] + "deniedServices": ["eks", "rds"], + "autoAllowReadonly": true } } } @@ -151,6 +152,7 @@ Make AWS CLI API calls with the specified service, operation, and parameters. |--------|------|---------|-------------| | `allowedServices` | array of strings | `[]` | List of AWS services that can be accessed without prompting | | `deniedServices` | array of strings | `[]` | List of AWS services to deny. Deny rules are evaluated before allow rules | +| `autoAllowReadonly` | boolean | `false` | Whether to automatically allow read-only operations (get, describe, list, ls, search, batch_get) without prompting | ## Using Tool Settings in Agent Configuration From 3b213e8ac49f6cff804fdc4a4b151bfc58e2e80e Mon Sep 17 00:00:00 2001 From: abhraina-aws Date: Thu, 11 Sep 2025 10:59:19 -0700 Subject: [PATCH 16/71] feat: add auto-announcement feature with /changelog command (#2833) --- crates/chat-cli/build.rs | 60 +++++ crates/chat-cli/src/cli/chat/cli/changelog.rs | 23 ++ crates/chat-cli/src/cli/chat/cli/mod.rs | 7 + crates/chat-cli/src/cli/chat/mod.rs | 38 ++- crates/chat-cli/src/cli/chat/prompt.rs | 1 + .../chat-cli/src/cli/chat/tools/introspect.rs | 11 + crates/chat-cli/src/cli/feed.json | 244 ++++++++++++++++++ crates/chat-cli/src/cli/mod.rs | 2 +- crates/chat-cli/src/database/mod.rs | 22 ++ crates/chat-cli/src/util/mod.rs | 1 + crates/chat-cli/src/util/ui.rs | 136 ++++++++++ scripts/build.py | 1 + 12 files changed, 544 insertions(+), 2 deletions(-) create mode 100644 crates/chat-cli/src/cli/chat/cli/changelog.rs create mode 100644 crates/chat-cli/src/util/ui.rs diff --git a/crates/chat-cli/build.rs b/crates/chat-cli/build.rs index fe39e2dbc3..8c6683ef28 100644 --- a/crates/chat-cli/build.rs +++ b/crates/chat-cli/build.rs @@ -83,6 +83,11 @@ fn write_plist() { fn main() { println!("cargo:rerun-if-changed=def.json"); + // Download feed.json if FETCH_FEED environment variable is set + if std::env::var("FETCH_FEED").is_ok() { + download_feed_json(); + } + #[cfg(target_os = "macos")] write_plist(); @@ -322,3 +327,58 @@ fn main() { // write an empty file to the output directory std::fs::write(format!("{}/mod.rs", outdir), pp).unwrap(); } + +/// Downloads the latest feed.json from the autocomplete repository. +/// This ensures official builds have the most up-to-date changelog information. +/// +/// # Errors +/// +/// Prints cargo warnings if: +/// - `curl` command is not available +/// - Network request fails +/// - File write operation fails +fn download_feed_json() { + use std::process::Command; + + println!("cargo:warning=Downloading latest feed.json from autocomplete repo..."); + + // Check if curl is available first + let curl_check = Command::new("curl").arg("--version").output(); + + if curl_check.is_err() { + panic!( + "curl command not found. Cannot download latest feed.json. Please install curl or build without FETCH_FEED=1 to use existing feed.json." + ); + } + + let output = Command::new("curl") + .args([ + "-H", + "Accept: application/vnd.github.v3.raw", + "-s", // silent + "-f", // fail on HTTP errors + "https://api.github.com/repos/aws/amazon-q-developer-cli-autocomplete/contents/feed.json", + ]) + .output(); + + match output { + Ok(result) if result.status.success() => { + if let Err(e) = std::fs::write("src/cli/feed.json", result.stdout) { + panic!("Failed to write feed.json: {}", e); + } else { + println!("cargo:warning=Successfully downloaded latest feed.json"); + } + }, + Ok(result) => { + let error_msg = if !result.stderr.is_empty() { + format!("HTTP error: {}", String::from_utf8_lossy(&result.stderr)) + } else { + "HTTP error occurred".to_string() + }; + panic!("Failed to download feed.json: {}", error_msg); + }, + Err(e) => { + panic!("Failed to execute curl: {}", e); + }, + } +} diff --git a/crates/chat-cli/src/cli/chat/cli/changelog.rs b/crates/chat-cli/src/cli/chat/cli/changelog.rs new file mode 100644 index 0000000000..5578c599c4 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/cli/changelog.rs @@ -0,0 +1,23 @@ +use clap::Args; +use eyre::Result; + +use crate::cli::chat::{ + ChatError, + ChatSession, + ChatState, +}; +use crate::util::ui; + +#[derive(Debug, PartialEq, Args)] +pub struct ChangelogArgs {} + +impl ChangelogArgs { + pub async fn execute(self, session: &mut ChatSession) -> Result { + // Use the shared rendering function from util::ui + ui::render_changelog_content(&mut session.stderr).map_err(|e| ChatError::Std(std::io::Error::other(e)))?; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } +} diff --git a/crates/chat-cli/src/cli/chat/cli/mod.rs b/crates/chat-cli/src/cli/chat/cli/mod.rs index 4e0f38a3d4..1d095d1e4f 100644 --- a/crates/chat-cli/src/cli/chat/cli/mod.rs +++ b/crates/chat-cli/src/cli/chat/cli/mod.rs @@ -1,3 +1,4 @@ +pub mod changelog; pub mod clear; pub mod compact; pub mod context; @@ -16,6 +17,7 @@ pub mod todos; pub mod tools; pub mod usage; +use changelog::ChangelogArgs; use clap::Parser; use clear::ClearArgs; use compact::CompactArgs; @@ -75,6 +77,9 @@ pub enum SlashCommand { Tools(ToolsArgs), /// Create a new Github issue or make a feature request Issue(issue::IssueArgs), + /// View changelog for Amazon Q CLI + #[command(name = "changelog")] + Changelog(ChangelogArgs), /// View and retrieve prompts Prompts(PromptsArgs), /// View context hooks @@ -145,6 +150,7 @@ impl SlashCommand { skip_printing_tools: true, }) }, + Self::Changelog(args) => args.execute(session).await, Self::Prompts(args) => args.execute(session).await, Self::Hooks(args) => args.execute(session).await, Self::Usage(args) => args.execute(os, session).await, @@ -179,6 +185,7 @@ impl SlashCommand { Self::Compact(_) => "compact", Self::Tools(_) => "tools", Self::Issue(_) => "issue", + Self::Changelog(_) => "changelog", Self::Prompts(_) => "prompts", Self::Hooks(_) => "hooks", Self::Usage(_) => "usage", diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index ece167c23b..942653bf07 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -167,6 +167,7 @@ use crate::telemetry::{ use crate::util::{ MCP_SERVER_TOOL_DELIMITER, directories, + ui, }; const LIMIT_REACHED_TEXT: &str = color_print::cstr! { "You've used all your free requests for this month. You have two options: @@ -449,8 +450,11 @@ const WELCOME_TEXT: &str = color_print::cstr! {" const SMALL_SCREEN_WELCOME_TEXT: &str = color_print::cstr! {"Welcome to Amazon Q!"}; const RESUME_TEXT: &str = color_print::cstr! {"Picking up where we left off..."}; +// Maximum number of times to show the changelog announcement per version +const CHANGELOG_MAX_SHOW_COUNT: i64 = 2; + // Only show the model-related tip for now to make users aware of this feature. -const ROTATING_TIPS: [&str; 18] = [ +const ROTATING_TIPS: [&str; 19] = [ color_print::cstr! {"You can resume the last conversation from your current directory by launching with q chat --resume"}, color_print::cstr! {"Get notified whenever Q CLI finishes responding. @@ -484,6 +488,7 @@ const ROTATING_TIPS: [&str; 18] = [ color_print::cstr! {"Run /prompts to learn how to build & run repeatable workflows"}, color_print::cstr! {"Use /tangent or ctrl + t (customizable) to start isolated conversations ( ↯ ) that don't affect your main chat history"}, color_print::cstr! {"Ask me directly about my capabilities! Try questions like \"What can you do?\" or \"Can you save conversations?\""}, + color_print::cstr! {"Stay up to date with the latest features and improvements! Use /changelog to see what's new in Amazon Q CLI"}, ]; const GREETING_BREAK_POINT: usize = 80; @@ -1109,6 +1114,34 @@ impl ChatSession { Ok(()) } + + async fn show_changelog_announcement(&mut self, os: &mut Os) -> Result<()> { + let current_version = env!("CARGO_PKG_VERSION"); + let last_version = os.database.get_changelog_last_version()?; + let show_count = os.database.get_changelog_show_count()?.unwrap_or(0); + + // Check if version changed or if we haven't shown it max times yet + let should_show = match &last_version { + Some(last) if last == current_version => show_count < CHANGELOG_MAX_SHOW_COUNT, + _ => true, // New version or no previous version + }; + + if should_show { + // Use the shared rendering function + ui::render_changelog_content(&mut self.stderr)?; + + // Update the database entries + os.database.set_changelog_last_version(current_version)?; + let new_count = if last_version.as_deref() == Some(current_version) { + show_count + 1 + } else { + 1 + }; + os.database.set_changelog_show_count(new_count)?; + } + + Ok(()) + } } impl Drop for ChatSession { @@ -1255,6 +1288,9 @@ impl ChatSession { execute!(self.stderr, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; } + // Check if we should show the whats-new announcement + self.show_changelog_announcement(os).await?; + if self.all_tools_trusted() { queue!( self.stderr, diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index 77fff472a1..e7addd8f32 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -92,6 +92,7 @@ pub const COMMANDS: &[&str] = &[ "/compact", "/compact help", "/usage", + "/changelog", "/save", "/load", "/subscribe", diff --git a/crates/chat-cli/src/cli/chat/tools/introspect.rs b/crates/chat-cli/src/cli/chat/tools/introspect.rs index 0b431ae4d3..9e42ec0f6e 100644 --- a/crates/chat-cli/src/cli/chat/tools/introspect.rs +++ b/crates/chat-cli/src/cli/chat/tools/introspect.rs @@ -71,6 +71,17 @@ impl Introspect { documentation.push_str("\n\n--- docs/todo-lists.md ---\n"); documentation.push_str(include_str!("../../../../../../docs/todo-lists.md")); + documentation.push_str("\n\n--- changelog (from feed.json) ---\n"); + // Include recent changelog entries from feed.json + let feed = crate::cli::feed::Feed::load(); + let recent_entries = feed.get_all_changelogs().into_iter().take(5).collect::>(); + for entry in recent_entries { + documentation.push_str(&format!("\n## {} ({})\n", entry.version, entry.date)); + for change in &entry.changes { + documentation.push_str(&format!("- {}: {}\n", change.change_type, change.description)); + } + } + documentation.push_str("\n\n--- CONTRIBUTING.md ---\n"); documentation.push_str(include_str!("../../../../../../CONTRIBUTING.md")); diff --git a/crates/chat-cli/src/cli/feed.json b/crates/chat-cli/src/cli/feed.json index 439e455621..7baece70c0 100644 --- a/crates/chat-cli/src/cli/feed.json +++ b/crates/chat-cli/src/cli/feed.json @@ -10,6 +10,250 @@ "hidden": true, "changes": [] }, + { + "type": "release", + "date": "2025-09-02", + "version": "1.15.0", + "title": "Version 1.15.0", + "changes": [ + { + "type": "added", + "description": "A new command `/experiment` for toggling experimental features - [#2711](https://github.com/aws/amazon-q-developer-cli/pull/2711)" + }, + { + "type": "added", + "description": "A new command `/agent generate` for generating agent config with Q - [#2690](https://github.com/aws/amazon-q-developer-cli/pull/2690)" + }, + { + "type": "added", + "description": "A new command `/tangent` for going on a tangent without context pollution - [#2634](https://github.com/aws/amazon-q-developer-cli/pull/2634)" + }, + { + "type": "added", + "description": "A new to-do list tool for handling complex multi-step prompts - [#2533](https://github.com/aws/amazon-q-developer-cli/pull/2533)" + }, + { + "type": "added", + "description": "Agent-scoped knowledge base and context-specific search - [#2647](https://github.com/aws/amazon-q-developer-cli/pull/2647)" + }, + { + "type": "added", + "description": "A new tool `introspect` that allows Q CLI to answer questions about itself - [#2677](https://github.com/aws/amazon-q-developer-cli/pull/2677)" + } + ] + }, + { + "type": "release", + "date": "2025-08-21", + "version": "1.14.1", + "title": "Version 1.14.1", + "changes": [ + { + "type": "fixed", + "description": "Tool permission issue in agent - [#2619](https://github.com/aws/amazon-q-developer-cli/pull/2619)" + }, + { + "type": "added", + "description": "MCP admin-level configuration with GetProfile - [#2639](https://github.com/aws/amazon-q-developer-cli/pull/2639)" + }, + { + "type": "added", + "description": "Wildcard pattern matching support for agent allowedTools - [#2612](https://github.com/aws/amazon-q-developer-cli/pull/2612)" + }, + { + "type": "added", + "description": "Agent hot swap capability - [#2637](https://github.com/aws/amazon-q-developer-cli/pull/2637)" + }, + { + "type": "fixed", + "description": "Agent default profile printing issue in `use_aws`, plus minor doc updates - [#2617](https://github.com/aws/amazon-q-developer-cli/pull/2617)" + }, + { + "type": "changed", + "description": "Knowledge beta improvements (phase 2): Refactored async_client and added BM25 support - [#2608](https://github.com/aws/amazon-q-developer-cli/pull/2608)" + } + ] + }, + { + "type": "release", + "date": "2025-08-15", + "version": "1.14.0", + "title": "Version 1.14.0", + "changes": [ + { + "type": "added", + "description": "Additional supported models in `q chat`, see `/model` - [#2419](https://github.com/aws/amazon-q-developer-cli/pull/2419)" + }, + { + "type": "added", + "description": "`--include` and `--exclude` flags for the `/knowledge add` command - [#2545](https://github.com/aws/amazon-q-developer-cli/pull/2545)" + }, + { + "type": "added", + "description": "Notifications on API retries - [#2607](https://github.com/aws/amazon-q-developer-cli/pull/2607)" + } + ] + }, + { + "type": "release", + "date": "2025-08-11", + "version": "1.13.3", + "title": "Version 1.13.3", + "changes": [ + { + "type": "added", + "description": "Support for setting denied shell commands with `toolsSettings.execute_bash.deniedCommands` - [#2512](https://github.com/aws/amazon-q-developer-cli/pull/2512)" + }, + { + "type": "added", + "description": "Support for setting denied AWS services with `toolsSettings.use_aws.deniedServices` - [#2512](https://github.com/aws/amazon-q-developer-cli/pull/2512)" + }, + { + "type": "added", + "description": "Support for setting denied file paths for `fs_read` and `fs_write` using `deniedPaths` in `toolsSettings` - [#2512](https://github.com/aws/amazon-q-developer-cli/pull/2512)" + }, + { + "type": "added", + "description": "Support for setting environment variables in MCP config - [#2241](https://github.com/aws/amazon-q-developer-cli/pull/2241)" + }, + { + "type": "fixed", + "description": "`q mcp add` from failing when the targeted `mcp.json` file does not exist - [#2561](https://github.com/aws/amazon-q-developer-cli/pull/2561)" + } + ] + }, + { + "type": "release", + "date": "2025-08-08", + "version": "1.13.2", + "title": "Version 1.13.2", + "changes": [ + { + "type": "added", + "description": "Regex matching for the `toolsSettings.execute_bash.allowedCommands` agent configuration - [#2483](https://github.com/aws/amazon-q-developer-cli/pull/2483)" + }, + { + "type": "added", + "description": "Support for workspace `mcp.json` configuration when `useLegacyMcpJson` is enabled - [#2516](https://github.com/aws/amazon-q-developer-cli/pull/2516)" + }, + { + "type": "added", + "description": "`/context show` to differentiate between agent and session context - [#2494](https://github.com/aws/amazon-q-developer-cli/pull/2494)" + }, + { + "type": "fixed", + "description": "Issues with `q mcp` subcommands failing - [#2475](https://github.com/aws/amazon-q-developer-cli/pull/2475)" + }, + { + "type": "fixed", + "description": "The `knowledge` tool always requiring permission - [#2501](https://github.com/aws/amazon-q-developer-cli/pull/2501)" + } + ] + }, + { + "type": "release", + "date": "2025-08-01", + "version": "1.13.1", + "title": "Version 1.13.1", + "changes": [ + { + "type": "added", + "description": "JSON schema support for the agent specification. Try it with `/agent create` - [#2440](https://github.com/aws/amazon-q-developer-cli/pull/2440)" + }, + { + "type": "deprecated", + "description": "The `/profile` command - [#2468](https://github.com/aws/amazon-q-developer-cli/pull/2468)" + }, + { + "type": "fixed", + "description": "Tool permissioning not being reset - [#2469](https://github.com/aws/amazon-q-developer-cli/pull/2469)" + }, + { + "type": "fixed", + "description": "An issue with history compaction not being applied on context overflow" + } + ] + }, + { + "type": "release", + "date": "2025-07-31", + "version": "1.13.0", + "title": "Version 1.13.0", + "changes": [ + { + "type": "added", + "description": "A new paradigm for working with `q chat` using agents. [See the documentation for more details](https://github.com/aws/amazon-q-developer-cli/blob/main/docs/SUMMARY.md)" + }, + { + "type": "added", + "description": "A new setting to disable markdown rendering in `qchat` with `chat.disableMarkdownRendering` - [#2223](https://github.com/aws/amazon-q-developer-cli/pull/2223)" + }, + { + "type": "added", + "description": "A new setting to disable markdown rendering in `qchat` with `chat.disableMarkdownRendering` - [#2236](https://github.com/aws/amazon-q-developer-cli/pull/2236)" + }, + { + "type": "fixed", + "description": "An issue with `/compact` failing for large initial messages - [#2375](https://github.com/aws/amazon-q-developer-cli/pull/2375)" + }, + { + "type": "fixed", + "description": "Images being removed from the conversation history - [#2333](https://github.com/aws/amazon-q-developer-cli/pull/2333)" + }, + { + "type": "fixed", + "description": "Code block detection for multi-line input - [#2384](https://github.com/aws/amazon-q-developer-cli/pull/2384)" + } + ] + }, + { + "type": "release", + "date": "2025-07-22", + "version": "1.12.7", + "title": "Version 1.12.7", + "changes": [ + { + "type": "fixed", + "description": "Issues with `q chat` requests not being cached correctly - [#461](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/461)" + } + ] + }, + { + "type": "release", + "date": "2025-07-17", + "version": "1.12.6", + "title": "Version 1.12.6", + "changes": [ + { + "type": "fixed", + "description": "Issues with read-only commands with the `execute_bash` tool - [#444](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/444)" + } + ] + }, + { + "type": "release", + "date": "2025-07-14", + "version": "1.12.5", + "title": "Version 1.12.5", + "changes": [ + { + "type": "added", + "description": "(Experimental) Support for Sigv4 authentication with `q chat`. Launch chat with the environment variable `AMAZON_Q_SIGV4=1` - [#207](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/207)" + }, + { + "type": "fixed", + "description": "An issue with authentication failing for long chat sessions - [#424](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/424)" + }, + { + "type": "fixed", + "description": "Issues with parsing `/compact` and `/editor` arguments - [#425](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/425)" + }, + { + "type": "changed", + "description": "`q chat` inline prompt hints to be disabled by default. To enable, run `q settings chat.enableHistoryHints true` - [#429](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/429)" + } + ] + }, { "type": "release", "date": "2025-07-09", diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs index 6ac5bf6a0c..62beb50f48 100644 --- a/crates/chat-cli/src/cli/mod.rs +++ b/crates/chat-cli/src/cli/mod.rs @@ -2,7 +2,7 @@ mod agent; pub mod chat; mod debug; mod diagnostics; -mod feed; +pub mod feed; mod issue; mod mcp; mod settings; diff --git a/crates/chat-cli/src/database/mod.rs b/crates/chat-cli/src/database/mod.rs index 9b5a48ee10..b184ea6fef 100644 --- a/crates/chat-cli/src/database/mod.rs +++ b/crates/chat-cli/src/database/mod.rs @@ -274,6 +274,28 @@ impl Database { .and_then(|s| Uuid::from_str(&s).ok())) } + /// Get changelog last version from state table + pub fn get_changelog_last_version(&self) -> Result, DatabaseError> { + self.get_entry::(Table::State, "changelog.lastVersion") + } + + /// Set changelog last version in state table + pub fn set_changelog_last_version(&self, version: &str) -> Result<(), DatabaseError> { + self.set_entry(Table::State, "changelog.lastVersion", version)?; + Ok(()) + } + + /// Get changelog show count from state table + pub fn get_changelog_show_count(&self) -> Result, DatabaseError> { + self.get_entry::(Table::State, "changelog.showCount") + } + + /// Set changelog show count in state table + pub fn set_changelog_show_count(&self, count: i64) -> Result<(), DatabaseError> { + self.set_entry(Table::State, "changelog.showCount", count)?; + Ok(()) + } + /// Set the client ID used for telemetry requests. pub fn set_client_id(&mut self, client_id: Uuid) -> Result { self.set_json_entry(Table::State, CLIENT_ID_KEY, client_id.to_string()) diff --git a/crates/chat-cli/src/util/mod.rs b/crates/chat-cli/src/util/mod.rs index ac48310ad5..648d90cad1 100644 --- a/crates/chat-cli/src/util/mod.rs +++ b/crates/chat-cli/src/util/mod.rs @@ -7,6 +7,7 @@ pub mod spinner; pub mod system_info; #[cfg(test)] pub mod test; +pub mod ui; use std::fmt::Display; use std::io; diff --git a/crates/chat-cli/src/util/ui.rs b/crates/chat-cli/src/util/ui.rs new file mode 100644 index 0000000000..ee93f0abae --- /dev/null +++ b/crates/chat-cli/src/util/ui.rs @@ -0,0 +1,136 @@ +use std::io::Write; + +use crossterm::execute; +use crossterm::style::{ + self, + Attribute, + Color, +}; +use eyre::Result; + +use crate::cli::feed::Feed; + +/// Render changelog content from feed.json with manual formatting +pub fn render_changelog_content(output: &mut impl Write) -> Result<()> { + let feed = Feed::load(); + let recent_entries = feed.get_all_changelogs() + .into_iter() + .take(2) // Show last 2 releases + .collect::>(); + + execute!(output, style::Print("\n"))?; + + // Title + execute!( + output, + style::SetForegroundColor(Color::Magenta), + style::SetAttribute(Attribute::Bold), + style::Print("What's New in Amazon Q CLI\n\n"), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + + // Render recent entries + for entry in recent_entries { + // Show version header + execute!( + output, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!("## {} ({})\n", entry.version, entry.date)), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + + for change in &entry.changes { + // Process **bold** syntax and remove PR links + let cleaned_description = clean_pr_links(&change.description); + let processed_description = process_bold_text(&cleaned_description); + execute!(output, style::Print("• "))?; + print_with_bold(output, &processed_description)?; + execute!(output, style::Print("\n"))?; + } + execute!(output, style::Print("\n"))?; // Add spacing between versions + } + + execute!( + output, + style::Print("\nRun `/changelog` anytime to see the latest updates and features!\n\n") + )?; + Ok(()) +} + +/// Removes PR links and numbers from changelog descriptions to improve readability. +/// +/// Removes text matching the pattern " - [#NUMBER](URL)" from the end of descriptions. +/// +/// Example input: "A new feature - [#2711](https://github.com/aws/amazon-q-developer-cli/pull/2711)" +/// Example output: "A new feature" +fn clean_pr_links(text: &str) -> String { + // Remove PR links like " - [#2711](https://github.com/aws/amazon-q-developer-cli/pull/2711)" + if let Some(pos) = text.find(" - [#") { + text[..pos].to_string() + } else { + text.to_string() + } +} + +/// Processes text to identify **bold** markdown syntax and returns segments with formatting info. +/// +/// Returns a vector of tuples where each tuple contains: +/// - `String`: The text segment +/// - `bool`: Whether this segment should be rendered in bold +/// +/// Example input: "This is **bold** text" +/// Example output: [("This is ", false), ("bold", true), (" text", false)] +fn process_bold_text(text: &str) -> Vec<(String, bool)> { + let mut result = Vec::new(); + let mut current = String::new(); + let mut in_bold = false; + let mut chars = text.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '*' && chars.peek() == Some(&'*') { + chars.next(); // consume second * + if !current.is_empty() { + result.push((current.clone(), in_bold)); + current.clear(); + } + in_bold = !in_bold; + } else { + current.push(ch); + } + } + + if !current.is_empty() { + result.push((current, in_bold)); + } + + result +} + +/// Renders text segments with proper bold formatting using crossterm. +/// +/// # Arguments +/// +/// * `output` - The writer to output formatted text to +/// * `segments` - Vector of (text, is_bold) tuples from `process_bold_text` +/// +/// # Errors +/// +/// Returns an error if writing to the output fails. +fn print_with_bold(output: &mut impl Write, segments: &[(String, bool)]) -> Result<()> { + for (text, is_bold) in segments { + if *is_bold { + execute!( + output, + style::SetAttribute(Attribute::Bold), + style::Print(text), + style::SetAttribute(Attribute::Reset), + )?; + } else { + execute!(output, style::Print(text))?; + } + } + Ok(()) +} diff --git a/scripts/build.py b/scripts/build.py index edbb8c27bc..2176045d74 100644 --- a/scripts/build.py +++ b/scripts/build.py @@ -87,6 +87,7 @@ def build_chat_bin( env={ **os.environ, **rust_env(release=release), + "FETCH_FEED": "1", # Always fetch latest feed.json for official builds }, ) From 39a09648ac4ecc034b9efcc3397d073e70dc4852 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Thu, 11 Sep 2025 11:10:41 -0700 Subject: [PATCH 17/71] feat(mcp): enables remote mcp (#2836) --- Cargo.lock | 50 +++ Cargo.toml | 2 +- .../chat-cli/src/cli/chat/server_messenger.rs | 15 + crates/chat-cli/src/cli/chat/tool_manager.rs | 101 ++++- .../src/cli/chat/tools/custom_tool.rs | 34 +- crates/chat-cli/src/mcp_client/client.rs | 342 +++++++++++++++-- crates/chat-cli/src/mcp_client/messenger.rs | 8 + crates/chat-cli/src/mcp_client/mod.rs | 2 + crates/chat-cli/src/mcp_client/oauth_util.rs | 361 ++++++++++++++++++ crates/chat-cli/src/util/directories.rs | 8 + 10 files changed, 871 insertions(+), 52 deletions(-) create mode 100644 crates/chat-cli/src/mcp_client/oauth_util.rs diff --git a/Cargo.lock b/Cargo.lock index c7f5d0043a..869e29ea84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4231,6 +4231,26 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "oauth2" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" +dependencies = [ + "base64 0.21.7", + "chrono", + "getrandom 0.2.16", + "http 1.3.1", + "rand 0.8.5", + "reqwest", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror 1.0.69", + "url", +] + [[package]] name = "objc-sys" version = "0.3.5" @@ -5243,18 +5263,24 @@ dependencies = [ "base64 0.22.1", "chrono", "futures", + "http 1.3.1", + "oauth2", "paste", "pin-project-lite", "process-wrap", + "reqwest", "rmcp-macros", "schemars", "serde", "serde_json", + "sse-stream", "thiserror 2.0.14", "tokio", "tokio-stream", "tokio-util", + "tower-service", "tracing", + "url", ] [[package]] @@ -5715,6 +5741,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -5945,6 +5981,19 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -6876,6 +6925,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 0bbee837a0..99d56615c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -129,7 +129,7 @@ winnow = "=0.6.2" winreg = "0.55.0" schemars = "1.0.4" jsonschema = "0.30.0" -rmcp = { version = "0.6.0", features = ["client", "transport-child-process"] } +rmcp = { version = "0.6.3", features = ["client", "transport-sse-client-reqwest", "reqwest", "transport-streamable-http-client-reqwest", "transport-child-process", "tower", "auth"] } [workspace.lints.rust] future_incompatible = "warn" diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs index a15bd8d696..be86d891a9 100644 --- a/crates/chat-cli/src/cli/chat/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/server_messenger.rs @@ -44,6 +44,10 @@ pub enum UpdateEventMessage { result: Result, peer: Option>, }, + OauthLink { + server_name: String, + link: String, + }, InitStart { server_name: String, }, @@ -146,6 +150,17 @@ impl Messenger for ServerMessenger { .map_err(|e| MessengerError::Custom(e.to_string()))?) } + async fn send_oauth_link(&self, link: String) -> MessengerResult { + Ok(self + .update_event_sender + .send(UpdateEventMessage::OauthLink { + server_name: self.server_name.clone(), + link, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + async fn send_init_msg(&self) -> MessengerResult { Ok(self .update_event_sender diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 7ca372779c..d2fd119ce1 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -90,6 +90,7 @@ use crate::database::settings::Setting; use crate::mcp_client::messenger::Messenger; use crate::mcp_client::{ InitializedMcpClient, + InnerService, McpClientService, }; use crate::os::Os; @@ -137,6 +138,11 @@ enum LoadingMsg { /// This is sent when all tool initialization is complete or when the application is shutting /// down. Terminate { still_loading: Vec }, + /// Indicates that a server requires user authentication and provides a sign-in link. + /// This message is used to notify the user about authentication requirements for MCP servers + /// that need OAuth or other authentication methods. Contains the server name and the + /// authentication message (typically a URL or instructions). + SignInNotice { name: String }, } /// Used to denote the loading outcome associated with a server. @@ -630,9 +636,14 @@ impl ToolManager { let server_name_clone = server_name.clone(); tokio::spawn(async move { match handle.await { - Ok(Ok(client)) => match client.cancel().await { - Ok(_) => info!("Server {server_name_clone} evicted due to agent swap"), - Err(e) => error!("Server {server_name_clone} has failed to cancel: {e}"), + Ok(Ok(client)) => { + let InnerService::Original(client) = client.inner_service else { + unreachable!(); + }; + match client.cancel().await { + Ok(_) => info!("Server {server_name_clone} evicted due to agent swap"), + Err(e) => error!("Server {server_name_clone} has failed to cancel: {e}"), + } }, Ok(Err(_)) | Err(_) => { error!("Server {server_name_clone} has failed to cancel"); @@ -640,9 +651,14 @@ impl ToolManager { } }); }, - InitializedMcpClient::Ready(running_service) => match running_service.cancel().await { - Ok(_) => info!("Server {server_name} evicted due to agent swap"), - Err(e) => error!("Server {server_name} has failed to cancel: {e}"), + InitializedMcpClient::Ready(running_service) => { + let InnerService::Original(client) = running_service.inner_service else { + unreachable!(); + }; + match client.cancel().await { + Ok(_) => info!("Server {server_name} evicted due to agent swap"), + Err(e) => error!("Server {server_name} has failed to cancel: {e}"), + } }, } } @@ -869,17 +885,16 @@ impl ToolManager { }); }; - let running_service = (*client.get_running_service().await.map_err(|e| ToolResult { + let running_service = client.get_running_service().await.map_err(|e| ToolResult { tool_use_id: value.id.clone(), content: vec![ToolResultContentBlock::Text(format!("Mcp tool client not ready: {e}"))], status: ToolResultStatus::Error, - })?) - .clone(); + })?; Tool::Custom(CustomTool { name: tool_name.to_owned(), server_name: server_name.to_owned(), - client: running_service, + client: running_service.clone(), params: value.args.as_object().cloned(), }) }, @@ -1170,6 +1185,15 @@ fn spawn_display_task( execute!(output, style::Print("\n"),)?; break; }, + LoadingMsg::SignInNotice { name } => { + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_oauth_message(&name, &mut output)?; + }, }, Err(_e) => { spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); @@ -1595,6 +1619,35 @@ fn spawn_orchestrator_task( }, UpdateEventMessage::ListResourcesResult { .. } => {}, UpdateEventMessage::ResourceTemplatesListResult { .. } => {}, + UpdateEventMessage::OauthLink { server_name, link } => { + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + let msg = eyre::eyre!(link); + let _ = queue_oauth_message_with_link(server_name.as_str(), &msg, &mut buf_writer); + let _ = buf_writer.flush(); + drop(buf_writer); + let record_str = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = LoadingRecord::Warn(record_str.clone()); + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + if let Some(sender) = &loading_status_sender { + let msg = LoadingMsg::SignInNotice { + name: server_name.clone(), + }; + if let Err(e) = sender.send(msg).await { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + loading_status_sender.take(); + } + } + }, UpdateEventMessage::InitStart { server_name, .. } => { pending.write().await.insert(server_name.clone()); loading_servers.insert(server_name, std::time::Instant::now()); @@ -1876,6 +1929,34 @@ fn queue_failure_message( )?) } +fn queue_oauth_message(name: &str, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" requires OAuth authentication. Use /mcp to see the auth link\n"), + )?) +} + +fn queue_oauth_message_with_link(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" requires OAuth authentication. Follow this link to proceed: \n"), + style::SetForegroundColor(style::Color::Yellow), + style::Print(msg), + style::ResetColor, + style::Print("\n") + )?) +} + fn queue_warn_message(name: &str, msg: &eyre::Report, time: &str, output: &mut impl Write) -> eyre::Result<()> { Ok(queue!( output, diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index be2a7a8831..be5acd37fd 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -7,7 +7,6 @@ use crossterm::{ style, }; use eyre::Result; -use rmcp::RoleClient; use rmcp::model::CallToolRequestParam; use schemars::JsonSchema; use serde::{ @@ -23,14 +22,39 @@ use crate::cli::agent::{ }; use crate::cli::chat::CONTINUATION_LINE; use crate::cli::chat::token_counter::TokenCounter; +use crate::mcp_client::RunningService; use crate::os::Os; use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::util::pattern_matching::matches_any_pattern; -// TODO: support http transport type +#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub enum TransportType { + /// Standard input/output transport (default) + Stdio, + /// HTTP transport for web-based communication + Http, +} + +impl Default for TransportType { + fn default() -> Self { + Self::Stdio + } +} + #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] pub struct CustomToolConfig { + /// The type of transport the mcp server is expecting + #[serde(default)] + pub r#type: TransportType, + /// The URL endpoint for HTTP-based MCP servers + #[serde(default)] + pub url: String, + /// HTTP headers to include when communicating with HTTP-based MCP servers + #[serde(default)] + pub headers: HashMap, /// The command string used to initialize the mcp server + #[serde(default)] pub command: String, /// A list of arguments to be used to run the command with #[serde(default)] @@ -64,20 +88,20 @@ pub struct CustomTool { /// prefixed to the tool name when presented to the model for disambiguation. pub server_name: String, /// Reference to the client that manages communication with the tool's server process. - pub client: rmcp::Peer, + pub client: RunningService, /// Optional parameters to pass to the tool when invoking the method. /// Structured as a JSON value to accommodate various parameter types and structures. pub params: Option>, } impl CustomTool { - pub async fn invoke(&self, _os: &Os, _updates: impl Write) -> Result { + pub async fn invoke(&self, _os: &Os, _updates: &mut impl Write) -> Result { let params = CallToolRequestParam { name: Cow::from(self.name.clone()), arguments: self.params.clone(), }; - let resp = self.client.call_tool(params).await?; + let resp = self.client.call_tool(params.clone()).await?; if resp.is_error.is_none_or(|v| !v) { Ok(InvokeOutput { diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 92c33f074c..c50d43a585 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -3,8 +3,13 @@ use std::collections::HashMap; use std::process::Stdio; use regex::Regex; +use reqwest::Client; use rmcp::model::{ + CallToolRequestParam, + CallToolResult, ErrorCode, + GetPromptRequestParam, + GetPromptResult, Implementation, InitializeRequestParam, ListPromptsResult, @@ -17,8 +22,10 @@ use rmcp::model::{ }; use rmcp::service::{ ClientInitializeError, + DynService, NotificationContext, }; +use rmcp::transport::auth::AuthClient; use rmcp::transport::{ ConfigureCommandExt, TokioChildProcess, @@ -31,14 +38,31 @@ use rmcp::{ ServiceExt, }; use tokio::io::AsyncReadExt as _; -use tokio::process::Command; +use tokio::process::{ + ChildStderr, + Command, +}; use tokio::task::JoinHandle; -use tracing::error; +use tracing::{ + debug, + error, + info, +}; use super::messenger::Messenger; +use super::oauth_util::HttpTransport; +use super::{ + AuthClientDropGuard, + OauthUtilError, + get_http_transport, +}; use crate::cli::chat::server_messenger::ServerMessenger; -use crate::cli::chat::tools::custom_tool::CustomToolConfig; +use crate::cli::chat::tools::custom_tool::{ + CustomToolConfig, + TransportType, +}; use crate::os::Os; +use crate::util::directories::DirectoryError; /// Fetches all pages of specified resources from a server macro_rules! paginated_fetch { @@ -118,9 +142,153 @@ pub enum McpClientError { JoinError(#[from] tokio::task::JoinError), #[error("Client has not finished initializing")] NotReady, + #[error(transparent)] + Directory(#[from] DirectoryError), + #[error(transparent)] + OauthUtil(#[from] OauthUtilError), + #[error(transparent)] + Parse(#[from] url::ParseError), + #[error(transparent)] + Auth(#[from] crate::auth::AuthError), +} + +macro_rules! decorate_with_auth_retry { + ($param_type:ty, $method_name:ident, $return_type:ty) => { + pub async fn $method_name(&self, param: $param_type) -> Result<$return_type, rmcp::ServiceError> { + let first_attempt = match &self.inner_service { + InnerService::Original(rs) => rs.$method_name(param.clone()).await, + InnerService::Peer(peer) => peer.$method_name(param.clone()).await, + }; + + match first_attempt { + Ok(result) => Ok(result), + Err(e) => { + // TODO: discern error type prior to retrying + // Not entirely sure what is thrown when auth is required + if let Some(auth_client) = self.get_auth_client() { + let refresh_result = auth_client.get_access_token().await; + match refresh_result { + Ok(_) => { + // Retry the operation after token refresh + match &self.inner_service { + InnerService::Original(rs) => rs.$method_name(param).await, + InnerService::Peer(peer) => peer.$method_name(param).await, + } + }, + Err(_) => { + // If refresh fails, return the original error + // Currently our event loop just does not allow us easy ways to + // reauth entirely once a session starts since this would mean + // swapping of transport (which also means swapping of client) + Err(e) + }, + } + } else { + // No auth client available, return original error + Err(e) + } + }, + } + } + }; +} + +/// Wrapper around rmcp service types to enable cloning. +/// +/// This exists because `rmcp::service::RunningService` is not directly cloneable as it is a +/// pointer type to `Peer`. This enum allows us to hold either the original service or its +/// peer representation, enabling cloning by converting the original service to a peer when needed. +pub enum InnerService { + Original(rmcp::service::RunningService>>), + Peer(rmcp::service::Peer), +} + +impl std::fmt::Debug for InnerService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InnerService::Original(_) => f.debug_tuple("Original").field(&"RunningService<..>").finish(), + InnerService::Peer(peer) => f.debug_tuple("Peer").field(peer).finish(), + } + } +} + +impl Clone for InnerService { + fn clone(&self) -> Self { + match self { + InnerService::Original(rs) => InnerService::Peer((*rs).clone()), + InnerService::Peer(peer) => InnerService::Peer(peer.clone()), + } + } +} + +/// A wrapper around MCP (Model Context Protocol) service instances that manages +/// authentication and enables cloning functionality. +/// +/// This struct holds either an original `RunningService` or its peer representation, +/// along with an optional authentication drop guard for managing OAuth tokens. +/// The authentication drop guard handles token lifecycle and cleanup when the +/// service is dropped. +/// +/// # Fields +/// * `inner_service` - The underlying MCP service instance (original or peer) +/// * `auth_dropguard` - Optional authentication manager for OAuth token handling +#[derive(Debug)] +pub struct RunningService { + pub inner_service: InnerService, + auth_dropguard: Option, +} + +impl Clone for RunningService { + fn clone(&self) -> Self { + let auth_dropguard = self.auth_dropguard.as_ref().map(|dg| { + let mut dg = dg.clone(); + dg.should_write = false; + dg + }); + + RunningService { + inner_service: self.inner_service.clone(), + auth_dropguard, + } + } } -pub type RunningService = rmcp::service::RunningService; +impl RunningService { + decorate_with_auth_retry!(CallToolRequestParam, call_tool, CallToolResult); + + decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult); + + pub fn get_auth_client(&self) -> Option> { + self.auth_dropguard.as_ref().map(|a| a.auth_client.clone()) + } +} + +pub type StdioTransport = (TokioChildProcess, Option); + +// TODO: add sse support (even though it's deprecated) +/// Represents the different transport mechanisms available for MCP (Model Context Protocol) +/// communication. +/// +/// This enum encapsulates the two primary ways to communicate with MCP servers: +/// - HTTP-based transport for remote servers +/// - Standard I/O transport for local process-based servers +pub enum Transport { + /// HTTP transport for communicating with remote MCP servers over network protocols. + /// Uses a streamable HTTP client with authentication support. + Http(HttpTransport), + /// Standard I/O transport for communicating with local MCP servers via child processes. + /// Communication happens through stdin/stdout pipes. + Stdio(StdioTransport), +} + +impl std::fmt::Debug for Transport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Transport::Http(_) => f.debug_tuple("Http").field(&"HttpTransport").finish(), + Transport::Stdio(_) => f.debug_tuple("Stdio").field(&"TokioChildProcess").finish(), + } + } +} /// This struct implements the [Service] trait from rmcp. It is within this trait the logic of /// server driven data flow (i.e. requests and notifications that are sent from the server) are @@ -145,44 +313,98 @@ impl McpClientService { let os_clone = os.clone(); let handle: JoinHandle> = tokio::spawn(async move { - let CustomToolConfig { - command: command_as_str, - args, - env: config_envs, - .. - } = &mut self.config; - - let command = Command::new(command_as_str).configure(|cmd| { - if let Some(envs) = config_envs { - process_env_vars(envs, &os_clone.env); - cmd.envs(envs); - } - cmd.envs(std::env::vars()).args(args); - - #[cfg(not(windows))] - cmd.process_group(0); - }); - - let messenger_clone = self.messenger.duplicate(); + let messenger_clone = self.messenger.clone(); let server_name = self.server_name.clone(); + let backup_config = self.config.clone(); let result: Result<_, McpClientError> = async { - // Spawn the child process with stderr piped - let (tokio_child_process, child_stderr) = - TokioChildProcess::builder(command).stderr(Stdio::piped()).spawn()?; + let messenger_dup = messenger_clone.duplicate(); + let (service, stderr, auth_client) = match self.get_transport(&os_clone, &*messenger_dup).await? { + Transport::Stdio((child_process, stderr)) => { + let service = self + .into_dyn() + .serve::(child_process) + .await + .map_err(Box::new)?; + + (service, stderr, None) + }, + Transport::Http(http_transport) => { + match http_transport { + HttpTransport::WithAuth((transport, mut auth_dg)) => { + // The crate does not automatically refresh tokens when they expire. We + // would need to handle that here + let url = self.config.url.clone(); + let service = match self.into_dyn().serve(transport).await.map_err(Box::new) { + Ok(service) => service, + Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => { + debug!("## mcp: first hand shake attempt failed: {:?}", e); + let refresh_res = + auth_dg.auth_client.get_access_token().await; + let new_self = McpClientService::new( + server_name.clone(), + backup_config, + messenger_clone.clone(), + ); + + let new_transport = + get_http_transport(&os_clone, true, &url, Some(auth_dg.auth_client.clone()), &*messenger_dup).await?; + + match new_transport { + HttpTransport::WithAuth((new_transport, new_auth_dg)) => { + auth_dg.should_write = false; + auth_dg = new_auth_dg; + + match refresh_res { + Ok(_token) => { + new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? + }, + Err(e) => { + error!("## mcp: token refresh attempt failed: {:?}", e); + info!("Retry for http transport failed {e}. Possible reauth needed"); + // This could be because the refresh token is expired, in which + // case we would need to have user go through the auth flow + // again + let new_transport = + get_http_transport(&os_clone, true, &url, None, &*messenger_dup).await?; + + match new_transport { + HttpTransport::WithAuth((new_transport, new_auth_dg)) => { + auth_dg = new_auth_dg; + auth_dg.should_write = false; + new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? + }, + HttpTransport::WithoutAuth(new_transport) => { + new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? + }, + } + }, + } + }, + HttpTransport::WithoutAuth(new_transport) => + new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?, + } + }, + Err(e) => return Err(e.into()), + }; + + (service, None, Some(auth_dg)) + }, + HttpTransport::WithoutAuth(transport) => { + let service = self.into_dyn().serve(transport).await.map_err(Box::new)?; - // Attempt to serve the process - let service = self - .serve::(tokio_child_process) - .await - .map_err(Box::new)?; + (service, None, None) + }, + } + }, + }; - Ok((service, child_stderr)) + Ok((service, stderr, auth_client)) } .await; - let (service, child_stderr) = match result { - Ok((service, stderr)) => (service, stderr), + let (service, child_stderr, auth_dropguard) = match result { + Ok((service, stderr, auth_dg)) => (service, stderr, auth_dg), Err(e) => { let msg = e.to_string(); let error_data = ErrorData { @@ -262,12 +484,52 @@ impl McpClientService { } }); - Ok(service) + Ok(RunningService { + inner_service: InnerService::Original(service), + auth_dropguard, + }) }); Ok(InitializedMcpClient::Pending(handle)) } + async fn get_transport(&mut self, os: &Os, messenger: &dyn Messenger) -> Result { + // TODO: figure out what to do with headers + let CustomToolConfig { + r#type: transport_type, + url, + command: command_as_str, + args, + env: config_envs, + .. + } = &mut self.config; + + match transport_type { + TransportType::Stdio => { + let command = Command::new(command_as_str).configure(|cmd| { + if let Some(envs) = config_envs { + process_env_vars(envs, &os.env); + cmd.envs(envs); + } + cmd.envs(std::env::vars()).args(args); + + #[cfg(not(windows))] + cmd.process_group(0); + }); + + let (tokio_child_process, child_stderr) = + TokioChildProcess::builder(command).stderr(Stdio::piped()).spawn()?; + + Ok(Transport::Stdio((tokio_child_process, child_stderr))) + }, + TransportType::Http => { + let http_transport = get_http_transport(os, false, url, None, messenger).await?; + + Ok(Transport::Http(http_transport)) + }, + } + } + async fn on_logging_message( &self, params: LoggingMessageNotificationParam, @@ -389,12 +651,20 @@ impl Service for McpClientService { /// The solution chosen here is to instead spawn a task and have [Service::serve] called there and /// return the handle to said task, stored in the [InitializedMcpClient::Pending] variant. This /// enum is then flipped lazily (if applicable) when a [RunningService] is needed. -#[derive(Debug)] pub enum InitializedMcpClient { Pending(JoinHandle>), Ready(RunningService), } +impl std::fmt::Debug for InitializedMcpClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InitializedMcpClient::Pending(_) => f.debug_tuple("Pending").field(&"JoinHandle<..>").finish(), + InitializedMcpClient::Ready(_) => f.debug_tuple("Ready").field(&"RunningService<..>").finish(), + } + } +} + impl InitializedMcpClient { pub async fn get_running_service(&mut self) -> Result<&RunningService, McpClientError> { match self { diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs index e9202b7dae..40e9bc84ca 100644 --- a/crates/chat-cli/src/mcp_client/messenger.rs +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -53,6 +53,10 @@ pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { peer: Option>, ) -> MessengerResult; + /// Sends an OAuth authorization link to the consumer + /// This function is used to deliver OAuth links that users need to visit for authentication + async fn send_oauth_link(&self, link: String) -> MessengerResult; + /// Signals to the orchestrator that a server has started initializing async fn send_init_msg(&self) -> MessengerResult; @@ -107,6 +111,10 @@ impl Messenger for NullMessenger { Ok(()) } + async fn send_oauth_link(&self, _link: String) -> MessengerResult { + Ok(()) + } + async fn send_init_msg(&self) -> MessengerResult { Ok(()) } diff --git a/crates/chat-cli/src/mcp_client/mod.rs b/crates/chat-cli/src/mcp_client/mod.rs index 7bc6d76f5a..432e4421f3 100644 --- a/crates/chat-cli/src/mcp_client/mod.rs +++ b/crates/chat-cli/src/mcp_client/mod.rs @@ -1,4 +1,6 @@ pub mod client; pub mod messenger; +pub mod oauth_util; pub use client::*; +pub use oauth_util::*; diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs new file mode 100644 index 0000000000..a705454b98 --- /dev/null +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -0,0 +1,361 @@ +use std::net::SocketAddr; +use std::path::PathBuf; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; + +use http::StatusCode; +use http_body_util::Full; +use hyper::Response; +use hyper::body::Bytes; +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use reqwest::Client; +use rmcp::serde_json; +use rmcp::transport::auth::{ + AuthClient, + OAuthState, + OAuthTokenResponse, +}; +use rmcp::transport::streamable_http_client::{ + StreamableHttpClientTransportConfig, + StreamableHttpClientWorker, +}; +use rmcp::transport::{ + AuthorizationManager, + StreamableHttpClientTransport, + WorkerTransport, +}; +use sha2::{ + Digest, + Sha256, +}; +use tokio::sync::oneshot::Sender; +use tokio_util::sync::CancellationToken; +use tracing::{ + debug, + error, + info, +}; +use url::Url; + +use super::messenger::Messenger; +use crate::os::Os; +use crate::util::directories::{ + DirectoryError, + get_mcp_auth_dir, +}; + +#[derive(Debug, thiserror::Error)] +pub enum OauthUtilError { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Parse(#[from] url::ParseError), + #[error(transparent)] + Auth(#[from] rmcp::transport::AuthError), + #[error(transparent)] + Serde(#[from] serde_json::Error), + #[error("Missing authorization manager")] + MissingAuthorizationManager, + #[error(transparent)] + OneshotRecv(#[from] tokio::sync::oneshot::error::RecvError), + #[error(transparent)] + Directory(#[from] DirectoryError), + #[error(transparent)] + Reqwest(#[from] reqwest::Error), +} + +/// A guard that automatically cancels the cancellation token when dropped. +/// This ensures that the OAuth loopback server is properly cleaned up +/// when the guard goes out of scope. +struct LoopBackDropGuard { + cancellation_token: CancellationToken, +} + +impl Drop for LoopBackDropGuard { + fn drop(&mut self) { + self.cancellation_token.cancel(); + } +} + +/// A guard that manages the lifecycle of an authenticated MCP client and automatically +/// persists OAuth credentials when dropped. +/// +/// This struct wraps an `AuthClient` and ensures that OAuth tokens are written to disk +/// when the guard goes out of scope, unless explicitly disabled via `should_write`. +/// This provides automatic credential caching for MCP server connections that require +/// OAuth authentication. +#[derive(Clone, Debug)] +pub struct AuthClientDropGuard { + pub should_write: bool, + pub cred_full_path: PathBuf, + pub auth_client: AuthClient, +} + +impl AuthClientDropGuard { + pub fn new(cred_full_path: PathBuf, auth_client: AuthClient) -> Self { + Self { + should_write: true, + cred_full_path, + auth_client, + } + } +} + +impl Drop for AuthClientDropGuard { + fn drop(&mut self) { + if !self.should_write { + return; + } + + let auth_client_clone = self.auth_client.clone(); + let path = self.cred_full_path.clone(); + + tokio::spawn(async move { + let Ok((client_id, cred)) = auth_client_clone.auth_manager.lock().await.get_credentials().await else { + error!("Failed to retrieve credentials in drop routine"); + return; + }; + let Some(cred) = cred else { + error!("Failed to retrieve credentials in drop routine from {client_id}"); + return; + }; + let Some(parent_path) = path.parent() else { + error!("Failed to retrieve parent path for token in drop routine for {client_id}"); + return; + }; + if let Err(e) = tokio::fs::create_dir_all(parent_path).await { + error!("Error making parent directory for token cache in drop routine for {client_id}: {e}"); + return; + } + + let serialized_cred = match serde_json::to_string_pretty(&cred) { + Ok(cred) => cred, + Err(e) => { + error!("Failed to serialize credentials for {client_id}: {e}"); + return; + }, + }; + if let Err(e) = tokio::fs::write(path, &serialized_cred).await { + error!("Error making writing token cache in drop routine: {e}"); + } + }); + } +} + +/// HTTP transport wrapper that handles both authenticated and non-authenticated MCP connections. +/// +/// This enum provides two variants for different authentication scenarios: +/// - `WithAuth`: Used when the MCP server requires OAuth authentication, containing both the +/// transport worker and an auth client guard that manages credential persistence +/// - `WithoutAuth`: Used for servers that don't require authentication, containing only the basic +/// transport worker +/// +/// The appropriate variant is automatically selected based on the server's response to +/// an initial probe request during transport creation. +pub enum HttpTransport { + WithAuth( + ( + WorkerTransport>>, + AuthClientDropGuard, + ), + ), + WithoutAuth(WorkerTransport>), +} + +pub async fn get_http_transport( + os: &Os, + delete_cache: bool, + url: &str, + auth_client: Option>, + messenger: &dyn Messenger, +) -> Result { + let cred_dir = get_mcp_auth_dir(os)?; + let url = Url::from_str(url)?; + let key = compute_key(&url); + let cred_full_path = cred_dir.join(format!("{key}.token.json")); + + if delete_cache && cred_full_path.is_file() { + tokio::fs::remove_file(&cred_full_path).await?; + } + + let reqwest_client = reqwest::Client::default(); + let probe_resp = reqwest_client.get(url.clone()).send().await?; + match probe_resp.status() { + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + debug!("## mcp: requires auth, auth client passed in is {:?}", auth_client); + let auth_client = match auth_client { + Some(auth_client) => auth_client, + None => { + let am = get_auth_manager(url.clone(), cred_full_path.clone(), messenger).await?; + AuthClient::new(reqwest_client, am) + }, + }; + let transport = + StreamableHttpClientTransport::with_client(auth_client.clone(), StreamableHttpClientTransportConfig { + uri: url.as_str().into(), + allow_stateless: false, + ..Default::default() + }); + + let auth_dg = AuthClientDropGuard::new(cred_full_path, auth_client); + debug!("## mcp: transport obtained"); + + Ok(HttpTransport::WithAuth((transport, auth_dg))) + }, + _ => { + let transport = StreamableHttpClientTransport::from_uri(url.as_str()); + + Ok(HttpTransport::WithoutAuth(transport)) + }, + } +} + +async fn get_auth_manager( + url: Url, + cred_full_path: PathBuf, + messenger: &dyn Messenger, +) -> Result { + let content_as_bytes = tokio::fs::read(&cred_full_path).await; + let mut oauth_state = OAuthState::new(url, None).await?; + + match content_as_bytes { + Ok(bytes) => { + let token = serde_json::from_slice::(&bytes)?; + + oauth_state.set_credentials("id", token).await?; + + debug!("## mcp: credentials set with cache"); + + Ok(oauth_state + .into_authorization_manager() + .ok_or(OauthUtilError::MissingAuthorizationManager)?) + }, + Err(e) => { + info!("Error reading cached credentials: {e}"); + debug!("## mcp: cache read failed. constructing auth manager from scratch"); + get_auth_manager_impl(oauth_state, messenger).await + }, + } +} + +async fn get_auth_manager_impl( + mut oauth_state: OAuthState, + messenger: &dyn Messenger, +) -> Result { + let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let (tx, rx) = tokio::sync::oneshot::channel::(); + + let (actual_addr, _dg) = make_svc(tx, socket_addr, cancellation_token).await?; + info!("Listening on local host port {:?} for oauth", actual_addr); + + oauth_state + .start_authorization(&["mcp", "profile", "email"], &format!("http://{}", actual_addr)) + .await?; + + let auth_url = oauth_state.get_authorization_url().await?; + _ = messenger.send_oauth_link(auth_url).await; + + let auth_code = rx.await?; + oauth_state.handle_callback(&auth_code).await?; + let am = oauth_state + .into_authorization_manager() + .ok_or(OauthUtilError::MissingAuthorizationManager)?; + + Ok(am) +} + +pub fn compute_key(rs: &Url) -> String { + let mut hasher = Sha256::new(); + let input = format!("{}{}", rs.origin().ascii_serialization(), rs.path()); + hasher.update(input.as_bytes()); + format!("{:x}", hasher.finalize()) +} + +async fn make_svc( + one_shot_sender: Sender, + socket_addr: SocketAddr, + cancellation_token: CancellationToken, +) -> Result<(SocketAddr, LoopBackDropGuard), OauthUtilError> { + #[derive(Clone, Debug)] + struct LoopBackForSendingAuthCode { + one_shot_sender: Arc>>>, + } + + #[derive(Debug, thiserror::Error)] + enum LoopBackError { + #[error("Poison error encountered: {0}")] + Poison(String), + #[error(transparent)] + Http(#[from] http::Error), + #[error("Failed to send auth code: {0}")] + Send(String), + } + + fn mk_response(s: String) -> Result>, LoopBackError> { + Ok(Response::builder().body(Full::new(Bytes::from(s)))?) + } + + impl hyper::service::Service> for LoopBackForSendingAuthCode { + type Error = LoopBackError; + type Future = Pin> + Send>>; + type Response = Response>; + + fn call(&self, req: hyper::Request) -> Self::Future { + let uri = req.uri(); + let query = uri.query().unwrap_or(""); + let params: std::collections::HashMap = + url::form_urlencoded::parse(query.as_bytes()).into_owned().collect(); + + let self_clone = self.clone(); + Box::pin(async move { + let code = params.get("code").cloned().unwrap_or_default(); + if let Some(sender) = self_clone + .one_shot_sender + .lock() + .map_err(|e| LoopBackError::Poison(e.to_string()))? + .take() + { + sender.send(code).map_err(LoopBackError::Send)?; + } + mk_response("Auth code sent".to_string()) + }) + } + } + + let listener = tokio::net::TcpListener::bind(socket_addr).await?; + let actual_addr = listener.local_addr()?; + let cancellation_token_clone = cancellation_token.clone(); + let dg = LoopBackDropGuard { + cancellation_token: cancellation_token_clone, + }; + + let loop_back = LoopBackForSendingAuthCode { + one_shot_sender: Arc::new(std::sync::Mutex::new(Some(one_shot_sender))), + }; + + // This is one and done + // This server only needs to last as long as it takes to send the auth code or to fail the auth + // flow + tokio::spawn(async move { + let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); + + tokio::select! { + _ = cancellation_token.cancelled() => { + info!("Oauth loopback server cancelled"); + }, + res = http1::Builder::new().serve_connection(io, loop_back) => { + if let Err(err) = res { + error!("Auth code loop back has failed: {:?}", err); + } + } + } + + Ok::<(), eyre::Report>(()) + }); + + Ok((actual_addr, dg)) +} diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index 66bfbcbdf7..89c6f3bc4e 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -241,6 +241,14 @@ pub fn agent_knowledge_dir(os: &Os, agent: Option<&crate::cli::Agent>) -> Result Ok(knowledge_bases_dir(os)?.join(unique_id)) } +/// The directory for MCP authentication cache +/// +/// This is the same directory used by IDE for SSO cache storage. +/// - All platforms: `$HOME/.aws/sso/cache` +pub fn get_mcp_auth_dir(os: &Os) -> Result { + Ok(home_dir(os)?.join(".aws").join("sso").join("cache")) +} + /// Generate a unique identifier for an agent based on its path and name fn generate_agent_unique_id(agent: &crate::cli::Agent) -> String { use std::collections::hash_map::DefaultHasher; From f27de810b487aaafa9ba819778993275feca1686 Mon Sep 17 00:00:00 2001 From: abhraina-aws Date: Thu, 11 Sep 2025 11:19:36 -0700 Subject: [PATCH 18/71] feat: Add /tangent tail to preserve the last tangent conversation (#2838) Users can now keep the final question and answer from tangent mode by using `/tangent tail` instead of `/tangent`. This preserves the last Q&A pair when returning to the main conversation, making it easy to retain helpful insights discovered during exploration. - `/tangent` - exits tangent mode (existing behavior unchanged) - `/tangent tail` - exits tangent mode but keeps the last Q&A pair This enables users to safely explore topics without losing the final valuable insight that could benefit their main conversation flow. --- crates/chat-cli/src/cli/chat/cli/tangent.rs | 170 ++++++++++++------- crates/chat-cli/src/cli/chat/conversation.rs | 116 +++++++++++++ docs/tangent-mode.md | 31 ++++ 3 files changed, 254 insertions(+), 63 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/tangent.rs b/crates/chat-cli/src/cli/chat/cli/tangent.rs index 5bd04eb9bf..94184c4828 100644 --- a/crates/chat-cli/src/cli/chat/cli/tangent.rs +++ b/crates/chat-cli/src/cli/chat/cli/tangent.rs @@ -1,4 +1,7 @@ -use clap::Args; +use clap::{ + Args, + Subcommand, +}; use crossterm::execute; use crossterm::style::{ self, @@ -14,9 +17,33 @@ use crate::database::settings::Setting; use crate::os::Os; #[derive(Debug, PartialEq, Args)] -pub struct TangentArgs; +pub struct TangentArgs { + #[command(subcommand)] + pub subcommand: Option, +} + +#[derive(Debug, PartialEq, Subcommand)] +pub enum TangentSubcommand { + /// Exit tangent mode and keep the last conversation entry (user question + assistant response) + Tail, +} impl TangentArgs { + async fn send_tangent_telemetry(os: &Os, session: &ChatSession, duration_seconds: i64) { + if let Err(err) = os + .telemetry + .send_tangent_mode_session( + &os.database, + session.conversation.conversation_id().to_string(), + crate::telemetry::TelemetryResult::Succeeded, + crate::telemetry::core::TangentModeSessionArgs { duration_seconds }, + ) + .await + { + tracing::warn!(?err, "Failed to send tangent mode session telemetry"); + } + } + pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { // Check if tangent mode is enabled if !os @@ -35,69 +62,86 @@ impl TangentArgs { skip_printing_tools: true, }); } - if session.conversation.is_in_tangent_mode() { - // Get duration before exiting tangent mode - let duration_seconds = session.conversation.get_tangent_duration_seconds().unwrap_or(0); - - session.conversation.exit_tangent_mode(); - - // Send telemetry for tangent mode session - if let Err(err) = os - .telemetry - .send_tangent_mode_session( - &os.database, - session.conversation.conversation_id().to_string(), - crate::telemetry::TelemetryResult::Succeeded, - crate::telemetry::core::TangentModeSessionArgs { duration_seconds }, - ) - .await - { - tracing::warn!(?err, "Failed to send tangent mode session telemetry"); - } - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print("Restored conversation from checkpoint ("), - style::SetForegroundColor(Color::Yellow), - style::Print("↯"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("). - Returned to main conversation.\n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - session.conversation.enter_tangent_mode(); - - // Get the configured tangent mode key for display - let tangent_key_char = match os - .database - .settings - .get_string(crate::database::settings::Setting::TangentModeKey) - { - Some(key) if key.len() == 1 => key.chars().next().unwrap_or('t'), - _ => 't', // Default to 't' if setting is missing or invalid - }; - let tangent_key_display = format!("ctrl + {}", tangent_key_char.to_lowercase()); + match self.subcommand { + Some(TangentSubcommand::Tail) => { + if session.conversation.is_in_tangent_mode() { + let duration_seconds = session.conversation.get_tangent_duration_seconds().unwrap_or(0); + session.conversation.exit_tangent_mode_with_tail(); + Self::send_tangent_telemetry(os, session, duration_seconds).await; - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print("Created a conversation checkpoint ("), - style::SetForegroundColor(Color::Yellow), - style::Print("↯"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("). Use "), - style::SetForegroundColor(Color::Green), - style::Print(&tangent_key_display), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" or "), - style::SetForegroundColor(Color::Green), - style::Print("/tangent"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to restore the conversation later.\n"), - style::Print("Note: this functionality is experimental and may change or be removed in the future.\n"), - style::SetForegroundColor(Color::Reset) - )?; + execute!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("Restored conversation from checkpoint ("), + style::SetForegroundColor(Color::Yellow), + style::Print("↯"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(") with last conversation entry preserved.\n"), + style::SetForegroundColor(Color::Reset) + )?; + } else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("You need to be in tangent mode to use tail.\n"), + style::SetForegroundColor(Color::Reset) + )?; + } + }, + None => { + if session.conversation.is_in_tangent_mode() { + let duration_seconds = session.conversation.get_tangent_duration_seconds().unwrap_or(0); + session.conversation.exit_tangent_mode(); + Self::send_tangent_telemetry(os, session, duration_seconds).await; + + execute!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("Restored conversation from checkpoint ("), + style::SetForegroundColor(Color::Yellow), + style::Print("↯"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("). - Returned to main conversation.\n"), + style::SetForegroundColor(Color::Reset) + )?; + } else { + session.conversation.enter_tangent_mode(); + + // Get the configured tangent mode key for display + let tangent_key_char = match os + .database + .settings + .get_string(crate::database::settings::Setting::TangentModeKey) + { + Some(key) if key.len() == 1 => key.chars().next().unwrap_or('t'), + _ => 't', // Default to 't' if setting is missing or invalid + }; + let tangent_key_display = format!("ctrl + {}", tangent_key_char.to_lowercase()); + + execute!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("Created a conversation checkpoint ("), + style::SetForegroundColor(Color::Yellow), + style::Print("↯"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("). Use "), + style::SetForegroundColor(Color::Green), + style::Print(&tangent_key_display), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" or "), + style::SetForegroundColor(Color::Green), + style::Print("/tangent"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to restore the conversation later.\n"), + style::Print( + "Note: this functionality is experimental and may change or be removed in the future.\n" + ), + style::SetForegroundColor(Color::Reset) + )?; + } + }, } Ok(ChatState::PromptUser { diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 8ea506f929..50025f229d 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -269,6 +269,27 @@ impl ConversationState { } } + /// Exit tangent mode and preserve the last conversation entry (user + assistant) + pub fn exit_tangent_mode_with_tail(&mut self) { + if let Some(checkpoint) = self.tangent_state.take() { + // Capture the last history entry from tangent conversation if it exists + // and if it's different from what was in the main conversation + let last_entry = if self.history.len() > checkpoint.main_history.len() { + self.history.back().cloned() + } else { + None // No new entries in tangent mode + }; + + // Restore from checkpoint + self.restore_from_checkpoint(checkpoint); + + // Add the last entry if it exists + if let Some(entry) = last_entry { + self.history.push_back(entry); + } + } + } + /// Appends a collection prompts into history and returns the last message in the collection. /// It asserts that the collection ends with a prompt that assumes the role of user. pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { @@ -1568,4 +1589,99 @@ mod tests { // No duration when not in tangent mode assert!(conversation.get_tangent_duration_seconds().is_none()); } + + #[tokio::test] + async fn test_tangent_mode_with_tail() { + let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); + let mut tool_manager = ToolManager::default(); + let mut conversation = ConversationState::new( + "test_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(), + tool_manager, + None, + &os, + false, + ) + .await; + + // Add main conversation + conversation.set_next_user_message("main question".to_string()).await; + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_response(None, "main response".to_string()), + None, + ); + + let main_history_len = conversation.history.len(); + + // Enter tangent mode + conversation.enter_tangent_mode(); + assert!(conversation.is_in_tangent_mode()); + + // Add tangent conversation + conversation.set_next_user_message("tangent question".to_string()).await; + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_response(None, "tangent response".to_string()), + None, + ); + + // Exit tangent mode with tail + conversation.exit_tangent_mode_with_tail(); + assert!(!conversation.is_in_tangent_mode()); + + // Should have main conversation + last assistant message from tangent + assert_eq!(conversation.history.len(), main_history_len + 1); + + // Check that the last message is the tangent response + if let Some(entry) = conversation.history.back() { + assert_eq!(entry.assistant.content(), "tangent response"); + } else { + panic!("Expected history entry at the end"); + } + } + + #[tokio::test] + async fn test_tangent_mode_with_tail_edge_cases() { + let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); + let mut tool_manager = ToolManager::default(); + let mut conversation = ConversationState::new( + "test_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(), + tool_manager, + None, + &os, + false, + ) + .await; + + // Add main conversation + conversation.set_next_user_message("main question".to_string()).await; + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_response(None, "main response".to_string()), + None, + ); + + let main_history_len = conversation.history.len(); + + // Test: Enter tangent mode but don't add any new conversation + conversation.enter_tangent_mode(); + assert!(conversation.is_in_tangent_mode()); + + // Exit tangent mode with tail (should not add anything since no new entries) + conversation.exit_tangent_mode_with_tail(); + assert!(!conversation.is_in_tangent_mode()); + + // Should have same length as before (no new entries added) + assert_eq!(conversation.history.len(), main_history_len); + + // Test: Call exit_tangent_mode_with_tail when not in tangent mode (should do nothing) + conversation.exit_tangent_mode_with_tail(); + assert_eq!(conversation.history.len(), main_history_len); + } } diff --git a/docs/tangent-mode.md b/docs/tangent-mode.md index 0ee785c99b..2bdbc346aa 100644 --- a/docs/tangent-mode.md +++ b/docs/tangent-mode.md @@ -32,6 +32,13 @@ Use `/tangent` or Ctrl+T again: Restored conversation from checkpoint (↯). - Returned to main conversation. ``` +### Exit Tangent Mode with Tail +Use `/tangent tail` to preserve the last conversation entry (question + answer): +``` +↯ > /tangent tail +Restored conversation from checkpoint (↯) with last conversation entry preserved. +``` + ## Usage Examples ### Example 1: Exploring Alternatives @@ -93,6 +100,29 @@ Restored conversation from checkpoint (↯). > Here's my query: SELECT * FROM orders... ``` +### Example 4: Keeping Useful Information +``` +> Help me debug this Python error + +I can help you debug that. Could you share the error message? + +> /tangent +Created a conversation checkpoint (↯). + +↯ > What are the most common Python debugging techniques? + +Here are the most effective Python debugging techniques: +1. Use print statements strategically +2. Leverage the Python debugger (pdb)... + +↯ > /tangent tail +Restored conversation from checkpoint (↯) with last conversation entry preserved. + +> Here's my error: TypeError: unsupported operand type(s)... + +# The preserved entry (question + answer about debugging techniques) is now part of main conversation +``` + ## Configuration ### Keyboard Shortcut @@ -131,6 +161,7 @@ q settings introspect.tangentMode true 2. **Return promptly** - Don't forget you're in tangent mode 3. **Use for clarification** - Perfect for "wait, what does X mean?" questions 4. **Experiment safely** - Test ideas without affecting main conversation +5. **Use `/tangent tail`** - When both the tangent question and answer are useful for main conversation ## Limitations From 178ccf1a5c7063ef5c28762c9a0da0fba7dfba28 Mon Sep 17 00:00:00 2001 From: abhraina-aws Date: Thu, 11 Sep 2025 11:20:18 -0700 Subject: [PATCH 19/71] feat: add daily heartbeat telemetry (#2839) Tracks daily active users by sending amazonqcli_dailyHeartbeat event once per day. Uses fail-closed logic to prevent spam during database errors. --- crates/chat-cli/src/cli/mod.rs | 5 +++++ crates/chat-cli/src/database/mod.rs | 21 +++++++++++++++++++++ crates/chat-cli/src/telemetry/core.rs | 10 ++++++++++ crates/chat-cli/src/telemetry/mod.rs | 4 ++++ crates/chat-cli/telemetry_definitions.json | 8 ++++++++ 5 files changed, 48 insertions(+) diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs index 62beb50f48..1e384cf63e 100644 --- a/crates/chat-cli/src/cli/mod.rs +++ b/crates/chat-cli/src/cli/mod.rs @@ -144,6 +144,11 @@ impl RootSubcommand { ); } + // Daily heartbeat check + if os.database.should_send_heartbeat() && os.telemetry.send_daily_heartbeat().is_ok() { + os.database.record_heartbeat_sent().ok(); + } + // Send executed telemetry. if self.valid_for_telemetry() { os.telemetry diff --git a/crates/chat-cli/src/database/mod.rs b/crates/chat-cli/src/database/mod.rs index b184ea6fef..80c667b8a8 100644 --- a/crates/chat-cli/src/database/mod.rs +++ b/crates/chat-cli/src/database/mod.rs @@ -61,6 +61,7 @@ const IDC_REGION_KEY: &str = "auth.idc.region"; // We include this key to remove for backwards compatibility const CUSTOMIZATION_STATE_KEY: &str = "api.selectedCustomization"; const PROFILE_MIGRATION_KEY: &str = "profile.Migrated"; +const HEARTBEAT_DATE_KEY: &str = "telemetry.lastHeartbeatDate"; const MIGRATIONS: &[Migration] = migrations![ "000_migration_table", @@ -333,6 +334,26 @@ impl Database { self.set_entry(Table::State, PROFILE_MIGRATION_KEY, true) } + /// Check if daily heartbeat should be sent + pub fn should_send_heartbeat(&self) -> bool { + use chrono::Utc; + let today = Utc::now().format("%Y-%m-%d").to_string(); + + match self.get_entry::(Table::State, HEARTBEAT_DATE_KEY) { + Ok(Some(last_date)) => last_date != today, + Ok(None) => true, // First time - definitely send + Err(_) => false, // Database error - don't send (might have already sent) + } + } + + /// Record that heartbeat was sent today + pub fn record_heartbeat_sent(&self) -> Result<(), DatabaseError> { + use chrono::Utc; + let today = Utc::now().format("%Y-%m-%d").to_string(); + self.set_entry(Table::State, HEARTBEAT_DATE_KEY, today)?; + Ok(()) + } + // /// Get the model id used for last conversation state. // pub fn get_last_used_model_id(&self) -> Result, DatabaseError> { // self.get_json_entry::(Table::State, LAST_USED_MODEL_ID) diff --git a/crates/chat-cli/src/telemetry/core.rs b/crates/chat-cli/src/telemetry/core.rs index 48d23bbef2..58091ae8ff 100644 --- a/crates/chat-cli/src/telemetry/core.rs +++ b/crates/chat-cli/src/telemetry/core.rs @@ -19,6 +19,7 @@ use crate::telemetry::definitions::metrics::{ AmazonqMessageResponseError, AmazonqProfileState, AmazonqStartChat, + AmazonqcliDailyHeartbeat, CodewhispererterminalAddChatMessage, CodewhispererterminalAgentConfigInit, CodewhispererterminalAgentContribution, @@ -499,6 +500,14 @@ impl Event { } .into_metric_datum(), ), + EventType::DailyHeartbeat {} => Some( + AmazonqcliDailyHeartbeat { + create_time: self.created_time, + value: None, + source: None, + } + .into_metric_datum(), + ), } } } @@ -689,6 +698,7 @@ pub enum EventType { message_id: Option, context_file_length: Option, }, + DailyHeartbeat {}, } #[derive(Debug)] diff --git a/crates/chat-cli/src/telemetry/mod.rs b/crates/chat-cli/src/telemetry/mod.rs index 0b0a535a5f..90a9faa8b2 100644 --- a/crates/chat-cli/src/telemetry/mod.rs +++ b/crates/chat-cli/src/telemetry/mod.rs @@ -235,6 +235,10 @@ impl TelemetryThread { Ok(self.tx.send(Event::new(EventType::UserLoggedIn {}))?) } + pub fn send_daily_heartbeat(&self) -> Result<(), TelemetryError> { + Ok(self.tx.send(Event::new(EventType::DailyHeartbeat {}))?) + } + pub async fn send_cli_subcommand_executed( &self, database: &Database, diff --git a/crates/chat-cli/telemetry_definitions.json b/crates/chat-cli/telemetry_definitions.json index 3e52e5d3b2..5dac4ec712 100644 --- a/crates/chat-cli/telemetry_definitions.json +++ b/crates/chat-cli/telemetry_definitions.json @@ -530,6 +530,14 @@ { "type": "statusCode", "required": false }, { "type": "codewhispererterminal_clientApplication" } ] + }, + { + "name": "amazonqcli_dailyHeartbeat", + "description": "Daily heartbeat to track active CLI usage", + "unit": "None", + "metadata": [ + { "type": "source", "required": false } + ] } ] } From d9c34dc7878e93d9ccd0dbb01de6f4b8eb5808c9 Mon Sep 17 00:00:00 2001 From: Brandon Kiser <51934408+brandonskiser@users.noreply.github.com> Date: Thu, 11 Sep 2025 11:25:22 -0700 Subject: [PATCH 20/71] fix: update dangerous patterns for execute bash to include $ (#2811) --- crates/chat-cli/src/cli/chat/tools/execute/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs index 388c48476b..2dce20fbb5 100644 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs @@ -70,7 +70,7 @@ impl ExecuteCommand { let Some(args) = shlex::split(&self.command) else { return true; }; - const DANGEROUS_PATTERNS: &[&str] = &["<(", "$(", "`", ">", "&&", "||", "&", ";", "${", "\n", "\r", "IFS"]; + const DANGEROUS_PATTERNS: &[&str] = &["<(", "$(", "`", ">", "&&", "||", "&", ";", "$", "\n", "\r", "IFS"]; if args .iter() @@ -328,6 +328,7 @@ mod tests { (r#"find / -fprintf "/path/to/file" -quit"#, true), (r"find . -${t}exec touch asdf \{\} +", true), (r"find . -${t:=exec} touch asdf2 \{\} +", true), + (r#"find /tmp -name "*" -exe$9c touch /tmp/find_result {} +"#, true), // `grep` command arguments ("echo 'test data' | grep -P '(?{system(\"date\")})'", true), ("echo 'test data' | grep --perl-regexp '(?{system(\"date\")})'", true), From 6ba1416cd02517e820beaf9b906241819db20645 Mon Sep 17 00:00:00 2001 From: yayami3 <116920988+yayami3@users.noreply.github.com> Date: Fri, 12 Sep 2025 03:26:22 +0900 Subject: [PATCH 21/71] docs: fix local agent directory path in documentation (#2749) * docs: fix local agent directory path - Fix local agent path from .aws/amazonq/cli-agents/ to .amazonq/cli-agents/ - Global paths (~/.aws/amazonq/cli-agents/) remain correct - Aligns documentation with source code implementation * fix: correct workspace agent path in /agent help message The help message for the /agent command incorrectly showed the workspace agent path as 'cwd/.aws/amazonq/cli-agents' when it should be 'cwd/.amazonq/cli-agents' (without the .aws directory). This fix aligns the help text with the actual WORKSPACE_AGENT_DIR_RELATIVE constant defined in directories.rs. --- crates/chat-cli/src/cli/chat/cli/profile.rs | 2 +- docs/agent-file-locations.md | 2 +- docs/default-agent-behavior.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index 2edca042f3..233ab7888d 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -55,7 +55,7 @@ use crate::util::{ Notes • Launch q chat with a specific agent with --agent -• Construct an agent under ~/.aws/amazonq/cli-agents/ (accessible globally) or cwd/.aws/amazonq/cli-agents (accessible in workspace) +• Construct an agent under ~/.aws/amazonq/cli-agents/ (accessible globally) or cwd/.amazonq/cli-agents (accessible in workspace) • See example config under global directory • Set default agent to assume with settings by running \"q settings chat.defaultAgent agent_name\" • Each agent maintains its own set of context and customizations" diff --git a/docs/agent-file-locations.md b/docs/agent-file-locations.md index c09ed9311b..b5be46e1b3 100644 --- a/docs/agent-file-locations.md +++ b/docs/agent-file-locations.md @@ -47,7 +47,7 @@ These agents are available from any directory when using Q CLI. When Q CLI looks for an agent, it follows this precedence order: -1. **Local first**: Checks `.aws/amazonq/cli-agents/` in the current working directory +1. **Local first**: Checks `.amazonq/cli-agents/` in the current working directory 2. **Global fallback**: If not found locally, checks `~/.aws/amazonq/cli-agents/` in the home directory ## Naming Conflicts diff --git a/docs/default-agent-behavior.md b/docs/default-agent-behavior.md index 0510906a60..727de2e93f 100644 --- a/docs/default-agent-behavior.md +++ b/docs/default-agent-behavior.md @@ -96,7 +96,7 @@ q chat --agent specialized-agent ### Create a Custom Default You can create your own "default" agent by placing an agent file with the name `q_cli_default` in either: -- `.aws/amazonq/cli-agents/` (local) +- `.amazonq/cli-agents/` (local) - `~/.aws/amazonq/cli-agents/` (global) This will override the built-in default agent configuration. From 22783a8265546129fabaee52dece83f9f2d1088c Mon Sep 17 00:00:00 2001 From: Bart van Bragt Date: Thu, 11 Sep 2025 20:33:55 +0200 Subject: [PATCH 22/71] Invalid pointer to trace log location (#2734) --- crates/chat-cli/src/cli/chat/tool_manager.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index d2fd119ce1..0d4bad09c8 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -1923,7 +1923,7 @@ fn queue_failure_message( style::Print(fail_load_msg), style::Print("\n"), style::Print(format!( - " - run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n" + " - run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog/{CHAT_BINARY_NAME}.log for detail\n" )), style::ResetColor, )?) From fcd52f9e07dc77cc66e935690125a13cc7357980 Mon Sep 17 00:00:00 2001 From: Ennio Pastore Date: Thu, 11 Sep 2025 20:34:36 +0200 Subject: [PATCH 23/71] Fix bug README.md (#2569) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3554b6c239..ecfe529a4e 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ cargo install typos-cli ## Project Layout -- [`chat_cli`](crates/chat_cli/) - the `q` CLI, allows users to interface with Amazon Q Developer from +- [`chat_cli`](crates/chat-cli/) - the `q` CLI, allows users to interface with Amazon Q Developer from the command line - [`scripts/`](scripts/) - Contains ops and build related scripts - [`crates/`](crates/) - Contains all rust crates From b801481fcdef88a9fd246803f2cae051b3f715e6 Mon Sep 17 00:00:00 2001 From: Kenneth Sanchez V Date: Thu, 11 Sep 2025 11:44:41 -0700 Subject: [PATCH 24/71] fix: Layout fix (#2798) --- crates/chat-cli/src/cli/chat/cli/experiment.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/cli/experiment.rs b/crates/chat-cli/src/cli/chat/cli/experiment.rs index 7854974c42..0dd981d9d3 100644 --- a/crates/chat-cli/src/cli/chat/cli/experiment.rs +++ b/crates/chat-cli/src/cli/chat/cli/experiment.rs @@ -108,7 +108,16 @@ async fn select_experiment(os: &mut Os, session: &mut ChatSession) -> Result Err(Interrupted) - Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => return Ok(None), + Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => { + // Move to beginning of line and clear everything from warning message down + queue!( + session.stderr, + crossterm::cursor::MoveToColumn(0), + crossterm::cursor::MoveUp(experiment_labels.len() as u16 + 3), + crossterm::terminal::Clear(crossterm::terminal::ClearType::FromCursorDown), + )?; + return Ok(None); + }, Err(e) => return Err(ChatError::Custom(format!("Failed to choose experiment: {e}").into())), }; @@ -161,6 +170,13 @@ async fn select_experiment(os: &mut Os, session: &mut ChatSession) -> Result Date: Thu, 11 Sep 2025 11:44:54 -0700 Subject: [PATCH 25/71] docs: Update experiment docs to contain todo lists (#2791) --- docs/built-in-tools.md | 6 +++--- docs/experiments.md | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/docs/built-in-tools.md b/docs/built-in-tools.md index 39778740bb..c66c18ebfa 100644 --- a/docs/built-in-tools.md +++ b/docs/built-in-tools.md @@ -110,19 +110,19 @@ Opens the browser to a pre-filled GitHub issue template to report chat issues, b This tool has no configuration options. -## Knowledge Tool +## Knowledge Tool (experimental) Store and retrieve information in a knowledge base across chat sessions. Provides semantic search capabilities for files, directories, and text content. This tool has no configuration options. -## Thinking Tool +## Thinking Tool (experimental) An internal reasoning mechanism that improves the quality of complex tasks by breaking them down into atomic actions. This tool has no configuration options. -## Todo_list Tool +## TODO List Tool (experimental) Create and manage TODO lists for tracking multi-step tasks. Lists are stored locally in `.amazonq/cli-todo-lists/`. diff --git a/docs/experiments.md b/docs/experiments.md index 92da46fd1c..3c8ce310b4 100644 --- a/docs/experiments.md +++ b/docs/experiments.md @@ -58,6 +58,28 @@ Amazon Q CLI includes experimental features that can be toggled on/off using the **When enabled:** Use `/tangent` or the keyboard shortcut to create a checkpoint and explore tangential topics. Use the same command to return to your main conversation. +### TODO Lists +**Tool name**: `todo_list` +**Command:** `/todos` +**Description:** Enables Q to create and modify TODO lists using the `todo_list` tool and the user to view and manage existing TODO lists using `/todos`. + +**Features:** +- Q will automatically make TODO lists when appropriate or when asked +- View, manage, and delete TODOs using `/todos` +- Resume existing TODO lists stored in `.amazonq/cli-todo-lists` + +**Usage:** +``` +/todos clear-finished # Delete completed TODOs in your working directory +/todos resume # Select and resume an existing TODO list +/todos view # Select and view and existing TODO list +/todos delete # Select and delete an existing TODO list +``` + +**Settings:** +- `chat.enableTodoList` - Enable/disable TODO list functionality (boolean) + + ## Managing Experiments Use the `/experiment` command to toggle experimental features: @@ -84,11 +106,13 @@ These features are provided to gather feedback and test new capabilities. Please All experimental commands are available in the fuzzy search (Ctrl+S): - `/experiment` - Manage experimental features - `/knowledge` - Knowledge base commands (when enabled) +- `/todos` - User-controlled TODO list commands (when enabled) ## Settings Integration Experiments are stored as settings and persist across sessions: - `EnabledKnowledge` - Knowledge experiment state - `EnabledThinking` - Thinking experiment state +- `EnabledTodoList` - TODO list experiment state You can also manage these through the settings system if needed. From 2274a7b6c36d9540864fce14d9c4a8a04ec92807 Mon Sep 17 00:00:00 2001 From: Michael Orlov <34108460+harleylrn@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:45:21 -0400 Subject: [PATCH 26/71] feat: Add support for comma-containing arguments in MCP --args parameter (#2754) --- crates/chat-cli/src/cli/mcp.rs | 133 ++++++++++++++++++++++++++++++--- 1 file changed, 124 insertions(+), 9 deletions(-) diff --git a/crates/chat-cli/src/cli/mcp.rs b/crates/chat-cli/src/cli/mcp.rs index c70951a9b5..7b7a9dbb30 100644 --- a/crates/chat-cli/src/cli/mcp.rs +++ b/crates/chat-cli/src/cli/mcp.rs @@ -20,6 +20,7 @@ use eyre::{ Result, bail, }; +use serde_json; use super::agent::{ Agent, @@ -96,8 +97,11 @@ pub struct AddArgs { /// The command used to launch the server #[arg(long)] pub command: String, - /// Arguments to pass to the command - #[arg(long, action = ArgAction::Append, allow_hyphen_values = true, value_delimiter = ',')] + /// Arguments to pass to the command. Can be provided as: + /// 1. Multiple --args flags: --args arg1 --args arg2 --args "arg,with,commas" + /// 2. Comma-separated with escaping: --args "arg1,arg2,arg\,with\,commas" + /// 3. JSON array format: --args '["arg1", "arg2", "arg,with,commas"]' + #[arg(long, action = ArgAction::Append, allow_hyphen_values = true)] pub args: Vec, /// Where to add the server to. If an agent name is not supplied, the changes shall be made to /// the global mcp.json @@ -119,6 +123,9 @@ pub struct AddArgs { impl AddArgs { pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { + // Process args to handle comma-separated values, escaping, and JSON arrays + let processed_args = self.process_args()?; + match self.agent.as_deref() { Some(agent_name) => { let (mut agent, config_path) = Agent::get_agent_by_name(os, agent_name).await?; @@ -136,7 +143,7 @@ impl AddArgs { let merged_env = self.env.into_iter().flatten().collect::>(); let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({ "command": self.command, - "args": self.args, + "args": processed_args, "env": merged_env, "timeout": self.timeout.unwrap_or(default_timeout()), "disabled": self.disabled, @@ -169,7 +176,7 @@ impl AddArgs { let merged_env = self.env.into_iter().flatten().collect::>(); let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({ "command": self.command, - "args": self.args, + "args": processed_args, "env": merged_env, "timeout": self.timeout.unwrap_or(default_timeout()), "disabled": self.disabled, @@ -188,6 +195,17 @@ impl AddArgs { Ok(()) } + + fn process_args(&self) -> Result> { + let mut processed_args = Vec::new(); + + for arg in &self.args { + let parsed = parse_args(arg)?; + processed_args.extend(parsed); + } + + Ok(processed_args) + } } #[derive(Debug, Clone, PartialEq, Eq, Args)] @@ -507,6 +525,65 @@ fn parse_env_vars(arg: &str) -> Result> { Ok(vars) } +fn parse_args(arg: &str) -> Result> { + // Try to parse as JSON array first + if arg.trim_start().starts_with('[') { + match serde_json::from_str::>(arg) { + Ok(args) => return Ok(args), + Err(_) => { + bail!( + "Failed to parse arguments as JSON array. Expected format: '[\"arg1\", \"arg2\", \"arg,with,commas\"]'" + ); + }, + } + } + + // Check if the string contains escaped commas + let has_escaped_commas = arg.contains("\\,"); + + if has_escaped_commas { + // Parse with escape support + let mut args = Vec::new(); + let mut current_arg = String::new(); + let mut chars = arg.chars().peekable(); + + while let Some(ch) = chars.next() { + match ch { + '\\' => { + // Handle escape sequences + if let Some(&next_ch) = chars.peek() { + if next_ch == ',' || next_ch == '\\' { + current_arg.push(chars.next().unwrap()); + } else { + current_arg.push(ch); + } + } else { + current_arg.push(ch); + } + }, + ',' => { + // Split on unescaped comma + args.push(current_arg.trim().to_string()); + current_arg.clear(); + }, + _ => { + current_arg.push(ch); + }, + } + } + + // Add the last argument + if !current_arg.is_empty() || !args.is_empty() { + args.push(current_arg.trim().to_string()); + } + + Ok(args) + } else { + // Default behavior: split on commas (backward compatibility) + Ok(arg.split(',').map(|s| s.trim().to_string()).collect()) + } +} + async fn load_cfg(os: &Os, p: &PathBuf) -> Result { Ok(if os.fs.exists(p) { McpServerConfig::load_from_file(os, p).await? @@ -618,11 +695,7 @@ mod tests { name: "test_server".to_string(), scope: None, command: "test_command".to_string(), - args: vec![ - "awslabs.eks-mcp-server".to_string(), - "--allow-write".to_string(), - "--allow-sensitive-data-access".to_string(), - ], + args: vec!["awslabs.eks-mcp-server,--allow-write,--allow-sensitive-data-access".to_string(),], agent: None, env: vec![ [ @@ -680,4 +753,46 @@ mod tests { })) ); } + + #[test] + fn test_parse_args_comma_separated() { + let result = parse_args("arg1,arg2,arg3").unwrap(); + assert_eq!(result, vec!["arg1", "arg2", "arg3"]); + } + + #[test] + fn test_parse_args_with_escaped_commas() { + let result = parse_args("arg1,arg2\\,with\\,commas,arg3").unwrap(); + assert_eq!(result, vec!["arg1", "arg2,with,commas", "arg3"]); + } + + #[test] + fn test_parse_args_json_array() { + let result = parse_args(r#"["arg1", "arg2", "arg,with,commas"]"#).unwrap(); + assert_eq!(result, vec!["arg1", "arg2", "arg,with,commas"]); + } + + #[test] + fn test_parse_args_single_arg_with_commas() { + let result = parse_args("--config=key1=val1\\,key2=val2").unwrap(); + assert_eq!(result, vec!["--config=key1=val1,key2=val2"]); + } + + #[test] + fn test_parse_args_backward_compatibility() { + let result = parse_args("--config=key1=val1,key2=val2").unwrap(); + assert_eq!(result, vec!["--config=key1=val1", "key2=val2"]); + } + + #[test] + fn test_parse_args_mixed_escaping() { + let result = parse_args("normal,escaped\\,comma,--flag=val1\\,val2").unwrap(); + assert_eq!(result, vec!["normal", "escaped,comma", "--flag=val1,val2"]); + } + + #[test] + fn test_parse_args_json_array_invalid() { + let result = parse_args(r#"["invalid json"#); + assert!(result.is_err()); + } } From c05eb003b137355de8dafe98a0f0595b0dc970b7 Mon Sep 17 00:00:00 2001 From: Brandon Kiser <51934408+brandonskiser@users.noreply.github.com> Date: Thu, 11 Sep 2025 11:53:41 -0700 Subject: [PATCH 27/71] chore: add extra curl flags for debugging build during feed.json failures (#2843) --- crates/chat-cli/build.rs | 10 ++++++---- crates/chat-cli/src/cli/chat/tools/use_aws.rs | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/crates/chat-cli/build.rs b/crates/chat-cli/build.rs index 8c6683ef28..5d8f6343e2 100644 --- a/crates/chat-cli/build.rs +++ b/crates/chat-cli/build.rs @@ -355,8 +355,10 @@ fn download_feed_json() { .args([ "-H", "Accept: application/vnd.github.v3.raw", - "-s", // silent - "-f", // fail on HTTP errors + "-f", // fail on HTTP errors + "-s", // silent + "-v", // verbose output printed to stderr + "--show-error", // print error message to stderr (since -s is used) "https://api.github.com/repos/aws/amazon-q-developer-cli-autocomplete/contents/feed.json", ]) .output(); @@ -371,9 +373,9 @@ fn download_feed_json() { }, Ok(result) => { let error_msg = if !result.stderr.is_empty() { - format!("HTTP error: {}", String::from_utf8_lossy(&result.stderr)) + format!("{}", String::from_utf8_lossy(&result.stderr)) } else { - "HTTP error occurred".to_string() + "An unknown error occurred".to_string() }; panic!("Failed to download feed.json: {}", error_msg); }, diff --git a/crates/chat-cli/src/cli/chat/tools/use_aws.rs b/crates/chat-cli/src/cli/chat/tools/use_aws.rs index 4ea40c80c4..3bb9611ea4 100644 --- a/crates/chat-cli/src/cli/chat/tools/use_aws.rs +++ b/crates/chat-cli/src/cli/chat/tools/use_aws.rs @@ -397,7 +397,7 @@ mod tests { #[tokio::test] async fn test_eval_perm_auto_allow_readonly_default() { let os = Os::new().await.unwrap(); - + // Test read-only operation with default settings (auto_allow_readonly = false) let readonly_cmd = use_aws! {{ "service_name": "s3", @@ -429,7 +429,7 @@ mod tests { #[tokio::test] async fn test_eval_perm_auto_allow_readonly_enabled() { let os = Os::new().await.unwrap(); - + let agent = Agent { name: "test_agent".to_string(), tools_settings: { @@ -475,7 +475,7 @@ mod tests { #[tokio::test] async fn test_eval_perm_auto_allow_readonly_with_denied_services() { let os = Os::new().await.unwrap(); - + let agent = Agent { name: "test_agent".to_string(), tools_settings: { From 640efb7113bf96065da0346f92ce57357c2a9e9b Mon Sep 17 00:00:00 2001 From: Brandon Kiser <51934408+brandonskiser@users.noreply.github.com> Date: Thu, 11 Sep 2025 13:02:30 -0700 Subject: [PATCH 28/71] fix: remove downloading feed during build (#2844) --- crates/chat-cli/src/cli/mcp.rs | 1 - scripts/build.py | 1 - 2 files changed, 2 deletions(-) diff --git a/crates/chat-cli/src/cli/mcp.rs b/crates/chat-cli/src/cli/mcp.rs index 7b7a9dbb30..0b709ef887 100644 --- a/crates/chat-cli/src/cli/mcp.rs +++ b/crates/chat-cli/src/cli/mcp.rs @@ -20,7 +20,6 @@ use eyre::{ Result, bail, }; -use serde_json; use super::agent::{ Agent, diff --git a/scripts/build.py b/scripts/build.py index 2176045d74..edbb8c27bc 100644 --- a/scripts/build.py +++ b/scripts/build.py @@ -87,7 +87,6 @@ def build_chat_bin( env={ **os.environ, **rust_env(release=release), - "FETCH_FEED": "1", # Always fetch latest feed.json for official builds }, ) From 458e3c1d23a768ab73bb198628a159abdbea2fb1 Mon Sep 17 00:00:00 2001 From: Matt Lee <1302416+mr-lee@users.noreply.github.com> Date: Fri, 12 Sep 2025 11:01:53 -0400 Subject: [PATCH 29/71] feat(execute_bash): change autoAllowReadonly default to false for security (#2846) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Change default_allow_read_only() from true to false for secure by default behavior - Default behavior: all bash commands require user confirmation (secure by default) - Opt-in behavior: when autoAllowReadonly=true, read-only commands are auto-approved - Use autoAllowReadonly casing to match use_aws tool pattern - Update documentation to reflect new default value and consistent naming - Add comprehensive tests covering all scenarios - Maintains backward compatibility through configuration - Follows same pattern as use_aws autoAllowReadonly setting 🤖 Assisted by Amazon Q Developer Co-authored-by: Matt Lee --- .../src/cli/chat/tools/execute/mod.rs | 131 +++++++++++++++++- docs/built-in-tools.md | 4 +- 2 files changed, 129 insertions(+), 6 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs index 2dce20fbb5..0ce3fdeac1 100644 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs @@ -196,11 +196,11 @@ impl ExecuteCommand { #[serde(default)] denied_commands: Vec, #[serde(default = "default_allow_read_only")] - allow_read_only: bool, + auto_allow_readonly: bool, } fn default_allow_read_only() -> bool { - true + false } let Self { command, .. } = self; @@ -211,7 +211,7 @@ impl ExecuteCommand { let Settings { allowed_commands, denied_commands, - allow_read_only, + auto_allow_readonly, } = match serde_json::from_value::(settings.clone()) { Ok(settings) => settings, Err(e) => { @@ -233,7 +233,7 @@ impl ExecuteCommand { if is_in_allowlist { PermissionEvalResult::Allow - } else if self.requires_acceptance(Some(&allowed_commands), allow_read_only) { + } else if self.requires_acceptance(Some(&allowed_commands), auto_allow_readonly) { PermissionEvalResult::Ask } else { PermissionEvalResult::Allow @@ -489,6 +489,129 @@ mod tests { assert!(matches!(res, PermissionEvalResult::Deny(ref rules) if rules.contains(&"\\Agit .*\\z".to_string()))); } + #[tokio::test] + async fn test_eval_perm_allow_read_only_default() { + use crate::cli::agent::Agent; + + let os = Os::new().await.unwrap(); + + // Test read-only command with default settings (allow_read_only = false) + let readonly_cmd = serde_json::from_value::(serde_json::json!({ + "command": "ls -la", + })) + .unwrap(); + + let agent = Agent::default(); + let res = readonly_cmd.eval_perm(&os, &agent); + // Should ask for confirmation even for read-only commands by default + assert!(matches!(res, PermissionEvalResult::Ask)); + + // Test non-read-only command with default settings + let write_cmd = serde_json::from_value::(serde_json::json!({ + "command": "rm file.txt", + })) + .unwrap(); + + let res = write_cmd.eval_perm(&os, &agent); + // Should ask for confirmation for write commands + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_allow_read_only_enabled() { + use crate::cli::agent::{ + Agent, + ToolSettingTarget, + }; + use std::collections::HashMap; + + let os = Os::new().await.unwrap(); + let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::::new(); + map.insert( + ToolSettingTarget(tool_name.to_string()), + serde_json::json!({ + "autoAllowReadonly": true + }), + ); + map + }, + ..Default::default() + }; + + // Test read-only command with allow_read_only = true + let readonly_cmd = serde_json::from_value::(serde_json::json!({ + "command": "ls -la", + })) + .unwrap(); + + let res = readonly_cmd.eval_perm(&os, &agent); + // Should allow read-only commands without confirmation + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test write command with allow_read_only = true + let write_cmd = serde_json::from_value::(serde_json::json!({ + "command": "rm file.txt", + })) + .unwrap(); + + let res = write_cmd.eval_perm(&os, &agent); + // Should still ask for confirmation for write commands + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_allow_read_only_with_denied_commands() { + use crate::cli::agent::{ + Agent, + ToolSettingTarget, + }; + use std::collections::HashMap; + + let os = Os::new().await.unwrap(); + let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::::new(); + map.insert( + ToolSettingTarget(tool_name.to_string()), + serde_json::json!({ + "autoAllowReadonly": true, + "deniedCommands": ["ls .*"] + }), + ); + map + }, + ..Default::default() + }; + + // Test read-only command that's in denied list + let denied_readonly_cmd = serde_json::from_value::(serde_json::json!({ + "command": "ls -la", + })) + .unwrap(); + + let res = denied_readonly_cmd.eval_perm(&os, &agent); + // Should deny even read-only commands if they're in denied list + assert!(matches!(res, PermissionEvalResult::Deny(ref commands) if commands.contains(&"\\Als .*\\z".to_string()))); + + // Test different read-only command not in denied list + let allowed_readonly_cmd = serde_json::from_value::(serde_json::json!({ + "command": "cat file.txt", + })) + .unwrap(); + + let res = allowed_readonly_cmd.eval_perm(&os, &agent); + // Should allow read-only commands not in denied list + assert!(matches!(res, PermissionEvalResult::Allow)); + } + #[tokio::test] async fn test_cloudtrail_tracking() { use crate::cli::chat::consts::{ diff --git a/docs/built-in-tools.md b/docs/built-in-tools.md index c66c18ebfa..6d5012ba57 100644 --- a/docs/built-in-tools.md +++ b/docs/built-in-tools.md @@ -24,7 +24,7 @@ Execute the specified bash command. "execute_bash": { "allowedCommands": ["git status", "git fetch"], "deniedCommands": ["git commit .*", "git push .*"], - "allowReadOnly": true + "autoAllowReadonly": true } } } @@ -36,7 +36,7 @@ Execute the specified bash command. |--------|------|---------|------------------------------------------------------------------------------------------| | `allowedCommands` | array of strings | `[]` | List of specific commands that are allowed without prompting. Supports regex formatting. Note that regex entered are anchored with \A and \z | | `deniedCommands` | array of strings | `[]` | List of specific commands that are denied. Supports regex formatting. Note that regex entered are anchored with \A and \z. Deny rules are evaluated before allow rules | -| `allowReadOnly` | boolean | `true` | Whether to allow read-only commands without prompting | +| `autoAllowReadonly` | boolean | `false` | Whether to allow read-only commands without prompting | ## Fs_read Tool From 4ea78b9bb13f65f2c782d8ed36b6de4d2ad7efdd Mon Sep 17 00:00:00 2001 From: Matt Lee <1302416+mr-lee@users.noreply.github.com> Date: Fri, 12 Sep 2025 12:28:35 -0400 Subject: [PATCH 30/71] feat(agent): add edit subcommand to modify existing agents (#2845) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Edit subcommand to AgentSubcommands enum - Implement edit functionality that opens existing agent files in editor - Use Agent::get_agent_by_name to locate and load existing agents - Include post-edit validation to ensure JSON remains valid - Add comprehensive tests for the new edit subcommand - Support both --name and -n flags for agent name specification 🤖 Assisted by Amazon Q Developer Co-authored-by: Matt Lee --- .../src/cli/agent/root_command_args.rs | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/crates/chat-cli/src/cli/agent/root_command_args.rs b/crates/chat-cli/src/cli/agent/root_command_args.rs index 469e0982ba..b470ba3e46 100644 --- a/crates/chat-cli/src/cli/agent/root_command_args.rs +++ b/crates/chat-cli/src/cli/agent/root_command_args.rs @@ -46,6 +46,12 @@ pub enum AgentSubcommands { #[arg(long, short)] from: Option, }, + /// Edit an existing agent config + Edit { + /// Name of the agent to edit + #[arg(long, short)] + name: String, + }, /// Validate a config with the given path Validate { #[arg(long, short)] @@ -138,6 +144,38 @@ impl AgentArgs { path_with_file_name.display() )?; }, + Some(AgentSubcommands::Edit { name }) => { + let _agents = Agents::load(os, None, true, &mut stderr, mcp_enabled).await.0; + let (_agent, path_with_file_name) = Agent::get_agent_by_name(os, &name).await?; + + let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); + let mut cmd = std::process::Command::new(editor_cmd); + + let status = cmd.arg(&path_with_file_name).status()?; + if !status.success() { + bail!("Editor process did not exit with success"); + } + + let Ok(content) = os.fs.read(&path_with_file_name).await else { + bail!( + "Post edit validation failed. Error opening {}. Aborting", + path_with_file_name.display() + ); + }; + if let Err(e) = serde_json::from_slice::(&content) { + bail!( + "Post edit validation failed for agent '{name}' at path: {}. Malformed config detected: {e}", + path_with_file_name.display() + ); + } + + writeln!( + stderr, + "\n✏️ Edited agent {} '{}'\n", + name, + path_with_file_name.display() + )?; + }, Some(AgentSubcommands::Validate { path }) => { let mut global_mcp_config = None::; let agent = Agent::load(os, path.as_str(), &mut global_mcp_config, mcp_enabled, &mut stderr).await; @@ -386,4 +424,24 @@ mod tests { }) ); } + + #[test] + fn test_agent_subcommand_edit() { + assert_parse!( + ["agent", "edit", "--name", "existing_agent"], + RootSubcommand::Agent(AgentArgs { + cmd: Some(AgentSubcommands::Edit { + name: "existing_agent".to_string(), + }) + }) + ); + assert_parse!( + ["agent", "edit", "-n", "existing_agent"], + RootSubcommand::Agent(AgentArgs { + cmd: Some(AgentSubcommands::Edit { + name: "existing_agent".to_string(), + }) + }) + ); + } } From c4a09109748b4ed6e251d88ebfb87e87093e4087 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Fri, 12 Sep 2025 14:01:49 -0700 Subject: [PATCH 31/71] fix(mcp): not being able to refresh tokens for remote mcp (#2849) * adds registration persistance for token refresh * truncates on tool description * Modifies oauth success message * adds time stamps on mcp logs --- crates/chat-cli/src/cli/chat/cli/mcp.rs | 6 +- crates/chat-cli/src/cli/chat/tool_manager.rs | 64 +++++++++++---- crates/chat-cli/src/mcp_client/client.rs | 14 +++- crates/chat-cli/src/mcp_client/oauth_util.rs | 86 ++++++++++++++++---- 4 files changed, 132 insertions(+), 38 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/mcp.rs b/crates/chat-cli/src/cli/chat/cli/mcp.rs index bbb4883ff5..1cabd5344b 100644 --- a/crates/chat-cli/src/cli/chat/cli/mcp.rs +++ b/crates/chat-cli/src/cli/chat/cli/mcp.rs @@ -54,9 +54,9 @@ impl McpArgs { let msg = msg .iter() .map(|record| match record { - LoadingRecord::Err(content) | LoadingRecord::Warn(content) | LoadingRecord::Success(content) => { - content.clone() - }, + LoadingRecord::Err(timestamp, content) + | LoadingRecord::Warn(timestamp, content) + | LoadingRecord::Success(timestamp, content) => format!("[{timestamp}]: {content}"), }) .collect::>() .join("\n--- tools refreshed ---\n"); diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 0d4bad09c8..7f72b874a6 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -150,9 +150,26 @@ enum LoadingMsg { /// surface (since we would only want to surface fatal errors in non-interactive mode). #[derive(Clone, Debug)] pub enum LoadingRecord { - Success(String), - Warn(String), - Err(String), + Success(String, String), + Warn(String, String), + Err(String, String), +} + +impl LoadingRecord { + pub fn success(msg: String) -> Self { + let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string(); + LoadingRecord::Success(timestamp, msg) + } + + pub fn warn(msg: String) -> Self { + let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string(); + LoadingRecord::Warn(timestamp, msg) + } + + pub fn err(msg: String) -> Self { + let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string(); + LoadingRecord::Err(timestamp, msg) + } } pub struct ToolManagerBuilder { @@ -473,10 +490,11 @@ pub enum PromptQueryResult { /// - `IllegalChar`: The tool name contains characters that are not allowed /// - `EmptyDescription`: The tool description is empty or missing #[allow(dead_code)] -enum OutOfSpecName { +enum ToolValidationViolation { TooLong(String), IllegalChar(String), EmptyDescription(String), + DescriptionTooLong(String), } #[derive(Clone, Default, Debug, Eq, PartialEq)] @@ -814,7 +832,7 @@ impl ToolManager { .lock() .await .iter() - .any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(_)))) + .any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(..)))) { queue!( stderr, @@ -962,7 +980,7 @@ impl ToolManager { if !conflicts.is_empty() { let mut record_lock = self.mcp_load_record.lock().await; for (server_name, msg) in conflicts { - let record = LoadingRecord::Err(msg); + let record = LoadingRecord::err(msg); record_lock .entry(server_name) .and_modify(|v| v.push(record.clone())) @@ -1494,9 +1512,9 @@ fn spawn_orchestrator_task( drop(buf_writer); let record = String::from_utf8_lossy(record_temp_buf).to_string(); let record = if process_result.is_err() { - LoadingRecord::Warn(record) + LoadingRecord::warn(record) } else { - LoadingRecord::Success(record) + LoadingRecord::success(record) }; load_record .lock() @@ -1522,7 +1540,7 @@ fn spawn_orchestrator_task( let _ = buf_writer.flush(); drop(buf_writer); let record = String::from_utf8_lossy(record_temp_buf).to_string(); - let record = LoadingRecord::Err(record); + let record = LoadingRecord::err(record); load_record .lock() .await @@ -1606,7 +1624,7 @@ fn spawn_orchestrator_task( let _ = buf_writer.flush(); drop(buf_writer); let record = String::from_utf8_lossy(record_temp_buf).to_string(); - let record = LoadingRecord::Err(record); + let record = LoadingRecord::err(record); load_record .lock() .await @@ -1626,7 +1644,7 @@ fn spawn_orchestrator_task( let _ = buf_writer.flush(); drop(buf_writer); let record_str = String::from_utf8_lossy(record_temp_buf).to_string(); - let record = LoadingRecord::Warn(record_str.clone()); + let record = LoadingRecord::warn(record_str.clone()); load_record .lock() .await @@ -1720,7 +1738,7 @@ async fn process_tool_specs( // // For non-compliance due to point 1, we shall change it on behalf of the users. // For the rest, we simply throw a warning and reject the tool. - let mut out_of_spec_tool_names = Vec::::new(); + let mut out_of_spec_tool_names = Vec::::new(); let mut hasher = DefaultHasher::new(); let mut number_of_tools = 0_usize; @@ -1745,12 +1763,18 @@ async fn process_tool_specs( } }); if model_tool_name.len() > 64 { - out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); + out_of_spec_tool_names.push(ToolValidationViolation::TooLong(spec.name.clone())); continue; } else if spec.description.is_empty() { - out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); + out_of_spec_tool_names.push(ToolValidationViolation::EmptyDescription(spec.name.clone())); continue; } + + if spec.description.len() > 10_004 { + spec.description.truncate(10_004); + out_of_spec_tool_names.push(ToolValidationViolation::DescriptionTooLong(spec.name.clone())); + } + tn_map.insert(model_tool_name.clone(), ToolInfo { server_name: server_name.to_string(), host_tool_name: spec.name.clone(), @@ -1788,21 +1812,25 @@ async fn process_tool_specs( if !out_of_spec_tool_names.is_empty() { Err(eyre::eyre!(out_of_spec_tool_names.iter().fold( String::from( - "The following tools are out of spec. They will be excluded from the list of available tools:\n", + "The following tools are out of spec. They may have been excluded from the list of available tools:\n", ), |mut acc, name| { let (tool_name, msg) = match name { - OutOfSpecName::TooLong(tool_name) => ( + ToolValidationViolation::TooLong(tool_name) => ( tool_name.as_str(), "tool name exceeds max length of 64 when combined with server name", ), - OutOfSpecName::IllegalChar(tool_name) => ( + ToolValidationViolation::IllegalChar(tool_name) => ( tool_name.as_str(), "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$", ), - OutOfSpecName::EmptyDescription(tool_name) => { + ToolValidationViolation::EmptyDescription(tool_name) => { (tool_name.as_str(), "tool schema contains empty description") }, + ToolValidationViolation::DescriptionTooLong(tool_name) => ( + tool_name.as_str(), + "tool description is longer than 10024 characters and has been truncated", + ), }; acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); acc diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index c50d43a585..52a7c89e38 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -152,6 +152,16 @@ pub enum McpClientError { Auth(#[from] crate::auth::AuthError), } +/// Decorates the method passed in with retry logic, but only if the [RunningService] has an +/// instance of [AuthClientDropGuard]. +/// The various methods to interact with the mcp server provided by RMCP supposedly does refresh +/// token once the token expires but that logic would require us to also note down the time at +/// which a token is obtained since the only time related information in the token is the duration +/// for which a token is valid. However, if we do solely rely on the internals of these methods to +/// refresh tokens, we would have no way of knowing when a token is obtained. (Maybe there is a +/// method that would allow us to configure what extra info to include in the token. If you find it, +/// feel free to remove this. That would also enable us to simplify the definition of +/// [RunningService]) macro_rules! decorate_with_auth_retry { ($param_type:ty, $method_name:ident, $return_type:ty) => { pub async fn $method_name(&self, param: $param_type) -> Result<$return_type, rmcp::ServiceError> { @@ -166,7 +176,7 @@ macro_rules! decorate_with_auth_retry { // TODO: discern error type prior to retrying // Not entirely sure what is thrown when auth is required if let Some(auth_client) = self.get_auth_client() { - let refresh_result = auth_client.get_access_token().await; + let refresh_result = auth_client.auth_manager.lock().await.refresh_token().await; match refresh_result { Ok(_) => { // Retry the operation after token refresh @@ -340,7 +350,7 @@ impl McpClientService { Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => { debug!("## mcp: first hand shake attempt failed: {:?}", e); let refresh_res = - auth_dg.auth_client.get_access_token().await; + auth_dg.auth_client.auth_manager.lock().await.refresh_token().await; let new_self = McpClientService::new( server_name.clone(), backup_config, diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs index a705454b98..9ec9cd18d6 100644 --- a/crates/chat-cli/src/mcp_client/oauth_util.rs +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -14,6 +14,7 @@ use reqwest::Client; use rmcp::serde_json; use rmcp::transport::auth::{ AuthClient, + OAuthClientConfig, OAuthState, OAuthTokenResponse, }; @@ -26,6 +27,10 @@ use rmcp::transport::{ StreamableHttpClientTransport, WorkerTransport, }; +use serde::{ + Deserialize, + Serialize, +}; use sha2::{ Digest, Sha256, @@ -64,6 +69,8 @@ pub enum OauthUtilError { Directory(#[from] DirectoryError), #[error(transparent)] Reqwest(#[from] reqwest::Error), + #[error("Malformed directory")] + MalformDirectory, } /// A guard that automatically cancels the cancellation token when dropped. @@ -79,6 +86,27 @@ impl Drop for LoopBackDropGuard { } } +/// This is modeled after [OAuthClientConfig] +/// It's only here because [OAuthClientConfig] does not implement Serialize and Deserialize +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Registration { + pub client_id: String, + pub client_secret: Option, + pub scopes: Vec, + pub redirect_uri: String, +} + +impl From for Registration { + fn from(value: OAuthClientConfig) -> Self { + Self { + client_id: value.client_id, + client_secret: value.client_secret, + scopes: value.scopes, + redirect_uri: value.redirect_uri, + } + } +} + /// A guard that manages the lifecycle of an authenticated MCP client and automatically /// persists OAuth credentials when dropped. /// @@ -164,6 +192,10 @@ pub enum HttpTransport { WithoutAuth(WorkerTransport>), } +fn get_scopes() -> &'static [&'static str] { + &["openid", "mcp", "email", "profile"] +} + pub async fn get_http_transport( os: &Os, delete_cache: bool, @@ -175,6 +207,7 @@ pub async fn get_http_transport( let url = Url::from_str(url)?; let key = compute_key(&url); let cred_full_path = cred_dir.join(format!("{key}.token.json")); + let reg_full_path = cred_dir.join(format!("{key}.registration.json")); if delete_cache && cred_full_path.is_file() { tokio::fs::remove_file(&cred_full_path).await?; @@ -188,7 +221,8 @@ pub async fn get_http_transport( let auth_client = match auth_client { Some(auth_client) => auth_client, None => { - let am = get_auth_manager(url.clone(), cred_full_path.clone(), messenger).await?; + let am = + get_auth_manager(url.clone(), cred_full_path.clone(), reg_full_path.clone(), messenger).await?; AuthClient::new(reqwest_client, am) }, }; @@ -215,16 +249,19 @@ pub async fn get_http_transport( async fn get_auth_manager( url: Url, cred_full_path: PathBuf, + reg_full_path: PathBuf, messenger: &dyn Messenger, ) -> Result { - let content_as_bytes = tokio::fs::read(&cred_full_path).await; + let cred_as_bytes = tokio::fs::read(&cred_full_path).await; + let reg_as_bytes = tokio::fs::read(®_full_path).await; let mut oauth_state = OAuthState::new(url, None).await?; - match content_as_bytes { - Ok(bytes) => { - let token = serde_json::from_slice::(&bytes)?; + match (cred_as_bytes, reg_as_bytes) { + (Ok(cred_as_bytes), Ok(reg_as_bytes)) => { + let token = serde_json::from_slice::(&cred_as_bytes)?; + let reg = serde_json::from_slice::(®_as_bytes)?; - oauth_state.set_credentials("id", token).await?; + oauth_state.set_credentials(®.client_id, token).await?; debug!("## mcp: credentials set with cache"); @@ -232,10 +269,30 @@ async fn get_auth_manager( .into_authorization_manager() .ok_or(OauthUtilError::MissingAuthorizationManager)?) }, - Err(e) => { - info!("Error reading cached credentials: {e}"); + _ => { + info!("Error reading cached credentials"); debug!("## mcp: cache read failed. constructing auth manager from scratch"); - get_auth_manager_impl(oauth_state, messenger).await + let (am, redirect_uri) = get_auth_manager_impl(oauth_state, messenger).await?; + + // Client registration is done in [start_authorization] + // If we have gotten past that point that means we have the info to persist the + // registration on disk. These are info that we need to refresh stake + // tokens. This is in contrast to tokens, which we only persist when we drop + // the client (because that way we can write once and ensure what is on the + // disk always the most up to date) + let (client_id, _credentials) = am.get_credentials().await?; + let reg = Registration { + client_id, + client_secret: None, + scopes: get_scopes().iter().map(|s| (*s).to_string()).collect::>(), + redirect_uri, + }; + let reg_as_str = serde_json::to_string_pretty(®)?; + let reg_parent_path = reg_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir(reg_parent_path).await?; + tokio::fs::write(reg_full_path, ®_as_str).await?; + + Ok(am) }, } } @@ -243,7 +300,7 @@ async fn get_auth_manager( async fn get_auth_manager_impl( mut oauth_state: OAuthState, messenger: &dyn Messenger, -) -> Result { +) -> Result<(AuthorizationManager, String), OauthUtilError> { let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0)); let cancellation_token = tokio_util::sync::CancellationToken::new(); let (tx, rx) = tokio::sync::oneshot::channel::(); @@ -251,9 +308,8 @@ async fn get_auth_manager_impl( let (actual_addr, _dg) = make_svc(tx, socket_addr, cancellation_token).await?; info!("Listening on local host port {:?} for oauth", actual_addr); - oauth_state - .start_authorization(&["mcp", "profile", "email"], &format!("http://{}", actual_addr)) - .await?; + let redirect_uri = format!("http://{}", actual_addr); + oauth_state.start_authorization(get_scopes(), &redirect_uri).await?; let auth_url = oauth_state.get_authorization_url().await?; _ = messenger.send_oauth_link(auth_url).await; @@ -264,7 +320,7 @@ async fn get_auth_manager_impl( .into_authorization_manager() .ok_or(OauthUtilError::MissingAuthorizationManager)?; - Ok(am) + Ok((am, redirect_uri)) } pub fn compute_key(rs: &Url) -> String { @@ -320,7 +376,7 @@ async fn make_svc( { sender.send(code).map_err(LoopBackError::Send)?; } - mk_response("Auth code sent".to_string()) + mk_response("You can close this page now".to_string()) }) } } From 697dc649bcaf2cfdf4877cc2d6d6e4a204be5c10 Mon Sep 17 00:00:00 2001 From: Matt Lee <1302416+mr-lee@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:41:38 -0400 Subject: [PATCH 32/71] fix(agent): add edit subcommand support to /agent slash command (#2854) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Edit variant to AgentSubcommand enum in profile.rs - Implement edit functionality for slash command usage - Use Agent::get_agent_by_name to locate existing agents - Include post-edit validation and agent reloading - Add edit case to name() method for proper command routing - Enables /agent edit --name usage in chat sessions 🤖 Assisted by Amazon Q Developer Co-authored-by: Matt Lee --- crates/chat-cli/src/cli/chat/cli/profile.rs | 65 +++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index 233ab7888d..4b99f373d3 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -77,6 +77,12 @@ pub enum AgentSubcommand { #[arg(long, short)] from: Option, }, + /// Edit an existing agent configuration + Edit { + /// Name of the agent to edit + #[arg(long, short)] + name: String, + }, /// Generate an agent configuration using AI Generate {}, /// Delete the specified agent @@ -242,6 +248,64 @@ impl AgentSubcommand { )?; }, + Self::Edit { name } => { + let (_agent, path_with_file_name) = Agent::get_agent_by_name(os, &name) + .await + .map_err(|e| ChatError::Custom(Cow::Owned(e.to_string())))?; + + let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); + let mut cmd = std::process::Command::new(editor_cmd); + + let status = cmd.arg(&path_with_file_name).status()?; + if !status.success() { + return Err(ChatError::Custom("Editor process did not exit with success".into())); + } + + let updated_agent = Agent::load( + os, + &path_with_file_name, + &mut None, + session.conversation.mcp_enabled, + &mut session.stderr, + ) + .await; + match updated_agent { + Ok(agent) => { + session.conversation.agents.agents.insert(agent.name.clone(), agent); + }, + Err(e) => { + execute!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("Error: "), + style::ResetColor, + style::Print(&e), + style::Print("\n"), + )?; + + return Err(ChatError::Custom( + format!("Post edit validation failed for agent '{name}'. Malformed config detected: {e}") + .into(), + )); + }, + } + + execute!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print("Agent "), + style::SetForegroundColor(Color::Cyan), + style::Print(name), + style::SetForegroundColor(Color::Green), + style::Print(" has been edited successfully"), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Changes take effect on next launch"), + style::SetForegroundColor(Color::Reset) + )?; + }, + Self::Generate {} => { let agent_name = match crate::util::input("Enter agent name: ", None) { Ok(input) => input.trim().to_string(), @@ -440,6 +504,7 @@ impl AgentSubcommand { match self { Self::List => "list", Self::Create { .. } => "create", + Self::Edit { .. } => "edit", Self::Generate { .. } => "generate", Self::Delete { .. } => "delete", Self::Set { .. } => "set", From 776f2edc540516a327e8ebae0f537cfc1c3adfd2 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Fri, 12 Sep 2025 16:58:42 -0700 Subject: [PATCH 33/71] fixes creation of cache directory panic when path already exists (#2857) --- .../chat-cli/src/cli/agent/root_command_args.rs | 2 +- crates/chat-cli/src/cli/chat/cli/profile.rs | 2 +- .../chat-cli/src/cli/chat/tools/execute/mod.rs | 16 ++++++++++------ crates/chat-cli/src/mcp_client/oauth_util.rs | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/root_command_args.rs b/crates/chat-cli/src/cli/agent/root_command_args.rs index b470ba3e46..0f02028e50 100644 --- a/crates/chat-cli/src/cli/agent/root_command_args.rs +++ b/crates/chat-cli/src/cli/agent/root_command_args.rs @@ -147,7 +147,7 @@ impl AgentArgs { Some(AgentSubcommands::Edit { name }) => { let _agents = Agents::load(os, None, true, &mut stderr, mcp_enabled).await.0; let (_agent, path_with_file_name) = Agent::get_agent_by_name(os, &name).await?; - + let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); let mut cmd = std::process::Command::new(editor_cmd); diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index 4b99f373d3..83a7b634cd 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -252,7 +252,7 @@ impl AgentSubcommand { let (_agent, path_with_file_name) = Agent::get_agent_by_name(os, &name) .await .map_err(|e| ChatError::Custom(Cow::Owned(e.to_string())))?; - + let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); let mut cmd = std::process::Command::new(editor_cmd); diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs index 0ce3fdeac1..e53daa6ef6 100644 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs @@ -494,7 +494,7 @@ mod tests { use crate::cli::agent::Agent; let os = Os::new().await.unwrap(); - + // Test read-only command with default settings (allow_read_only = false) let readonly_cmd = serde_json::from_value::(serde_json::json!({ "command": "ls -la", @@ -519,15 +519,16 @@ mod tests { #[tokio::test] async fn test_eval_perm_allow_read_only_enabled() { + use std::collections::HashMap; + use crate::cli::agent::{ Agent, ToolSettingTarget, }; - use std::collections::HashMap; let os = Os::new().await.unwrap(); let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; - + let agent = Agent { name: "test_agent".to_string(), tools_settings: { @@ -566,15 +567,16 @@ mod tests { #[tokio::test] async fn test_eval_perm_allow_read_only_with_denied_commands() { + use std::collections::HashMap; + use crate::cli::agent::{ Agent, ToolSettingTarget, }; - use std::collections::HashMap; let os = Os::new().await.unwrap(); let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; - + let agent = Agent { name: "test_agent".to_string(), tools_settings: { @@ -599,7 +601,9 @@ mod tests { let res = denied_readonly_cmd.eval_perm(&os, &agent); // Should deny even read-only commands if they're in denied list - assert!(matches!(res, PermissionEvalResult::Deny(ref commands) if commands.contains(&"\\Als .*\\z".to_string()))); + assert!( + matches!(res, PermissionEvalResult::Deny(ref commands) if commands.contains(&"\\Als .*\\z".to_string())) + ); // Test different read-only command not in denied list let allowed_readonly_cmd = serde_json::from_value::(serde_json::json!({ diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs index 9ec9cd18d6..4c862fb67a 100644 --- a/crates/chat-cli/src/mcp_client/oauth_util.rs +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -289,7 +289,7 @@ async fn get_auth_manager( }; let reg_as_str = serde_json::to_string_pretty(®)?; let reg_parent_path = reg_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; - tokio::fs::create_dir(reg_parent_path).await?; + tokio::fs::create_dir_all(reg_parent_path).await?; tokio::fs::write(reg_full_path, ®_as_str).await?; Ok(am) From f1b48c96806af66c7b3baa5c55457937fc7a36f7 Mon Sep 17 00:00:00 2001 From: Erben Mo Date: Mon, 15 Sep 2025 13:56:01 -0700 Subject: [PATCH 34/71] Reduce default fs_read trust permission to current working directory only (#2824) * Reduce default fs_read trust permission to current working directory only Previously by default fs_read is trusted to read any file on user's file system. This PR reduces the fs_read permission to CWD only. This means user can still access any file under CWD without prompt. But if user needs to access file outside CWD, she will be prompted for explicit approval. User can still explicitly add fs_read to trusted tools in chat / agent definition so fs_read can read any file without prompt. This change essentially adds a layer of defense against prompt injection by following the least-privilege principle. * remove allow_read_only since it is always false now --- crates/chat-cli/src/cli/agent/mod.rs | 8 +- crates/chat-cli/src/cli/chat/tools/fs_read.rs | 114 +++++++++++++++--- crates/chat-cli/src/cli/chat/tools/mod.rs | 2 +- crates/chat-cli/src/os/env.rs | 7 ++ 4 files changed, 112 insertions(+), 19 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 7d8d301723..9ef82c15f2 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -815,7 +815,7 @@ impl Agents { // This "static" way avoids needing to construct a tool instance. fn default_permission_label(&self, tool_name: &str) -> String { let label = match tool_name { - "fs_read" => "trusted".dark_green().bold(), + "fs_read" => "trust working directory".dark_grey(), "fs_write" => "not trusted".dark_grey(), #[cfg(not(windows))] "execute_bash" => "trust read-only commands".dark_grey(), @@ -1142,9 +1142,9 @@ mod tests { let label = agents.display_label("fs_read", &ToolOrigin::Native); // With no active agent, it should fall back to default permissions - // fs_read has a default of "trusted" + // fs_read has a default of "trust working directory" assert!( - label.contains("trusted"), + label.contains("trust working directory"), "fs_read should show default trusted permission, instead found: {}", label ); @@ -1173,7 +1173,7 @@ mod tests { // Test default permissions for known tools let fs_read_label = agents.display_label("fs_read", &ToolOrigin::Native); assert!( - fs_read_label.contains("trusted"), + fs_read_label.contains("trust working directory"), "fs_read should be trusted by default, instead found: {}", fs_read_label ); diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs index a11924e9a2..c4e60ae6cc 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_read.rs @@ -109,28 +109,32 @@ impl FsRead { allowed_paths: Vec, #[serde(default)] denied_paths: Vec, - #[serde(default = "default_allow_read_only")] + #[serde(default)] allow_read_only: bool, } - fn default_allow_read_only() -> bool { - true - } - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_read"); - match agent.tools_settings.get("fs_read") { - Some(settings) => { + let settings = agent.tools_settings.get("fs_read").cloned() + .unwrap_or_else(|| serde_json::json!({})); + + { let Settings { - allowed_paths, + mut allowed_paths, denied_paths, allow_read_only, - } = match serde_json::from_value::(settings.clone()) { + } = match serde_json::from_value::(settings) { Ok(settings) => settings, Err(e) => { error!("Failed to deserialize tool settings for fs_read: {:?}", e); return PermissionEvalResult::Ask; }, }; + + // Always add current working directory to allowed paths + if let Ok(cwd) = os.env.current_dir() { + allowed_paths.push(cwd.to_string_lossy().to_string()); + } + let allow_set = { let mut builder = GlobSetBuilder::new(); for path in &allowed_paths { @@ -259,10 +263,7 @@ impl FsRead { PermissionEvalResult::Ask }, } - }, - None if is_in_allowlist => PermissionEvalResult::Allow, - _ => PermissionEvalResult::Ask, - } + } } pub async fn invoke(&self, os: &Os, updates: &mut impl Write) -> Result { @@ -862,6 +863,7 @@ fn format_mode(mode: u32) -> [char; 9] { #[cfg(test)] mod tests { use std::collections::HashMap; + use std::path::PathBuf; use super::*; use crate::cli::agent::ToolSettingTarget; @@ -1397,7 +1399,7 @@ mod tests { } #[tokio::test] - async fn test_eval_perm() { + async fn test_eval_perm_denied_path() { const DENIED_PATH_OR_FILE: &str = "/some/denied/path"; const DENIED_PATH_OR_FILE_GLOB: &str = "/denied/glob/**/path"; @@ -1447,4 +1449,88 @@ mod tests { && deny_list.iter().filter(|p| *p == DENIED_PATH_OR_FILE).collect::>().len() == 2 )); } + + #[tokio::test] + async fn test_eval_perm_allowed_path_and_cwd() { + + // by default the fake env uses "/" as the CWD. + // change it to a sub folder so we can test fs_read reading files outside CWD + let os = Os::new().await.unwrap(); + os.env.set_current_dir_for_test(PathBuf::from("/home/user")); + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: { + let mut map = HashMap::new(); + map.insert( + ToolSettingTarget("fs_read".to_string()), + serde_json::json!({ + "allowedPaths": ["/explicitly/allowed/path"] + }), + ); + map + }, + ..Default::default() // Not in allowed_tools, allow_read_only = false + }; + + // Test 1: Explicitly allowed path should work + let allowed_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/explicitly/allowed/path", "mode": "Directory" }, + { "path": "/explicitly/allowed/path/file.txt", "mode": "Line" }, + ] + })).unwrap(); + let res = allowed_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test 2: CWD should always be allowed + let cwd_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/home/user/", "mode": "Directory" }, + { "path": "/home/user/file.txt", "mode": "Line" }, + ] + })).unwrap(); + let res = cwd_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test 3: Outside CWD and not explicitly allowed should ask + let outside_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/tmp/not/allowed/file.txt", "mode": "Line" } + ] + })).unwrap(); + let res = outside_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Ask)); + } + + #[tokio::test] + async fn test_eval_perm_no_settings_cwd_behavior() { + let os = Os::new().await.unwrap(); + os.env.set_current_dir_for_test(PathBuf::from("/home/user")); + + let agent = Agent { + name: "test_agent".to_string(), + tools_settings: HashMap::new(), // No fs_read settings + ..Default::default() + }; + + // Test 1: CWD should be allowed even with no settings + let cwd_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/home/user/", "mode": "Directory" }, + { "path": "/home/user/file.txt", "mode": "Line" }, + ] + })).unwrap(); + let res = cwd_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test 2: Outside CWD should ask for permission + let outside_tool = serde_json::from_value::(serde_json::json!({ + "operations": [ + { "path": "/tmp/not/allowed/file.txt", "mode": "Line" } + ] + })).unwrap(); + let res = outside_tool.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Ask)); + } } diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index 6b8baec18f..43ccb006cf 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -57,7 +57,7 @@ use crate::cli::agent::{ use crate::cli::chat::line_tracker::FileLineTracker; use crate::os::Os; -pub const DEFAULT_APPROVE: [&str; 1] = ["fs_read"]; +pub const DEFAULT_APPROVE: [&str; 0] = []; pub const NATIVE_TOOLS: [&str; 8] = [ "fs_read", "fs_write", diff --git a/crates/chat-cli/src/os/env.rs b/crates/chat-cli/src/os/env.rs index 40aac3b5fe..63e5449419 100644 --- a/crates/chat-cli/src/os/env.rs +++ b/crates/chat-cli/src/os/env.rs @@ -132,6 +132,13 @@ impl Env { } } + pub fn set_current_dir_for_test(&self, path: PathBuf) { + use inner::Inner; + if let Inner::Fake(fake) = &self.0 { + fake.lock().unwrap().cwd = path; + } + } + pub fn current_exe(&self) -> Result { use inner::Inner; match &self.0 { From a9f170c571784f5da5604d22fb22f640b0ac4605 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Mon, 15 Sep 2025 16:19:50 -0700 Subject: [PATCH 35/71] fixes remote mcp creds not being written when obtained (#2878) --- crates/chat-cli/src/mcp_client/client.rs | 46 ++++------- crates/chat-cli/src/mcp_client/oauth_util.rs | 82 +++++++------------- 2 files changed, 44 insertions(+), 84 deletions(-) diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 52a7c89e38..36f667ef03 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use std::process::Stdio; use regex::Regex; -use reqwest::Client; use rmcp::model::{ CallToolRequestParam, CallToolResult, @@ -25,7 +24,6 @@ use rmcp::service::{ DynService, NotificationContext, }; -use rmcp::transport::auth::AuthClient; use rmcp::transport::{ ConfigureCommandExt, TokioChildProcess, @@ -52,7 +50,7 @@ use tracing::{ use super::messenger::Messenger; use super::oauth_util::HttpTransport; use super::{ - AuthClientDropGuard, + AuthClientWrapper, OauthUtilError, get_http_transport, }; @@ -175,10 +173,11 @@ macro_rules! decorate_with_auth_retry { Err(e) => { // TODO: discern error type prior to retrying // Not entirely sure what is thrown when auth is required - if let Some(auth_client) = self.get_auth_client() { - let refresh_result = auth_client.auth_manager.lock().await.refresh_token().await; + if let Some(auth_client) = self.auth_client.as_ref() { + let refresh_result = auth_client.refresh_token().await; match refresh_result { Ok(_) => { + info!("Token refreshed"); // Retry the operation after token refresh match &self.inner_service { InnerService::Original(rs) => rs.$method_name(param).await, @@ -245,20 +244,14 @@ impl Clone for InnerService { #[derive(Debug)] pub struct RunningService { pub inner_service: InnerService, - auth_dropguard: Option, + auth_client: Option, } impl Clone for RunningService { fn clone(&self) -> Self { - let auth_dropguard = self.auth_dropguard.as_ref().map(|dg| { - let mut dg = dg.clone(); - dg.should_write = false; - dg - }); - RunningService { inner_service: self.inner_service.clone(), - auth_dropguard, + auth_client: self.auth_client.clone(), } } } @@ -267,10 +260,6 @@ impl RunningService { decorate_with_auth_retry!(CallToolRequestParam, call_tool, CallToolResult); decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult); - - pub fn get_auth_client(&self) -> Option> { - self.auth_dropguard.as_ref().map(|a| a.auth_client.clone()) - } } pub type StdioTransport = (TokioChildProcess, Option); @@ -341,7 +330,7 @@ impl McpClientService { }, Transport::Http(http_transport) => { match http_transport { - HttpTransport::WithAuth((transport, mut auth_dg)) => { + HttpTransport::WithAuth((transport, mut auth_client)) => { // The crate does not automatically refresh tokens when they expire. We // would need to handle that here let url = self.config.url.clone(); @@ -349,8 +338,7 @@ impl McpClientService { Ok(service) => service, Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => { debug!("## mcp: first hand shake attempt failed: {:?}", e); - let refresh_res = - auth_dg.auth_client.auth_manager.lock().await.refresh_token().await; + let refresh_res = auth_client.refresh_token().await; let new_self = McpClientService::new( server_name.clone(), backup_config, @@ -358,15 +346,14 @@ impl McpClientService { ); let new_transport = - get_http_transport(&os_clone, true, &url, Some(auth_dg.auth_client.clone()), &*messenger_dup).await?; + get_http_transport(&os_clone, true, &url, Some(auth_client.auth_client.clone()), &*messenger_dup).await?; match new_transport { - HttpTransport::WithAuth((new_transport, new_auth_dg)) => { - auth_dg.should_write = false; - auth_dg = new_auth_dg; + HttpTransport::WithAuth((new_transport, new_auth_client)) => { + auth_client = new_auth_client; match refresh_res { - Ok(_token) => { + Ok(_) => { new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? }, Err(e) => { @@ -379,9 +366,8 @@ impl McpClientService { get_http_transport(&os_clone, true, &url, None, &*messenger_dup).await?; match new_transport { - HttpTransport::WithAuth((new_transport, new_auth_dg)) => { - auth_dg = new_auth_dg; - auth_dg.should_write = false; + HttpTransport::WithAuth((new_transport, new_auth_client)) => { + auth_client = new_auth_client; new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? }, HttpTransport::WithoutAuth(new_transport) => { @@ -398,7 +384,7 @@ impl McpClientService { Err(e) => return Err(e.into()), }; - (service, None, Some(auth_dg)) + (service, None, Some(auth_client)) }, HttpTransport::WithoutAuth(transport) => { let service = self.into_dyn().serve(transport).await.map_err(Box::new)?; @@ -496,7 +482,7 @@ impl McpClientService { Ok(RunningService { inner_service: InnerService::Original(service), - auth_dropguard, + auth_client: auth_dropguard, }) }); diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs index 4c862fb67a..d7b3b50bbd 100644 --- a/crates/chat-cli/src/mcp_client/oauth_util.rs +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -71,6 +71,8 @@ pub enum OauthUtilError { Reqwest(#[from] reqwest::Error), #[error("Malformed directory")] MalformDirectory, + #[error("Missing credential")] + MissingCredentials, } /// A guard that automatically cancels the cancellation token when dropped. @@ -107,68 +109,36 @@ impl From for Registration { } } -/// A guard that manages the lifecycle of an authenticated MCP client and automatically -/// persists OAuth credentials when dropped. +/// A wrapper that manages an authenticated MCP client. /// -/// This struct wraps an `AuthClient` and ensures that OAuth tokens are written to disk -/// when the guard goes out of scope, unless explicitly disabled via `should_write`. -/// This provides automatic credential caching for MCP server connections that require -/// OAuth authentication. +/// This struct wraps an `AuthClient` and provides access to OAuth credentials +/// for MCP server connections that require authentication. The credentials +/// are managed separately from this wrapper's lifecycle. #[derive(Clone, Debug)] -pub struct AuthClientDropGuard { - pub should_write: bool, +pub struct AuthClientWrapper { pub cred_full_path: PathBuf, pub auth_client: AuthClient, } -impl AuthClientDropGuard { +impl AuthClientWrapper { pub fn new(cred_full_path: PathBuf, auth_client: AuthClient) -> Self { Self { - should_write: true, cred_full_path, auth_client, } } -} -impl Drop for AuthClientDropGuard { - fn drop(&mut self) { - if !self.should_write { - return; - } + /// Refreshes token in memory using the registration read from when the auth client was + /// spawned. This also persists the retrieved token + pub async fn refresh_token(&self) -> Result<(), OauthUtilError> { + let cred = self.auth_client.auth_manager.lock().await.refresh_token().await?; + let parent_path = self.cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir_all(parent_path).await?; - let auth_client_clone = self.auth_client.clone(); - let path = self.cred_full_path.clone(); + let cred_as_bytes = serde_json::to_string_pretty(&cred)?; + tokio::fs::write(&self.cred_full_path, &cred_as_bytes).await?; - tokio::spawn(async move { - let Ok((client_id, cred)) = auth_client_clone.auth_manager.lock().await.get_credentials().await else { - error!("Failed to retrieve credentials in drop routine"); - return; - }; - let Some(cred) = cred else { - error!("Failed to retrieve credentials in drop routine from {client_id}"); - return; - }; - let Some(parent_path) = path.parent() else { - error!("Failed to retrieve parent path for token in drop routine for {client_id}"); - return; - }; - if let Err(e) = tokio::fs::create_dir_all(parent_path).await { - error!("Error making parent directory for token cache in drop routine for {client_id}: {e}"); - return; - } - - let serialized_cred = match serde_json::to_string_pretty(&cred) { - Ok(cred) => cred, - Err(e) => { - error!("Failed to serialize credentials for {client_id}: {e}"); - return; - }, - }; - if let Err(e) = tokio::fs::write(path, &serialized_cred).await { - error!("Error making writing token cache in drop routine: {e}"); - } - }); + Ok(()) } } @@ -186,7 +156,7 @@ pub enum HttpTransport { WithAuth( ( WorkerTransport>>, - AuthClientDropGuard, + AuthClientWrapper, ), ), WithoutAuth(WorkerTransport>), @@ -233,7 +203,7 @@ pub async fn get_http_transport( ..Default::default() }); - let auth_dg = AuthClientDropGuard::new(cred_full_path, auth_client); + let auth_dg = AuthClientWrapper::new(cred_full_path, auth_client); debug!("## mcp: transport obtained"); Ok(HttpTransport::WithAuth((transport, auth_dg))) @@ -276,11 +246,8 @@ async fn get_auth_manager( // Client registration is done in [start_authorization] // If we have gotten past that point that means we have the info to persist the - // registration on disk. These are info that we need to refresh stake - // tokens. This is in contrast to tokens, which we only persist when we drop - // the client (because that way we can write once and ensure what is on the - // disk always the most up to date) - let (client_id, _credentials) = am.get_credentials().await?; + // registration on disk. + let (client_id, credentials) = am.get_credentials().await?; let reg = Registration { client_id, client_secret: None, @@ -292,6 +259,13 @@ async fn get_auth_manager( tokio::fs::create_dir_all(reg_parent_path).await?; tokio::fs::write(reg_full_path, ®_as_str).await?; + let credentials = credentials.ok_or(OauthUtilError::MissingCredentials)?; + + let cred_parent_path = cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?; + tokio::fs::create_dir_all(cred_parent_path).await?; + let reg_as_str = serde_json::to_string_pretty(&credentials)?; + tokio::fs::write(cred_full_path, ®_as_str).await?; + Ok(am) }, } From 5b379eb2b81497dfd634882675dcc81937612cf8 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Mon, 15 Sep 2025 17:11:18 -0700 Subject: [PATCH 36/71] fixes bug where refreshed credentials gets deleted (#2879) --- crates/chat-cli/src/mcp_client/client.rs | 6 +++--- crates/chat-cli/src/mcp_client/oauth_util.rs | 5 ----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 36f667ef03..dc1fc1dd2f 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -346,7 +346,7 @@ impl McpClientService { ); let new_transport = - get_http_transport(&os_clone, true, &url, Some(auth_client.auth_client.clone()), &*messenger_dup).await?; + get_http_transport(&os_clone, &url, Some(auth_client.auth_client.clone()), &*messenger_dup).await?; match new_transport { HttpTransport::WithAuth((new_transport, new_auth_client)) => { @@ -363,7 +363,7 @@ impl McpClientService { // case we would need to have user go through the auth flow // again let new_transport = - get_http_transport(&os_clone, true, &url, None, &*messenger_dup).await?; + get_http_transport(&os_clone, &url, None, &*messenger_dup).await?; match new_transport { HttpTransport::WithAuth((new_transport, new_auth_client)) => { @@ -519,7 +519,7 @@ impl McpClientService { Ok(Transport::Stdio((tokio_child_process, child_stderr))) }, TransportType::Http => { - let http_transport = get_http_transport(os, false, url, None, messenger).await?; + let http_transport = get_http_transport(os, url, None, messenger).await?; Ok(Transport::Http(http_transport)) }, diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs index d7b3b50bbd..f284265cf0 100644 --- a/crates/chat-cli/src/mcp_client/oauth_util.rs +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -168,7 +168,6 @@ fn get_scopes() -> &'static [&'static str] { pub async fn get_http_transport( os: &Os, - delete_cache: bool, url: &str, auth_client: Option>, messenger: &dyn Messenger, @@ -179,10 +178,6 @@ pub async fn get_http_transport( let cred_full_path = cred_dir.join(format!("{key}.token.json")); let reg_full_path = cred_dir.join(format!("{key}.registration.json")); - if delete_cache && cred_full_path.is_file() { - tokio::fs::remove_file(&cred_full_path).await?; - } - let reqwest_client = reqwest::Client::default(); let probe_resp = reqwest_client.get(url.clone()).send().await?; match probe_resp.status() { From 0de1451e152b70c2149d9917bc9c6e16eb2fbf5a Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Mon, 15 Sep 2025 17:31:02 -0700 Subject: [PATCH 37/71] fix(mcp): bug where refreshed credentials gets deleted (#2880) --- crates/chat-cli/src/mcp_client/client.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index dc1fc1dd2f..fbf5d745de 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -361,7 +361,9 @@ impl McpClientService { info!("Retry for http transport failed {e}. Possible reauth needed"); // This could be because the refresh token is expired, in which // case we would need to have user go through the auth flow - // again + // again. We do this by deleting the cred + // and discarding the client to trigger a full auth flow + tokio::fs::remove_file(&auth_client.cred_full_path).await?; let new_transport = get_http_transport(&os_clone, &url, None, &*messenger_dup).await?; From 63f022a477c527dca1df9788829c704afa73a5b7 Mon Sep 17 00:00:00 2001 From: Kenneth Sanchez V Date: Tue, 16 Sep 2025 10:20:14 -0700 Subject: [PATCH 38/71] Update bm25 to v2.3.2 and ignore unmaintained fxhash advisory (#2872) - Updated bm25 from v2.2.1 to v2.3.2 (latest version) - Added RUSTSEC-2025-0057 to deny.toml ignore list for unmaintained fxhash - fxhash is a transitive dependency of bm25, waiting for upstream migration to rustc-hash Co-authored-by: Kenneth S. --- Cargo.lock | 10 +++++----- crates/semantic-search-client/Cargo.toml | 2 +- deny.toml | 2 ++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 869e29ea84..049bc6c0f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -948,9 +948,9 @@ dependencies = [ [[package]] name = "bm25" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84ff0d57042bc263e2ebadb3703424b59b65870902649a2b3d0f4d7ab863244" +checksum = "1cbd8ffdfb7b4c2ff038726178a780a94f90525ed0ad264c0afaa75dd8c18a64" dependencies = [ "cached", "deunicode", @@ -4237,7 +4237,7 @@ version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", "chrono", "getrandom 0.2.16", "http 1.3.1", @@ -6008,9 +6008,9 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "stop-words" -version = "0.8.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c6a86be9f7fa4559b7339669e72026eb437f5e9c5a85c207fe1033079033a17" +checksum = "645a3d441ccf4bf47f2e4b7681461986681a6eeea9937d4c3bc9febd61d17c71" dependencies = [ "serde_json", ] diff --git a/crates/semantic-search-client/Cargo.toml b/crates/semantic-search-client/Cargo.toml index 5024a5f4c6..fcf6c49447 100644 --- a/crates/semantic-search-client/Cargo.toml +++ b/crates/semantic-search-client/Cargo.toml @@ -30,7 +30,7 @@ glob.workspace = true hnsw_rs = "=0.3.1" # BM25 implementation - works on all platforms including ARM -bm25 = { version = "2.2.1", features = ["language_detection"] } +bm25 = { version = "2.3.2", features = ["language_detection"] } # Common dependencies for all platforms anyhow = "1.0" diff --git a/deny.toml b/deny.toml index 60971d3c22..6a974b0974 100644 --- a/deny.toml +++ b/deny.toml @@ -27,6 +27,8 @@ ignore = [ "RUSTSEC-2024-0429", # paste is used in core deps "RUSTSEC-2024-0436", + # fxhash is unmaintained but used by bm25, waiting for bm25 to migrate + "RUSTSEC-2025-0057", ] [licenses] From 214b3f76f5a893c0a772d071767eb58e47e430d8 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Tue, 16 Sep 2025 10:47:02 -0700 Subject: [PATCH 39/71] adds temp message to still loading section of /tools (#2881) --- crates/chat-cli/src/cli/chat/cli/tools.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/cli/tools.rs b/crates/chat-cli/src/cli/chat/cli/tools.rs index 717649070d..1f1e5267ff 100644 --- a/crates/chat-cli/src/cli/chat/cli/tools.rs +++ b/crates/chat-cli/src/cli/chat/cli/tools.rs @@ -147,7 +147,11 @@ impl ToolsArgs { queue!( session.stderr, style::SetAttribute(Attribute::Bold), - style::Print("Servers still loading"), + style::Print("Servers loading (Some of these might need auth. See "), + style::SetForegroundColor(Color::Green), + style::Print("/mcp"), + style::SetForegroundColor(Color::Reset), + style::Print(" for details)"), style::SetAttribute(Attribute::Reset), style::Print("\n"), style::Print("▔".repeat(terminal_width)), From 0c240550195e46b47750b5d7a7e5af3a7ebaf525 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Tue, 16 Sep 2025 11:41:24 -0700 Subject: [PATCH 40/71] chore: removes codeowner for schema (#2892) --- .github/CODEOWNERS | 1 - 1 file changed, 1 deletion(-) delete mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 8d5534ec90..0000000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1 +0,0 @@ -schemas/ @chaynabors From 584def035172ec12c88d51a2d3b345fe182ae12d Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Tue, 16 Sep 2025 11:46:00 -0700 Subject: [PATCH 41/71] chore(agent): updates agent config schema (#2891) --- .../src/cli/chat/tools/custom_tool.rs | 3 ++- schemas/agent-v1.json | 27 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index be5acd37fd..907f70d35c 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -44,7 +44,8 @@ impl Default for TransportType { #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] pub struct CustomToolConfig { - /// The type of transport the mcp server is expecting + /// The type of transport the mcp server is expecting. For http transport, only url (for now) + /// is taken into account. #[serde(default)] pub r#type: TransportType, /// The URL endpoint for HTTP-based MCP servers diff --git a/schemas/agent-v1.json b/schemas/agent-v1.json index d49fc53406..c7b63492d3 100644 --- a/schemas/agent-v1.json +++ b/schemas/agent-v1.json @@ -16,6 +16,20 @@ }, "required": ["command"] } + }, + "TransportType": { + "oneOf": [ + { + "description": "Standard input/output transport (default)", + "type": "string", + "const": "stdio" + }, + { + "description": "HTTP transport for web-based communication", + "type": "string", + "const": "http" + } + ] } }, "properties": { @@ -49,9 +63,20 @@ "additionalProperties": { "type": "object", "properties": { + "type": { + "description": "The type of transport the mcp server is expecting. For http transport, only url (for now) is taken into account", + "$ref": "#/$definitions/TransportType", + "default": "stdio" + }, + "url": { + "description": "The URL endpoint for HTTP-based MCP servers", + "type": "string", + "default": "" + }, "command": { "description": "The command string used to initialize the mcp server", - "type": "string" + "type": "string", + "default": "" }, "args": { "description": "A list of arguments to be used to run the command with", From 8417df71be2c2090b5fd08c91561fda44d7e7566 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Tue, 16 Sep 2025 14:34:59 -0700 Subject: [PATCH 42/71] chore: version bump (#2894) --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/chat-cli/src/cli/feed.json | 68 +++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 049bc6c0f7..cc2f536f71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1207,7 +1207,7 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chat_cli" -version = "1.15.0" +version = "1.16.0" dependencies = [ "amzn-codewhisperer-client", "amzn-codewhisperer-streaming-client", @@ -5647,7 +5647,7 @@ dependencies = [ [[package]] name = "semantic_search_client" -version = "1.15.0" +version = "1.16.0" dependencies = [ "anyhow", "bm25", diff --git a/Cargo.toml b/Cargo.toml index 99d56615c5..f35a1a154d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ authors = ["Amazon Q CLI Team (q-cli@amazon.com)", "Chay Nabors (nabochay@amazon edition = "2024" homepage = "https://aws.amazon.com/q/" publish = false -version = "1.15.0" +version = "1.16.0" license = "MIT OR Apache-2.0" [workspace.dependencies] diff --git a/crates/chat-cli/src/cli/feed.json b/crates/chat-cli/src/cli/feed.json index 7baece70c0..d4c928005c 100644 --- a/crates/chat-cli/src/cli/feed.json +++ b/crates/chat-cli/src/cli/feed.json @@ -10,6 +10,74 @@ "hidden": true, "changes": [] }, + { + "type": "release", + "date": "2025-09-16", + "version": "1.16.0", + "title": "Version 1.16.0", + "changes": [ + { + "type": "added", + "description": "Support for remote MCP connections - [#2836](https://github.com/aws/amazon-q-developer-cli/pull/2836)" + }, + { + "type": "added", + "description": "A new `/tangent tail` command to preserve the last tangent conversation - [#2838](https://github.com/aws/amazon-q-developer-cli/pull/2838)" + }, + { + "type": "added", + "description": "A new edit subcommand to `/agent` slash command for modifying existing agents - [#2854](https://github.com/aws/amazon-q-developer-cli/pull/2854)" + }, + { + "type": "added", + "description": "A new auto-announcement feature with `/changelog` command - [#2833](https://github.com/aws/amazon-q-developer-cli/pull/2833)" + }, + { + "type": "added", + "description": "A new CLI history persistence feature with file storage - [#2769](https://github.com/aws/amazon-q-developer-cli/pull/2769)" + }, + { + "type": "added", + "description": "Support for comma-containing arguments in MCP --args parameter - [#2754](https://github.com/aws/amazon-q-developer-cli/pull/2754)" + }, + { + "type": "added", + "description": "Support for configurable autoAllowReadonly setting in use_aws tool - [#2828](https://github.com/aws/amazon-q-developer-cli/pull/2828)" + }, + { + "type": "added", + "description": "Support for configurable line wrapping in chat interface - [#2816](https://github.com/aws/amazon-q-developer-cli/pull/2816)" + }, + { + "type": "added", + "description": "Support for model field in agent configuration format - [#2815](https://github.com/aws/amazon-q-developer-cli/pull/2815)" + }, + { + "type": "added", + "description": "AGENTS.md documentation to default agent resources - [#2812](https://github.com/aws/amazon-q-developer-cli/pull/2812)" + }, + { + "type": "security", + "description": "Reduced default fs_read trust permission to current working directory only - [#2824](https://github.com/aws/amazon-q-developer-cli/pull/2824)" + }, + { + "type": "security", + "description": "Changed autoAllowReadonly default to false for security in execute_bash - [#2846](https://github.com/aws/amazon-q-developer-cli/pull/2846)" + }, + { + "type": "security", + "description": "Updated dangerous patterns for execute_bash to include $ character - [#2811](https://github.com/aws/amazon-q-developer-cli/pull/2811)" + }, + { + "type": "fixed", + "description": "Path with trailing slash not being handled in file matching - [#2817](https://github.com/aws/amazon-q-developer-cli/pull/2817)" + }, + { + "type": "fixed", + "description": "Summary being erroneously preserved when conversation is cleared - [#2793](https://github.com/aws/amazon-q-developer-cli/pull/2793)" + } + ] + }, { "type": "release", "date": "2025-09-02", From 864a27b208f751f82d60c60483d158673ce7dd63 Mon Sep 17 00:00:00 2001 From: evanliu048 Date: Tue, 16 Sep 2025 14:50:16 -0700 Subject: [PATCH 43/71] fix format (#2895) --- crates/chat-cli/src/cli/chat/tools/fs_read.rs | 296 +++++++++--------- 1 file changed, 150 insertions(+), 146 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs index c4e60ae6cc..5c5c0abf25 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_read.rs @@ -114,156 +114,156 @@ impl FsRead { } let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_read"); - let settings = agent.tools_settings.get("fs_read").cloned() + let settings = agent + .tools_settings + .get("fs_read") + .cloned() .unwrap_or_else(|| serde_json::json!({})); - + { - let Settings { - mut allowed_paths, - denied_paths, - allow_read_only, - } = match serde_json::from_value::(settings) { - Ok(settings) => settings, - Err(e) => { - error!("Failed to deserialize tool settings for fs_read: {:?}", e); - return PermissionEvalResult::Ask; - }, - }; - - // Always add current working directory to allowed paths - if let Ok(cwd) = os.env.current_dir() { - allowed_paths.push(cwd.to_string_lossy().to_string()); + let Settings { + mut allowed_paths, + denied_paths, + allow_read_only, + } = match serde_json::from_value::(settings) { + Ok(settings) => settings, + Err(e) => { + error!("Failed to deserialize tool settings for fs_read: {:?}", e); + return PermissionEvalResult::Ask; + }, + }; + + // Always add current working directory to allowed paths + if let Ok(cwd) = os.env.current_dir() { + allowed_paths.push(cwd.to_string_lossy().to_string()); + } + + let allow_set = { + let mut builder = GlobSetBuilder::new(); + for path in &allowed_paths { + let Ok(path) = directories::canonicalizes_path(os, path) else { + continue; + }; + if let Err(e) = directories::add_gitignore_globs(&mut builder, path.as_str()) { + warn!("Failed to create glob from path given: {path}: {e}. Ignoring."); + } } - - let allow_set = { - let mut builder = GlobSetBuilder::new(); - for path in &allowed_paths { - let Ok(path) = directories::canonicalizes_path(os, path) else { - continue; - }; - if let Err(e) = directories::add_gitignore_globs(&mut builder, path.as_str()) { - warn!("Failed to create glob from path given: {path}: {e}. Ignoring."); - } + builder.build() + }; + + let mut sanitized_deny_list = Vec::<&String>::new(); + let deny_set = { + let mut builder = GlobSetBuilder::new(); + for path in &denied_paths { + let Ok(processed_path) = directories::canonicalizes_path(os, path) else { + continue; + }; + match directories::add_gitignore_globs(&mut builder, processed_path.as_str()) { + Ok(_) => { + // Note that we need to push twice here because for each rule we + // are creating two globs (one for file and one for directory) + sanitized_deny_list.push(path); + sanitized_deny_list.push(path); + }, + Err(e) => warn!("Failed to create glob from path given: {path}: {e}. Ignoring."), } - builder.build() - }; + } + builder.build() + }; - let mut sanitized_deny_list = Vec::<&String>::new(); - let deny_set = { - let mut builder = GlobSetBuilder::new(); - for path in &denied_paths { - let Ok(processed_path) = directories::canonicalizes_path(os, path) else { - continue; - }; - match directories::add_gitignore_globs(&mut builder, processed_path.as_str()) { - Ok(_) => { - // Note that we need to push twice here because for each rule we - // are creating two globs (one for file and one for directory) - sanitized_deny_list.push(path); - sanitized_deny_list.push(path); + match (allow_set, deny_set) { + (Ok(allow_set), Ok(deny_set)) => { + let mut deny_list = Vec::::new(); + let mut ask = false; + + for op in &self.operations { + match op { + FsReadOperation::Line(FsLine { path, .. }) + | FsReadOperation::Directory(FsDirectory { path, .. }) + | FsReadOperation::Search(FsSearch { path, .. }) => { + let Ok(path) = directories::canonicalizes_path(os, path) else { + ask = true; + continue; + }; + let denied_match_set = deny_set.matches(path.as_ref() as &str); + if !denied_match_set.is_empty() { + let deny_res = PermissionEvalResult::Deny({ + denied_match_set + .iter() + .filter_map(|i| sanitized_deny_list.get(*i).map(|s| (*s).clone())) + .collect::>() + }); + deny_list.push(deny_res); + continue; + } + + // We only want to ask if we are not allowing read only + // operation + if !is_in_allowlist && !allow_read_only && !allow_set.is_match(path.as_ref() as &str) { + ask = true; + } + }, + FsReadOperation::Image(fs_image) => { + let paths = &fs_image.image_paths; + let denied_match_set = paths + .iter() + .flat_map(|path| { + let Ok(path) = directories::canonicalizes_path(os, path) else { + return vec![]; + }; + deny_set.matches(path.as_ref() as &str) + }) + .collect::>(); + if !denied_match_set.is_empty() { + let deny_res = PermissionEvalResult::Deny({ + denied_match_set + .iter() + .filter_map(|i| sanitized_deny_list.get(*i).map(|s| (*s).clone())) + .collect::>() + }); + deny_list.push(deny_res); + continue; + } + + // We only want to ask if we are not allowing read only + // operation + if !is_in_allowlist + && !allow_read_only + && !paths.iter().any(|path| allow_set.is_match(path)) + { + ask = true; + } }, - Err(e) => warn!("Failed to create glob from path given: {path}: {e}. Ignoring."), } } - builder.build() - }; - - match (allow_set, deny_set) { - (Ok(allow_set), Ok(deny_set)) => { - let mut deny_list = Vec::::new(); - let mut ask = false; - - for op in &self.operations { - match op { - FsReadOperation::Line(FsLine { path, .. }) - | FsReadOperation::Directory(FsDirectory { path, .. }) - | FsReadOperation::Search(FsSearch { path, .. }) => { - let Ok(path) = directories::canonicalizes_path(os, path) else { - ask = true; - continue; - }; - let denied_match_set = deny_set.matches(path.as_ref() as &str); - if !denied_match_set.is_empty() { - let deny_res = PermissionEvalResult::Deny({ - denied_match_set - .iter() - .filter_map(|i| sanitized_deny_list.get(*i).map(|s| (*s).clone())) - .collect::>() - }); - deny_list.push(deny_res); - continue; - } - - // We only want to ask if we are not allowing read only - // operation - if !is_in_allowlist - && !allow_read_only - && !allow_set.is_match(path.as_ref() as &str) - { - ask = true; - } - }, - FsReadOperation::Image(fs_image) => { - let paths = &fs_image.image_paths; - let denied_match_set = paths - .iter() - .flat_map(|path| { - let Ok(path) = directories::canonicalizes_path(os, path) else { - return vec![]; - }; - deny_set.matches(path.as_ref() as &str) - }) - .collect::>(); - if !denied_match_set.is_empty() { - let deny_res = PermissionEvalResult::Deny({ - denied_match_set - .iter() - .filter_map(|i| sanitized_deny_list.get(*i).map(|s| (*s).clone())) - .collect::>() - }); - deny_list.push(deny_res); - continue; - } - - // We only want to ask if we are not allowing read only - // operation - if !is_in_allowlist - && !allow_read_only - && !paths.iter().any(|path| allow_set.is_match(path)) - { - ask = true; - } - }, - } - } - if !deny_list.is_empty() { - PermissionEvalResult::Deny({ - deny_list.into_iter().fold(Vec::::new(), |mut acc, res| { - if let PermissionEvalResult::Deny(mut rules) = res { - acc.append(&mut rules); - } - acc - }) + if !deny_list.is_empty() { + PermissionEvalResult::Deny({ + deny_list.into_iter().fold(Vec::::new(), |mut acc, res| { + if let PermissionEvalResult::Deny(mut rules) = res { + acc.append(&mut rules); + } + acc }) - } else if ask { - PermissionEvalResult::Ask - } else { - PermissionEvalResult::Allow - } - }, - (allow_res, deny_res) => { - if let Err(e) = allow_res { - warn!("fs_read failed to build allow set: {:?}", e); - } - if let Err(e) = deny_res { - warn!("fs_read failed to build deny set: {:?}", e); - } - warn!("One or more detailed args failed to parse, falling back to ask"); + }) + } else if ask { PermissionEvalResult::Ask - }, - } + } else { + PermissionEvalResult::Allow + } + }, + (allow_res, deny_res) => { + if let Err(e) = allow_res { + warn!("fs_read failed to build allow set: {:?}", e); + } + if let Err(e) = deny_res { + warn!("fs_read failed to build deny set: {:?}", e); + } + warn!("One or more detailed args failed to parse, falling back to ask"); + PermissionEvalResult::Ask + }, } + } } pub async fn invoke(&self, os: &Os, updates: &mut impl Write) -> Result { @@ -1452,12 +1452,11 @@ mod tests { #[tokio::test] async fn test_eval_perm_allowed_path_and_cwd() { - // by default the fake env uses "/" as the CWD. // change it to a sub folder so we can test fs_read reading files outside CWD let os = Os::new().await.unwrap(); os.env.set_current_dir_for_test(PathBuf::from("/home/user")); - + let agent = Agent { name: "test_agent".to_string(), tools_settings: { @@ -1479,7 +1478,8 @@ mod tests { { "path": "/explicitly/allowed/path", "mode": "Directory" }, { "path": "/explicitly/allowed/path/file.txt", "mode": "Line" }, ] - })).unwrap(); + })) + .unwrap(); let res = allowed_tool.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Allow)); @@ -1489,7 +1489,8 @@ mod tests { { "path": "/home/user/", "mode": "Directory" }, { "path": "/home/user/file.txt", "mode": "Line" }, ] - })).unwrap(); + })) + .unwrap(); let res = cwd_tool.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Allow)); @@ -1498,7 +1499,8 @@ mod tests { "operations": [ { "path": "/tmp/not/allowed/file.txt", "mode": "Line" } ] - })).unwrap(); + })) + .unwrap(); let res = outside_tool.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Ask)); } @@ -1507,7 +1509,7 @@ mod tests { async fn test_eval_perm_no_settings_cwd_behavior() { let os = Os::new().await.unwrap(); os.env.set_current_dir_for_test(PathBuf::from("/home/user")); - + let agent = Agent { name: "test_agent".to_string(), tools_settings: HashMap::new(), // No fs_read settings @@ -1520,7 +1522,8 @@ mod tests { { "path": "/home/user/", "mode": "Directory" }, { "path": "/home/user/file.txt", "mode": "Line" }, ] - })).unwrap(); + })) + .unwrap(); let res = cwd_tool.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Allow)); @@ -1529,7 +1532,8 @@ mod tests { "operations": [ { "path": "/tmp/not/allowed/file.txt", "mode": "Line" } ] - })).unwrap(); + })) + .unwrap(); let res = outside_tool.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Ask)); } From 0aa3c2198334d91ef6eb80be7e384a7fe13f55e8 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Wed, 17 Sep 2025 10:52:23 -0700 Subject: [PATCH 44/71] fix(mcp): shell expansion not being done for stdio command (#2915) --- crates/chat-cli/src/mcp_client/client.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index fbf5d745de..4156cf34ae 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -60,7 +60,10 @@ use crate::cli::chat::tools::custom_tool::{ TransportType, }; use crate::os::Os; -use crate::util::directories::DirectoryError; +use crate::util::directories::{ + DirectoryError, + canonicalizes_path, +}; /// Fetches all pages of specified resources from a server macro_rules! paginated_fetch { @@ -504,7 +507,8 @@ impl McpClientService { match transport_type { TransportType::Stdio => { - let command = Command::new(command_as_str).configure(|cmd| { + let expanded_cmd = canonicalizes_path(os, command_as_str)?; + let command = Command::new(expanded_cmd).configure(|cmd| { if let Some(envs) = config_envs { process_env_vars(envs, &os.env); cmd.envs(envs); From 6e8f9c99ff0b5a9ad9317b068b827d2143dbbece Mon Sep 17 00:00:00 2001 From: kkashilk <93673379+kkashilk@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:16:06 -0700 Subject: [PATCH 45/71] Bump up version for hotfix (#2916) --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cc2f536f71..f85bfb292f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1207,7 +1207,7 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chat_cli" -version = "1.16.0" +version = "1.16.1" dependencies = [ "amzn-codewhisperer-client", "amzn-codewhisperer-streaming-client", @@ -5647,7 +5647,7 @@ dependencies = [ [[package]] name = "semantic_search_client" -version = "1.16.0" +version = "1.16.1" dependencies = [ "anyhow", "bm25", diff --git a/Cargo.toml b/Cargo.toml index f35a1a154d..bc7b16ff38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ authors = ["Amazon Q CLI Team (q-cli@amazon.com)", "Chay Nabors (nabochay@amazon edition = "2024" homepage = "https://aws.amazon.com/q/" publish = false -version = "1.16.0" +version = "1.16.1" license = "MIT OR Apache-2.0" [workspace.dependencies] From 08c9a45d31b0054e74e8da705f3a673f53600bb4 Mon Sep 17 00:00:00 2001 From: Brandon Kiser <51934408+brandonskiser@users.noreply.github.com> Date: Wed, 17 Sep 2025 12:18:20 -0700 Subject: [PATCH 46/71] chore: add 1.16.1 feed.json entry (#2919) --- crates/chat-cli/src/cli/feed.json | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/crates/chat-cli/src/cli/feed.json b/crates/chat-cli/src/cli/feed.json index d4c928005c..a9d0f68eb6 100644 --- a/crates/chat-cli/src/cli/feed.json +++ b/crates/chat-cli/src/cli/feed.json @@ -10,6 +10,18 @@ "hidden": true, "changes": [] }, + { + "type": "release", + "date": "2025-09-17", + "version": "1.16.1", + "title": "Version 1.16.1", + "changes": [ + { + "type": "fixed", + "description": "Dashboard not updating after logging in - [#688](https://github.com/aws/amazon-q-developer-cli-autocomplete/pull/688)" + } + ] + }, { "type": "release", "date": "2025-09-16", From bc4ec5ce38e1b4bc920678be675e5e6cf352312e Mon Sep 17 00:00:00 2001 From: Erben Mo Date: Wed, 17 Sep 2025 15:24:37 -0700 Subject: [PATCH 47/71] Change autocomplete shortcut from ctrl-f to ctrl-g (#2825) * Change autocomplete shortcut from ctrl-f to ctrl-g The reason is ctrl-f is the standard shortcut in UNIX for moving cursor forward by 1 character. You can find it being supported everywhere... in your browser, your terminal, etc. * make the autocompletion key configurable --- crates/chat-cli/src/cli/chat/prompt.rs | 43 +++++++++++++++++++++--- crates/chat-cli/src/database/settings.rs | 4 +++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index e7addd8f32..06d449fdce 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -474,19 +474,23 @@ pub fn rl( EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), ); - // Add custom keybinding for Ctrl+J to insert a newline + // Add custom keybinding for Ctrl+j to insert a newline rl.bind_sequence( KeyEvent(KeyCode::Char('j'), Modifiers::CTRL), EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), ); - // Add custom keybinding for Ctrl+F to accept hint (like fish shell) + // Add custom keybinding for autocompletion hint acceptance (configurable) + let autocompletion_key_char = match os.database.settings.get_string(Setting::AutocompletionKey) { + Some(key) if key.len() == 1 => key.chars().next().unwrap_or('g'), + _ => 'g', // Default to 'g' if setting is missing or invalid + }; rl.bind_sequence( - KeyEvent(KeyCode::Char('f'), Modifiers::CTRL), + KeyEvent(KeyCode::Char(autocompletion_key_char), Modifiers::CTRL), EventHandler::Simple(Cmd::CompleteHint), ); - // Add custom keybinding for Ctrl+T to toggle tangent mode (configurable) + // Add custom keybinding for Ctrl+t to toggle tangent mode (configurable) let tangent_key_char = match os.database.settings.get_string(Setting::TangentModeKey) { Some(key) if key.len() == 1 => key.chars().next().unwrap_or('t'), _ => 't', // Default to 't' if setting is missing or invalid @@ -722,4 +726,35 @@ mod tests { let hint = hinter.hint(line, pos, &ctx); assert_eq!(hint, None); } + + #[tokio::test] + // If you get a unit test failure for key override, please consider using a new key binding instead. + // The list of reserved keybindings here are the standard in UNIX world so please don't take them + async fn test_no_emacs_keybindings_overridden() { + let (sender, _) = tokio::sync::broadcast::channel::(1); + let (_, receiver) = tokio::sync::broadcast::channel::(1); + + // Create a mock Os for testing + let mock_os = crate::os::Os::new().await.unwrap(); + let mut test_editor = rl(&mock_os, sender, receiver).unwrap(); + + // Reserved Emacs keybindings that should not be overridden + let reserved_keys = ['a', 'e', 'f', 'b', 'k']; + + for &key in &reserved_keys { + let key_event = KeyEvent(KeyCode::Char(key), Modifiers::CTRL); + + // Try to bind and get the previous handler + let previous_handler = test_editor.bind_sequence( + key_event, + EventHandler::Simple(Cmd::Noop) + ); + + // If there was a previous handler, it means the key was already bound + // (which could be our custom binding overriding Emacs) + if previous_handler.is_some() { + panic!("Ctrl+{} appears to be overridden (found existing binding)", key); + } + } + } } diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs index 21e8e98097..bc005438cb 100644 --- a/crates/chat-cli/src/database/settings.rs +++ b/crates/chat-cli/src/database/settings.rs @@ -41,6 +41,8 @@ pub enum Setting { KnowledgeIndexType, #[strum(message = "Key binding for fuzzy search command (single character)")] SkimCommandKey, + #[strum(message = "Key binding for autocompletion hint acceptance (single character)")] + AutocompletionKey, #[strum(message = "Enable tangent mode feature (boolean)")] EnabledTangentMode, #[strum(message = "Key binding for tangent mode toggle (single character)")] @@ -94,6 +96,7 @@ impl AsRef for Setting { Self::KnowledgeChunkOverlap => "knowledge.chunkOverlap", Self::KnowledgeIndexType => "knowledge.indexType", Self::SkimCommandKey => "chat.skimCommandKey", + Self::AutocompletionKey => "chat.autocompletionKey", Self::EnabledTangentMode => "chat.enableTangentMode", Self::TangentModeKey => "chat.tangentModeKey", Self::IntrospectTangentMode => "introspect.tangentMode", @@ -139,6 +142,7 @@ impl TryFrom<&str> for Setting { "knowledge.chunkOverlap" => Ok(Self::KnowledgeChunkOverlap), "knowledge.indexType" => Ok(Self::KnowledgeIndexType), "chat.skimCommandKey" => Ok(Self::SkimCommandKey), + "chat.autocompletionKey" => Ok(Self::AutocompletionKey), "chat.enableTangentMode" => Ok(Self::EnabledTangentMode), "chat.tangentModeKey" => Ok(Self::TangentModeKey), "introspect.tangentMode" => Ok(Self::IntrospectTangentMode), From dd04793ef5785f3c917910c4008194f443ce9761 Mon Sep 17 00:00:00 2001 From: Erben Mo Date: Wed, 17 Sep 2025 18:02:12 -0700 Subject: [PATCH 48/71] feat: add support for preToolUse and postToolUse hook (#2875) --- crates/chat-cli/src/cli/agent/hook.rs | 16 +- crates/chat-cli/src/cli/agent/legacy/hooks.rs | 1 + crates/chat-cli/src/cli/agent/mod.rs | 77 +++- crates/chat-cli/src/cli/chat/cli/hooks.rs | 416 ++++++++++++++++-- crates/chat-cli/src/cli/chat/context.rs | 8 +- crates/chat-cli/src/cli/chat/conversation.rs | 12 +- crates/chat-cli/src/cli/chat/mod.rs | 348 +++++++++++++++ crates/chat-cli/src/cli/chat/tool_manager.rs | 5 +- .../src/cli/chat/tools/custom_tool.rs | 8 +- .../chat-cli/src/cli/chat/tools/introspect.rs | 5 + crates/chat-cli/src/cli/chat/tools/mod.rs | 1 + docs/agent-format.md | 31 +- docs/hooks.md | 161 +++++++ 13 files changed, 1037 insertions(+), 52 deletions(-) create mode 100644 docs/hooks.md diff --git a/crates/chat-cli/src/cli/agent/hook.rs b/crates/chat-cli/src/cli/agent/hook.rs index 1cb899f5a7..89ca74146b 100644 --- a/crates/chat-cli/src/cli/agent/hook.rs +++ b/crates/chat-cli/src/cli/agent/hook.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::fmt::Display; use schemars::JsonSchema; @@ -11,9 +10,6 @@ const DEFAULT_TIMEOUT_MS: u64 = 30_000; const DEFAULT_MAX_OUTPUT_SIZE: usize = 1024 * 10; const DEFAULT_CACHE_TTL_SECONDS: u64 = 0; -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, JsonSchema)] -pub struct Hooks(HashMap); - #[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, PartialEq, JsonSchema, Hash)] #[serde(rename_all = "camelCase")] pub enum HookTrigger { @@ -21,6 +17,10 @@ pub enum HookTrigger { AgentSpawn, /// Triggered per user message submission UserPromptSubmit, + /// Triggered before tool execution + PreToolUse, + /// Triggered after tool execution + PostToolUse, } impl Display for HookTrigger { @@ -28,6 +28,8 @@ impl Display for HookTrigger { match self { HookTrigger::AgentSpawn => write!(f, "agentSpawn"), HookTrigger::UserPromptSubmit => write!(f, "userPromptSubmit"), + HookTrigger::PreToolUse => write!(f, "preToolUse"), + HookTrigger::PostToolUse => write!(f, "postToolUse"), } } } @@ -61,6 +63,11 @@ pub struct Hook { #[serde(default = "Hook::default_cache_ttl_seconds")] pub cache_ttl_seconds: u64, + /// Optional glob matcher for hook + /// Currently used for matching tool name of PreToolUse and PostToolUse hook + #[serde(skip_serializing_if = "Option::is_none")] + pub matcher: Option, + #[schemars(skip)] #[serde(default, skip_serializing)] pub source: Source, @@ -73,6 +80,7 @@ impl Hook { timeout_ms: Self::default_timeout_ms(), max_output_size: Self::default_max_output_size(), cache_ttl_seconds: Self::default_cache_ttl_seconds(), + matcher: None, source, } } diff --git a/crates/chat-cli/src/cli/agent/legacy/hooks.rs b/crates/chat-cli/src/cli/agent/legacy/hooks.rs index 2a7a639d7f..6929049b3a 100644 --- a/crates/chat-cli/src/cli/agent/legacy/hooks.rs +++ b/crates/chat-cli/src/cli/agent/legacy/hooks.rs @@ -80,6 +80,7 @@ impl From for Option { timeout_ms: value.timeout_ms, max_output_size: value.max_output_size, cache_ttl_seconds: value.cache_ttl_seconds, + matcher: None, source: Default::default(), }) } diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 9ef82c15f2..3dcc659203 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -959,6 +959,7 @@ mod tests { use serde_json::json; use super::*; + use crate::cli::agent::hook::Source; const INPUT: &str = r#" { "name": "some_agent", @@ -968,21 +969,21 @@ mod tests { "fetch": { "command": "fetch3.1", "args": [] }, "git": { "command": "git-mcp", "args": [] } }, - "tools": [ + "tools": [ "@git" ], "toolAliases": { "@gits/some_tool": "some_tool2" }, - "allowedTools": [ - "fs_read", + "allowedTools": [ + "fs_read", "@fetch", "@gits/git_status" ], - "resources": [ + "resources": [ "file://~/my-genai-prompts/unittest.md" ], - "toolsSettings": { + "toolsSettings": { "fs_write": { "allowedPaths": ["~/**"] }, "@git/git_status": { "git_user": "$GIT_USER" } } @@ -1353,4 +1354,70 @@ mod tests { assert_eq!(agents.get_active().and_then(|a| a.model.as_ref()), None); } + + #[test] + fn test_agent_with_hooks() { + let agent_json = json!({ + "name": "test-agent", + "hooks": { + "agentSpawn": [ + { + "command": "git status" + } + ], + "preToolUse": [ + { + "matcher": "fs_write", + "command": "validate-tool.sh" + }, + { + "matcher": "fs_read", + "command": "enforce-tdd.sh" + } + ], + "postToolUse": [ + { + "matcher": "fs_write", + "command": "format-python.sh" + } + ] + } + }); + + let agent: Agent = serde_json::from_value(agent_json).expect("Failed to deserialize agent"); + + // Verify agent name + assert_eq!(agent.name, "test-agent"); + + // Verify agentSpawn hook + assert!(agent.hooks.contains_key(&HookTrigger::AgentSpawn)); + let agent_spawn_hooks = &agent.hooks[&HookTrigger::AgentSpawn]; + assert_eq!(agent_spawn_hooks.len(), 1); + assert_eq!(agent_spawn_hooks[0].command, "git status"); + assert_eq!(agent_spawn_hooks[0].matcher, None); + + // Verify preToolUse hooks + assert!(agent.hooks.contains_key(&HookTrigger::PreToolUse)); + let pre_tool_hooks = &agent.hooks[&HookTrigger::PreToolUse]; + assert_eq!(pre_tool_hooks.len(), 2); + + assert_eq!(pre_tool_hooks[0].command, "validate-tool.sh"); + assert_eq!(pre_tool_hooks[0].matcher, Some("fs_write".to_string())); + + assert_eq!(pre_tool_hooks[1].command, "enforce-tdd.sh"); + assert_eq!(pre_tool_hooks[1].matcher, Some("fs_read".to_string())); + + // Verify postToolUse hooks + assert!(agent.hooks.contains_key(&HookTrigger::PostToolUse)); + + // Verify default values are set correctly + for hooks in agent.hooks.values() { + for hook in hooks { + assert_eq!(hook.timeout_ms, 30_000); + assert_eq!(hook.max_output_size, 10_240); + assert_eq!(hook.cache_ttl_seconds, 0); + assert_eq!(hook.source, Source::Agent); + } + } + } } diff --git a/crates/chat-cli/src/cli/chat/cli/hooks.rs b/crates/chat-cli/src/cli/chat/cli/hooks.rs index 38763285c6..96e1e842dc 100644 --- a/crates/chat-cli/src/cli/chat/cli/hooks.rs +++ b/crates/chat-cli/src/cli/chat/cli/hooks.rs @@ -37,6 +37,8 @@ use crate::cli::agent::hook::{ Hook, HookTrigger, }; +use crate::cli::agent::is_mcp_tool_ref; +use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::cli::chat::consts::AGENT_FORMAT_HOOKS_DOC_URL; use crate::cli::chat::util::truncate_safe; use crate::cli::chat::{ @@ -44,6 +46,47 @@ use crate::cli::chat::{ ChatSession, ChatState, }; +use crate::util::pattern_matching::matches_any_pattern; + +/// Hook execution result: (exit_code, output) +/// Output is stdout if exit_code is 0, stderr otherwise. +pub type HookOutput = (i32, String); + +/// Check if a hook matches a tool name based on its matcher pattern +fn hook_matches_tool(hook: &Hook, tool_name: &str) -> bool { + match &hook.matcher { + None => true, // No matcher means the hook runs for all tools + Some(pattern) => { + match pattern.as_str() { + "*" => true, // Wildcard matches all tools + "@builtin" => !is_mcp_tool_ref(tool_name), // Built-in tools are not MCP tools + _ => { + // If tool_name is MCP, check server pattern first + if is_mcp_tool_ref(tool_name) { + if let Some(server_name) = tool_name.strip_prefix('@').and_then(|s| s.split(MCP_SERVER_TOOL_DELIMITER).next()) { + let server_pattern = format!("@{}", server_name); + if pattern == &server_pattern { + return true; + } + } + } + + // Use matches_any_pattern for both MCP and built-in tools + let mut patterns = std::collections::HashSet::new(); + patterns.insert(pattern.clone()); + matches_any_pattern(&patterns, tool_name) + } + } + } + } +} + +#[derive(Debug, Clone)] +pub struct ToolContext { + pub tool_name: String, + pub tool_input: serde_json::Value, + pub tool_response: Option, +} #[derive(Debug, Clone)] pub struct CachedHook { @@ -74,22 +117,32 @@ impl HookExecutor { &mut self, hooks: HashMap>, output: &mut impl Write, + cwd: &str, prompt: Option<&str>, - ) -> Result, ChatError> { + tool_context: Option, + ) -> Result, ChatError> { let mut cached = vec![]; let mut futures = FuturesUnordered::new(); for hook in hooks .into_iter() .flat_map(|(trigger, hooks)| hooks.into_iter().map(move |hook| (trigger, hook))) { + // Filter hooks by tool matcher + if let Some(tool_ctx) = &tool_context { + if !hook_matches_tool(&hook.1, &tool_ctx.tool_name) { + continue; // Skip this hook - doesn't match tool + } + } + if let Some(cache) = self.get_cache(&hook) { - cached.push((hook.clone(), cache.clone())); + // Note: we only cache successful hook run. hence always using 0 as exit code for cached hook + cached.push((hook.clone(), (0, cache))); continue; } - futures.push(self.run_hook(hook, prompt)); + futures.push(self.run_hook(hook, cwd, prompt, tool_context.clone())); } - let mut complete = 0; + let mut complete = 0; // number of hooks that are run successfully with exit code 0 let total = futures.len(); let mut spinner = None; let spinner_text = |complete: usize, total: usize| { @@ -138,9 +191,25 @@ impl HookExecutor { } // Process results regardless of output enabled - if let Ok(output) = result { - complete += 1; - results.push((hook, output)); + if let Ok((exit_code, hook_output)) = &result { + // Print warning if exit code is not 0 + if *exit_code != 0 { + queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗ "), + style::ResetColor, + style::Print(format!("{} \"", hook.0)), + style::Print(&hook.1.command), + style::Print("\""), + style::SetForegroundColor(style::Color::Red), + style::Print(format!(" failed with exit code: {}, stderr: {})\n", exit_code, hook_output.trim_end())), + style::ResetColor, + )?; + } else { + complete += 1; + } + results.push((hook, result.unwrap())); } // Display ending summary or add a new spinner @@ -167,12 +236,17 @@ impl HookExecutor { drop(futures); // Fill cache with executed results, skipping what was already from cache - for ((trigger, hook), output) in &results { + for ((trigger, hook), (exit_code, output)) in &results { + if *exit_code != 0 { + continue; // Only cache successful hooks + } self.cache.insert((*trigger, hook.clone()), CachedHook { output: output.clone(), expiry: match trigger { HookTrigger::AgentSpawn => None, HookTrigger::UserPromptSubmit => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), + HookTrigger::PreToolUse => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), + HookTrigger::PostToolUse => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), }, }); } @@ -185,8 +259,10 @@ impl HookExecutor { async fn run_hook( &self, hook: (HookTrigger, Hook), + cwd: &str, prompt: Option<&str>, - ) -> ((HookTrigger, Hook), Result, Duration) { + tool_context: Option, + ) -> ((HookTrigger, Hook), Result, Duration) { let start_time = Instant::now(); let command = &hook.1.command; @@ -213,33 +289,61 @@ impl HookExecutor { let timeout = Duration::from_millis(hook.1.timeout_ms); - // Set USER_PROMPT environment variable if provided + // Generate hook command input in JSON format + let mut hook_input = serde_json::json!({ + "hook_event_name": hook.0.to_string(), + "cwd": cwd + }); + + // Set USER_PROMPT environment variable and add to JSON input if provided if let Some(prompt) = prompt { // Sanitize the prompt to avoid issues with special characters let sanitized_prompt = sanitize_user_prompt(prompt); cmd.env("USER_PROMPT", sanitized_prompt); + hook_input["prompt"] = serde_json::Value::String(prompt.to_string()); } - let command_future = cmd.output(); + // ToolUse specific input + if let Some(tool_ctx) = tool_context { + hook_input["tool_name"] = serde_json::Value::String(tool_ctx.tool_name); + hook_input["tool_input"] = tool_ctx.tool_input; + if let Some(response) = tool_ctx.tool_response { + hook_input["tool_response"] = response; + } + } + let json_input = serde_json::to_string(&hook_input).unwrap_or_default(); + + // Build a future for hook command w/ the JSON input passed in through STDIN + let command_future = async move { + let mut child = cmd.spawn()?; + if let Some(stdin) = child.stdin.take() { + use tokio::io::AsyncWriteExt; + let mut stdin = stdin; + let _ = stdin.write_all(json_input.as_bytes()).await; + let _ = stdin.shutdown().await; + } + child.wait_with_output().await + }; // Run with timeout let result = match tokio::time::timeout(timeout, command_future).await { - Ok(Ok(result)) => { - if result.status.success() { - let stdout = result.stdout.to_str_lossy(); - let stdout = format!( - "{}{}", - truncate_safe(&stdout, hook.1.max_output_size), - if stdout.len() > hook.1.max_output_size { - " ... truncated" - } else { - "" - } - ); - Ok(stdout) + Ok(Ok(output)) => { + let exit_code = output.status.code().unwrap_or(-1); + let raw_output = if exit_code == 0 { + output.stdout.to_str_lossy() } else { - Err(eyre!("command returned non-zero exit code: {}", result.status)) - } + output.stderr.to_str_lossy() + }; + let formatted_output = format!( + "{}{}", + truncate_safe(&raw_output, hook.1.max_output_size), + if raw_output.len() > hook.1.max_output_size { + " ... truncated" + } else { + "" + } + ); + Ok((exit_code, formatted_output)) }, Ok(Err(err)) => Err(eyre!("failed to execute command: {}", err)), Err(_) => Err(eyre!("command timed out after {} ms", timeout.as_millis())), @@ -330,3 +434,263 @@ impl HooksArgs { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + use crate::cli::agent::hook::{Hook, HookTrigger}; + use tempfile::TempDir; + + #[test] + fn test_hook_matches_tool() { + let hook_no_matcher = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: None, + source: crate::cli::agent::hook::Source::Session, + }; + + let fs_write_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("fs_write".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let fs_wildcard_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("fs_*".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let all_tools_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("*".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let builtin_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("@builtin".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let git_server_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("@git".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let git_status_hook = Hook { + command: "echo test".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("@git/status".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + // No matcher should match all tools + assert!(hook_matches_tool(&hook_no_matcher, "fs_write")); + assert!(hook_matches_tool(&hook_no_matcher, "execute_bash")); + assert!(hook_matches_tool(&hook_no_matcher, "@git/status")); + + // Exact matcher should only match exact tool + assert!(hook_matches_tool(&fs_write_hook, "fs_write")); + assert!(!hook_matches_tool(&fs_write_hook, "fs_read")); + + // Wildcard matcher should match pattern + assert!(hook_matches_tool(&fs_wildcard_hook, "fs_write")); + assert!(hook_matches_tool(&fs_wildcard_hook, "fs_read")); + assert!(!hook_matches_tool(&fs_wildcard_hook, "execute_bash")); + + // * should match all tools + assert!(hook_matches_tool(&all_tools_hook, "fs_write")); + assert!(hook_matches_tool(&all_tools_hook, "execute_bash")); + assert!(hook_matches_tool(&all_tools_hook, "@git/status")); + + // @builtin should match built-in tools only + assert!(hook_matches_tool(&builtin_hook, "fs_write")); + assert!(hook_matches_tool(&builtin_hook, "execute_bash")); + assert!(!hook_matches_tool(&builtin_hook, "@git/status")); + + // @git should match all git server tools + assert!(hook_matches_tool(&git_server_hook, "@git/status")); + assert!(!hook_matches_tool(&git_server_hook, "@other/tool")); + assert!(!hook_matches_tool(&git_server_hook, "fs_write")); + + // @git/status should match exact MCP tool + assert!(hook_matches_tool(&git_status_hook, "@git/status")); + assert!(!hook_matches_tool(&git_status_hook, "@git/commit")); + assert!(!hook_matches_tool(&git_status_hook, "fs_write")); + } + + #[tokio::test] + async fn test_hook_executor_with_tool_context() { + let mut executor = HookExecutor::new(); + let mut output = Vec::new(); + + // Create temp directory and file + let temp_dir = TempDir::new().unwrap(); + let test_file = temp_dir.path().join("hook_output.json"); + let test_file_str = test_file.to_string_lossy(); + + // Create a simple hook that writes JSON input to a file + #[cfg(unix)] + let command = format!("cat > {}", test_file_str); + #[cfg(windows)] + let command = format!("type > {}", test_file_str); + + let hook = Hook { + command, + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("fs_write".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let mut hooks = HashMap::new(); + hooks.insert(HookTrigger::PreToolUse, vec![hook]); + + let tool_context = ToolContext { + tool_name: "fs_write".to_string(), + tool_input: serde_json::json!({ + "command": "create", + "path": "/test/file.py" + }), + tool_response: None, + }; + + // Run the hook + let result = executor.run_hooks( + hooks, + &mut output, + ".", + None, + Some(tool_context) + ).await; + + assert!(result.is_ok()); + + // Verify the hook wrote the JSON input to the file + if let Ok(content) = std::fs::read_to_string(&test_file) { + let json: serde_json::Value = serde_json::from_str(&content).unwrap(); + assert_eq!(json["hook_event_name"], "preToolUse"); + assert_eq!(json["tool_name"], "fs_write"); + assert_eq!(json["tool_input"]["command"], "create"); + assert_eq!(json["cwd"], "."); + } + // TempDir automatically cleans up when dropped + } + + #[tokio::test] + async fn test_hook_filtering_no_match() { + let mut executor = HookExecutor::new(); + let mut output = Vec::new(); + + // Hook that matches execute_bash (should NOT run for fs_write tool call) + let execute_bash_hook = Hook { + command: "echo 'should not run'".to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("execute_bash".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let mut hooks = HashMap::new(); + hooks.insert(HookTrigger::PostToolUse, vec![execute_bash_hook]); + + let tool_context = ToolContext { + tool_name: "fs_write".to_string(), + tool_input: serde_json::json!({"command": "create"}), + tool_response: Some(serde_json::json!({"success": true})), + }; + + // Run the hooks + let result = executor.run_hooks( + hooks, + &mut output, + ".", // cwd - using current directory for now + None, // prompt - no user prompt for this test + Some(tool_context) + ).await; + + assert!(result.is_ok()); + let hook_results = result.unwrap(); + + // Should run 0 hooks because matcher doesn't match tool_name + assert_eq!(hook_results.len(), 0); + + // Output should be empty since no hooks ran + assert!(output.is_empty()); + } + + #[tokio::test] + async fn test_hook_exit_code_2() { + let mut executor = HookExecutor::new(); + let mut output = Vec::new(); + + // Create a hook that exits with code 2 and outputs to stderr + #[cfg(unix)] + let command = "echo 'Tool execution blocked by security policy' >&2; exit 2"; + #[cfg(windows)] + let command = "echo Tool execution blocked by security policy 1>&2 & exit /b 2"; + + let hook = Hook { + command: command.to_string(), + timeout_ms: 5000, + cache_ttl_seconds: 0, + max_output_size: 1000, + matcher: Some("fs_write".to_string()), + source: crate::cli::agent::hook::Source::Session, + }; + + let hooks = HashMap::from([ + (HookTrigger::PreToolUse, vec![hook]) + ]); + + let tool_context = ToolContext { + tool_name: "fs_write".to_string(), + tool_input: serde_json::json!({ + "command": "create", + "path": "/sensitive/file.py" + }), + tool_response: None, + }; + + let results = executor.run_hooks( + hooks, + &mut output, + ".", // cwd + None, // prompt + Some(tool_context) + ).await.unwrap(); + + // Should have one result + assert_eq!(results.len(), 1); + + let ((trigger, _hook), (exit_code, hook_output)) = &results[0]; + assert_eq!(*trigger, HookTrigger::PreToolUse); + assert_eq!(*exit_code, 2); + assert!(hook_output.contains("Tool execution blocked by security policy")); + } +} diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs index 1fdcc5e8ee..fa39f80397 100644 --- a/crates/chat-cli/src/cli/chat/context.rs +++ b/crates/chat-cli/src/cli/chat/context.rs @@ -16,6 +16,7 @@ use serde::{ use super::cli::model::context_window_tokens; use super::util::drop_matched_context_files; +use super::cli::hooks::HookOutput; use crate::cli::agent::Agent; use crate::cli::agent::hook::{ Hook, @@ -247,11 +248,14 @@ impl ContextManager { &mut self, trigger: HookTrigger, output: &mut impl Write, + os: &crate::os::Os, prompt: Option<&str>, - ) -> Result, ChatError> { + tool_context: Option, + ) -> Result, ChatError> { let mut hooks = self.hooks.clone(); hooks.retain(|t, _| *t == trigger); - self.hook_executor.run_hooks(hooks, output, prompt).await + let cwd = os.env.current_dir()?.to_string_lossy().to_string(); + self.hook_executor.run_hooks(hooks, output, &cwd, prompt, tool_context).await } } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 50025f229d..9179697807 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -29,6 +29,7 @@ use tracing::{ }; use super::cli::compact::CompactStrategy; +use super::cli::hooks::HookOutput; use super::cli::model::context_window_tokens; use super::consts::{ DUMMY_TOOL_NAME, @@ -563,12 +564,12 @@ impl ConversationState { let mut agent_spawn_context = None; if let Some(cm) = self.context_manager.as_mut() { let user_prompt = self.next_message.as_ref().and_then(|m| m.prompt()); - let agent_spawn = cm.run_hooks(HookTrigger::AgentSpawn, output, user_prompt).await?; + let agent_spawn = cm.run_hooks(HookTrigger::AgentSpawn, output, os, user_prompt, None /* tool_context */).await?; agent_spawn_context = format_hook_context(&agent_spawn, HookTrigger::AgentSpawn); if let (true, Some(next_message)) = (run_perprompt_hooks, self.next_message.as_mut()) { let per_prompt = cm - .run_hooks(HookTrigger::UserPromptSubmit, output, next_message.prompt()) + .run_hooks(HookTrigger::UserPromptSubmit, output, os, next_message.prompt(), None /* tool_context */) .await?; if let Some(ctx) = format_hook_context(&per_prompt, HookTrigger::UserPromptSubmit) { next_message.additional_context = ctx; @@ -1030,8 +1031,9 @@ impl From for ToolInputSchema { /// # Returns /// [Option::Some] if `hook_results` is not empty and at least one hook has content. Otherwise, /// [Option::None] -fn format_hook_context(hook_results: &[((HookTrigger, Hook), String)], trigger: HookTrigger) -> Option { - if hook_results.iter().all(|(_, content)| content.is_empty()) { +fn format_hook_context(hook_results: &[((HookTrigger, Hook), HookOutput)], trigger: HookTrigger) -> Option { + // Note: only format context when hook command exit code is 0 + if hook_results.iter().all(|(_, (exit_code, content))| *exit_code != 0 || content.is_empty()) { return None; } @@ -1044,7 +1046,7 @@ fn format_hook_context(hook_results: &[((HookTrigger, Hook), String)], trigger: } context_content.push_str("\n\n"); - for (_, output) in hook_results.iter().filter(|((h_trigger, _), _)| *h_trigger == trigger) { + for (_, (_, output)) in hook_results.iter().filter(|((h_trigger, _), (exit_code, _))| *h_trigger == trigger && *exit_code == 0) { context_content.push_str(&format!("{output}\n\n")); } context_content.push_str(CONTEXT_ENTRY_END_HEADER); diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 942653bf07..2da6b3d5ca 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -42,6 +42,7 @@ use clap::{ ValueEnum, }; use cli::compact::CompactStrategy; +use cli::hooks::ToolContext; use cli::model::{ find_model, get_available_models, @@ -2261,6 +2262,7 @@ impl ChatSession { }); } + // All tools are allowed now // Execute the requested tools. let mut tool_results = vec![]; let mut image_blocks: Vec = Vec::new(); @@ -2426,6 +2428,45 @@ impl ChatSession { } } + // Run PostToolUse hooks for all executed tools after we have the tool_results + if let Some(cm) = self.conversation.context_manager.as_mut() { + for result in &tool_results { + if let Some(tool) = self.tool_uses.iter().find(|t| t.id == result.tool_use_id) { + let content: Vec = result.content.iter().map(|block| { + match block { + ToolUseResultBlock::Text(text) => serde_json::Value::String(text.clone()), + ToolUseResultBlock::Json(json) => json.clone(), + } + }).collect(); + + let tool_response = match result.status { + ToolResultStatus::Success => serde_json::json!({"success": true, "result": content}), + ToolResultStatus::Error => serde_json::json!({"success": false, "error": content}), + }; + + let tool_context = ToolContext { + tool_name: match &tool.tool { + Tool::Custom(custom_tool) => custom_tool.namespaced_tool_name(), // for MCP tool, pass MCP name to the hook + _ => tool.name.clone(), + }, + tool_input: tool.tool_input.clone(), + tool_response: Some(tool_response), + }; + + // Here is how we handle postToolUse output: + // Exit code is 0: nothing. stdout is not shown to user. We don't support processing the PostToolUse hook output yet. + // Exit code is non-zero: display an error to user (already taken care of by the ContextManager.run_hooks) + let _ = cm.run_hooks( + crate::cli::agent::hook::HookTrigger::PostToolUse, + &mut std::io::stderr(), + os, + None, + Some(tool_context) + ).await; + } + } + } + if !image_blocks.is_empty() { let images = image_blocks.into_iter().map(|(block, _)| block).collect(); self.conversation.add_tool_results_with_images(tool_results, images); @@ -2761,6 +2802,7 @@ impl ChatSession { } } + // Validate the tool use request from LLM, including basic checks like fs_read file should exist, as well as user-defined preToolUse hook check. async fn validate_tools(&mut self, os: &Os, tool_uses: Vec) -> Result { let conv_id = self.conversation.conversation_id().to_owned(); debug!(?tool_uses, "Validating tool uses"); @@ -2770,6 +2812,7 @@ impl ChatSession { for tool_use in tool_uses { let tool_use_id = tool_use.id.clone(); let tool_use_name = tool_use.name.clone(); + let tool_input = tool_use.args.clone(); let mut tool_telemetry = ToolUseEventBuilder::new( conv_id.clone(), tool_use.id.clone(), @@ -2791,6 +2834,7 @@ impl ChatSession { name: tool_use_name, tool, accepted: false, + tool_input, }); }, Err(err) => { @@ -2862,9 +2906,75 @@ impl ChatSession { )); } + // Execute PreToolUse hooks for all validated tools + // The mental model is preToolHook is like validate tools, but its behavior can be customized by user + // Note that after preTookUse hook, user can still reject the took run + if let Some(cm) = self.conversation.context_manager.as_mut() { + for tool in &queued_tools { + let tool_context = ToolContext { + tool_name: match &tool.tool { + Tool::Custom(custom_tool) => custom_tool.namespaced_tool_name(), // for MCP tool, pass MCP name to the hook + _ => tool.name.clone(), + }, + tool_input: tool.tool_input.clone(), + tool_response: None, + }; + + let hook_results = cm.run_hooks( + crate::cli::agent::hook::HookTrigger::PreToolUse, + &mut std::io::stderr(), + os, + None, /* prompt */ + Some(tool_context) + ).await?; + + // Here is how we handle the preToolUse hook output: + // Exit code is 0: nothing. stdout is not shown to user. + // Exit code is 2: block the tool use. return stderr to LLM. show warning to user + // Other error: show warning to user. + + // Check for exit code 2 and add to tool_results + for (_, (exit_code, output)) in &hook_results { + if *exit_code == 2 { + tool_results.push(ToolUseResult { + tool_use_id: tool.id.clone(), + content: vec![ToolUseResultBlock::Text(format!("PreToolHook blocked the tool execution: {}", output))], + status: ToolResultStatus::Error, + }); + } + } + } + } + + // If we have any hook validation errors, return them immediately to the model + if !tool_results.is_empty() { + debug!(?tool_results, "Error found in PreToolUse hooks"); + for tool_result in &tool_results { + for block in &tool_result.content { + if let ToolUseResultBlock::Text(content) = block { + queue!( + self.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print(format!("{}\n", content)), + style::SetForegroundColor(Color::Reset), + )?; + } + } + } + + self.conversation.add_tool_results(tool_results); + return Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )); + } + self.tool_uses = queued_tools; self.pending_tool_index = Some(0); self.tool_turn_start_time = Some(Instant::now()); + Ok(ChatState::ExecuteTools) } @@ -3786,6 +3896,244 @@ mod tests { .unwrap(); } + // Integration test for PreToolUse hook functionality. + // + // In this integration test we create a preToolUse hook that logs tool info into a file + // and we run fs_read and verify the log is generated with the correct ToolContext data. + #[tokio::test] + async fn test_tool_hook_integration() { + use crate::cli::agent::hook::{Hook, HookTrigger}; + use std::collections::HashMap; + + let mut os = Os::new().await.unwrap(); + os.client.set_mock_output(serde_json::json!([ + [ + "I'll read that file for you", + { + "tool_use_id": "1", + "name": "fs_read", + "args": { + "operations": [ + { + "mode": "Line", + "path": "/test.txt", + "start_line": 1, + "end_line": 3 + } + ] + } + } + ], + [ + "Here's the file content!", + ], + ])); + + // Create test file + os.fs.write("/test.txt", "line1\nline2\nline3\n").await.unwrap(); + + // Create agent with PreToolUse and PostToolUse hooks + let mut agents = Agents::default(); + let mut hooks = HashMap::new(); + + // Get the real path in the temp directory for the hooks to write to + let pre_hook_log_path = os.fs.chroot_path_str("/pre-hook-test.log"); + let post_hook_log_path = os.fs.chroot_path_str("/post-hook-test.log"); + let pre_hook_command = format!("cat > {}", pre_hook_log_path); + let post_hook_command = format!("cat > {}", post_hook_log_path); + + hooks.insert(HookTrigger::PreToolUse, vec![Hook { + command: pre_hook_command, + timeout_ms: 5000, + max_output_size: 1024, + cache_ttl_seconds: 0, + matcher: Some("fs_*".to_string()), // Match fs_read, fs_write, etc. + source: crate::cli::agent::hook::Source::Agent, + }]); + + hooks.insert(HookTrigger::PostToolUse, vec![Hook { + command: post_hook_command, + timeout_ms: 5000, + max_output_size: 1024, + cache_ttl_seconds: 0, + matcher: Some("fs_*".to_string()), // Match fs_read, fs_write, etc. + source: crate::cli::agent::hook::Source::Agent, + }]); + + let agent = Agent { + name: "TestAgent".to_string(), + hooks, + ..Default::default() + }; + agents.agents.insert("TestAgent".to_string(), agent); + agents.switch("TestAgent").expect("Failed to switch agent"); + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + + // Test that PreToolUse hook runs + ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "fake_conv_id", + agents, + None, // No initial input + InputSource::new_mock(vec![ + "read /test.txt".to_string(), + "y".to_string(), // Accept tool execution + "exit".to_string(), + ]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + None, + ) + .await + .unwrap() + .spawn(&mut os) + .await + .unwrap(); + + // Verify the PreToolUse hook was called + if let Ok(pre_log_content) = os.fs.read_to_string("/pre-hook-test.log").await { + let pre_hook_data: serde_json::Value = serde_json::from_str(&pre_log_content) + .expect("PreToolUse hook output should be valid JSON"); + + assert_eq!(pre_hook_data["hook_event_name"], "preToolUse"); + assert_eq!(pre_hook_data["tool_name"], "fs_read"); + assert_eq!(pre_hook_data["tool_response"], serde_json::Value::Null); + + let tool_input = &pre_hook_data["tool_input"]; + assert!(tool_input["operations"].is_array()); + + println!("✓ PreToolUse hook validation passed: {}", pre_log_content); + } else { + panic!("PreToolUse hook log file not found - hook may not have been called"); + } + + // Verify the PostToolUse hook was called + if let Ok(post_log_content) = os.fs.read_to_string("/post-hook-test.log").await { + let post_hook_data: serde_json::Value = serde_json::from_str(&post_log_content) + .expect("PostToolUse hook output should be valid JSON"); + + assert_eq!(post_hook_data["hook_event_name"], "postToolUse"); + assert_eq!(post_hook_data["tool_name"], "fs_read"); + + // Validate tool_response structure for successful execution + let tool_response = &post_hook_data["tool_response"]; + assert_eq!(tool_response["success"], true); + assert!(tool_response["result"].is_array()); + + let result_blocks = tool_response["result"].as_array().unwrap(); + assert!(!result_blocks.is_empty()); + let content = result_blocks[0].as_str().unwrap(); + assert!(content.contains("line1\nline2\nline3")); + + println!("✓ PostToolUse hook validation passed: {}", post_log_content); + } else { + panic!("PostToolUse hook log file not found - hook may not have been called"); + } + } + + #[tokio::test] + async fn test_pretool_hook_blocking_integration() { + use crate::cli::agent::hook::{Hook, HookTrigger}; + use std::collections::HashMap; + + let mut os = Os::new().await.unwrap(); + + // Create a test file to read + os.fs.write("/sensitive.txt", "classified information").await.unwrap(); + + // Mock LLM responses: first tries fs_read, gets blocked, then responds to error + os.client.set_mock_output(serde_json::json!([ + [ + "I'll read that file for you", + { + "tool_use_id": "1", + "name": "fs_read", + "args": { + "operations": [ + { + "mode": "Line", + "path": "/sensitive.txt" + } + ] + } + } + ], + [ + "I understand the security policy blocked access to that file.", + ], + ])); + + // Create agent with blocking PreToolUse hook + let mut agents = Agents::default(); + let mut hooks = HashMap::new(); + + // Create a hook that blocks fs_read of sensitive files with exit code 2 + #[cfg(unix)] + let hook_command = "echo 'Security policy violation: cannot read sensitive files' >&2; exit 2"; + #[cfg(windows)] + let hook_command = "echo Security policy violation: cannot read sensitive files 1>&2 & exit /b 2"; + + hooks.insert(HookTrigger::PreToolUse, vec![Hook { + command: hook_command.to_string(), + timeout_ms: 5000, + max_output_size: 1024, + cache_ttl_seconds: 0, + matcher: Some("fs_read".to_string()), + source: crate::cli::agent::hook::Source::Agent, + }]); + + let agent = Agent { + name: "SecurityAgent".to_string(), + hooks, + ..Default::default() + }; + agents.agents.insert("SecurityAgent".to_string(), agent); + agents.switch("SecurityAgent").expect("Failed to switch agent"); + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + + // Run chat session - hook should block tool execution + let result = ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "test_conv_id", + agents, + None, + InputSource::new_mock(vec![ + "read /sensitive.txt".to_string(), + "exit".to_string(), + ]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + None, + ) + .await + .unwrap() + .spawn(&mut os) + .await; + + // The session should complete successfully (hook blocks tool but doesn't crash) + assert!(result.is_ok(), "Chat session should complete successfully even when hook blocks tool"); + } + #[test] fn test_does_input_reference_file() { let tests = &[ diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 7f72b874a6..30c0473a77 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -99,8 +99,7 @@ use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::util::directories::home_dir; const NAMESPACE_DELIMITER: &str = "___"; -// This applies for both mcp server and tool name since in the end the tool name as seen by the -// model is just {server_name}{NAMESPACE_DELIMITER}{tool_name} +// This applies for both mcp server and tool name const VALID_TOOL_NAME: &str = "^[a-zA-Z][a-zA-Z0-9_]*$"; const SPINNER_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; @@ -873,7 +872,7 @@ impl ToolManager { "thinking" => Tool::Thinking(serde_json::from_value::(value.args).map_err(map_err)?), "knowledge" => Tool::Knowledge(serde_json::from_value::(value.args).map_err(map_err)?), "todo_list" => Tool::Todo(serde_json::from_value::(value.args).map_err(map_err)?), - // Note that this name is namespaced with server_name{DELIMITER}tool_name + // Note that this name is NO LONGER namespaced with server_name{DELIMITER}tool_name name => { // Note: tn_map also has tools that underwent no transformation. In otherwords, if // it is a valid tool name, we should get a hit. diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 907f70d35c..45b37b2103 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -96,6 +96,11 @@ pub struct CustomTool { } impl CustomTool { + /// Returns the full tool name with server prefix in the format @server_name/tool_name + pub fn namespaced_tool_name(&self) -> String { + format!("@{}{}{}", self.server_name, MCP_SERVER_TOOL_DELIMITER, self.name) + } + pub async fn invoke(&self, _os: &Os, _updates: &mut impl Write) -> Result { let params = CallToolRequestParam { name: Cow::from(self.name.clone()), @@ -157,7 +162,6 @@ impl CustomTool { } pub fn eval_perm(&self, _os: &Os, agent: &Agent) -> PermissionEvalResult { - let Self { name: tool_name, .. } = self; let server_name = &self.server_name; let server_pattern = format!("@{server_name}"); @@ -165,7 +169,7 @@ impl CustomTool { return PermissionEvalResult::Allow; } - let tool_pattern = format!("@{server_name}{MCP_SERVER_TOOL_DELIMITER}{tool_name}"); + let tool_pattern = self.namespaced_tool_name(); if matches_any_pattern(&agent.allowed_tools, &tool_pattern) { return PermissionEvalResult::Allow; } diff --git a/crates/chat-cli/src/cli/chat/tools/introspect.rs b/crates/chat-cli/src/cli/chat/tools/introspect.rs index 9e42ec0f6e..f14351ab54 100644 --- a/crates/chat-cli/src/cli/chat/tools/introspect.rs +++ b/crates/chat-cli/src/cli/chat/tools/introspect.rs @@ -71,6 +71,9 @@ impl Introspect { documentation.push_str("\n\n--- docs/todo-lists.md ---\n"); documentation.push_str(include_str!("../../../../../../docs/todo-lists.md")); + documentation.push_str("\n\n--- docs/hooks.md ---\n"); + documentation.push_str(include_str!("../../../../../../docs/hooks.md")); + documentation.push_str("\n\n--- changelog (from feed.json) ---\n"); // Include recent changelog entries from feed.json let feed = crate::cli::feed::Feed::load(); @@ -120,6 +123,8 @@ impl Introspect { ); documentation .push_str("• Todo Lists: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/todo-lists.md\n"); + documentation + .push_str("• Hooks: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/hooks.md\n"); documentation .push_str("• Contributing: https://github.com/aws/amazon-q-developer-cli/blob/main/CONTRIBUTING.md\n"); diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index 43ccb006cf..96cc0d76f2 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -271,6 +271,7 @@ pub struct QueuedTool { pub name: String, pub accepted: bool, pub tool: Tool, + pub tool_input: serde_json::Value, } /// The schema specification describing a tool's fields. diff --git a/docs/agent-format.md b/docs/agent-format.md index c707b476e7..35adc8cfaa 100644 --- a/docs/agent-format.md +++ b/docs/agent-format.md @@ -256,19 +256,37 @@ Resources can include: ## Hooks Field -The `hooks` field defines commands to run at specific trigger points. The output of these commands is added to the agent's context. +The `hooks` field defines commands to run at specific trigger points during agent lifecycle and tool execution. + +For detailed information about hook behavior, input/output formats, and examples, see the [Hooks documentation](hooks.md). ```json { "hooks": { "agentSpawn": [ { - "command": "git status", + "command": "git status" } ], "userPromptSubmit": [ { - "command": "ls -la", + "command": "ls -la" + } + ], + "preToolUse": [ + { + "matcher": "execute_bash", + "command": "{ echo \"$(date) - Bash command:\"; cat; echo; } >> /tmp/bash_audit_log" + }, + { + "matcher": "use_aws", + "command": "{ echo \"$(date) - AWS CLI call:\"; cat; echo; } >> /tmp/aws_audit_log" + } + ], + "postToolUse": [ + { + "matcher": "fs_write", + "command": "cargo fmt --all" } ] } @@ -277,10 +295,13 @@ The `hooks` field defines commands to run at specific trigger points. The output Each hook is defined with: - `command` (required): The command to execute +- `matcher` (optional): Pattern to match tool names for `preToolUse` and `postToolUse` hooks. See [built-in tools documentation](./built-in-tools.md) for available tool names. Available hook triggers: -- `agentSpawn`: Triggered when the agent is initialized -- `userPromptSubmit`: Triggered when the user submits a message +- `agentSpawn`: Triggered when the agent is initialized. +- `userPromptSubmit`: Triggered when the user submits a message. +- `preToolUse`: Triggered before a tool is executed. Can block the tool use. +- `postToolUse`: Triggered after a tool is executed. ## UseLegacyMcpJson Field diff --git a/docs/hooks.md b/docs/hooks.md new file mode 100644 index 0000000000..4a0636a06b --- /dev/null +++ b/docs/hooks.md @@ -0,0 +1,161 @@ +# Hooks + +Hooks allow you to execute custom commands at specific points during agent lifecycle and tool execution. This enables security validation, logging, formatting, context gathering, and other custom behaviors. + +## Defining Hooks + +Hooks are defined in the agent configuration file. See the [agent format documentation](agent-format.md#hooks-field) for the complete syntax and examples. + +## Hook Event + +Hooks receive hook event in JSON format via STDIN: + +```json +{ + "hook_event_name": "agentSpawn", + "cwd": "/current/working/directory" +} +``` + +For tool-related hooks, additional fields are included: +- `tool_name`: Name of the tool being executed +- `tool_input`: Tool-specific parameters (see individual tool documentation) +- `tool_response`: Tool execution results (PostToolUse only) + +## Hook Output + +- **Exit code 0**: Hook succeeded. STDOUT is captured but not shown to user. +- **Exit code 2**: (PreToolUse only) Block tool execution. STDERR is returned to the LLM. +- **Other exit codes**: Hook failed. STDERR is shown as warning to user. + +## Tool Matching + +Use the `matcher` field to specify which tools the hook applies to: + +### Examples +- `"fs_write"` - Exact match for built-in tools +- `"fs_*"` - Wildcard pattern for built-in tools +- `"@git"` - All tools from git MCP server +- `"@git/status"` - Specific tool from git MCP server +- `"*"` - All tools (built-in and MCP) +- `"@builtin"` - All built-in tools only +- No matcher - Applies to all tools + +For complete tool reference format, see [agent format documentation](agent-format.md#tools-field). + +## Hook Types + +### AgentSpawn + +Runs when agent is activated. No tool context provided. + +**Hook Event** +```json +{ + "hook_event_name": "agentSpawn", + "cwd": "/current/working/directory" +} +``` + +**Exit Code Behavior:** +- **0**: Hook succeeded, STDOUT is added to agent's context +- **Other**: Show STDERR warning to user + +### UserPromptSubmit + +Runs when user submits a prompt. Output is added to conversation context. + +**Hook Event** +```json +{ + "hook_event_name": "userPromptSubmit", + "cwd": "/current/working/directory", + "prompt": "user's input prompt" +} +``` + +**Exit Code Behavior:** +- **0**: Hook succeeded, STDOUT is added to agent's context +- **Other**: Show STDERR warning to user + +### PreToolUse + +Runs before tool execution. Can validate and block tool usage. + +**Hook Event** +```json +{ + "hook_event_name": "preToolUse", + "cwd": "/current/working/directory", + "tool_name": "fs_read", + "tool_input": { + "operations": [ + { + "mode": "Line", + "path": "/current/working/directory/docs/hooks.md" + } + ] + } +} +``` + +**Exit Code Behavior:** +- **0**: Allow tool execution. +- **2**: Block tool execution, return STDERR to LLM. +- **Other**: Show STDERR warning to user, allow tool execution. + +### PostToolUse + +Runs after tool execution with access to tool results. + +**Hook Event** +```json +{ + "hook_event_name": "postToolUse", + "cwd": "/current/working/directory", + "tool_name": "fs_read", + "tool_input": { + "operations": [ + { + "mode": "Line", + "path": "/current/working/directory/docs/hooks.md" + } + ] + }, + "tool_response": { + "success": true, + "result": ["# Hooks\n\nHooks allow you to execute..."] + } +} +``` + +**Exit Code Behavior:** +- **0**: Hook succeeded. +- **Other**: Show STDERR warning to user. Tool already ran. + +### MCP Example + +For MCP tools, the tool name includes the full namespaced format including the MCP Server name: + +**Hook Event** +```json +{ + "hook_event_name": "preToolUse", + "cwd": "/current/working/directory", + "tool_name": "@postgres/query", + "tool_input": { + "sql": "SELECT * FROM orders LIMIT 10;" + } +} +``` + +## Timeout + +Default timeout is 30 seconds (30,000ms). Configure with `timeout_ms` field. + +## Caching + +Successfull hook results are cached based on `cache_ttl_seconds`: +- `0`: No caching (default) +- `> 0`: Cache successful results for specified seconds +- AgentSpawn hooks are never cached \ No newline at end of file From 5d0c62beb4b0847cbd0417be7a3381e9077fe381 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Thu, 18 Sep 2025 09:48:17 -0700 Subject: [PATCH 49/71] fix(mcp): oauth issues (#2925) * fix incorrect scope for mcp oauth * reverts custom tool config enum change * fixes display task overriding sign in notice * updates schema --- crates/chat-cli/src/cli/chat/prompt.rs | 5 +- crates/chat-cli/src/cli/chat/tool_manager.rs | 1 + .../src/cli/chat/tools/custom_tool.rs | 21 ++++-- crates/chat-cli/src/cli/mcp.rs | 2 +- crates/chat-cli/src/mcp_client/client.rs | 40 +++++++---- crates/chat-cli/src/mcp_client/oauth_util.rs | 66 +++++++++++++++---- schemas/agent-v1.json | 21 ++++++ 7 files changed, 125 insertions(+), 31 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index 06d449fdce..2ee6640baa 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -745,10 +745,7 @@ mod tests { let key_event = KeyEvent(KeyCode::Char(key), Modifiers::CTRL); // Try to bind and get the previous handler - let previous_handler = test_editor.bind_sequence( - key_event, - EventHandler::Simple(Cmd::Noop) - ); + let previous_handler = test_editor.bind_sequence(key_event, EventHandler::Simple(Cmd::Noop)); // If there was a previous handler, it means the key was already bound // (which could be our custom binding overriding Emacs) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 30c0473a77..6ef93bea26 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -1210,6 +1210,7 @@ fn spawn_display_task( terminal::Clear(terminal::ClearType::CurrentLine), )?; queue_oauth_message(&name, &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; }, }, Err(_e) => { diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 45b37b2103..5457c9ee17 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -22,7 +22,10 @@ use crate::cli::agent::{ }; use crate::cli::chat::CONTINUATION_LINE; use crate::cli::chat::token_counter::TokenCounter; -use crate::mcp_client::RunningService; +use crate::mcp_client::{ + RunningService, + oauth_util, +}; use crate::os::Os; use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::util::pattern_matching::matches_any_pattern; @@ -43,17 +46,20 @@ impl Default for TransportType { } #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] pub struct CustomToolConfig { - /// The type of transport the mcp server is expecting. For http transport, only url (for now) - /// is taken into account. + /// The transport type to use for communication with the MCP server #[serde(default)] pub r#type: TransportType, - /// The URL endpoint for HTTP-based MCP servers + /// The URL for HTTP-based MCP server communication #[serde(default)] pub url: String, /// HTTP headers to include when communicating with HTTP-based MCP servers #[serde(default)] pub headers: HashMap, + /// Scopes with which oauth is done + #[serde(default = "get_default_scopes")] + pub oauth_scopes: Vec, /// The command string used to initialize the mcp server #[serde(default)] pub command: String, @@ -74,6 +80,13 @@ pub struct CustomToolConfig { pub is_from_legacy_mcp_json: bool, } +pub fn get_default_scopes() -> Vec { + oauth_util::get_default_scopes() + .iter() + .map(|s| (*s).to_string()) + .collect::>() +} + pub fn default_timeout() -> u64 { 120 * 1000 } diff --git a/crates/chat-cli/src/cli/mcp.rs b/crates/chat-cli/src/cli/mcp.rs index 0b709ef887..f0e8b97886 100644 --- a/crates/chat-cli/src/cli/mcp.rs +++ b/crates/chat-cli/src/cli/mcp.rs @@ -406,7 +406,7 @@ impl StatusArgs { style::Print(format!("Disabled: {}\n", cfg.disabled)), style::Print(format!( "Env Vars: {}\n", - cfg.env.as_ref().map_or_else( + cfg.env.map_or_else( || "(none)".into(), |e| e .iter() diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 4156cf34ae..b05ed3ad24 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -151,6 +151,8 @@ pub enum McpClientError { Parse(#[from] url::ParseError), #[error(transparent)] Auth(#[from] crate::auth::AuthError), + #[error("{0}")] + MalformedConfig(&'static str), } /// Decorates the method passed in with retry logic, but only if the [RunningService] has an @@ -336,7 +338,7 @@ impl McpClientService { HttpTransport::WithAuth((transport, mut auth_client)) => { // The crate does not automatically refresh tokens when they expire. We // would need to handle that here - let url = self.config.url.clone(); + let url = &backup_config.url; let service = match self.into_dyn().serve(transport).await.map_err(Box::new) { Ok(service) => service, Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => { @@ -344,12 +346,15 @@ impl McpClientService { let refresh_res = auth_client.refresh_token().await; let new_self = McpClientService::new( server_name.clone(), - backup_config, + backup_config.clone(), messenger_clone.clone(), ); + let scopes = &backup_config.oauth_scopes; + let timeout = backup_config.timeout; + let headers = &backup_config.headers; let new_transport = - get_http_transport(&os_clone, &url, Some(auth_client.auth_client.clone()), &*messenger_dup).await?; + get_http_transport(&os_clone, url, timeout, scopes, headers,Some(auth_client.auth_client.clone()), &*messenger_dup).await?; match new_transport { HttpTransport::WithAuth((new_transport, new_auth_client)) => { @@ -367,8 +372,8 @@ impl McpClientService { // again. We do this by deleting the cred // and discarding the client to trigger a full auth flow tokio::fs::remove_file(&auth_client.cred_full_path).await?; - let new_transport = - get_http_transport(&os_clone, &url, None, &*messenger_dup).await?; + let new_transport = + get_http_transport(&os_clone, url, timeout, scopes,headers,None, &*messenger_dup).await?; match new_transport { HttpTransport::WithAuth((new_transport, new_auth_client)) => { @@ -495,17 +500,32 @@ impl McpClientService { } async fn get_transport(&mut self, os: &Os, messenger: &dyn Messenger) -> Result { - // TODO: figure out what to do with headers let CustomToolConfig { - r#type: transport_type, + r#type, url, + headers, + oauth_scopes: scopes, command: command_as_str, args, env: config_envs, + timeout, .. } = &mut self.config; - match transport_type { + let is_malformed_http = matches!(r#type, TransportType::Http) && url.is_empty(); + let is_malformed_stdio = matches!(r#type, TransportType::Stdio) && command_as_str.is_empty(); + + if is_malformed_http { + return Err(McpClientError::MalformedConfig( + "MCP config is malformed: transport type is specified to be http but url is empty", + )); + } else if is_malformed_stdio { + return Err(McpClientError::MalformedConfig( + "MCP config is malformed: transport type is specified to be stdio but command is empty", + )); + } + + match r#type { TransportType::Stdio => { let expanded_cmd = canonicalizes_path(os, command_as_str)?; let command = Command::new(expanded_cmd).configure(|cmd| { @@ -525,7 +545,7 @@ impl McpClientService { Ok(Transport::Stdio((tokio_child_process, child_stderr))) }, TransportType::Http => { - let http_transport = get_http_transport(os, url, None, messenger).await?; + let http_transport = get_http_transport(os, url, *timeout, scopes, headers, None, messenger).await?; Ok(Transport::Http(http_transport)) }, @@ -562,7 +582,6 @@ impl McpClientService { async fn on_tool_list_changed(&self, context: NotificationContext) { let NotificationContext { peer, .. } = context; - let _timeout = self.config.timeout; paginated_fetch! { final_result_type: ListToolsResult, @@ -578,7 +597,6 @@ impl McpClientService { async fn on_prompt_list_changed(&self, context: NotificationContext) { let NotificationContext { peer, .. } = context; - let _timeout = self.config.timeout; paginated_fetch! { final_result_type: ListPromptsResult, diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs index f284265cf0..9c53193922 100644 --- a/crates/chat-cli/src/mcp_client/oauth_util.rs +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -1,10 +1,14 @@ +use std::collections::HashMap; use std::net::SocketAddr; use std::path::PathBuf; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; -use http::StatusCode; +use http::{ + HeaderMap, + StatusCode, +}; use http_body_util::Full; use hyper::Response; use hyper::body::Bytes; @@ -69,6 +73,8 @@ pub enum OauthUtilError { Directory(#[from] DirectoryError), #[error(transparent)] Reqwest(#[from] reqwest::Error), + #[error("{0}")] + Http(String), #[error("Malformed directory")] MalformDirectory, #[error("Missing credential")] @@ -162,13 +168,16 @@ pub enum HttpTransport { WithoutAuth(WorkerTransport>), } -fn get_scopes() -> &'static [&'static str] { - &["openid", "mcp", "email", "profile"] +pub fn get_default_scopes() -> &'static [&'static str] { + &["openid", "email", "profile", "offline_access"] } pub async fn get_http_transport( os: &Os, url: &str, + timeout: u64, + scopes: &[String], + headers: &HashMap, auth_client: Option>, messenger: &dyn Messenger, ) -> Result { @@ -178,7 +187,13 @@ pub async fn get_http_transport( let cred_full_path = cred_dir.join(format!("{key}.token.json")); let reg_full_path = cred_dir.join(format!("{key}.registration.json")); - let reqwest_client = reqwest::Client::default(); + let mut client_builder = reqwest::ClientBuilder::new().timeout(std::time::Duration::from_millis(timeout)); + if !headers.is_empty() { + let headers = HeaderMap::try_from(headers).map_err(|e| OauthUtilError::Http(e.to_string()))?; + client_builder = client_builder.default_headers(headers); + }; + let reqwest_client = client_builder.build()?; + let probe_resp = reqwest_client.get(url.clone()).send().await?; match probe_resp.status() { StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { @@ -186,8 +201,14 @@ pub async fn get_http_transport( let auth_client = match auth_client { Some(auth_client) => auth_client, None => { - let am = - get_auth_manager(url.clone(), cred_full_path.clone(), reg_full_path.clone(), messenger).await?; + let am = get_auth_manager( + url.clone(), + cred_full_path.clone(), + reg_full_path.clone(), + scopes, + messenger, + ) + .await?; AuthClient::new(reqwest_client, am) }, }; @@ -204,7 +225,12 @@ pub async fn get_http_transport( Ok(HttpTransport::WithAuth((transport, auth_dg))) }, _ => { - let transport = StreamableHttpClientTransport::from_uri(url.as_str()); + let transport = + StreamableHttpClientTransport::with_client(reqwest_client, StreamableHttpClientTransportConfig { + uri: url.as_str().into(), + allow_stateless: false, + ..Default::default() + }); Ok(HttpTransport::WithoutAuth(transport)) }, @@ -215,6 +241,7 @@ async fn get_auth_manager( url: Url, cred_full_path: PathBuf, reg_full_path: PathBuf, + scopes: &[String], messenger: &dyn Messenger, ) -> Result { let cred_as_bytes = tokio::fs::read(&cred_full_path).await; @@ -237,7 +264,7 @@ async fn get_auth_manager( _ => { info!("Error reading cached credentials"); debug!("## mcp: cache read failed. constructing auth manager from scratch"); - let (am, redirect_uri) = get_auth_manager_impl(oauth_state, messenger).await?; + let (am, redirect_uri) = get_auth_manager_impl(oauth_state, scopes, messenger).await?; // Client registration is done in [start_authorization] // If we have gotten past that point that means we have the info to persist the @@ -246,7 +273,10 @@ async fn get_auth_manager( let reg = Registration { client_id, client_secret: None, - scopes: get_scopes().iter().map(|s| (*s).to_string()).collect::>(), + scopes: get_default_scopes() + .iter() + .map(|s| (*s).to_string()) + .collect::>(), redirect_uri, }; let reg_as_str = serde_json::to_string_pretty(®)?; @@ -268,6 +298,7 @@ async fn get_auth_manager( async fn get_auth_manager_impl( mut oauth_state: OAuthState, + scopes: &[String], messenger: &dyn Messenger, ) -> Result<(AuthorizationManager, String), OauthUtilError> { let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0)); @@ -278,7 +309,9 @@ async fn get_auth_manager_impl( info!("Listening on local host port {:?} for oauth", actual_addr); let redirect_uri = format!("http://{}", actual_addr); - oauth_state.start_authorization(get_scopes(), &redirect_uri).await?; + let scopes_as_str = scopes.iter().map(String::as_str).collect::>(); + let scopes_as_slice = scopes_as_str.as_slice(); + oauth_state.start_authorization(scopes_as_slice, &redirect_uri).await?; let auth_url = oauth_state.get_authorization_url().await?; _ = messenger.send_oauth_link(auth_url).await; @@ -333,9 +366,19 @@ async fn make_svc( let query = uri.query().unwrap_or(""); let params: std::collections::HashMap = url::form_urlencoded::parse(query.as_bytes()).into_owned().collect(); + debug!("## mcp: uri: {}, query: {}, params: {:?}", uri, query, params); let self_clone = self.clone(); Box::pin(async move { + let error = params.get("error"); + let resp = if let Some(err) = error { + mk_response(format!( + "Oauth failed. Check url for precise reasons. Possible reasons: {err}.\nIf this is scope related. You can try configuring the server scopes to be an empty array via adding oauth_scopes: []" + )) + } else { + mk_response("You can close this page now".to_string()) + }; + let code = params.get("code").cloned().unwrap_or_default(); if let Some(sender) = self_clone .one_shot_sender @@ -345,7 +388,8 @@ async fn make_svc( { sender.send(code).map_err(LoopBackError::Send)?; } - mk_response("You can close this page now".to_string()) + + resp }) } } diff --git a/schemas/agent-v1.json b/schemas/agent-v1.json index c7b63492d3..5e72b08476 100644 --- a/schemas/agent-v1.json +++ b/schemas/agent-v1.json @@ -73,6 +73,27 @@ "type": "string", "default": "" }, + "headers": { + "description": "HTTP headers to include when communicating with HTTP-based MCP servers", + "type": "object", + "additionalProperties": { + "type": "string" + }, + "default": {} + }, + "oauthScopes": { + "description": "Scopes with which oauth is done", + "type": "array", + "items": { + "type": "string" + }, + "default": [ + "openid", + "email", + "profile", + "offline_access" + ] + }, "command": { "description": "The command string used to initialize the mcp server", "type": "string", From 61efff3a7441b3fcbbc6ec6cd4a6ed2470e5e62c Mon Sep 17 00:00:00 2001 From: nirajchowdhary <226941436+nirajchowdhary@users.noreply.github.com> Date: Thu, 18 Sep 2025 11:01:52 -0700 Subject: [PATCH 50/71] fix(chat): reset pending tool state when clearing conversation (#2855) Reset tool_uses, pending_tool_index, and tool_turn_start_time to prevent orphaned tool approval prompts after conversation history is cleared. Co-authored-by: Niraj Chowdhary --- crates/chat-cli/src/cli/chat/cli/clear.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crates/chat-cli/src/cli/chat/cli/clear.rs b/crates/chat-cli/src/cli/chat/cli/clear.rs index b31bccd28e..de274e2c01 100644 --- a/crates/chat-cli/src/cli/chat/cli/clear.rs +++ b/crates/chat-cli/src/cli/chat/cli/clear.rs @@ -52,6 +52,12 @@ impl ClearArgs { if let Some(cm) = session.conversation.context_manager.as_mut() { cm.hook_executor.cache.clear(); } + + // Reset pending tool state to prevent orphaned tool approval prompts + session.tool_uses.clear(); + session.pending_tool_index = None; + session.tool_turn_start_time = None; + execute!( session.stderr, style::SetForegroundColor(Color::Green), From 515b4df668b57cbdb0d69ccb6e4cdf089a85b0de Mon Sep 17 00:00:00 2001 From: kkashilk <93673379+kkashilk@users.noreply.github.com> Date: Thu, 18 Sep 2025 11:46:47 -0700 Subject: [PATCH 51/71] Trim region to avoid login failures (#2930) --- crates/chat-cli/src/cli/chat/cli/hooks.rs | 139 ++++++++++-------- crates/chat-cli/src/cli/chat/context.rs | 6 +- crates/chat-cli/src/cli/chat/conversation.rs | 28 +++- crates/chat-cli/src/cli/chat/mod.rs | 138 +++++++++-------- .../chat-cli/src/cli/chat/tools/introspect.rs | 3 +- crates/chat-cli/src/cli/user.rs | 2 +- docs/hooks.md | 2 +- 7 files changed, 185 insertions(+), 133 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/hooks.rs b/crates/chat-cli/src/cli/chat/cli/hooks.rs index 96e1e842dc..05baee5bb7 100644 --- a/crates/chat-cli/src/cli/chat/cli/hooks.rs +++ b/crates/chat-cli/src/cli/chat/cli/hooks.rs @@ -38,7 +38,6 @@ use crate::cli::agent::hook::{ HookTrigger, }; use crate::cli::agent::is_mcp_tool_ref; -use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::cli::chat::consts::AGENT_FORMAT_HOOKS_DOC_URL; use crate::cli::chat::util::truncate_safe; use crate::cli::chat::{ @@ -46,6 +45,7 @@ use crate::cli::chat::{ ChatSession, ChatState, }; +use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::util::pattern_matching::matches_any_pattern; /// Hook execution result: (exit_code, output) @@ -58,26 +58,29 @@ fn hook_matches_tool(hook: &Hook, tool_name: &str) -> bool { None => true, // No matcher means the hook runs for all tools Some(pattern) => { match pattern.as_str() { - "*" => true, // Wildcard matches all tools + "*" => true, // Wildcard matches all tools "@builtin" => !is_mcp_tool_ref(tool_name), // Built-in tools are not MCP tools _ => { // If tool_name is MCP, check server pattern first if is_mcp_tool_ref(tool_name) { - if let Some(server_name) = tool_name.strip_prefix('@').and_then(|s| s.split(MCP_SERVER_TOOL_DELIMITER).next()) { + if let Some(server_name) = tool_name + .strip_prefix('@') + .and_then(|s| s.split(MCP_SERVER_TOOL_DELIMITER).next()) + { let server_pattern = format!("@{}", server_name); if pattern == &server_pattern { return true; } } } - + // Use matches_any_pattern for both MCP and built-in tools let mut patterns = std::collections::HashSet::new(); patterns.insert(pattern.clone()); matches_any_pattern(&patterns, tool_name) - } + }, } - } + }, } } @@ -133,7 +136,7 @@ impl HookExecutor { continue; // Skip this hook - doesn't match tool } } - + if let Some(cache) = self.get_cache(&hook) { // Note: we only cache successful hook run. hence always using 0 as exit code for cached hook cached.push((hook.clone(), (0, cache))); @@ -203,7 +206,11 @@ impl HookExecutor { style::Print(&hook.1.command), style::Print("\""), style::SetForegroundColor(style::Color::Red), - style::Print(format!(" failed with exit code: {}, stderr: {})\n", exit_code, hook_output.trim_end())), + style::Print(format!( + " failed with exit code: {}, stderr: {})\n", + exit_code, + hook_output.trim_end() + )), style::ResetColor, )?; } else { @@ -437,11 +444,16 @@ impl HooksArgs { #[cfg(test)] mod tests { - use super::*; use std::collections::HashMap; - use crate::cli::agent::hook::{Hook, HookTrigger}; + use tempfile::TempDir; + use super::*; + use crate::cli::agent::hook::{ + Hook, + HookTrigger, + }; + #[test] fn test_hook_matches_tool() { let hook_no_matcher = Hook { @@ -452,7 +464,7 @@ mod tests { matcher: None, source: crate::cli::agent::hook::Source::Session, }; - + let fs_write_hook = Hook { command: "echo test".to_string(), timeout_ms: 5000, @@ -461,7 +473,7 @@ mod tests { matcher: Some("fs_write".to_string()), source: crate::cli::agent::hook::Source::Session, }; - + let fs_wildcard_hook = Hook { command: "echo test".to_string(), timeout_ms: 5000, @@ -470,7 +482,7 @@ mod tests { matcher: Some("fs_*".to_string()), source: crate::cli::agent::hook::Source::Session, }; - + let all_tools_hook = Hook { command: "echo test".to_string(), timeout_ms: 5000, @@ -479,7 +491,7 @@ mod tests { matcher: Some("*".to_string()), source: crate::cli::agent::hook::Source::Session, }; - + let builtin_hook = Hook { command: "echo test".to_string(), timeout_ms: 5000, @@ -488,7 +500,7 @@ mod tests { matcher: Some("@builtin".to_string()), source: crate::cli::agent::hook::Source::Session, }; - + let git_server_hook = Hook { command: "echo test".to_string(), timeout_ms: 5000, @@ -497,7 +509,7 @@ mod tests { matcher: Some("@git".to_string()), source: crate::cli::agent::hook::Source::Session, }; - + let git_status_hook = Hook { command: "echo test".to_string(), timeout_ms: 5000, @@ -506,36 +518,36 @@ mod tests { matcher: Some("@git/status".to_string()), source: crate::cli::agent::hook::Source::Session, }; - + // No matcher should match all tools assert!(hook_matches_tool(&hook_no_matcher, "fs_write")); assert!(hook_matches_tool(&hook_no_matcher, "execute_bash")); assert!(hook_matches_tool(&hook_no_matcher, "@git/status")); - + // Exact matcher should only match exact tool assert!(hook_matches_tool(&fs_write_hook, "fs_write")); assert!(!hook_matches_tool(&fs_write_hook, "fs_read")); - + // Wildcard matcher should match pattern assert!(hook_matches_tool(&fs_wildcard_hook, "fs_write")); assert!(hook_matches_tool(&fs_wildcard_hook, "fs_read")); assert!(!hook_matches_tool(&fs_wildcard_hook, "execute_bash")); - + // * should match all tools assert!(hook_matches_tool(&all_tools_hook, "fs_write")); assert!(hook_matches_tool(&all_tools_hook, "execute_bash")); assert!(hook_matches_tool(&all_tools_hook, "@git/status")); - + // @builtin should match built-in tools only assert!(hook_matches_tool(&builtin_hook, "fs_write")); assert!(hook_matches_tool(&builtin_hook, "execute_bash")); assert!(!hook_matches_tool(&builtin_hook, "@git/status")); - + // @git should match all git server tools assert!(hook_matches_tool(&git_server_hook, "@git/status")); assert!(!hook_matches_tool(&git_server_hook, "@other/tool")); assert!(!hook_matches_tool(&git_server_hook, "fs_write")); - + // @git/status should match exact MCP tool assert!(hook_matches_tool(&git_status_hook, "@git/status")); assert!(!hook_matches_tool(&git_status_hook, "@git/commit")); @@ -546,18 +558,18 @@ mod tests { async fn test_hook_executor_with_tool_context() { let mut executor = HookExecutor::new(); let mut output = Vec::new(); - + // Create temp directory and file let temp_dir = TempDir::new().unwrap(); let test_file = temp_dir.path().join("hook_output.json"); let test_file_str = test_file.to_string_lossy(); - + // Create a simple hook that writes JSON input to a file #[cfg(unix)] let command = format!("cat > {}", test_file_str); #[cfg(windows)] let command = format!("type > {}", test_file_str); - + let hook = Hook { command, timeout_ms: 5000, @@ -566,10 +578,10 @@ mod tests { matcher: Some("fs_write".to_string()), source: crate::cli::agent::hook::Source::Session, }; - + let mut hooks = HashMap::new(); hooks.insert(HookTrigger::PreToolUse, vec![hook]); - + let tool_context = ToolContext { tool_name: "fs_write".to_string(), tool_input: serde_json::json!({ @@ -578,18 +590,14 @@ mod tests { }), tool_response: None, }; - + // Run the hook - let result = executor.run_hooks( - hooks, - &mut output, - ".", - None, - Some(tool_context) - ).await; - + let result = executor + .run_hooks(hooks, &mut output, ".", None, Some(tool_context)) + .await; + assert!(result.is_ok()); - + // Verify the hook wrote the JSON input to the file if let Ok(content) = std::fs::read_to_string(&test_file) { let json: serde_json::Value = serde_json::from_str(&content).unwrap(); @@ -605,7 +613,7 @@ mod tests { async fn test_hook_filtering_no_match() { let mut executor = HookExecutor::new(); let mut output = Vec::new(); - + // Hook that matches execute_bash (should NOT run for fs_write tool call) let execute_bash_hook = Hook { command: "echo 'should not run'".to_string(), @@ -615,31 +623,33 @@ mod tests { matcher: Some("execute_bash".to_string()), source: crate::cli::agent::hook::Source::Session, }; - + let mut hooks = HashMap::new(); hooks.insert(HookTrigger::PostToolUse, vec![execute_bash_hook]); - + let tool_context = ToolContext { tool_name: "fs_write".to_string(), tool_input: serde_json::json!({"command": "create"}), tool_response: Some(serde_json::json!({"success": true})), }; - + // Run the hooks - let result = executor.run_hooks( - hooks, - &mut output, - ".", // cwd - using current directory for now - None, // prompt - no user prompt for this test - Some(tool_context) - ).await; - + let result = executor + .run_hooks( + hooks, + &mut output, + ".", // cwd - using current directory for now + None, // prompt - no user prompt for this test + Some(tool_context), + ) + .await; + assert!(result.is_ok()); let hook_results = result.unwrap(); - + // Should run 0 hooks because matcher doesn't match tool_name assert_eq!(hook_results.len(), 0); - + // Output should be empty since no hooks ran assert!(output.is_empty()); } @@ -654,7 +664,7 @@ mod tests { let command = "echo 'Tool execution blocked by security policy' >&2; exit 2"; #[cfg(windows)] let command = "echo Tool execution blocked by security policy 1>&2 & exit /b 2"; - + let hook = Hook { command: command.to_string(), timeout_ms: 5000, @@ -664,9 +674,7 @@ mod tests { source: crate::cli::agent::hook::Source::Session, }; - let hooks = HashMap::from([ - (HookTrigger::PreToolUse, vec![hook]) - ]); + let hooks = HashMap::from([(HookTrigger::PreToolUse, vec![hook])]); let tool_context = ToolContext { tool_name: "fs_write".to_string(), @@ -677,17 +685,20 @@ mod tests { tool_response: None, }; - let results = executor.run_hooks( - hooks, - &mut output, - ".", // cwd - None, // prompt - Some(tool_context) - ).await.unwrap(); + let results = executor + .run_hooks( + hooks, + &mut output, + ".", // cwd + None, // prompt + Some(tool_context), + ) + .await + .unwrap(); // Should have one result assert_eq!(results.len(), 1); - + let ((trigger, _hook), (exit_code, hook_output)) = &results[0]; assert_eq!(*trigger, HookTrigger::PreToolUse); assert_eq!(*exit_code, 2); diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs index fa39f80397..89edfb47aa 100644 --- a/crates/chat-cli/src/cli/chat/context.rs +++ b/crates/chat-cli/src/cli/chat/context.rs @@ -14,9 +14,9 @@ use serde::{ Serializer, }; +use super::cli::hooks::HookOutput; use super::cli::model::context_window_tokens; use super::util::drop_matched_context_files; -use super::cli::hooks::HookOutput; use crate::cli::agent::Agent; use crate::cli::agent::hook::{ Hook, @@ -255,7 +255,9 @@ impl ContextManager { let mut hooks = self.hooks.clone(); hooks.retain(|t, _| *t == trigger); let cwd = os.env.current_dir()?.to_string_lossy().to_string(); - self.hook_executor.run_hooks(hooks, output, &cwd, prompt, tool_context).await + self.hook_executor + .run_hooks(hooks, output, &cwd, prompt, tool_context) + .await } } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 9179697807..3d064c0180 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -564,12 +564,26 @@ impl ConversationState { let mut agent_spawn_context = None; if let Some(cm) = self.context_manager.as_mut() { let user_prompt = self.next_message.as_ref().and_then(|m| m.prompt()); - let agent_spawn = cm.run_hooks(HookTrigger::AgentSpawn, output, os, user_prompt, None /* tool_context */).await?; + let agent_spawn = cm + .run_hooks( + HookTrigger::AgentSpawn, + output, + os, + user_prompt, + None, // tool_context + ) + .await?; agent_spawn_context = format_hook_context(&agent_spawn, HookTrigger::AgentSpawn); if let (true, Some(next_message)) = (run_perprompt_hooks, self.next_message.as_mut()) { let per_prompt = cm - .run_hooks(HookTrigger::UserPromptSubmit, output, os, next_message.prompt(), None /* tool_context */) + .run_hooks( + HookTrigger::UserPromptSubmit, + output, + os, + next_message.prompt(), + None, // tool_context + ) .await?; if let Some(ctx) = format_hook_context(&per_prompt, HookTrigger::UserPromptSubmit) { next_message.additional_context = ctx; @@ -1033,7 +1047,10 @@ impl From for ToolInputSchema { /// [Option::None] fn format_hook_context(hook_results: &[((HookTrigger, Hook), HookOutput)], trigger: HookTrigger) -> Option { // Note: only format context when hook command exit code is 0 - if hook_results.iter().all(|(_, (exit_code, content))| *exit_code != 0 || content.is_empty()) { + if hook_results + .iter() + .all(|(_, (exit_code, content))| *exit_code != 0 || content.is_empty()) + { return None; } @@ -1046,7 +1063,10 @@ fn format_hook_context(hook_results: &[((HookTrigger, Hook), HookOutput)], trigg } context_content.push_str("\n\n"); - for (_, (_, output)) in hook_results.iter().filter(|((h_trigger, _), (exit_code, _))| *h_trigger == trigger && *exit_code == 0) { + for (_, (_, output)) in hook_results + .iter() + .filter(|((h_trigger, _), (exit_code, _))| *h_trigger == trigger && *exit_code == 0) + { context_content.push_str(&format!("{output}\n\n")); } context_content.push_str(CONTEXT_ENTRY_END_HEADER); diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 2da6b3d5ca..e155099450 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -2432,37 +2432,42 @@ impl ChatSession { if let Some(cm) = self.conversation.context_manager.as_mut() { for result in &tool_results { if let Some(tool) = self.tool_uses.iter().find(|t| t.id == result.tool_use_id) { - let content: Vec = result.content.iter().map(|block| { - match block { + let content: Vec = result + .content + .iter() + .map(|block| match block { ToolUseResultBlock::Text(text) => serde_json::Value::String(text.clone()), ToolUseResultBlock::Json(json) => json.clone(), - } - }).collect(); - + }) + .collect(); + let tool_response = match result.status { ToolResultStatus::Success => serde_json::json!({"success": true, "result": content}), ToolResultStatus::Error => serde_json::json!({"success": false, "error": content}), }; - + let tool_context = ToolContext { tool_name: match &tool.tool { - Tool::Custom(custom_tool) => custom_tool.namespaced_tool_name(), // for MCP tool, pass MCP name to the hook + Tool::Custom(custom_tool) => custom_tool.namespaced_tool_name(), /* for MCP tool, pass MCP name to the hook */ _ => tool.name.clone(), }, tool_input: tool.tool_input.clone(), tool_response: Some(tool_response), }; - + // Here is how we handle postToolUse output: - // Exit code is 0: nothing. stdout is not shown to user. We don't support processing the PostToolUse hook output yet. - // Exit code is non-zero: display an error to user (already taken care of by the ContextManager.run_hooks) - let _ = cm.run_hooks( - crate::cli::agent::hook::HookTrigger::PostToolUse, - &mut std::io::stderr(), - os, - None, - Some(tool_context) - ).await; + // Exit code is 0: nothing. stdout is not shown to user. We don't support processing the PostToolUse + // hook output yet. Exit code is non-zero: display an error to user (already + // taken care of by the ContextManager.run_hooks) + let _ = cm + .run_hooks( + crate::cli::agent::hook::HookTrigger::PostToolUse, + &mut std::io::stderr(), + os, + None, + Some(tool_context), + ) + .await; } } } @@ -2802,7 +2807,8 @@ impl ChatSession { } } - // Validate the tool use request from LLM, including basic checks like fs_read file should exist, as well as user-defined preToolUse hook check. + // Validate the tool use request from LLM, including basic checks like fs_read file should exist, as + // well as user-defined preToolUse hook check. async fn validate_tools(&mut self, os: &Os, tool_uses: Vec) -> Result { let conv_id = self.conversation.conversation_id().to_owned(); debug!(?tool_uses, "Validating tool uses"); @@ -2907,38 +2913,44 @@ impl ChatSession { } // Execute PreToolUse hooks for all validated tools - // The mental model is preToolHook is like validate tools, but its behavior can be customized by user - // Note that after preTookUse hook, user can still reject the took run + // The mental model is preToolHook is like validate tools, but its behavior can be customized by + // user Note that after preTookUse hook, user can still reject the took run if let Some(cm) = self.conversation.context_manager.as_mut() { for tool in &queued_tools { let tool_context = ToolContext { tool_name: match &tool.tool { - Tool::Custom(custom_tool) => custom_tool.namespaced_tool_name(), // for MCP tool, pass MCP name to the hook + Tool::Custom(custom_tool) => custom_tool.namespaced_tool_name(), // for MCP tool, pass MCP + // name to the hook _ => tool.name.clone(), }, tool_input: tool.tool_input.clone(), tool_response: None, }; - - let hook_results = cm.run_hooks( - crate::cli::agent::hook::HookTrigger::PreToolUse, - &mut std::io::stderr(), - os, - None, /* prompt */ - Some(tool_context) - ).await?; + + let hook_results = cm + .run_hooks( + crate::cli::agent::hook::HookTrigger::PreToolUse, + &mut std::io::stderr(), + os, + None, // prompt + Some(tool_context), + ) + .await?; // Here is how we handle the preToolUse hook output: // Exit code is 0: nothing. stdout is not shown to user. // Exit code is 2: block the tool use. return stderr to LLM. show warning to user // Other error: show warning to user. - + // Check for exit code 2 and add to tool_results for (_, (exit_code, output)) in &hook_results { if *exit_code == 2 { tool_results.push(ToolUseResult { tool_use_id: tool.id.clone(), - content: vec![ToolUseResultBlock::Text(format!("PreToolHook blocked the tool execution: {}", output))], + content: vec![ToolUseResultBlock::Text(format!( + "PreToolHook blocked the tool execution: {}", + output + ))], status: ToolResultStatus::Error, }); } @@ -2974,7 +2986,7 @@ impl ChatSession { self.tool_uses = queued_tools; self.pending_tool_index = Some(0); self.tool_turn_start_time = Some(Instant::now()); - + Ok(ChatState::ExecuteTools) } @@ -3897,14 +3909,18 @@ mod tests { } // Integration test for PreToolUse hook functionality. - // - // In this integration test we create a preToolUse hook that logs tool info into a file + // + // In this integration test we create a preToolUse hook that logs tool info into a file // and we run fs_read and verify the log is generated with the correct ToolContext data. #[tokio::test] async fn test_tool_hook_integration() { - use crate::cli::agent::hook::{Hook, HookTrigger}; use std::collections::HashMap; + use crate::cli::agent::hook::{ + Hook, + HookTrigger, + }; + let mut os = Os::new().await.unwrap(); os.client.set_mock_output(serde_json::json!([ [ @@ -3935,13 +3951,13 @@ mod tests { // Create agent with PreToolUse and PostToolUse hooks let mut agents = Agents::default(); let mut hooks = HashMap::new(); - + // Get the real path in the temp directory for the hooks to write to let pre_hook_log_path = os.fs.chroot_path_str("/pre-hook-test.log"); let post_hook_log_path = os.fs.chroot_path_str("/post-hook-test.log"); let pre_hook_command = format!("cat > {}", pre_hook_log_path); let post_hook_command = format!("cat > {}", post_hook_log_path); - + hooks.insert(HookTrigger::PreToolUse, vec![Hook { command: pre_hook_command, timeout_ms: 5000, @@ -4002,16 +4018,16 @@ mod tests { // Verify the PreToolUse hook was called if let Ok(pre_log_content) = os.fs.read_to_string("/pre-hook-test.log").await { - let pre_hook_data: serde_json::Value = serde_json::from_str(&pre_log_content) - .expect("PreToolUse hook output should be valid JSON"); - + let pre_hook_data: serde_json::Value = + serde_json::from_str(&pre_log_content).expect("PreToolUse hook output should be valid JSON"); + assert_eq!(pre_hook_data["hook_event_name"], "preToolUse"); assert_eq!(pre_hook_data["tool_name"], "fs_read"); assert_eq!(pre_hook_data["tool_response"], serde_json::Value::Null); - + let tool_input = &pre_hook_data["tool_input"]; assert!(tool_input["operations"].is_array()); - + println!("✓ PreToolUse hook validation passed: {}", pre_log_content); } else { panic!("PreToolUse hook log file not found - hook may not have been called"); @@ -4019,22 +4035,22 @@ mod tests { // Verify the PostToolUse hook was called if let Ok(post_log_content) = os.fs.read_to_string("/post-hook-test.log").await { - let post_hook_data: serde_json::Value = serde_json::from_str(&post_log_content) - .expect("PostToolUse hook output should be valid JSON"); - + let post_hook_data: serde_json::Value = + serde_json::from_str(&post_log_content).expect("PostToolUse hook output should be valid JSON"); + assert_eq!(post_hook_data["hook_event_name"], "postToolUse"); assert_eq!(post_hook_data["tool_name"], "fs_read"); - + // Validate tool_response structure for successful execution let tool_response = &post_hook_data["tool_response"]; assert_eq!(tool_response["success"], true); assert!(tool_response["result"].is_array()); - + let result_blocks = tool_response["result"].as_array().unwrap(); assert!(!result_blocks.is_empty()); let content = result_blocks[0].as_str().unwrap(); assert!(content.contains("line1\nline2\nline3")); - + println!("✓ PostToolUse hook validation passed: {}", post_log_content); } else { panic!("PostToolUse hook log file not found - hook may not have been called"); @@ -4043,20 +4059,24 @@ mod tests { #[tokio::test] async fn test_pretool_hook_blocking_integration() { - use crate::cli::agent::hook::{Hook, HookTrigger}; use std::collections::HashMap; + use crate::cli::agent::hook::{ + Hook, + HookTrigger, + }; + let mut os = Os::new().await.unwrap(); - + // Create a test file to read os.fs.write("/sensitive.txt", "classified information").await.unwrap(); - + // Mock LLM responses: first tries fs_read, gets blocked, then responds to error os.client.set_mock_output(serde_json::json!([ [ "I'll read that file for you", { - "tool_use_id": "1", + "tool_use_id": "1", "name": "fs_read", "args": { "operations": [ @@ -4076,13 +4096,13 @@ mod tests { // Create agent with blocking PreToolUse hook let mut agents = Agents::default(); let mut hooks = HashMap::new(); - + // Create a hook that blocks fs_read of sensitive files with exit code 2 #[cfg(unix)] let hook_command = "echo 'Security policy violation: cannot read sensitive files' >&2; exit 2"; #[cfg(windows)] let hook_command = "echo Security policy violation: cannot read sensitive files 1>&2 & exit /b 2"; - + hooks.insert(HookTrigger::PreToolUse, vec![Hook { command: hook_command.to_string(), timeout_ms: 5000, @@ -4112,10 +4132,7 @@ mod tests { "test_conv_id", agents, None, - InputSource::new_mock(vec![ - "read /sensitive.txt".to_string(), - "exit".to_string(), - ]), + InputSource::new_mock(vec!["read /sensitive.txt".to_string(), "exit".to_string()]), false, || Some(80), tool_manager, @@ -4131,7 +4148,10 @@ mod tests { .await; // The session should complete successfully (hook blocks tool but doesn't crash) - assert!(result.is_ok(), "Chat session should complete successfully even when hook blocks tool"); + assert!( + result.is_ok(), + "Chat session should complete successfully even when hook blocks tool" + ); } #[test] diff --git a/crates/chat-cli/src/cli/chat/tools/introspect.rs b/crates/chat-cli/src/cli/chat/tools/introspect.rs index f14351ab54..4968cd8e94 100644 --- a/crates/chat-cli/src/cli/chat/tools/introspect.rs +++ b/crates/chat-cli/src/cli/chat/tools/introspect.rs @@ -123,8 +123,7 @@ impl Introspect { ); documentation .push_str("• Todo Lists: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/todo-lists.md\n"); - documentation - .push_str("• Hooks: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/hooks.md\n"); + documentation.push_str("• Hooks: https://github.com/aws/amazon-q-developer-cli/blob/main/docs/hooks.md\n"); documentation .push_str("• Contributing: https://github.com/aws/amazon-q-developer-cli/blob/main/CONTRIBUTING.md\n"); diff --git a/crates/chat-cli/src/cli/user.rs b/crates/chat-cli/src/cli/user.rs index 50e761b9b0..9394746c29 100644 --- a/crates/chat-cli/src/cli/user.rs +++ b/crates/chat-cli/src/cli/user.rs @@ -118,7 +118,7 @@ impl LoginArgs { }; let start_url = input("Enter Start URL", default_start_url.as_deref())?; - let region = input("Enter Region", default_region.as_deref())?; + let region = input("Enter Region", default_region.as_deref())?.trim().to_string(); let _ = os.database.set_start_url(start_url.clone()); let _ = os.database.set_idc_region(region.clone()); diff --git a/docs/hooks.md b/docs/hooks.md index 4a0636a06b..d7cfa3d50f 100644 --- a/docs/hooks.md +++ b/docs/hooks.md @@ -155,7 +155,7 @@ Default timeout is 30 seconds (30,000ms). Configure with `timeout_ms` field. ## Caching -Successfull hook results are cached based on `cache_ttl_seconds`: +Successful hook results are cached based on `cache_ttl_seconds`: - `0`: No caching (default) - `> 0`: Cache successful results for specified seconds - AgentSpawn hooks are never cached \ No newline at end of file From dbf21ef2952fcd15e080dc45f5d4abde40788a64 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Thu, 18 Sep 2025 12:11:42 -0700 Subject: [PATCH 52/71] chore: copy change for warning message for oauth redirect page (#2931) --- crates/chat-cli/src/mcp_client/oauth_util.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs index 9c53193922..c4d4869bf1 100644 --- a/crates/chat-cli/src/mcp_client/oauth_util.rs +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -373,7 +373,11 @@ async fn make_svc( let error = params.get("error"); let resp = if let Some(err) = error { mk_response(format!( - "Oauth failed. Check url for precise reasons. Possible reasons: {err}.\nIf this is scope related. You can try configuring the server scopes to be an empty array via adding oauth_scopes: []" + "OAuth failed. Check URL for precise reasons. Possible reasons: {}.\n\ + If this is scope related, you can try configuring the server scopes \n\ + to be an empty array by adding \"oauthScopes\": [] to your server config.\n\ + Example: {{\"type\": \"http\", \"uri\": \"https://example.com/mcp\", \"oauthScopes\": []}}\n", + err )) } else { mk_response("You can close this page now".to_string()) From 63ca7ef61768f694c19f9b14a957dfcdb7ba95e2 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Thu, 18 Sep 2025 14:42:50 -0700 Subject: [PATCH 53/71] fix: removes deny unknown fields in mcp config (#2935) --- crates/chat-cli/src/cli/chat/tools/custom_tool.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 5457c9ee17..976f2d3820 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -46,7 +46,7 @@ impl Default for TransportType { } #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] -#[serde(rename_all = "camelCase", deny_unknown_fields)] +#[serde(rename_all = "camelCase")] pub struct CustomToolConfig { /// The transport type to use for communication with the MCP server #[serde(default)] From 18bd27041c0fcb240d8dd58c2cc273d1a41098ba Mon Sep 17 00:00:00 2001 From: kkashilk <93673379+kkashilk@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:27:41 -0700 Subject: [PATCH 54/71] Normalize and expand relative-paths to absolute paths (#2933) --- crates/chat-cli/src/cli/chat/cli/clear.rs | 2 +- crates/chat-cli/src/util/directories.rs | 93 +++++++++++++++++++++-- 2 files changed, 88 insertions(+), 7 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/clear.rs b/crates/chat-cli/src/cli/chat/cli/clear.rs index de274e2c01..994ed35e03 100644 --- a/crates/chat-cli/src/cli/chat/cli/clear.rs +++ b/crates/chat-cli/src/cli/chat/cli/clear.rs @@ -57,7 +57,7 @@ impl ClearArgs { session.tool_uses.clear(); session.pending_tool_index = None; session.tool_turn_start_time = None; - + execute!( session.stderr, style::SetForegroundColor(Color::Green), diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index 89c6f3bc4e..2bb53381d7 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -185,7 +185,45 @@ pub fn canonicalizes_path(os: &Os, path_as_str: &str) -> Result { let context = |input: &str| Ok(os.env.get(input).ok()); let home_dir = || os.env.home().map(|p| p.to_string_lossy().to_string()); - Ok(shellexpand::full_with_context(path_as_str, home_dir, context)?.to_string()) + let expanded = shellexpand::full_with_context(path_as_str, home_dir, context)?; + let path_buf = if !expanded.starts_with("/") { + // Convert relative paths to absolute paths + let current_dir = os.env.current_dir()?; + current_dir.join(expanded.as_ref() as &str) + } else { + // Already absolute path + PathBuf::from(expanded.as_ref() as &str) + }; + + // Try canonicalize first, fallback to manual normalization if it fails + match path_buf.canonicalize() { + Ok(normalized) => Ok(normalized.as_path().to_string_lossy().to_string()), + Err(_) => { + // If canonicalize fails (e.g., path doesn't exist), do manual normalization + let normalized = normalize_path(&path_buf); + Ok(normalized.to_string_lossy().to_string()) + }, + } +} + +/// Manually normalize a path by resolving . and .. components +fn normalize_path(path: &std::path::Path) -> std::path::PathBuf { + let mut components = Vec::new(); + for component in path.components() { + match component { + std::path::Component::CurDir => { + // Skip current directory components + }, + std::path::Component::ParentDir => { + // Pop the last component for parent directory + components.pop(); + }, + _ => { + components.push(component); + }, + } + } + components.iter().collect() } /// Given a globset builder and a path, build globs for both the file and directory patterns @@ -445,28 +483,71 @@ mod tests { // Test home directory expansion let result = canonicalizes_path(&test_os, "~/test").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\home\\testuser\\test"); + #[cfg(unix)] assert_eq!(result, "/home/testuser/test"); // Test environment variable expansion let result = canonicalizes_path(&test_os, "$TEST_VAR/path").unwrap(); - assert_eq!(result, "test_value/path"); + #[cfg(windows)] + assert_eq!(result, "\\test_value\\path"); + #[cfg(unix)] + assert_eq!(result, "/test_value/path"); // Test combined expansion let result = canonicalizes_path(&test_os, "~/$TEST_VAR").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\home\\testuser\\test_value"); + #[cfg(unix)] assert_eq!(result, "/home/testuser/test_value"); + // Test ~, . and .. expansion + let result = canonicalizes_path(&test_os, "~/./.././testuser").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\home\\testuser"); + #[cfg(unix)] + assert_eq!(result, "/home/testuser"); + // Test absolute path (no expansion needed) let result = canonicalizes_path(&test_os, "/absolute/path").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\absolute\\path"); + #[cfg(unix)] assert_eq!(result, "/absolute/path"); - // Test relative path (no expansion needed) + // Test ~, . and .. expansion for a path that does not exist + let result = canonicalizes_path(&test_os, "~/./.././testuser/new/path/../../new").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\home\\testuser\\new"); + #[cfg(unix)] + assert_eq!(result, "/home/testuser/new"); + + // Test path with . and .. + let result = canonicalizes_path(&test_os, "/absolute/./../path").unwrap(); + #[cfg(windows)] + assert_eq!(result, "\\path"); + #[cfg(unix)] + assert_eq!(result, "/path"); + + // Test relative path (which should be expanded because now all inputs are converted to + // absolute) let result = canonicalizes_path(&test_os, "relative/path").unwrap(); - assert_eq!(result, "relative/path"); + #[cfg(windows)] + assert_eq!(result, "\\relative\\path"); + #[cfg(unix)] + assert_eq!(result, "/relative/path"); // Test glob prefixed paths let result = canonicalizes_path(&test_os, "**/path").unwrap(); - assert_eq!(result, "**/path"); + #[cfg(windows)] + assert_eq!(result, "\\**\\path"); + #[cfg(unix)] + assert_eq!(result, "/**/path"); let result = canonicalizes_path(&test_os, "**/middle/**/path").unwrap(); - assert_eq!(result, "**/middle/**/path"); + #[cfg(windows)] + assert_eq!(result, "\\**\\middle\\**\\path"); + #[cfg(unix)] + assert_eq!(result, "/**/middle/**/path"); } } From 46403c12549feddb07b3ffec0ad442e5c2237e00 Mon Sep 17 00:00:00 2001 From: kkashilk <93673379+kkashilk@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:54:17 -0700 Subject: [PATCH 55/71] Bump version to 1.16.2 and update feed.json (#2938) --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- crates/chat-cli/src/cli/feed.json | 32 +++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f85bfb292f..7230af9f7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1207,7 +1207,7 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chat_cli" -version = "1.16.1" +version = "1.16.2" dependencies = [ "amzn-codewhisperer-client", "amzn-codewhisperer-streaming-client", @@ -5647,7 +5647,7 @@ dependencies = [ [[package]] name = "semantic_search_client" -version = "1.16.1" +version = "1.16.2" dependencies = [ "anyhow", "bm25", diff --git a/Cargo.toml b/Cargo.toml index bc7b16ff38..7f0b79de2a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ authors = ["Amazon Q CLI Team (q-cli@amazon.com)", "Chay Nabors (nabochay@amazon edition = "2024" homepage = "https://aws.amazon.com/q/" publish = false -version = "1.16.1" +version = "1.16.2" license = "MIT OR Apache-2.0" [workspace.dependencies] diff --git a/crates/chat-cli/src/cli/feed.json b/crates/chat-cli/src/cli/feed.json index a9d0f68eb6..e81a34d24d 100644 --- a/crates/chat-cli/src/cli/feed.json +++ b/crates/chat-cli/src/cli/feed.json @@ -10,6 +10,38 @@ "hidden": true, "changes": [] }, + { + "type": "release", + "date": "2025-09-19", + "version": "1.16.2", + "title": "Version 1.16.2", + "changes": [ + { + "type": "added", + "description": "Add support for preToolUse and postToolUse hook - [#2875](https://github.com/aws/amazon-q-developer-cli/pull/2875)" + }, + { + "type": "added", + "description": "Support for specifying oauth scopes via config - [#2925]( https://github.com/aws/amazon-q-developer-cli/pull/2925)" + }, + { + "type": "fixed", + "description": "Support for headers ingestion for remote mcp - [#2925]( https://github.com/aws/amazon-q-developer-cli/pull/2925)" + }, + { + "type": "added", + "description": "Change autocomplete shortcut from ctrl-f to ctrl-g - [#2634](https://github.com/aws/amazon-q-developer-cli/pull/2634)" + }, + { + "type": "fixed", + "description": "Fix file-path expansion in mcp-config - [#2915]( https://github.com/aws/amazon-q-developer-cli/pull/2915)" + }, + { + "type": "fixed", + "description": "Fix filepath expansion to use absolute paths - [#2933](https://github.com/aws/amazon-q-developer-cli/pull/2933)" + } + ] + }, { "type": "release", "date": "2025-09-17", From c90c87d8762e736c053fd25a04919224310d2b77 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Thu, 18 Sep 2025 17:41:18 -0700 Subject: [PATCH 56/71] changes how mcp command gets expanded (#2940) --- crates/chat-cli/src/mcp_client/client.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index b05ed3ad24..44473c3a74 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -60,10 +60,7 @@ use crate::cli::chat::tools::custom_tool::{ TransportType, }; use crate::os::Os; -use crate::util::directories::{ - DirectoryError, - canonicalizes_path, -}; +use crate::util::directories::DirectoryError; /// Fetches all pages of specified resources from a server macro_rules! paginated_fetch { @@ -153,6 +150,8 @@ pub enum McpClientError { Auth(#[from] crate::auth::AuthError), #[error("{0}")] MalformedConfig(&'static str), + #[error(transparent)] + LookUp(#[from] shellexpand::LookupError), } /// Decorates the method passed in with retry logic, but only if the [RunningService] has an @@ -527,8 +526,11 @@ impl McpClientService { match r#type { TransportType::Stdio => { - let expanded_cmd = canonicalizes_path(os, command_as_str)?; - let command = Command::new(expanded_cmd).configure(|cmd| { + let context = |input: &str| Ok(os.env.get(input).ok()); + let home_dir = || os.env.home().map(|p| p.to_string_lossy().to_string()); + let expanded_cmd = shellexpand::full_with_context(command_as_str, home_dir, context)?; + + let command = Command::new(expanded_cmd.as_ref() as &str).configure(|cmd| { if let Some(envs) = config_envs { process_env_vars(envs, &os.env); cmd.envs(envs); From c031c331b8683995ad514c26c1eab3eaf927ff96 Mon Sep 17 00:00:00 2001 From: kkashilk <93673379+kkashilk@users.noreply.github.com> Date: Fri, 19 Sep 2025 11:53:32 -0700 Subject: [PATCH 57/71] Update default tool label for execute bash (#2945) --- crates/chat-cli/src/cli/agent/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 3dcc659203..f81133d001 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -818,9 +818,9 @@ impl Agents { "fs_read" => "trust working directory".dark_grey(), "fs_write" => "not trusted".dark_grey(), #[cfg(not(windows))] - "execute_bash" => "trust read-only commands".dark_grey(), + "execute_bash" => "not trusted".dark_grey(), #[cfg(windows)] - "execute_cmd" => "trust read-only commands".dark_grey(), + "execute_cmd" => "not trusted".dark_grey(), "use_aws" => "trust read-only commands".dark_grey(), "report_issue" => "trusted".dark_green().bold(), "introspect" => "trusted".dark_green().bold(), From 37a77e0a6246dbf0ca5e71d16683a9969daa86d7 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Fri, 19 Sep 2025 11:57:30 -0700 Subject: [PATCH 58/71] fix: incorrect wrapping for response text (#2900) --- crates/chat-cli/src/cli/chat/parse.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/parse.rs b/crates/chat-cli/src/cli/chat/parse.rs index 72f0ab94c2..ed30f3944b 100644 --- a/crates/chat-cli/src/cli/chat/parse.rs +++ b/crates/chat-cli/src/cli/chat/parse.rs @@ -81,6 +81,7 @@ impl<'a> ParserError> for Error<'a> { #[derive(Debug)] pub struct ParseState { + pub is_first_line: bool, pub terminal_width: Option, pub markdown_disabled: Option, pub column: usize, @@ -96,6 +97,7 @@ pub struct ParseState { impl ParseState { pub fn new(terminal_width: Option, markdown_disabled: Option) -> Self { Self { + is_first_line: true, terminal_width, markdown_disabled, column: 0, @@ -198,8 +200,17 @@ fn text<'a, 'b>( ) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { move |i| { let content = take_while(1.., |t| AsChar::is_alphanum(t) || "+,.!?\"".contains(t)).parse_next(i)?; - queue_newline_or_advance(&mut o, state, content.width())?; + if state.is_first_line { + state.is_first_line = false; + // The extra space here is reserved for the prompt pointer ("> "). + // Essentially we want the input to wrap as if the prompt pointer is a part of it + // but only display what is received. + queue_newline_or_advance(&mut o, state, content.width() + 2)?; + } else { + queue_newline_or_advance(&mut o, state, content.width())?; + } queue(&mut o, style::Print(content))?; + Ok(()) } } From 13fedbda8a9f0798e9699dff9874cd9cca67f73b Mon Sep 17 00:00:00 2001 From: kkashilk <93673379+kkashilk@users.noreply.github.com> Date: Tue, 23 Sep 2025 10:36:44 -0700 Subject: [PATCH 59/71] Improve error messages for dispatch failures (#2969) --- crates/chat-cli/src/auth/mod.rs | 8 +++++--- crates/chat-cli/src/cli/agent/mod.rs | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/crates/chat-cli/src/auth/mod.rs b/crates/chat-cli/src/auth/mod.rs index 4b425f2a6f..1e38864750 100644 --- a/crates/chat-cli/src/auth/mod.rs +++ b/crates/chat-cli/src/auth/mod.rs @@ -14,15 +14,17 @@ pub use builder_id::{ pub use consts::START_URL; use thiserror::Error; +use crate::aws_common::SdkErrorDisplay; + #[derive(Debug, Error)] pub enum AuthError { #[error(transparent)] Ssooidc(Box), - #[error(transparent)] + #[error("{}", SdkErrorDisplay(.0))] SdkRegisterClient(Box>), - #[error(transparent)] + #[error("{}", SdkErrorDisplay(.0))] SdkCreateToken(Box>), - #[error(transparent)] + #[error("{}", SdkErrorDisplay(.0))] SdkStartDeviceAuthorization(Box>), #[error(transparent)] Io(#[from] std::io::Error), diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index f81133d001..02b984dd2d 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -1189,8 +1189,8 @@ mod tests { let execute_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; let execute_bash_label = agents.display_label(execute_name, &ToolOrigin::Native); assert!( - execute_bash_label.contains("read-only"), - "execute_bash should show read-only by default, instead found: {}", + execute_bash_label.contains("not trusted"), + "execute_bash should not be trusted by default, instead found: {}", execute_bash_label ); } From bf183f0c58650357312ef0c11caabb9546e4b189 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Wed, 24 Sep 2025 11:08:28 -0700 Subject: [PATCH 60/71] fix(mcp): hardcodes client id for oauth (#2976) --- crates/chat-cli/src/mcp_client/oauth_util.rs | 87 ++++++++++++++++++-- 1 file changed, 81 insertions(+), 6 deletions(-) diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs index c4d4869bf1..8af59a6c13 100644 --- a/crates/chat-cli/src/mcp_client/oauth_util.rs +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -28,6 +28,7 @@ use rmcp::transport::streamable_http_client::{ }; use rmcp::transport::{ AuthorizationManager, + AuthorizationSession, StreamableHttpClientTransport, WorkerTransport, }; @@ -194,10 +195,12 @@ pub async fn get_http_transport( }; let reqwest_client = client_builder.build()?; - let probe_resp = reqwest_client.get(url.clone()).send().await?; + // The probe request, like all other request, should adhere to the standards as per https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + let mut probe_request = reqwest_client.post(url.clone()); + probe_request = probe_request.header("Accept", "application/json, text/event-stream"); + let probe_resp = probe_request.send().await?; match probe_resp.status() { StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - debug!("## mcp: requires auth, auth client passed in is {:?}", auth_client); let auth_client = match auth_client { Some(auth_client) => auth_client, None => { @@ -215,12 +218,11 @@ pub async fn get_http_transport( let transport = StreamableHttpClientTransport::with_client(auth_client.clone(), StreamableHttpClientTransportConfig { uri: url.as_str().into(), - allow_stateless: false, + allow_stateless: true, ..Default::default() }); let auth_dg = AuthClientWrapper::new(cred_full_path, auth_client); - debug!("## mcp: transport obtained"); Ok(HttpTransport::WithAuth((transport, auth_dg))) }, @@ -228,7 +230,7 @@ pub async fn get_http_transport( let transport = StreamableHttpClientTransport::with_client(reqwest_client, StreamableHttpClientTransportConfig { uri: url.as_str().into(), - allow_stateless: false, + allow_stateless: true, ..Default::default() }); @@ -311,7 +313,7 @@ async fn get_auth_manager_impl( let redirect_uri = format!("http://{}", actual_addr); let scopes_as_str = scopes.iter().map(String::as_str).collect::>(); let scopes_as_slice = scopes_as_str.as_slice(); - oauth_state.start_authorization(scopes_as_slice, &redirect_uri).await?; + start_authorization(&mut oauth_state, scopes_as_slice, &redirect_uri).await?; let auth_url = oauth_state.get_authorization_url().await?; _ = messenger.send_oauth_link(auth_url).await; @@ -332,6 +334,79 @@ pub fn compute_key(rs: &Url) -> String { format!("{:x}", hasher.finalize()) } +/// This is our own implementation of [OAuthState::start_authorization]. +/// This differs from [OAuthState::start_authorization] by assigning our own client_id for DCR. +/// We need this because the SDK hardcodes their own client id. And some servers will use client_id +/// to identify if a client is even allowed to perform the auth handshake. +async fn start_authorization( + oauth_state: &mut OAuthState, + scopes: &[&str], + redirect_uri: &str, +) -> Result<(), OauthUtilError> { + // DO NOT CHANGE THIS + // This string has significance as it is used for remote servers to identify us + const CLIENT_ID: &str = "Q DEV CLI"; + + let stub_cred = get_stub_credentials()?; + oauth_state.set_credentials(CLIENT_ID, stub_cred).await?; + + // The setting of credentials would put the oauth state into authorize. + if let OAuthState::Authorized(auth_manager) = oauth_state { + // set redirect uri + let config = OAuthClientConfig { + client_id: CLIENT_ID.to_string(), + client_secret: None, + scopes: scopes.iter().map(|s| (*s).to_string()).collect(), + redirect_uri: redirect_uri.to_string(), + }; + + // try to dynamic register client + let config = match auth_manager.register_client(CLIENT_ID, redirect_uri).await { + Ok(config) => config, + Err(e) => { + eprintln!("Dynamic registration failed: {}", e); + // fallback to default config + config + }, + }; + // reset client config + auth_manager.configure_client(config)?; + let auth_url = auth_manager.get_authorization_url(scopes).await?; + + let mut stub_auth_manager = AuthorizationManager::new("http://localhost").await?; + std::mem::swap(auth_manager, &mut stub_auth_manager); + + let session = AuthorizationSession { + auth_manager: stub_auth_manager, + auth_url, + redirect_uri: redirect_uri.to_string(), + }; + + let mut new_oauth_state = OAuthState::Session(session); + std::mem::swap(oauth_state, &mut new_oauth_state); + } else { + unreachable!() + } + + Ok(()) +} + +/// This looks silly but [rmcp::transport::auth::OAuthTokenResponse] is private and there is no +/// other way to create this directly +fn get_stub_credentials() -> Result { + const STUB_TOKEN: &str = r#" + { + "access_token": "stub", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "stub", + "scope": "stub" + } + "#; + + serde_json::from_str::(STUB_TOKEN) +} + async fn make_svc( one_shot_sender: Sender, socket_addr: SocketAddr, From 5acd7e3f3c4607380b4424cfa70147eaf93183a0 Mon Sep 17 00:00:00 2001 From: nirajchowdhary <226941436+nirajchowdhary@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:08:32 -0700 Subject: [PATCH 61/71] =?UTF-8?q?fix:=20consolidate=20tool=20permission=20?= =?UTF-8?q?logic=20for=20consistent=20display=20and=20exe=E2=80=A6=20(#297?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: consolidate tool permission logic for consistent display and execution * fix: centralize tool permission checking logic --------- Co-authored-by: Niraj Chowdhary --- crates/chat-cli/src/cli/agent/mod.rs | 30 ++----- .../src/cli/chat/tools/custom_tool.rs | 17 ++-- .../src/cli/chat/tools/execute/mod.rs | 4 +- crates/chat-cli/src/cli/chat/tools/fs_read.rs | 4 +- .../chat-cli/src/cli/chat/tools/fs_write.rs | 4 +- .../chat-cli/src/cli/chat/tools/knowledge.rs | 4 +- crates/chat-cli/src/cli/chat/tools/use_aws.rs | 4 +- crates/chat-cli/src/util/mod.rs | 1 + .../src/util/tool_permission_checker.rs | 82 +++++++++++++++++++ 9 files changed, 104 insertions(+), 46 deletions(-) create mode 100644 crates/chat-cli/src/util/tool_permission_checker.rs diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 02b984dd2d..3b6c7fbbee 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -776,32 +776,14 @@ impl Agents { /// Returns a label to describe the permission status for a given tool. pub fn display_label(&self, tool_name: &str, origin: &ToolOrigin) -> String { - use crate::util::pattern_matching::matches_any_pattern; + use crate::util::tool_permission_checker::is_tool_in_allowlist; let tool_trusted = self.get_active().is_some_and(|a| { - if matches!(origin, &ToolOrigin::Native) { - return matches_any_pattern(&a.allowed_tools, tool_name); - } - - a.allowed_tools.iter().any(|name| { - name.strip_prefix("@").is_some_and(|remainder| { - remainder - .split_once(MCP_SERVER_TOOL_DELIMITER) - .is_some_and(|(_left, right)| right == tool_name) - || remainder == >::borrow(origin) - }) || { - if let Some(server_name) = name.strip_prefix("@").and_then(|s| s.split('/').next()) { - if server_name == >::borrow(origin) { - let tool_pattern = format!("@{}/{}", server_name, tool_name); - matches_any_pattern(&a.allowed_tools, &tool_pattern) - } else { - false - } - } else { - false - } - } - }) + let server_name = match origin { + ToolOrigin::Native => None, + _ => Some(>::borrow(origin)), + }; + is_tool_in_allowlist(&a.allowed_tools, tool_name, server_name) }); if tool_trusted || self.trust_all_tools { diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 976f2d3820..0e45205678 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -28,7 +28,6 @@ use crate::mcp_client::{ }; use crate::os::Os; use crate::util::MCP_SERVER_TOOL_DELIMITER; -use crate::util::pattern_matching::matches_any_pattern; #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] #[serde(rename_all = "camelCase")] @@ -175,18 +174,12 @@ impl CustomTool { } pub fn eval_perm(&self, _os: &Os, agent: &Agent) -> PermissionEvalResult { - let server_name = &self.server_name; + use crate::util::tool_permission_checker::is_tool_in_allowlist; - let server_pattern = format!("@{server_name}"); - if agent.allowed_tools.contains(&server_pattern) { - return PermissionEvalResult::Allow; - } - - let tool_pattern = self.namespaced_tool_name(); - if matches_any_pattern(&agent.allowed_tools, &tool_pattern) { - return PermissionEvalResult::Allow; + if is_tool_in_allowlist(&agent.allowed_tools, &self.name, Some(&self.server_name)) { + PermissionEvalResult::Allow + } else { + PermissionEvalResult::Ask } - - PermissionEvalResult::Ask } } diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs index e53daa6ef6..200cac641a 100644 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs @@ -23,7 +23,7 @@ use crate::cli::chat::tools::{ }; use crate::cli::chat::util::truncate_safe; use crate::os::Os; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; // Platform-specific modules #[cfg(windows)] @@ -205,7 +205,7 @@ impl ExecuteCommand { let Self { command, .. } = self; let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, tool_name); + let is_in_allowlist = is_tool_in_allowlist(&agent.allowed_tools, tool_name, None); match agent.tools_settings.get(tool_name) { Some(settings) => { let Settings { diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs index 5c5c0abf25..e6dc7e31ba 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_read.rs @@ -46,7 +46,7 @@ use crate::cli::chat::{ }; use crate::os::Os; use crate::util::directories; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; #[derive(Debug, Clone, Deserialize)] pub struct FsRead { @@ -113,7 +113,7 @@ impl FsRead { allow_read_only: bool, } - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_read"); + let is_in_allowlist = is_tool_in_allowlist(&agent.allowed_tools, "fs_read", None); let settings = agent .tools_settings .get("fs_read") diff --git a/crates/chat-cli/src/cli/chat/tools/fs_write.rs b/crates/chat-cli/src/cli/chat/tools/fs_write.rs index 284706f43a..d23a957b60 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_write.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_write.rs @@ -45,7 +45,7 @@ use crate::cli::agent::{ use crate::cli::chat::line_tracker::FileLineTracker; use crate::os::Os; use crate::util::directories; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; static SYNTAX_SET: LazyLock = LazyLock::new(SyntaxSet::load_defaults_newlines); static THEME_SET: LazyLock = LazyLock::new(ThemeSet::load_defaults); @@ -470,7 +470,7 @@ impl FsWrite { denied_paths: Vec, } - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_write"); + let is_in_allowlist = is_tool_in_allowlist(&agent.allowed_tools, "fs_write", None); match agent.tools_settings.get("fs_write") { Some(settings) => { let Settings { diff --git a/crates/chat-cli/src/cli/chat/tools/knowledge.rs b/crates/chat-cli/src/cli/chat/tools/knowledge.rs index 7bde4f0651..7f8428e6f8 100644 --- a/crates/chat-cli/src/cli/chat/tools/knowledge.rs +++ b/crates/chat-cli/src/cli/chat/tools/knowledge.rs @@ -20,7 +20,7 @@ use crate::cli::agent::{ use crate::database::settings::Setting; use crate::os::Os; use crate::util::knowledge_store::KnowledgeStore; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; /// The Knowledge tool allows storing and retrieving information across chat sessions. /// It provides semantic search capabilities for files, directories, and text content. @@ -497,7 +497,7 @@ impl Knowledge { _ = self; _ = os; - if matches_any_pattern(&agent.allowed_tools, "knowledge") { + if is_tool_in_allowlist(&agent.allowed_tools, "knowledge", None) { PermissionEvalResult::Allow } else { PermissionEvalResult::Ask diff --git a/crates/chat-cli/src/cli/chat/tools/use_aws.rs b/crates/chat-cli/src/cli/chat/tools/use_aws.rs index 3bb9611ea4..b7390744cd 100644 --- a/crates/chat-cli/src/cli/chat/tools/use_aws.rs +++ b/crates/chat-cli/src/cli/chat/tools/use_aws.rs @@ -29,7 +29,7 @@ use crate::cli::agent::{ PermissionEvalResult, }; use crate::os::Os; -use crate::util::pattern_matching::matches_any_pattern; +use crate::util::tool_permission_checker::is_tool_in_allowlist; const READONLY_OPS: [&str; 6] = ["get", "describe", "list", "ls", "search", "batch_get"]; @@ -187,7 +187,7 @@ impl UseAws { } let Self { service_name, .. } = self; - let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "use_aws"); + let is_in_allowlist = is_tool_in_allowlist(&agent.allowed_tools, "use_aws", None); match agent.tools_settings.get("use_aws") { Some(settings) => { let settings = match serde_json::from_value::(settings.clone()) { diff --git a/crates/chat-cli/src/util/mod.rs b/crates/chat-cli/src/util/mod.rs index 648d90cad1..48d8c94c97 100644 --- a/crates/chat-cli/src/util/mod.rs +++ b/crates/chat-cli/src/util/mod.rs @@ -7,6 +7,7 @@ pub mod spinner; pub mod system_info; #[cfg(test)] pub mod test; +pub mod tool_permission_checker; pub mod ui; use std::fmt::Display; diff --git a/crates/chat-cli/src/util/tool_permission_checker.rs b/crates/chat-cli/src/util/tool_permission_checker.rs new file mode 100644 index 0000000000..f1cc04f895 --- /dev/null +++ b/crates/chat-cli/src/util/tool_permission_checker.rs @@ -0,0 +1,82 @@ +use std::collections::HashSet; + +use tracing::debug; + +use crate::util::MCP_SERVER_TOOL_DELIMITER; +use crate::util::pattern_matching::matches_any_pattern; + +/// Checks if a tool is allowed based on the agent's allowed_tools configuration. +/// This function handles both native tools and MCP tools with wildcard pattern support. +pub fn is_tool_in_allowlist(allowed_tools: &HashSet, tool_name: &str, server_name: Option<&str>) -> bool { + let filter_patterns = |predicate: fn(&str) -> bool| -> HashSet { + allowed_tools + .iter() + .filter(|pattern| predicate(pattern)) + .cloned() + .collect() + }; + + match server_name { + // Native tool + None => { + let patterns = filter_patterns(|p| !p.starts_with('@')); + debug!("Native patterns: {:?}", patterns); + let result = matches_any_pattern(&patterns, tool_name); + debug!("Native tool '{}' permission check result: {}", tool_name, result); + result + }, + // MCP tool + Some(server) => { + let patterns = filter_patterns(|p| p.starts_with('@')); + debug!("MCP patterns: {:?}", patterns); + + // Check server-level permission first: @server_name + let server_pattern = format!("@{}", server); + debug!("Checking server-level pattern: '{}'", server_pattern); + if matches_any_pattern(&patterns, &server_pattern) { + debug!("Server-level permission granted for '{}'", server_pattern); + return true; + } + + // Check tool-specific permission: @server_name/tool_name + let tool_pattern = format!("@{}{}{}", server, MCP_SERVER_TOOL_DELIMITER, tool_name); + debug!("Checking tool-specific pattern: '{}'", tool_pattern); + let result = matches_any_pattern(&patterns, &tool_pattern); + debug!("Tool-specific permission result for '{}': {}", tool_pattern, result); + result + }, + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + #[test] + fn test_native_vs_mcp_separation() { + let mut allowed = HashSet::new(); + allowed.insert("fs_*".to_string()); + allowed.insert("@git".to_string()); + + // Native patterns only apply to native tools + assert!(is_tool_in_allowlist(&allowed, "fs_read", None)); + assert!(!is_tool_in_allowlist(&allowed, "fs_read", Some("server"))); + + // MCP patterns only apply to MCP tools + assert!(is_tool_in_allowlist(&allowed, "status", Some("git"))); + assert!(!is_tool_in_allowlist(&allowed, "git", None)); + } + + #[test] + fn test_mcp_wildcard_patterns() { + let mut allowed = HashSet::new(); + allowed.insert("@*quip*".to_string()); + allowed.insert("@git/read_*".to_string()); + + assert!(is_tool_in_allowlist(&allowed, "tool", Some("quip-server"))); + assert!(is_tool_in_allowlist(&allowed, "read_file", Some("git"))); + assert!(!is_tool_in_allowlist(&allowed, "write_file", Some("git"))); + } +} From 5526328ae5e2028885bb5b4ba3468740217bcc9b Mon Sep 17 00:00:00 2001 From: Kyosuke Konishi <86059523+Konippi@users.noreply.github.com> Date: Fri, 26 Sep 2025 02:57:18 +0900 Subject: [PATCH 62/71] feat(chat): expand support for /prompts command (#2799) * feat: expand support for /prompts command * fix: prompts spec * fix: add /prompts delete cmd * fix(prompts): improve validation and user input handling * fix(prompts): manage prompt using structs --- crates/chat-cli/src/cli/chat/cli/editor.rs | 39 +- crates/chat-cli/src/cli/chat/cli/mod.rs | 2 +- crates/chat-cli/src/cli/chat/cli/prompts.rs | 1044 ++++++++++++++++++- crates/chat-cli/src/cli/chat/mod.rs | 2 +- crates/chat-cli/src/util/directories.rs | 13 + 5 files changed, 1071 insertions(+), 29 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/editor.rs b/crates/chat-cli/src/cli/chat/cli/editor.rs index fdf8d0d501..ff0433e9e4 100644 --- a/crates/chat-cli/src/cli/chat/cli/editor.rs +++ b/crates/chat-cli/src/cli/chat/cli/editor.rs @@ -84,13 +84,8 @@ impl EditorArgs { } } -/// Opens the user's preferred editor to compose a prompt -pub fn open_editor(initial_text: Option) -> Result { - // Create a temporary file with a unique name - let temp_dir = std::env::temp_dir(); - let file_name = format!("q_prompt_{}.md", Uuid::new_v4()); - let temp_file_path = temp_dir.join(file_name); - +/// Launch the user's preferred editor with the given file path +fn launch_editor(file_path: &std::path::Path) -> Result<(), ChatError> { // Get the editor from environment variable or use a default let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); @@ -104,11 +99,6 @@ pub fn open_editor(initial_text: Option) -> Result { let editor_bin = parts.remove(0); - // Write initial content to the file if provided - let initial_content = initial_text.unwrap_or_default(); - std::fs::write(&temp_file_path, &initial_content) - .map_err(|e| ChatError::Custom(format!("Failed to create temporary file: {}", e).into()))?; - // Open the editor with the parsed command and arguments let mut cmd = std::process::Command::new(editor_bin); // Add any arguments that were part of the EDITOR variable @@ -117,7 +107,7 @@ pub fn open_editor(initial_text: Option) -> Result { } // Add the file path as the last argument let status = cmd - .arg(&temp_file_path) + .arg(file_path) .status() .map_err(|e| ChatError::Custom(format!("Failed to open editor: {}", e).into()))?; @@ -125,6 +115,29 @@ pub fn open_editor(initial_text: Option) -> Result { return Err(ChatError::Custom("Editor exited with non-zero status".into())); } + Ok(()) +} + +/// Opens the user's preferred editor to edit an existing file +pub fn open_editor_file(file_path: &std::path::Path) -> Result<(), ChatError> { + launch_editor(file_path) +} + +/// Opens the user's preferred editor to compose a prompt +pub fn open_editor(initial_text: Option) -> Result { + // Create a temporary file with a unique name + let temp_dir = std::env::temp_dir(); + let file_name = format!("q_prompt_{}.md", Uuid::new_v4()); + let temp_file_path = temp_dir.join(file_name); + + // Write initial content to the file if provided + let initial_content = initial_text.unwrap_or_default(); + std::fs::write(&temp_file_path, &initial_content) + .map_err(|e| ChatError::Custom(format!("Failed to create temporary file: {}", e).into()))?; + + // Launch the editor + launch_editor(&temp_file_path)?; + // Read the content back let content = std::fs::read_to_string(&temp_file_path) .map_err(|e| ChatError::Custom(format!("Failed to read temporary file: {}", e).into()))?; diff --git a/crates/chat-cli/src/cli/chat/cli/mod.rs b/crates/chat-cli/src/cli/chat/cli/mod.rs index 1d095d1e4f..e5a20ed1c2 100644 --- a/crates/chat-cli/src/cli/chat/cli/mod.rs +++ b/crates/chat-cli/src/cli/chat/cli/mod.rs @@ -151,7 +151,7 @@ impl SlashCommand { }) }, Self::Changelog(args) => args.execute(session).await, - Self::Prompts(args) => args.execute(session).await, + Self::Prompts(args) => args.execute(os, session).await, Self::Hooks(args) => args.execute(session).await, Self::Usage(args) => args.execute(os, session).await, Self::Mcp(args) => args.execute(session).await, diff --git a/crates/chat-cli/src/cli/chat/cli/prompts.rs b/crates/chat-cli/src/cli/chat/cli/prompts.rs index 55f7661f76..c8cd310e2d 100644 --- a/crates/chat-cli/src/cli/chat/cli/prompts.rs +++ b/crates/chat-cli/src/cli/chat/cli/prompts.rs @@ -2,6 +2,9 @@ use std::collections::{ HashMap, VecDeque, }; +use std::fs; +use std::path::PathBuf; +use std::sync::LazyLock; use clap::{ Args, @@ -16,9 +19,16 @@ use crossterm::{ execute, queue, }; +use regex::Regex; +use rmcp::model::{ + PromptMessage, + PromptMessageContent, + PromptMessageRole, +}; use thiserror::Error; use unicode_width::UnicodeWidthStr; +use crate::cli::chat::cli::editor::open_editor_file; use crate::cli::chat::tool_manager::PromptBundle; use crate::cli::chat::{ ChatError, @@ -26,6 +36,17 @@ use crate::cli::chat::{ ChatState, }; use crate::mcp_client::McpClientError; +use crate::os::Os; +use crate::util::directories::{ + chat_global_prompts_dir, + chat_local_prompts_dir, +}; + +/// Maximum allowed length for prompt names +const MAX_PROMPT_NAME_LENGTH: usize = 50; + +/// Regex for validating prompt names (alphanumeric, hyphens, underscores only) +static PROMPT_NAME_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap()); #[derive(Debug, Error)] pub enum GetPromptError { @@ -49,6 +70,153 @@ pub enum GetPromptError { McpClient(#[from] McpClientError), #[error(transparent)] Service(#[from] rmcp::ServiceError), + #[error(transparent)] + Io(#[from] std::io::Error), +} + +/// Represents a single prompt (local or global) +#[derive(Debug, Clone)] +struct Prompt { + name: String, + path: PathBuf, +} + +impl std::fmt::Display for Prompt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl Prompt { + /// Create a new prompt with the given name in the specified directory + fn new(name: &str, base_dir: PathBuf) -> Self { + let path = base_dir.join(format!("{}.md", name)); + Self { + name: name.to_string(), + path, + } + } + + /// Check if the prompt file exists + fn exists(&self) -> bool { + self.path.exists() + } + + /// Load the content of the prompt file + fn load_content(&self) -> Result { + fs::read_to_string(&self.path).map_err(GetPromptError::Io) + } + + /// Save content to the prompt file + fn save_content(&self, content: &str) -> Result<(), GetPromptError> { + // Ensure parent directory exists + if let Some(parent) = self.path.parent() { + fs::create_dir_all(parent).map_err(GetPromptError::Io)?; + } + fs::write(&self.path, content).map_err(GetPromptError::Io) + } + + /// Delete the prompt file + fn delete(&self) -> Result<(), GetPromptError> { + fs::remove_file(&self.path).map_err(GetPromptError::Io) + } +} + +/// Represents both local and global prompts for a given name +#[derive(Debug)] +struct Prompts { + local: Prompt, + global: Prompt, +} + +impl Prompts { + /// Create a new Prompts instance for the given name + fn new(name: &str, os: &Os) -> Result { + let local_dir = chat_local_prompts_dir(os).map_err(|e| GetPromptError::General(e.into()))?; + let global_dir = chat_global_prompts_dir(os).map_err(|e| GetPromptError::General(e.into()))?; + + Ok(Self { + local: Prompt::new(name, local_dir), + global: Prompt::new(name, global_dir), + }) + } + + /// Check if local prompt overrides a global one (both local and global exist) + fn has_local_override(&self) -> bool { + self.local.exists() && self.global.exists() + } + + /// Find and load existing prompt content (local takes priority) + fn load_existing(&self) -> Result, GetPromptError> { + if self.local.exists() { + let content = self.local.load_content()?; + Ok(Some((content, self.local.path.clone()))) + } else if self.global.exists() { + let content = self.global.load_content()?; + Ok(Some((content, self.global.path.clone()))) + } else { + Ok(None) + } + } + + /// Get all available prompt names from both directories + fn get_available_names(os: &Os) -> Result, GetPromptError> { + let mut prompt_names = std::collections::HashSet::new(); + + // Helper function to collect prompt names from a directory + let collect_from_dir = + |dir: PathBuf, names: &mut std::collections::HashSet| -> Result<(), GetPromptError> { + if dir.exists() { + for entry in fs::read_dir(&dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("md") { + if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { + let prompt = Prompt::new(file_stem, dir.clone()); + names.insert(prompt.name); + } + } + } + } + Ok(()) + }; + + // Check global prompts + if let Ok(global_dir) = chat_global_prompts_dir(os) { + collect_from_dir(global_dir, &mut prompt_names)?; + } + + // Check local prompts + if let Ok(local_dir) = chat_local_prompts_dir(os) { + collect_from_dir(local_dir, &mut prompt_names)?; + } + + Ok(prompt_names.into_iter().collect()) + } +} + +/// Validate prompt name to ensure it's safe and follows naming conventions +fn validate_prompt_name(name: &str) -> Result<(), String> { + // Check for empty name + if name.trim().is_empty() { + return Err("Prompt name cannot be empty. Please provide a valid name for your prompt.".to_string()); + } + + // Check length limit + if name.len() > MAX_PROMPT_NAME_LENGTH { + return Err(format!( + "Prompt name must be {} characters or less. Current length: {} characters.", + MAX_PROMPT_NAME_LENGTH, + name.len() + )); + } + + // Check for valid characters using regex (alphanumeric, hyphens, underscores only) + if !PROMPT_NAME_REGEX.is_match(name) { + return Err("Prompt name can only contain letters, numbers, hyphens (-), and underscores (_). Special characters, spaces, and path separators are not allowed.".to_string()); + } + + Ok(()) } /// Command-line arguments for prompt operations @@ -69,21 +237,41 @@ pub struct PromptsArgs { } impl PromptsArgs { - pub async fn execute(self, session: &mut ChatSession) -> Result { + pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { let search_word = match &self.subcommand { Some(PromptsSubcommand::List { search_word }) => search_word.clone(), _ => None, }; if let Some(subcommand) = self.subcommand { - if matches!(subcommand, PromptsSubcommand::Get { .. }) { - return subcommand.execute(session).await; + if matches!( + subcommand, + PromptsSubcommand::Get { .. } + | PromptsSubcommand::Create { .. } + | PromptsSubcommand::Edit { .. } + | PromptsSubcommand::Remove { .. } + ) { + return subcommand.execute(os, session).await; } } let terminal_width = session.terminal_width(); let prompts = session.conversation.tool_manager.list_prompts().await?; + + // Get available prompt names + let prompt_names = Prompts::get_available_names(os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + let mut longest_name = ""; + + // Update longest_name to include local prompts + for name in &prompt_names { + if name.contains(search_word.as_deref().unwrap_or("")) { + if name.len() > longest_name.len() { + longest_name = name; + } + } + } + let arg_pos = { let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4; if optimal_case > terminal_width { @@ -146,10 +334,96 @@ impl PromptsArgs { .collect(); prompts_by_server.sort_by_key(|(server_name, _)| server_name.as_str()); + // Display prompts by category + let filtered_names: Vec<_> = prompt_names + .iter() + .filter(|name| name.contains(search_word.as_deref().unwrap_or(""))) + .collect(); + + if !filtered_names.is_empty() { + // Separate global and local prompts for display + let _global_dir = chat_global_prompts_dir(os).ok(); + let _local_dir = chat_local_prompts_dir(os).ok(); + + let mut global_prompts = Vec::new(); + let mut local_prompts = Vec::new(); + let mut overridden_globals = Vec::new(); + + for name in &filtered_names { + // Use the Prompts struct to check for conflicts + if let Ok(prompts) = Prompts::new(name, os) { + let (local_exists, global_exists) = (prompts.local.exists(), prompts.global.exists()); + + if global_exists { + global_prompts.push(name); + } + + if local_exists { + local_prompts.push(name); + // Check for overrides using has_local_override method + if global_exists { + overridden_globals.push(name); + } + } + } + } + + if !global_prompts.is_empty() { + queue!( + session.stderr, + style::SetAttribute(Attribute::Bold), + style::Print("Global (.aws/amazonq/prompts):"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + )?; + for name in &global_prompts { + queue!(session.stderr, style::Print("- "), style::Print(name))?; + queue!(session.stderr, style::Print("\n"))?; + } + } + + if !local_prompts.is_empty() { + if !global_prompts.is_empty() { + queue!(session.stderr, style::Print("\n"))?; + } + queue!( + session.stderr, + style::SetAttribute(Attribute::Bold), + style::Print("Local (.amazonq/prompts):"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + )?; + for name in &local_prompts { + let has_global_version = overridden_globals.contains(name); + queue!(session.stderr, style::Print("- "), style::Print(name),)?; + if has_global_version { + queue!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print(" (overrides global)"), + style::SetForegroundColor(Color::Reset), + )?; + } + + // Show override indicator if this local prompt overrides a global one + if overridden_globals.contains(name) { + queue!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print(" (overrides global)"), + style::SetForegroundColor(Color::Reset), + )?; + } + + queue!(session.stderr, style::Print("\n"))?; + } + } + } + for (i, (server_name, bundles)) in prompts_by_server.iter_mut().enumerate() { bundles.sort_by_key(|bundle| &bundle.prompt_get.name); - if i > 0 { + if i > 0 || !filtered_names.is_empty() { queue!(session.stderr, style::Print("\n"))?; } queue!( @@ -228,20 +502,84 @@ pub enum PromptsSubcommand { /// Optional arguments for the prompt arguments: Option>, }, + /// Create a new prompt + Create { + /// Name of the prompt to create + #[arg(short = 'n', long)] + name: String, + /// Content of the prompt (if not provided, opens editor) + #[arg(long)] + content: Option, + /// Create in global directory instead of local + #[arg(long)] + global: bool, + }, + /// Edit an existing prompt + Edit { + /// Name of the prompt to edit + name: String, + /// Edit global prompt instead of local + #[arg(long)] + global: bool, + }, + /// Remove an existing prompt + Remove { + /// Name of the prompt to remove + name: String, + /// Remove global prompt instead of local + #[arg(long)] + global: bool, + }, } impl PromptsSubcommand { - pub async fn execute(self, session: &mut ChatSession) -> Result { - let PromptsSubcommand::Get { - orig_input, - name, - arguments, - } = self - else { - unreachable!("List has already been parsed out at this point"); - }; + pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { + match self { + PromptsSubcommand::Get { + orig_input, + name, + arguments: _, + } => Self::execute_get(os, session, orig_input, name).await, + PromptsSubcommand::Create { name, content, global } => { + Self::execute_create(os, session, name, content, global).await + }, + PromptsSubcommand::Edit { name, global } => Self::execute_edit(os, session, name, global).await, + PromptsSubcommand::Remove { name, global } => Self::execute_remove(os, session, name, global).await, + PromptsSubcommand::List { .. } => { + unreachable!("List has already been parsed out at this point"); + }, + } + } - let prompts = match session.conversation.tool_manager.get_prompt(name, arguments).await { + async fn execute_get( + os: &Os, + session: &mut ChatSession, + orig_input: Option, + name: String, + ) -> Result { + // First try to find prompt (global or local) + let prompts = Prompts::new(&name, os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + if let Some((content, _)) = prompts + .load_existing() + .map_err(|e| ChatError::Custom(e.to_string().into()))? + { + // Handle local prompt + session.pending_prompts.clear(); + + // Create a PromptMessage from the local prompt content + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: PromptMessageContent::Text { text: content.clone() }, + }; + session.pending_prompts.push_back(prompt_message); + + return Ok(ChatState::HandleInput { + input: orig_input.unwrap_or_default(), + }); + } + + // If not found locally, try MCP prompts + let prompts = match session.conversation.tool_manager.get_prompt(name, None).await { Ok(resp) => resp, Err(e) => { match e { @@ -294,10 +632,688 @@ impl PromptsSubcommand { }) } + async fn execute_create( + os: &Os, + session: &mut ChatSession, + name: String, + content: Option, + global: bool, + ) -> Result { + // Create prompts instance and validate name + let mut prompts = Prompts::new(&name, os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + + if let Err(validation_error) = validate_prompt_name(&name) { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print("❌ Invalid prompt name: "), + style::Print(validation_error), + style::Print("\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("Valid names contain only letters, numbers, hyphens, and underscores (1-50 characters)\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + // Check if prompt already exists in target location + let (local_exists, global_exists) = (prompts.local.exists(), prompts.global.exists()); + let target_exists = if global { global_exists } else { local_exists }; + + if target_exists { + let location = if global { "global" } else { "local" }; + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" already exists in "), + style::Print(location), + style::Print(" directory. Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts edit "), + style::Print(&name), + if global { + style::Print(" --global") + } else { + style::Print("") + }, + style::SetForegroundColor(Color::Yellow), + style::Print(" to modify it.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + // Check if creating this prompt will cause or involve a conflict + let opposite_exists = if global { local_exists } else { global_exists }; + + if prompts.has_local_override() || opposite_exists { + let (existing_scope, _creating_scope, override_message) = if !global { + ( + "global", + "local", + "Creating this local prompt will override the global one.", + ) + } else { + ( + "local", + "global", + "The local prompt will continue to override this global one.", + ) + }; + + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("⚠ Warning: A "), + style::Print(existing_scope), + style::Print(" prompt named '"), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print("' already exists.\n"), + style::Print(override_message), + style::Print("\n"), + style::SetForegroundColor(Color::Reset), + )?; + + // Flush stderr to ensure the warning is displayed before asking for input + execute!(session.stderr)?; + + // Ask for user confirmation + let user_input = match crate::util::input("Do you want to continue? (y/n): ", None) { + Ok(input) => input.trim().to_lowercase(), + Err(_) => { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Prompt creation cancelled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }, + }; + + if user_input != "y" && user_input != "yes" { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Prompt creation cancelled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + } + + match content { + Some(content) => { + // Write the prompt file with provided content + let target_prompt = if global { + &mut prompts.global + } else { + &mut prompts.local + }; + + target_prompt + .save_content(&content) + .map_err(|e| ChatError::Custom(e.to_string().into()))?; + + let location = if global { "global" } else { "local" }; + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Created "), + style::Print(location), + style::Print(" prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Green), + style::Print(" at "), + style::SetForegroundColor(Color::DarkGrey), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + None => { + // Create file with default template and open editor + let default_content = "# Enter your prompt content here\n\nDescribe what this prompt should do..."; + let target_prompt = if global { + &mut prompts.global + } else { + &mut prompts.local + }; + + target_prompt + .save_content(default_content) + .map_err(|e| ChatError::Custom(e.to_string().into()))?; + + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("Opening editor to create prompt content...\n"), + style::SetForegroundColor(Color::Reset), + )?; + + // Try to open the editor + match open_editor_file(&target_prompt.path) { + Ok(()) => { + let location = if global { "global" } else { "local" }; + queue!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print("✓ Created "), + style::Print(location), + style::Print(" prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Green), + style::Print(" at "), + style::SetForegroundColor(Color::DarkGrey), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + Err(err) => { + queue!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("Error opening editor: "), + style::Print(err.to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("Tip: You can edit this file directly: "), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + } + }, + }; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + async fn execute_edit( + os: &Os, + session: &mut ChatSession, + name: String, + global: bool, + ) -> Result { + // Validate prompt name + if let Err(validation_error) = validate_prompt_name(&name) { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print("❌ Invalid prompt name: "), + style::Print(validation_error), + style::Print("\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + let prompts = Prompts::new(&name, os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + let (local_exists, global_exists) = (prompts.local.exists(), prompts.global.exists()); + + // Find the target prompt to edit + let target_prompt = if global { + if !global_exists { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Global prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + &prompts.global + } else { + if local_exists { + &prompts.local + } else if global_exists { + // Found global prompt, but user wants to edit local + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Local prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found, but global version exists.\n"), + style::Print("Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts edit "), + style::Print(&name), + style::Print(" --global"), + style::SetForegroundColor(Color::Yellow), + style::Print(" to edit the global version, or\n"), + style::Print("use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts create "), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" to create a local override.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } else { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + }; + + let location = if global { "global" } else { "local" }; + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("Opening editor for "), + style::Print(location), + style::Print(" prompt: "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("File: "), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + + // Try to open the editor + match open_editor_file(&target_prompt.path) { + Ok(()) => { + queue!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print("✓ Prompt edited successfully.\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + }, + Err(err) => { + queue!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("Error opening editor: "), + style::Print(err.to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("Tip: You can edit this file directly: "), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + async fn execute_remove( + os: &Os, + session: &mut ChatSession, + name: String, + global: bool, + ) -> Result { + let prompts = Prompts::new(&name, os).map_err(|e| ChatError::Custom(e.to_string().into()))?; + let (local_exists, global_exists) = (prompts.local.exists(), prompts.global.exists()); + + // Find the target prompt to remove + let target_prompt = if global { + if !global_exists { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Global prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + &prompts.global + } else { + if local_exists { + &prompts.local + } else if global_exists { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Local prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found, but global version exists.\n"), + style::Print("Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts remove "), + style::Print(&name), + style::Print(" --global"), + style::SetForegroundColor(Color::Yellow), + style::Print(" to remove the global version.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } else { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + }; + + let location = if global { "global" } else { "local" }; + + // Ask for confirmation + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("⚠ Warning: This will permanently remove the "), + style::Print(location), + style::Print(" prompt '"), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print("'.\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("File: "), + style::Print(target_prompt.path.display().to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + )?; + + // Flush stderr to ensure the warning is displayed before asking for input + execute!(session.stderr)?; + + // Ask for user confirmation + let user_input = match crate::util::input("Are you sure you want to remove this prompt? (y/n): ", None) { + Ok(input) => input.trim().to_lowercase(), + Err(_) => { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Removal cancelled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }, + }; + + if user_input != "y" && user_input != "yes" { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Removal cancelled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + // Remove the file + match target_prompt.delete() { + Ok(()) => { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::Print("✓ Removed "), + style::Print(location), + style::Print(" prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Green), + style::Print(" successfully.\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + }, + Err(err) => { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print("Error deleting prompt: "), + style::Print(err.to_string()), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + pub fn name(&self) -> &'static str { match self { PromptsSubcommand::List { .. } => "list", PromptsSubcommand::Get { .. } => "get", + PromptsSubcommand::Create { .. } => "create", + PromptsSubcommand::Edit { .. } => "edit", + PromptsSubcommand::Remove { .. } => "remove", } } } +#[cfg(test)] +mod tests { + use std::fs; + + use tempfile::TempDir; + + use super::*; + + fn create_prompt_file(dir: &PathBuf, name: &str, content: &str) { + fs::create_dir_all(dir).unwrap(); + fs::write(dir.join(format!("{}.md", name)), content).unwrap(); + } + + #[tokio::test] + async fn test_prompt_file_operations() { + let temp_dir = TempDir::new().unwrap(); + + // Create test prompts in temp directory structure + let global_dir = temp_dir.path().join(".aws/amazonq/prompts"); + let local_dir = temp_dir.path().join(".amazonq/prompts"); + + create_prompt_file(&global_dir, "global_only", "Global content"); + create_prompt_file(&global_dir, "shared", "Global shared"); + create_prompt_file(&local_dir, "local_only", "Local content"); + create_prompt_file(&local_dir, "shared", "Local shared"); + + // Test that we can read the files directly + assert_eq!( + fs::read_to_string(global_dir.join("global_only.md")).unwrap(), + "Global content" + ); + assert_eq!(fs::read_to_string(local_dir.join("shared.md")).unwrap(), "Local shared"); + } + + #[test] + fn test_local_prompts_override_global() { + let temp_dir = TempDir::new().unwrap(); + + // Create global and local directories + let global_dir = temp_dir.path().join(".aws/amazonq/prompts"); + let local_dir = temp_dir.path().join(".amazonq/prompts"); + + // Create prompts: one with same name in both directories, one unique to each + create_prompt_file(&global_dir, "shared", "Global version"); + create_prompt_file(&global_dir, "global_only", "Global only"); + create_prompt_file(&local_dir, "shared", "Local version"); + create_prompt_file(&local_dir, "local_only", "Local only"); + + // Simulate the priority logic from get_available_prompt_names() + let mut names = Vec::new(); + + // Add global prompts first + for entry in fs::read_dir(&global_dir).unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("md") { + if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { + let prompt = Prompt::new(file_stem, global_dir.clone()); + names.push(prompt.name); + } + } + } + + // Add local prompts (with override logic) + for entry in fs::read_dir(&local_dir).unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("md") { + if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { + let prompt = Prompt::new(file_stem, local_dir.clone()); + let name = prompt.name; + // Remove duplicate if it exists (local overrides global) + names.retain(|n| n != &name); + names.push(name); + } + } + } + + // Verify: should have 3 unique prompts (shared, global_only, local_only) + assert_eq!(names.len(), 3); + assert!(names.contains(&"shared".to_string())); + assert!(names.contains(&"global_only".to_string())); + assert!(names.contains(&"local_only".to_string())); + + // Verify only one "shared" exists (local overrode global) + let shared_count = names.iter().filter(|&name| name == "shared").count(); + assert_eq!(shared_count, 1); + + // Simulate load_prompt_by_name() priority: local first, then global + let shared_content = if local_dir.join("shared.md").exists() { + fs::read_to_string(local_dir.join("shared.md")).unwrap() + } else { + fs::read_to_string(global_dir.join("shared.md")).unwrap() + }; + + // Verify local version was loaded + assert_eq!(shared_content, "Local version"); + } + + #[test] + fn test_validate_prompt_name() { + // Empty name + assert!(validate_prompt_name("").is_err()); + assert!(validate_prompt_name(" ").is_err()); + + // Too long name (over 50 characters) + let long_name = "a".repeat(51); + assert!(validate_prompt_name(&long_name).is_err()); + + // Exactly 50 characters should be valid + let max_name = "a".repeat(50); + assert!(validate_prompt_name(&max_name).is_ok()); + + // Valid names with allowed characters + assert!(validate_prompt_name("valid_name").is_ok()); + assert!(validate_prompt_name("valid-name-v2").is_ok()); + + // Invalid characters (spaces, special chars, path separators) + assert!(validate_prompt_name("invalid name").is_err()); // space + assert!(validate_prompt_name("path/name").is_err()); // forward slash + assert!(validate_prompt_name("path\\name").is_err()); // backslash + assert!(validate_prompt_name("name.ext").is_err()); // dot + assert!(validate_prompt_name("name@host").is_err()); // at symbol + assert!(validate_prompt_name("name#tag").is_err()); // hash + assert!(validate_prompt_name("name$var").is_err()); // dollar sign + assert!(validate_prompt_name("name%percent").is_err()); // percent + assert!(validate_prompt_name("name&and").is_err()); // ampersand + assert!(validate_prompt_name("name*star").is_err()); // asterisk + assert!(validate_prompt_name("name+plus").is_err()); // plus + assert!(validate_prompt_name("name=equals").is_err()); // equals + assert!(validate_prompt_name("name?question").is_err()); // question mark + assert!(validate_prompt_name("name[bracket]").is_err()); // brackets + assert!(validate_prompt_name("name{brace}").is_err()); // braces + assert!(validate_prompt_name("name(paren)").is_err()); // parentheses + assert!(validate_prompt_name("name").is_err()); // angle brackets + assert!(validate_prompt_name("name|pipe").is_err()); // pipe + assert!(validate_prompt_name("name;semicolon").is_err()); // semicolon + assert!(validate_prompt_name("name:colon").is_err()); // colon + assert!(validate_prompt_name("name\"quote").is_err()); // double quote + assert!(validate_prompt_name("name'apostrophe").is_err()); // single quote + assert!(validate_prompt_name("name`backtick").is_err()); // backtick + assert!(validate_prompt_name("name~tilde").is_err()); // tilde + assert!(validate_prompt_name("name!exclamation").is_err()); // exclamation + } +} diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index e155099450..53448af02d 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -2048,7 +2048,7 @@ impl ChatSession { name: prompt_name, arguments, }; - return subcommand.execute(self).await; + return subcommand.execute(os, self).await; } else if let Some(command) = input.strip_prefix("!") { // Use platform-appropriate shell let result = if cfg!(target_os = "windows") { diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index 2bb53381d7..6833cb828b 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -44,6 +44,8 @@ type Result = std::result::Result; const WORKSPACE_AGENT_DIR_RELATIVE: &str = ".amazonq/cli-agents"; const GLOBAL_AGENT_DIR_RELATIVE_TO_HOME: &str = ".aws/amazonq/cli-agents"; +const WORKSPACE_PROMPTS_DIR_RELATIVE: &str = ".amazonq/prompts"; +const GLOBAL_PROMPTS_DIR_RELATIVE_TO_HOME: &str = ".aws/amazonq/prompts"; const CLI_BASH_HISTORY_PATH: &str = ".aws/amazonq/.cli_bash_history"; /// The directory of the users home @@ -180,6 +182,17 @@ pub fn chat_local_agent_dir(os: &Os) -> Result { Ok(cwd.join(WORKSPACE_AGENT_DIR_RELATIVE)) } +/// The directory containing global prompts +pub fn chat_global_prompts_dir(os: &Os) -> Result { + Ok(home_dir(os)?.join(GLOBAL_PROMPTS_DIR_RELATIVE_TO_HOME)) +} + +/// The directory containing local prompts +pub fn chat_local_prompts_dir(os: &Os) -> Result { + let cwd = os.env.current_dir()?; + Ok(cwd.join(WORKSPACE_PROMPTS_DIR_RELATIVE)) +} + /// Canonicalizes path given by expanding the path given pub fn canonicalizes_path(os: &Os, path_as_str: &str) -> Result { let context = |input: &str| Ok(os.env.get(input).ok()); From daed93c33de840abf1f9031fe31b16aa43337d74 Mon Sep 17 00:00:00 2001 From: Kenneth Sanchez V Date: Thu, 25 Sep 2025 11:44:21 -0700 Subject: [PATCH 63/71] [fomat] Runs Clippy and cargo +nightly fmt (#2991) Co-authored-by: Kenneth S. --- crates/chat-cli/src/cli/agent/mod.rs | 2 +- crates/chat-cli/src/cli/chat/cli/prompts.rs | 176 ++++++++++---------- 2 files changed, 86 insertions(+), 92 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 3b6c7fbbee..369932091a 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -781,7 +781,7 @@ impl Agents { let tool_trusted = self.get_active().is_some_and(|a| { let server_name = match origin { ToolOrigin::Native => None, - _ => Some(>::borrow(origin)), + ToolOrigin::McpServer(_) => Some(>::borrow(origin)), }; is_tool_in_allowlist(&a.allowed_tools, tool_name, server_name) }); diff --git a/crates/chat-cli/src/cli/chat/cli/prompts.rs b/crates/chat-cli/src/cli/chat/cli/prompts.rs index c8cd310e2d..1dffde5169 100644 --- a/crates/chat-cli/src/cli/chat/cli/prompts.rs +++ b/crates/chat-cli/src/cli/chat/cli/prompts.rs @@ -265,10 +265,8 @@ impl PromptsArgs { // Update longest_name to include local prompts for name in &prompt_names { - if name.contains(search_word.as_deref().unwrap_or("")) { - if name.len() > longest_name.len() { - longest_name = name; - } + if name.contains(search_word.as_deref().unwrap_or("")) && name.len() > longest_name.len() { + longest_name = name; } } @@ -901,54 +899,52 @@ impl PromptsSubcommand { }); } &prompts.global + } else if local_exists { + &prompts.local + } else if global_exists { + // Found global prompt, but user wants to edit local + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Local prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found, but global version exists.\n"), + style::Print("Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts edit "), + style::Print(&name), + style::Print(" --global"), + style::SetForegroundColor(Color::Yellow), + style::Print(" to edit the global version, or\n"), + style::Print("use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts create "), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" to create a local override.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); } else { - if local_exists { - &prompts.local - } else if global_exists { - // Found global prompt, but user wants to edit local - queue!( - session.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Yellow), - style::Print("Local prompt "), - style::SetForegroundColor(Color::Cyan), - style::Print(&name), - style::SetForegroundColor(Color::Yellow), - style::Print(" not found, but global version exists.\n"), - style::Print("Use "), - style::SetForegroundColor(Color::Cyan), - style::Print("/prompts edit "), - style::Print(&name), - style::Print(" --global"), - style::SetForegroundColor(Color::Yellow), - style::Print(" to edit the global version, or\n"), - style::Print("use "), - style::SetForegroundColor(Color::Cyan), - style::Print("/prompts create "), - style::Print(&name), - style::SetForegroundColor(Color::Yellow), - style::Print(" to create a local override.\n"), - style::SetForegroundColor(Color::Reset), - )?; - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - } else { - queue!( - session.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Yellow), - style::Print("Prompt "), - style::SetForegroundColor(Color::Cyan), - style::Print(&name), - style::SetForegroundColor(Color::Yellow), - style::Print(" not found.\n"), - style::SetForegroundColor(Color::Reset), - )?; - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - } + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); }; let location = if global { "global" } else { "local" }; @@ -1030,47 +1026,45 @@ impl PromptsSubcommand { }); } &prompts.global + } else if local_exists { + &prompts.local + } else if global_exists { + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Local prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found, but global version exists.\n"), + style::Print("Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts remove "), + style::Print(&name), + style::Print(" --global"), + style::SetForegroundColor(Color::Yellow), + style::Print(" to remove the global version.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); } else { - if local_exists { - &prompts.local - } else if global_exists { - queue!( - session.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Yellow), - style::Print("Local prompt "), - style::SetForegroundColor(Color::Cyan), - style::Print(&name), - style::SetForegroundColor(Color::Yellow), - style::Print(" not found, but global version exists.\n"), - style::Print("Use "), - style::SetForegroundColor(Color::Cyan), - style::Print("/prompts remove "), - style::Print(&name), - style::Print(" --global"), - style::SetForegroundColor(Color::Yellow), - style::Print(" to remove the global version.\n"), - style::SetForegroundColor(Color::Reset), - )?; - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - } else { - queue!( - session.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Yellow), - style::Print("Prompt "), - style::SetForegroundColor(Color::Cyan), - style::Print(&name), - style::SetForegroundColor(Color::Yellow), - style::Print(" not found.\n"), - style::SetForegroundColor(Color::Reset), - )?; - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - } + queue!( + session.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(&name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); }; let location = if global { "global" } else { "local" }; From 814f149316d57820af72c3be8fc528f7b7e7fb92 Mon Sep 17 00:00:00 2001 From: abhraina-aws Date: Thu, 25 Sep 2025 15:19:58 -0700 Subject: [PATCH 64/71] feat: add context usage percentage indicator to prompt (#2994) Add experimental feature to show context window usage as a percentage in the chat prompt (e.g., "[rust-agent] 6% >"). The percentage is color-coded: green (<50%), yellow (50-89%), red (90-100%). The feature is disabled by default and can be enabled via /experiment. --- .../chat-cli/src/cli/chat/cli/experiment.rs | 5 + crates/chat-cli/src/cli/chat/cli/usage.rs | 158 +++++++++++------- crates/chat-cli/src/cli/chat/mod.rs | 20 ++- crates/chat-cli/src/cli/chat/prompt.rs | 12 ++ crates/chat-cli/src/cli/chat/prompt_parser.rs | 92 ++++++++-- crates/chat-cli/src/database/settings.rs | 4 + 6 files changed, 215 insertions(+), 76 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/experiment.rs b/crates/chat-cli/src/cli/chat/cli/experiment.rs index 0dd981d9d3..121419b244 100644 --- a/crates/chat-cli/src/cli/chat/cli/experiment.rs +++ b/crates/chat-cli/src/cli/chat/cli/experiment.rs @@ -50,6 +50,11 @@ static AVAILABLE_EXPERIMENTS: &[Experiment] = &[ description: "Enables Q to create todo lists that can be viewed and managed using /todos", setting_key: Setting::EnabledTodoList, }, + Experiment { + name: "Context Usage Indicator", + description: "Shows context usage percentage in the prompt (e.g., [rust-agent] 6% >)", + setting_key: Setting::EnabledContextUsageIndicator, + }, ]; #[derive(Debug, PartialEq, Args)] diff --git a/crates/chat-cli/src/cli/chat/cli/usage.rs b/crates/chat-cli/src/cli/chat/cli/usage.rs index 6e6fe0c961..4240bf4c2f 100644 --- a/crates/chat-cli/src/cli/chat/cli/usage.rs +++ b/crates/chat-cli/src/cli/chat/cli/usage.rs @@ -21,6 +21,60 @@ use crate::cli::chat::{ }; use crate::os::Os; +/// Detailed usage data for context window analysis +#[derive(Debug)] +pub struct DetailedUsageData { + pub total_tokens: TokenCount, + pub context_tokens: TokenCount, + pub assistant_tokens: TokenCount, + pub user_tokens: TokenCount, + pub tools_tokens: TokenCount, + pub context_window_size: usize, + pub dropped_context_files: Vec<(String, String)>, +} + +/// Calculate usage percentage from token counts +pub fn calculate_usage_percentage(tokens: TokenCount, context_window_size: usize) -> f32 { + (tokens.value() as f32 / context_window_size as f32) * 100.0 +} + +/// Get detailed usage data for context window analysis +pub async fn get_detailed_usage_data(session: &mut ChatSession, os: &Os) -> Result { + let context_window_size = context_window_tokens(session.conversation.model_info.as_ref()); + + let state = session + .conversation + .backend_conversation_state(os, true, &mut std::io::stderr()) + .await?; + + let data = state.calculate_conversation_size(); + let tool_specs_json: String = state + .tools + .values() + .filter_map(|s| serde_json::to_string(s).ok()) + .collect::>() + .join(""); + let tools_char_count: CharCount = tool_specs_json.len().into(); + let total_tokens: TokenCount = + (data.context_messages + data.user_messages + data.assistant_messages + tools_char_count).into(); + + Ok(DetailedUsageData { + total_tokens, + context_tokens: data.context_messages.into(), + assistant_tokens: data.assistant_messages.into(), + user_tokens: data.user_messages.into(), + tools_tokens: tools_char_count.into(), + context_window_size, + dropped_context_files: state.dropped_context_files, + }) +} + +/// Get total usage percentage (simple interface for prompt generation) +pub async fn get_total_usage_percentage(session: &mut ChatSession, os: &Os) -> Result { + let data = get_detailed_usage_data(session, os).await?; + Ok(calculate_usage_percentage(data.total_tokens, data.context_window_size)) +} + /// Arguments for the usage command that displays token usage statistics and context window /// information. /// @@ -32,12 +86,9 @@ pub struct UsageArgs; impl UsageArgs { pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - let state = session - .conversation - .backend_conversation_state(os, true, &mut session.stderr) - .await?; + let usage_data = get_detailed_usage_data(session, os).await?; - if !state.dropped_context_files.is_empty() { + if !usage_data.dropped_context_files.is_empty() { execute!( session.stderr, style::SetForegroundColor(Color::DarkYellow), @@ -50,33 +101,18 @@ impl UsageArgs { )?; } - let data = state.calculate_conversation_size(); - let tool_specs_json: String = state - .tools - .values() - .filter_map(|s| serde_json::to_string(s).ok()) - .collect::>() - .join(""); - let context_token_count: TokenCount = data.context_messages.into(); - let assistant_token_count: TokenCount = data.assistant_messages.into(); - let user_token_count: TokenCount = data.user_messages.into(); - let tools_char_count: CharCount = tool_specs_json.len().into(); // usize → CharCount - let tools_token_count: TokenCount = tools_char_count.into(); // CharCount → TokenCount - let total_token_used: TokenCount = - (data.context_messages + data.user_messages + data.assistant_messages + tools_char_count).into(); let window_width = session.terminal_width(); // set a max width for the progress bar for better aesthetic let progress_bar_width = std::cmp::min(window_width, 80); - let context_window_size = context_window_tokens(session.conversation.model_info.as_ref()); - let context_width = - ((context_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; - let assistant_width = - ((assistant_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; - let tools_width = - ((tools_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; - let user_width = - ((user_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; + let context_width = ((usage_data.context_tokens.value() as f64 / usage_data.context_window_size as f64) + * progress_bar_width as f64) as usize; + let assistant_width = ((usage_data.assistant_tokens.value() as f64 / usage_data.context_window_size as f64) + * progress_bar_width as f64) as usize; + let tools_width = ((usage_data.tools_tokens.value() as f64 / usage_data.context_window_size as f64) + * progress_bar_width as f64) as usize; + let user_width = ((usage_data.user_tokens.value() as f64 / usage_data.context_window_size as f64) + * progress_bar_width as f64) as usize; let left_over_width = progress_bar_width - std::cmp::min( @@ -86,44 +122,45 @@ impl UsageArgs { let is_overflow = (context_width + assistant_width + user_width + tools_width) > progress_bar_width; + let total_percentage = calculate_usage_percentage(usage_data.total_tokens, usage_data.context_window_size); + if is_overflow { queue!( session.stderr, style::Print(format!( "\nCurrent context window ({} of {}k tokens used)\n", - total_token_used, - context_window_size / 1000 + usage_data.total_tokens, + usage_data.context_window_size / 1000 )), style::SetForegroundColor(Color::DarkRed), style::Print("█".repeat(progress_bar_width)), style::SetForegroundColor(Color::Reset), style::Print(" "), - style::Print(format!( - "{:.2}%", - (total_token_used.value() as f32 / context_window_size as f32) * 100.0 - )), + style::Print(format!("{:.2}%", total_percentage)), )?; } else { queue!( session.stderr, style::Print(format!( "\nCurrent context window ({} of {}k tokens used)\n", - total_token_used, - context_window_size / 1000 + usage_data.total_tokens, + usage_data.context_window_size / 1000 )), // Context files style::SetForegroundColor(Color::DarkCyan), // add a nice visual to mimic "tiny" progress, so the overrall progress bar doesn't look too // empty - style::Print("|".repeat(if context_width == 0 && *context_token_count > 0 { - 1 - } else { - 0 - })), + style::Print( + "|".repeat(if context_width == 0 && usage_data.context_tokens.value() > 0 { + 1 + } else { + 0 + }) + ), style::Print("█".repeat(context_width)), // Tools style::SetForegroundColor(Color::DarkRed), - style::Print("|".repeat(if tools_width == 0 && *tools_token_count > 0 { + style::Print("|".repeat(if tools_width == 0 && usage_data.tools_tokens.value() > 0 { 1 } else { 0 @@ -131,24 +168,27 @@ impl UsageArgs { style::Print("█".repeat(tools_width)), // Assistant responses style::SetForegroundColor(Color::Blue), - style::Print("|".repeat(if assistant_width == 0 && *assistant_token_count > 0 { + style::Print( + "|".repeat(if assistant_width == 0 && usage_data.assistant_tokens.value() > 0 { + 1 + } else { + 0 + }) + ), + style::Print("█".repeat(assistant_width)), + // User prompts + style::SetForegroundColor(Color::Magenta), + style::Print("|".repeat(if user_width == 0 && usage_data.user_tokens.value() > 0 { 1 } else { 0 })), - style::Print("█".repeat(assistant_width)), - // User prompts - style::SetForegroundColor(Color::Magenta), - style::Print("|".repeat(if user_width == 0 && *user_token_count > 0 { 1 } else { 0 })), style::Print("█".repeat(user_width)), style::SetForegroundColor(Color::DarkGrey), style::Print("█".repeat(left_over_width)), style::Print(" "), style::SetForegroundColor(Color::Reset), - style::Print(format!( - "{:.2}%", - (total_token_used.value() as f32 / context_window_size as f32) * 100.0 - )), + style::Print(format!("{:.2}%", total_percentage)), )?; } @@ -161,32 +201,32 @@ impl UsageArgs { style::SetForegroundColor(Color::Reset), style::Print(format!( "~{} tokens ({:.2}%)\n", - context_token_count, - (context_token_count.value() as f32 / context_window_size as f32) * 100.0 + usage_data.context_tokens, + calculate_usage_percentage(usage_data.context_tokens, usage_data.context_window_size) )), style::SetForegroundColor(Color::DarkRed), style::Print("█ Tools: "), style::SetForegroundColor(Color::Reset), style::Print(format!( " ~{} tokens ({:.2}%)\n", - tools_token_count, - (tools_token_count.value() as f32 / context_window_size as f32) * 100.0 + usage_data.tools_tokens, + calculate_usage_percentage(usage_data.tools_tokens, usage_data.context_window_size) )), style::SetForegroundColor(Color::Blue), style::Print("█ Q responses: "), style::SetForegroundColor(Color::Reset), style::Print(format!( " ~{} tokens ({:.2}%)\n", - assistant_token_count, - (assistant_token_count.value() as f32 / context_window_size as f32) * 100.0 + usage_data.assistant_tokens, + calculate_usage_percentage(usage_data.assistant_tokens, usage_data.context_window_size) )), style::SetForegroundColor(Color::Magenta), style::Print("█ Your prompts: "), style::SetForegroundColor(Color::Reset), style::Print(format!( " ~{} tokens ({:.2}%)\n\n", - user_token_count, - (user_token_count.value() as f32 / context_window_size as f32) * 100.0 + usage_data.user_tokens, + calculate_usage_percentage(usage_data.user_tokens, usage_data.context_window_size) )), )?; diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 53448af02d..cd198645ab 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -1920,7 +1920,7 @@ impl ChatSession { style::SetForegroundColor(Color::Reset), style::SetAttribute(Attribute::Reset) )?; - let prompt = self.generate_tool_trust_prompt(); + let prompt = self.generate_tool_trust_prompt(os).await; let user_input = match self.read_user_input(&prompt, false) { Some(input) => input, None => return Ok(ChatState::Exit), @@ -3115,11 +3115,25 @@ impl ChatSession { } /// Helper function to generate a prompt based on the current context - fn generate_tool_trust_prompt(&mut self) -> String { + async fn generate_tool_trust_prompt(&mut self, os: &Os) -> String { let profile = self.conversation.current_profile().map(|s| s.to_string()); let all_trusted = self.all_tools_trusted(); let tangent_mode = self.conversation.is_in_tangent_mode(); - prompt::generate_prompt(profile.as_deref(), all_trusted, tangent_mode) + + // Check if context usage indicator is enabled + let usage_percentage = if os + .database + .settings + .get_bool(crate::database::settings::Setting::EnabledContextUsageIndicator) + .unwrap_or(false) + { + use crate::cli::chat::cli::usage::get_total_usage_percentage; + get_total_usage_percentage(self, os).await.ok() + } else { + None + }; + + prompt::generate_prompt(profile.as_deref(), all_trusted, tangent_mode, usage_percentage) } async fn send_tool_use_telemetry(&mut self, os: &Os) { diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index 2ee6640baa..210a0f635c 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -408,6 +408,18 @@ impl Highlighter for ChatHelper { result.push_str(&format!("[{}] ", profile).cyan().to_string()); } + // Add percentage part if present (colored by usage level) + if let Some(percentage) = components.usage_percentage { + let colored_percentage = if percentage < 50.0 { + format!("{}% ", percentage as u32).green() + } else if percentage < 90.0 { + format!("{}% ", percentage as u32).yellow() + } else { + format!("{}% ", percentage as u32).red() + }; + result.push_str(&colored_percentage.to_string()); + } + // Add tangent indicator if present (yellow) if components.tangent_mode { result.push_str(&"↯ ".yellow().to_string()); diff --git a/crates/chat-cli/src/cli/chat/prompt_parser.rs b/crates/chat-cli/src/cli/chat/prompt_parser.rs index daf67a9859..87f0639054 100644 --- a/crates/chat-cli/src/cli/chat/prompt_parser.rs +++ b/crates/chat-cli/src/cli/chat/prompt_parser.rs @@ -6,15 +6,16 @@ pub struct PromptComponents { pub profile: Option, pub warning: bool, pub tangent_mode: bool, + pub usage_percentage: Option, } /// Parse prompt components from a plain text prompt pub fn parse_prompt_components(prompt: &str) -> Option { - // Expected format: "[agent] !> " or "> " or "!> " or "[agent] ↯ > " or "↯ > " or "[agent] ↯ !> " - // etc. + // Expected format: "[agent] 6% !> " or "> " or "!> " or "[agent] ↯ > " or "6% ↯ > " etc. let mut profile = None; let mut warning = false; let mut tangent_mode = false; + let mut usage_percentage = None; let mut remaining = prompt.trim(); // Check for agent pattern [agent] first @@ -28,6 +29,17 @@ pub fn parse_prompt_components(prompt: &str) -> Option { } } + // Check for percentage pattern (e.g., "6% ") + if let Some(percent_pos) = remaining.find('%') { + let before_percent = &remaining[..percent_pos]; + if let Ok(percentage) = before_percent.trim().parse::() { + usage_percentage = Some(percentage); + if let Some(space_after_percent) = remaining[percent_pos..].find(' ') { + remaining = remaining[percent_pos + space_after_percent + 1..].trim_start(); + } + } + } + // Check for tangent mode ↯ first if let Some(after_tangent) = remaining.strip_prefix('↯') { tangent_mode = true; @@ -46,13 +58,19 @@ pub fn parse_prompt_components(prompt: &str) -> Option { profile, warning, tangent_mode, + usage_percentage, }) } else { None } } -pub fn generate_prompt(current_profile: Option<&str>, warning: bool, tangent_mode: bool) -> String { +pub fn generate_prompt( + current_profile: Option<&str>, + warning: bool, + tangent_mode: bool, + usage_percentage: Option, +) -> String { // Generate plain text prompt that will be colored by highlight_prompt let warning_symbol = if warning { "!" } else { "" }; let profile_part = current_profile @@ -60,10 +78,12 @@ pub fn generate_prompt(current_profile: Option<&str>, warning: bool, tangent_mod .map(|p| format!("[{p}] ")) .unwrap_or_default(); + let percentage_part = usage_percentage.map(|p| format!("{:.0}% ", p)).unwrap_or_default(); + if tangent_mode { - format!("{profile_part}↯ {warning_symbol}> ") + format!("{profile_part}{percentage_part}↯ {warning_symbol}> ") } else { - format!("{profile_part}{warning_symbol}> ") + format!("{profile_part}{percentage_part}{warning_symbol}> ") } } @@ -74,26 +94,43 @@ mod tests { #[test] fn test_generate_prompt() { // Test default prompt (no profile) - assert_eq!(generate_prompt(None, false, false), "> "); + assert_eq!(generate_prompt(None, false, false, None), "> "); // Test default prompt with warning - assert_eq!(generate_prompt(None, true, false), "!> "); + assert_eq!(generate_prompt(None, true, false, None), "!> "); // Test tangent mode - assert_eq!(generate_prompt(None, false, true), "↯ > "); + assert_eq!(generate_prompt(None, false, true, None), "↯ > "); // Test tangent mode with warning - assert_eq!(generate_prompt(None, true, true), "↯ !> "); + assert_eq!(generate_prompt(None, true, true, None), "↯ !> "); // Test default profile (should be same as no profile) - assert_eq!(generate_prompt(Some(DEFAULT_AGENT_NAME), false, false), "> "); + assert_eq!(generate_prompt(Some(DEFAULT_AGENT_NAME), false, false, None), "> "); // Test custom profile - assert_eq!(generate_prompt(Some("test-profile"), false, false), "[test-profile] > "); + assert_eq!( + generate_prompt(Some("test-profile"), false, false, None), + "[test-profile] > " + ); // Test custom profile with tangent mode assert_eq!( - generate_prompt(Some("test-profile"), false, true), + generate_prompt(Some("test-profile"), false, true, None), "[test-profile] ↯ > " ); // Test another custom profile with warning - assert_eq!(generate_prompt(Some("dev"), true, false), "[dev] !> "); + assert_eq!(generate_prompt(Some("dev"), true, false, None), "[dev] !> "); // Test custom profile with warning and tangent mode - assert_eq!(generate_prompt(Some("dev"), true, true), "[dev] ↯ !> "); + assert_eq!(generate_prompt(Some("dev"), true, true, None), "[dev] ↯ !> "); + // Test custom profile with usage percentage + assert_eq!( + generate_prompt(Some("rust-agent"), false, false, Some(6.2)), + "[rust-agent] 6% > " + ); + // Test custom profile with usage percentage and warning + assert_eq!( + generate_prompt(Some("rust-agent"), true, false, Some(15.7)), + "[rust-agent] 16% !> " + ); + // Test usage percentage without profile + assert_eq!(generate_prompt(None, false, false, Some(25.3)), "25% > "); + // Test usage percentage with tangent mode + assert_eq!(generate_prompt(None, false, true, Some(8.9)), "9% ↯ > "); } #[test] @@ -103,48 +140,75 @@ mod tests { assert!(components.profile.is_none()); assert!(!components.warning); assert!(!components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test warning prompt let components = parse_prompt_components("!> ").unwrap(); assert!(components.profile.is_none()); assert!(components.warning); assert!(!components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test tangent mode let components = parse_prompt_components("↯ > ").unwrap(); assert!(components.profile.is_none()); assert!(!components.warning); assert!(components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test tangent mode with warning let components = parse_prompt_components("↯ !> ").unwrap(); assert!(components.profile.is_none()); assert!(components.warning); assert!(components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test profile prompt let components = parse_prompt_components("[test] > ").unwrap(); assert_eq!(components.profile.as_deref(), Some("test")); assert!(!components.warning); assert!(!components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test profile with warning let components = parse_prompt_components("[dev] !> ").unwrap(); assert_eq!(components.profile.as_deref(), Some("dev")); assert!(components.warning); assert!(!components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test profile with tangent mode let components = parse_prompt_components("[dev] ↯ > ").unwrap(); assert_eq!(components.profile.as_deref(), Some("dev")); assert!(!components.warning); assert!(components.tangent_mode); + assert!(components.usage_percentage.is_none()); // Test profile with warning and tangent mode let components = parse_prompt_components("[dev] ↯ !> ").unwrap(); assert_eq!(components.profile.as_deref(), Some("dev")); assert!(components.warning); assert!(components.tangent_mode); + assert!(components.usage_percentage.is_none()); + + // Test prompts with percentages + let components = parse_prompt_components("[rust-agent] 6% > ").unwrap(); + assert_eq!(components.profile.as_deref(), Some("rust-agent")); + assert!(!components.warning); + assert!(!components.tangent_mode); + assert_eq!(components.usage_percentage, Some(6.0)); + + let components = parse_prompt_components("25% > ").unwrap(); + assert!(components.profile.is_none()); + assert!(!components.warning); + assert!(!components.tangent_mode); + assert_eq!(components.usage_percentage, Some(25.0)); + + let components = parse_prompt_components("8% ↯ > ").unwrap(); + assert!(components.profile.is_none()); + assert!(!components.warning); + assert!(components.tangent_mode); + assert_eq!(components.usage_percentage, Some(8.0)); // Test invalid prompt assert!(parse_prompt_components("invalid").is_none()); diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs index bc005438cb..2ea8510e2a 100644 --- a/crates/chat-cli/src/database/settings.rs +++ b/crates/chat-cli/src/database/settings.rs @@ -67,6 +67,8 @@ pub enum Setting { McpNoInteractiveTimeout, #[strum(message = "Track previously loaded MCP servers (boolean)")] McpLoadedBefore, + #[strum(message = "Show context usage percentage in prompt (boolean)")] + EnabledContextUsageIndicator, #[strum(message = "Default AI model for conversations (string)")] ChatDefaultModel, #[strum(message = "Disable markdown formatting in chat (boolean)")] @@ -115,6 +117,7 @@ impl AsRef for Setting { Self::ChatDisableAutoCompaction => "chat.disableAutoCompaction", Self::ChatEnableHistoryHints => "chat.enableHistoryHints", Self::EnabledTodoList => "chat.enableTodoList", + Self::EnabledContextUsageIndicator => "chat.enableContextUsageIndicator", } } } @@ -161,6 +164,7 @@ impl TryFrom<&str> for Setting { "chat.disableAutoCompaction" => Ok(Self::ChatDisableAutoCompaction), "chat.enableHistoryHints" => Ok(Self::ChatEnableHistoryHints), "chat.enableTodoList" => Ok(Self::EnabledTodoList), + "chat.enableContextUsageIndicator" => Ok(Self::EnabledContextUsageIndicator), _ => Err(DatabaseError::InvalidSetting(value.to_string())), } } From 0c235260102c2f82cf16e4722faef89814ba7572 Mon Sep 17 00:00:00 2001 From: Kenneth Sanchez V Date: Thu, 25 Sep 2025 15:28:15 -0700 Subject: [PATCH 65/71] [fix] Fixes issues with Tool Input parsing. (#2986) * [fix] Fixes issues with Tool Input parsing. * Ocassionally the model will generate a tool use which parameters are not a valid json. When this happens it corrupts the conversation history. * Here we first avoid storing the tool use and add the propert validation logic to the conversation history. * adds validation logic to safety * [fix] Update to use a new RecvErrorKind instead of custom error handling. * [fix] Gives visual hint to the user, that request is being retried. --------- Co-authored-by: Kenneth S. --- crates/chat-cli/src/cli/chat/mod.rs | 47 ++++++++++ crates/chat-cli/src/cli/chat/parser.rs | 119 ++++++++++++++++++++++++- 2 files changed, 165 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index cd198645ab..a576440c80 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -2694,6 +2694,53 @@ impl ChatSession { .await?, )); }, + RecvErrorKind::ToolValidationError { + tool_use_id, + name, + message, + error_message, + } => { + self.send_chat_telemetry( + os, + TelemetryResult::Failed, + Some(reason), + Some(reason_desc), + status_code, + false, // We retry the request, so don't end the current turn yet. + ) + .await; + + error!( + recv_error.request_metadata.request_id, + tool_use_id, name, error_message, "Tool validation failed" + ); + self.conversation + .push_assistant_message(os, *message, Some(recv_error.request_metadata)); + let tool_results = vec![ToolUseResult { + tool_use_id, + content: vec![ToolUseResultBlock::Text(format!( + "Tool validation failed: {}. Please ensure tool arguments are provided as a valid JSON object.", + error_message + ))], + status: ToolResultStatus::Error, + }]; + // User hint of what happened + let _ = queue!( + self.stdout, + style::Print("\n\n"), + style::SetForegroundColor(Color::Yellow), + style::Print(format!("Tool validation failed: {}\n Retrying the request...", error_message)), + style::ResetColor, + style::Print("\n"), + ); + self.conversation.add_tool_results(tool_results); + self.send_tool_use_telemetry(os).await; + return Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )); + }, _ => { self.send_chat_telemetry( os, diff --git a/crates/chat-cli/src/cli/chat/parser.rs b/crates/chat-cli/src/cli/chat/parser.rs index 55e57519ba..2e0cdfb03c 100644 --- a/crates/chat-cli/src/cli/chat/parser.rs +++ b/crates/chat-cli/src/cli/chat/parser.rs @@ -91,6 +91,7 @@ impl RecvError { RecvErrorKind::StreamTimeout { .. } => None, RecvErrorKind::UnexpectedToolUseEos { .. } => None, RecvErrorKind::Cancelled => None, + RecvErrorKind::ToolValidationError { .. } => None, } } } @@ -103,6 +104,7 @@ impl ReasonCode for RecvError { RecvErrorKind::StreamTimeout { .. } => "RecvErrorStreamTimeout".to_string(), RecvErrorKind::UnexpectedToolUseEos { .. } => "RecvErrorUnexpectedToolUseEos".to_string(), RecvErrorKind::Cancelled => "Interrupted".to_string(), + RecvErrorKind::ToolValidationError { .. } => "RecvErrorToolValidation".to_string(), } } } @@ -151,6 +153,14 @@ pub enum RecvErrorKind { /// The stream processing task was cancelled #[error("Stream handling was cancelled")] Cancelled, + /// Tool validation failed due to invalid arguments + #[error("Tool validation failed for tool: {} with id: {}", .name, .tool_use_id)] + ToolValidationError { + tool_use_id: String, + name: String, + message: Box, + error_message: String, + }, } /// Represents a response stream from a call to the SendMessage API. @@ -472,7 +482,43 @@ impl ResponseParser { } let args = match serde_json::from_str(&tool_string) { - Ok(args) => args, + Ok(args) => { + // Ensure we have a valid JSON object + match args { + serde_json::Value::Object(_) => args, + _ => { + error!("Received non-object JSON for tool arguments: {:?}", args); + let warning_args = serde_json::Value::Object( + [( + "key".to_string(), + serde_json::Value::String( + "WARNING: the actual tool use arguments were not a valid JSON object".to_string(), + ), + )] + .into_iter() + .collect(), + ); + self.tool_uses.push(AssistantToolUse { + id: id.clone(), + name: name.clone(), + orig_name: name.clone(), + args: warning_args.clone(), + orig_args: warning_args.clone(), + }); + let message = Box::new(AssistantMessage::new_tool_use( + Some(self.message_id.clone()), + std::mem::take(&mut self.assistant_text), + self.tool_uses.clone().into_iter().collect(), + )); + return Err(self.error(RecvErrorKind::ToolValidationError { + tool_use_id: id, + name, + message, + error_message: format!("Expected JSON object, got: {:?}", args), + })); + }, + } + }, Err(err) if !tool_string.is_empty() => { // If we failed deserializing after waiting for a long time, then this is most // likely bedrock responding with a stop event for some reason without actually @@ -753,4 +799,75 @@ mod tests { "assistant text preceding a code reference should be ignored as this indicates licensed code is being returned" ); } + + #[tokio::test] + async fn test_response_parser_avoid_invalid_json() { + let content_to_ignore = "IGNORE ME PLEASE"; + let tool_use_id = "TEST_ID".to_string(); + let tool_name = "execute_bash".to_string(); + let tool_args = serde_json::json!("invalid json").to_string(); + let mut events = vec![ + ChatResponseStream::AssistantResponseEvent { + content: "hi".to_string(), + }, + ChatResponseStream::AssistantResponseEvent { + content: " there".to_string(), + }, + ChatResponseStream::AssistantResponseEvent { + content: content_to_ignore.to_string(), + }, + ChatResponseStream::CodeReferenceEvent(()), + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: tool_name.clone(), + input: None, + stop: None, + }, + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: tool_name.clone(), + input: Some(tool_args), + stop: None, + }, + ]; + events.reverse(); + let mock = SendMessageOutput::Mock(events); + let mut parser = ResponseParser::new( + mock, + "".to_string(), + None, + 1, + vec![], + mpsc::channel(32).0, + Instant::now(), + SystemTime::now(), + CancellationToken::new(), + Arc::new(Mutex::new(None)), + ); + + let mut output = String::new(); + let mut found_validation_error = false; + for _ in 0..5 { + match parser.recv().await { + Ok(event) => { + output.push_str(&format!("{:?}", event)); + }, + Err(recv_error) => { + if matches!(recv_error.source, RecvErrorKind::ToolValidationError { .. }) { + found_validation_error = true; + } + break; + }, + } + } + + assert!( + !output.contains(content_to_ignore), + "assistant text preceding a code reference should be ignored as this indicates licensed code is being returned" + ); + assert!( + found_validation_error, + "Expected to find tool validation error for non-object JSON" + ); + } } From 1f4f96f89437d720803e6e17a45d00a83ea7a765 Mon Sep 17 00:00:00 2001 From: evanliu048 Date: Thu, 25 Sep 2025 16:12:09 -0700 Subject: [PATCH 66/71] feat: Adds checkpointing functionality using Git CLI commands (#2896) * (in progress) Implement checkpointing using git CLI commands * feat: Add new checkpointing functionality using git CLI Updates: - Only works if the user has git installed - Supports auto initialization if the user is in a git repo, manual if not - UI ported over from dedicated file tools implementation * feat: Add user message for turn-level checkpoints, clean command Updates: - The clean subcommand will delete the shadow repo - The description for turn-level checkpoints is a truncated version of the user's last message * fix: Fix shadow repo deletion logic Updates: - Running the clean subcommand now properly deletes the entire shadow repo for both automatic and manual modes * chore: Run formatter and fix clippy warnings * feat: Add checkpoint diff Updates: - Users can now view diffs between checkpoints - Fixed tool-level checkpoint display handling * fix: Fix last messsage handling for checkpoints Updates: - Checkpoints now (hopefully) correctly display the correct turn-specific user message - Added slash command auto completion * fix: Fix commit message handling again * chore: Run formatter * Removed old comment * define a global capture dirctory * revise the capture path * fix cpature clean bug * add a clean all flag * add auto drop method for capture feature * support file details when expand * add the file summary when list and expand * revise structure and print no diff msg * delete all flag, add summry when fs read * refactor code * revise ui * add capture into experiement * clippy * rename to checkpoint * reverse false renaming * recover history * disable tangent mode in checkpoint * fix cr * nit: keep checkpoint name * allow both tangent & checkpoint enabled * ci --------- Co-authored-by: kiran-garre --- crates/chat-cli/src/cli/chat/checkpoint.rs | 422 +++++++++++++ .../chat-cli/src/cli/chat/cli/checkpoint.rs | 573 ++++++++++++++++++ .../chat-cli/src/cli/chat/cli/experiment.rs | 9 + crates/chat-cli/src/cli/chat/cli/mod.rs | 6 + crates/chat-cli/src/cli/chat/cli/tangent.rs | 33 + crates/chat-cli/src/cli/chat/conversation.rs | 21 + crates/chat-cli/src/cli/chat/mod.rs | 202 +++++- .../chat-cli/src/cli/chat/tools/fs_write.rs | 2 +- crates/chat-cli/src/cli/chat/tools/mod.rs | 10 + crates/chat-cli/src/database/settings.rs | 8 + crates/chat-cli/src/util/directories.rs | 5 + 11 files changed, 1287 insertions(+), 4 deletions(-) create mode 100644 crates/chat-cli/src/cli/chat/checkpoint.rs create mode 100644 crates/chat-cli/src/cli/chat/cli/checkpoint.rs diff --git a/crates/chat-cli/src/cli/chat/checkpoint.rs b/crates/chat-cli/src/cli/chat/checkpoint.rs new file mode 100644 index 0000000000..c5fb0b8183 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/checkpoint.rs @@ -0,0 +1,422 @@ +use std::collections::{ + HashMap, + VecDeque, +}; +use std::path::{ + Path, + PathBuf, +}; +use std::process::{ + Command, + Output, +}; + +use chrono::{ + DateTime, + Local, +}; +use crossterm::style::Stylize; +use eyre::{ + Result, + bail, + eyre, +}; +use serde::{ + Deserialize, + Serialize, +}; + +use crate::cli::ConversationState; +use crate::cli::chat::conversation::HistoryEntry; +use crate::os::Os; + +/// Manages a shadow git repository for tracking and restoring workspace changes +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CheckpointManager { + /// Path to the shadow (bare) git repository + pub shadow_repo_path: PathBuf, + + /// All checkpoints in chronological order + pub checkpoints: Vec, + + /// Fast lookup: tag -> index in checkpoints vector + pub tag_index: HashMap, + + /// Track the current turn number + pub current_turn: usize, + + /// Track tool uses within current turn + pub tools_in_turn: usize, + + /// Last user message for commit description + pub pending_user_message: Option, + + /// Whether the message has been locked for this turn + pub message_locked: bool, + + /// Cached file change statistics + #[serde(default)] + pub file_stats_cache: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct FileStats { + pub added: usize, + pub modified: usize, + pub deleted: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Checkpoint { + pub tag: String, + pub timestamp: DateTime, + pub description: String, + pub history_snapshot: VecDeque, + pub is_turn: bool, + pub tool_name: Option, +} + +impl CheckpointManager { + /// Initialize checkpoint manager automatically (when in a git repo) + pub async fn auto_init( + os: &Os, + shadow_path: impl AsRef, + current_history: &VecDeque, + ) -> Result { + if !is_git_installed() { + bail!("Git is not installed. Checkpoints require git to function."); + } + if !is_in_git_repo() { + bail!("Not in a git repository. Use '/checkpoint init' to manually enable checkpoints."); + } + + let manager = Self::manual_init(os, shadow_path, current_history).await?; + Ok(manager) + } + + /// Initialize checkpoint manager manually + pub async fn manual_init( + os: &Os, + path: impl AsRef, + current_history: &VecDeque, + ) -> Result { + let path = path.as_ref(); + os.fs.create_dir_all(path).await?; + + // Initialize bare repository + run_git(path, false, &["init", "--bare", &path.to_string_lossy()])?; + + // Configure git + configure_git(&path.to_string_lossy())?; + + // Create initial checkpoint + stage_commit_tag(&path.to_string_lossy(), "Initial state", "0")?; + + let initial_checkpoint = Checkpoint { + tag: "0".to_string(), + timestamp: Local::now(), + description: "Initial state".to_string(), + history_snapshot: current_history.clone(), + is_turn: true, + tool_name: None, + }; + + let mut tag_index = HashMap::new(); + tag_index.insert("0".to_string(), 0); + + Ok(Self { + shadow_repo_path: path.to_path_buf(), + checkpoints: vec![initial_checkpoint], + tag_index, + current_turn: 0, + tools_in_turn: 0, + pending_user_message: None, + message_locked: false, + file_stats_cache: HashMap::new(), + }) + } + + /// Create a new checkpoint point + pub fn create_checkpoint( + &mut self, + tag: &str, + description: &str, + history: &VecDeque, + is_turn: bool, + tool_name: Option, + ) -> Result<()> { + // Stage, commit and tag + stage_commit_tag(&self.shadow_repo_path.to_string_lossy(), description, tag)?; + + // Record checkpoint metadata + let checkpoint = Checkpoint { + tag: tag.to_string(), + timestamp: Local::now(), + description: description.to_string(), + history_snapshot: history.clone(), + is_turn, + tool_name, + }; + + self.checkpoints.push(checkpoint); + self.tag_index.insert(tag.to_string(), self.checkpoints.len() - 1); + + // Cache file stats for this checkpoint + if let Ok(stats) = self.compute_file_stats(tag) { + self.file_stats_cache.insert(tag.to_string(), stats); + } + + Ok(()) + } + + /// Restore workspace to a specific checkpoint + pub fn restore(&self, conversation: &mut ConversationState, tag: &str, hard: bool) -> Result<()> { + let checkpoint = self.get_checkpoint(tag)?; + + if hard { + // Hard: reset the whole work-tree to the tag + let output = run_git(&self.shadow_repo_path, true, &["reset", "--hard", tag])?; + if !output.status.success() { + bail!("Failed to restore: {}", String::from_utf8_lossy(&output.stderr)); + } + } else { + // Soft: only restore tracked files. If the tag is an empty tree, this is a no-op. + if !self.tag_has_any_paths(tag)? { + // Nothing tracked in this checkpoint -> nothing to restore; treat as success. + conversation.restore_to_checkpoint(checkpoint)?; + return Ok(()); + } + // Use checkout against work-tree + let output = run_git(&self.shadow_repo_path, true, &["checkout", tag, "--", "."])?; + if !output.status.success() { + bail!("Failed to restore: {}", String::from_utf8_lossy(&output.stderr)); + } + } + + // Restore conversation history + conversation.restore_to_checkpoint(checkpoint)?; + + Ok(()) + } + + /// Return true iff the given tag/tree has any tracked paths. + fn tag_has_any_paths(&self, tag: &str) -> eyre::Result { + // Use `git ls-tree -r --name-only ` to check if the tree is empty + let out = run_git( + &self.shadow_repo_path, + // work_tree + false, + &["ls-tree", "-r", "--name-only", tag], + )?; + Ok(!out.stdout.is_empty()) + } + + /// Get file change statistics for a checkpoint + pub fn compute_file_stats(&self, tag: &str) -> Result { + if tag == "0" { + return Ok(FileStats::default()); + } + + let prev_tag = get_previous_tag(tag); + self.compute_stats_between(&prev_tag, tag) + } + + /// Compute file statistics between two checkpoints + pub fn compute_stats_between(&self, from: &str, to: &str) -> Result { + let output = run_git(&self.shadow_repo_path, false, &["diff", "--name-status", from, to])?; + + let mut stats = FileStats::default(); + for line in String::from_utf8_lossy(&output.stdout).lines() { + if let Some((status, _)) = line.split_once('\t') { + match status.chars().next() { + Some('A') => stats.added += 1, + Some('M') => stats.modified += 1, + Some('D') => stats.deleted += 1, + Some('R' | 'C') => stats.modified += 1, + _ => {}, + } + } + } + + Ok(stats) + } + + /// Generate detailed diff between checkpoints + pub fn diff(&self, from: &str, to: &str) -> Result { + let mut result = String::new(); + + // Get file changes + let output = run_git(&self.shadow_repo_path, false, &["diff", "--name-status", from, to])?; + + for line in String::from_utf8_lossy(&output.stdout).lines() { + if let Some((status, file)) = line.split_once('\t') { + match status.chars().next() { + Some('A') => result.push_str(&format!(" + {} (added)\n", file).green().to_string()), + Some('M') => result.push_str(&format!(" ~ {} (modified)\n", file).yellow().to_string()), + Some('D') => result.push_str(&format!(" - {} (deleted)\n", file).red().to_string()), + Some('R' | 'C') => result.push_str(&format!(" ~ {} (renamed)\n", file).yellow().to_string()), + _ => {}, + } + } + } + + // Add statistics + let stat_output = run_git(&self.shadow_repo_path, false, &[ + "diff", + from, + to, + "--stat", + "--color=always", + ])?; + + if stat_output.status.success() { + result.push('\n'); + result.push_str(&String::from_utf8_lossy(&stat_output.stdout)); + } + + Ok(result) + } + + /// Check for uncommitted changes + pub fn has_changes(&self) -> Result { + let output = run_git(&self.shadow_repo_path, true, &["status", "--porcelain"])?; + Ok(!output.stdout.is_empty()) + } + + /// Clean up shadow repository + pub async fn cleanup(&self, os: &Os) -> Result<()> { + if self.shadow_repo_path.exists() { + os.fs.remove_dir_all(&self.shadow_repo_path).await?; + } + Ok(()) + } + + fn get_checkpoint(&self, tag: &str) -> Result<&Checkpoint> { + self.tag_index + .get(tag) + .and_then(|&idx| self.checkpoints.get(idx)) + .ok_or_else(|| eyre!("Checkpoint '{}' not found", tag)) + } +} + +impl Drop for CheckpointManager { + fn drop(&mut self) { + let path = self.shadow_repo_path.clone(); + // Try to spawn cleanup task + if let Ok(handle) = tokio::runtime::Handle::try_current() { + handle.spawn(async move { + let _ = tokio::fs::remove_dir_all(path).await; + }); + } else { + // Fallback to thread + std::thread::spawn(move || { + let _ = std::fs::remove_dir_all(path); + }); + } + } +} + +// Helper functions + +/// Truncate message for display +pub fn truncate_message(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + return s.to_string(); + } + + let truncated = &s[..max_len]; + if let Some(pos) = truncated.rfind(' ') { + format!("{}...", &truncated[..pos]) + } else { + format!("{}...", truncated) + } +} + +pub const CHECKPOINT_MESSAGE_MAX_LENGTH: usize = 60; + +fn is_git_installed() -> bool { + Command::new("git") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +fn is_in_git_repo() -> bool { + Command::new("git") + .args(["rev-parse", "--is-inside-work-tree"]) + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +fn configure_git(shadow_path: &str) -> Result<()> { + run_git(Path::new(shadow_path), false, &["config", "user.name", "Q"])?; + run_git(Path::new(shadow_path), false, &["config", "user.email", "qcli@local"])?; + run_git(Path::new(shadow_path), false, &["config", "core.preloadindex", "true"])?; + Ok(()) +} + +fn stage_commit_tag(shadow_path: &str, message: &str, tag: &str) -> Result<()> { + // Stage all changes + run_git(Path::new(shadow_path), true, &["add", "-A"])?; + + // Commit + let output = run_git(Path::new(shadow_path), true, &[ + "commit", + "--allow-empty", + "--no-verify", + "-m", + message, + ])?; + + if !output.status.success() { + bail!("Git commit failed: {}", String::from_utf8_lossy(&output.stderr)); + } + + // Tag + let output = run_git(Path::new(shadow_path), false, &["tag", tag])?; + if !output.status.success() { + bail!("Git tag failed: {}", String::from_utf8_lossy(&output.stderr)); + } + + Ok(()) +} + +fn run_git(dir: &Path, with_work_tree: bool, args: &[&str]) -> Result { + let mut cmd = Command::new("git"); + cmd.arg(format!("--git-dir={}", dir.display())); + + if with_work_tree { + cmd.arg("--work-tree=."); + } + + cmd.args(args); + + let output = cmd.output()?; + if !output.status.success() && !output.stderr.is_empty() { + bail!(String::from_utf8_lossy(&output.stderr).to_string()); + } + + Ok(output) +} + +fn get_previous_tag(tag: &str) -> String { + // Parse turn.tool format + if let Some((turn_str, tool_str)) = tag.split_once('.') { + if let Ok(tool_num) = tool_str.parse::() { + return if tool_num > 1 { + format!("{}.{}", turn_str, tool_num - 1) + } else { + turn_str.to_string() + }; + } + } + + // Parse turn-only format + if let Ok(turn) = tag.parse::() { + return turn.saturating_sub(1).to_string(); + } + + "0".to_string() +} diff --git a/crates/chat-cli/src/cli/chat/cli/checkpoint.rs b/crates/chat-cli/src/cli/chat/cli/checkpoint.rs new file mode 100644 index 0000000000..634da119c3 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/cli/checkpoint.rs @@ -0,0 +1,573 @@ +use std::io::Write; + +use clap::Subcommand; +use crossterm::style::{ + Attribute, + Color, + StyledContent, + Stylize, +}; +use crossterm::{ + execute, + style, +}; +use dialoguer::Select; + +use crate::cli::chat::checkpoint::{ + Checkpoint, + CheckpointManager, + FileStats, +}; +use crate::cli::chat::{ + ChatError, + ChatSession, + ChatState, +}; +use crate::database::settings::Setting; +use crate::os::Os; +use crate::util::directories::get_shadow_repo_dir; + +#[derive(Debug, PartialEq, Subcommand)] +pub enum CheckpointSubcommand { + /// Initialize checkpoints manually + Init, + + /// Restore workspace to a checkpoint + #[command( + about = "Restore workspace to a checkpoint", + long_about = r#"Restore files to a checkpoint . If is omitted, you'll pick one interactively. + +Default mode: + • Restores tracked file changes + • Keeps new files created after the checkpoint + +With --hard: + • Exactly matches the checkpoint state + • Removes files created after the checkpoint"# + )] + Restore { + /// Checkpoint tag (e.g., 3 or 3.1). Leave empty to select interactively. + tag: Option, + + /// Exactly match checkpoint state (removes newer files) + #[arg(long)] + hard: bool, + }, + + /// List all checkpoints + List { + /// Limit number of results shown + #[arg(short, long)] + limit: Option, + }, + + /// Delete the shadow repository + Clean, + + /// Show details of a checkpoint + Expand { + /// Checkpoint tag to expand + tag: String, + }, + + /// Show differences between checkpoints + Diff { + /// First checkpoint tag + tag1: String, + + /// Second checkpoint tag (defaults to current state) + #[arg(required = false)] + tag2: Option, + }, +} + +impl CheckpointSubcommand { + pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { + // Check if checkpoint is enabled + if !os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + { + execute!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("\nCheckpoint is disabled. Enable it with: q settings chat.enableCheckpoint true\n"), + style::SetForegroundColor(Color::Reset) + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + // Check if in tangent mode - captures are disabled during tangent mode + if session.conversation.is_in_tangent_mode() { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print( + "⚠️ Checkpoint is disabled while in tangent mode. Disable tangent mode with: q settings -d chat.enableTangentMode.\n\n" + ), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + match self { + Self::Init => self.handle_init(os, session).await, + Self::Restore { ref tag, hard } => self.handle_restore(session, tag.clone(), hard).await, + Self::List { limit } => Self::handle_list(session, limit), + Self::Clean => self.handle_clean(os, session).await, + Self::Expand { ref tag } => Self::handle_expand(session, tag.clone()), + Self::Diff { ref tag1, ref tag2 } => Self::handle_diff(session, tag1.clone(), tag2.clone()), + } + } + + async fn handle_init(&self, os: &Os, session: &mut ChatSession) -> Result { + if session.conversation.checkpoint_manager.is_some() { + execute!( + session.stderr, + style::SetForegroundColor(Color::Blue), + style::Print( + "✓ Checkpoints are already enabled for this session! Use /checkpoint list to see current checkpoints.\n" + ), + style::SetForegroundColor(Color::Reset) + )?; + } else { + let path = get_shadow_repo_dir(os, session.conversation.conversation_id().to_string()) + .map_err(|e| ChatError::Custom(e.to_string().into()))?; + + let start = std::time::Instant::now(); + session.conversation.checkpoint_manager = Some( + CheckpointManager::manual_init(os, path, session.conversation.history()) + .await + .map_err(|e| ChatError::Custom(format!("Checkpoints could not be initialized: {e}").into()))?, + ); + + execute!( + session.stderr, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!( + "📷 Checkpoints are enabled! (took {:.2}s)\n", + start.elapsed().as_secs_f32() + )), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset), + )?; + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + async fn handle_restore( + &self, + session: &mut ChatSession, + tag: Option, + hard: bool, + ) -> Result { + // Take manager out temporarily to avoid borrow issues + let Some(manager) = session.conversation.checkpoint_manager.take() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Checkpoints not enabled. Use '/checkpoint init' to enable.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + let tag_result = if let Some(tag) = tag { + Ok(tag) + } else { + // Interactive selection + match gather_turn_checkpoints(&manager) { + Ok(entries) => { + if let Some(idx) = select_checkpoint(&entries, "Select checkpoint to restore:") { + Ok(entries[idx].tag.clone()) + } else { + Err(()) + } + }, + Err(e) => { + session.conversation.checkpoint_manager = Some(manager); + return Err(ChatError::Custom(format!("Failed to gather checkpoints: {}", e).into())); + }, + } + }; + + let tag = match tag_result { + Ok(tag) => tag, + Err(_) => { + session.conversation.checkpoint_manager = Some(manager); + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }, + }; + + match manager.restore(&mut session.conversation, &tag, hard) { + Ok(_) => { + execute!( + session.stderr, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!("✓ Restored to checkpoint {}\n", tag)), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset), + )?; + session.conversation.checkpoint_manager = Some(manager); + }, + Err(e) => { + session.conversation.checkpoint_manager = Some(manager); + return Err(ChatError::Custom(format!("Failed to restore: {}", e).into())); + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + fn handle_list(session: &mut ChatSession, limit: Option) -> Result { + let Some(manager) = session.conversation.checkpoint_manager.as_ref() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Checkpoints not enabled. Use '/checkpoint init' to enable.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + print_checkpoints(manager, &mut session.stderr, limit) + .map_err(|e| ChatError::Custom(format!("Could not display all checkpoints: {}", e).into()))?; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + async fn handle_clean(&self, os: &Os, session: &mut ChatSession) -> Result { + let Some(manager) = session.conversation.checkpoint_manager.take() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ ️Checkpoints not enabled.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + // Print the path that will be deleted + execute!( + session.stderr, + style::Print(format!("Deleting: {}\n", manager.shadow_repo_path.display())) + )?; + + match manager.cleanup(os).await { + Ok(()) => { + execute!( + session.stderr, + style::SetAttribute(Attribute::Bold), + style::Print("✓ Deleted shadow repository for this session.\n"), + style::SetAttribute(Attribute::Reset), + )?; + }, + Err(e) => { + session.conversation.checkpoint_manager = Some(manager); + return Err(ChatError::Custom(format!("Failed to clean: {e}").into())); + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + fn handle_expand(session: &mut ChatSession, tag: String) -> Result { + let Some(manager) = session.conversation.checkpoint_manager.as_ref() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ ️Checkpoints not enabled. Use '/checkpoint init' to enable.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + expand_checkpoint(manager, &mut session.stderr, &tag) + .map_err(|e| ChatError::Custom(format!("Failed to expand checkpoint: {}", e).into()))?; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + + fn handle_diff(session: &mut ChatSession, tag1: String, tag2: Option) -> Result { + let Some(manager) = session.conversation.checkpoint_manager.as_ref() else { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Checkpoints not enabled. Use '/checkpoint init' to enable.\n"), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }; + + let tag2 = tag2.unwrap_or_else(|| "HEAD".to_string()); + + // Validate tags exist + if tag1 != "HEAD" && !manager.tag_index.contains_key(&tag1) { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + "⚠️ Checkpoint '{}' not found! Use /checkpoint list to see available checkpoints\n", + tag1 + )), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + if tag2 != "HEAD" && !manager.tag_index.contains_key(&tag2) { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + "⚠️ Checkpoint '{}' not found! Use /checkpoint list to see available checkpoints\n", + tag2 + )), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + let header = if tag2 == "HEAD" { + format!("Changes since checkpoint {}:\n", tag1) + } else { + format!("Changes from {} to {}:\n", tag1, tag2) + }; + + execute!( + session.stderr, + style::SetForegroundColor(Color::Blue), + style::Print(header), + style::SetForegroundColor(Color::Reset), + )?; + + match manager.diff(&tag1, &tag2) { + Ok(diff) => { + if diff.trim().is_empty() { + execute!( + session.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("No changes.\n"), + style::SetForegroundColor(Color::Reset), + )?; + } else { + execute!(session.stderr, style::Print(diff))?; + } + }, + Err(e) => { + return Err(ChatError::Custom(format!("Failed to generate diff: {e}").into())); + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } +} + +// Display helpers + +struct CheckpointDisplay { + tag: String, + parts: Vec>, +} + +impl CheckpointDisplay { + fn from_checkpoint(checkpoint: &Checkpoint, manager: &CheckpointManager) -> Result { + let mut parts = Vec::new(); + + // Tag + parts.push(format!("[{}] ", checkpoint.tag).blue()); + + // Content + if checkpoint.is_turn { + // Turn checkpoint: show timestamp and description + parts.push( + format!( + "{} - {}", + checkpoint.timestamp.format("%Y-%m-%d %H:%M:%S"), + checkpoint.description + ) + .reset(), + ); + + // Add file stats if available + if let Some(stats) = manager.file_stats_cache.get(&checkpoint.tag) { + let stats_str = format_stats(stats); + if !stats_str.is_empty() { + parts.push(format!(" ({})", stats_str).dark_grey()); + } + } + } else { + // Tool checkpoint: show tool name and description + let tool_name = checkpoint.tool_name.clone().unwrap_or_else(|| "Tool".to_string()); + parts.push(format!("{}: ", tool_name).magenta()); + parts.push(checkpoint.description.clone().reset()); + } + + Ok(Self { + tag: checkpoint.tag.clone(), + parts, + }) + } +} + +impl std::fmt::Display for CheckpointDisplay { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for part in &self.parts { + write!(f, "{}", part)?; + } + Ok(()) + } +} + +fn format_stats(stats: &FileStats) -> String { + let mut parts = Vec::new(); + + if stats.added > 0 { + parts.push(format!("+{}", stats.added)); + } + if stats.modified > 0 { + parts.push(format!("~{}", stats.modified)); + } + if stats.deleted > 0 { + parts.push(format!("-{}", stats.deleted)); + } + + parts.join(" ") +} + +fn gather_turn_checkpoints(manager: &CheckpointManager) -> Result, eyre::Report> { + manager + .checkpoints + .iter() + .filter(|c| c.is_turn) + .map(|c| CheckpointDisplay::from_checkpoint(c, manager)) + .collect() +} + +fn print_checkpoints( + manager: &CheckpointManager, + output: &mut impl Write, + limit: Option, +) -> Result<(), eyre::Report> { + let entries = gather_turn_checkpoints(manager)?; + let limit = limit.unwrap_or(entries.len()); + + for entry in entries.iter().take(limit) { + execute!(output, style::Print(&entry), style::Print("\n"))?; + } + + Ok(()) +} + +fn expand_checkpoint(manager: &CheckpointManager, output: &mut impl Write, tag: &str) -> Result<(), eyre::Report> { + let Some(&idx) = manager.tag_index.get(tag) else { + execute!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print(format!("⚠️ checkpoint '{}' not found\n", tag)), + style::SetForegroundColor(Color::Reset), + )?; + return Ok(()); + }; + + let checkpoint = &manager.checkpoints[idx]; + + // Print main checkpoint + let display = CheckpointDisplay::from_checkpoint(checkpoint, manager)?; + execute!(output, style::Print(&display), style::Print("\n"))?; + + if !checkpoint.is_turn { + return Ok(()); + } + + // Print tool checkpoints for this turn + let mut tool_checkpoints = Vec::new(); + for i in (0..idx).rev() { + let c = &manager.checkpoints[i]; + if c.is_turn { + break; + } + tool_checkpoints.push((i, CheckpointDisplay::from_checkpoint(c, manager)?)); + } + + for (checkpoint_idx, display) in tool_checkpoints.iter().rev() { + // Compute stats for this tool + let curr_tag = &manager.checkpoints[*checkpoint_idx].tag; + let prev_tag = if *checkpoint_idx > 0 { + &manager.checkpoints[checkpoint_idx - 1].tag + } else { + "0" + }; + + let stats_str = manager + .compute_stats_between(prev_tag, curr_tag) + .map(|s| format_stats(&s)) + .unwrap_or_default(); + + execute!( + output, + style::SetForegroundColor(Color::Blue), + style::Print(" └─ "), + style::Print(display), + style::SetForegroundColor(Color::Reset), + )?; + + if !stats_str.is_empty() { + execute!( + output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!(" ({})", stats_str)), + style::SetForegroundColor(Color::Reset), + )?; + } + + execute!(output, style::Print("\n"))?; + } + + Ok(()) +} + +fn select_checkpoint(entries: &[CheckpointDisplay], prompt: &str) -> Option { + Select::with_theme(&crate::util::dialoguer_theme()) + .with_prompt(prompt) + .items(entries) + .report(false) + .interact_opt() + .unwrap_or(None) +} diff --git a/crates/chat-cli/src/cli/chat/cli/experiment.rs b/crates/chat-cli/src/cli/chat/cli/experiment.rs index 121419b244..9b7c3a8cd2 100644 --- a/crates/chat-cli/src/cli/chat/cli/experiment.rs +++ b/crates/chat-cli/src/cli/chat/cli/experiment.rs @@ -50,6 +50,15 @@ static AVAILABLE_EXPERIMENTS: &[Experiment] = &[ description: "Enables Q to create todo lists that can be viewed and managed using /todos", setting_key: Setting::EnabledTodoList, }, + Experiment { + name: "Checkpoint", + description: concat!( + "Enables workspace checkpoints to snapshot, list, expand, diff, and restore files (/checkpoint)\n", + " ", + "Cannot be used in tangent mode (to avoid mixing up conversation history)" + ), + setting_key: Setting::EnabledCheckpoint, + }, Experiment { name: "Context Usage Indicator", description: "Shows context usage percentage in the prompt (e.g., [rust-agent] 6% >)", diff --git a/crates/chat-cli/src/cli/chat/cli/mod.rs b/crates/chat-cli/src/cli/chat/cli/mod.rs index e5a20ed1c2..bf951596e6 100644 --- a/crates/chat-cli/src/cli/chat/cli/mod.rs +++ b/crates/chat-cli/src/cli/chat/cli/mod.rs @@ -1,4 +1,5 @@ pub mod changelog; +pub mod checkpoint; pub mod clear; pub mod compact; pub mod context; @@ -35,6 +36,7 @@ use tangent::TangentArgs; use todos::TodoSubcommand; use tools::ToolsArgs; +use crate::cli::chat::cli::checkpoint::CheckpointSubcommand; use crate::cli::chat::cli::subscribe::SubscribeArgs; use crate::cli::chat::cli::usage::UsageArgs; use crate::cli::chat::consts::AGENT_MIGRATION_DOC_URL; @@ -102,6 +104,8 @@ pub enum SlashCommand { Persist(PersistSubcommand), // #[command(flatten)] // Root(RootSubcommand), + #[command(subcommand)] + Checkpoint(CheckpointSubcommand), /// View, manage, and resume to-do lists #[command(subcommand)] Todos(TodoSubcommand), @@ -169,6 +173,7 @@ impl SlashCommand { // skip_printing_tools: true, // }) // }, + Self::Checkpoint(subcommand) => subcommand.execute(os, session).await, Self::Todos(subcommand) => subcommand.execute(os, session).await, } } @@ -198,6 +203,7 @@ impl SlashCommand { PersistSubcommand::Save { .. } => "save", PersistSubcommand::Load { .. } => "load", }, + Self::Checkpoint(_) => "checkpoint", Self::Todos(_) => "todos", } } diff --git a/crates/chat-cli/src/cli/chat/cli/tangent.rs b/crates/chat-cli/src/cli/chat/cli/tangent.rs index 94184c4828..65165c84f3 100644 --- a/crates/chat-cli/src/cli/chat/cli/tangent.rs +++ b/crates/chat-cli/src/cli/chat/cli/tangent.rs @@ -65,6 +65,22 @@ impl TangentArgs { match self.subcommand { Some(TangentSubcommand::Tail) => { + // Check if checkpoint is enabled + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print( + "⚠️ Checkpoint is disabled while in tangent mode. Please exit tangent mode if you want to use checkpoint.\n" + ), + style::SetForegroundColor(Color::Reset), + )?; + } if session.conversation.is_in_tangent_mode() { let duration_seconds = session.conversation.get_tangent_duration_seconds().unwrap_or(0); session.conversation.exit_tangent_mode_with_tail(); @@ -106,6 +122,23 @@ impl TangentArgs { style::SetForegroundColor(Color::Reset) )?; } else { + // Check if checkpoint is enabled + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print( + "⚠️ Checkpoint is disabled while in tangent mode. Please exit tangent mode if you want to use checkpoint.\n" + ), + style::SetForegroundColor(Color::Reset), + )?; + } + session.conversation.enter_tangent_mode(); // Get the configured tangent mode key for display diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 3d064c0180..1217c0289b 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -74,6 +74,10 @@ use crate::cli::agent::hook::{ HookTrigger, }; use crate::cli::chat::ChatError; +use crate::cli::chat::checkpoint::{ + Checkpoint, + CheckpointManager, +}; use crate::cli::chat::cli::model::{ ModelInfo, get_model_info, @@ -138,6 +142,8 @@ pub struct ConversationState { /// Maps from a file path to [FileLineTracker] #[serde(default)] pub file_line_tracker: HashMap, + + pub checkpoint_manager: Option, #[serde(default = "default_true")] pub mcp_enabled: bool, /// Tangent mode checkpoint - stores main conversation when in tangent mode @@ -203,6 +209,7 @@ impl ConversationState { model: None, model_info: model, file_line_tracker: HashMap::new(), + checkpoint_manager: None, mcp_enabled, tangent_state: None, } @@ -891,6 +898,20 @@ Return only the JSON configuration, no additional text.", self.transcript.push_back(message); } + /// Restore conversation from a checkpoint's history snapshot + pub fn restore_to_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<(), eyre::Report> { + // 1. Restore history from snapshot + self.history = checkpoint.history_snapshot.clone(); + + // 2. Clear any pending next message (uncommitted state) + self.next_message = None; + + // 3. Update valid history range + self.valid_history_range = (0, self.history.len()); + + Ok(()) + } + /// Swapping agent involves the following: /// - Reinstantiate the context manager /// - Swap agent on tool manager diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index a576440c80..fcdb8b30ef 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -6,11 +6,13 @@ mod input_source; mod message; mod parse; use std::path::MAIN_SEPARATOR; +pub mod checkpoint; mod line_tracker; mod parser; mod prompt; mod prompt_parser; pub mod server_messenger; +use crate::cli::chat::checkpoint::CHECKPOINT_MESSAGE_MAX_LENGTH; #[cfg(unix)] mod skim_integration; mod token_counter; @@ -142,6 +144,10 @@ use crate::auth::AuthError; use crate::auth::builder_id::is_idc_user; use crate::cli::TodoListState; use crate::cli::agent::Agents; +use crate::cli::chat::checkpoint::{ + CheckpointManager, + truncate_message, +}; use crate::cli::chat::cli::SlashCommand; use crate::cli::chat::cli::editor::open_editor; use crate::cli::chat::cli::prompts::{ @@ -165,6 +171,7 @@ use crate::telemetry::{ TelemetryResult, get_error_reason, }; +use crate::util::directories::get_shadow_repo_dir; use crate::util::{ MCP_SERVER_TOOL_DELIMITER, directories, @@ -455,7 +462,7 @@ const RESUME_TEXT: &str = color_print::cstr! {"Picking up where we left off. const CHANGELOG_MAX_SHOW_COUNT: i64 = 2; // Only show the model-related tip for now to make users aware of this feature. -const ROTATING_TIPS: [&str; 19] = [ +const ROTATING_TIPS: [&str; 20] = [ color_print::cstr! {"You can resume the last conversation from your current directory by launching with q chat --resume"}, color_print::cstr! {"Get notified whenever Q CLI finishes responding. @@ -490,6 +497,7 @@ const ROTATING_TIPS: [&str; 19] = [ color_print::cstr! {"Use /tangent or ctrl + t (customizable) to start isolated conversations ( ↯ ) that don't affect your main chat history"}, color_print::cstr! {"Ask me directly about my capabilities! Try questions like \"What can you do?\" or \"Can you save conversations?\""}, color_print::cstr! {"Stay up to date with the latest features and improvements! Use /changelog to see what's new in Amazon Q CLI"}, + color_print::cstr! {"Enable workspace checkpoints to snapshot & restore changes. Just run q settings chat.enableCheckpoint true"}, ]; const GREETING_BREAK_POINT: usize = 80; @@ -1322,6 +1330,38 @@ impl ChatSession { } } + // Initialize capturing if possible + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + { + let path = get_shadow_repo_dir(os, self.conversation.conversation_id().to_string())?; + let start = std::time::Instant::now(); + let checkpoint_manager = match CheckpointManager::auto_init(os, &path, self.conversation.history()).await { + Ok(manager) => { + execute!( + self.stderr, + style::Print( + format!( + "📷 Checkpoints are enabled! (took {:.2}s)\n\n", + start.elapsed().as_secs_f32() + ) + .blue() + .bold() + ) + )?; + Some(manager) + }, + Err(e) => { + execute!(self.stderr, style::Print(format!("{e}\n\n").blue()))?; + None + }, + }; + self.conversation.checkpoint_manager = checkpoint_manager; + } + if let Some(user_input) = self.initial_input.take() { self.inner = Some(ChatState::HandleInput { input: user_input }); } @@ -2083,6 +2123,23 @@ impl ChatSession { skip_printing_tools: false, }) } else { + // Track the message for checkpoint descriptions, but only if not already set + // This prevents tool approval responses (y/n/t) from overwriting the original message + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + && !self.conversation.is_in_tangent_mode() + { + if let Some(manager) = self.conversation.checkpoint_manager.as_mut() { + if !manager.message_locked && self.pending_tool_index.is_none() { + manager.pending_user_message = Some(user_input.clone()); + manager.message_locked = true; + } + } + } + // Check for a pending tool approval if let Some(index) = self.pending_tool_index { let is_trust = ["t", "T"].contains(&input); @@ -2306,6 +2363,74 @@ impl ChatSession { } execute!(self.stdout, style::Print("\n"))?; + // Handle checkpoint after tool execution - store tag for later display + let checkpoint_tag: Option = { + let enabled = os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + && !self.conversation.is_in_tangent_mode(); + if invoke_result.is_err() || !enabled { + None + } + // Take manager out temporarily to avoid borrow conflicts + else if let Some(mut manager) = self.conversation.checkpoint_manager.take() { + // Check if there are uncommitted changes + let has_changes = match manager.has_changes() { + Ok(b) => b, + Err(e) => { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!("Could not check if uncommitted changes exist: {e}\n")), + style::Print("Saving anyways...\n"), + style::SetForegroundColor(Color::Reset), + )?; + true + }, + }; + let tag = if has_changes { + // Generate tag for this tool use + let tag = format!("{}.{}", manager.current_turn + 1, manager.tools_in_turn + 1); + + // Get tool summary for commit message + let is_fs_read = matches!(&tool.tool, Tool::FsRead(_)); + let description = if is_fs_read { + "External edits detected (likely manual change)".to_string() + } else { + match tool.tool.get_summary() { + Some(summary) => summary, + None => tool.tool.display_name(), + } + }; + + // Create checkpoint + if let Err(e) = manager.create_checkpoint( + &tag, + &description, + &self.conversation.history().clone(), + false, + Some(tool.name.clone()), + ) { + debug!("Failed to create tool checkpoint: {}", e); + None + } else { + manager.tools_in_turn += 1; + Some(tag) + } + } else { + None + }; + + // Put manager back + self.conversation.checkpoint_manager = Some(manager); + tag + } else { + None + } + }; + let tool_end_time = Instant::now(); let tool_time = tool_end_time.duration_since(tool_start); tool_telemetry = tool_telemetry.and_modify(|ev| { @@ -2348,8 +2473,18 @@ impl ChatSession { style::SetAttribute(Attribute::Bold), style::Print(format!(" ● Completed in {}s", tool_time)), style::SetForegroundColor(Color::Reset), - style::Print("\n\n"), )?; + if let Some(tag) = checkpoint_tag { + execute!( + self.stdout, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!(" [{tag}]")), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset), + )?; + } + execute!(self.stdout, style::Print("\n\n"))?; tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_success = Some(true)); if let Tool::Custom(_) = &tool.tool { @@ -2729,7 +2864,10 @@ impl ChatSession { self.stdout, style::Print("\n\n"), style::SetForegroundColor(Color::Yellow), - style::Print(format!("Tool validation failed: {}\n Retrying the request...", error_message)), + style::Print(format!( + "Tool validation failed: {}\n Retrying the request...", + error_message + )), style::ResetColor, style::Print("\n"), ); @@ -2845,6 +2983,64 @@ impl ChatSession { self.pending_tool_index = None; self.tool_turn_start_time = None; + // Create turn checkpoint if tools were used + if os + .database + .settings + .get_bool(Setting::EnabledCheckpoint) + .unwrap_or(false) + && !self.conversation.is_in_tangent_mode() + { + if let Some(mut manager) = self.conversation.checkpoint_manager.take() { + if manager.tools_in_turn > 0 { + // Increment turn counter + manager.current_turn += 1; + + // Get user message for description + let description = manager.pending_user_message.take().map_or_else( + || "Turn completed".to_string(), + |msg| truncate_message(&msg, CHECKPOINT_MESSAGE_MAX_LENGTH), + ); + + // Create turn checkpoint + let tag = manager.current_turn.to_string(); + if let Err(e) = manager.create_checkpoint( + &tag, + &description, + &self.conversation.history().clone(), + true, + None, + ) { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!("⚠️ Could not create automatic checkpoint: {}\n\n", e)), + style::SetForegroundColor(Color::Reset), + )?; + } else { + execute!( + self.stderr, + style::SetForegroundColor(Color::Blue), + style::SetAttribute(Attribute::Bold), + style::Print(format!("✓ Created checkpoint {}\n\n", tag)), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset), + )?; + } + + // Reset for next turn + manager.tools_in_turn = 0; + manager.message_locked = false; // Unlock for next turn + } else { + // Clear pending message even if no tools were used + manager.pending_user_message = None; + } + + // Put manager back + self.conversation.checkpoint_manager = Some(manager); + } + } + self.send_chat_telemetry(os, TelemetryResult::Succeeded, None, None, None, true) .await; diff --git a/crates/chat-cli/src/cli/chat/tools/fs_write.rs b/crates/chat-cli/src/cli/chat/tools/fs_write.rs index d23a957b60..d72ccb2be6 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_write.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_write.rs @@ -451,7 +451,7 @@ impl FsWrite { } /// Returns the summary from any variant of the FsWrite enum - fn get_summary(&self) -> Option<&String> { + pub fn get_summary(&self) -> Option<&String> { match self { FsWrite::Create { summary, .. } => summary.as_ref(), FsWrite::StrReplace { summary, .. } => summary.as_ref(), diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index 96cc0d76f2..9b90e1d052 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -187,6 +187,16 @@ impl Tool { _ => None, } } + + /// Returns the tool's summary if available + pub fn get_summary(&self) -> Option { + match self { + Tool::FsWrite(fs_write) => fs_write.get_summary().cloned(), + Tool::ExecuteCommand(execute_cmd) => execute_cmd.summary.clone(), + Tool::FsRead(fs_read) => fs_read.summary.clone(), + _ => None, + } + } } /// A tool specification to be sent to the model as part of a conversation. Maps to diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs index 2ea8510e2a..9f440f3ab4 100644 --- a/crates/chat-cli/src/database/settings.rs +++ b/crates/chat-cli/src/database/settings.rs @@ -81,6 +81,8 @@ pub enum Setting { ChatEnableHistoryHints, #[strum(message = "Enable the todo list feature (boolean)")] EnabledTodoList, + #[strum(message = "Enable the checkpoint feature (boolean)")] + EnabledCheckpoint, } impl AsRef for Setting { @@ -117,6 +119,7 @@ impl AsRef for Setting { Self::ChatDisableAutoCompaction => "chat.disableAutoCompaction", Self::ChatEnableHistoryHints => "chat.enableHistoryHints", Self::EnabledTodoList => "chat.enableTodoList", + Self::EnabledCheckpoint => "chat.enableCheckpoint", Self::EnabledContextUsageIndicator => "chat.enableContextUsageIndicator", } } @@ -164,6 +167,7 @@ impl TryFrom<&str> for Setting { "chat.disableAutoCompaction" => Ok(Self::ChatDisableAutoCompaction), "chat.enableHistoryHints" => Ok(Self::ChatEnableHistoryHints), "chat.enableTodoList" => Ok(Self::EnabledTodoList), + "chat.enableCheckpoint" => Ok(Self::EnabledCheckpoint), "chat.enableContextUsageIndicator" => Ok(Self::EnabledContextUsageIndicator), _ => Err(DatabaseError::InvalidSetting(value.to_string())), } @@ -301,6 +305,7 @@ mod test { .set(Setting::ChatDisableMarkdownRendering, false) .await .unwrap(); + settings.set(Setting::EnabledCheckpoint, true).await.unwrap(); assert_eq!(settings.get(Setting::TelemetryEnabled), Some(&Value::Bool(true))); assert_eq!( @@ -324,6 +329,7 @@ mod test { settings.get(Setting::ChatDisableMarkdownRendering), Some(&Value::Bool(false)) ); + assert_eq!(settings.get(Setting::EnabledCheckpoint), Some(&Value::Bool(true))); settings.remove(Setting::TelemetryEnabled).await.unwrap(); settings.remove(Setting::OldClientId).await.unwrap(); @@ -331,6 +337,7 @@ mod test { settings.remove(Setting::KnowledgeIndexType).await.unwrap(); settings.remove(Setting::McpLoadedBefore).await.unwrap(); settings.remove(Setting::ChatDisableMarkdownRendering).await.unwrap(); + settings.remove(Setting::EnabledCheckpoint).await.unwrap(); assert_eq!(settings.get(Setting::TelemetryEnabled), None); assert_eq!(settings.get(Setting::OldClientId), None); @@ -338,5 +345,6 @@ mod test { assert_eq!(settings.get(Setting::KnowledgeIndexType), None); assert_eq!(settings.get(Setting::McpLoadedBefore), None); assert_eq!(settings.get(Setting::ChatDisableMarkdownRendering), None); + assert_eq!(settings.get(Setting::EnabledCheckpoint), None); } } diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index 6833cb828b..2e89ce1e09 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -43,6 +43,7 @@ pub enum DirectoryError { type Result = std::result::Result; const WORKSPACE_AGENT_DIR_RELATIVE: &str = ".amazonq/cli-agents"; +const GLOBAL_SHADOW_REPO_DIR: &str = ".aws/amazonq/cli-checkpoints"; const GLOBAL_AGENT_DIR_RELATIVE_TO_HOME: &str = ".aws/amazonq/cli-agents"; const WORKSPACE_PROMPTS_DIR_RELATIVE: &str = ".amazonq/prompts"; const GLOBAL_PROMPTS_DIR_RELATIVE_TO_HOME: &str = ".aws/amazonq/prompts"; @@ -300,6 +301,10 @@ pub fn get_mcp_auth_dir(os: &Os) -> Result { Ok(home_dir(os)?.join(".aws").join("sso").join("cache")) } +pub fn get_shadow_repo_dir(os: &Os, conversation_id: String) -> Result { + Ok(home_dir(os)?.join(GLOBAL_SHADOW_REPO_DIR).join(conversation_id)) +} + /// Generate a unique identifier for an agent based on its path and name fn generate_agent_unique_id(agent: &crate::cli::Agent) -> String { use std::collections::hash_map::DefaultHasher; From dae8168b907d2ccaa821f4dece4535bc223d22ba Mon Sep 17 00:00:00 2001 From: Kenneth Sanchez V Date: Thu, 25 Sep 2025 18:01:08 -0700 Subject: [PATCH 67/71] Bump version to 1.16.3 and update feed.json (#2996) Co-authored-by: Kenneth S. --- Cargo.lock | 4 +-- Cargo.toml | 2 +- crates/chat-cli/src/cli/feed.json | 36 +++++++++++++++++++++ docs/experiments.md | 52 +++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7230af9f7e..74a6c9f8e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1207,7 +1207,7 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chat_cli" -version = "1.16.2" +version = "1.16.3" dependencies = [ "amzn-codewhisperer-client", "amzn-codewhisperer-streaming-client", @@ -5647,7 +5647,7 @@ dependencies = [ [[package]] name = "semantic_search_client" -version = "1.16.2" +version = "1.16.3" dependencies = [ "anyhow", "bm25", diff --git a/Cargo.toml b/Cargo.toml index 7f0b79de2a..d2fb418eb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ authors = ["Amazon Q CLI Team (q-cli@amazon.com)", "Chay Nabors (nabochay@amazon edition = "2024" homepage = "https://aws.amazon.com/q/" publish = false -version = "1.16.2" +version = "1.16.3" license = "MIT OR Apache-2.0" [workspace.dependencies] diff --git a/crates/chat-cli/src/cli/feed.json b/crates/chat-cli/src/cli/feed.json index e81a34d24d..2ae3451824 100644 --- a/crates/chat-cli/src/cli/feed.json +++ b/crates/chat-cli/src/cli/feed.json @@ -10,6 +10,42 @@ "hidden": true, "changes": [] }, + { + "type": "release", + "date": "2025-09-26", + "version": "1.16.3", + "title": "Version 1.16.3", + "changes": [ + { + "type": "added", + "description": "[Experimental] Adds checkpointing functionality using Git CLI commands - [#2896](https://github.com/aws/amazon-q-developer-cli/pull/2896)" + }, + { + "type": "fixed", + "description": "Fixes issues with Tool Input parsing - [#2986](https://github.com/aws/amazon-q-developer-cli/pull/2986)" + }, + { + "type": "added", + "description": "[Experimental] Add context usage percentage indicator to prompt - [#2994](https://github.com/aws/amazon-q-developer-cli/pull/2994)" + }, + { + "type": "added", + "description": "Expand support for /prompts command - [#2799](https://github.com/aws/amazon-q-developer-cli/pull/2799)" + }, + { + "type": "fixed", + "description": "Consolidate tool permission logic for consistent display and execution - [#2975](https://github.com/aws/amazon-q-developer-cli/pull/2975)" + }, + { + "type": "fixed", + "description": "Hardcode client id for oauth in MCP - [#2976](https://github.com/aws/amazon-q-developer-cli/pull/2976)" + }, + { + "type": "improved", + "description": "Improve error messages for dispatch failures - [#2969](https://github.com/aws/amazon-q-developer-cli/pull/2969)" + } + ] + }, { "type": "release", "date": "2025-09-19", diff --git a/docs/experiments.md b/docs/experiments.md index 3c8ce310b4..3864cc3670 100644 --- a/docs/experiments.md +++ b/docs/experiments.md @@ -4,6 +4,56 @@ Amazon Q CLI includes experimental features that can be toggled on/off using the ## Available Experiments +### Checkpointing +**Description:** Enables session-scoped checkpoints for tracking file changes using Git CLI commands + +**Features:** +- Snapshots file changes into a shadow bare git repo +- List, expand, diff, and restore to any checkpoint +- Conversation history unwinds when restoring checkpoints +- Auto-enables in git repositories (ephemeral, cleaned on session end) +- Manual initialization available for non-git directories + +**Usage:** +``` +/checkpoint init # Manually enable checkpoints (if not in git repo) +/checkpoint list [--limit N] # Show turn-level checkpoints with file stats +/checkpoint expand # Show tool-level checkpoints under a turn +/checkpoint diff [tag2|HEAD] # Compare checkpoints or with current state +/checkpoint restore [] [--hard] # Restore to checkpoint (interactive picker if no tag) +/checkpoint clean # Delete session shadow repo +``` + +**Restore Options:** +- Default: Revert tracked changes & deletions; keep files created after checkpoint +- `--hard`: Make workspace exactly match checkpoint; deletes tracked files created after it + +**Example:** +``` +/checkpoint list +[0] 2025-09-18 14:00:00 - Initial checkpoint +[1] 2025-09-18 14:05:31 - add two_sum.py (+1 file) +[2] 2025-09-18 14:07:10 - add tests (modified 1) + +/checkpoint expand 2 +[2] 2025-09-18 14:07:10 - add tests + └─ [2.1] fs_write: Add minimal test cases to two_sum.py (modified 1) +``` + +### Context Usage Percentage +**Description:** Shows context window usage as a percentage in the chat prompt + +**Features:** +- Displays percentage of context window used in prompt (e.g., "[rust-agent] 6% >") +- Color-coded indicators: + - Green: <50% usage + - Yellow: 50-89% usage + - Red: 90-100% usage +- Helps monitor context window consumption +- Disabled by default + +**When enabled:** The chat prompt will show your current context usage percentage with color coding to help you understand how much of the available context window is being used. + ### Knowledge **Command:** `/knowledge` **Description:** Enables persistent context storage and retrieval across chat sessions @@ -111,6 +161,8 @@ All experimental commands are available in the fuzzy search (Ctrl+S): ## Settings Integration Experiments are stored as settings and persist across sessions: +- `EnabledCheckpointing` - Checkpointing experiment state +- `EnabledContextUsagePercentage` - Context usage percentage experiment state - `EnabledKnowledge` - Knowledge experiment state - `EnabledThinking` - Thinking experiment state - `EnabledTodoList` - TODO list experiment state From a8a815ade49524e4d8841dc2eb74e1bb2ff7b027 Mon Sep 17 00:00:00 2001 From: Kenneth Sanchez V Date: Thu, 25 Sep 2025 18:33:54 -0700 Subject: [PATCH 68/71] chore: update feed (#2998) --- crates/chat-cli/src/cli/feed.json | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/crates/chat-cli/src/cli/feed.json b/crates/chat-cli/src/cli/feed.json index 2ae3451824..53c38195c2 100644 --- a/crates/chat-cli/src/cli/feed.json +++ b/crates/chat-cli/src/cli/feed.json @@ -22,7 +22,7 @@ }, { "type": "fixed", - "description": "Fixes issues with Tool Input parsing - [#2986](https://github.com/aws/amazon-q-developer-cli/pull/2986)" + "description": "Validation issues with MCP tool arguments - [#2986](https://github.com/aws/amazon-q-developer-cli/pull/2986)" }, { "type": "added", @@ -30,18 +30,14 @@ }, { "type": "added", - "description": "Expand support for /prompts command - [#2799](https://github.com/aws/amazon-q-developer-cli/pull/2799)" + "description": "Support for custom prompts, see `/prompts` - [#2799](https://github.com/aws/amazon-q-developer-cli/pull/2799)" }, { "type": "fixed", - "description": "Consolidate tool permission logic for consistent display and execution - [#2975](https://github.com/aws/amazon-q-developer-cli/pull/2975)" + "description": "Various issues with MCP OAuth - [#2976](https://github.com/aws/amazon-q-developer-cli/pull/2976)" }, { "type": "fixed", - "description": "Hardcode client id for oauth in MCP - [#2976](https://github.com/aws/amazon-q-developer-cli/pull/2976)" - }, - { - "type": "improved", "description": "Improve error messages for dispatch failures - [#2969](https://github.com/aws/amazon-q-developer-cli/pull/2969)" } ] From bf6de44f084d1aca901e166bb1b15bdbc7cb3e62 Mon Sep 17 00:00:00 2001 From: mbcohn Date: Sun, 31 Aug 2025 08:19:27 +0000 Subject: [PATCH 69/71] docs: clarify agent-specific knowledge bases not yet released Add version note indicating that agent-specific knowledge bases are available in development but not in current releases (v1.14.1 and earlier). Current releases use global storage at ~/.aws/amazonq/knowledge_bases/. --- docs/knowledge-management.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/knowledge-management.md b/docs/knowledge-management.md index 82e3a291af..afa42a1dc9 100644 --- a/docs/knowledge-management.md +++ b/docs/knowledge-management.md @@ -169,6 +169,8 @@ Configure knowledge base behavior: ## Agent-Specific Knowledge Bases +> **Note**: Agent-specific knowledge bases are available in development versions but not yet released. In current releases (v1.14.1 and earlier), all knowledge bases are stored globally at `~/.aws/amazonq/knowledge_bases/` and shared across all agents. + ### Isolated Knowledge Storage Each agent maintains its own isolated knowledge base, ensuring that knowledge contexts are scoped to the specific agent you're working with. This provides better organization and prevents knowledge conflicts between different agents. From 7d8d476df8e34a098588d091b17ccaa35de8750f Mon Sep 17 00:00:00 2001 From: mbcohn Date: Sat, 27 Sep 2025 00:20:01 +0000 Subject: [PATCH 70/71] fix: treat empty string as unset for chat.defaultAgent setting When users run 'q settings chat.defaultAgent ""' to reset their default agent (as documented), the empty string was being treated as a valid agent name, causing an error message on every chat session start. This change treats empty strings as 'no default set', allowing users to cleanly reset to the built-in default agent without error messages. Fixes the misleading behavior described in the AWS documentation where setting an empty string should reset to built-in default silently. --- crates/chat-cli/src/cli/agent/mod.rs | 31 +++++++++++++++------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 369932091a..7ee44c601d 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -668,21 +668,24 @@ impl Agents { } if let Some(user_set_default) = os.database.settings.get_string(Setting::ChatDefaultAgent) { - if all_agents.iter().any(|a| a.name == user_set_default) { - break 'active_idx user_set_default; + // Treat empty strings as "no default set" to allow clean reset + if !user_set_default.is_empty() { + if all_agents.iter().any(|a| a.name == user_set_default) { + break 'active_idx user_set_default; + } + let _ = queue!( + output, + style::SetForegroundColor(Color::Red), + style::Print("Error"), + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + ": user defined default {} not found. Falling back to in-memory default", + user_set_default + )), + style::Print("\n"), + style::SetForegroundColor(Color::Reset) + ); } - let _ = queue!( - output, - style::SetForegroundColor(Color::Red), - style::Print("Error"), - style::SetForegroundColor(Color::Yellow), - style::Print(format!( - ": user defined default {} not found. Falling back to in-memory default", - user_set_default - )), - style::Print("\n"), - style::SetForegroundColor(Color::Reset) - ); } all_agents.push({ From 2482110ee25cd444bfcb5a4ebc295a4bb22ae8c1 Mon Sep 17 00:00:00 2001 From: Michael Bennett Cohn <99377421+mibeco@users.noreply.github.com> Date: Fri, 26 Sep 2025 17:27:40 -0700 Subject: [PATCH 71/71] Update crates/chat-cli/build.rs Co-authored-by: amazon-q-developer[bot] <208079219+amazon-q-developer[bot]@users.noreply.github.com> --- crates/chat-cli/build.rs | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/crates/chat-cli/build.rs b/crates/chat-cli/build.rs index 5d8f6343e2..0735d05e78 100644 --- a/crates/chat-cli/build.rs +++ b/crates/chat-cli/build.rs @@ -324,19 +324,16 @@ fn main() { let file: syn::File = syn::parse_str(&out).unwrap(); let pp = prettyplease::unparse(&file); - // write an empty file to the output directory - std::fs::write(format!("{}/mod.rs", outdir), pp).unwrap(); -} - /// Downloads the latest feed.json from the autocomplete repository. /// This ensures official builds have the most up-to-date changelog information. /// /// # Errors /// -/// Prints cargo warnings if: +/// Returns errors if: /// - `curl` command is not available /// - Network request fails /// - File write operation fails +/// - Downloaded content fails validation fn download_feed_json() { use std::process::Command; @@ -346,9 +343,8 @@ fn download_feed_json() { let curl_check = Command::new("curl").arg("--version").output(); if curl_check.is_err() { - panic!( - "curl command not found. Cannot download latest feed.json. Please install curl or build without FETCH_FEED=1 to use existing feed.json." - ); + eprintln!("curl command not found. Cannot download latest feed.json. Please install curl or build without FETCH_FEED=1 to use existing feed.json."); + std::process::exit(1); } let output = Command::new("curl") @@ -359,14 +355,22 @@ fn download_feed_json() { "-s", // silent "-v", // verbose output printed to stderr "--show-error", // print error message to stderr (since -s is used) - "https://api.github.com/repos/aws/amazon-q-developer-cli-autocomplete/contents/feed.json", + "--max-filesize", "1048576", // 1MB limit + "", ]) .output(); match output { Ok(result) if result.status.success() => { + // Basic validation - ensure it's valid JSON + if let Err(e) = serde_json::from_slice::(&result.stdout) { + eprintln!("Downloaded content is not valid JSON: {}", e); + std::process::exit(1); + } + if let Err(e) = std::fs::write("src/cli/feed.json", result.stdout) { - panic!("Failed to write feed.json: {}", e); + eprintln!("Failed to write feed.json: {}", e); + std::process::exit(1); } else { println!("cargo:warning=Successfully downloaded latest feed.json"); } @@ -377,10 +381,13 @@ fn download_feed_json() { } else { "An unknown error occurred".to_string() }; - panic!("Failed to download feed.json: {}", error_msg); + eprintln!("Failed to download feed.json: {}", error_msg); + std::process::exit(1); }, Err(e) => { - panic!("Failed to execute curl: {}", e); + eprintln!("Failed to execute curl: {}", e); + std::process::exit(1); }, } } +}