diff --git a/crates/coco-tui/src/components/chat.rs b/crates/coco-tui/src/components/chat.rs index 76e5c34..112b104 100644 --- a/crates/coco-tui/src/components/chat.rs +++ b/crates/coco-tui/src/components/chat.rs @@ -3,7 +3,7 @@ use code_combo::tools::{BashInput, Final}; use code_combo::{ Agent, Block as ChatBlock, ChatResponse, ChatStreamUpdate, Config, Content as ChatContent, Message as ChatMessage, Output, RuntimeOverrides, SessionEnv, Starter, StarterCommand, - StarterError, StarterEvent, StopReason, TextEdit, ToolUse, bash_unsafe_ranges, + StarterError, StarterEvent, StopReason, TextEdit, ToolUse, UsageStats, bash_unsafe_ranges, discover_starters, load_runtime_overrides, parse_primary_command, save_runtime_overrides, }; use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; @@ -63,6 +63,7 @@ pub struct Chat<'a> { shortcut_hints: ShortcutHintsPanel, prev_focus: Option, combo_thinking_active: bool, + last_usage: Option, token_schedule_session_save: Option, cancellation_guard: CancellationGuard, @@ -225,6 +226,7 @@ impl Chat<'static> { shortcut_hints: ShortcutHintsPanel::default(), prev_focus: None, combo_thinking_active: false, + last_usage: None, token_schedule_session_save: None, cancellation_guard: CancellationGuard::default(), } @@ -677,6 +679,9 @@ impl Chat<'static> { if let Some(line) = self.ctrl_c_reminder_line() { block = block.title_bottom(line); } + if let Some(line) = self.context_usage_indicator() { + block = block.title_bottom(line); + } block } @@ -744,6 +749,31 @@ impl Chat<'static> { ]) } + fn context_usage_indicator(&self) -> Option> { + let usage = self.last_usage.as_ref()?; + let input_tokens = usage.input_tokens.unwrap_or(0); + let output_tokens = usage.output_tokens.unwrap_or(0); + let used_tokens = usage.total_tokens.or_else(|| { + if usage.input_tokens.is_some() || usage.output_tokens.is_some() { + Some(input_tokens.saturating_add(output_tokens)) + } else { + None + } + })?; + let theme = global::theme(); + let text = if let Some(window) = self.agent.context_window() { + let percent = if window == 0 { + 0 + } else { + ((used_tokens as f64 / window as f64) * 100.0).round() as usize + }; + format!(" ctx: {used_tokens}/{window} ({percent}%) ") + } else { + format!(" ctx: {used_tokens} tok ") + }; + Some(Line::from(Span::styled(text, theme.ui.shortcut_desc)).alignment(Alignment::Right)) + } + fn auto_accept_indicator(&self) -> Line<'static> { let theme = global::theme(); let (status, status_style) = if self.state.auto_accept_edits { @@ -903,6 +933,7 @@ impl Chat<'static> { let model_override = model_override.cloned(); self.state.write().model_override = model_override.clone(); self.agent.set_model_override(model_override); + self.last_usage = None; self.persist_runtime_overrides(); global::trigger_schedule_session_save(); } @@ -1018,6 +1049,15 @@ impl Component for Chat<'static> { self.messages .append_stream_text(*index, *kind, text.clone()); } + Event::Answer(AnswerEvent::Usage { usage }) => { + match &mut self.last_usage { + Some(total) => add_usage(total, usage), + None => { + self.last_usage = Some(usage.clone()); + } + } + global::signal_dirty(); + } Event::Answer(AnswerEvent::Cancelled) => { self.messages.reset_stream(); self.set_ready(); @@ -1680,6 +1720,21 @@ fn is_safe_command(command: &str) -> bool { !trimmed.is_empty() && bash_unsafe_ranges(command).is_empty() } +fn add_usage(total: &mut UsageStats, delta: &UsageStats) { + let has_breakdown = delta.input_tokens.is_some() || delta.output_tokens.is_some(); + if has_breakdown { + let input = total.input_tokens.unwrap_or(0) + delta.input_tokens.unwrap_or(0); + let output = total.output_tokens.unwrap_or(0) + delta.output_tokens.unwrap_or(0); + total.input_tokens = Some(input); + total.output_tokens = Some(output); + total.total_tokens = Some(input + output); + return; + } + if let Some(delta_total) = delta.total_tokens { + total.total_tokens = Some(total.total_tokens.unwrap_or(0) + delta_total); + } +} + /// Parse coco reply stdout output as JSON fields and validate required schemas. fn parse_coco_reply_output( output: &Final, @@ -1826,6 +1881,9 @@ async fn handle_offload_combo_reply( .map_err(|e| ComboReplyError::ChatFailed { message: e.to_string(), })?; + if let Some(usage) = chat_response.usage.clone() { + tx.send(AnswerEvent::Usage { usage }.into()).ok(); + } // Extract Bash tool_use from response let blocks = match &chat_response.message.content { @@ -2169,12 +2227,12 @@ async fn task_combo_execute( let stream_name = name.clone(); let thinking_seen = Arc::new(AtomicBool::new(false)); let thinking_seen_stream = thinking_seen.clone(); - let reply = agent - .reply_prompt_stream_with_thinking( - &system_prompt, - schemas, - thinking.clone(), - cancel_token.clone(), + let reply = agent + .reply_prompt_stream_with_thinking( + &system_prompt, + schemas, + thinking.clone(), + cancel_token.clone(), move |update| { let (index, kind, text) = match update { ChatStreamUpdate::Plain { index, text } => { @@ -2205,6 +2263,9 @@ async fn task_combo_execute( reply }; if let Ok(reply) = &reply { + if let Some(usage) = reply.usage.clone() { + tx.send(AnswerEvent::Usage { usage }.into()).ok(); + } let thinking = if streamed_thinking { Vec::new() } else { @@ -2623,6 +2684,9 @@ async fn handle_chat_response( streamed_thinking: bool, ) { let tx = global::event_tx(); + if let Some(usage) = chat_resp.usage.clone() { + tx.send(AnswerEvent::Usage { usage }.into()).ok(); + } let mut to_execute: Vec = vec![]; let mut bot_messages = match chat_resp.message.content { ChatContent::Text(text) => { diff --git a/crates/coco-tui/src/events.rs b/crates/coco-tui/src/events.rs index 5ef410a..0dab7b1 100644 --- a/crates/coco-tui/src/events.rs +++ b/crates/coco-tui/src/events.rs @@ -1,5 +1,5 @@ use code_combo::{ - OutputChunk, Starter, TextEdit, ThinkingConfig, ToolUse, + OutputChunk, Starter, TextEdit, ThinkingConfig, ToolUse, UsageStats, tools::{Final, SubagentEvent}, }; use crossterm::event::{KeyEvent, MouseEvent}; @@ -48,6 +48,9 @@ pub enum AnswerEvent { text: String, }, Cancelled, + Usage { + usage: UsageStats, + }, // Below events come from User ToolOutput { id: String, diff --git a/crates/openai/src/client.rs b/crates/openai/src/client.rs index 933fd17..e14e858 100644 --- a/crates/openai/src/client.rs +++ b/crates/openai/src/client.rs @@ -10,7 +10,10 @@ use reqwest::{StatusCode, Url, header::HeaderMap}; use snafu::{ResultExt, Whatever}; use tracing::trace; -use crate::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, ErrorResponse}; +use crate::{ + ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, + StreamOptions, +}; pub struct Client { model: String, @@ -110,6 +113,9 @@ impl Client { ) -> Result { request.model = self.model.clone(); request.stream = Some(true); + request.stream_options = Some(StreamOptions { + include_usage: Some(true), + }); let url = self .base_url .join(self.api_path("chat/completions")) diff --git a/crates/openai/src/types.rs b/crates/openai/src/types.rs index dacbc82..25e5f0a 100644 --- a/crates/openai/src/types.rs +++ b/crates/openai/src/types.rs @@ -129,6 +129,8 @@ pub struct FunctionChoice { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionResponse { pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -142,6 +144,16 @@ pub struct ChatChoice { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionChunk { pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/agent.rs b/src/agent.rs index a8ebac3..ad5793c 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -63,6 +63,7 @@ pub struct Agent { pub struct ChatResponse { pub message: Message, pub stop_reason: Option, + pub usage: Option, } #[derive(Debug, Clone)] @@ -75,6 +76,7 @@ pub struct PromptReply { pub tool_use: ToolUse, pub response: String, pub thinking: Vec, + pub usage: Option, } enum StreamAction { @@ -86,6 +88,7 @@ struct StreamAccumulator { blocks: Vec, tool_inputs: HashMap, stop_reason: Option, + usage: Option, } impl StreamAccumulator { @@ -94,11 +97,18 @@ impl StreamAccumulator { blocks: Vec::new(), tool_inputs: HashMap::new(), stop_reason: None, + usage: None, } } - fn finish(self) -> (Vec, Option) { - (self.blocks, self.stop_reason) + fn finish( + self, + ) -> ( + Vec, + Option, + Option, + ) { + (self.blocks, self.stop_reason, self.usage) } fn handle_event( @@ -110,7 +120,12 @@ impl StreamAccumulator { F: FnMut(ChatStreamUpdate), { match event { - MessagesStreamEvent::MessageStart { .. } => Ok(StreamAction::Continue), + MessagesStreamEvent::MessageStart { message } => { + if let Some(usage) = message.usage { + self.update_usage(usage); + } + Ok(StreamAction::Continue) + } MessagesStreamEvent::ContentBlockStart { index, content_block, @@ -142,10 +157,13 @@ impl StreamAccumulator { self.finalize_tool_input(index)?; Ok(StreamAction::Continue) } - MessagesStreamEvent::MessageDelta { delta, .. } => { + MessagesStreamEvent::MessageDelta { delta, usage } => { if let Some(reason) = delta.stop_reason { self.stop_reason = Some(reason); } + if let Some(usage) = usage { + self.update_usage(usage); + } Ok(StreamAction::Continue) } MessagesStreamEvent::MessageStop => Ok(StreamAction::Stop), @@ -164,6 +182,22 @@ impl StreamAccumulator { } } + fn update_usage(&mut self, usage: crate::provider::UsageStats) { + match &mut self.usage { + Some(current) => current.merge(usage), + None => { + let mut current = usage; + if current.total_tokens.is_none() + && let (Some(input), Some(output)) = + (current.input_tokens, current.output_tokens) + { + current.total_tokens = Some(input + output); + } + self.usage = Some(current); + } + } + } + fn ensure_block_slot(&mut self, index: usize) { if self.blocks.len() <= index { self.blocks.resize_with(index + 1, || Block::Text { @@ -437,6 +471,7 @@ impl Agent { .whatever_context_display("failed to send messages")?; let stop_reason = response.stop_reason.clone(); + let usage = response.usage.clone(); let message = if response.content.is_empty() { Message::assistant(Content::Multiple(Vec::default())) } else { @@ -448,6 +483,7 @@ impl Agent { Ok(ChatResponse { message, stop_reason, + usage, }) } @@ -474,6 +510,7 @@ impl Agent { .whatever_context_display("failed to send messages")?; let stop_reason = response.stop_reason.clone(); + let usage = response.usage.clone(); let message = if response.content.is_empty() { Message::assistant(Content::Multiple(Vec::default())) } else { @@ -485,6 +522,7 @@ impl Agent { Ok(ChatResponse { message, stop_reason, + usage, }) } @@ -568,7 +606,7 @@ impl Agent { } } - let (blocks, stop_reason) = accumulator.finish(); + let (blocks, stop_reason, usage) = accumulator.finish(); let message = if blocks.is_empty() { Message::assistant(Content::Multiple(Vec::default())) } else { @@ -580,6 +618,7 @@ impl Agent { Ok(ChatResponse { message, stop_reason, + usage, }) } @@ -657,6 +696,7 @@ impl Agent { .whatever_context_display("failed to request prompt reply")? }; let stop_reason = response.stop_reason.clone(); + let usage = response.usage.clone(); if !response.content.is_empty() { let mut history = self.messages.lock().await; history.push(Message::assistant(Content::Multiple( @@ -698,6 +738,7 @@ impl Agent { tool_use, response, thinking, + usage, }); } } @@ -794,7 +835,7 @@ impl Agent { } } - let (blocks, stop_reason) = accumulator.finish(); + let (blocks, stop_reason, usage) = accumulator.finish(); if !blocks.is_empty() { let msg = Message::assistant(Content::Multiple(blocks.clone())); self.messages.lock().await.push(msg); @@ -834,6 +875,7 @@ impl Agent { tool_use, response, thinking, + usage, }); } } @@ -871,6 +913,10 @@ impl Agent { self.request_options_for_current_model().combo_reply_retries } + pub fn context_window(&self) -> Option { + self.request_options_for_current_model().context_window + } + pub fn current_model(&self) -> String { let selected_model = self.selected_model(); match Self::select_provider_index(selected_model.as_deref(), &self.config.providers) { @@ -1257,7 +1303,7 @@ mod tests { other => panic!("unexpected update: {other:?}"), } - let (blocks, stop_reason) = accumulator.finish(); + let (blocks, stop_reason, _) = accumulator.finish(); assert_eq!(stop_reason, Some(StopReason::EndTurn)); assert_eq!(blocks.len(), 2); match &blocks[0] { diff --git a/src/config.rs b/src/config.rs index a8b7cf6..5ddce00 100644 --- a/src/config.rs +++ b/src/config.rs @@ -106,7 +106,7 @@ impl Config { fn apply_model_presets(options: &mut RequestOptions, presets: &[ModelRequestConfig], model: &str) { for preset in presets { - if preset.model == model { + if preset.model.eq_ignore_ascii_case(model) { options.apply_override(preset); } } diff --git a/src/provider/anthropic.rs b/src/provider/anthropic.rs index 1b22ee1..1902f00 100644 --- a/src/provider/anthropic.rs +++ b/src/provider/anthropic.rs @@ -1,11 +1,9 @@ -use serde_json::Value; - use ::anthropic as anthropic_api; use crate::provider::types::{ Block, Content, ContentBlockDelta, Message, MessageDelta, MessagesResponse, MessagesStreamEvent, Role, StopReason, StreamErrorDetail, StreamUsage, Thinking, Tool, - ToolChoice, ToolUse, + ToolChoice, ToolUse, UsageStats, }; impl From for Role { @@ -301,10 +299,12 @@ impl From for anthropic_api::MessageDelta { impl From for MessagesResponse { fn from(value: anthropic_api::MessagesResponse) -> Self { + let usage = usage_stats_from_usage(&value.usage); Self { content: value.content.into_iter().map(Into::into).collect(), stop_reason: value.stop_reason.map(Into::into), stop_sequence: value.stop_sequence, + usage, } } } @@ -319,8 +319,12 @@ impl From for StreamErrorDetail { } } -fn convert_stream_usage(usage: anthropic_api::StreamUsage) -> StreamUsage { - serde_json::to_value(usage).unwrap_or(Value::Null) +fn usage_stats_from_usage(usage: &anthropic_api::Usage) -> Option { + Some(UsageStats { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.input_tokens + usage.output_tokens), + }) } impl From for MessagesStreamEvent { @@ -362,3 +366,15 @@ impl From for MessagesStreamEvent { } } } + +fn convert_stream_usage(usage: anthropic_api::StreamUsage) -> StreamUsage { + let total_tokens = match (usage.input_tokens, usage.output_tokens) { + (Some(input_tokens), Some(output_tokens)) => Some(input_tokens + output_tokens), + _ => None, + }; + UsageStats { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + total_tokens, + } +} diff --git a/src/provider/openai.rs b/src/provider/openai.rs index a2babcc..7ac3faa 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -14,6 +14,7 @@ use crate::RequestOptions; use crate::provider::types::{ Block, Content, ContentBlockDelta, Message, MessageDelta, MessagesResponse, MessagesStreamEvent, Role, StopReason, StreamErrorDetail, Thinking, Tool, ToolChoice, ToolUse, + UsageStats, }; struct ToolCallState { @@ -311,7 +312,9 @@ fn response_into_messages(response: openai_api::ChatCompletionResponse) -> Messa let mut content_blocks = Vec::new(); let mut stop_reason = None; let stop_sequence = None; - if let Some(choice) = response.choices.into_iter().next() { + let openai_api::ChatCompletionResponse { choices, usage } = response; + let usage = usage.map(usage_stats_from_openai); + if let Some(choice) = choices.into_iter().next() { stop_reason = choice.finish_reason.and_then(map_finish_reason); if let Some(reasoning_content) = choice.message.reasoning_content && !reasoning_content.is_empty() @@ -341,6 +344,15 @@ fn response_into_messages(response: openai_api::ChatCompletionResponse) -> Messa content: content_blocks, stop_reason, stop_sequence, + usage, + } +} + +fn usage_stats_from_openai(usage: openai_api::Usage) -> UsageStats { + UsageStats { + input_tokens: Some(usage.prompt_tokens), + output_tokens: Some(usage.completion_tokens), + total_tokens: Some(usage.total_tokens), } } @@ -378,6 +390,7 @@ impl OpenAIStream { content: Vec::new(), stop_reason: None, stop_sequence: None, + usage: None, }, }); Self { @@ -454,7 +467,17 @@ impl Stream for OpenAIStream { Poll::Pending => return Poll::Pending, Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), Poll::Ready(Some(Ok(chunk))) => { - if let Some(choice) = chunk.choices.into_iter().next() { + let openai_api::ChatCompletionChunk { choices, usage } = chunk; + if let Some(usage) = usage { + this.push_event(MessagesStreamEvent::MessageDelta { + delta: MessageDelta { + stop_reason: None, + stop_sequence: None, + }, + usage: Some(usage_stats_from_openai(usage)), + }); + } + if let Some(choice) = choices.into_iter().next() { let delta = choice.delta; if let Some(reasoning_content) = delta.reasoning_content && !reasoning_content.is_empty() diff --git a/src/provider/types.rs b/src/provider/types.rs index edee5fb..5c3fa9a 100644 --- a/src/provider/types.rs +++ b/src/provider/types.rs @@ -1,6 +1,40 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub struct UsageStats { + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tokens: Option, +} + +impl UsageStats { + pub fn merge(&mut self, other: UsageStats) { + let updated_input = other.input_tokens.is_some(); + let updated_output = other.output_tokens.is_some(); + let updated_total = other.total_tokens.is_some(); + if other.input_tokens.is_some() { + self.input_tokens = other.input_tokens; + } + if other.output_tokens.is_some() { + self.output_tokens = other.output_tokens; + } + if other.total_tokens.is_some() { + self.total_tokens = other.total_tokens; + } + if !updated_total + && (updated_input || updated_output || self.total_tokens.is_none()) + && let (Some(input), Some(output)) = (self.input_tokens, self.output_tokens) + { + self.total_tokens = Some(input + output); + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum Role { @@ -208,6 +242,8 @@ pub struct MessagesResponse { pub stop_reason: Option, #[serde(skip_serializing_if = "Option::is_none")] pub stop_sequence: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -247,7 +283,7 @@ pub struct StreamErrorDetail { pub code: Option, } -pub type StreamUsage = Value; +pub type StreamUsage = UsageStats; #[derive(Debug, Clone)] pub enum MessagesStreamEvent { @@ -279,3 +315,25 @@ pub enum MessagesStreamEvent { data: Value, }, } + +#[cfg(test)] +mod tests { + use super::UsageStats; + + #[test] + fn usage_merge_recomputes_total() { + let mut usage = UsageStats { + input_tokens: Some(10), + output_tokens: Some(2), + total_tokens: None, + }; + usage.merge(UsageStats { + input_tokens: None, + output_tokens: Some(5), + total_tokens: None, + }); + assert_eq!(usage.input_tokens, Some(10)); + assert_eq!(usage.output_tokens, Some(5)); + assert_eq!(usage.total_tokens, Some(15)); + } +}