Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions crates/coco-tui/src/components/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use code_combo::tools::{BashInput, Final};
use code_combo::{
Agent, Block as ChatBlock, ChatResponse, ChatStreamUpdate, Config, Content as ChatContent,
Message as ChatMessage, Output, RuntimeOverrides, SessionEnv, Starter, StarterCommand,
StarterError, StarterEvent, StopReason, TextEdit, ToolUse, bash_unsafe_ranges,
StarterError, StarterEvent, StopReason, TextEdit, ToolUse, UsageStats, bash_unsafe_ranges,
discover_starters, load_runtime_overrides, parse_primary_command, save_runtime_overrides,
};
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
Expand Down Expand Up @@ -63,6 +63,7 @@ pub struct Chat<'a> {
shortcut_hints: ShortcutHintsPanel,
prev_focus: Option<Focus>,
combo_thinking_active: bool,
last_usage: Option<UsageStats>,

token_schedule_session_save: Option<CancellationToken>,
cancellation_guard: CancellationGuard,
Expand Down Expand Up @@ -225,6 +226,7 @@ impl Chat<'static> {
shortcut_hints: ShortcutHintsPanel::default(),
prev_focus: None,
combo_thinking_active: false,
last_usage: None,
token_schedule_session_save: None,
cancellation_guard: CancellationGuard::default(),
}
Expand Down Expand Up @@ -677,6 +679,9 @@ impl Chat<'static> {
if let Some(line) = self.ctrl_c_reminder_line() {
block = block.title_bottom(line);
}
if let Some(line) = self.context_usage_indicator() {
block = block.title_bottom(line);
}
block
}

Expand Down Expand Up @@ -744,6 +749,31 @@ impl Chat<'static> {
])
}

fn context_usage_indicator(&self) -> Option<Line<'static>> {
let usage = self.last_usage.as_ref()?;
let input_tokens = usage.input_tokens.unwrap_or(0);
let output_tokens = usage.output_tokens.unwrap_or(0);
let used_tokens = usage.total_tokens.or_else(|| {
if usage.input_tokens.is_some() || usage.output_tokens.is_some() {
Some(input_tokens.saturating_add(output_tokens))
} else {
None
}
})?;
let theme = global::theme();
let text = if let Some(window) = self.agent.context_window() {
let percent = if window == 0 {
0
} else {
((used_tokens as f64 / window as f64) * 100.0).round() as usize
};
format!(" ctx: {used_tokens}/{window} ({percent}%) ")
} else {
format!(" ctx: {used_tokens} tok ")
};
Some(Line::from(Span::styled(text, theme.ui.shortcut_desc)).alignment(Alignment::Right))
}

fn auto_accept_indicator(&self) -> Line<'static> {
let theme = global::theme();
let (status, status_style) = if self.state.auto_accept_edits {
Expand Down Expand Up @@ -903,6 +933,7 @@ impl Chat<'static> {
let model_override = model_override.cloned();
self.state.write().model_override = model_override.clone();
self.agent.set_model_override(model_override);
self.last_usage = None;
self.persist_runtime_overrides();
global::trigger_schedule_session_save();
}
Expand Down Expand Up @@ -1018,6 +1049,15 @@ impl Component for Chat<'static> {
self.messages
.append_stream_text(*index, *kind, text.clone());
}
Event::Answer(AnswerEvent::Usage { usage }) => {
match &mut self.last_usage {
Some(total) => add_usage(total, usage),
None => {
self.last_usage = Some(usage.clone());
}
}
global::signal_dirty();
}
Event::Answer(AnswerEvent::Cancelled) => {
self.messages.reset_stream();
self.set_ready();
Expand Down Expand Up @@ -1680,6 +1720,21 @@ fn is_safe_command(command: &str) -> bool {
!trimmed.is_empty() && bash_unsafe_ranges(command).is_empty()
}

fn add_usage(total: &mut UsageStats, delta: &UsageStats) {
let has_breakdown = delta.input_tokens.is_some() || delta.output_tokens.is_some();
if has_breakdown {
let input = total.input_tokens.unwrap_or(0) + delta.input_tokens.unwrap_or(0);
let output = total.output_tokens.unwrap_or(0) + delta.output_tokens.unwrap_or(0);
total.input_tokens = Some(input);
total.output_tokens = Some(output);
total.total_tokens = Some(input + output);
return;
}
if let Some(delta_total) = delta.total_tokens {
total.total_tokens = Some(total.total_tokens.unwrap_or(0) + delta_total);
}
}

/// Parse coco reply stdout output as JSON fields and validate required schemas.
fn parse_coco_reply_output(
output: &Final,
Expand Down Expand Up @@ -1826,6 +1881,9 @@ async fn handle_offload_combo_reply(
.map_err(|e| ComboReplyError::ChatFailed {
message: e.to_string(),
})?;
if let Some(usage) = chat_response.usage.clone() {
tx.send(AnswerEvent::Usage { usage }.into()).ok();
}

// Extract Bash tool_use from response
let blocks = match &chat_response.message.content {
Expand Down Expand Up @@ -2169,12 +2227,12 @@ async fn task_combo_execute(
let stream_name = name.clone();
let thinking_seen = Arc::new(AtomicBool::new(false));
let thinking_seen_stream = thinking_seen.clone();
let reply = agent
.reply_prompt_stream_with_thinking(
&system_prompt,
schemas,
thinking.clone(),
cancel_token.clone(),
let reply = agent
.reply_prompt_stream_with_thinking(
&system_prompt,
schemas,
thinking.clone(),
cancel_token.clone(),
move |update| {
let (index, kind, text) = match update {
ChatStreamUpdate::Plain { index, text } => {
Expand Down Expand Up @@ -2205,6 +2263,9 @@ async fn task_combo_execute(
reply
};
if let Ok(reply) = &reply {
if let Some(usage) = reply.usage.clone() {
tx.send(AnswerEvent::Usage { usage }.into()).ok();
}
let thinking = if streamed_thinking {
Vec::new()
} else {
Expand Down Expand Up @@ -2623,6 +2684,9 @@ async fn handle_chat_response(
streamed_thinking: bool,
) {
let tx = global::event_tx();
if let Some(usage) = chat_resp.usage.clone() {
tx.send(AnswerEvent::Usage { usage }.into()).ok();
}
let mut to_execute: Vec<code_combo::ToolUse> = vec![];
let mut bot_messages = match chat_resp.message.content {
ChatContent::Text(text) => {
Expand Down
5 changes: 4 additions & 1 deletion crates/coco-tui/src/events.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use code_combo::{
OutputChunk, Starter, TextEdit, ThinkingConfig, ToolUse,
OutputChunk, Starter, TextEdit, ThinkingConfig, ToolUse, UsageStats,
tools::{Final, SubagentEvent},
};
use crossterm::event::{KeyEvent, MouseEvent};
Expand Down Expand Up @@ -48,6 +48,9 @@ pub enum AnswerEvent {
text: String,
},
Cancelled,
Usage {
usage: UsageStats,
},
// Below events come from User
ToolOutput {
id: String,
Expand Down
8 changes: 7 additions & 1 deletion crates/openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use reqwest::{StatusCode, Url, header::HeaderMap};
use snafu::{ResultExt, Whatever};
use tracing::trace;

use crate::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, ErrorResponse};
use crate::{
ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, ErrorResponse,
StreamOptions,
};

pub struct Client {
model: String,
Expand Down Expand Up @@ -110,6 +113,9 @@ impl Client {
) -> Result<ChatCompletionStream> {
request.model = self.model.clone();
request.stream = Some(true);
request.stream_options = Some(StreamOptions {
include_usage: Some(true),
});
let url = self
.base_url
.join(self.api_path("chat/completions"))
Expand Down
12 changes: 12 additions & 0 deletions crates/openai/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ pub struct FunctionChoice {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub choices: Vec<ChatChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -142,6 +144,16 @@ pub struct ChatChoice {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub choices: Vec<ChatChoiceDelta>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
Loading