Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src-tauri/src/services/proxy/headers/claude_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl RequestProcessor for ClaudeHeadersProcessor {
response_status: u16,
response_body: &[u8],
is_sse: bool,
_response_time_ms: Option<i64>,
response_time_ms: Option<i64>,
) -> Result<()> {
use crate::services::proxy::log_recorder::{
LogRecorder, RequestLogContext, ResponseParser,
Expand All @@ -178,6 +178,7 @@ impl RequestProcessor for ClaudeHeadersProcessor {
client_ip,
proxy_pricing_template_id,
request_body,
response_time_ms,
);

// 2. 解析响应
Expand Down
10 changes: 3 additions & 7 deletions src-tauri/src/services/proxy/log_recorder/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

use crate::services::session::manager::SESSION_MANAGER;
use crate::services::session::models::ProxySession;
use std::time::Instant;

/// 请求日志上下文(在请求处理早期提取)
#[derive(Debug, Clone)]
Expand All @@ -17,7 +16,7 @@ pub struct RequestLogContext {
pub model: Option<String>, // 从 request_body 提取
pub is_stream: bool, // 从 request_body 提取 stream 字段
pub request_body: Vec<u8>, // 保留原始请求体
pub start_time: Instant,
pub response_time_ms: Option<i64>, // 响应时间(毫秒)
}

impl RequestLogContext {
Expand All @@ -28,6 +27,7 @@ impl RequestLogContext {
client_ip: &str,
proxy_pricing_template_id: Option<&str>,
request_body: &[u8],
response_time_ms: Option<i64>,
) -> Self {
// 提取 user_id(完整)、display_id(用于日志)、model 和 stream(仅解析一次)
let (user_id, session_id, model, is_stream) = if !request_body.is_empty() {
Expand Down Expand Up @@ -70,7 +70,7 @@ impl RequestLogContext {
model,
is_stream,
request_body: request_body.to_vec(),
start_time: Instant::now(),
response_time_ms,
}
}

Expand Down Expand Up @@ -107,8 +107,4 @@ impl RequestLogContext {
proxy_template_id.map(|s| s.to_string()),
)
}

pub fn elapsed_ms(&self) -> i64 {
self.start_time.elapsed().as_millis() as i64
}
}
14 changes: 7 additions & 7 deletions src-tauri/src/services/proxy/log_recorder/recorder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl LogRecorder {
&context.client_ip,
&context.request_body,
ResponseData::Sse(data_lines),
Some(context.elapsed_ms()),
context.response_time_ms,
context.pricing_template_id.clone(),
)
.await
Expand Down Expand Up @@ -98,7 +98,7 @@ impl LogRecorder {
"parse_error",
&error_detail,
"sse",
Some(context.elapsed_ms()),
context.response_time_ms,
)
.await
}
Expand All @@ -120,7 +120,7 @@ impl LogRecorder {
&context.client_ip,
&context.request_body,
ResponseData::Json(data),
Some(context.elapsed_ms()),
context.response_time_ms,
context.pricing_template_id.clone(),
)
.await
Expand Down Expand Up @@ -153,7 +153,7 @@ impl LogRecorder {
"parse_error",
&error_detail,
"json",
Some(context.elapsed_ms()),
context.response_time_ms,
)
.await
}
Expand Down Expand Up @@ -186,7 +186,7 @@ impl LogRecorder {
"parse_error",
&error_detail,
response_type,
Some(context.elapsed_ms()),
context.response_time_ms,
)
.await
}
Expand All @@ -213,7 +213,7 @@ impl LogRecorder {
"upstream_error",
detail,
response_type, // 根据请求体的 stream 字段判断
Some(context.elapsed_ms()),
context.response_time_ms,
)
.await
}
Expand Down Expand Up @@ -250,7 +250,7 @@ impl LogRecorder {
"upstream_error",
&error_detail,
response_type, // 根据请求体的 stream 字段判断
Some(context.elapsed_ms()),
context.response_time_ms,
)
.await
}
Expand Down
90 changes: 67 additions & 23 deletions src-tauri/src/services/proxy/proxy_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,47 +482,86 @@ async fn handle_request_inner(
let sse_chunks = Arc::new(Mutex::new(Vec::new()));
let sse_chunks_clone = Arc::clone(&sse_chunks);

// 创建一个通道,在流完全消费后触发统计
let (stream_end_tx, stream_end_rx) = tokio::sync::oneshot::channel::<()>();

let stream = upstream_res.bytes_stream();

// amp-code 需要移除工具名前缀
let is_amp_code = tool_id == "amp-code";
let prefix_regex = Regex::new(r#""name"\s*:\s*"mcp_([^"]+)""#).ok();

// 拦截流数据并收集
let mapped_stream = stream.map(move |result| {
if let Ok(chunk) = &result {
if let Ok(mut chunks) = sse_chunks_clone.lock() {
chunks.push(chunk.clone());
let mapped_stream = stream
.map(move |result| {
match &result {
Ok(chunk) => {
if let Ok(mut chunks) = sse_chunks_clone.lock() {
chunks.push(chunk.clone());
}
}
Err(_) => {
// 流错误 - 注意: stream_completed_clone 已被移除,不再需要通知
}
}
}
result
.map(|bytes| {
if is_amp_code {
if let Some(ref re) = prefix_regex {
let text = String::from_utf8_lossy(&bytes);
let cleaned = re.replace_all(&text, r#""name": "$1""#);
Frame::data(Bytes::from(cleaned.into_owned()))

result
.map(|bytes| {
if is_amp_code {
if let Some(ref re) = prefix_regex {
let text = String::from_utf8_lossy(&bytes);
let cleaned = re.replace_all(&text, r#""name": "$1""#);
Frame::data(Bytes::from(cleaned.into_owned()))
} else {
Frame::data(bytes)
}
} else {
Frame::data(bytes)
}
} else {
Frame::data(bytes)
}
})
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
});
})
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
})
// 在流的最后一个元素之后插入完成信号
.chain(futures_util::stream::once(async move {
// 发送流完成信号
let _ = stream_end_tx.send(());
tracing::debug!("SSE 流已完全消费完毕,发送完成信号");
// 返回一个永远不会被使用的 Err,会被下游过滤掉
Err(Box::new(std::io::Error::other("__stream_end_marker__"))
as Box<dyn std::error::Error + Send + Sync>)
}))
// 过滤掉结束标记
.filter(|item| {
let is_end_marker = item
.as_ref()
.err()
.and_then(|e| e.downcast_ref::<std::io::Error>())
.map(|e| e.to_string().contains("__stream_end_marker__"))
.unwrap_or(false);
futures_util::future::ready(!is_end_marker)
});

// 在流结束后异步记录日志
// 在流真正结束后异步记录日志
let processor_clone = Arc::clone(&processor);
let client_ip_clone = client_ip.clone();
let request_body_clone = processed.body.clone();
let response_status = status.as_u16();
let start_time_clone = start_time; // 捕获 start_time 用于计算响应时间
let start_time_clone = start_time;
let proxy_pricing_template_id_clone = proxy_pricing_template_id.clone();

tokio::spawn(async move {
// 等待流结束(延迟确保所有 chunks 已收集)
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
// 等待流完全消费的信号(无超时,真正等待流结束)
match stream_end_rx.await {
Ok(_) => {
tracing::info!("✓ 收到 SSE 流完成信号,流已完全消费");
}
Err(_) => {
tracing::warn!("✗ 未收到 SSE 流完成信号(sender 被 drop),可能流被提前取消");
}
}

// 小延迟确保最后的 chunk 写入完成(异步锁竞争)
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;

let chunks = match sse_chunks.lock() {
Ok(guard) => guard.clone(),
Expand All @@ -532,13 +571,18 @@ async fn handle_request_inner(
}
};

tracing::info!(
chunks_count = chunks.len(),
"开始处理 SSE chunks 进行 token 统计"
);

// 将所有 chunk 合并为完整响应体
let mut full_data = Vec::new();
for chunk in &chunks {
full_data.extend_from_slice(chunk);
}

// 计算响应时间
// 计算响应时间(从请求开始到流完全消费的总时间)
let response_time_ms = start_time_clone.elapsed().as_millis() as i64;

// 调用工具特定的日志记录
Expand Down
94 changes: 53 additions & 41 deletions src-tauri/src/services/token_stats/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ impl TokenExtractor for ClaudeTokenExtractor {

let event_type = json.get("type").and_then(|v| v.as_str()).unwrap_or("");

tracing::debug!(event_type = event_type, "解析 SSE 事件");

let mut result = SseTokenData::default();

match event_type {
Expand Down Expand Up @@ -193,47 +195,57 @@ impl TokenExtractor for ClaudeTokenExtractor {
}
}
"message_delta" => {
// 检查是否有 stop_reason(任何值都接受:end_turn, tool_use, max_tokens 等)
if let Some(delta) = json.get("delta") {
if delta.get("stop_reason").and_then(|v| v.as_str()).is_some() {
if let Some(usage) = json.get("usage") {
// 提取缓存创建 token:优先读取扁平字段,回退到嵌套对象
let cache_creation = usage
.get("cache_creation_input_tokens")
.and_then(|v| v.as_i64())
.unwrap_or_else(|| {
if let Some(cache_obj) = usage.get("cache_creation") {
let ephemeral_5m = cache_obj
.get("ephemeral_5m_input_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(0);
let ephemeral_1h = cache_obj
.get("ephemeral_1h_input_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(0);
ephemeral_5m + ephemeral_1h
} else {
0
}
});

let cache_read = usage
.get("cache_read_input_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(0);

let output_tokens = usage
.get("output_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(0);

result.message_delta = Some(MessageDeltaData {
cache_creation_tokens: cache_creation,
cache_read_tokens: cache_read,
output_tokens,
});
}
}
tracing::info!("检测到 message_delta 事件");

// message_delta 事件包含最终的usage统计
// 条件:必须有 usage 字段(无论是否有 stop_reason)
if let Some(usage) = json.get("usage") {
tracing::info!("message_delta 包含 usage 字段");

// 提取缓存创建 token:优先读取扁平字段,回退到嵌套对象
let cache_creation = usage
.get("cache_creation_input_tokens")
.and_then(|v| v.as_i64())
.unwrap_or_else(|| {
if let Some(cache_obj) = usage.get("cache_creation") {
let ephemeral_5m = cache_obj
.get("ephemeral_5m_input_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(0);
let ephemeral_1h = cache_obj
.get("ephemeral_1h_input_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(0);
ephemeral_5m + ephemeral_1h
} else {
0
}
});

let cache_read = usage
.get("cache_read_input_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(0);

let output_tokens = usage
.get("output_tokens")
.and_then(|v| v.as_i64())
.unwrap_or(0);

tracing::info!(
output_tokens = output_tokens,
cache_creation = cache_creation,
cache_read = cache_read,
"message_delta 提取成功"
);

result.message_delta = Some(MessageDeltaData {
cache_creation_tokens: cache_creation,
cache_read_tokens: cache_read,
output_tokens,
});
} else {
tracing::warn!("message_delta 事件缺少 usage 字段");
}
}
_ => {}
Expand Down
Loading