diff --git a/package-lock.json b/package-lock.json index b84cf99..2eb3c86 100644 --- a/package-lock.json +++ b/package-lock.json @@ -15,9 +15,11 @@ "@radix-ui/react-alert-dialog": "^1.1.15", "@radix-ui/react-avatar": "^1.1.11", "@radix-ui/react-checkbox": "^1.3.3", + "@radix-ui/react-collapsible": "^1.1.12", "@radix-ui/react-dialog": "^1.1.15", "@radix-ui/react-dropdown-menu": "^2.1.16", "@radix-ui/react-label": "^2.1.8", + "@radix-ui/react-popover": "^1.1.15", "@radix-ui/react-progress": "^1.1.8", "@radix-ui/react-radio-group": "^1.3.8", "@radix-ui/react-scroll-area": "^1.2.10", @@ -37,8 +39,9 @@ "ipaddr.js": "^2.3.0", "lucide-react": "^0.552.0", "react": "^19.2.1", + "react-day-picker": "^9.13.0", "react-dom": "^19.2.1", - "recharts": "^3.3.0", + "recharts": "^3.6.0", "tailwind-merge": "^3.3.1" }, "devDependencies": { @@ -47,6 +50,7 @@ "@types/node": "^20.19.25", "@types/react": "^19.2.2", "@types/react-dom": "^19.2.2", + "@types/recharts": "^1.8.29", "@vitejs/plugin-react": "^5.1.0", "autoprefixer": "^10.4.16", "concurrently": "^9.2.1", @@ -362,6 +366,12 @@ "node": ">=6.9.0" } }, + "node_modules/@date-fns/tz": { + "version": "1.4.1", + "resolved": "https://registry.npmmirror.com/@date-fns/tz/-/tz-1.4.1.tgz", + "integrity": "sha512-P5LUNhtbj6YfI3iJjw5EL9eUAG6OitD0W3fWQcpQjDRc/QIsL0tRNuO1PcDvPccWL1fSTXXdE1ds+l95DV/OFA==", + "license": "MIT" + }, "node_modules/@dnd-kit/accessibility": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/@dnd-kit/accessibility/-/accessibility-3.1.1.tgz", @@ -1432,6 +1442,36 @@ } } }, + "node_modules/@radix-ui/react-collapsible": { + "version": "1.1.12", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-collapsible/-/react-collapsible-1.1.12.tgz", + "integrity": "sha512-Uu+mSh4agx2ib1uIGPP4/CKNULyajb3p92LsVXmH2EHVMTfZWpll88XJ0j4W0z3f8NK1eYl1+Mf/szHPmcHzyA==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.3", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-presence": "1.1.5", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-collection": { "version": "1.1.7", "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.7.tgz", @@ -1793,6 +1833,61 @@ } } }, + "node_modules/@radix-ui/react-popover": { + "version": "1.1.15", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-popover/-/react-popover-1.1.15.tgz", + "integrity": "sha512-kr0X2+6Yy/vJzLYJUPCZEc8SfQcf+1COFoAqauJm74umQhta9M7lNJHP7QQS3vkvcGLQUbWpMzwrXYwrYztHKA==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.3", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-dismissable-layer": "1.1.11", + "@radix-ui/react-focus-guards": "1.1.3", + "@radix-ui/react-focus-scope": "1.1.7", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-popper": "1.2.8", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-presence": "1.1.5", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-use-controllable-state": "1.2.2", + "aria-hidden": "^1.2.4", + "react-remove-scroll": "^2.6.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-slot": { + "version": "1.2.3", + "resolved": "https://registry.npmmirror.com/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", + "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-popper": { "version": "1.2.8", "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.8.tgz", @@ -3273,6 +3368,34 @@ "@types/react": "^19.2.0" } }, + "node_modules/@types/recharts": { + "version": "1.8.29", + "resolved": "https://registry.npmmirror.com/@types/recharts/-/recharts-1.8.29.tgz", + "integrity": "sha512-ulKklaVsnFIIhTQsQw226TnOibrddW1qUQNFVhoQEyY1Z7FRQrNecFCGt7msRuJseudzE9czVawZb17dK/aPXw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-shape": "^1", + "@types/react": "*" + } + }, + "node_modules/@types/recharts/node_modules/@types/d3-path": { + "version": "1.0.11", + "resolved": "https://registry.npmmirror.com/@types/d3-path/-/d3-path-1.0.11.tgz", + "integrity": "sha512-4pQMp8ldf7UaB/gR8Fvvy69psNHkTpD/pVw3vmEi8iZAB9EPMBruB1JvHO4BIq9QkUUd2lV1F5YXpMNj7JPBpw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/recharts/node_modules/@types/d3-shape": { + "version": "1.3.12", + "resolved": "https://registry.npmmirror.com/@types/d3-shape/-/d3-shape-1.3.12.tgz", + "integrity": "sha512-8oMzcd4+poSLGgV0R1Q1rOlx/xdmozS4Xab7np0eamFFUYq71AU9pOCJEFnkXW2aI/oXdVYJzw6pssbSut7Z9Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-path": "^1" + } + }, "node_modules/@types/use-sync-external-store": { "version": "0.0.6", "resolved": "https://registry.npmjs.org/@types/use-sync-external-store/-/use-sync-external-store-0.0.6.tgz", @@ -4529,6 +4652,12 @@ "url": "https://github.com/sponsors/kossnocorp" } }, + "node_modules/date-fns-jalali": { + "version": "4.1.0-0", + "resolved": "https://registry.npmmirror.com/date-fns-jalali/-/date-fns-jalali-4.1.0-0.tgz", + "integrity": "sha512-hTIP/z+t+qKwBDcmmsnmjWTduxCg+5KfdqWQvb2X/8C9+knYY6epN/pfxdDuyVlSVeFz0sM5eEfwIUQ70U4ckg==", + "license": "MIT" + }, "node_modules/debug": { "version": "4.4.3", "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", @@ -7102,6 +7231,27 @@ "node": ">=0.10.0" } }, + "node_modules/react-day-picker": { + "version": "9.13.0", + "resolved": "https://registry.npmmirror.com/react-day-picker/-/react-day-picker-9.13.0.tgz", + "integrity": "sha512-euzj5Hlq+lOHqI53NiuNhCP8HWgsPf/bBAVijR50hNaY1XwjKjShAnIe8jm8RD2W9IJUvihDIZ+KrmqfFzNhFQ==", + "license": "MIT", + "dependencies": { + "@date-fns/tz": "^1.4.1", + "date-fns": "^4.1.0", + "date-fns-jalali": "^4.1.0-0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "individual", + "url": "https://github.com/sponsors/gpbl" + }, + "peerDependencies": { + "react": ">=16.8.0" + } + }, "node_modules/react-dom": { "version": "19.2.1", "resolved": "https://registry.npmmirror.com/react-dom/-/react-dom-19.2.1.tgz", @@ -7247,10 +7397,13 @@ } }, "node_modules/recharts": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/recharts/-/recharts-3.3.0.tgz", - "integrity": "sha512-Vi0qmTB0iz1+/Cz9o5B7irVyUjX2ynvEgImbgMt/3sKRREcUM07QiYjS1QpAVrkmVlXqy5gykq4nGWMz9AS4Rg==", + "version": "3.6.0", + "resolved": "https://registry.npmmirror.com/recharts/-/recharts-3.6.0.tgz", + "integrity": "sha512-L5bjxvQRAe26RlToBAziKUB7whaGKEwD3znoM6fz3DrTowCIC/FnJYnuq1GEzB8Zv2kdTfaxQfi5GoH0tBinyg==", "license": "MIT", + "workspaces": [ + "www" + ], "dependencies": { "@reduxjs/toolkit": "1.x.x || 2.x.x", "clsx": "^2.1.1", diff --git a/package.json b/package.json index f4fd39a..03498fd 100644 --- a/package.json +++ b/package.json @@ -48,9 +48,11 @@ "@radix-ui/react-alert-dialog": "^1.1.15", "@radix-ui/react-avatar": "^1.1.11", "@radix-ui/react-checkbox": "^1.3.3", + "@radix-ui/react-collapsible": "^1.1.12", "@radix-ui/react-dialog": "^1.1.15", "@radix-ui/react-dropdown-menu": "^2.1.16", "@radix-ui/react-label": "^2.1.8", + "@radix-ui/react-popover": "^1.1.15", "@radix-ui/react-progress": "^1.1.8", "@radix-ui/react-radio-group": "^1.3.8", "@radix-ui/react-scroll-area": "^1.2.10", @@ -70,8 +72,9 @@ "ipaddr.js": "^2.3.0", "lucide-react": "^0.552.0", "react": "^19.2.1", + "react-day-picker": "^9.13.0", "react-dom": "^19.2.1", - "recharts": "^3.3.0", + "recharts": "^3.6.0", "tailwind-merge": "^3.3.1" }, "devDependencies": { @@ -80,6 +83,7 @@ "@types/node": "^20.19.25", "@types/react": "^19.2.2", "@types/react-dom": "^19.2.2", + "@types/recharts": "^1.8.29", "@vitejs/plugin-react": "^5.1.0", "autoprefixer": "^10.4.16", "concurrently": "^9.2.1", diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 7196c80..8ae1afa 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -891,6 +891,7 @@ dependencies = [ "chrono", "cocoa", "dirs", + "flate2", "fs2", "futures-util", "http-body-util", @@ -927,6 +928,7 @@ dependencies = [ "tracing-subscriber", "url", "urlencoding", + "uuid", "winreg 0.52.0", ] diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 2996be3..e3f9e03 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -50,6 +50,7 @@ pin-project-lite = "0.2" bytes = "1" futures-util = "0.3" async-trait = "0.1" +flate2 = "1.0" # gzip 解压缩支持 # 文件锁 fs2 = "0.4" # 数据库 @@ -61,6 +62,7 @@ notify = "6" linked-hash-map = "0.5" # 序列化/反序列化 bincode = "1.3" +uuid = { version = "1.18.1", features = ["v4"] } [dev-dependencies] tempfile = "3.8" diff --git a/src-tauri/src/commands/analytics_commands.rs b/src-tauri/src/commands/analytics_commands.rs new file mode 100644 index 0000000..c2d3c96 --- /dev/null +++ b/src-tauri/src/commands/analytics_commands.rs @@ -0,0 +1,391 @@ +//! Token统计分析相关的Tauri命令 + +use anyhow::Result; +use duckcoding::services::token_stats::{ + CostGroupBy, CostSummaryQuery, TimeGranularity, TokenStatsAnalytics, TrendDataPoint, TrendQuery, +}; +use duckcoding::utils::config_dir; +use serde::{Deserialize, Serialize}; + +/// 按模型分组的成本统计 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelCostStat { + /// 模型名称 + pub model: String, + /// 总成本(USD) + pub total_cost: f64, + /// 请求数 + pub request_count: i64, +} + +/// 按配置分组的成本统计 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConfigCostStat { + /// 配置名称 + pub config_name: String, + /// 总成本(USD) + pub total_cost: f64, + /// 请求数 + pub request_count: i64, +} + +/// 成本汇总数据(前端期望的格式) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostSummary { + /// 总成本(USD) + pub total_cost: f64, + /// 总请求数 + pub total_requests: i64, + /// 成功请求数 + pub successful_requests: i64, + /// 失败请求数 + pub failed_requests: i64, + /// 平均响应时间(毫秒) + pub avg_response_time: Option, + /// 按模型分组的成本 + pub cost_by_model: Vec, + /// 按配置分组的成本 + pub cost_by_config: Vec, + /// 按天的成本趋势 + pub daily_costs: Vec, +} + +/// 按天的成本统计 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DailyCost { + /// 日期(时间戳毫秒) + pub date: i64, + /// 总成本(USD) + pub cost: f64, +} + +/// 查询趋势数据 +/// +/// # 参数 +/// - `query`: 趋势查询参数 +/// +/// # 返回 +/// - `Ok(Vec)`: 按时间排序的趋势数据点列表 +/// - `Err`: 查询失败 +#[tauri::command] +pub async fn query_token_trends(query: TrendQuery) -> Result, String> { + let db_path = config_dir() + .map_err(|e| format!("Failed to get config dir: {}", e))? + .join("token_stats.db"); + + let analytics = TokenStatsAnalytics::new(db_path.clone()); + + analytics + .query_trends(&query) + .map_err(|e| format!("Failed to query trends: {}", e)) +} + +/// 查询成本汇总数据 +/// +/// # 参数 +/// - `start_time`: 开始时间戳(毫秒) +/// - `end_time`: 结束时间戳(毫秒) +/// - `tool_type`: 工具类型过滤(可选) +/// - `session_id`: 会话 ID 过滤(可选) +/// +/// # 返回 +/// - `Ok(CostSummary)`: 成本汇总数据 +/// - `Err`: 查询失败 +#[tauri::command] +pub async fn query_cost_summary( + start_time: i64, + end_time: i64, + tool_type: Option, + session_id: Option, +) -> Result { + let db_path = config_dir() + .map_err(|e| format!("Failed to get config dir: {}", e))? + .join("token_stats.db"); + + let analytics = TokenStatsAnalytics::new(db_path.clone()); + + // 构建基础查询参数 + let base_query = CostSummaryQuery { + start_time: Some(start_time), + end_time: Some(end_time), + tool_type: tool_type.clone(), + session_id: session_id.clone(), + group_by: CostGroupBy::Model, // 默认分组,实际查询时会覆盖 + }; + + // 1. 查询按模型分组的成本 + let model_query = CostSummaryQuery { + group_by: CostGroupBy::Model, + ..base_query.clone() + }; + let model_summaries = analytics + .query_cost_summary(&model_query) + .map_err(|e| format!("Failed to query cost by model: {}", e))?; + + // 2. 查询按配置分组的成本 + let config_query = CostSummaryQuery { + group_by: CostGroupBy::Config, + ..base_query.clone() + }; + let config_summaries = analytics + .query_cost_summary(&config_query) + .map_err(|e| format!("Failed to query cost by config: {}", e))?; + + // 3. 查询按天的成本趋势 + let trend_query = TrendQuery { + start_time: Some(start_time), + end_time: Some(end_time), + tool_type: tool_type.clone(), + session_id: session_id.clone(), + granularity: TimeGranularity::Day, + ..Default::default() + }; + let daily_trends = analytics + .query_trends(&trend_query) + .map_err(|e| format!("Failed to query daily trends: {}", e))?; + + // 4. 计算总计指标(通过聚合所有数据) + let total_cost: f64 = model_summaries.iter().map(|s| s.total_cost).sum(); + let total_requests: i64 = model_summaries.iter().map(|s| s.request_count).sum(); + + // 5. 查询成功和失败请求数(需要额外查询) + use duckcoding::data::DataManager; + let manager = DataManager::global() + .sqlite(&db_path) + .map_err(|e| format!("Failed to get SQLite manager: {}", e))?; + + // 构建 WHERE 子句 + let mut where_clauses = vec!["timestamp >= ?1", "timestamp <= ?2"]; + let mut params: Vec> = vec![ + Box::new(start_time) as Box, + Box::new(end_time), + ]; + + if let Some(ref tt) = tool_type { + where_clauses.push("tool_type = ?"); + params.push(Box::new(tt.clone())); + } + + if let Some(ref sid) = session_id { + where_clauses.push("session_id = ?"); + params.push(Box::new(sid.clone())); + } + + let where_clause = where_clauses.join(" AND "); + + let sql = format!( + "SELECT + COUNT(*) as total, + COALESCE(SUM(CASE WHEN request_status = 'success' THEN 1 ELSE 0 END), 0) as successful, + COALESCE(SUM(CASE WHEN request_status = 'failed' THEN 1 ELSE 0 END), 0) as failed, + AVG(response_time_ms) as avg_response_time + FROM token_logs + WHERE {}", + where_clause + ); + + let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect(); + + let (successful_requests, failed_requests, avg_response_time) = manager + .transaction(|tx| { + let mut stmt = tx + .prepare(&sql) + .map_err(duckcoding::data::DataError::Database)?; + + let result = stmt + .query_row(param_refs.as_slice(), |row| { + Ok(( + row.get::<_, i64>(1)?, + row.get::<_, i64>(2)?, + row.get::<_, Option>(3)?, + )) + }) + .map_err(duckcoding::data::DataError::Database)?; + + Ok(result) + }) + .map_err(|e| format!("Failed to query request stats: {}", e))?; + + // 6. 构建返回结果 + Ok(CostSummary { + total_cost, + total_requests, + successful_requests, + failed_requests, + avg_response_time, + cost_by_model: model_summaries + .into_iter() + .map(|s| ModelCostStat { + model: s.group_name, + total_cost: s.total_cost, + request_count: s.request_count, + }) + .collect(), + cost_by_config: config_summaries + .into_iter() + .map(|s| ConfigCostStat { + config_name: s.group_name, + total_cost: s.total_cost, + request_count: s.request_count, + }) + .collect(), + daily_costs: daily_trends + .into_iter() + .map(|d| DailyCost { + date: d.timestamp, + cost: d.total_cost, + }) + .collect(), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::TimeZone; + use duckcoding::models::token_stats::TokenLog; + use duckcoding::services::token_stats::db::TokenStatsDb; + use tempfile::tempdir; + + #[tokio::test] + async fn test_query_token_trends_command() { + // 创建临时数据库 + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test_trends.db"); + let db = TokenStatsDb::new(db_path.clone()); + db.init_table().unwrap(); + + // 插入测试数据(使用固定时间避免跨日期边界) + let base_time = chrono::Utc + .with_ymd_and_hms(2026, 1, 10, 12, 0, 0) + .unwrap() + .timestamp_millis(); + + for i in 0..10 { + let log = TokenLog::new( + "claude_code".to_string(), + base_time - (i * 3600 * 1000), // 每小时一条 + "127.0.0.1".to_string(), + "test_session".to_string(), + "default".to_string(), + "claude-sonnet-4-5-20250929".to_string(), + Some(format!("msg_{}", i)), + 100, + 50, + 10, + 20, + "success".to_string(), + "json".to_string(), + None, + None, + Some(100), + Some(0.001), + Some(0.002), + Some(0.0001), + Some(0.0002), + 0.0033, + Some("test_template".to_string()), + ); + db.insert_log(&log).unwrap(); + } + + // 创建查询 + let query = TrendQuery { + tool_type: Some("claude_code".to_string()), + granularity: TimeGranularity::Hour, + ..Default::default() + }; + + // 执行查询(通过直接调用analytics而不是tauri命令) + let analytics = TokenStatsAnalytics::new(db_path); + let trends = analytics.query_trends(&query).unwrap(); + + // 验证结果 + assert_eq!(trends.len(), 10); + assert!(trends[0].input_tokens > 0); + assert!(trends[0].total_cost > 0.0); + } + + #[tokio::test] + async fn test_query_cost_summary_command() { + // 创建临时数据库 + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test_cost_summary.db"); + let db = TokenStatsDb::new(db_path.clone()); + db.init_table().unwrap(); + + // 插入测试数据(多个模型和配置,使用固定时间) + let base_time = chrono::Utc + .with_ymd_and_hms(2026, 1, 10, 12, 0, 0) + .unwrap() + .timestamp_millis(); + + let models = ["claude-sonnet-4-5-20250929", "claude-3-opus-20240229"]; + let configs = ["default", "custom"]; + + for (i, model) in models.iter().enumerate() { + for (j, config) in configs.iter().enumerate() { + for k in 0..3 { + let log = TokenLog::new( + "claude_code".to_string(), + base_time - (k * 1000), + "127.0.0.1".to_string(), + format!("session_{}_{}", i, j), + config.to_string(), + model.to_string(), + Some(format!("msg_{}_{}_{}", i, j, k)), + 100, + 50, + 10, + 20, + "success".to_string(), + "json".to_string(), + None, + None, + Some(100), + Some(0.001), + Some(0.002), + Some(0.0001), + Some(0.0002), + 0.0033, + Some("test_template".to_string()), + ); + db.insert_log(&log).unwrap(); + } + } + } + + // 执行查询 + let analytics = TokenStatsAnalytics::new(db_path); + + // 按模型分组 + let model_query = CostSummaryQuery { + tool_type: Some("claude_code".to_string()), + group_by: CostGroupBy::Model, + ..Default::default() + }; + let model_summaries = analytics.query_cost_summary(&model_query).unwrap(); + + // 验证结果 + assert_eq!(model_summaries.len(), 2); // 2个模型 + for summary in &model_summaries { + assert_eq!(summary.request_count, 6); // 每个模型6条记录(2个配置 × 3条) + assert!(summary.total_cost > 0.0); + } + + // 按配置分组 + let config_query = CostSummaryQuery { + tool_type: Some("claude_code".to_string()), + group_by: CostGroupBy::Config, + ..Default::default() + }; + let config_summaries = analytics.query_cost_summary(&config_query).unwrap(); + + // 验证结果 + assert_eq!(config_summaries.len(), 2); // 2个配置 + for summary in &config_summaries { + assert_eq!(summary.request_count, 6); // 每个配置6条记录(2个模型 × 3条) + assert!(summary.total_cost > 0.0); + } + } +} diff --git a/src-tauri/src/commands/config_commands.rs b/src-tauri/src/commands/config_commands.rs index d33064b..26f9d0f 100644 --- a/src-tauri/src/commands/config_commands.rs +++ b/src-tauri/src/commands/config_commands.rs @@ -51,6 +51,23 @@ pub async fn save_global_config(config: GlobalConfig) -> Result<(), String> { write_global_config(&config) } +/// 更新 Token 统计配置(部分更新,避免竞态条件) +#[tauri::command] +pub async fn update_token_stats_config( + config: ::duckcoding::models::config::TokenStatsConfig, +) -> Result<(), String> { + use ::duckcoding::utils::config::{read_global_config, write_global_config}; + + // 读取当前配置 + let mut global_config = read_global_config()?.ok_or_else(|| "全局配置不存在".to_string())?; + + // 仅更新 token_stats_config 字段 + global_config.token_stats_config = config; + + // 写回配置 + write_global_config(&global_config) +} + #[tauri::command] pub async fn get_global_config() -> Result, String> { read_global_config() diff --git a/src-tauri/src/commands/mod.rs b/src-tauri/src/commands/mod.rs index b537a9a..c449b41 100644 --- a/src-tauri/src/commands/mod.rs +++ b/src-tauri/src/commands/mod.rs @@ -1,10 +1,12 @@ pub mod amp_commands; // AMP 用户认证命令 +pub mod analytics_commands; // Token统计分析命令(Phase 4) pub mod balance_commands; pub mod config_commands; pub mod dashboard_commands; // 仪表板状态管理命令 pub mod error; // 错误处理统一模块 pub mod log_commands; pub mod onboarding; +pub mod pricing_commands; // 价格配置管理命令(Phase 6) pub mod profile_commands; // Profile 管理命令(v2.0) pub mod provider_commands; // 供应商管理命令(v1.5.0) pub mod proxy_commands; @@ -12,6 +14,7 @@ pub mod session_commands; pub mod startup_commands; // 开机自启动管理命令 pub mod stats_commands; pub mod token_commands; // 令牌资产管理命令(NEW API 集成) +pub mod token_stats_commands; // Token统计命令 pub mod tool_commands; pub mod tool_management; pub mod types; @@ -20,11 +23,13 @@ pub mod window_commands; // 重新导出所有命令函数 pub use amp_commands::*; // AMP 用户认证命令 +pub use analytics_commands::*; // Token统计分析命令(Phase 4) pub use balance_commands::*; pub use config_commands::*; pub use dashboard_commands::*; // 仪表板状态管理命令 pub use log_commands::*; pub use onboarding::*; +pub use pricing_commands::*; // 价格配置管理命令(Phase 6) pub use profile_commands::*; // Profile 管理命令(v2.0) pub use provider_commands::*; // 供应商管理命令(v1.5.0) pub use proxy_commands::*; @@ -32,6 +37,7 @@ pub use session_commands::*; pub use startup_commands::*; // 开机自启动管理命令 pub use stats_commands::*; pub use token_commands::*; // 令牌资产管理命令(NEW API 集成) +pub use token_stats_commands::*; // Token统计命令 pub use tool_commands::*; pub use tool_management::*; pub use update_commands::*; diff --git a/src-tauri/src/commands/onboarding.rs b/src-tauri/src/commands/onboarding.rs index 782d361..fd3ec91 100644 --- a/src-tauri/src/commands/onboarding.rs +++ b/src-tauri/src/commands/onboarding.rs @@ -29,6 +29,7 @@ fn create_minimal_config() -> GlobalConfig { single_instance_enabled: true, startup_enabled: false, config_watch: duckcoding::models::config::ConfigWatchConfig::default(), + token_stats_config: duckcoding::models::config::TokenStatsConfig::default(), } } diff --git a/src-tauri/src/commands/pricing_commands.rs b/src-tauri/src/commands/pricing_commands.rs new file mode 100644 index 0000000..de2b0b2 --- /dev/null +++ b/src-tauri/src/commands/pricing_commands.rs @@ -0,0 +1,102 @@ +/// 价格配置管理命令 +/// +/// 提供价格模板的 CRUD 操作和工具默认模板管理 +use duckcoding::models::pricing::PricingTemplate; +use duckcoding::services::pricing::PRICING_MANAGER; + +use super::error::AppResult; + +/// 列出所有价格模板 +/// +/// # 返回 +/// +/// 所有可用价格模板的列表 +#[tauri::command] +pub async fn list_pricing_templates() -> AppResult> { + let templates = PRICING_MANAGER.list_templates()?; + Ok(templates) +} + +/// 获取指定价格模板 +/// +/// # 参数 +/// +/// - `template_id`: 模板 ID +/// +/// # 返回 +/// +/// 价格模板详细信息 +#[tauri::command] +pub async fn get_pricing_template(template_id: String) -> AppResult { + let template = PRICING_MANAGER.get_template(&template_id)?; + Ok(template) +} + +/// 保存价格模板(创建或更新) +/// +/// # 参数 +/// +/// - `template`: 价格模板数据 +/// +/// # 注意 +/// +/// - 如果模板 ID 已存在,将覆盖现有模板 +/// - 不允许覆盖内置预设模板(is_default_preset = true) +#[tauri::command] +pub async fn save_pricing_template(template: PricingTemplate) -> AppResult<()> { + // 检查是否尝试覆盖内置模板 + if let Ok(existing) = PRICING_MANAGER.get_template(&template.id) { + if existing.is_default_preset && !template.is_default_preset { + return Err(anyhow::anyhow!("Cannot overwrite built-in preset template").into()); + } + } + + PRICING_MANAGER.save_template(&template)?; + Ok(()) +} + +/// 删除价格模板 +/// +/// # 参数 +/// +/// - `template_id`: 模板 ID +/// +/// # 注意 +/// +/// - 不允许删除内置预设模板 +#[tauri::command] +pub async fn delete_pricing_template(template_id: String) -> AppResult<()> { + PRICING_MANAGER.delete_template(&template_id)?; + Ok(()) +} + +/// 设置工具的默认价格模板 +/// +/// # 参数 +/// +/// - `tool_id`: 工具 ID(claude-code / codex / gemini-cli) +/// - `template_id`: 模板 ID +/// +/// # 注意 +/// +/// - 模板必须存在才能设置为默认模板 +#[tauri::command] +pub async fn set_default_template(tool_id: String, template_id: String) -> AppResult<()> { + PRICING_MANAGER.set_default_template(&tool_id, &template_id)?; + Ok(()) +} + +/// 获取工具的默认价格模板 +/// +/// # 参数 +/// +/// - `tool_id`: 工具 ID(claude-code / codex / gemini-cli) +/// +/// # 返回 +/// +/// 该工具当前使用的默认价格模板 +#[tauri::command] +pub async fn get_default_template(tool_id: String) -> AppResult { + let template = PRICING_MANAGER.get_default_template(&tool_id)?; + Ok(template) +} diff --git a/src-tauri/src/commands/profile_commands.rs b/src-tauri/src/commands/profile_commands.rs index 24d723c..996a2d3 100644 --- a/src-tauri/src/commands/profile_commands.rs +++ b/src-tauri/src/commands/profile_commands.rs @@ -16,12 +16,19 @@ pub struct ProfileManagerState { #[serde(tag = "type", rename_all = "kebab-case")] pub enum ProfileInput { #[serde(rename = "claude-code")] - Claude { api_key: String, base_url: String }, + Claude { + api_key: String, + base_url: String, + #[serde(default)] + pricing_template_id: Option, // 🆕 Phase 6: 价格模板 ID + }, #[serde(rename = "codex")] Codex { api_key: String, base_url: String, wire_api: String, + #[serde(default)] + pricing_template_id: Option, // 🆕 Phase 6: 价格模板 ID }, #[serde(rename = "gemini-cli")] Gemini { @@ -29,6 +36,8 @@ pub enum ProfileInput { base_url: String, #[serde(default)] model: Option, + #[serde(default)] + pricing_template_id: Option, // 🆕 Phase 6: 价格模板 ID }, } @@ -108,8 +117,18 @@ pub async fn pm_save_profile( match tool_id.as_str() { "claude-code" => { - if let ProfileInput::Claude { api_key, base_url } = input { - Ok(manager.save_claude_profile(&name, api_key, base_url)?) + if let ProfileInput::Claude { + api_key, + base_url, + pricing_template_id, + } = input + { + Ok(manager.save_claude_profile_with_template( + &name, + api_key, + base_url, + pricing_template_id, + )?) } else { Err(super::error::AppError::ValidationError { field: "input".to_string(), @@ -122,9 +141,16 @@ pub async fn pm_save_profile( api_key, base_url, wire_api, + pricing_template_id, } = input { - Ok(manager.save_codex_profile(&name, api_key, base_url, Some(wire_api))?) + Ok(manager.save_codex_profile_with_template( + &name, + api_key, + base_url, + Some(wire_api), + pricing_template_id, + )?) } else { Err(super::error::AppError::ValidationError { field: "input".to_string(), @@ -137,9 +163,16 @@ pub async fn pm_save_profile( api_key, base_url, model, + pricing_template_id, } = input { - Ok(manager.save_gemini_profile(&name, api_key, base_url, model)?) + Ok(manager.save_gemini_profile_with_template( + &name, + api_key, + base_url, + model, + pricing_template_id, + )?) } else { Err(super::error::AppError::ValidationError { field: "input".to_string(), diff --git a/src-tauri/src/commands/proxy_commands.rs b/src-tauri/src/commands/proxy_commands.rs index 790cc5f..dadd813 100644 --- a/src-tauri/src/commands/proxy_commands.rs +++ b/src-tauri/src/commands/proxy_commands.rs @@ -460,24 +460,36 @@ pub async fn update_proxy_from_profile( let proxy_config_mgr = ProxyConfigManager::new().map_err(|e| e.to_string())?; // 根据工具类型读取 Profile - let (api_key, base_url) = match tool_id.as_str() { + let (api_key, base_url, pricing_template_id) = match tool_id.as_str() { "claude-code" => { let profile = profile_mgr .get_claude_profile(&profile_name) .map_err(|e| e.to_string())?; - (profile.api_key, profile.base_url) + ( + profile.api_key, + profile.base_url, + profile.pricing_template_id, + ) } "codex" => { let profile = profile_mgr .get_codex_profile(&profile_name) .map_err(|e| e.to_string())?; - (profile.api_key, profile.base_url) + ( + profile.api_key, + profile.base_url, + profile.pricing_template_id, + ) } "gemini-cli" => { let profile = profile_mgr .get_gemini_profile(&profile_name) .map_err(|e| e.to_string())?; - (profile.api_key, profile.base_url) + ( + profile.api_key, + profile.base_url, + profile.pricing_template_id, + ) } _ => return Err(format!("不支持的工具: {}", tool_id)), }; @@ -494,6 +506,7 @@ pub async fn update_proxy_from_profile( proxy_config.real_api_key = Some(api_key); proxy_config.real_base_url = Some(base_url); proxy_config.real_profile_name = Some(profile_name.clone()); + proxy_config.pricing_template_id = pricing_template_id; // Phase 6: 价格模板 proxy_config_mgr .update_config(&tool_id, proxy_config.clone()) diff --git a/src-tauri/src/commands/session_commands.rs b/src-tauri/src/commands/session_commands.rs index e4c3dc4..ffbbd72 100644 --- a/src-tauri/src/commands/session_commands.rs +++ b/src-tauri/src/commands/session_commands.rs @@ -33,6 +33,7 @@ pub async fn update_session_config( custom_profile_name: Option, url: String, api_key: String, + pricing_template_id: Option, // Phase 6: 价格模板 ) -> AppResult<()> { Ok(SESSION_MANAGER.update_session_config( &session_id, @@ -40,6 +41,7 @@ pub async fn update_session_config( custom_profile_name.as_deref(), &url, &api_key, + pricing_template_id.as_deref(), )?) } diff --git a/src-tauri/src/commands/token_commands.rs b/src-tauri/src/commands/token_commands.rs index bd95837..ee7d53d 100644 --- a/src-tauri/src/commands/token_commands.rs +++ b/src-tauri/src/commands/token_commands.rs @@ -100,6 +100,7 @@ pub async fn import_token_as_profile( remote_token: RemoteToken, tool_id: String, profile_name: String, + pricing_template_id: Option, // 🆕 Phase 6: 可选的价格模板 ID ) -> Result<(), String> { // 验证 tool_id if tool_id != "claude-code" && tool_id != "codex" && tool_id != "gemini-cli" { @@ -139,6 +140,7 @@ pub async fn import_token_as_profile( updated_at: Utc::now(), raw_settings: None, raw_config_json: None, + pricing_template_id: pricing_template_id.clone(), }; store.claude_code.insert(profile_name.clone(), profile); } @@ -152,6 +154,7 @@ pub async fn import_token_as_profile( updated_at: Utc::now(), raw_config_toml: None, raw_auth_json: None, + pricing_template_id: pricing_template_id.clone(), }; store.codex.insert(profile_name.clone(), profile); } @@ -165,6 +168,7 @@ pub async fn import_token_as_profile( updated_at: Utc::now(), raw_settings: None, raw_env: None, + pricing_template_id: pricing_template_id.clone(), }; store.gemini_cli.insert(profile_name.clone(), profile); } @@ -200,6 +204,13 @@ pub async fn create_custom_profile( let manager = profile_manager.manager.read().await; let mut store = manager.load_profiles_store().map_err(|e| e.to_string())?; + // 从 extra_config 中提取 pricing_template_id + let pricing_template_id = extra_config + .as_ref() + .and_then(|v| v.get("pricing_template_id")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + // 根据工具类型创建对应的 Profile match tool_id.as_str() { "claude-code" => { @@ -211,6 +222,7 @@ pub async fn create_custom_profile( updated_at: Utc::now(), raw_settings: None, raw_config_json: None, + pricing_template_id: pricing_template_id.clone(), }; store.claude_code.insert(profile_name.clone(), profile); } @@ -232,6 +244,7 @@ pub async fn create_custom_profile( updated_at: Utc::now(), raw_config_toml: None, raw_auth_json: None, + pricing_template_id: pricing_template_id.clone(), }; store.codex.insert(profile_name.clone(), profile); } @@ -252,6 +265,7 @@ pub async fn create_custom_profile( updated_at: Utc::now(), raw_settings: None, raw_env: None, + pricing_template_id: pricing_template_id.clone(), }; store.gemini_cli.insert(profile_name.clone(), profile); } diff --git a/src-tauri/src/commands/token_stats_commands.rs b/src-tauri/src/commands/token_stats_commands.rs new file mode 100644 index 0000000..f22feaa --- /dev/null +++ b/src-tauri/src/commands/token_stats_commands.rs @@ -0,0 +1,81 @@ +use duckcoding::models::token_stats::{SessionStats, TokenLogsPage, TokenStatsQuery}; +use duckcoding::services::token_stats::TokenStatsManager; + +/// 查询会话实时统计 +#[tauri::command] +pub async fn get_session_stats( + tool_type: String, + session_id: String, +) -> Result { + TokenStatsManager::get() + .get_session_stats(&tool_type, &session_id) + .map_err(|e| e.to_string()) +} + +/// 分页查询Token日志 +#[tauri::command] +pub async fn query_token_logs(query_params: TokenStatsQuery) -> Result { + TokenStatsManager::get() + .query_logs(query_params) + .map_err(|e| e.to_string()) +} + +/// 手动清理旧日志 +#[tauri::command] +pub async fn cleanup_token_logs( + retention_days: Option, + max_count: Option, +) -> Result { + TokenStatsManager::get() + .cleanup_by_config(retention_days, max_count) + .map_err(|e| e.to_string()) +} + +/// 获取数据库统计摘要 +#[tauri::command] +pub async fn get_token_stats_summary() -> Result<(i64, Option, Option), String> { + TokenStatsManager::get() + .get_stats_summary() + .map_err(|e| e.to_string()) +} + +/// 强制执行 WAL checkpoint +/// +/// 将 WAL 文件中的所有数据回写到主数据库文件, +/// 用于手动清理过大的 WAL 文件 +#[tauri::command] +pub async fn force_token_stats_checkpoint() -> Result<(), String> { + TokenStatsManager::get() + .force_checkpoint() + .map_err(|e| e.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_get_session_stats() { + let result = get_session_stats("claude_code".to_string(), "test_session".to_string()).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_query_token_logs() { + let query = TokenStatsQuery::default(); + let result = query_token_logs(query).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_cleanup_token_logs() { + let result = cleanup_token_logs(Some(30), Some(10000)).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_get_token_stats_summary() { + let result = get_token_stats_summary().await; + assert!(result.is_ok()); + } +} diff --git a/src-tauri/src/core/http.rs b/src-tauri/src/core/http.rs index 17e34d4..89304e8 100644 --- a/src-tauri/src/core/http.rs +++ b/src-tauri/src/core/http.rs @@ -115,6 +115,7 @@ mod tests { single_instance_enabled: true, startup_enabled: false, config_watch: crate::models::config::ConfigWatchConfig::default(), + token_stats_config: crate::models::config::TokenStatsConfig::default(), }; let url = build_proxy_url(&config).unwrap(); @@ -145,6 +146,7 @@ mod tests { single_instance_enabled: true, startup_enabled: false, config_watch: crate::models::config::ConfigWatchConfig::default(), + token_stats_config: crate::models::config::TokenStatsConfig::default(), }; let url = build_proxy_url(&config).unwrap(); diff --git a/src-tauri/src/data/checkpoint_strategy.md b/src-tauri/src/data/checkpoint_strategy.md new file mode 100644 index 0000000..7d82f5b --- /dev/null +++ b/src-tauri/src/data/checkpoint_strategy.md @@ -0,0 +1,63 @@ +# SQLite WAL Checkpoint 策略设计 + +## 背景 + +SQLite WAL (Write-Ahead Logging) 模式通过先写入 WAL 文件再定期合并到主文件来提升并发性能。但如果不及时 checkpoint,WAL 文件会无限增长。 + +## Checkpoint 模式 + +| 模式 | 行为 | 性能影响 | 使用场景 | +| -------- | ------------------ | -------- | ----------------- | +| PASSIVE | 尝试回写但不阻塞 | 最小 | 高频操作后 | +| FULL | 等待读者完成后回写 | 中等 | 定期维护 | +| RESTART | FULL + 重置WAL | 中等 | 同FULL | +| TRUNCATE | 强制清空WAL | 最大 | 手动触发/应用关闭 | + +## Token 统计数据库策略 + +### 写入机制 + +- **批量写入**:收集写入事件到缓冲区 +- **触发条件**:10条记录或100ms间隔 +- **异步执行**:不阻塞代理响应 + +### Checkpoint 分层 + +1. **批量写入后**:PASSIVE checkpoint(每次批量写入) +2. **定期维护**:每5分钟执行TRUNCATE(后台任务) +3. **应用关闭**:强制TRUNCATE(刷盘缓冲区+清空WAL) +4. **手动触发**:提供命令执行TRUNCATE + +## 会话记录数据库策略 + +### 已有机制 + +- 批量写入:10条或100ms触发 +- 定期清理:每小时清理旧会话 + +### Checkpoint 优化 + +1. **批量写入后**:PASSIVE checkpoint +2. **清理任务后**:TRUNCATE checkpoint +3. **低频操作**:删除/更新使用PASSIVE +4. **应用关闭**:TRUNCATE checkpoint + +## 实现要点 + +### 性能考虑 + +- PASSIVE不阻塞,适合高频场景 +- TRUNCATE完全清空,适合低频场景 +- 避免在写入循环内执行TRUNCATE + +### 数据完整性 + +- 应用关闭时强制刷盘 +- 定期TRUNCATE防止WAL过大 +- 批量写入减少磁盘IO次数 + +### 监控指标 + +- WAL文件大小 +- Checkpoint执行频率 +- 写入延迟统计 diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 1755291..d9a866f 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -245,6 +245,7 @@ fn main() { detect_tool_without_save, // 全局配置管理 save_global_config, + update_token_stats_config, get_global_config, generate_api_key_for_tool, // 使用统计 @@ -294,6 +295,15 @@ fn main() { clear_all_sessions, update_session_config, update_session_note, + // Token统计命令 + get_session_stats, + query_token_logs, + cleanup_token_logs, + get_token_stats_summary, + force_token_stats_checkpoint, + // Token统计分析命令(Phase 4) + query_token_trends, + query_cost_summary, // 配置监听控制 block_external_change, allow_external_change, @@ -376,17 +386,37 @@ fn main() { set_tool_instance_selection, get_selected_provider_id, set_selected_provider_id, + // 价格配置管理命令(Phase 6) + list_pricing_templates, + get_pricing_template, + save_pricing_template, + delete_pricing_template, + set_default_template, + get_default_template, // AMP 用户认证命令 get_amp_user_info, validate_and_save_amp_token, get_saved_amp_user_info, ]); - // 使用自定义事件循环处理 macOS Reopen 事件 + // 使用自定义事件循环处理 macOS Reopen 事件和应用关闭 builder .build(tauri::generate_context!()) .expect("error while building tauri application") .run(|app_handle, event| { + // 处理应用关闭事件 + if let tauri::RunEvent::ExitRequested { .. } = event { + tracing::info!("应用正在关闭,执行清理任务..."); + + // 关闭会话管理器后台任务 + duckcoding::services::session::shutdown_session_manager(); + + // 关闭 Token 统计后台任务 + duckcoding::services::token_stats::shutdown_token_stats_manager(); + + tracing::info!("清理任务完成"); + } + #[cfg(not(target_os = "macos"))] { let _ = app_handle; diff --git a/src-tauri/src/models/config.rs b/src-tauri/src/models/config.rs index 063af06..7fc37b1 100644 --- a/src-tauri/src/models/config.rs +++ b/src-tauri/src/models/config.rs @@ -114,6 +114,34 @@ pub struct ConfigWatchConfig { pub sensitive_fields: HashMap>, } +/// Token统计配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenStatsConfig { + /// 数据保留天数(None表示不限制) + #[serde(default)] + pub retention_days: Option, + /// 最大日志条数(None表示不限制) + #[serde(default)] + pub max_log_count: Option, + /// 是否启用自动清理 + #[serde(default = "default_auto_cleanup_enabled")] + pub auto_cleanup_enabled: bool, +} + +impl Default for TokenStatsConfig { + fn default() -> Self { + Self { + retention_days: Some(30), + max_log_count: Some(10000), + auto_cleanup_enabled: true, + } + } +} + +fn default_auto_cleanup_enabled() -> bool { + true +} + /// 配置文件快照 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConfigSnapshot { @@ -168,6 +196,9 @@ pub struct ToolProxyConfig { /// 启动代理前激活的 Profile 名称(用于关闭时还原) #[serde(default, skip_serializing_if = "Option::is_none")] pub original_active_profile: Option, + /// 价格模板 ID(用于成本计算) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pricing_template_id: Option, } impl Default for ToolProxyConfig { @@ -184,6 +215,7 @@ impl Default for ToolProxyConfig { session_endpoint_config_enabled: false, auto_start: false, original_active_profile: None, + pricing_template_id: None, } } } @@ -247,6 +279,9 @@ pub struct GlobalConfig { /// 配置监听配置 #[serde(default)] pub config_watch: ConfigWatchConfig, + /// Token统计配置 + #[serde(default)] + pub token_stats_config: TokenStatsConfig, } fn default_proxy_configs() -> HashMap { @@ -266,6 +301,7 @@ fn default_proxy_configs() -> HashMap { session_endpoint_config_enabled: false, auto_start: false, original_active_profile: None, + pricing_template_id: None, }, ); @@ -283,6 +319,7 @@ fn default_proxy_configs() -> HashMap { session_endpoint_config_enabled: false, auto_start: false, original_active_profile: None, + pricing_template_id: None, }, ); @@ -300,6 +337,7 @@ fn default_proxy_configs() -> HashMap { session_endpoint_config_enabled: false, auto_start: false, original_active_profile: None, + pricing_template_id: None, }, ); @@ -317,6 +355,7 @@ fn default_proxy_configs() -> HashMap { session_endpoint_config_enabled: false, auto_start: false, original_active_profile: None, + pricing_template_id: None, }, ); @@ -350,6 +389,7 @@ impl GlobalConfig { session_endpoint_config_enabled: false, auto_start: false, original_active_profile: None, + pricing_template_id: None, }); } diff --git a/src-tauri/src/models/mod.rs b/src-tauri/src/models/mod.rs index d9b2002..df4b436 100644 --- a/src-tauri/src/models/mod.rs +++ b/src-tauri/src/models/mod.rs @@ -1,18 +1,22 @@ pub mod balance; pub mod config; pub mod dashboard; +pub mod pricing; pub mod provider; pub mod proxy_config; pub mod remote_token; +pub mod token_stats; pub mod tool; pub mod update; pub use balance::*; pub use config::*; pub use dashboard::*; +pub use pricing::*; pub use provider::*; // 只导出新的 proxy_config 类型,避免与 config.rs 中的旧类型冲突 pub use proxy_config::{ProxyMetadata, ProxyStore}; pub use remote_token::*; +pub use token_stats::*; pub use tool::*; pub use update::*; diff --git a/src-tauri/src/models/pricing.rs b/src-tauri/src/models/pricing.rs new file mode 100644 index 0000000..d66168f --- /dev/null +++ b/src-tauri/src/models/pricing.rs @@ -0,0 +1,332 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// 单个模型的价格定义 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelPrice { + /// 提供商(如:anthropic、openai) + pub provider: String, + + /// 输入价格(USD/百万 Token) + pub input_price_per_1m: f64, + + /// 输出价格(USD/百万 Token) + pub output_price_per_1m: f64, + + /// 缓存写入价格(USD/百万 Token,可选) + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write_price_per_1m: Option, + + /// 缓存读取价格(USD/百万 Token,可选) + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read_price_per_1m: Option, + + /// 货币类型(默认:USD) + #[serde(default = "default_currency")] + pub currency: String, + + /// 模型别名列表(支持多种 ID 格式) + #[serde(default)] + pub aliases: Vec, +} + +impl ModelPrice { + /// 创建新的模型价格定义 + #[allow(clippy::too_many_arguments)] + pub fn new( + provider: String, + input_price_per_1m: f64, + output_price_per_1m: f64, + cache_write_price_per_1m: Option, + cache_read_price_per_1m: Option, + aliases: Vec, + ) -> Self { + Self { + provider, + input_price_per_1m, + output_price_per_1m, + cache_write_price_per_1m, + cache_read_price_per_1m, + currency: default_currency(), + aliases, + } + } +} + +/// 单个模型的继承配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InheritedModel { + /// 模型名称(如:"claude-sonnet-4.5") + pub model_name: String, + + /// 从哪个模板继承(如:"claude_official_2025_01") + pub source_template_id: String, + + /// 倍率(应用到继承的价格上) + pub multiplier: f64, +} + +impl InheritedModel { + /// 创建新的继承模型配置 + pub fn new(model_name: String, source_template_id: String, multiplier: f64) -> Self { + Self { + model_name, + source_template_id, + multiplier, + } + } +} + +/// 价格模板(统一结构,支持三种模式) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PricingTemplate { + /// 模板ID(唯一标识) + pub id: String, + + /// 模板名称 + pub name: String, + + /// 模板描述 + pub description: String, + + /// 模板版本 + pub version: String, + + /// 创建时间(Unix 时间戳,毫秒) + pub created_at: i64, + + /// 更新时间(Unix 时间戳,毫秒) + pub updated_at: i64, + + /// 继承配置(每个模型独立配置,可从不同模板继承) + #[serde(default)] + pub inherited_models: Vec, + + /// 自定义模型(直接定义价格) + #[serde(default)] + pub custom_models: HashMap, + + /// 标签列表(用于分类和搜索) + #[serde(default)] + pub tags: Vec, + + /// 是否为内置预设模板 + #[serde(default)] + pub is_default_preset: bool, +} + +impl PricingTemplate { + /// 创建新的价格模板 + #[allow(clippy::too_many_arguments)] + pub fn new( + id: String, + name: String, + description: String, + version: String, + inherited_models: Vec, + custom_models: HashMap, + tags: Vec, + is_default_preset: bool, + ) -> Self { + let now = chrono::Utc::now().timestamp_millis(); + Self { + id, + name, + description, + version, + created_at: now, + updated_at: now, + inherited_models, + custom_models, + tags, + is_default_preset, + } + } + + /// 判断是否为完全自定义模式 + /// + /// 完全自定义:inherited_models 为空,custom_models 包含所有模型及其价格 + pub fn is_full_custom(&self) -> bool { + self.inherited_models.is_empty() && !self.custom_models.is_empty() + } + + /// 判断是否为纯继承模式 + /// + /// 纯继承:inherited_models 包含多个模型,custom_models 为空 + pub fn is_pure_inheritance(&self) -> bool { + !self.inherited_models.is_empty() && self.custom_models.is_empty() + } + + /// 判断是否为混合模式 + /// + /// 混合模式:inherited_models 和 custom_models 同时存在 + pub fn is_mixed(&self) -> bool { + !self.inherited_models.is_empty() && !self.custom_models.is_empty() + } +} + +/// 工具默认模板配置(存储在 default_templates.json) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DefaultTemplatesConfig { + /// 工具 -> 默认模板 ID 的映射 + /// + /// 例如: + /// ```json + /// { + /// "claude-code": "claude_official_2025_01", + /// "codex": "claude_official_2025_01", + /// "gemini-cli": "claude_official_2025_01" + /// } + /// ``` + #[serde(flatten)] + pub tool_defaults: HashMap, +} + +impl DefaultTemplatesConfig { + /// 创建新的默认模板配置 + pub fn new() -> Self { + Self { + tool_defaults: HashMap::new(), + } + } + + /// 获取工具的默认模板 ID + pub fn get_default(&self, tool_id: &str) -> Option<&String> { + self.tool_defaults.get(tool_id) + } + + /// 设置工具的默认模板 ID + pub fn set_default(&mut self, tool_id: String, template_id: String) { + self.tool_defaults.insert(tool_id, template_id); + } +} + +impl Default for DefaultTemplatesConfig { + fn default() -> Self { + Self::new() + } +} + +/// 默认货币类型 +fn default_currency() -> String { + "USD".to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_price_creation() { + let price = ModelPrice::new( + "anthropic".to_string(), + 3.0, + 15.0, + Some(3.75), + Some(0.3), + vec![ + "claude-sonnet-4.5".to_string(), + "claude-sonnet-4-5".to_string(), + ], + ); + + assert_eq!(price.provider, "anthropic"); + assert_eq!(price.input_price_per_1m, 3.0); + assert_eq!(price.output_price_per_1m, 15.0); + assert_eq!(price.cache_write_price_per_1m, Some(3.75)); + assert_eq!(price.cache_read_price_per_1m, Some(0.3)); + assert_eq!(price.currency, "USD"); + assert_eq!(price.aliases.len(), 2); + } + + #[test] + fn test_inherited_model_creation() { + let inherited = InheritedModel::new( + "claude-sonnet-4.5".to_string(), + "builtin_claude".to_string(), + 1.1, + ); + + assert_eq!(inherited.model_name, "claude-sonnet-4.5"); + assert_eq!(inherited.source_template_id, "builtin_claude"); + assert_eq!(inherited.multiplier, 1.1); + } + + #[test] + fn test_pricing_template_modes() { + // 完全自定义模式 + let mut custom_models = HashMap::new(); + custom_models.insert( + "model1".to_string(), + ModelPrice::new("provider1".to_string(), 1.0, 2.0, None, None, vec![]), + ); + + let full_custom = PricingTemplate::new( + "template1".to_string(), + "Full Custom".to_string(), + "Description".to_string(), + "1.0".to_string(), + vec![], + custom_models.clone(), + vec![], + false, + ); + + assert!(full_custom.is_full_custom()); + assert!(!full_custom.is_pure_inheritance()); + assert!(!full_custom.is_mixed()); + + // 纯继承模式 + let inherited_models = vec![InheritedModel::new( + "model1".to_string(), + "source_template".to_string(), + 1.0, + )]; + + let pure_inheritance = PricingTemplate::new( + "template2".to_string(), + "Pure Inheritance".to_string(), + "Description".to_string(), + "1.0".to_string(), + inherited_models.clone(), + HashMap::new(), + vec![], + false, + ); + + assert!(!pure_inheritance.is_full_custom()); + assert!(pure_inheritance.is_pure_inheritance()); + assert!(!pure_inheritance.is_mixed()); + + // 混合模式 + let mixed = PricingTemplate::new( + "template3".to_string(), + "Mixed".to_string(), + "Description".to_string(), + "1.0".to_string(), + inherited_models, + custom_models, + vec![], + false, + ); + + assert!(!mixed.is_full_custom()); + assert!(!mixed.is_pure_inheritance()); + assert!(mixed.is_mixed()); + } + + #[test] + fn test_default_templates_config() { + let mut config = DefaultTemplatesConfig::new(); + + config.set_default("claude-code".to_string(), "template1".to_string()); + config.set_default("codex".to_string(), "template2".to_string()); + + assert_eq!( + config.get_default("claude-code"), + Some(&"template1".to_string()) + ); + assert_eq!(config.get_default("codex"), Some(&"template2".to_string())); + assert_eq!(config.get_default("gemini-cli"), None); + } +} diff --git a/src-tauri/src/models/proxy_config.rs b/src-tauri/src/models/proxy_config.rs index a7a7bfa..ee32e89 100644 --- a/src-tauri/src/models/proxy_config.rs +++ b/src-tauri/src/models/proxy_config.rs @@ -26,6 +26,9 @@ pub struct ToolProxyConfig { /// 启动代理前激活的 Profile 名称(用于关闭时还原) #[serde(default, skip_serializing_if = "Option::is_none")] pub original_active_profile: Option, + /// 价格模板 ID(用于成本计算) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pricing_template_id: Option, /// AMP Code 原始 settings.json 完整内容(用于关闭时还原,语义备份) #[serde(default, skip_serializing_if = "Option::is_none")] pub original_amp_settings: Option, @@ -51,6 +54,7 @@ impl ToolProxyConfig { session_endpoint_config_enabled: false, auto_start: false, original_active_profile: None, + pricing_template_id: None, original_amp_settings: None, original_amp_secrets: None, tavily_api_key: None, diff --git a/src-tauri/src/models/token_stats.rs b/src-tauri/src/models/token_stats.rs new file mode 100644 index 0000000..e30de1a --- /dev/null +++ b/src-tauri/src/models/token_stats.rs @@ -0,0 +1,320 @@ +use serde::{Deserialize, Serialize}; + +/// Token日志记录 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenLog { + /// 主键ID(自增) + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + + /// 工具类型:claude_code, codex, gemini_cli + pub tool_type: String, + + /// 请求时间戳(毫秒) + pub timestamp: i64, + + /// 客户端IP地址 + pub client_ip: String, + + /// 会话ID + pub session_id: String, + + /// 使用的配置名称 + pub config_name: String, + + /// 模型名称 + pub model: String, + + /// API返回的消息ID + #[serde(skip_serializing_if = "Option::is_none")] + pub message_id: Option, + + /// 输入Token数量 + pub input_tokens: i64, + + /// 输出Token数量 + pub output_tokens: i64, + + /// 缓存创建Token数量 + pub cache_creation_tokens: i64, + + /// 缓存读取Token数量 + pub cache_read_tokens: i64, + + /// 请求状态:success, failed + pub request_status: String, + + /// 响应类型:sse, json, unknown + pub response_type: String, + + /// 错误类型:parse_error, request_interrupted, upstream_error(成功时为None) + #[serde(skip_serializing_if = "Option::is_none")] + pub error_type: Option, + + /// 错误详情(成功时为None) + #[serde(skip_serializing_if = "Option::is_none")] + pub error_detail: Option, + + /// 响应时间(毫秒) + #[serde(skip_serializing_if = "Option::is_none")] + pub response_time_ms: Option, + + /// 输入部分价格(USD) + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "crate::utils::precision::option_price_precision")] + pub input_price: Option, + + /// 输出部分价格(USD) + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "crate::utils::precision::option_price_precision")] + pub output_price: Option, + + /// 缓存写入部分价格(USD) + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "crate::utils::precision::option_price_precision")] + pub cache_write_price: Option, + + /// 缓存读取部分价格(USD) + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "crate::utils::precision::option_price_precision")] + pub cache_read_price: Option, + + /// 总成本(USD) + #[serde(default)] + #[serde(with = "crate::utils::precision::price_precision")] + pub total_cost: f64, + + /// 使用的价格模板ID + #[serde(skip_serializing_if = "Option::is_none")] + pub pricing_template_id: Option, +} + +impl TokenLog { + /// 创建新的Token日志记录 + #[allow(clippy::too_many_arguments)] + pub fn new( + tool_type: String, + timestamp: i64, + client_ip: String, + session_id: String, + config_name: String, + model: String, + message_id: Option, + input_tokens: i64, + output_tokens: i64, + cache_creation_tokens: i64, + cache_read_tokens: i64, + request_status: String, + response_type: String, + error_type: Option, + error_detail: Option, + response_time_ms: Option, + input_price: Option, + output_price: Option, + cache_write_price: Option, + cache_read_price: Option, + total_cost: f64, + pricing_template_id: Option, + ) -> Self { + Self { + id: None, + tool_type, + timestamp, + client_ip, + session_id, + config_name, + model, + message_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + request_status, + response_type, + error_type, + error_detail, + response_time_ms, + input_price, + output_price, + cache_write_price, + cache_read_price, + total_cost, + pricing_template_id, + } + } + + /// 计算总Token数量 + pub fn total_tokens(&self) -> i64 { + self.input_tokens + self.output_tokens + } + + /// 计算总缓存Token数量 + pub fn total_cache_tokens(&self) -> i64 { + self.cache_creation_tokens + self.cache_read_tokens + } + + /// 是否成功 + pub fn is_success(&self) -> bool { + self.request_status == "success" + } +} + +/// 会话统计数据 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionStats { + /// 总输入Token数量 + pub total_input: i64, + + /// 总输出Token数量 + pub total_output: i64, + + /// 总缓存创建Token数量 + pub total_cache_creation: i64, + + /// 总缓存读取Token数量 + pub total_cache_read: i64, + + /// 请求总数 + pub request_count: i64, +} + +impl SessionStats { + /// 创建空的统计数据 + pub fn empty() -> Self { + Self { + total_input: 0, + total_output: 0, + total_cache_creation: 0, + total_cache_read: 0, + request_count: 0, + } + } + + /// 计算总Token数量 + pub fn total_tokens(&self) -> i64 { + self.total_input + self.total_output + } + + /// 计算总缓存Token数量 + pub fn total_cache_tokens(&self) -> i64 { + self.total_cache_creation + self.total_cache_read + } +} + +/// Token日志查询参数 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenStatsQuery { + /// 工具类型筛选 + pub tool_type: Option, + + /// 会话ID筛选 + pub session_id: Option, + + /// 配置名称筛选 + pub config_name: Option, + + /// 开始时间戳(毫秒) + pub start_time: Option, + + /// 结束时间戳(毫秒) + pub end_time: Option, + + /// 分页:页码(从0开始) + pub page: u32, + + /// 分页:每页大小 + pub page_size: u32, +} + +impl Default for TokenStatsQuery { + fn default() -> Self { + Self { + tool_type: None, + session_id: None, + config_name: None, + start_time: None, + end_time: None, + page: 0, + page_size: 20, + } + } +} + +/// 分页查询结果 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenLogsPage { + /// 日志列表 + pub logs: Vec, + + /// 总记录数 + pub total: i64, + + /// 当前页码 + pub page: u32, + + /// 每页大小 + pub page_size: u32, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_token_log_creation() { + let log = TokenLog::new( + "claude_code".to_string(), + 1700000000000, + "127.0.0.1".to_string(), + "session_123".to_string(), + "default".to_string(), + "claude-sonnet-4-5-20250929".to_string(), + Some("msg_123".to_string()), + 1000, + 500, + 100, + 200, + "success".to_string(), + "sse".to_string(), + None, + None, + Some(1500), + Some(0.003), + Some(0.0075), + Some(0.000375), + Some(0.00006), + 0.011235, + Some("builtin_claude".to_string()), + ); + + assert_eq!(log.tool_type, "claude_code"); + assert_eq!(log.total_tokens(), 1500); + assert_eq!(log.total_cache_tokens(), 300); + assert!(log.is_success()); + assert_eq!(log.response_time_ms, Some(1500)); + assert_eq!(log.total_cost, 0.011235); + assert_eq!(log.pricing_template_id, Some("builtin_claude".to_string())); + } + + #[test] + fn test_session_stats_calculation() { + let stats = SessionStats { + total_input: 10000, + total_output: 5000, + total_cache_creation: 1000, + total_cache_read: 2000, + request_count: 10, + }; + + assert_eq!(stats.total_tokens(), 15000); + assert_eq!(stats.total_cache_tokens(), 3000); + } + + #[test] + fn test_query_default() { + let query = TokenStatsQuery::default(); + assert_eq!(query.page, 0); + assert_eq!(query.page_size, 20); + assert!(query.tool_type.is_none()); + } +} diff --git a/src-tauri/src/services/migration_manager/manager.rs b/src-tauri/src/services/migration_manager/manager.rs index 14f2918..1fe8328 100644 --- a/src-tauri/src/services/migration_manager/manager.rs +++ b/src-tauri/src/services/migration_manager/manager.rs @@ -192,6 +192,7 @@ impl MigrationManager { single_instance_enabled: true, startup_enabled: false, config_watch: crate::models::config::ConfigWatchConfig::default(), + token_stats_config: crate::models::config::TokenStatsConfig::default(), }); config.version = Some(new_version.to_string()); diff --git a/src-tauri/src/services/migration_manager/migrations/profile_v2.rs b/src-tauri/src/services/migration_manager/migrations/profile_v2.rs index b74e9f8..89dcf16 100644 --- a/src-tauri/src/services/migration_manager/migrations/profile_v2.rs +++ b/src-tauri/src/services/migration_manager/migrations/profile_v2.rs @@ -218,6 +218,7 @@ impl ProfileV2Migration { raw_settings: Some(settings_value), raw_config_json: None, source: ProfileSource::Custom, + pricing_template_id: None, }; profiles.insert(profile_name.clone(), profile); tracing::info!("已从原始 Claude Code 配置迁移 Profile: {}", profile_name); @@ -323,6 +324,7 @@ impl ProfileV2Migration { raw_config_toml, raw_auth_json: Some(auth_data), source: ProfileSource::Custom, + pricing_template_id: None, }; profiles.insert(profile_name.clone(), profile); tracing::info!("已从原始 Codex 配置迁移 Profile: {}", profile_name); @@ -398,6 +400,7 @@ impl ProfileV2Migration { raw_settings: None, raw_env, source: ProfileSource::Custom, + pricing_template_id: None, }; profiles.insert(profile_name.clone(), profile); tracing::info!("已从原始 Gemini CLI 配置迁移 Profile: {}", profile_name); @@ -495,6 +498,7 @@ impl ProfileV2Migration { raw_settings, raw_config_json, source: ProfileSource::Custom, + pricing_template_id: None, }, CodexProfile::default_placeholder(), GeminiProfile::default_placeholder(), @@ -533,6 +537,7 @@ impl ProfileV2Migration { raw_config_toml, raw_auth_json, source: ProfileSource::Custom, + pricing_template_id: None, }, GeminiProfile::default_placeholder(), )) @@ -570,6 +575,7 @@ impl ProfileV2Migration { raw_settings, raw_env, source: ProfileSource::Custom, + pricing_template_id: None, }, )) } @@ -866,6 +872,7 @@ impl ClaudeProfile { raw_settings: None, raw_config_json: None, source: ProfileSource::Custom, + pricing_template_id: None, } } } @@ -881,6 +888,7 @@ impl CodexProfile { raw_config_toml: None, raw_auth_json: None, source: ProfileSource::Custom, + pricing_template_id: None, } } } @@ -896,6 +904,7 @@ impl GeminiProfile { raw_settings: None, raw_env: None, source: ProfileSource::Custom, + pricing_template_id: None, } } } diff --git a/src-tauri/src/services/migration_manager/migrations/proxy_config_split.rs b/src-tauri/src/services/migration_manager/migrations/proxy_config_split.rs index dcf4a92..c4c94d0 100644 --- a/src-tauri/src/services/migration_manager/migrations/proxy_config_split.rs +++ b/src-tauri/src/services/migration_manager/migrations/proxy_config_split.rs @@ -282,6 +282,10 @@ fn parse_old_config(value: &Value) -> Result { .get("original_active_profile") .and_then(|v| v.as_str()) .map(|s| s.to_string()), + pricing_template_id: obj + .get("pricing_template_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), original_amp_settings: obj.get("original_amp_settings").cloned(), original_amp_secrets: obj.get("original_amp_secrets").cloned(), tavily_api_key: obj diff --git a/src-tauri/src/services/mod.rs b/src-tauri/src/services/mod.rs index ec3607e..ba207b3 100644 --- a/src-tauri/src/services/mod.rs +++ b/src-tauri/src/services/mod.rs @@ -10,6 +10,7 @@ // - balance: 余额监控配置管理 // - provider_manager: 供应商配置管理 // - new_api: NEW API 客户端服务 +// - token_stats: Token统计和请求记录 pub mod amp_native_config; // AMP Code 原生配置管理 pub mod balance; @@ -17,11 +18,13 @@ pub mod config; pub mod dashboard_manager; // 仪表板状态管理 pub mod migration_manager; pub mod new_api; // NEW API 客户端 +pub mod pricing; // 价格配置管理 pub mod profile_manager; // Profile管理(v2.1) pub mod provider_manager; // 供应商配置管理 pub mod proxy; pub mod proxy_config_manager; // 透明代理配置管理(v2.1) pub mod session; +pub mod token_stats; // Token统计服务 pub mod tool; pub mod update; @@ -39,6 +42,8 @@ pub use provider_manager::ProviderManager; pub use proxy::*; // session 模块:明确导出避免 db 名称冲突 pub use session::{manager::SESSION_MANAGER, models::*}; +// token_stats 模块:导出管理器和提取器 +pub use token_stats::{TokenStatsDb, TokenStatsManager}; // tool 模块:导出主要服务类和子模块 pub use tool::{ db::ToolInstanceDB, downloader, downloader::FileDownloader, installer, diff --git a/src-tauri/src/services/pricing/builtin.rs b/src-tauri/src/services/pricing/builtin.rs new file mode 100644 index 0000000..8966a2a --- /dev/null +++ b/src-tauri/src/services/pricing/builtin.rs @@ -0,0 +1,257 @@ +use crate::models::pricing::{ModelPrice, PricingTemplate}; +use std::collections::HashMap; + +/// 生成内置 Claude 价格模板 +/// +/// 包含 8 个 Claude 模型的官方定价 +pub fn builtin_claude_official_template() -> PricingTemplate { + let mut custom_models = HashMap::new(); + + // Claude Opus 4.5: $5 input / $25 output + custom_models.insert( + "claude-opus-4.5".to_string(), + ModelPrice::new( + "anthropic".to_string(), + 5.0, + 25.0, + Some(6.25), // Cache write: 5.0 * 1.25 + Some(0.5), // Cache read: 5.0 * 0.1 + vec![ + "claude-opus-4.5".to_string(), + "claude-opus-4-5".to_string(), + "opus-4.5".to_string(), + "claude-opus-4-5-20251101".to_string(), + ], + ), + ); + + // Claude Opus 4.1: $15 input / $75 output + custom_models.insert( + "claude-opus-4.1".to_string(), + ModelPrice::new( + "anthropic".to_string(), + 15.0, + 75.0, + Some(18.75), // Cache write: 15.0 * 1.25 + Some(1.5), // Cache read: 15.0 * 0.1 + vec![ + "claude-opus-4.1".to_string(), + "claude-opus-4-1".to_string(), + "claude-opus-4-1-20250805".to_string(), + ], + ), + ); + + // Claude Opus 4: $15 input / $75 output + custom_models.insert( + "claude-opus-4".to_string(), + ModelPrice::new( + "anthropic".to_string(), + 15.0, + 75.0, + Some(18.75), // Cache write: 15.0 * 1.25 + Some(1.5), // Cache read: 15.0 * 0.1 + vec![ + "claude-opus-4".to_string(), + "claude-opus-4-20250514".to_string(), + ], + ), + ); + + // Claude Sonnet 4.5: $3 input / $15 output + custom_models.insert( + "claude-sonnet-4.5".to_string(), + ModelPrice::new( + "anthropic".to_string(), + 3.0, + 15.0, + Some(3.75), // Cache write: 3.0 * 1.25 + Some(0.3), // Cache read: 3.0 * 0.1 + vec![ + "claude-sonnet-4.5".to_string(), + "claude-sonnet-4-5".to_string(), + "claude-sonnet-4-5-20250929".to_string(), + ], + ), + ); + + // Claude Sonnet 4: $3 input / $15 output + custom_models.insert( + "claude-sonnet-4".to_string(), + ModelPrice::new( + "anthropic".to_string(), + 3.0, + 15.0, + Some(3.75), // Cache write: 3.0 * 1.25 + Some(0.3), // Cache read: 3.0 * 0.1 + vec![ + "claude-sonnet-4".to_string(), + "claude-sonnet-4-20250514".to_string(), + ], + ), + ); + + // claude-3-7-sonnet : $3 input / $15 output + custom_models.insert( + "claude-3-7-sonnet".to_string(), + ModelPrice::new( + "anthropic".to_string(), + 3.0, + 15.0, + Some(3.75), // Cache write: 3.0 * 1.25 + Some(0.3), // Cache read: 3.0 * 0.1 + vec![ + "claude-3-7-sonnet".to_string(), + "claude-3-7-sonnet-20250219".to_string(), + "claude-3-sonnet-3-7".to_string(), + "sonnet-3.7".to_string(), + ], + ), + ); + + // Claude Haiku 4.5: $1 input / $5 output + custom_models.insert( + "claude-haiku-4.5".to_string(), + ModelPrice::new( + "anthropic".to_string(), + 1.0, + 5.0, + Some(1.25), // Cache write: 1.0 * 1.25 + Some(0.1), // Cache read: 1.0 * 0.1 + vec![ + "claude-haiku-4.5".to_string(), + "claude-haiku-4-5".to_string(), + "claude-haiku-4-5-20251001".to_string(), + ], + ), + ); + + // Claude Haiku 3.5: $0.8 input / $4 output + custom_models.insert( + "claude-haiku-3.5".to_string(), + ModelPrice::new( + "anthropic".to_string(), + 0.8, + 4.0, + Some(1.0), // Cache write: 0.8 * 1.25 + Some(0.08), // Cache read: 0.8 * 0.1 + vec![ + "claude-haiku-3.5".to_string(), + "claude-haiku-3-5".to_string(), + "claude-3-5-haiku-20241022".to_string(), + ], + ), + ); + + PricingTemplate::new( + "builtin_claude".to_string(), + "内置Claude价格".to_string(), + "Anthropic 官方定价,包含 8 个 Claude 模型".to_string(), + "1.0".to_string(), + vec![], // 内置模板不使用继承 + custom_models, + vec!["official".to_string(), "claude".to_string()], + true, // 标记为内置预设模板 + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_builtin_template() { + let template = builtin_claude_official_template(); + + // 验证基本信息 + assert_eq!(template.id, "builtin_claude"); + assert!(template.is_default_preset); + assert!(template.is_full_custom()); + + // 验证包含 8 个模型 + assert_eq!(template.custom_models.len(), 8); + + // 验证 Opus 4.5 价格 + let opus_4_5 = template.custom_models.get("claude-opus-4.5").unwrap(); + assert_eq!(opus_4_5.provider, "anthropic"); + assert_eq!(opus_4_5.input_price_per_1m, 5.0); + assert_eq!(opus_4_5.output_price_per_1m, 25.0); + assert_eq!(opus_4_5.cache_write_price_per_1m, Some(6.25)); + assert_eq!(opus_4_5.cache_read_price_per_1m, Some(0.5)); + assert_eq!(opus_4_5.aliases.len(), 4); + + // 验证 Sonnet 4.5 价格 + let sonnet_4_5 = template.custom_models.get("claude-sonnet-4.5").unwrap(); + assert_eq!(sonnet_4_5.input_price_per_1m, 3.0); + assert_eq!(sonnet_4_5.output_price_per_1m, 15.0); + assert_eq!(sonnet_4_5.cache_write_price_per_1m, Some(3.75)); + assert_eq!(sonnet_4_5.cache_read_price_per_1m, Some(0.3)); + + // // 验证 Claude 3.5 Sonnet (旧版本) 价格 + // let sonnet_3_5 = template.custom_models.get("claude-3-5-sonnet").unwrap(); + // assert_eq!(sonnet_3_5.input_price_per_1m, 3.0); + // assert_eq!(sonnet_3_5.output_price_per_1m, 15.0); + // assert_eq!(sonnet_3_5.cache_write_price_per_1m, Some(3.75)); + // assert_eq!(sonnet_3_5.cache_read_price_per_1m, Some(0.3)); + // assert!(sonnet_3_5 + // .aliases + // .contains(&"claude-sonnet-4-5-20250929".to_string())); + + // 验证 Haiku 3.5 价格 + let haiku_3_5 = template.custom_models.get("claude-haiku-3.5").unwrap(); + assert_eq!(haiku_3_5.input_price_per_1m, 0.8); + assert_eq!(haiku_3_5.output_price_per_1m, 4.0); + assert_eq!(haiku_3_5.cache_write_price_per_1m, Some(1.0)); + assert_eq!(haiku_3_5.cache_read_price_per_1m, Some(0.08)); + } + + #[test] + fn test_builtin_template_aliases() { + let template = builtin_claude_official_template(); + + // 验证 Sonnet 4.5 的别名 + let sonnet_4_5 = template.custom_models.get("claude-sonnet-4.5").unwrap(); + assert!(sonnet_4_5 + .aliases + .contains(&"claude-sonnet-4.5".to_string())); + assert!(sonnet_4_5 + .aliases + .contains(&"claude-sonnet-4-5".to_string())); + assert!(sonnet_4_5 + .aliases + .contains(&"claude-sonnet-4-5-20250929".to_string())); + } + + #[test] + fn test_cache_price_calculations() { + let template = builtin_claude_official_template(); + + // 验证缓存价格计算公式:write = input * 1.25, read = input * 0.1 + for (_, model_price) in template.custom_models.iter() { + let expected_cache_write = + (model_price.input_price_per_1m * 1.25 * 100.0).round() / 100.0; + let expected_cache_read = + (model_price.input_price_per_1m * 0.1 * 100.0).round() / 100.0; + + let actual_cache_write = model_price + .cache_write_price_per_1m + .map(|v| (v * 100.0).round() / 100.0) + .unwrap_or(0.0); + let actual_cache_read = model_price + .cache_read_price_per_1m + .map(|v| (v * 100.0).round() / 100.0) + .unwrap_or(0.0); + + assert_eq!( + actual_cache_write, expected_cache_write, + "Cache write price mismatch for model with input price {}", + model_price.input_price_per_1m + ); + assert_eq!( + actual_cache_read, expected_cache_read, + "Cache read price mismatch for model with input price {}", + model_price.input_price_per_1m + ); + } + } +} diff --git a/src-tauri/src/services/pricing/manager.rs b/src-tauri/src/services/pricing/manager.rs new file mode 100644 index 0000000..a7bb1aa --- /dev/null +++ b/src-tauri/src/services/pricing/manager.rs @@ -0,0 +1,554 @@ +use crate::data::DataManager; +use crate::models::pricing::{DefaultTemplatesConfig, ModelPrice, PricingTemplate}; +use crate::services::pricing::builtin::builtin_claude_official_template; +use crate::utils::precision::price_precision; +use anyhow::{anyhow, Context, Result}; +use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::sync::Arc; + +#[cfg(test)] +use crate::models::pricing::InheritedModel; + +/// 成本分解结果 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostBreakdown { + /// 输入部分价格(USD) + #[serde(with = "price_precision")] + pub input_price: f64, + + /// 输出部分价格(USD) + #[serde(with = "price_precision")] + pub output_price: f64, + + /// 缓存写入部分价格(USD) + #[serde(with = "price_precision")] + pub cache_write_price: f64, + + /// 缓存读取部分价格(USD) + #[serde(with = "price_precision")] + pub cache_read_price: f64, + + /// 总成本(USD) + #[serde(with = "price_precision")] + pub total_cost: f64, + + /// 使用的价格模板 ID + pub template_id: String, +} + +lazy_static! { + /// 全局 PricingManager 实例 + pub static ref PRICING_MANAGER: PricingManager = { + PricingManager::init_global().expect("Failed to initialize PricingManager") + }; +} + +/// 价格管理服务 +pub struct PricingManager { + /// DataManager 实例(Arc 包装以支持克隆) + data_manager: Arc, + + /// 价格配置目录路径(保留用于未来扩展) + #[allow(dead_code)] + pricing_dir: PathBuf, + + /// 模板存储目录路径 + templates_dir: PathBuf, + + /// 默认模板配置文件路径 + default_templates_path: PathBuf, +} + +impl PricingManager { + /// 初始化全局实例(用于 lazy_static) + pub fn init_global() -> Result { + let home_dir = dirs::home_dir().ok_or_else(|| anyhow!("无法获取用户主目录"))?; + let base_dir = home_dir.join(".duckcoding"); + + let manager = Self::new(base_dir)?; + manager.initialize()?; + + Ok(manager) + } + + /// 创建新的价格管理服务实例(用于测试或自定义场景) + pub fn new_with_manager(base_dir: PathBuf, data_manager: Arc) -> Self { + let pricing_dir = base_dir.join("pricing"); + let templates_dir = pricing_dir.join("templates"); + let default_templates_path = pricing_dir.join("default_templates.json"); + + Self { + data_manager, + pricing_dir, + templates_dir, + default_templates_path, + } + } + + /// 创建新的价格管理服务实例(使用全局 DataManager) + pub fn new(base_dir: PathBuf) -> Result { + Ok(Self::new_with_manager( + base_dir, + Arc::new(DataManager::new()), + )) + } + + /// 初始化价格配置目录和默认模板 + pub fn initialize(&self) -> Result<()> { + // 创建目录 + std::fs::create_dir_all(&self.templates_dir) + .context("Failed to create templates directory")?; + + // 保存内置 Claude 官方模板 + let builtin_template = builtin_claude_official_template(); + self.save_template(&builtin_template)?; + + // 初始化默认模板配置(如果不存在) + if !self.default_templates_path.exists() { + let mut config = DefaultTemplatesConfig::new(); + config.set_default("claude-code".to_string(), "builtin_claude".to_string()); + config.set_default("codex".to_string(), "builtin_claude".to_string()); + config.set_default("gemini-cli".to_string(), "builtin_claude".to_string()); + + let value = serde_json::to_value(&config) + .context("Failed to serialize default templates config")?; + + self.data_manager + .json() + .write(&self.default_templates_path, &value) + .context("Failed to write default templates config")?; + } + + Ok(()) + } + + /// 获取指定的价格模板 + pub fn get_template(&self, template_id: &str) -> Result { + let template_path = self.templates_dir.join(format!("{}.json", template_id)); + + if !template_path.exists() { + return Err(anyhow!("Template {} not found", template_id)); + } + + let value = self + .data_manager + .json() + .read(&template_path) + .with_context(|| format!("Failed to read template {}", template_id))?; + + serde_json::from_value(value) + .with_context(|| format!("Failed to parse template {}", template_id)) + } + + /// 列出所有价格模板 + pub fn list_templates(&self) -> Result> { + let mut templates = Vec::new(); + + if !self.templates_dir.exists() { + return Ok(templates); + } + + let entries = + std::fs::read_dir(&self.templates_dir).context("Failed to read templates directory")?; + + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("json") { + if let Ok(value) = self.data_manager.json().read(&path) { + if let Ok(template) = serde_json::from_value::(value) { + templates.push(template); + } + } + } + } + + Ok(templates) + } + + /// 保存价格模板 + pub fn save_template(&self, template: &PricingTemplate) -> Result<()> { + let template_path = self.templates_dir.join(format!("{}.json", template.id)); + + let value = serde_json::to_value(template) + .with_context(|| format!("Failed to serialize template {}", template.id))?; + + self.data_manager + .json() + .write(&template_path, &value) + .with_context(|| format!("Failed to save template {}", template.id)) + } + + /// 删除价格模板 + pub fn delete_template(&self, template_id: &str) -> Result<()> { + let template_path = self.templates_dir.join(format!("{}.json", template_id)); + + if !template_path.exists() { + return Err(anyhow!("Template {} not found", template_id)); + } + + // 不允许删除内置预设模板 + if let Ok(template) = self.get_template(template_id) { + if template.is_default_preset { + return Err(anyhow!("Cannot delete built-in preset template")); + } + } + + std::fs::remove_file(&template_path) + .with_context(|| format!("Failed to delete template {}", template_id)) + } + + /// 设置工具的默认模板 + pub fn set_default_template(&self, tool_id: &str, template_id: &str) -> Result<()> { + // 验证模板是否存在 + self.get_template(template_id)?; + + let mut config = self.get_default_templates_config()?; + config.set_default(tool_id.to_string(), template_id.to_string()); + + let value = serde_json::to_value(&config) + .context("Failed to serialize default templates config")?; + + self.data_manager + .json() + .write(&self.default_templates_path, &value) + .context("Failed to update default templates config") + } + + /// 获取工具的默认模板 + pub fn get_default_template(&self, tool_id: &str) -> Result { + let config = self.get_default_templates_config()?; + + let template_id = config + .get_default(tool_id) + .ok_or_else(|| anyhow!("No default template set for tool {}", tool_id))?; + + self.get_template(template_id) + } + + /// 获取默认模板配置 + fn get_default_templates_config(&self) -> Result { + if !self.default_templates_path.exists() { + return Ok(DefaultTemplatesConfig::new()); + } + + let value = self + .data_manager + .json() + .read(&self.default_templates_path) + .context("Failed to read default templates config")?; + + serde_json::from_value(value).context("Failed to parse default templates config") + } + + /// 计算成本(核心方法) + /// + /// # 参数 + /// + /// - `template_id`: 价格模板 ID(None 时使用工具默认模板) + /// - `model`: 模型名称 + /// - `input_tokens`: 输入 Token 数量 + /// - `output_tokens`: 输出 Token 数量 + /// - `cache_creation_tokens`: 缓存创建 Token 数量 + /// - `cache_read_tokens`: 缓存读取 Token 数量 + /// + /// # 返回 + /// + /// 成本分解结果 + pub fn calculate_cost( + &self, + template_id: Option<&str>, + model: &str, + input_tokens: i64, + output_tokens: i64, + cache_creation_tokens: i64, + cache_read_tokens: i64, + ) -> Result { + // 1. 获取模板 + let template = if let Some(id) = template_id { + self.get_template(id)? + } else { + // 使用 claude-code 的默认模板作为回退 + self.get_default_template("claude-code")? + }; + + // 2. 解析模型价格(别名 → 继承 → 倍率) + let model_price = self.resolve_model_price(&template, model)?; + + // 3. 计算各部分价格 + let input_price = input_tokens as f64 * model_price.input_price_per_1m / 1_000_000.0; + let output_price = output_tokens as f64 * model_price.output_price_per_1m / 1_000_000.0; + let cache_write_price = cache_creation_tokens as f64 + * model_price.cache_write_price_per_1m.unwrap_or(0.0) + / 1_000_000.0; + let cache_read_price = cache_read_tokens as f64 + * model_price.cache_read_price_per_1m.unwrap_or(0.0) + / 1_000_000.0; + + // 4. 计算总成本 + let total_cost = input_price + output_price + cache_write_price + cache_read_price; + + Ok(CostBreakdown { + input_price, + output_price, + cache_write_price, + cache_read_price, + total_cost, + template_id: template.id.clone(), + }) + } + + /// 解析模型价格(支持别名、继承、倍率) + fn resolve_model_price(&self, template: &PricingTemplate, model: &str) -> Result { + // 1. 优先查找自定义模型(直接匹配) + if let Some(price) = template.custom_models.get(model) { + return Ok(price.clone()); + } + + // 2. 别名匹配自定义模型 + for price in template.custom_models.values() { + if price.aliases.contains(&model.to_string()) { + return Ok(price.clone()); + } + } + + // 3. 查找继承配置(支持别名匹配) + for inherited in &template.inherited_models { + // 加载源模板并获取基础价格(包括别名信息) + if let Ok(source_template) = self.get_template(&inherited.source_template_id) { + if let Ok(base_price) = + self.resolve_model_price(&source_template, &inherited.model_name) + { + // 检查请求的模型名是否匹配模型名或别名 + if inherited.model_name == model + || base_price.aliases.contains(&model.to_string()) + { + // 应用倍率 + return Ok(ModelPrice { + provider: base_price.provider, + input_price_per_1m: base_price.input_price_per_1m + * inherited.multiplier, + output_price_per_1m: base_price.output_price_per_1m + * inherited.multiplier, + cache_write_price_per_1m: base_price + .cache_write_price_per_1m + .map(|p| p * inherited.multiplier), + cache_read_price_per_1m: base_price + .cache_read_price_per_1m + .map(|p| p * inherited.multiplier), + currency: base_price.currency, + aliases: base_price.aliases, + }); + } + } + } + } + + Err(anyhow!( + "Model {} not found in template {}", + model, + template.id + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::{tempdir, TempDir}; + + fn create_test_manager() -> (PricingManager, TempDir) { + let dir = tempdir().unwrap(); + let data_manager = Arc::new(DataManager::new()); + let manager = PricingManager::new_with_manager(dir.path().to_path_buf(), data_manager); + manager.initialize().unwrap(); + (manager, dir) + } + + #[test] + fn test_initialize() { + let (manager, _dir) = create_test_manager(); + + // 验证目录创建 + assert!(manager.pricing_dir.exists()); + assert!(manager.templates_dir.exists()); + assert!(manager.default_templates_path.exists()); + + // 验证内置模板存在 + let template = manager.get_template("builtin_claude").unwrap(); + assert_eq!(template.id, "builtin_claude"); + assert!(template.is_default_preset); + } + + #[test] + fn test_resolve_model_price_with_alias() { + let (manager, _dir) = create_test_manager(); + let template = manager.get_template("builtin_claude").unwrap(); + + // 测试直接匹配 + let price1 = manager + .resolve_model_price(&template, "claude-sonnet-4.5") + .unwrap(); + assert_eq!(price1.input_price_per_1m, 3.0); + + // 测试别名匹配 + let price2 = manager + .resolve_model_price(&template, "claude-sonnet-4-5") + .unwrap(); + assert_eq!(price2.input_price_per_1m, 3.0); + + let price3 = manager + .resolve_model_price(&template, "claude-sonnet-4-5-20250929") + .unwrap(); + assert_eq!(price3.input_price_per_1m, 3.0); + } + + #[test] + fn test_calculate_cost_breakdown() { + let (manager, _dir) = create_test_manager(); + + let breakdown = manager + .calculate_cost( + Some("builtin_claude"), + "claude-sonnet-4.5", + 1000, // input + 500, // output + 100, // cache write + 200, // cache read + ) + .unwrap(); + + // 验证各部分价格 + // input: 1000 * 3.0 / 1_000_000 = 0.003 + assert_eq!(breakdown.input_price, 0.003); + + // output: 500 * 15.0 / 1_000_000 = 0.0075 + assert_eq!(breakdown.output_price, 0.0075); + + // cache write: 100 * 3.75 / 1_000_000 = 0.000375 + assert_eq!(breakdown.cache_write_price, 0.000375); + + // cache read: 200 * 0.3 / 1_000_000 = 0.00006 + assert_eq!(breakdown.cache_read_price, 0.00006); + + // total: 0.003 + 0.0075 + 0.000375 + 0.00006 = 0.011235 + // 使用 assert_eq! 直接比较,因为各部分价格已验证精确 + let expected_total = breakdown.input_price + + breakdown.output_price + + breakdown.cache_write_price + + breakdown.cache_read_price; + assert_eq!(breakdown.total_cost, expected_total); + + assert_eq!(breakdown.template_id, "builtin_claude"); + } + + #[test] + fn test_multi_source_inheritance() { + let (manager, _dir) = create_test_manager(); + + // 创建一个使用多源继承的模板 + let template = PricingTemplate::new( + "test_multi_source".to_string(), + "Test Multi Source".to_string(), + "Test".to_string(), + "1.0".to_string(), + vec![ + InheritedModel::new( + "claude-sonnet-4.5".to_string(), + "builtin_claude".to_string(), + 1.1, + ), + InheritedModel::new( + "claude-opus-4.5".to_string(), + "builtin_claude".to_string(), + 1.5, + ), + ], + Default::default(), + vec![], + false, + ); + + manager.save_template(&template).unwrap(); + + // 测试 Sonnet 4.5 的继承(1.1 倍率) + let price1 = manager + .resolve_model_price(&template, "claude-sonnet-4.5") + .unwrap(); + assert_eq!(price1.input_price_per_1m, 3.0 * 1.1); + assert_eq!(price1.output_price_per_1m, 15.0 * 1.1); + + // 测试 Opus 4.5 的继承(1.5 倍率) + let price2 = manager + .resolve_model_price(&template, "claude-opus-4.5") + .unwrap(); + assert_eq!(price2.input_price_per_1m, 5.0 * 1.5); + assert_eq!(price2.output_price_per_1m, 25.0 * 1.5); + } + + #[test] + fn test_default_template_fallback() { + let (manager, _dir) = create_test_manager(); + + // 不指定模板 ID,应使用默认模板 + let breakdown = manager + .calculate_cost(None, "claude-sonnet-4.5", 1000, 500, 0, 0) + .unwrap(); + + assert_eq!(breakdown.template_id, "builtin_claude"); + assert_eq!(breakdown.input_price, 0.003); + assert_eq!(breakdown.output_price, 0.0075); + } + + #[test] + fn test_set_and_get_default_template() { + let (manager, _dir) = create_test_manager(); + + // 设置默认模板 + manager + .set_default_template("test-tool", "builtin_claude") + .unwrap(); + + // 获取默认模板 + let template = manager.get_default_template("test-tool").unwrap(); + assert_eq!(template.id, "builtin_claude"); + } + + #[test] + fn test_delete_template() { + let (manager, _dir) = create_test_manager(); + + // 创建测试模板 + let template = PricingTemplate::new( + "test_delete".to_string(), + "Test Delete".to_string(), + "Test".to_string(), + "1.0".to_string(), + vec![], + Default::default(), + vec![], + false, + ); + + manager.save_template(&template).unwrap(); + assert!(manager.get_template("test_delete").is_ok()); + + // 删除模板 + manager.delete_template("test_delete").unwrap(); + assert!(manager.get_template("test_delete").is_err()); + } + + #[test] + fn test_cannot_delete_builtin_template() { + let (manager, _dir) = create_test_manager(); + + // 尝试删除内置模板应该失败 + let result = manager.delete_template("builtin_claude"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Cannot delete built-in preset template")); + } +} diff --git a/src-tauri/src/services/pricing/mod.rs b/src-tauri/src/services/pricing/mod.rs new file mode 100644 index 0000000..17f324d --- /dev/null +++ b/src-tauri/src/services/pricing/mod.rs @@ -0,0 +1,5 @@ +pub mod builtin; +pub mod manager; + +pub use builtin::*; +pub use manager::*; diff --git a/src-tauri/src/services/profile_manager/manager.rs b/src-tauri/src/services/profile_manager/manager.rs index 175b1a5..821bc34 100644 --- a/src-tauri/src/services/profile_manager/manager.rs +++ b/src-tauri/src/services/profile_manager/manager.rs @@ -94,6 +94,17 @@ impl ProfileManager { // ==================== Claude Code ==================== pub fn save_claude_profile(&self, name: &str, api_key: String, base_url: String) -> Result<()> { + self.save_claude_profile_with_template(name, api_key, base_url, None) + } + + /// 保存 Claude Profile(支持价格模板) + pub fn save_claude_profile_with_template( + &self, + name: &str, + api_key: String, + base_url: String, + pricing_template_id: Option, + ) -> Result<()> { // 保留字校验 validate_profile_name(name)?; @@ -107,6 +118,8 @@ impl ProfileManager { if !base_url.is_empty() { existing.base_url = base_url; } + // Phase 6: 更新价格模板 ID(允许清空) + existing.pricing_template_id = pricing_template_id; existing.updated_at = Utc::now(); existing.clone() } else { @@ -122,6 +135,7 @@ impl ProfileManager { raw_settings: None, raw_config_json: None, source: ProfileSource::Custom, + pricing_template_id, // Phase 6: 价格模板 ID } }; @@ -175,6 +189,18 @@ impl ProfileManager { api_key: String, base_url: String, wire_api: Option, + ) -> Result<()> { + self.save_codex_profile_with_template(name, api_key, base_url, wire_api, None) + } + + /// 保存 Codex Profile(支持价格模板) + pub fn save_codex_profile_with_template( + &self, + name: &str, + api_key: String, + base_url: String, + wire_api: Option, + pricing_template_id: Option, ) -> Result<()> { // 保留字校验 validate_profile_name(name)?; @@ -192,6 +218,8 @@ impl ProfileManager { if let Some(w) = wire_api { existing.wire_api = w; } + // Phase 6: 更新价格模板 ID(允许清空) + existing.pricing_template_id = pricing_template_id; existing.updated_at = Utc::now(); existing.clone() } else { @@ -208,6 +236,7 @@ impl ProfileManager { raw_config_toml: None, raw_auth_json: None, source: ProfileSource::Custom, + pricing_template_id, // Phase 6: 价格模板 ID } }; @@ -261,6 +290,18 @@ impl ProfileManager { api_key: String, base_url: String, model: Option, + ) -> Result<()> { + self.save_gemini_profile_with_template(name, api_key, base_url, model, None) + } + + /// 保存 Gemini Profile(支持价格模板) + pub fn save_gemini_profile_with_template( + &self, + name: &str, + api_key: String, + base_url: String, + model: Option, + pricing_template_id: Option, ) -> Result<()> { // 保留字校验 validate_profile_name(name)?; @@ -280,6 +321,8 @@ impl ProfileManager { existing.model = Some(m); } } + // Phase 6: 更新价格模板 ID(允许清空) + existing.pricing_template_id = pricing_template_id; existing.updated_at = Utc::now(); existing.clone() } else { @@ -296,6 +339,7 @@ impl ProfileManager { raw_settings: None, raw_env: None, source: ProfileSource::Custom, + pricing_template_id, // Phase 6: 价格模板 ID } }; @@ -411,13 +455,41 @@ impl ProfileManager { // 应用到原生配置文件 self.apply_to_native(tool_id, profile_name)?; - // 读取应用后的配置并保存快照 + // 读取应用后的配置并保存快照(为每个工具读取所有配置文件) let tool = crate::models::Tool::by_id(tool_id) .ok_or_else(|| anyhow!("未找到工具: {}", tool_id))?; - let config_path = tool.config_dir.join(&tool.config_file); - if config_path.exists() { - let manager = crate::data::DataManager::new(); - let snapshot = manager.json_uncached().read(&config_path)?; + let manager = crate::data::DataManager::new(); + let mut files_snapshot = std::collections::HashMap::new(); + + for filename in tool.config_files() { + let config_path = tool.config_dir.join(&filename); + if !config_path.exists() { + continue; + } + + let content = if filename.ends_with(".json") { + // JSON 文件:直接读取 + manager.json_uncached().read(&config_path)? + } else if filename.ends_with(".toml") { + // TOML 文件:读取后转换为 JSON + let doc = manager.toml().read_document(&config_path)?; + let toml_str = doc.to_string(); + let toml_value: toml::Value = + toml::from_str(&toml_str).map_err(|e| anyhow!("TOML 解析失败: {}", e))?; + serde_json::to_value(toml_value)? + } else if filename.ends_with(".env") || filename == ".env" { + // ENV 文件:读取为 HashMap 后转换为 JSON + let env_map = manager.env().read(&config_path)?; + serde_json::to_value(env_map)? + } else { + continue; + }; + + files_snapshot.insert(filename.clone(), content); + } + + if !files_snapshot.is_empty() { + let snapshot = serde_json::to_value(files_snapshot)?; self.save_native_snapshot(tool_id, snapshot)?; tracing::debug!("已保存 Profile 快照: {} / {}", tool_id, profile_name); } @@ -501,6 +573,7 @@ impl ProfileManager { raw_settings: None, raw_config_json: None, source: ProfileSource::Custom, + pricing_template_id: None, } }; @@ -549,6 +622,7 @@ impl ProfileManager { raw_config_toml: None, raw_auth_json: None, source: ProfileSource::Custom, + pricing_template_id: None, } }; @@ -599,6 +673,7 @@ impl ProfileManager { raw_settings: None, raw_env: None, source: ProfileSource::Custom, + pricing_template_id: None, } }; diff --git a/src-tauri/src/services/profile_manager/types.rs b/src-tauri/src/services/profile_manager/types.rs index 49d1d0f..0c166b1 100644 --- a/src-tauri/src/services/profile_manager/types.rs +++ b/src-tauri/src/services/profile_manager/types.rs @@ -80,6 +80,9 @@ pub struct ClaudeProfile { pub raw_settings: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub raw_config_json: Option, + /// 价格模板 ID(用于成本计算) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pricing_template_id: Option, } /// Codex Profile @@ -97,6 +100,9 @@ pub struct CodexProfile { pub raw_config_toml: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub raw_auth_json: Option, + /// 价格模板 ID(用于成本计算) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pricing_template_id: Option, } fn default_codex_wire_api() -> String { @@ -118,6 +124,9 @@ pub struct GeminiProfile { pub raw_settings: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub raw_env: Option, + /// 价格模板 ID(用于成本计算) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pricing_template_id: Option, } // ==================== profiles.json 结构 ==================== @@ -307,6 +316,9 @@ pub struct ProfileDescriptor { pub provider: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, + /// 价格模板 ID(用于成本计算) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pricing_template_id: Option, } impl ProfileDescriptor { @@ -334,6 +346,7 @@ impl ProfileDescriptor { switched_at, provider: None, model: None, + pricing_template_id: profile.pricing_template_id.clone(), } } @@ -361,6 +374,7 @@ impl ProfileDescriptor { switched_at, provider: Some(profile.wire_api.clone()), // 前端仍使用 provider 字段名 model: None, + pricing_template_id: profile.pricing_template_id.clone(), } } @@ -388,6 +402,7 @@ impl ProfileDescriptor { switched_at, provider: None, model: profile.model.clone(), + pricing_template_id: profile.pricing_template_id.clone(), } } } diff --git a/src-tauri/src/services/proxy/headers/claude_processor.rs b/src-tauri/src/services/proxy/headers/claude_processor.rs index 25d9fe0..5272fd8 100644 --- a/src-tauri/src/services/proxy/headers/claude_processor.rs +++ b/src-tauri/src/services/proxy/headers/claude_processor.rs @@ -40,8 +40,13 @@ impl RequestProcessor for ClaudeHeadersProcessor { let timestamp = chrono::Utc::now().timestamp(); // 查询会话配置 - if let Ok(Some((config_name, session_url, session_api_key))) = - SESSION_MANAGER.get_session_config(user_id) + if let Ok(Some(( + config_name, + _custom_profile_name, + session_url, + session_api_key, + _session_pricing_template_id, + ))) = SESSION_MANAGER.get_session_config(user_id) { // 如果是自定义配置且有 URL 和 API Key,使用数据库的配置 if config_name == "custom" @@ -129,4 +134,58 @@ impl RequestProcessor for ClaudeHeadersProcessor { // Claude Code 不需要特殊的响应处理 // 使用默认实现即可 + + /// 提取模型名称 + fn extract_model(&self, request_body: &[u8]) -> Option { + if request_body.is_empty() { + return None; + } + + // 尝试解析请求体 JSON + if let Ok(json_body) = serde_json::from_slice::(request_body) { + // Claude API 的模型字段在顶层 + json_body + .get("model") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + } else { + None + } + } + + /// Claude Code 的请求日志记录实现 + /// + /// 使用统一的日志记录架构,自动处理所有错误场景 + async fn record_request_log( + &self, + client_ip: &str, + config_name: &str, + proxy_pricing_template_id: Option<&str>, + request_body: &[u8], + response_status: u16, + response_body: &[u8], + is_sse: bool, + _response_time_ms: Option, + ) -> Result<()> { + use crate::services::proxy::log_recorder::{ + LogRecorder, RequestLogContext, ResponseParser, + }; + + // 1. 创建请求上下文(一次性提取所有信息) + let context = RequestLogContext::from_request( + self.tool_id(), + config_name, + client_ip, + proxy_pricing_template_id, + request_body, + ); + + // 2. 解析响应 + let parsed = ResponseParser::parse(response_body, response_status, is_sse); + + // 3. 记录日志(自动处理成功/失败/解析错误) + LogRecorder::record(&context, response_status, parsed).await?; + + Ok(()) + } } diff --git a/src-tauri/src/services/proxy/headers/mod.rs b/src-tauri/src/services/proxy/headers/mod.rs index 1326002..4730e01 100644 --- a/src-tauri/src/services/proxy/headers/mod.rs +++ b/src-tauri/src/services/proxy/headers/mod.rs @@ -91,6 +91,53 @@ pub trait RequestProcessor: Send + Sync + std::fmt::Debug { fn should_process_response(&self) -> bool { false } + + /// 提取模型名称(用于成本计算) + /// + /// # 参数 + /// - `request_body`: 请求体字节数组 + /// + /// # 返回 + /// - `Some(String)`: 成功提取模型名称 + /// - `None`: 未找到模型名称或解析失败 + /// + /// # 默认实现 + /// 默认返回 None(不提取模型) + fn extract_model(&self, _request_body: &[u8]) -> Option { + None + } + + /// 记录请求日志(包括 Token 统计) + /// + /// 不同的 AI 工具有不同的数据格式和会话 ID 提取方式, + /// 因此每个工具需要实现自己的日志记录逻辑。 + /// + /// # 参数 + /// - `client_ip`: 客户端 IP 地址 + /// - `config_name`: 配置名称("global" 或 Profile 名称) + /// - `proxy_pricing_template_id`: 代理配置的价格模板 ID + /// - `request_body`: 请求体字节数组 + /// - `response_status`: HTTP 响应状态码 + /// - `response_body`: 响应体字节数组 + /// - `is_sse`: 是否为 SSE 流式响应 + /// - `response_time_ms`: 响应时间(毫秒) + /// + /// # 默认实现 + /// 默认不记录日志(空操作) + #[allow(clippy::too_many_arguments)] + async fn record_request_log( + &self, + _client_ip: &str, + _config_name: &str, + _proxy_pricing_template_id: Option<&str>, + _request_body: &[u8], + _response_status: u16, + _response_body: &[u8], + _is_sse: bool, + _response_time_ms: Option, + ) -> Result<()> { + Ok(()) + } } /// 创建请求处理器工厂函数 diff --git a/src-tauri/src/services/proxy/log_recorder/context.rs b/src-tauri/src/services/proxy/log_recorder/context.rs new file mode 100644 index 0000000..14d18cc --- /dev/null +++ b/src-tauri/src/services/proxy/log_recorder/context.rs @@ -0,0 +1,114 @@ +// 请求上下文提取层 +// +// 职责:在请求处理早期一次性提取所有必要信息,避免重复解析 + +use crate::services::session::manager::SESSION_MANAGER; +use crate::services::session::models::ProxySession; +use std::time::Instant; + +/// 请求日志上下文(在请求处理早期提取) +#[derive(Debug, Clone)] +pub struct RequestLogContext { + pub tool_id: String, + pub session_id: String, // 从 request_body 提取 + pub config_name: String, + pub client_ip: String, + pub pricing_template_id: Option, // 会话级 > 代理级 + pub model: Option, // 从 request_body 提取 + pub is_stream: bool, // 从 request_body 提取 stream 字段 + pub request_body: Vec, // 保留原始请求体 + pub start_time: Instant, +} + +impl RequestLogContext { + /// 从请求创建上下文(早期提取,仅解析一次) + pub fn from_request( + tool_id: &str, + config_name: &str, + client_ip: &str, + proxy_pricing_template_id: Option<&str>, + request_body: &[u8], + ) -> Self { + // 提取 user_id(完整)、display_id(用于日志)、model 和 stream(仅解析一次) + let (user_id, session_id, model, is_stream) = if !request_body.is_empty() { + match serde_json::from_slice::(request_body) { + Ok(json) => { + // 提取完整 user_id(用于查询配置) + let user_id = json["metadata"]["user_id"] + .as_str() + .map(|s| s.to_string()) + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + + // 提取 display_id(用于存储日志) + let session_id = ProxySession::extract_display_id(&user_id) + .unwrap_or_else(|| user_id.clone()); + + let model = json["model"].as_str().map(|s| s.to_string()); + let is_stream = json["stream"].as_bool().unwrap_or(false); + (user_id, session_id, model, is_stream) + } + Err(_) => { + let fallback_id = uuid::Uuid::new_v4().to_string(); + (fallback_id.clone(), fallback_id, None, false) + } + } + } else { + let fallback_id = uuid::Uuid::new_v4().to_string(); + (fallback_id.clone(), fallback_id, None, false) + }; + + // 查询会话级别的配置(优先级:会话 > 代理),使用完整 user_id 查询 + let (config_name, pricing_template_id) = + Self::resolve_session_config(&user_id, config_name, proxy_pricing_template_id); + + Self { + tool_id: tool_id.to_string(), + session_id, + config_name, + client_ip: client_ip.to_string(), + pricing_template_id, + model, + is_stream, + request_body: request_body.to_vec(), + start_time: Instant::now(), + } + } + + /// 解析会话级配置(同时提取 config_name 和 pricing_template_id) + fn resolve_session_config( + session_id: &str, + proxy_config_name: &str, + proxy_template_id: Option<&str>, + ) -> (String, Option) { + // 查询会话配置 + if let Ok(Some(( + config_name, + custom_profile_name, + session_url, + session_api_key, + session_pricing_template_id, + ))) = SESSION_MANAGER.get_session_config(session_id) + { + // 判断是否为自定义配置:config_name == "custom" 且 URL、API Key、pricing_template_id 都不为空 + if config_name == "custom" + && !session_url.is_empty() + && !session_api_key.is_empty() + && session_pricing_template_id.is_some() + { + // 使用会话级配置:custom_profile_name 作为配置名(如果存在) + let final_config_name = custom_profile_name.unwrap_or_else(|| "custom".to_string()); + return (final_config_name, session_pricing_template_id); + } + } + + // 回退到代理级配置(包括会话存在但不是自定义配置的情况) + ( + proxy_config_name.to_string(), + proxy_template_id.map(|s| s.to_string()), + ) + } + + pub fn elapsed_ms(&self) -> i64 { + self.start_time.elapsed().as_millis() as i64 + } +} diff --git a/src-tauri/src/services/proxy/log_recorder/mod.rs b/src-tauri/src/services/proxy/log_recorder/mod.rs new file mode 100644 index 0000000..a598a39 --- /dev/null +++ b/src-tauri/src/services/proxy/log_recorder/mod.rs @@ -0,0 +1,16 @@ +// 日志记录模块 - 统一的请求日志记录架构 +// +// 职责: +// - 提取请求上下文 +// - 解析响应数据(SSE/JSON) +// - 提取 Token 统计 +// - 计算成本 +// - 记录到数据库 + +mod context; +mod parser; +mod recorder; + +pub use context::RequestLogContext; +pub use parser::{ParsedResponse, ResponseParser}; +pub use recorder::LogRecorder; diff --git a/src-tauri/src/services/proxy/log_recorder/parser.rs b/src-tauri/src/services/proxy/log_recorder/parser.rs new file mode 100644 index 0000000..f91468a --- /dev/null +++ b/src-tauri/src/services/proxy/log_recorder/parser.rs @@ -0,0 +1,84 @@ +// 响应解析层 +// +// 职责:安全解析响应数据,区分 SSE 流式和 JSON 非流式,永不 panic + +use serde_json::Value; + +/// 解析后的响应数据 +#[derive(Debug)] +pub enum ParsedResponse { + /// SSE 流式响应(已提取的 data 块) + Sse { data_lines: Vec }, + /// JSON 响应(已解析的 JSON) + Json { data: Value }, + /// 空响应(上游失败或连接中断) + Empty, + /// 解析失败(保留原始字节和错误信息) + ParseError { + raw_bytes: Vec, + error: String, + response_type: &'static str, // "sse" 或 "json" + }, +} + +pub struct ResponseParser; + +impl ResponseParser { + /// 安全解析响应(区分 SSE/JSON,永不 panic) + pub fn parse(response_body: &[u8], response_status: u16, is_sse: bool) -> ParsedResponse { + // 1. 检查空响应或无状态码 + if response_body.is_empty() || response_status == 0 { + return ParsedResponse::Empty; + } + + // 2. 根据类型解析 + if is_sse { + Self::parse_sse(response_body) + } else { + Self::parse_json(response_body) + } + } + + /// 解析 SSE 流式响应 + /// + /// SSE 格式示例: + /// ``` + /// data: {"type":"message_start","message":{...}} + /// + /// data: {"type":"content_block_delta",...} + /// + /// data: {"type":"message_delta","delta":{...},"usage":{...}} + /// ``` + fn parse_sse(response_body: &[u8]) -> ParsedResponse { + let body_str = String::from_utf8_lossy(response_body); + let data_lines: Vec = body_str + .lines() + .filter(|line| line.starts_with("data: ")) + .map(|line| line.trim_start_matches("data: ").to_string()) + .filter(|line| !line.is_empty() && line != "[DONE]") // 过滤空行和结束标记 + .collect(); + + if data_lines.is_empty() { + // SSE 流为空或仅包含无效数据 + return ParsedResponse::ParseError { + raw_bytes: response_body.to_vec(), + error: "SSE 流不包含有效的 data 块".to_string(), + response_type: "sse", + }; + } + + ParsedResponse::Sse { data_lines } + } + + /// 解析 JSON 响应 + fn parse_json(response_body: &[u8]) -> ParsedResponse { + match serde_json::from_slice::(response_body) { + Ok(data) => ParsedResponse::Json { data }, + Err(e) => ParsedResponse::ParseError { + raw_bytes: response_body.to_vec(), + error: e.to_string(), + response_type: "json", + }, + } + } +} diff --git a/src-tauri/src/services/proxy/log_recorder/recorder.rs b/src-tauri/src/services/proxy/log_recorder/recorder.rs new file mode 100644 index 0000000..20198f0 --- /dev/null +++ b/src-tauri/src/services/proxy/log_recorder/recorder.rs @@ -0,0 +1,257 @@ +// 日志记录层 +// +// 职责:统一的日志记录接口,处理成功/失败/解析错误等所有场景 + +use super::{ParsedResponse, RequestLogContext}; +use crate::services::token_stats::manager::{ResponseData, TokenStatsManager}; +use anyhow::Result; +use hyper::StatusCode; + +pub struct LogRecorder; + +impl LogRecorder { + /// 记录请求日志(统一入口) + pub async fn record( + context: &RequestLogContext, + response_status: u16, + parsed: ParsedResponse, + ) -> Result<()> { + // 1. 检查 HTTP 状态码 + let status_code = + StatusCode::from_u16(response_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + if status_code.is_client_error() || status_code.is_server_error() { + // HTTP 4xx/5xx 错误 + Self::record_http_error(context, response_status, &status_code).await + } else { + // HTTP 2xx/3xx 或无状态码,根据解析结果处理 + match parsed { + ParsedResponse::Sse { data_lines } => { + // SSE 成功响应 + Self::record_sse_success(context, data_lines).await + } + ParsedResponse::Json { data } => { + // JSON 成功响应 + Self::record_json_success(context, data).await + } + ParsedResponse::Empty => { + // 空响应(上游失败) + Self::record_upstream_error(context, "上游返回空响应体").await + } + ParsedResponse::ParseError { + error, + response_type, + .. + } => { + // 解析失败 + Self::record_parse_error(context, &error, response_type).await + } + } + } + } + + /// 记录 SSE 成功响应 + async fn record_sse_success( + context: &RequestLogContext, + data_lines: Vec, + ) -> Result<()> { + let manager = TokenStatsManager::get(); + + match manager + .log_request( + &context.tool_id, + &context.session_id, + &context.config_name, + &context.client_ip, + &context.request_body, + ResponseData::Sse(data_lines), + Some(context.elapsed_ms()), + context.pricing_template_id.clone(), + ) + .await + { + Ok(_) => { + tracing::debug!( + tool_id = %context.tool_id, + session_id = %context.session_id, + "SSE 流式响应记录成功" + ); + Ok(()) + } + Err(e) => { + tracing::error!( + tool_id = %context.tool_id, + session_id = %context.session_id, + error = ?e, + "SSE Token 提取失败,记录为 parse_error" + ); + + // Token 提取失败,记录为 parse_error + let error_detail = format!("SSE Token 提取失败: {}", e); + manager + .log_failed_request( + &context.tool_id, + &context.session_id, + &context.config_name, + &context.client_ip, + &context.request_body, + "parse_error", + &error_detail, + "sse", + Some(context.elapsed_ms()), + ) + .await + } + } + } + + /// 记录 JSON 成功响应 + async fn record_json_success( + context: &RequestLogContext, + data: serde_json::Value, + ) -> Result<()> { + let manager = TokenStatsManager::get(); + + match manager + .log_request( + &context.tool_id, + &context.session_id, + &context.config_name, + &context.client_ip, + &context.request_body, + ResponseData::Json(data), + Some(context.elapsed_ms()), + context.pricing_template_id.clone(), + ) + .await + { + Ok(_) => { + tracing::debug!( + tool_id = %context.tool_id, + session_id = %context.session_id, + "JSON 响应记录成功" + ); + Ok(()) + } + Err(e) => { + tracing::error!( + tool_id = %context.tool_id, + session_id = %context.session_id, + error = ?e, + "JSON Token 提取失败,记录为 parse_error" + ); + + // Token 提取失败,记录为 parse_error + let error_detail = format!("JSON Token 提取失败: {}", e); + manager + .log_failed_request( + &context.tool_id, + &context.session_id, + &context.config_name, + &context.client_ip, + &context.request_body, + "parse_error", + &error_detail, + "json", + Some(context.elapsed_ms()), + ) + .await + } + } + } + + /// 记录解析错误 + async fn record_parse_error( + context: &RequestLogContext, + error: &str, + response_type: &str, + ) -> Result<()> { + let error_detail = format!("响应解析失败: {}", error); + + tracing::warn!( + tool_id = %context.tool_id, + session_id = %context.session_id, + response_type = response_type, + error = error, + "响应解析失败" + ); + + TokenStatsManager::get() + .log_failed_request( + &context.tool_id, + &context.session_id, + &context.config_name, + &context.client_ip, + &context.request_body, + "parse_error", + &error_detail, + response_type, + Some(context.elapsed_ms()), + ) + .await + } + + /// 记录上游错误(空响应或连接失败) + pub async fn record_upstream_error(context: &RequestLogContext, detail: &str) -> Result<()> { + tracing::warn!( + tool_id = %context.tool_id, + session_id = %context.session_id, + detail = detail, + is_stream = context.is_stream, + "上游请求失败" + ); + + let response_type = if context.is_stream { "sse" } else { "json" }; + + TokenStatsManager::get() + .log_failed_request( + &context.tool_id, + &context.session_id, + &context.config_name, + &context.client_ip, + &context.request_body, + "upstream_error", + detail, + response_type, // 根据请求体的 stream 字段判断 + Some(context.elapsed_ms()), + ) + .await + } + + /// 记录 HTTP 错误(4xx/5xx) + async fn record_http_error( + context: &RequestLogContext, + status: u16, + status_code: &StatusCode, + ) -> Result<()> { + let error_detail = format!( + "HTTP {}: {}", + status, + status_code.canonical_reason().unwrap_or("Unknown") + ); + + tracing::warn!( + tool_id = %context.tool_id, + session_id = %context.session_id, + status = status, + is_stream = context.is_stream, + "HTTP 错误响应" + ); + + let response_type = if context.is_stream { "sse" } else { "json" }; + + TokenStatsManager::get() + .log_failed_request( + &context.tool_id, + &context.session_id, + &context.config_name, + &context.client_ip, + &context.request_body, + "upstream_error", + &error_detail, + response_type, // 根据请求体的 stream 字段判断 + Some(context.elapsed_ms()), + ) + .await + } +} diff --git a/src-tauri/src/services/proxy/mod.rs b/src-tauri/src/services/proxy/mod.rs index 2e67568..906e019 100644 --- a/src-tauri/src/services/proxy/mod.rs +++ b/src-tauri/src/services/proxy/mod.rs @@ -4,6 +4,7 @@ pub mod config; // 代理配置辅助模块 pub mod headers; +pub mod log_recorder; // 统一日志记录模块 pub mod proxy_instance; pub mod proxy_manager; pub mod proxy_service; diff --git a/src-tauri/src/services/proxy/proxy_instance.rs b/src-tauri/src/services/proxy/proxy_instance.rs index 84b6de2..f7be6d7 100644 --- a/src-tauri/src/services/proxy/proxy_instance.rs +++ b/src-tauri/src/services/proxy/proxy_instance.rs @@ -7,7 +7,7 @@ use anyhow::{Context, Result}; use bytes::Bytes; -use http_body_util::BodyExt; +use http_body_util::{BodyExt, Full}; use hyper::body::{Frame, Incoming}; use hyper::server::conn::http1; use hyper::service::service_fn; @@ -150,6 +150,24 @@ impl ProxyInstance { error = ?e, "接受连接失败" ); + + // 记录连接层错误到数据库(无 session_id) + let manager = + crate::services::token_stats::manager::TokenStatsManager::get(); + let error_detail = format!("连接处理失败: {:?}", e); + let _ = manager + .log_failed_request( + &tool_id, + "connection_error", // 通用会话 ID + "global", + "unknown", // 无法获取客户端 IP + &[], // 无请求体 + "connection_error", + &error_detail, + "unknown", // 无法确定响应类型 + None, // 无响应时间 + ) + .await; } } } @@ -242,6 +260,9 @@ async fn handle_request_inner( own_port: u16, tool_id: &str, ) -> Result> { + // 记录请求开始时间(用于计算响应时间) + let start_time = std::time::Instant::now(); + // 获取配置 let proxy_config = { let cfg = config.read().await; @@ -279,6 +300,35 @@ async fn handle_request_inner( let method = req.method().clone(); let headers = req.headers().clone(); + // 拦截 count_tokens 接口,不转发到上游,直接返回权限错误 + if path == "/v1/messages/count_tokens" { + tracing::warn!("拦截 count_tokens 请求,返回权限错误"); + let error_response = serde_json::json!({ + "type": "error", + "error": { + "type": "permission_error", + "message": "count_tokens endpoint is not enabled for this channel. Please enable it in channel settings." + } + }); + let response_body = serde_json::to_string(&error_response) + .unwrap_or_else(|_| r#"{"type":"error","error":{"type":"internal_error","message":"Failed to serialize error response"}}"#.to_string()); + + return Response::builder() + .status(StatusCode::FORBIDDEN) + .header("content-type", "application/json") + .body(box_body(Full::new(Bytes::from(response_body)))) + .map_err(|e| anyhow::anyhow!("Failed to build count_tokens error response: {}", e)); + } + + // 提取客户端IP(用于日志记录) + let client_ip = req + .headers() + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.split(',').next()) + .unwrap_or("unknown") + .to_string(); + // amp-code 在 processor 内部获取配置,这里传占位符 let base = proxy_config .real_base_url @@ -354,7 +404,45 @@ async fn handle_request_inner( } // 发送请求 - let upstream_res = reqwest_builder.send().await.context("上游请求失败")?; + let upstream_res = match reqwest_builder.send().await { + Ok(res) => res, + Err(e) => { + // 上游请求失败,记录错误到数据库 + let processor_clone = Arc::clone(&processor); + let client_ip_clone = client_ip.clone(); + let config_name_clone = proxy_config + .real_profile_name + .clone() + .unwrap_or_else(|| "default".to_string()); + let proxy_pricing_template_id_clone = proxy_config.pricing_template_id.clone(); + let request_body_clone = processed.body.clone(); + let error_msg = e.to_string(); + + // 从请求体中判断是否为流式请求 + let is_sse = serde_json::from_slice::(&processed.body) + .ok() + .and_then(|json| json.get("stream").and_then(|v| v.as_bool())) + .unwrap_or(false); + + tokio::spawn(async move { + // 调用 record_request_log,传递 response_status=0 标记为上游失败 + let _ = processor_clone + .record_request_log( + &client_ip_clone, + &config_name_clone, + proxy_pricing_template_id_clone.as_deref(), + &request_body_clone, + 0, // response_status=0 标记上游请求失败 + &[], // 空响应体 + is_sse, // 从请求体提取 + Some(start_time.elapsed().as_millis() as i64), + ) + .await; + }); + + return Err(anyhow::anyhow!("上游请求失败: {}", error_msg)); + } + }; // 构建响应 let status = StatusCode::from_u16(upstream_res.status().as_u16()) @@ -377,8 +465,22 @@ async fn handle_request_inner( if is_sse { tracing::debug!(tool_id = %tool_id, "SSE 流式响应"); + + // SSE 流式响应:收集响应体并调用 processor.record_request_log use futures_util::StreamExt; use regex::Regex; + use std::sync::{Arc, Mutex}; + + let config_name = proxy_config + .real_profile_name + .clone() + .unwrap_or_else(|| "default".to_string()); + + let proxy_pricing_template_id = proxy_config.pricing_template_id.clone(); + + // 使用 Arc> 在流处理过程中收集数据 + let sse_chunks = Arc::new(Mutex::new(Vec::new())); + let sse_chunks_clone = Arc::clone(&sse_chunks); let stream = upstream_res.bytes_stream(); @@ -386,7 +488,13 @@ async fn handle_request_inner( let is_amp_code = tool_id == "amp-code"; let prefix_regex = Regex::new(r#""name"\s*:\s*"mcp_([^"]+)""#).ok(); + // 拦截流数据并收集 let mapped_stream = stream.map(move |result| { + if let Ok(chunk) = &result { + if let Ok(mut chunks) = sse_chunks_clone.lock() { + chunks.push(chunk.clone()); + } + } result .map(|bytes| { if is_amp_code { @@ -404,21 +512,104 @@ async fn handle_request_inner( .map_err(|e| Box::new(e) as Box) }); + // 在流结束后异步记录日志 + let processor_clone = Arc::clone(&processor); + let client_ip_clone = client_ip.clone(); + let request_body_clone = processed.body.clone(); + let response_status = status.as_u16(); + let start_time_clone = start_time; // 捕获 start_time 用于计算响应时间 + let proxy_pricing_template_id_clone = proxy_pricing_template_id.clone(); + + tokio::spawn(async move { + // 等待流结束(延迟确保所有 chunks 已收集) + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + + let chunks = match sse_chunks.lock() { + Ok(guard) => guard.clone(), + Err(e) => { + tracing::error!(error = ?e, "获取 SSE chunks 锁失败"); + return; + } + }; + + // 将所有 chunk 合并为完整响应体 + let mut full_data = Vec::new(); + for chunk in &chunks { + full_data.extend_from_slice(chunk); + } + + // 计算响应时间 + let response_time_ms = start_time_clone.elapsed().as_millis() as i64; + + // 调用工具特定的日志记录 + if let Err(e) = processor_clone + .record_request_log( + &client_ip_clone, + &config_name, + proxy_pricing_template_id_clone.as_deref(), + &request_body_clone, + response_status, + &full_data, + true, // is_sse + Some(response_time_ms), + ) + .await + { + tracing::error!(error = ?e, "SSE 流日志记录失败"); + } + }); + let body = http_body_util::StreamBody::new(mapped_stream); Ok(response.body(box_body(body)).unwrap()) } else { - // 普通响应 + // 普通响应:读取响应体并调用 processor.record_request_log let body_bytes = upstream_res.bytes().await.context("读取响应体失败")?; + // amp-code 需要清理响应体中的工具名前缀 let final_body = if tool_id == "amp-code" { let text = String::from_utf8_lossy(&body_bytes); let re = regex::Regex::new(r#""name"\s*:\s*"mcp_([^"]+)""#).unwrap(); let cleaned = re.replace_all(&text, r#""name": "$1""#); Bytes::from(cleaned.into_owned()) } else { - body_bytes + body_bytes.clone() }; + // 获取配置名称 + let config_name = proxy_config + .real_profile_name + .clone() + .unwrap_or_else(|| "default".to_string()); + + let proxy_pricing_template_id = proxy_config.pricing_template_id.clone(); + + // 异步记录日志 + let processor_clone = Arc::clone(&processor); + let client_ip_clone = client_ip.clone(); + let request_body_clone = processed.body.clone(); + let response_body_clone = body_bytes.clone(); + let response_status = status.as_u16(); + let response_time_ms = start_time.elapsed().as_millis() as i64; // 计算响应时间 + + tokio::spawn(async move { + // 调用工具特定的日志记录 + if let Err(e) = processor_clone + .record_request_log( + &client_ip_clone, + &config_name, + proxy_pricing_template_id.as_deref(), + &request_body_clone, + response_status, + &response_body_clone, + false, // is_sse + Some(response_time_ms), + ) + .await + { + tracing::error!(error = ?e, "日志记录失败"); + } + }); + Ok(response .body(box_body(http_body_util::Full::new(final_body))) .unwrap()) diff --git a/src-tauri/src/services/proxy/proxy_service.rs b/src-tauri/src/services/proxy/proxy_service.rs index 8ec04c8..fcc09e8 100644 --- a/src-tauri/src/services/proxy/proxy_service.rs +++ b/src-tauri/src/services/proxy/proxy_service.rs @@ -231,6 +231,7 @@ mod tests { single_instance_enabled: true, startup_enabled: false, config_watch: crate::models::config::ConfigWatchConfig::default(), + token_stats_config: crate::models::config::TokenStatsConfig::default(), }; let url = ProxyService::build_proxy_url(&config); @@ -261,6 +262,7 @@ mod tests { single_instance_enabled: true, startup_enabled: false, config_watch: crate::models::config::ConfigWatchConfig::default(), + token_stats_config: crate::models::config::TokenStatsConfig::default(), }; let url = ProxyService::build_proxy_url(&config); @@ -294,6 +296,7 @@ mod tests { single_instance_enabled: true, startup_enabled: false, config_watch: crate::models::config::ConfigWatchConfig::default(), + token_stats_config: crate::models::config::TokenStatsConfig::default(), }; let url = ProxyService::build_proxy_url(&config); diff --git a/src-tauri/src/services/session/db_utils.rs b/src-tauri/src/services/session/db_utils.rs index 91d8102..fdac929 100644 --- a/src-tauri/src/services/session/db_utils.rs +++ b/src-tauri/src/services/session/db_utils.rs @@ -6,9 +6,12 @@ use crate::data::managers::sqlite::QueryRow; use crate::services::session::models::ProxySession; use anyhow::{anyhow, Context, Result}; +/// 会话配置类型:(config_name, custom_profile_name, url, api_key, pricing_template_id) +pub type SessionConfig = (String, Option, String, String, Option); + /// 标准会话查询的 SQL 语句 /// -/// **字段顺序(共 13 个):** +/// **字段顺序(共 14 个):** /// 1. session_id /// 2. display_id /// 3. tool_id @@ -22,10 +25,11 @@ use anyhow::{anyhow, Context, Result}; /// 11. request_count /// 12. created_at /// 13. updated_at +/// 14. pricing_template_id pub const SELECT_SESSION_FIELDS: &str = "session_id, display_id, tool_id, config_name, \ custom_profile_name, url, api_key, note, \ first_seen_at, last_seen_at, request_count, \ - created_at, updated_at"; + created_at, updated_at, pricing_template_id"; /// 创建表的 SQL 语句 pub const CREATE_TABLE_SQL: &str = " @@ -42,7 +46,8 @@ CREATE TABLE IF NOT EXISTS claude_proxy_sessions ( last_seen_at INTEGER NOT NULL, request_count INTEGER NOT NULL DEFAULT 0, created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL + updated_at INTEGER NOT NULL, + pricing_template_id TEXT ); CREATE INDEX IF NOT EXISTS idx_tool_id ON claude_proxy_sessions(tool_id); @@ -50,10 +55,23 @@ CREATE INDEX IF NOT EXISTS idx_display_id ON claude_proxy_sessions(display_id); CREATE INDEX IF NOT EXISTS idx_last_seen_at ON claude_proxy_sessions(last_seen_at); "; -/// 兼容旧数据库的字段添加语句 +/// 兼容旧数据库的字段添加语句(每个语句单独执行,忽略重复列错误) +pub const ALTER_TABLE_STATEMENTS: &[&str] = &[ + "ALTER TABLE claude_proxy_sessions ADD COLUMN custom_profile_name TEXT", + "ALTER TABLE claude_proxy_sessions ADD COLUMN note TEXT", + "ALTER TABLE claude_proxy_sessions ADD COLUMN pricing_template_id TEXT", +]; + +/// 兼容旧数据库的字段添加语句(已废弃,保留用于向后兼容) +#[deprecated( + since = "1.5.1", + note = "请使用 ALTER_TABLE_STATEMENTS 数组,每个语句单独执行" +)] +#[allow(dead_code)] pub const ALTER_TABLE_SQL: &str = " ALTER TABLE claude_proxy_sessions ADD COLUMN custom_profile_name TEXT; ALTER TABLE claude_proxy_sessions ADD COLUMN note TEXT; +ALTER TABLE claude_proxy_sessions ADD COLUMN pricing_template_id TEXT; "; /// 从 QueryRow 解析为 ProxySession @@ -73,9 +91,9 @@ ALTER TABLE claude_proxy_sessions ADD COLUMN note TEXT; /// - values[7]: note (可为 NULL) /// - values[8..12]: 整数字段 pub fn parse_proxy_session(row: &QueryRow) -> Result { - if row.values.len() != 13 { + if row.values.len() != 14 { return Err(anyhow!( - "Invalid row: expected 13 columns, got {}", + "Invalid row: expected 14 columns, got {}", row.values.len() )); } @@ -118,6 +136,7 @@ pub fn parse_proxy_session(row: &QueryRow) -> Result { request_count: get_i32(10).context("request_count")?, created_at: get_i64(11).context("created_at")?, updated_at: get_i64(12).context("updated_at")?, + pricing_template_id: get_optional_string(13), }) } @@ -144,13 +163,13 @@ pub fn parse_count(row: &QueryRow) -> Result { .map(|v| v as usize) } -/// 从 QueryRow 提取三元组配置 (config_name, url, api_key) +/// 从 QueryRow 提取五元组配置 (config_name, custom_profile_name, url, api_key, pricing_template_id) /// /// 用于 `get_session_config()` 方法的结果解析 -pub fn parse_session_config(row: &QueryRow) -> Result<(String, String, String)> { - if row.values.len() != 3 { +pub fn parse_session_config(row: &QueryRow) -> Result { + if row.values.len() != 5 { return Err(anyhow!( - "Invalid config row: expected 3 columns, got {}", + "Invalid config row: expected 5 columns, got {}", row.values.len() )); } @@ -160,17 +179,27 @@ pub fn parse_session_config(row: &QueryRow) -> Result<(String, String, String)> .ok_or_else(|| anyhow!("config_name is not a string"))? .to_string(); - let url = row.values[1] + let custom_profile_name = row.values[1].as_str().map(|s| s.to_string()); + + let url = row.values[2] .as_str() .ok_or_else(|| anyhow!("url is not a string"))? .to_string(); - let api_key = row.values[2] + let api_key = row.values[3] .as_str() .ok_or_else(|| anyhow!("api_key is not a string"))? .to_string(); - Ok((config_name, url, api_key)) + let pricing_template_id = row.values[4].as_str().map(|s| s.to_string()); + + Ok(( + config_name, + custom_profile_name, + url, + api_key, + pricing_template_id, + )) } #[cfg(test)] @@ -195,6 +224,7 @@ mod tests { "request_count".to_string(), "created_at".to_string(), "updated_at".to_string(), + "pricing_template_id".to_string(), ], values: vec![ json!("test_session_1"), @@ -210,6 +240,7 @@ mod tests { json!(5), json!(1000), json!(2000), + json!("anthropic_official"), ], }; @@ -228,6 +259,10 @@ mod tests { assert_eq!(session.request_count, 5); assert_eq!(session.created_at, 1000); assert_eq!(session.updated_at, 2000); + assert_eq!( + session.pricing_template_id, + Some("anthropic_official".to_string()) + ); } #[test] @@ -247,6 +282,7 @@ mod tests { "request_count".to_string(), "created_at".to_string(), "updated_at".to_string(), + "pricing_template_id".to_string(), ], values: vec![ json!("test_session_2"), @@ -262,6 +298,7 @@ mod tests { json!(10), json!(3000), json!(4000), + json!(null), // pricing_template_id ], }; @@ -271,6 +308,7 @@ mod tests { assert_eq!(session.config_name, "global"); assert_eq!(session.custom_profile_name, None); assert_eq!(session.note, None); + assert_eq!(session.pricing_template_id, None); assert_eq!(session.request_count, 10); } @@ -290,21 +328,57 @@ mod tests { let row = QueryRow { columns: vec![ "config_name".to_string(), + "custom_profile_name".to_string(), "url".to_string(), "api_key".to_string(), + "pricing_template_id".to_string(), ], values: vec![ json!("custom"), + json!("my-profile"), json!("https://api.test.com"), json!("sk-xxx"), + json!("anthropic_official"), ], }; - let (config_name, url, api_key) = parse_session_config(&row).unwrap(); + let (config_name, custom_profile_name, url, api_key, pricing_template_id) = + parse_session_config(&row).unwrap(); assert_eq!(config_name, "custom"); + assert_eq!(custom_profile_name, Some("my-profile".to_string())); assert_eq!(url, "https://api.test.com"); assert_eq!(api_key, "sk-xxx"); + assert_eq!(pricing_template_id, Some("anthropic_official".to_string())); + } + + #[test] + fn test_parse_session_config_with_null_pricing_template() { + let row = QueryRow { + columns: vec![ + "config_name".to_string(), + "custom_profile_name".to_string(), + "url".to_string(), + "api_key".to_string(), + "pricing_template_id".to_string(), + ], + values: vec![ + json!("global"), + json!(null), + json!("https://api.example.com"), + json!("sk-test"), + json!(null), + ], + }; + + let (config_name, custom_profile_name, url, api_key, pricing_template_id) = + parse_session_config(&row).unwrap(); + + assert_eq!(config_name, "global"); + assert_eq!(custom_profile_name, None); + assert_eq!(url, "https://api.example.com"); + assert_eq!(api_key, "sk-test"); + assert_eq!(pricing_template_id, None); } #[test] @@ -319,6 +393,6 @@ mod tests { assert!(result .unwrap_err() .to_string() - .contains("expected 13 columns")); + .contains("expected 14 columns")); } } diff --git a/src-tauri/src/services/session/manager.rs b/src-tauri/src/services/session/manager.rs index ba2d43b..cbc71d2 100644 --- a/src-tauri/src/services/session/manager.rs +++ b/src-tauri/src/services/session/manager.rs @@ -2,8 +2,8 @@ use crate::data::DataManager; use crate::services::session::db_utils::{ - parse_count, parse_proxy_session, parse_session_config, ALTER_TABLE_SQL, CREATE_TABLE_SQL, - SELECT_SESSION_FIELDS, + parse_count, parse_proxy_session, parse_session_config, SessionConfig, ALTER_TABLE_STATEMENTS, + CREATE_TABLE_SQL, SELECT_SESSION_FIELDS, }; use crate::services::session::models::{ProxySession, SessionEvent, SessionListResponse}; use anyhow::Result; @@ -43,8 +43,10 @@ impl SessionManager { let db = manager_instance.sqlite(&db_path)?; db.execute_raw(CREATE_TABLE_SQL)?; - // 兼容旧数据库(忽略错误) - let _ = db.execute_raw(ALTER_TABLE_SQL); + // 兼容旧数据库(逐个执行 ALTER TABLE,忽略重复列错误) + for stmt in ALTER_TABLE_STATEMENTS { + let _ = db.execute_raw(stmt); + } // 创建事件队列 let (event_sender, event_receiver) = mpsc::unbounded_channel(); @@ -139,6 +141,8 @@ impl SessionManager { /// 批量写入事件到数据库 fn flush_events(manager: &Arc, db_path: &Path, buffer: &mut Vec) { + let mut has_writes = false; + for event in buffer.drain(..) { match event { SessionEvent::NewRequest { @@ -150,8 +154,9 @@ impl SessionManager { if let Some(display_id) = ProxySession::extract_display_id(&session_id) { // Upsert 会话 if let Ok(db) = manager.sqlite(db_path) { - let _ = db.execute( - "INSERT INTO claude_proxy_sessions ( + if db + .execute( + "INSERT INTO claude_proxy_sessions ( session_id, display_id, tool_id, config_name, url, api_key, first_seen_at, last_seen_at, request_count, created_at, updated_at @@ -160,13 +165,24 @@ impl SessionManager { last_seen_at = ?4, request_count = request_count + 1, updated_at = ?4", - &[&session_id, &display_id, &tool_id, ×tamp.to_string()], - ); + &[&session_id, &display_id, &tool_id, ×tamp.to_string()], + ) + .is_ok() + { + has_writes = true; + } } } } } } + + // 批量写入后执行 PASSIVE checkpoint(不阻塞,性能最优) + if has_writes { + if let Ok(db) = manager.sqlite(db_path) { + let _ = db.execute_raw("PRAGMA wal_checkpoint(PASSIVE)"); + } + } } /// 内部清理方法(用于后台任务) @@ -209,7 +225,14 @@ impl SessionManager { 0 }; - Ok(deleted_by_age + deleted_by_count) + let total_deleted = deleted_by_age + deleted_by_count; + + // 执行 WAL checkpoint 回写主文件 + if total_deleted > 0 { + let _ = db.execute_raw("PRAGMA wal_checkpoint(TRUNCATE)"); + } + + Ok(total_deleted) } /// 发送会话事件(公共 API) @@ -264,20 +287,32 @@ impl SessionManager { /// 删除单个会话(公共 API) pub fn delete_session(&self, session_id: &str) -> Result<()> { let db = self.manager.sqlite(&self.db_path)?; - db.execute( + let deleted = db.execute( "DELETE FROM claude_proxy_sessions WHERE session_id = ?", &[session_id], )?; + + // 执行 PASSIVE checkpoint(不阻塞) + if deleted > 0 { + let _ = db.execute_raw("PRAGMA wal_checkpoint(PASSIVE)"); + } + Ok(()) } /// 清空工具所有会话(公共 API) pub fn clear_sessions(&self, tool_id: &str) -> Result<()> { let db = self.manager.sqlite(&self.db_path)?; - db.execute( + let deleted = db.execute( "DELETE FROM claude_proxy_sessions WHERE tool_id = ?", &[tool_id], )?; + + // 执行 PASSIVE checkpoint(不阻塞) + if deleted > 0 { + let _ = db.execute_raw("PRAGMA wal_checkpoint(PASSIVE)"); + } + Ok(()) } @@ -298,11 +333,11 @@ impl SessionManager { } /// 获取会话配置(公共 API,用于请求处理) - /// 返回 (config_name, url, api_key) - pub fn get_session_config(&self, session_id: &str) -> Result> { + /// 返回 (config_name, custom_profile_name, url, api_key, pricing_template_id) + pub fn get_session_config(&self, session_id: &str) -> Result> { let db = self.manager.sqlite(&self.db_path)?; let rows = db.query( - "SELECT config_name, url, api_key FROM claude_proxy_sessions WHERE session_id = ?", + "SELECT config_name, custom_profile_name, url, api_key, pricing_template_id FROM claude_proxy_sessions WHERE session_id = ?", &[session_id], )?; @@ -321,24 +356,31 @@ impl SessionManager { custom_profile_name: Option<&str>, url: &str, api_key: &str, + pricing_template_id: Option<&str>, // Phase 6: 价格模板 ) -> Result<()> { let db = self.manager.sqlite(&self.db_path)?; let now = chrono::Utc::now().timestamp(); - db.execute( + let updated = db.execute( "UPDATE claude_proxy_sessions - SET config_name = ?, custom_profile_name = ?, url = ?, api_key = ?, updated_at = ? + SET config_name = ?, custom_profile_name = ?, url = ?, api_key = ?, pricing_template_id = ?, updated_at = ? WHERE session_id = ?", &[ config_name, custom_profile_name.unwrap_or(""), url, api_key, + pricing_template_id.unwrap_or(""), &now.to_string(), session_id, ], )?; + // 执行 PASSIVE checkpoint(不阻塞) + if updated > 0 { + let _ = db.execute_raw("PRAGMA wal_checkpoint(PASSIVE)"); + } + Ok(()) } @@ -347,11 +389,16 @@ impl SessionManager { let db = self.manager.sqlite(&self.db_path)?; let now = chrono::Utc::now().timestamp(); - db.execute( + let updated = db.execute( "UPDATE claude_proxy_sessions SET note = ?, updated_at = ? WHERE session_id = ?", &[note.unwrap_or(""), &now.to_string(), session_id], )?; + // 执行 PASSIVE checkpoint(不阻塞) + if updated > 0 { + let _ = db.execute_raw("PRAGMA wal_checkpoint(PASSIVE)"); + } + Ok(()) } } @@ -381,7 +428,9 @@ mod tests { // 初始化数据库 let db = manager_instance.sqlite(&db_path).unwrap(); db.execute_raw(CREATE_TABLE_SQL).unwrap(); - let _ = db.execute_raw(ALTER_TABLE_SQL); + for stmt in ALTER_TABLE_STATEMENTS { + let _ = db.execute_raw(stmt); + } let (event_sender, event_receiver) = mpsc::unbounded_channel(); @@ -488,6 +537,7 @@ mod tests { Some("my-profile"), "https://api.test.com", "sk-test", + None, // pricing_template_id ) .unwrap(); diff --git a/src-tauri/src/services/session/mod.rs b/src-tauri/src/services/session/mod.rs index 2dfdee8..4d494f9 100644 --- a/src-tauri/src/services/session/mod.rs +++ b/src-tauri/src/services/session/mod.rs @@ -4,5 +4,5 @@ mod db_utils; pub mod manager; pub mod models; -pub use manager::SESSION_MANAGER; +pub use manager::{shutdown_session_manager, SESSION_MANAGER}; pub use models::{ProxySession, SessionEvent, SessionListResponse}; diff --git a/src-tauri/src/services/session/models.rs b/src-tauri/src/services/session/models.rs index 03065bb..5f06fce 100644 --- a/src-tauri/src/services/session/models.rs +++ b/src-tauri/src/services/session/models.rs @@ -31,6 +31,9 @@ pub struct ProxySession { pub created_at: i64, /// 更新时间(Unix 时间戳,秒) pub updated_at: i64, + /// 价格模板 ID(用于成本计算) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pricing_template_id: Option, } /// 会话事件(异步队列传递) diff --git a/src-tauri/src/services/token_stats/analytics.rs b/src-tauri/src/services/token_stats/analytics.rs new file mode 100644 index 0000000..c5319ba --- /dev/null +++ b/src-tauri/src/services/token_stats/analytics.rs @@ -0,0 +1,540 @@ +//! Token 统计分析模块 +//! +//! 提供趋势分析和成本汇总查询功能 + +use crate::data::DataManager; +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// 时间粒度 +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum TimeGranularity { + /// 15分钟粒度 + FifteenMinutes, + /// 30分钟粒度 + ThirtyMinutes, + /// 小时粒度 + Hour, + /// 12小时粒度 + TwelveHours, + /// 天粒度 + #[default] + Day, +} + +/// 趋势查询参数 +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct TrendQuery { + /// 开始时间戳(毫秒) + pub start_time: Option, + /// 结束时间戳(毫秒) + pub end_time: Option, + /// 工具类型过滤 + pub tool_type: Option, + /// 模型过滤 + pub model: Option, + /// 配置名称过滤 + pub config_name: Option, + /// 会话 ID 过滤 + pub session_id: Option, + /// 时间粒度 + pub granularity: TimeGranularity, +} + +/// 趋势数据点 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrendDataPoint { + /// 时间戳(毫秒) + pub timestamp: i64, + /// 输入 Token 总数 + pub input_tokens: i64, + /// 输出 Token 总数 + pub output_tokens: i64, + /// 缓存写入 Token 总数 + pub cache_creation_tokens: i64, + /// 缓存读取 Token 总数 + pub cache_read_tokens: i64, + /// 总成本(USD) + pub total_cost: f64, + /// 输入部分成本(USD) + pub input_price: f64, + /// 输出部分成本(USD) + pub output_price: f64, + /// 缓存写入部分成本(USD) + pub cache_write_price: f64, + /// 缓存读取部分成本(USD) + pub cache_read_price: f64, + /// 请求总数 + pub request_count: i64, + /// 错误请求数 + pub error_count: i64, + /// 平均响应时间(毫秒) + pub avg_response_time: Option, +} + +/// 成本汇总分组方式 +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum CostGroupBy { + /// 按模型分组 + #[default] + Model, + /// 按配置分组 + Config, + /// 按会话分组 + Session, +} + +/// 成本汇总查询参数 +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CostSummaryQuery { + /// 开始时间戳(毫秒) + pub start_time: Option, + /// 结束时间戳(毫秒) + pub end_time: Option, + /// 工具类型过滤 + pub tool_type: Option, + /// 会话 ID 过滤 + pub session_id: Option, + /// 分组方式 + pub group_by: CostGroupBy, +} + +/// 成本汇总数据 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostSummary { + /// 分组字段名称(model/config_name/session_id) + pub group_name: String, + /// 总成本(USD) + pub total_cost: f64, + /// 请求总数 + pub request_count: i64, + /// 输入 Token 总数 + pub input_tokens: i64, + /// 输出 Token 总数 + pub output_tokens: i64, + /// 平均响应时间(毫秒) + pub avg_response_time: Option, +} + +/// Token 统计分析服务 +pub struct TokenStatsAnalytics { + db_path: PathBuf, +} + +impl TokenStatsAnalytics { + /// 创建新的分析服务实例 + pub fn new(db_path: PathBuf) -> Self { + Self { db_path } + } + + /// 查询趋势数据 + pub fn query_trends(&self, query: &TrendQuery) -> Result> { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + // 构建时间分组表达式 + let time_expr = match query.granularity { + TimeGranularity::FifteenMinutes => { + // 按15分钟分组:向下取整到最近的15分钟 + "CAST((timestamp / 900000) * 900000 AS INTEGER)" + } + TimeGranularity::ThirtyMinutes => { + // 按30分钟分组:向下取整到最近的30分钟 + "CAST((timestamp / 1800000) * 1800000 AS INTEGER)" + } + TimeGranularity::Hour => { + // 按小时分组 + "CAST((timestamp / 3600000) * 3600000 AS INTEGER)" + } + TimeGranularity::TwelveHours => { + // 按12小时分组 + "CAST((timestamp / 43200000) * 43200000 AS INTEGER)" + } + TimeGranularity::Day => { + // 按天分组 + "CAST((timestamp / 86400000) * 86400000 AS INTEGER)" + } + }; + + // 构建 WHERE 子句 + let mut where_clauses = Vec::new(); + let mut params: Vec> = Vec::new(); + + if let Some(start_time) = query.start_time { + where_clauses.push("timestamp >= ?"); + params.push(Box::new(start_time)); + } + + if let Some(end_time) = query.end_time { + where_clauses.push("timestamp <= ?"); + params.push(Box::new(end_time)); + } + + if let Some(ref tool_type) = query.tool_type { + where_clauses.push("tool_type = ?"); + params.push(Box::new(tool_type.clone())); + } + + if let Some(ref model) = query.model { + where_clauses.push("model = ?"); + params.push(Box::new(model.clone())); + } + + if let Some(ref config_name) = query.config_name { + where_clauses.push("config_name = ?"); + params.push(Box::new(config_name.clone())); + } + + if let Some(ref session_id) = query.session_id { + where_clauses.push("session_id = ?"); + params.push(Box::new(session_id.clone())); + } + + let where_clause = if where_clauses.is_empty() { + String::new() + } else { + format!("WHERE {}", where_clauses.join(" AND ")) + }; + + // 构建完整 SQL + let sql = format!( + "SELECT + {} as timestamp, + SUM(input_tokens) as input_tokens, + SUM(output_tokens) as output_tokens, + SUM(cache_creation_tokens) as cache_creation_tokens, + SUM(cache_read_tokens) as cache_read_tokens, + SUM(total_cost) as total_cost, + SUM(COALESCE(input_price, 0.0)) as input_price, + SUM(COALESCE(output_price, 0.0)) as output_price, + SUM(COALESCE(cache_write_price, 0.0)) as cache_write_price, + SUM(COALESCE(cache_read_price, 0.0)) as cache_read_price, + COUNT(*) as request_count, + SUM(CASE WHEN request_status = 'error' THEN 1 ELSE 0 END) as error_count, + AVG(response_time_ms) as avg_response_time + FROM token_logs + {} + GROUP BY {} + ORDER BY timestamp", + time_expr, where_clause, time_expr + ); + + // 执行查询 + let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect(); + + let db_trends = manager.transaction(|tx| { + let mut stmt = tx.prepare(&sql)?; + let trends = stmt + .query_map(param_refs.as_slice(), |row| { + Ok(TrendDataPoint { + timestamp: row.get(0)?, + input_tokens: row.get(1)?, + output_tokens: row.get(2)?, + cache_creation_tokens: row.get(3)?, + cache_read_tokens: row.get(4)?, + total_cost: row.get(5)?, + input_price: row.get(6)?, + output_price: row.get(7)?, + cache_write_price: row.get(8)?, + cache_read_price: row.get(9)?, + request_count: row.get(10)?, + error_count: row.get(11)?, + avg_response_time: row.get(12)?, + }) + })? + .collect::, _>>() + .map_err(crate::data::DataError::Database)?; + Ok(trends) + })?; + + // 如果没有指定时间范围,直接返回查询结果 + if query.start_time.is_none() || query.end_time.is_none() { + return Ok(db_trends); + } + + // 填充缺失的时间点 + let filled_trends = self.fill_missing_time_points( + db_trends, + query.start_time.unwrap(), + query.end_time.unwrap(), + query.granularity, + ); + + Ok(filled_trends) + } + + /// 填充缺失的时间点,确保所有时间段都有数据(即使为0) + fn fill_missing_time_points( + &self, + db_trends: Vec, + start_time: i64, + end_time: i64, + granularity: TimeGranularity, + ) -> Vec { + use std::collections::HashMap; + + // 计算时间间隔(毫秒) + let interval_ms = match granularity { + TimeGranularity::FifteenMinutes => 15 * 60 * 1000, + TimeGranularity::ThirtyMinutes => 30 * 60 * 1000, + TimeGranularity::Hour => 60 * 60 * 1000, + TimeGranularity::TwelveHours => 12 * 60 * 60 * 1000, + TimeGranularity::Day => 24 * 60 * 60 * 1000, + }; + + // 将数据库结果转换为 HashMap 以便快速查找 + let mut data_map: HashMap = HashMap::new(); + for point in db_trends { + data_map.insert(point.timestamp, point); + } + + // 生成完整的时间序列 + let mut result = Vec::new(); + let mut current_time = (start_time / interval_ms) * interval_ms; // 向下取整到粒度边界 + + while current_time <= end_time { + let point = if let Some(existing) = data_map.get(¤t_time) { + // 如果有数据,使用数据库的值 + existing.clone() + } else { + // 如果没有数据,创建零值数据点 + TrendDataPoint { + timestamp: current_time, + input_tokens: 0, + output_tokens: 0, + cache_creation_tokens: 0, + cache_read_tokens: 0, + total_cost: 0.0, + input_price: 0.0, + output_price: 0.0, + cache_write_price: 0.0, + cache_read_price: 0.0, + request_count: 0, + error_count: 0, + avg_response_time: None, + } + }; + result.push(point); + current_time += interval_ms; + } + + result + } + + /// 查询成本汇总数据 + pub fn query_cost_summary(&self, query: &CostSummaryQuery) -> Result> { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + // 确定分组字段 + let group_field = match query.group_by { + CostGroupBy::Model => "model", + CostGroupBy::Config => "config_name", + CostGroupBy::Session => "session_id", + }; + + // 构建 WHERE 子句 + let mut where_clauses = Vec::new(); + let mut params: Vec> = Vec::new(); + + if let Some(start_time) = query.start_time { + where_clauses.push("timestamp >= ?"); + params.push(Box::new(start_time)); + } + + if let Some(end_time) = query.end_time { + where_clauses.push("timestamp <= ?"); + params.push(Box::new(end_time)); + } + + if let Some(ref tool_type) = query.tool_type { + where_clauses.push("tool_type = ?"); + params.push(Box::new(tool_type.clone())); + } + + if let Some(ref session_id) = query.session_id { + where_clauses.push("session_id = ?"); + params.push(Box::new(session_id.clone())); + } + + let where_clause = if where_clauses.is_empty() { + String::new() + } else { + format!("WHERE {}", where_clauses.join(" AND ")) + }; + + // 构建完整 SQL + let sql = format!( + "SELECT + {} as group_name, + SUM(total_cost) as total_cost, + COUNT(*) as request_count, + SUM(input_tokens) as input_tokens, + SUM(output_tokens) as output_tokens, + AVG(response_time_ms) as avg_response_time + FROM token_logs + {} + GROUP BY {} + ORDER BY total_cost DESC", + group_field, where_clause, group_field + ); + + // 执行查询 + let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|p| p.as_ref()).collect(); + + Ok(manager.transaction(|tx| { + let mut stmt = tx.prepare(&sql)?; + let summaries = stmt + .query_map(param_refs.as_slice(), |row| { + Ok(CostSummary { + group_name: row.get(0)?, + total_cost: row.get(1)?, + request_count: row.get(2)?, + input_tokens: row.get(3)?, + output_tokens: row.get(4)?, + avg_response_time: row.get(5)?, + }) + })? + .collect::, _>>() + .map_err(crate::data::DataError::Database)?; + Ok(summaries) + })?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::token_stats::TokenLog; + use crate::services::token_stats::db::TokenStatsDb; + use chrono::TimeZone; + use tempfile::tempdir; + + #[test] + fn test_query_trends() { + // 创建临时数据库 + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test_trends.db"); + let db = TokenStatsDb::new(db_path.clone()); + db.init_table().unwrap(); + + // 插入测试数据(使用固定时间避免跨日期边界) + let base_time = chrono::Utc + .with_ymd_and_hms(2026, 1, 10, 12, 0, 0) + .unwrap() + .timestamp_millis(); + + for i in 0..10 { + let log = TokenLog::new( + "claude_code".to_string(), + base_time - (i * 3600 * 1000), // 每小时一条 + "127.0.0.1".to_string(), + "test_session".to_string(), + "default".to_string(), + "claude-sonnet-4-5-20250929".to_string(), + Some(format!("msg_{}", i)), + 100, + 50, + 10, + 20, + "success".to_string(), + "json".to_string(), + None, + None, + Some(100), + Some(0.001), + Some(0.002), + Some(0.0001), + Some(0.0002), + 0.0033, + Some("test_template".to_string()), + ); + db.insert_log(&log).unwrap(); + } + + // 查询趋势数据 + let analytics = TokenStatsAnalytics::new(db_path); + let query = TrendQuery { + tool_type: Some("claude_code".to_string()), + granularity: TimeGranularity::Hour, + ..Default::default() + }; + + let trends = analytics.query_trends(&query).unwrap(); + + // 验证结果 + assert_eq!(trends.len(), 10); + assert_eq!(trends[0].input_tokens, 100); + assert_eq!(trends[0].output_tokens, 50); + assert!((trends[0].total_cost - 0.0033).abs() < 0.0001); + assert_eq!(trends[0].request_count, 1); + assert_eq!(trends[0].error_count, 0); + } + + #[test] + fn test_query_cost_summary() { + // 创建临时数据库 + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test_cost_summary.db"); + let db = TokenStatsDb::new(db_path.clone()); + db.init_table().unwrap(); + + // 插入测试数据(多个会话,使用固定时间) + let base_time = chrono::Utc + .with_ymd_and_hms(2026, 1, 10, 12, 0, 0) + .unwrap() + .timestamp_millis(); + + for session_idx in 0..3 { + for i in 0..5 { + let log = TokenLog::new( + "claude_code".to_string(), + base_time - (i * 1000), + "127.0.0.1".to_string(), + format!("session_{}", session_idx), + "default".to_string(), + "claude-sonnet-4-5-20250929".to_string(), + Some(format!("msg_{}_{}", session_idx, i)), + 100, + 50, + 10, + 20, + "success".to_string(), + "json".to_string(), + None, + None, + Some(100), + Some(0.001), + Some(0.002), + Some(0.0001), + Some(0.0002), + 0.0033, + Some("test_template".to_string()), + ); + db.insert_log(&log).unwrap(); + } + } + + // 查询成本汇总 + let analytics = TokenStatsAnalytics::new(db_path); + let query = CostSummaryQuery { + tool_type: Some("claude_code".to_string()), + group_by: CostGroupBy::Session, + ..Default::default() + }; + + let summaries = analytics.query_cost_summary(&query).unwrap(); + + // 验证结果 + assert_eq!(summaries.len(), 3); // 3个会话 + for summary in &summaries { + assert_eq!(summary.request_count, 5); // 每个会话5条记录 + assert!((summary.total_cost - 0.0165).abs() < 0.001); // 0.0033 * 5 + } + } +} diff --git a/src-tauri/src/services/token_stats/cost_calculation_test.rs b/src-tauri/src/services/token_stats/cost_calculation_test.rs new file mode 100644 index 0000000..fc33fff --- /dev/null +++ b/src-tauri/src/services/token_stats/cost_calculation_test.rs @@ -0,0 +1,202 @@ +//! 成本计算集成测试 +//! +//! 验证成本计算逻辑是否正确工作 + +#[cfg(test)] +mod tests { + use crate::services::pricing::PRICING_MANAGER; + use crate::services::token_stats::create_extractor; + use serde_json::json; + + #[test] + fn test_cost_calculation_with_claude_3_5_sonnet() { + // 测试 Claude 3.5 Sonnet 20241022 版本的成本计算 + let model = "claude-sonnet-4-5-20250929"; + let input_tokens = 100; + let output_tokens = 50; + let cache_creation_tokens = 10; + let cache_read_tokens = 20; + + // 使用默认模板计算成本 + let result = PRICING_MANAGER.calculate_cost( + None, // 使用默认模板 + model, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + ); + + // 验证计算成功 + assert!(result.is_ok(), "成本计算应该成功: {:?}", result.err()); + + let breakdown = result.unwrap(); + + // 验证使用了正确的模板 + assert_eq!(breakdown.template_id, "builtin_claude"); + + // 验证成本计算正确(Claude 3.5 Sonnet: $3/1M input, $15/1M output) + // input: 100 * 3.0 / 1,000,000 = 0.0003 + // output: 50 * 15.0 / 1,000,000 = 0.00075 + // cache_write: 10 * 3.75 / 1,000,000 = 0.0000375 + // cache_read: 20 * 0.3 / 1,000,000 = 0.000006 + // total: 0.0003 + 0.00075 + 0.0000375 + 0.000006 = 0.0010935 + + println!("实际计算结果:"); + println!(" 输入价格: {:.10}", breakdown.input_price); + println!(" 输出价格: {:.10}", breakdown.output_price); + println!(" 缓存写入价格: {:.10}", breakdown.cache_write_price); + println!(" 缓存读取价格: {:.10}", breakdown.cache_read_price); + println!(" 总成本: {:.10}", breakdown.total_cost); + + assert!((breakdown.input_price - 0.0003).abs() < 1e-9); + assert!((breakdown.output_price - 0.00075).abs() < 1e-9); + assert!((breakdown.cache_write_price - 0.0000375).abs() < 1e-9); + assert!((breakdown.cache_read_price - 0.000006).abs() < 1e-9); + assert!( + (breakdown.total_cost - 0.0010935).abs() < 1e-7, + "expected 0.0010935, got {}", + breakdown.total_cost + ); + + println!("✅ 成本计算测试通过:"); + println!(" 模型: {}", model); + println!(" 输入价格: {:.6}", breakdown.input_price); + println!(" 输出价格: {:.6}", breakdown.output_price); + println!(" 缓存写入价格: {:.6}", breakdown.cache_write_price); + println!(" 缓存读取价格: {:.6}", breakdown.cache_read_price); + println!(" 总成本: {:.6}", breakdown.total_cost); + } + + #[test] + fn test_cost_calculation_with_different_models() { + // 测试不同模型的成本计算 + + // Claude Opus 4.5: $5 input / $25 output + let opus_result = PRICING_MANAGER.calculate_cost(None, "claude-opus-4.5", 1000, 500, 0, 0); + assert!(opus_result.is_ok()); + let opus_breakdown = opus_result.unwrap(); + assert!((opus_breakdown.input_price - 0.005).abs() < 1e-9); // 1000 * 5 / 1M + assert!((opus_breakdown.output_price - 0.0125).abs() < 1e-9); // 500 * 25 / 1M + + // Claude Sonnet 4.5: $3 input / $15 output + let sonnet_result = + PRICING_MANAGER.calculate_cost(None, "claude-sonnet-4.5", 1000, 500, 0, 0); + assert!(sonnet_result.is_ok()); + let sonnet_breakdown = sonnet_result.unwrap(); + assert!((sonnet_breakdown.input_price - 0.003).abs() < 1e-9); // 1000 * 3 / 1M + assert!((sonnet_breakdown.output_price - 0.0075).abs() < 1e-9); // 500 * 15 / 1M + + // Claude Haiku 3.5: $0.8 input / $4 output + let haiku_result = + PRICING_MANAGER.calculate_cost(None, "claude-haiku-3.5", 1000, 500, 0, 0); + assert!(haiku_result.is_ok()); + let haiku_breakdown = haiku_result.unwrap(); + assert!((haiku_breakdown.input_price - 0.0008).abs() < 1e-9); // 1000 * 0.8 / 1M + assert!((haiku_breakdown.output_price - 0.002).abs() < 1e-9); // 500 * 4 / 1M + + println!("✅ 多模型成本计算测试通过"); + } + + #[test] + fn test_token_extraction_from_response() { + // 测试从响应中提取 Token 信息 + let extractor = create_extractor("claude_code").unwrap(); + + let response_json = json!({ + "id": "msg_test_123", + "model": "claude-sonnet-4-5-20250929", + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_creation_input_tokens": 10, + "cache_read_input_tokens": 20 + } + }); + + let token_info = extractor.extract_from_json(&response_json).unwrap(); + + assert_eq!(token_info.input_tokens, 100); + assert_eq!(token_info.output_tokens, 50); + assert_eq!(token_info.cache_creation_tokens, 10); + assert_eq!(token_info.cache_read_tokens, 20); + assert_eq!(token_info.message_id, "msg_test_123"); + + println!("✅ Token 提取测试通过"); + } + + #[test] + fn test_end_to_end_cost_calculation() { + // 端到端测试:从响应提取 Token -> 计算成本 + let extractor = create_extractor("claude_code").unwrap(); + + let response_json = json!({ + "id": "msg_end_to_end", + "model": "claude-sonnet-4-5-20250929", + "usage": { + "input_tokens": 1000, + "output_tokens": 500, + "cache_creation_input_tokens": 100, + "cache_read_input_tokens": 200 + } + }); + + // 步骤1: 提取 Token + let token_info = extractor.extract_from_json(&response_json).unwrap(); + + // 步骤2: 计算成本 + let result = PRICING_MANAGER.calculate_cost( + None, + "claude-sonnet-4-5-20250929", + token_info.input_tokens, + token_info.output_tokens, + token_info.cache_creation_tokens, + token_info.cache_read_tokens, + ); + + assert!(result.is_ok()); + let breakdown = result.unwrap(); + + // 验证总成本不为 0 + assert!(breakdown.total_cost > 0.0, "总成本应该大于 0"); + + // 预期成本计算: + // input: 1000 * 3.0 / 1,000,000 = 0.003 + // output: 500 * 15.0 / 1,000,000 = 0.0075 + // cache_write: 100 * 3.75 / 1,000,000 = 0.000375 + // cache_read: 200 * 0.3 / 1,000,000 = 0.00006 + // total: 0.003 + 0.0075 + 0.000375 + 0.00006 = 0.010935 + + println!("端到端实际计算结果:"); + println!(" 输入价格: {:.10}", breakdown.input_price); + println!(" 输出价格: {:.10}", breakdown.output_price); + println!(" 缓存写入价格: {:.10}", breakdown.cache_write_price); + println!(" 缓存读取价格: {:.10}", breakdown.cache_read_price); + println!(" 总成本: {:.10}", breakdown.total_cost); + + assert!( + (breakdown.total_cost - 0.010935).abs() < 1e-6, + "expected 0.010935, got {}", + breakdown.total_cost + ); + + println!("✅ 端到端成本计算测试通过"); + println!( + " 输入: {} tokens -> ${:.6}", + token_info.input_tokens, breakdown.input_price + ); + println!( + " 输出: {} tokens -> ${:.6}", + token_info.output_tokens, breakdown.output_price + ); + println!( + " 缓存写入: {} tokens -> ${:.6}", + token_info.cache_creation_tokens, breakdown.cache_write_price + ); + println!( + " 缓存读取: {} tokens -> ${:.6}", + token_info.cache_read_tokens, breakdown.cache_read_price + ); + println!(" 总成本: ${:.6}", breakdown.total_cost); + } +} diff --git a/src-tauri/src/services/token_stats/db.rs b/src-tauri/src/services/token_stats/db.rs new file mode 100644 index 0000000..fb71864 --- /dev/null +++ b/src-tauri/src/services/token_stats/db.rs @@ -0,0 +1,758 @@ +use crate::data::DataManager; +use crate::models::token_stats::{SessionStats, TokenLog, TokenLogsPage, TokenStatsQuery}; +use anyhow::{Context, Result}; +use std::path::PathBuf; + +/// Token统计数据库操作层 +pub struct TokenStatsDb { + db_path: PathBuf, +} + +impl TokenStatsDb { + /// 创建新的数据库操作实例 + pub fn new(db_path: PathBuf) -> Self { + Self { db_path } + } + + /// 初始化数据库表 + pub fn init_table(&self) -> Result<()> { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + // 启用 WAL 模式(提升并发性能) + manager + .execute_raw("PRAGMA journal_mode=WAL") + .context("Failed to enable WAL mode")?; + + // 创建表(Schema v3 - 包含成本分析字段) + manager + .execute_raw( + "CREATE TABLE IF NOT EXISTS token_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tool_type TEXT NOT NULL, + timestamp INTEGER NOT NULL, + client_ip TEXT NOT NULL, + session_id TEXT NOT NULL, + config_name TEXT NOT NULL, + model TEXT NOT NULL, + message_id TEXT, + + -- Token 数量 + input_tokens INTEGER NOT NULL DEFAULT 0, + output_tokens INTEGER NOT NULL DEFAULT 0, + cache_creation_tokens INTEGER NOT NULL DEFAULT 0, + cache_read_tokens INTEGER NOT NULL DEFAULT 0, + + -- 请求状态 + request_status TEXT NOT NULL DEFAULT 'success', + response_type TEXT NOT NULL DEFAULT 'unknown', + error_type TEXT, + error_detail TEXT, + + -- 各部分的价格(USD) + input_price REAL, + output_price REAL, + cache_write_price REAL, + cache_read_price REAL, + + -- 总成本(USD) + total_cost REAL NOT NULL DEFAULT 0.0, + + -- 响应时间 + response_time_ms INTEGER, + + -- 价格模板 ID + pricing_template_id TEXT, + + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP + )", + ) + .context("Failed to create token_logs table")?; + + // 创建索引 + manager + .execute_raw( + "CREATE INDEX IF NOT EXISTS idx_session_timestamp + ON token_logs(session_id, timestamp)", + ) + .context("Failed to create session_timestamp index")?; + + manager + .execute_raw( + "CREATE INDEX IF NOT EXISTS idx_timestamp + ON token_logs(timestamp)", + ) + .context("Failed to create timestamp index")?; + + manager + .execute_raw( + "CREATE INDEX IF NOT EXISTS idx_tool_type + ON token_logs(tool_type)", + ) + .context("Failed to create tool_type index")?; + + // 添加成本分析相关索引(Phase 1) + manager + .execute_raw( + "CREATE INDEX IF NOT EXISTS idx_model + ON token_logs(model)", + ) + .context("Failed to create model index")?; + + manager + .execute_raw( + "CREATE INDEX IF NOT EXISTS idx_total_cost + ON token_logs(total_cost)", + ) + .context("Failed to create total_cost index")?; + + manager + .execute_raw( + "CREATE INDEX IF NOT EXISTS idx_timestamp_cost + ON token_logs(timestamp, total_cost)", + ) + .context("Failed to create timestamp_cost index")?; + + manager + .execute_raw( + "CREATE INDEX IF NOT EXISTS idx_tool_model + ON token_logs(tool_type, model)", + ) + .context("Failed to create tool_model index")?; + + Ok(()) + } + + /// 插入单条日志记录 + pub fn insert_log(&self, log: &TokenLog) -> Result { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + let params = vec![ + log.tool_type.clone(), + log.timestamp.to_string(), + log.client_ip.clone(), + log.session_id.clone(), + log.config_name.clone(), + log.model.clone(), + log.message_id.clone().unwrap_or_default(), + log.input_tokens.to_string(), + log.output_tokens.to_string(), + log.cache_creation_tokens.to_string(), + log.cache_read_tokens.to_string(), + log.request_status.clone(), + log.response_type.clone(), + log.error_type.clone().unwrap_or_default(), + log.error_detail.clone().unwrap_or_default(), + log.response_time_ms + .map(|v| v.to_string()) + .unwrap_or_default(), + log.input_price.map(|v| v.to_string()).unwrap_or_default(), + log.output_price.map(|v| v.to_string()).unwrap_or_default(), + log.cache_write_price + .map(|v| v.to_string()) + .unwrap_or_default(), + log.cache_read_price + .map(|v| v.to_string()) + .unwrap_or_default(), + log.total_cost.to_string(), + log.pricing_template_id.clone().unwrap_or_default(), + ]; + + let params_refs: Vec<&str> = params.iter().map(|s| s.as_str()).collect(); + + manager + .execute( + "INSERT INTO token_logs ( + tool_type, timestamp, client_ip, session_id, config_name, + model, message_id, input_tokens, output_tokens, + cache_creation_tokens, cache_read_tokens, + request_status, response_type, error_type, error_detail, + response_time_ms, input_price, output_price, cache_write_price, cache_read_price, + total_cost, pricing_template_id + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22)", + ¶ms_refs, + ) + .context("Failed to insert token log")?; + + // 执行 WAL checkpoint(TRUNCATE 模式,立即回写主文件) + // 注意:这会稍微影响写入性能,但确保数据及时持久化 + let _ = manager.execute_raw("PRAGMA wal_checkpoint(TRUNCATE)"); + + // 获取最后插入的ID(通过查询max(id)) + let rows = manager + .query("SELECT max(id) as last_id FROM token_logs", &[]) + .context("Failed to query last insert id")?; + + let id = rows + .first() + .and_then(|row| row.values.first()) + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + Ok(id) + } + + /// 插入单条日志记录(不执行 checkpoint) + /// + /// 用于批量写入场景,在批量写入后统一执行 checkpoint + pub fn insert_log_without_checkpoint(&self, log: &TokenLog) -> Result { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + let params = vec![ + log.tool_type.clone(), + log.timestamp.to_string(), + log.client_ip.clone(), + log.session_id.clone(), + log.config_name.clone(), + log.model.clone(), + log.message_id.clone().unwrap_or_default(), + log.input_tokens.to_string(), + log.output_tokens.to_string(), + log.cache_creation_tokens.to_string(), + log.cache_read_tokens.to_string(), + log.request_status.clone(), + log.response_type.clone(), + log.error_type.clone().unwrap_or_default(), + log.error_detail.clone().unwrap_or_default(), + log.response_time_ms + .map(|v| v.to_string()) + .unwrap_or_default(), + log.input_price.map(|v| v.to_string()).unwrap_or_default(), + log.output_price.map(|v| v.to_string()).unwrap_or_default(), + log.cache_write_price + .map(|v| v.to_string()) + .unwrap_or_default(), + log.cache_read_price + .map(|v| v.to_string()) + .unwrap_or_default(), + log.total_cost.to_string(), + log.pricing_template_id.clone().unwrap_or_default(), + ]; + + let params_refs: Vec<&str> = params.iter().map(|s| s.as_str()).collect(); + + manager + .execute( + "INSERT INTO token_logs ( + tool_type, timestamp, client_ip, session_id, config_name, + model, message_id, input_tokens, output_tokens, + cache_creation_tokens, cache_read_tokens, + request_status, response_type, error_type, error_detail, + response_time_ms, input_price, output_price, cache_write_price, cache_read_price, + total_cost, pricing_template_id + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22)", + ¶ms_refs, + ) + .context("Failed to insert token log")?; + + // 获取最后插入的ID + let rows = manager + .query("SELECT max(id) as last_id FROM token_logs", &[]) + .context("Failed to query last insert id")?; + + let id = rows + .first() + .and_then(|row| row.values.first()) + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + Ok(id) + } + + /// 查询会话统计数据 + pub fn get_session_stats(&self, tool_type: &str, session_id: &str) -> Result { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + let rows = manager + .query( + "SELECT + COALESCE(SUM(input_tokens), 0) as total_input, + COALESCE(SUM(output_tokens), 0) as total_output, + COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation, + COALESCE(SUM(cache_read_tokens), 0) as total_cache_read, + COUNT(*) as request_count + FROM token_logs + WHERE session_id = ?1 AND tool_type = ?2", + &[session_id, tool_type], + ) + .context("Failed to query session stats")?; + + let row = rows.first().context("No stats row returned")?; + + Ok(SessionStats { + total_input: row.values.first().and_then(|v| v.as_i64()).unwrap_or(0), + total_output: row.values.get(1).and_then(|v| v.as_i64()).unwrap_or(0), + total_cache_creation: row.values.get(2).and_then(|v| v.as_i64()).unwrap_or(0), + total_cache_read: row.values.get(3).and_then(|v| v.as_i64()).unwrap_or(0), + request_count: row.values.get(4).and_then(|v| v.as_i64()).unwrap_or(0), + }) + } + + /// 分页查询日志记录 + pub fn query_logs(&self, query: &TokenStatsQuery) -> Result { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + // 构建查询条件 + let mut where_clauses = Vec::new(); + let mut params = Vec::new(); + + if let Some(ref tool_type) = query.tool_type { + where_clauses.push("tool_type = ?"); + params.push(tool_type.clone()); + } + + if let Some(ref session_id) = query.session_id { + where_clauses.push("session_id = ?"); + params.push(session_id.clone()); + } + + if let Some(ref config_name) = query.config_name { + where_clauses.push("config_name = ?"); + params.push(config_name.clone()); + } + + if let Some(start_time) = query.start_time { + where_clauses.push("timestamp >= ?"); + params.push(start_time.to_string()); + } + + if let Some(end_time) = query.end_time { + where_clauses.push("timestamp <= ?"); + params.push(end_time.to_string()); + } + + let where_clause = if where_clauses.is_empty() { + String::new() + } else { + format!("WHERE {}", where_clauses.join(" AND ")) + }; + + // 查询总数 + let count_sql = format!("SELECT COUNT(*) FROM token_logs {}", where_clause); + let params_refs: Vec<&str> = params.iter().map(|s| s.as_str()).collect(); + + let count_rows = manager + .query(&count_sql, ¶ms_refs) + .context("Failed to query total count")?; + + let total: i64 = count_rows + .first() + .and_then(|row| row.values.first()) + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + // 查询日志列表 + let offset = query.page * query.page_size; + let list_sql = format!( + "SELECT id, tool_type, timestamp, client_ip, session_id, config_name, + model, message_id, input_tokens, output_tokens, + cache_creation_tokens, cache_read_tokens, + request_status, response_type, error_type, error_detail, + response_time_ms, input_price, output_price, cache_write_price, cache_read_price, + total_cost, pricing_template_id + FROM token_logs {} + ORDER BY timestamp DESC + LIMIT ? OFFSET ?", + where_clause + ); + + let mut list_params = params.clone(); + list_params.push(query.page_size.to_string()); + list_params.push(offset.to_string()); + + let list_params_refs: Vec<&str> = list_params.iter().map(|s| s.as_str()).collect(); + + let list_rows = manager + .query(&list_sql, &list_params_refs) + .context("Failed to query logs")?; + + let logs = list_rows + .iter() + .map(|row| { + Ok(TokenLog { + id: row.values.first().and_then(|v| v.as_i64()), + tool_type: row + .values + .get(1) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + timestamp: row.values.get(2).and_then(|v| v.as_i64()).unwrap_or(0), + client_ip: row + .values + .get(3) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + session_id: row + .values + .get(4) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + config_name: row + .values + .get(5) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + model: row + .values + .get(6) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + message_id: row.values.get(7).and_then(|v| v.as_str()).map(String::from), + input_tokens: row.values.get(8).and_then(|v| v.as_i64()).unwrap_or(0), + output_tokens: row.values.get(9).and_then(|v| v.as_i64()).unwrap_or(0), + cache_creation_tokens: row.values.get(10).and_then(|v| v.as_i64()).unwrap_or(0), + cache_read_tokens: row.values.get(11).and_then(|v| v.as_i64()).unwrap_or(0), + request_status: row + .values + .get(12) + .and_then(|v| v.as_str()) + .unwrap_or("success") + .to_string(), + response_type: row + .values + .get(13) + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(), + error_type: row + .values + .get(14) + .and_then(|v| v.as_str()) + .map(String::from), + error_detail: row + .values + .get(15) + .and_then(|v| v.as_str()) + .map(String::from), + response_time_ms: row.values.get(16).and_then(|v| v.as_i64()), + input_price: row.values.get(17).and_then(|v| v.as_f64()), + output_price: row.values.get(18).and_then(|v| v.as_f64()), + cache_write_price: row.values.get(19).and_then(|v| v.as_f64()), + cache_read_price: row.values.get(20).and_then(|v| v.as_f64()), + total_cost: row.values.get(21).and_then(|v| v.as_f64()).unwrap_or(0.0), + pricing_template_id: row + .values + .get(22) + .and_then(|v| v.as_str()) + .map(String::from), + }) + }) + .collect::>>()?; + + Ok(TokenLogsPage { + logs, + total, + page: query.page, + page_size: query.page_size, + }) + } + + /// 清理旧数据 + pub fn cleanup_old_logs( + &self, + retention_days: Option, + max_count: Option, + ) -> Result { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + let mut deleted_count = 0; + + // 按时间清理 + if let Some(days) = retention_days { + let cutoff_timestamp = + chrono::Utc::now().timestamp_millis() - (days as i64 * 86400 * 1000); + let count = manager + .execute( + "DELETE FROM token_logs WHERE timestamp < ?1", + &[&cutoff_timestamp.to_string()], + ) + .context("Failed to delete old logs by time")?; + deleted_count += count; + } + + // 按条数清理 + if let Some(max) = max_count { + let count = manager + .execute( + "DELETE FROM token_logs + WHERE id NOT IN ( + SELECT id FROM token_logs + ORDER BY timestamp DESC + LIMIT ?1 + )", + &[&max.to_string()], + ) + .context("Failed to delete old logs by count")?; + deleted_count += count; + } + + // 执行 WAL checkpoint 回写主文件 + if deleted_count > 0 { + manager + .execute_raw("PRAGMA wal_checkpoint(TRUNCATE)") + .context("Failed to checkpoint WAL")?; + } + + Ok(deleted_count) + } + + /// 获取数据库统计信息 + pub fn get_stats_summary(&self) -> Result<(i64, Option, Option)> { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + let rows = manager + .query( + "SELECT + COUNT(*) as total, + MIN(timestamp) as oldest, + MAX(timestamp) as newest + FROM token_logs", + &[], + ) + .context("Failed to query stats summary")?; + + let row = rows.first().context("No summary row returned")?; + + let total = row.values.first().and_then(|v| v.as_i64()).unwrap_or(0); + let oldest = row.values.get(1).and_then(|v| v.as_i64()); + let newest = row.values.get(2).and_then(|v| v.as_i64()); + + Ok((total, oldest, newest)) + } + + /// 强制执行 WAL checkpoint(手动触发) + /// + /// 将 WAL 文件中的所有数据回写到主数据库文件, + /// 用于清理过大的 WAL 文件 + pub fn force_checkpoint(&self) -> Result<()> { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + manager + .execute_raw("PRAGMA wal_checkpoint(TRUNCATE)") + .context("Failed to execute WAL checkpoint")?; + + Ok(()) + } + + /// 执行 PASSIVE checkpoint + /// + /// 尽可能多地将 WAL 数据回写到主文件,但不阻塞其他操作。 + /// 适合在批量写入后执行,性能影响最小。 + pub fn passive_checkpoint(&self) -> Result<()> { + let manager = DataManager::global() + .sqlite(&self.db_path) + .context("Failed to get SQLite manager")?; + + manager + .execute_raw("PRAGMA wal_checkpoint(PASSIVE)") + .context("Failed to execute PASSIVE checkpoint")?; + + Ok(()) + } +} + +impl Clone for TokenStatsDb { + fn clone(&self) -> Self { + Self::new(self.db_path.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + fn create_test_db() -> (TokenStatsDb, PathBuf) { + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test_token_stats.db"); + let db = TokenStatsDb::new(db_path.clone()); + db.init_table().unwrap(); + (db, db_path) + } + + #[test] + fn test_init_table() { + let (db, _) = create_test_db(); + // 重复初始化不应报错 + assert!(db.init_table().is_ok()); + } + + #[test] + fn test_insert_and_query() { + let (db, _) = create_test_db(); + + let log = TokenLog::new( + "claude_code".to_string(), + chrono::Utc::now().timestamp_millis(), + "127.0.0.1".to_string(), + "session_123".to_string(), + "default".to_string(), + "claude-sonnet-4-5-20250929".to_string(), + Some("msg_123".to_string()), + 1000, + 500, + 100, + 200, + "success".to_string(), + "json".to_string(), + None, + None, + None, + None, + None, + None, + None, + 0.0, + None, + ); + + let id = db.insert_log(&log).unwrap(); + assert!(id > 0); + + // 查询会话统计 + let stats = db.get_session_stats("claude_code", "session_123").unwrap(); + assert_eq!(stats.total_input, 1000); + assert_eq!(stats.total_output, 500); + assert_eq!(stats.request_count, 1); + } + + #[test] + fn test_query_logs_pagination() { + let (db, _) = create_test_db(); + + // 插入多条记录 + for i in 0..25 { + let log = TokenLog::new( + "claude_code".to_string(), + chrono::Utc::now().timestamp_millis() + i, + "127.0.0.1".to_string(), + "session_123".to_string(), + "default".to_string(), + "claude-sonnet-4-5-20250929".to_string(), + Some(format!("msg_{}", i)), + 100, + 50, + 10, + 20, + "success".to_string(), + "sse".to_string(), + None, + None, + None, + None, + None, + None, + None, + 0.0, + None, + ); + db.insert_log(&log).unwrap(); + } + + // 查询第一页 + let query = TokenStatsQuery { + page: 0, + page_size: 10, + ..Default::default() + }; + let page = db.query_logs(&query).unwrap(); + assert_eq!(page.logs.len(), 10); + assert_eq!(page.total, 25); + + // 查询第三页 + let query = TokenStatsQuery { + page: 2, + page_size: 10, + ..Default::default() + }; + let page = db.query_logs(&query).unwrap(); + assert_eq!(page.logs.len(), 5); + } + + #[test] + fn test_cleanup() { + let (db, _) = create_test_db(); + + // 插入旧数据和新数据 + let old_timestamp = chrono::Utc::now().timestamp_millis() - (40 * 86400 * 1000); // 40天前 + let old_log = TokenLog::new( + "claude_code".to_string(), + old_timestamp, + "127.0.0.1".to_string(), + "session_old".to_string(), + "default".to_string(), + "claude-3".to_string(), + None, + 100, + 50, + 0, + 0, + "success".to_string(), + "json".to_string(), + None, + None, + None, + None, + None, + None, + None, + 0.0, + None, + ); + db.insert_log(&old_log).unwrap(); + + let new_log = TokenLog::new( + "claude_code".to_string(), + chrono::Utc::now().timestamp_millis(), + "127.0.0.1".to_string(), + "session_new".to_string(), + "default".to_string(), + "claude-3".to_string(), + None, + 200, + 100, + 0, + 0, + "success".to_string(), + "json".to_string(), + None, + None, + None, + None, + None, + None, + None, + 0.0, + None, + ); + db.insert_log(&new_log).unwrap(); + + // 清理30天前的数据 + let deleted = db.cleanup_old_logs(Some(30), None).unwrap(); + assert_eq!(deleted, 1); + + // 验证新数据仍在 + let stats = db.get_session_stats("claude_code", "session_new").unwrap(); + assert_eq!(stats.request_count, 1); + } +} diff --git a/src-tauri/src/services/token_stats/extractor.rs b/src-tauri/src/services/token_stats/extractor.rs new file mode 100644 index 0000000..bad2157 --- /dev/null +++ b/src-tauri/src/services/token_stats/extractor.rs @@ -0,0 +1,520 @@ +use anyhow::{Context, Result}; +use serde_json::Value; + +/// Token提取器统一接口 +pub trait TokenExtractor: Send + Sync { + /// 从请求体中提取模型名称 + fn extract_model_from_request(&self, body: &[u8]) -> Result; + + /// 从SSE数据块中提取Token信息 + fn extract_from_sse_chunk(&self, chunk: &str) -> Result>; + + /// 从JSON响应中提取Token信息 + fn extract_from_json(&self, json: &Value) -> Result; +} + +/// SSE流式数据中的Token信息 +#[derive(Debug, Clone, Default)] +pub struct SseTokenData { + /// message_start块数据 + pub message_start: Option, + /// message_delta块数据(end_turn) + pub message_delta: Option, +} + +/// message_start块数据 +#[derive(Debug, Clone)] +pub struct MessageStartData { + pub model: String, + pub message_id: String, + pub input_tokens: i64, + pub output_tokens: i64, + pub cache_creation_tokens: i64, + pub cache_read_tokens: i64, +} + +/// message_delta块数据(end_turn) +#[derive(Debug, Clone)] +pub struct MessageDeltaData { + pub cache_creation_tokens: i64, + pub cache_read_tokens: i64, + pub output_tokens: i64, +} + +/// 响应Token信息(完整) +#[derive(Debug, Clone)] +pub struct ResponseTokenInfo { + pub model: String, + pub message_id: String, + pub input_tokens: i64, + pub output_tokens: i64, + pub cache_creation_tokens: i64, + pub cache_read_tokens: i64, +} + +impl ResponseTokenInfo { + /// 从SSE数据合并得到完整信息 + /// + /// 合并规则: + /// - model, message_id, input_tokens: 始终使用 message_start 的值 + /// - output_tokens, cache_*: 优先使用 message_delta 的值,回退到 message_start + pub fn from_sse_data(start: MessageStartData, delta: Option) -> Self { + let (cache_creation, cache_read, output) = if let Some(d) = delta { + // 优先使用 delta 的值(最终统计) + ( + d.cache_creation_tokens, + d.cache_read_tokens, + d.output_tokens, + ) + } else { + // 回退到 start 的值(初始统计) + ( + start.cache_creation_tokens, + start.cache_read_tokens, + start.output_tokens, + ) + }; + + Self { + model: start.model, + message_id: start.message_id, + input_tokens: start.input_tokens, + output_tokens: output, + cache_creation_tokens: cache_creation, + cache_read_tokens: cache_read, + } + } +} + +/// Claude Code工具的Token提取器 +pub struct ClaudeTokenExtractor; + +impl TokenExtractor for ClaudeTokenExtractor { + fn extract_model_from_request(&self, body: &[u8]) -> Result { + let json: Value = + serde_json::from_slice(body).context("Failed to parse request body as JSON")?; + + json.get("model") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .context("Missing 'model' field in request body") + } + + fn extract_from_sse_chunk(&self, chunk: &str) -> Result> { + // SSE格式: data: {...} 或直接 {...}(已去掉前缀) + let data_line = chunk.trim(); + + // 跳过空行 + if data_line.is_empty() { + return Ok(None); + } + + // 兼容处理:去掉 "data: " 前缀(如果存在) + let json_str = if let Some(stripped) = data_line.strip_prefix("data: ") { + stripped + } else { + data_line + }; + + // 跳过 [DONE] 标记 + if json_str.trim() == "[DONE]" { + return Ok(None); + } + + let json: Value = + serde_json::from_str(json_str).context("Failed to parse SSE chunk as JSON")?; + + let event_type = json.get("type").and_then(|v| v.as_str()).unwrap_or(""); + + let mut result = SseTokenData::default(); + + match event_type { + "message_start" => { + if let Some(message) = json.get("message") { + let model = message + .get("model") + .and_then(|v| v.as_str()) + .context("Missing model in message_start")? + .to_string(); + + let message_id = message + .get("id") + .and_then(|v| v.as_str()) + .context("Missing id in message_start")? + .to_string(); + + let usage = message + .get("usage") + .context("Missing usage in message_start")?; + + let input_tokens = usage + .get("input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + let output_tokens = usage + .get("output_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + // 提取缓存创建 token:优先读取扁平字段,回退到嵌套对象 + let cache_creation_tokens = usage + .get("cache_creation_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or_else(|| { + if let Some(cache_obj) = usage.get("cache_creation") { + let ephemeral_5m = cache_obj + .get("ephemeral_5m_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + let ephemeral_1h = cache_obj + .get("ephemeral_1h_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + ephemeral_5m + ephemeral_1h + } else { + 0 + } + }); + + let cache_read_tokens = usage + .get("cache_read_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + result.message_start = Some(MessageStartData { + model, + message_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + }); + } + } + "message_delta" => { + // 检查是否有 stop_reason(任何值都接受:end_turn, tool_use, max_tokens 等) + if let Some(delta) = json.get("delta") { + if delta.get("stop_reason").and_then(|v| v.as_str()).is_some() { + if let Some(usage) = json.get("usage") { + // 提取缓存创建 token:优先读取扁平字段,回退到嵌套对象 + let cache_creation = usage + .get("cache_creation_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or_else(|| { + if let Some(cache_obj) = usage.get("cache_creation") { + let ephemeral_5m = cache_obj + .get("ephemeral_5m_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + let ephemeral_1h = cache_obj + .get("ephemeral_1h_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + ephemeral_5m + ephemeral_1h + } else { + 0 + } + }); + + let cache_read = usage + .get("cache_read_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + let output_tokens = usage + .get("output_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + result.message_delta = Some(MessageDeltaData { + cache_creation_tokens: cache_creation, + cache_read_tokens: cache_read, + output_tokens, + }); + } + } + } + } + _ => {} + } + + Ok( + if result.message_start.is_some() || result.message_delta.is_some() { + Some(result) + } else { + None + }, + ) + } + + fn extract_from_json(&self, json: &Value) -> Result { + let model = json + .get("model") + .and_then(|v| v.as_str()) + .context("Missing model field")? + .to_string(); + + let message_id = json + .get("id") + .and_then(|v| v.as_str()) + .context("Missing id field")? + .to_string(); + + let usage = json.get("usage").context("Missing usage field")?; + + let input_tokens = usage + .get("input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + let output_tokens = usage + .get("output_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + // 提取 cache_creation_input_tokens: + // 优先读取扁平字段,如果不存在则尝试从嵌套对象聚合 + let cache_creation = usage + .get("cache_creation_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or_else(|| { + // 回退:尝试从嵌套的 cache_creation 对象聚合 + if let Some(cache_obj) = usage.get("cache_creation") { + let ephemeral_5m = cache_obj + .get("ephemeral_5m_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + let ephemeral_1h = cache_obj + .get("ephemeral_1h_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + ephemeral_5m + ephemeral_1h + } else { + 0 + } + }); + + let cache_read = usage + .get("cache_read_input_tokens") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + Ok(ResponseTokenInfo { + model, + message_id, + input_tokens, + output_tokens, + cache_creation_tokens: cache_creation, + cache_read_tokens: cache_read, + }) + } +} + +/// 创建Token提取器工厂函数 +pub fn create_extractor(tool_type: &str) -> Result> { + // 支持破折号和下划线两种格式 + let normalized = tool_type.replace('-', "_"); + match normalized.as_str() { + "claude_code" => Ok(Box::new(ClaudeTokenExtractor)), + // 预留扩展点 + "codex" => anyhow::bail!("Codex token extractor not implemented yet"), + "gemini_cli" => anyhow::bail!("Gemini CLI token extractor not implemented yet"), + _ => anyhow::bail!("Unknown tool type: {}", tool_type), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_model_from_request() { + let extractor = ClaudeTokenExtractor; + let body = r#"{"model":"claude-sonnet-4-5-20250929","messages":[]}"#; + + let model = extractor + .extract_model_from_request(body.as_bytes()) + .unwrap(); + assert_eq!(model, "claude-sonnet-4-5-20250929"); + } + + #[test] + fn test_extract_from_sse_message_start() { + let extractor = ClaudeTokenExtractor; + let chunk = r#"data: {"type":"message_start","message":{"model":"claude-haiku-4-5-20251001","id":"msg_123","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":27592,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":1}}}"#; + + let result = extractor.extract_from_sse_chunk(chunk).unwrap().unwrap(); + assert!(result.message_start.is_some()); + + let start = result.message_start.unwrap(); + assert_eq!(start.model, "claude-haiku-4-5-20251001"); + assert_eq!(start.message_id, "msg_123"); + assert_eq!(start.input_tokens, 27592); + assert_eq!(start.output_tokens, 1); + assert_eq!(start.cache_creation_tokens, 0); + assert_eq!(start.cache_read_tokens, 0); + } + + #[test] + fn test_extract_from_sse_message_delta() { + let extractor = ClaudeTokenExtractor; + let chunk = r#"data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":27592,"cache_creation_input_tokens":100,"cache_read_input_tokens":200,"output_tokens":12}}"#; + + let result = extractor.extract_from_sse_chunk(chunk).unwrap().unwrap(); + assert!(result.message_delta.is_some()); + + let delta = result.message_delta.unwrap(); + assert_eq!(delta.cache_creation_tokens, 100); + assert_eq!(delta.cache_read_tokens, 200); + assert_eq!(delta.output_tokens, 12); + } + + #[test] + fn test_extract_from_json() { + let extractor = ClaudeTokenExtractor; + let json_str = r#"{ + "content": [{"text": "test", "type": "text"}], + "id": "msg_018K1Hs5Tm7sC7xdeYpYhUFN", + "model": "claude-haiku-4-5-20251001", + "role": "assistant", + "stop_reason": "end_turn", + "type": "message", + "usage": { + "cache_creation_input_tokens": 50, + "cache_read_input_tokens": 100, + "input_tokens": 119, + "output_tokens": 21 + } + }"#; + + let json: Value = serde_json::from_str(json_str).unwrap(); + let result = extractor.extract_from_json(&json).unwrap(); + + assert_eq!(result.model, "claude-haiku-4-5-20251001"); + assert_eq!(result.message_id, "msg_018K1Hs5Tm7sC7xdeYpYhUFN"); + assert_eq!(result.input_tokens, 119); + assert_eq!(result.output_tokens, 21); + assert_eq!(result.cache_creation_tokens, 50); + assert_eq!(result.cache_read_tokens, 100); + } + + #[test] + fn test_response_token_info_from_sse() { + let start = MessageStartData { + model: "claude-3".to_string(), + message_id: "msg_123".to_string(), + input_tokens: 1000, + output_tokens: 1, + cache_creation_tokens: 50, + cache_read_tokens: 100, + }; + + let delta = MessageDeltaData { + cache_creation_tokens: 50, + cache_read_tokens: 100, + output_tokens: 200, + }; + + let info = ResponseTokenInfo::from_sse_data(start, Some(delta)); + assert_eq!(info.model, "claude-3"); + assert_eq!(info.input_tokens, 1000); + assert_eq!(info.output_tokens, 200); + assert_eq!(info.cache_creation_tokens, 50); + assert_eq!(info.cache_read_tokens, 100); + } + + #[test] + fn test_create_extractor() { + assert!(create_extractor("claude_code").is_ok()); + assert!(create_extractor("codex").is_err()); + assert!(create_extractor("gemini_cli").is_err()); + assert!(create_extractor("unknown").is_err()); + } + + #[test] + fn test_extract_nested_cache_creation_json() { + // 测试嵌套 cache_creation 对象的提取(JSON 响应) + let extractor = ClaudeTokenExtractor; + let json_str = r#"{ + "id": "msg_013B8kRbTZdntKmHWE6AZzuU", + "model": "claude-sonnet-4-5-20250929", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "test"}], + "usage": { + "cache_creation": { + "ephemeral_1h_input_tokens": 0, + "ephemeral_5m_input_tokens": 73444 + }, + "cache_creation_input_tokens": 73444, + "cache_read_input_tokens": 19198, + "input_tokens": 12, + "output_tokens": 259, + "service_tier": "standard" + } + }"#; + + let json: Value = serde_json::from_str(json_str).unwrap(); + let result = extractor.extract_from_json(&json).unwrap(); + + assert_eq!(result.model, "claude-sonnet-4-5-20250929"); + assert_eq!(result.message_id, "msg_013B8kRbTZdntKmHWE6AZzuU"); + assert_eq!(result.input_tokens, 12); + assert_eq!(result.output_tokens, 259); + assert_eq!(result.cache_creation_tokens, 73444); + assert_eq!(result.cache_read_tokens, 19198); + } + + #[test] + fn test_extract_nested_cache_creation_sse_start() { + // 测试嵌套 cache_creation 对象的提取(SSE message_start) + let extractor = ClaudeTokenExtractor; + let chunk = r#"data: {"type":"message_start","message":{"model":"claude-sonnet-4-5-20250929","id":"msg_018GWR1gBaJBchrC6t5nnRui","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":9,"cache_creation_input_tokens":2122,"cache_read_input_tokens":123663,"cache_creation":{"ephemeral_5m_input_tokens":2122,"ephemeral_1h_input_tokens":0},"output_tokens":1,"service_tier":"standard"}}}"#; + + let result = extractor.extract_from_sse_chunk(chunk).unwrap().unwrap(); + assert!(result.message_start.is_some()); + + let start = result.message_start.unwrap(); + assert_eq!(start.model, "claude-sonnet-4-5-20250929"); + assert_eq!(start.message_id, "msg_018GWR1gBaJBchrC6t5nnRui"); + assert_eq!(start.input_tokens, 9); + assert_eq!(start.output_tokens, 1); + assert_eq!(start.cache_creation_tokens, 2122); + assert_eq!(start.cache_read_tokens, 123663); + } + + #[test] + fn test_extract_message_delta_with_tool_use() { + // 测试 stop_reason="tool_use" 的情况 + let extractor = ClaudeTokenExtractor; + let chunk = r#"data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":9,"cache_creation_input_tokens":2122,"cache_read_input_tokens":123663,"output_tokens":566}}"#; + + let result = extractor.extract_from_sse_chunk(chunk).unwrap().unwrap(); + assert!(result.message_delta.is_some()); + + let delta = result.message_delta.unwrap(); + assert_eq!(delta.cache_creation_tokens, 2122); + assert_eq!(delta.cache_read_tokens, 123663); + assert_eq!(delta.output_tokens, 566); + } + + #[test] + fn test_from_sse_data_without_delta() { + // 测试没有 delta 时使用 start 的缓存值 + let start = MessageStartData { + model: "claude-3".to_string(), + message_id: "msg_test".to_string(), + input_tokens: 100, + output_tokens: 50, + cache_creation_tokens: 200, + cache_read_tokens: 300, + }; + + let info = ResponseTokenInfo::from_sse_data(start, None); + assert_eq!(info.input_tokens, 100); + assert_eq!(info.output_tokens, 50); + assert_eq!(info.cache_creation_tokens, 200); + assert_eq!(info.cache_read_tokens, 300); + } +} diff --git a/src-tauri/src/services/token_stats/manager.rs b/src-tauri/src/services/token_stats/manager.rs new file mode 100644 index 0000000..2732d50 --- /dev/null +++ b/src-tauri/src/services/token_stats/manager.rs @@ -0,0 +1,553 @@ +use crate::models::token_stats::{SessionStats, TokenLog, TokenLogsPage, TokenStatsQuery}; +use crate::services::pricing::PRICING_MANAGER; +use crate::services::token_stats::db::TokenStatsDb; +use crate::services::token_stats::extractor::{ + create_extractor, MessageDeltaData, MessageStartData, ResponseTokenInfo, +}; +use crate::utils::config_dir; +use anyhow::{Context, Result}; +use once_cell::sync::OnceCell; +use serde_json::Value; +use std::path::PathBuf; +use tokio::sync::mpsc; +use tokio::time::{interval, Duration}; +use tokio_util::sync::CancellationToken; + +/// 全局 TokenStatsManager 单例 +static TOKEN_STATS_MANAGER: OnceCell = OnceCell::new(); + +/// 全局取消令牌,用于优雅关闭后台任务 +static CANCELLATION_TOKEN: once_cell::sync::Lazy = + once_cell::sync::Lazy::new(CancellationToken::new); + +/// 响应数据类型 +pub enum ResponseData { + /// SSE流式响应(收集的所有data块) + Sse(Vec), + /// JSON响应 + Json(Value), +} + +/// Token统计管理器 +pub struct TokenStatsManager { + db: TokenStatsDb, + event_sender: mpsc::UnboundedSender, +} + +impl TokenStatsManager { + /// 获取全局单例实例 + pub fn get() -> &'static TokenStatsManager { + TOKEN_STATS_MANAGER.get_or_init(|| { + let db_path = Self::default_db_path(); + let db = TokenStatsDb::new(db_path); + + // 初始化数据库表 + if let Err(e) = db.init_table() { + eprintln!("Failed to initialize token stats database: {}", e); + } + + // 创建事件队列 + let (event_sender, event_receiver) = mpsc::unbounded_channel(); + + let manager = TokenStatsManager { db, event_sender }; + + // 启动后台任务 + manager.start_background_tasks(event_receiver); + + manager + }) + } + + /// 获取默认数据库路径 + fn default_db_path() -> PathBuf { + config_dir() + .map(|dir| dir.join("token_stats.db")) + .unwrap_or_else(|_| PathBuf::from("token_stats.db")) + } + + /// 启动后台任务 + fn start_background_tasks(&self, mut event_receiver: mpsc::UnboundedReceiver) { + let db = self.db.clone(); + + // 批量写入任务 + tokio::spawn(async move { + let mut buffer: Vec = Vec::new(); + let mut tick_interval = interval(Duration::from_millis(100)); + + loop { + tokio::select! { + _ = CANCELLATION_TOKEN.cancelled() => { + // 应用关闭,刷盘缓冲区 + if !buffer.is_empty() { + Self::flush_logs(&db, &mut buffer, true); + tracing::info!("Token 日志已刷盘: {} 条", buffer.len()); + } + tracing::info!("Token 批量写入任务已停止"); + break; + } + // 接收日志事件 + Some(log) = event_receiver.recv() => { + buffer.push(log); + + // 如果缓冲区达到 10 条,立即写入 + if buffer.len() >= 10 { + Self::flush_logs(&db, &mut buffer, false); + } + } + // 每 100ms 刷新一次 + _ = tick_interval.tick() => { + if !buffer.is_empty() { + Self::flush_logs(&db, &mut buffer, false); + } + } + } + } + }); + + // 定期 TRUNCATE checkpoint 任务(每 5 分钟) + let db_clone = self.db.clone(); + tokio::spawn(async move { + let mut checkpoint_interval = interval(Duration::from_secs(300)); // 5分钟 + + loop { + tokio::select! { + _ = CANCELLATION_TOKEN.cancelled() => { + tracing::info!("Token Checkpoint 任务已停止"); + break; + } + _ = checkpoint_interval.tick() => { + if let Err(e) = db_clone.force_checkpoint() { + tracing::error!("定期 Checkpoint 失败: {}", e); + } else { + tracing::debug!("Token 数据库 TRUNCATE checkpoint 完成"); + } + } + } + } + }); + } + + /// 批量写入日志到数据库 + /// + /// # 参数 + /// - `db`: 数据库实例 + /// - `buffer`: 日志缓冲区 + /// - `use_truncate`: 是否使用 TRUNCATE checkpoint(应用关闭时使用) + fn flush_logs(db: &TokenStatsDb, buffer: &mut Vec, use_truncate: bool) { + for log in buffer.drain(..) { + if let Err(e) = db.insert_log_without_checkpoint(&log) { + tracing::error!("插入 Token 日志失败: {}", e); + } + } + + // 批量写入后执行 checkpoint + let checkpoint_result = if use_truncate { + db.force_checkpoint() // TRUNCATE模式 + } else { + db.passive_checkpoint() // PASSIVE模式 + }; + + if let Err(e) = checkpoint_result { + tracing::error!("Checkpoint 失败: {}", e); + } + } + + /// 记录请求日志 + /// + /// # 参数 + /// + /// - `tool_type`: 工具类型(claude_code/codex/gemini_cli) + /// - `session_id`: 会话ID + /// - `config_name`: 使用的配置名称 + /// - `client_ip`: 客户端IP地址 + /// - `request_body`: 请求体(用于提取model) + /// - `response_data`: 响应数据(SSE流或JSON) + /// - `response_time_ms`: 响应时间(毫秒) + /// - `pricing_template_id`: 价格模板ID(None则使用默认模板) + #[allow(clippy::too_many_arguments)] + pub async fn log_request( + &self, + tool_type: &str, + session_id: &str, + config_name: &str, + client_ip: &str, + request_body: &[u8], + response_data: ResponseData, + response_time_ms: Option, + pricing_template_id: Option, + ) -> Result<()> { + // 创建提取器 + let extractor = create_extractor(tool_type).context("Failed to create token extractor")?; + + // 提取请求中的模型名称 + let model = extractor + .extract_model_from_request(request_body) + .context("Failed to extract model from request")?; + + // 确定响应类型 + let response_type = match &response_data { + ResponseData::Sse(_) => "sse", + ResponseData::Json(_) => "json", + }; + + // 提取响应中的Token信息 + let token_info = match response_data { + ResponseData::Sse(chunks) => self.parse_sse_chunks(&*extractor, chunks)?, + ResponseData::Json(json) => extractor.extract_from_json(&json)?, + }; + + // 使用价格模板计算成本 + let template_id_ref = pricing_template_id.as_deref(); + + let ( + final_input_price, + final_output_price, + final_cache_write_price, + final_cache_read_price, + final_total_cost, + final_pricing_template_id, + ) = match PRICING_MANAGER.calculate_cost( + template_id_ref, + &model, + token_info.input_tokens, + token_info.output_tokens, + token_info.cache_creation_tokens, + token_info.cache_read_tokens, + ) { + Ok(breakdown) => { + tracing::debug!( + model = %model, + template_id = %breakdown.template_id, + total_cost = breakdown.total_cost, + input_tokens = token_info.input_tokens, + output_tokens = token_info.output_tokens, + "成本计算成功" + ); + ( + Some(breakdown.input_price), + Some(breakdown.output_price), + Some(breakdown.cache_write_price), + Some(breakdown.cache_read_price), + breakdown.total_cost, + Some(breakdown.template_id), + ) + } + Err(e) => { + // 计算失败,使用 0 + tracing::warn!( + model = %model, + template_id = ?template_id_ref, + error = ?e, + "成本计算失败,使用默认值 0" + ); + (None, None, None, None, 0.0, None) + } + }; + + // 创建日志记录(成功) + let timestamp = chrono::Utc::now().timestamp_millis(); + let log = TokenLog::new( + tool_type.to_string(), + timestamp, + client_ip.to_string(), + session_id.to_string(), + config_name.to_string(), + model, + Some(token_info.message_id), + token_info.input_tokens, + token_info.output_tokens, + token_info.cache_creation_tokens, + token_info.cache_read_tokens, + "success".to_string(), + response_type.to_string(), + None, + None, + response_time_ms, + final_input_price, + final_output_price, + final_cache_write_price, + final_cache_read_price, + final_total_cost, + final_pricing_template_id, + ); + + // 发送到批量写入队列(异步,不阻塞) + if let Err(e) = self.event_sender.send(log) { + tracing::error!("发送 Token 日志事件失败: {}", e); + } + + Ok(()) + } + + /// 记录失败的请求 + /// + /// # 参数 + /// + /// - `tool_type`: 工具类型 + /// - `session_id`: 会话ID + /// - `config_name`: 配置名称 + /// - `client_ip`: 客户端IP + /// - `request_body`: 请求体(用于提取model,失败时可能为空) + /// - `error_type`: 错误类型(parse_error/request_interrupted/upstream_error) + /// - `error_detail`: 错误详情 + /// - `response_type`: 响应类型(sse/json/unknown) + /// - `response_time_ms`: 响应时间(毫秒) + #[allow(clippy::too_many_arguments)] + pub async fn log_failed_request( + &self, + tool_type: &str, + session_id: &str, + config_name: &str, + client_ip: &str, + request_body: &[u8], + error_type: &str, + error_detail: &str, + response_type: &str, + response_time_ms: Option, + ) -> Result<()> { + // 尝试提取模型名称(失败时使用 "unknown") + let model = if !request_body.is_empty() { + create_extractor(tool_type) + .and_then(|extractor| extractor.extract_model_from_request(request_body)) + .unwrap_or_else(|_| "unknown".to_string()) + } else { + "unknown".to_string() + }; + + // 创建日志记录(失败) + let timestamp = chrono::Utc::now().timestamp_millis(); + let log = TokenLog::new( + tool_type.to_string(), + timestamp, + client_ip.to_string(), + session_id.to_string(), + config_name.to_string(), + model, + None, // 失败时没有 message_id + 0, // 失败时 token 数量为 0 + 0, + 0, + 0, + "failed".to_string(), + response_type.to_string(), + Some(error_type.to_string()), + Some(error_detail.to_string()), + response_time_ms, + None, // 失败时没有价格信息 + None, + None, + None, + 0.0, // 失败时成本为 0 + None, + ); + + // 发送到批量写入队列 + if let Err(e) = self.event_sender.send(log) { + tracing::error!("发送失败请求日志事件失败: {}", e); + } + + Ok(()) + } + + /// 解析SSE流数据块 + fn parse_sse_chunks( + &self, + extractor: &dyn crate::services::token_stats::extractor::TokenExtractor, + chunks: Vec, + ) -> Result { + let mut message_start: Option = None; + let mut message_delta: Option = None; + + for (i, chunk) in chunks.iter().enumerate() { + match extractor.extract_from_sse_chunk(chunk) { + Ok(Some(data)) => { + if let Some(start) = data.message_start { + tracing::debug!(chunk_index = i, "找到 message_start 事件"); + message_start = Some(start); + } + if let Some(delta) = data.message_delta { + tracing::debug!(chunk_index = i, "找到 message_delta 事件"); + message_delta = Some(delta); + } + } + Ok(None) => { + // 正常跳过非数据块(如 ping、空行等) + } + Err(e) => { + tracing::warn!( + chunk_index = i, + error = ?e, + chunk_preview = %chunk.chars().take(100).collect::(), + "SSE chunk 解析失败" + ); + } + } + } + + if message_start.is_none() { + tracing::error!( + chunks_count = chunks.len(), + "所有 SSE chunks 中未找到 message_start 事件" + ); + } + + let start = message_start.context("Missing message_start in SSE stream")?; + + Ok(ResponseTokenInfo::from_sse_data(start, message_delta)) + } + + /// 查询会话实时统计 + pub fn get_session_stats(&self, tool_type: &str, session_id: &str) -> Result { + self.db.get_session_stats(tool_type, session_id) + } + + /// 分页查询历史日志 + pub fn query_logs(&self, query: TokenStatsQuery) -> Result { + self.db.query_logs(&query) + } + + /// 根据配置清理旧数据 + pub fn cleanup_by_config( + &self, + retention_days: Option, + max_count: Option, + ) -> Result { + self.db.cleanup_old_logs(retention_days, max_count) + } + + /// 获取数据库统计摘要 + pub fn get_stats_summary(&self) -> Result<(i64, Option, Option)> { + self.db.get_stats_summary() + } + + /// 强制执行 WAL checkpoint + /// + /// 将所有 WAL 数据回写到主数据库文件, + /// 用于手动清理过大的 WAL 文件 + pub fn force_checkpoint(&self) -> Result<()> { + self.db.force_checkpoint() + } +} + +/// 关闭 TokenStatsManager 后台任务 +/// +/// 在应用关闭时调用,优雅地停止所有后台任务并刷盘缓冲区数据 +pub fn shutdown_token_stats_manager() { + tracing::info!("TokenStatsManager 关闭信号已发送"); + CANCELLATION_TOKEN.cancel(); + + // 等待一小段时间让任务完成刷盘 + std::thread::sleep(std::time::Duration::from_millis(300)); +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[tokio::test] + async fn test_log_request_with_json() { + let manager = TokenStatsManager::get(); + + let request_body = json!({ + "model": "claude-sonnet-4-5-20250929", + "messages": [] + }) + .to_string(); + + let response_json = json!({ + "id": "msg_test_123", + "model": "claude-sonnet-4-5-20250929", + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_creation_input_tokens": 10, + "cache_read_input_tokens": 20 + } + }); + + let result = manager + .log_request( + "claude_code", + "test_session", + "default", + "127.0.0.1", + request_body.as_bytes(), + ResponseData::Json(response_json), + None, // response_time_ms + None, // pricing_template_id + ) + .await; + + assert!(result.is_ok()); + + // 等待异步插入完成 + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // 验证统计数据 + let stats = manager + .get_session_stats("claude_code", "test_session") + .unwrap(); + assert_eq!(stats.total_input, 100); + assert_eq!(stats.total_output, 50); + } + + #[test] + fn test_parse_sse_chunks() { + let manager = TokenStatsManager::get(); + let extractor = create_extractor("claude_code").unwrap(); + + let chunks = vec![ + r#"data: {"type":"message_start","message":{"model":"claude-3","id":"msg_123","usage":{"input_tokens":1000,"output_tokens":1}}}"#.to_string(), + r#"data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"cache_creation_input_tokens":50,"cache_read_input_tokens":100,"output_tokens":200}}"#.to_string(), + ]; + + let info = manager.parse_sse_chunks(&*extractor, chunks).unwrap(); + assert_eq!(info.model, "claude-3"); + assert_eq!(info.message_id, "msg_123"); + assert_eq!(info.input_tokens, 1000); + assert_eq!(info.output_tokens, 200); + assert_eq!(info.cache_creation_tokens, 50); + assert_eq!(info.cache_read_tokens, 100); + } + + #[test] + fn test_query_logs() { + let manager = TokenStatsManager::get(); + + // 插入测试数据 + let log = TokenLog::new( + "claude_code".to_string(), + chrono::Utc::now().timestamp_millis(), + "127.0.0.1".to_string(), + "test_query_session".to_string(), + "default".to_string(), + "claude-3".to_string(), + Some("msg_query_test".to_string()), + 100, + 50, + 10, + 20, + "success".to_string(), + "json".to_string(), + None, + None, + None, + None, + None, + None, + None, + 0.0, + None, + ); + manager.db.insert_log(&log).unwrap(); + + // 查询日志 + let query = TokenStatsQuery { + session_id: Some("test_query_session".to_string()), + ..Default::default() + }; + let page = manager.query_logs(query).unwrap(); + assert!(page.total >= 1); + } +} diff --git a/src-tauri/src/services/token_stats/mod.rs b/src-tauri/src/services/token_stats/mod.rs new file mode 100644 index 0000000..0ef2696 --- /dev/null +++ b/src-tauri/src/services/token_stats/mod.rs @@ -0,0 +1,22 @@ +//! Token统计服务模块 +//! +//! 提供透明代理的Token数据统计和请求记录功能。 + +pub mod analytics; +pub mod db; +pub mod extractor; +pub mod manager; + +#[cfg(test)] +mod cost_calculation_test; + +pub use analytics::{ + CostGroupBy, CostSummary, CostSummaryQuery, TimeGranularity, TokenStatsAnalytics, + TrendDataPoint, TrendQuery, +}; +pub use db::TokenStatsDb; +pub use extractor::{ + create_extractor, ClaudeTokenExtractor, MessageDeltaData, MessageStartData, ResponseTokenInfo, + SseTokenData, TokenExtractor, +}; +pub use manager::{shutdown_token_stats_manager, TokenStatsManager}; diff --git a/src-tauri/src/utils/mod.rs b/src-tauri/src/utils/mod.rs index dd44c99..560a09f 100644 --- a/src-tauri/src/utils/mod.rs +++ b/src-tauri/src/utils/mod.rs @@ -4,6 +4,7 @@ pub mod config; pub mod file_helpers; pub mod installer_scanner; pub mod platform; +pub mod precision; pub mod version; pub mod wsl_executor; diff --git a/src-tauri/src/utils/precision.rs b/src-tauri/src/utils/precision.rs new file mode 100644 index 0000000..304be36 --- /dev/null +++ b/src-tauri/src/utils/precision.rs @@ -0,0 +1,163 @@ +//! 数值精度工具模块 +//! +//! 提供价格等需要精确小数位数的序列化/反序列化支持 + +use serde::{Deserialize, Deserializer, Serializer}; + +/// 价格字段精度(小数点后6位) +/// +/// 用于 serde 的 serialize_with 和 deserialize_with 属性 +pub mod price_precision { + use super::*; + + /// 序列化 f64 为固定 6 位小数 + /// + /// 注意:对于非常小的数(< 0.0001),JSON可能使用科学计数法 + pub fn serialize(value: &f64, serializer: S) -> Result + where + S: Serializer, + { + // 四舍五入到 6 位小数 + let multiplier = 1_000_000.0; // 10^6 + let rounded = (value * multiplier).round() / multiplier; + serializer.serialize_f64(rounded) + } + + /// 反序列化保持原有精度 + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + f64::deserialize(deserializer) + } +} + +/// 可选价格字段精度(Option) +pub mod option_price_precision { + use super::*; + + /// 序列化 Option 为固定 6 位小数 + pub fn serialize(value: &Option, serializer: S) -> Result + where + S: Serializer, + { + match value { + Some(v) => { + // 四舍五入到 6 位小数 + let multiplier = 1_000_000.0; // 10^6 + let rounded = (v * multiplier).round() / multiplier; + serializer.serialize_some(&rounded) + } + None => serializer.serialize_none(), + } + } + + /// 反序列化保持原有精度 + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + Option::::deserialize(deserializer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct TestStruct { + #[serde(with = "price_precision")] + price: f64, + #[serde(with = "option_price_precision")] + optional_price: Option, + } + + #[test] + fn test_price_precision_serialization() { + let test = TestStruct { + price: 0.099904123456789, + optional_price: Some(1.234567890123), + }; + + let json = serde_json::to_string(&test).unwrap(); + println!("Serialized: {}", json); + + // 反序列化验证精度 + let deserialized: TestStruct = serde_json::from_str(&json).unwrap(); + assert!((deserialized.price - 0.099904).abs() < 1e-9); + assert!((deserialized.optional_price.unwrap() - 1.234568).abs() < 1e-9); + } + + #[test] + fn test_price_precision_deserialization() { + let json = r#"{"price":0.099904,"optional_price":1.234568}"#; + let test: TestStruct = serde_json::from_str(json).unwrap(); + + assert!((test.price - 0.099904).abs() < 1e-9); + assert!((test.optional_price.unwrap() - 1.234568).abs() < 1e-9); + } + + #[test] + fn test_option_price_none() { + let test = TestStruct { + price: 0.5, + optional_price: None, + }; + + let json = serde_json::to_string(&test).unwrap(); + assert!(json.contains("\"optional_price\":null")); + + let deserialized: TestStruct = serde_json::from_str(&json).unwrap(); + assert!((deserialized.price - 0.5).abs() < 1e-9); + assert!(deserialized.optional_price.is_none()); + } + + #[test] + fn test_very_small_price() { + let test = TestStruct { + price: 0.000001234567, + optional_price: Some(0.000009876543), + }; + + let json = serde_json::to_string(&test).unwrap(); + println!("Small price serialized: {}", json); + + // 反序列化验证精度(四舍五入到 6 位小数) + let deserialized: TestStruct = serde_json::from_str(&json).unwrap(); + assert!((deserialized.price - 0.000001).abs() < 1e-9); // 0.000001234567 -> 0.000001 + assert!((deserialized.optional_price.unwrap() - 0.00001).abs() < 1e-9); // 0.000009876543 -> 0.00001 + } + + #[test] + fn test_typical_api_costs() { + // 测试典型的 API 成本(例如 Claude API) + let test = TestStruct { + price: 0.001234, + optional_price: Some(0.056789), + }; + + let json = serde_json::to_string(&test).unwrap(); + println!("Typical API cost serialized: {}", json); + + let deserialized: TestStruct = serde_json::from_str(&json).unwrap(); + assert!((deserialized.price - 0.001234).abs() < 1e-9); + assert!((deserialized.optional_price.unwrap() - 0.056789).abs() < 1e-9); + } + + #[test] + fn test_rounding_behavior() { + // 测试四舍五入行为 + let test = TestStruct { + price: 0.0000015, // 应该四舍五入到 0.000002 + optional_price: Some(0.0000014), // 应该四舍五入到 0.000001 + }; + + let json = serde_json::to_string(&test).unwrap(); + let deserialized: TestStruct = serde_json::from_str(&json).unwrap(); + + assert!((deserialized.price - 0.000002).abs() < 1e-9); + assert!((deserialized.optional_price.unwrap() - 0.000001).abs() < 1e-9); + } +} diff --git a/src/App.tsx b/src/App.tsx index 01e10a2..1eb4000 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -14,12 +14,13 @@ import { TransparentProxyPage } from '@/pages/TransparentProxyPage'; import { ToolManagementPage } from '@/pages/ToolManagementPage'; import { HelpPage } from '@/pages/HelpPage'; import { AboutPage } from '@/pages/AboutPage'; +import { BalancePage } from '@/pages/BalancePage'; +import TokenStatisticsPage from '@/pages/TokenStatisticsPage'; import { useToast } from '@/hooks/use-toast'; import { useAppEvents } from '@/hooks/useAppEvents'; import { useCloseAction } from '@/hooks/useCloseAction'; import { useConfigWatch } from '@/hooks/useConfigWatch'; import { Toaster } from '@/components/ui/toaster'; -import { BalancePage } from '@/pages/BalancePage'; import OnboardingOverlay from '@/components/Onboarding/OnboardingOverlay'; import { getRequiredSteps, @@ -36,6 +37,7 @@ import { type GlobalConfig, type UpdateInfo, } from '@/lib/tauri-commands'; +import type { ToolType } from '@/types/token-stats'; type TabType = | 'dashboard' @@ -44,6 +46,7 @@ type TabType = | 'profile-management' | 'balance' | 'transparent-proxy' + | 'token-statistics' | 'provider-management' | 'settings' | 'help' @@ -57,6 +60,12 @@ function App() { const [settingsRestrictToTab, setSettingsRestrictToTab] = useState(undefined); const [restrictedPage, setRestrictedPage] = useState(undefined); + // Token 统计页面导航参数 + const [tokenStatsParams, setTokenStatsParams] = useState<{ + sessionId?: string; + toolType?: ToolType; + }>({}); + // 引导状态管理 const [showOnboarding, setShowOnboarding] = useState(false); const [onboardingSteps, setOnboardingSteps] = useState([]); @@ -286,6 +295,24 @@ function App() { setSettingsRestrictToTab(undefined); }); + // 监听应用内导航事件 + const unlistenAppNavigate = listen<{ + tab: TabType; + params?: { sessionId?: string; toolType?: ToolType }; + }>('app-navigate', (event) => { + const { tab, params } = event.payload || {}; + if (tab) { + setActiveTab(tab); + // 如果是导航到 Token 统计页面,保存参数 + if (tab === 'token-statistics' && params) { + setTokenStatsParams(params); + } else if (tab !== 'token-statistics') { + // 清空参数 + setTokenStatsParams({}); + } + } + }); + return () => { unlistenUpdateAvailable.then((fn) => fn()); unlistenRequestCheck.then((fn) => fn()); @@ -293,6 +320,7 @@ function App() { unlistenOpenSettings.then((fn) => fn()); unlistenOnboardingNavigate.then((fn) => fn()); unlistenClearRestriction.then((fn) => fn()); + unlistenAppNavigate.then((fn) => fn()); }; }, [toast]); @@ -380,6 +408,12 @@ function App() { {activeTab === 'transparent-proxy' && ( )} + {activeTab === 'token-statistics' && ( + + )} {activeTab === 'settings' && ( void; + startTime: Date | null; + endTime: Date | null; + onStartTimeChange: (date: Date | null) => void; + onEndTimeChange: (date: Date | null) => void; + onConfirm: () => void; +} + +/** + * 日期时间选择器组件 + */ +function DateTimePicker({ + date, + onDateChange, + label, + maxDate, +}: { + date: Date | null; + onDateChange: (date: Date | null) => void; + label: string; + maxDate?: Date; +}) { + const [timeInput, setTimeInput] = useState('00:00'); + + useEffect(() => { + if (date) { + const hours = date.getHours().toString().padStart(2, '0'); + const minutes = date.getMinutes().toString().padStart(2, '0'); + setTimeInput(`${hours}:${minutes}`); + } + }, [date]); + + const handleDateSelect = (selectedDate: Date | undefined) => { + if (!selectedDate) { + onDateChange(null); + return; + } + + // 保留现有的时间部分 + const [hours, minutes] = timeInput.split(':').map(Number); + const newDate = new Date(selectedDate); + newDate.setHours(hours || 0, minutes || 0, 0, 0); + onDateChange(newDate); + }; + + const handleTimeChange = (e: React.ChangeEvent) => { + const value = e.target.value; + setTimeInput(value); + + if (!date) return; + + const [hours, minutes] = value.split(':').map(Number); + if (isNaN(hours) || isNaN(minutes)) return; + + const newDate = new Date(date); + newDate.setHours(hours, minutes, 0, 0); + onDateChange(newDate); + }; + + return ( +
+ +
+ {/* 日期选择器 */} + + + + + + { + if (maxDate && date > maxDate) return true; + return false; + }} + initialFocus + locale={zhCN} + /> + + + + {/* 时间输入框 */} +
+ + +
+
+
+ ); +} + +/** + * 自定义时间范围选择对话框 + */ +export function CustomTimeRangeDialog({ + open, + onOpenChange, + startTime, + endTime, + onStartTimeChange, + onEndTimeChange, + onConfirm, +}: CustomTimeRangeDialogProps) { + const validation = validateCustomTimeRange(startTime, endTime); + const maxDate = new Date(); // 结束时间不能晚于当前时刻 + + return ( + + + + 自定义时间范围 + + 选择起止时间(精确到分钟),时间跨度最多90天,结束时间不能晚于当前时刻 + + + +
+ {/* 开始时间 */} + + + {/* 结束时间 */} + + + {/* 错误提示 */} + {!validation.valid && validation.error && ( +
+ {validation.error} +
+ )} +
+ + + + + +
+
+ ); +} diff --git a/src/components/ui/calendar.tsx b/src/components/ui/calendar.tsx new file mode 100644 index 0000000..73126d7 --- /dev/null +++ b/src/components/ui/calendar.tsx @@ -0,0 +1,57 @@ +/** + * Calendar 组件 + * 基于 react-day-picker 的日历选择器组件 + */ + +import * as React from 'react'; +import { DayPicker } from 'react-day-picker'; + +import { cn } from '@/lib/utils'; +import { buttonVariants } from '@/components/ui/button-variants'; + +export type CalendarProps = React.ComponentProps; + +function Calendar({ className, classNames, showOutsideDays = true, ...props }: CalendarProps) { + return ( + + ); +} +Calendar.displayName = 'Calendar'; + +export { Calendar }; diff --git a/src/components/ui/collapsible.tsx b/src/components/ui/collapsible.tsx new file mode 100644 index 0000000..9605c4e --- /dev/null +++ b/src/components/ui/collapsible.tsx @@ -0,0 +1,9 @@ +import * as CollapsiblePrimitive from '@radix-ui/react-collapsible'; + +const Collapsible = CollapsiblePrimitive.Root; + +const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger; + +const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent; + +export { Collapsible, CollapsibleTrigger, CollapsibleContent }; diff --git a/src/components/ui/popover.tsx b/src/components/ui/popover.tsx new file mode 100644 index 0000000..c5d749a --- /dev/null +++ b/src/components/ui/popover.tsx @@ -0,0 +1,34 @@ +/** + * Popover 组件 + * 基于 Radix UI 的弹出层组件 + */ + +import * as React from 'react'; +import * as PopoverPrimitive from '@radix-ui/react-popover'; + +import { cn } from '@/lib/utils'; + +const Popover = PopoverPrimitive.Root; + +const PopoverTrigger = PopoverPrimitive.Trigger; + +const PopoverContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, align = 'center', sideOffset = 4, ...props }, ref) => ( + + + +)); +PopoverContent.displayName = PopoverPrimitive.Content.displayName; + +export { Popover, PopoverTrigger, PopoverContent }; diff --git a/src/hooks/useAnalyticsStats.ts b/src/hooks/useAnalyticsStats.ts new file mode 100644 index 0000000..5507fb3 --- /dev/null +++ b/src/hooks/useAnalyticsStats.ts @@ -0,0 +1,168 @@ +/** + * 统一的 Analytics 数据 Hook + * 整合 queryTokenTrends + queryCostSummary,提供完整的统计数据 + * + * 功能: + * - 支持全局统计 / 会话级统计 + * - 统一时间范围控制(15分钟~30天 + 自定义) + * - 自动聚合 Token 分布数据 + * - 处理响应时间 null 值 + * - 数据一致性保证 + */ + +import { useState, useEffect, useMemo, useCallback } from 'react'; +import { useTimeRangeControl, type UseTimeRangeControlReturn } from './useTimeRangeControl'; +import { queryTokenTrends, queryCostSummary } from '@/lib/tauri-commands/analytics'; +import type { TrendDataPoint, CostSummary } from '@/types/analytics'; + +/** + * Hook 参数 + */ +export interface UseAnalyticsStatsProps { + /** 工具 ID */ + toolId: string; + /** 会话 ID(可选,传入时为会话级统计) */ + sessionId?: string; + /** 是否启用自动加载(默认 true) */ + enabled?: boolean; +} + +/** + * Token 分布数据 + */ +export interface TokenBreakdown { + /** 输入 Token 总数 */ + input: number; + /** 输出 Token 总数 */ + output: number; + /** 缓存创建 Token 总数 */ + cacheCreation: number; + /** 缓存读取 Token 总数 */ + cacheRead: number; +} + +/** + * Analytics 统计数据 + */ +export interface AnalyticsStats { + /** 成本汇总数据(统计卡片) */ + summary: CostSummary | null; + /** 趋势数据(图表) */ + trends: TrendDataPoint[]; + /** Token 总数 */ + totalTokens: number; + /** Token 分布详情 */ + tokenBreakdown: TokenBreakdown; + /** 加载状态 */ + loading: boolean; + /** 错误信息 */ + error: string | null; + /** 时间范围控制 */ + timeControl: UseTimeRangeControlReturn; + /** 手动刷新 */ + refresh: () => Promise; +} + +/** + * 统一的 Analytics 数据 Hook + */ +export function useAnalyticsStats(props: UseAnalyticsStatsProps): AnalyticsStats { + const { toolId, sessionId, enabled = true } = props; + + // 时间范围控制 + const timeControl = useTimeRangeControl(); + + // 数据状态 + const [trendsData, setTrendsData] = useState([]); + const [summary, setSummary] = useState(null); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + /** + * 加载 Analytics 数据 + */ + const loadAnalyticsData = useCallback(async () => { + if (!enabled) return; + + try { + setLoading(true); + setError(null); + + // 并行查询趋势数据和成本汇总 + const [trends, summaryData] = await Promise.all([ + queryTokenTrends({ + start_time: timeControl.startTimeMs, + end_time: timeControl.endTimeMs, + tool_type: toolId, + granularity: timeControl.granularity, + session_id: sessionId, + }), + queryCostSummary(timeControl.startTimeMs, timeControl.endTimeMs, toolId, sessionId), + ]); + + // 处理响应时间的 null 值,转换为 0 以便图表连线 + const processedTrends = trends.map((point) => ({ + ...point, + avg_response_time: point.avg_response_time ?? 0, + })); + + setTrendsData(processedTrends); + setSummary(summaryData); + } catch (err: any) { + console.error('Failed to load analytics data:', err); + setError(err?.message || String(err)); + } finally { + setLoading(false); + } + }, [ + enabled, + toolId, + sessionId, + timeControl.startTimeMs, + timeControl.endTimeMs, + timeControl.granularity, + ]); + + // 自动加载数据(依赖时间范围和粒度) + useEffect(() => { + loadAnalyticsData(); + }, [loadAnalyticsData]); + + /** + * 计算 Token 总数 + */ + const totalTokens = useMemo(() => { + return trendsData.reduce((acc, point) => { + return ( + acc + + point.input_tokens + + point.output_tokens + + point.cache_creation_tokens + + point.cache_read_tokens + ); + }, 0); + }, [trendsData]); + + /** + * 计算 Token 分布 + */ + const tokenBreakdown = useMemo(() => { + return { + input: trendsData.reduce((acc, point) => acc + point.input_tokens, 0), + output: trendsData.reduce((acc, point) => acc + point.output_tokens, 0), + cacheCreation: trendsData.reduce((acc, point) => acc + point.cache_creation_tokens, 0), + cacheRead: trendsData.reduce((acc, point) => acc + point.cache_read_tokens, 0), + }; + }, [trendsData]); + + return { + summary, + trends: trendsData, + totalTokens, + tokenBreakdown, + loading, + error, + timeControl, + refresh: loadAnalyticsData, + }; +} diff --git a/src/hooks/useTimeRangeControl.ts b/src/hooks/useTimeRangeControl.ts new file mode 100644 index 0000000..7cd45a9 --- /dev/null +++ b/src/hooks/useTimeRangeControl.ts @@ -0,0 +1,206 @@ +/** + * 时间范围控制 Hook + * 统一管理预设时间范围和自定义时间范围的状态和逻辑 + */ + +import { useState, useMemo, useCallback } from 'react'; +import type { TimeRange, TimeGranularity } from '@/types/analytics'; +import { + PRESET_ALLOWED_GRANULARITIES, + calculateAllowedGranularitiesFromTimeSpan, + selectDefaultGranularity, + calculatePresetStartTime, +} from '@/utils/time-range'; + +export interface UseTimeRangeControlReturn { + // 模式控制 + mode: 'preset' | 'custom'; + setMode: (mode: 'preset' | 'custom') => void; + + // 预设时间范围 + presetRange: Exclude; + setPresetRange: (range: Exclude) => void; + + // 自定义时间范围 + customStartTime: Date | null; + customEndTime: Date | null; + setCustomStartTime: (date: Date | null) => void; + setCustomEndTime: (date: Date | null) => void; + + // 粒度控制 + granularity: TimeGranularity; + setGranularity: (g: TimeGranularity) => void; + allowedGranularities: TimeGranularity[]; + + // 计算后的时间戳 + startTimeMs: number; + endTimeMs: number; + + // 自定义时间对话框控制 + showCustomDialog: boolean; + openCustomDialog: () => void; + closeCustomDialog: () => void; + confirmCustomTime: () => void; + + // 验证状态 + isCustomTimeValid: boolean; +} + +/** + * 时间范围控制 Hook + */ +export function useTimeRangeControl(): UseTimeRangeControlReturn { + // 模式:预设或自定义 + const [mode, setMode] = useState<'preset' | 'custom'>('preset'); + + // 预设时间范围 + const [presetRange, setPresetRange] = useState>('day'); + + // 自定义时间范围(临时状态,确认后才应用) + const [customStartTime, setCustomStartTime] = useState(null); + const [customEndTime, setCustomEndTime] = useState(null); + + // 已确认的自定义时间范围 + const [confirmedCustomStart, setConfirmedCustomStart] = useState(null); + const [confirmedCustomEnd, setConfirmedCustomEnd] = useState(null); + + // 粒度 + const [granularity, setGranularity] = useState('hour'); + + // 自定义时间对话框状态 + const [showCustomDialog, setShowCustomDialog] = useState(false); + + // 计算允许的粒度选项 + const allowedGranularities = useMemo(() => { + if (mode === 'preset') { + return PRESET_ALLOWED_GRANULARITIES[presetRange]; + } else { + // 自定义模式:基于已确认的时间范围计算 + if (confirmedCustomStart && confirmedCustomEnd) { + return calculateAllowedGranularitiesFromTimeSpan( + confirmedCustomStart.getTime(), + confirmedCustomEnd.getTime(), + ); + } + return ['day'] as TimeGranularity[]; // 兜底 + } + }, [mode, presetRange, confirmedCustomStart, confirmedCustomEnd]); + + // 计算实际使用的时间戳 + const { startTimeMs, endTimeMs } = useMemo(() => { + const now = Date.now(); + + if (mode === 'preset') { + return { + startTimeMs: calculatePresetStartTime(presetRange, now), + endTimeMs: now, + }; + } else { + // 自定义模式:使用已确认的时间 + if (confirmedCustomStart && confirmedCustomEnd) { + return { + startTimeMs: confirmedCustomStart.getTime(), + endTimeMs: confirmedCustomEnd.getTime(), + }; + } + // 兜底:返回最近1天 + return { + startTimeMs: calculatePresetStartTime('day', now), + endTimeMs: now, + }; + } + }, [mode, presetRange, confirmedCustomStart, confirmedCustomEnd]); + + // 验证自定义时间是否有效 + const isCustomTimeValid = useMemo(() => { + if (!customStartTime || !customEndTime) return false; + + const startMs = customStartTime.getTime(); + const endMs = customEndTime.getTime(); + const now = Date.now(); + + if (startMs >= endMs) return false; + if (endMs > now) return false; + + const spanMs = endMs - startMs; + const DAY_90 = 90 * 24 * 60 * 60 * 1000; + if (spanMs > DAY_90) return false; + + return true; + }, [customStartTime, customEndTime]); + + // 打开自定义时间对话框 + const openCustomDialog = useCallback(() => { + // 初始化为已确认的时间,如果没有则使用最近1小时 + if (confirmedCustomStart && confirmedCustomEnd) { + setCustomStartTime(confirmedCustomStart); + setCustomEndTime(confirmedCustomEnd); + } else { + const now = new Date(); + const oneHourAgo = new Date(now.getTime() - 60 * 60 * 1000); + setCustomStartTime(oneHourAgo); + setCustomEndTime(now); + } + setShowCustomDialog(true); + }, [confirmedCustomStart, confirmedCustomEnd]); + + // 关闭对话框 + const closeCustomDialog = useCallback(() => { + setShowCustomDialog(false); + }, []); + + // 确认自定义时间 + const confirmCustomTime = useCallback(() => { + if (!isCustomTimeValid || !customStartTime || !customEndTime) return; + + // 保存已确认的时间 + setConfirmedCustomStart(customStartTime); + setConfirmedCustomEnd(customEndTime); + + // 切换到自定义模式 + setMode('custom'); + + // 计算允许的粒度并选择默认值 + const allowed = calculateAllowedGranularitiesFromTimeSpan( + customStartTime.getTime(), + customEndTime.getTime(), + ); + const defaultGranularity = selectDefaultGranularity(allowed); + setGranularity(defaultGranularity); + + // 关闭对话框 + setShowCustomDialog(false); + }, [isCustomTimeValid, customStartTime, customEndTime]); + + // 切换预设范围时自动调整粒度 + const handleSetPresetRange = useCallback((range: Exclude) => { + setPresetRange(range); + setMode('preset'); + + // 自动选择默认粒度 + const allowed = PRESET_ALLOWED_GRANULARITIES[range]; + const defaultGranularity = selectDefaultGranularity(allowed); + setGranularity(defaultGranularity); + }, []); + + return { + mode, + setMode, + presetRange, + setPresetRange: handleSetPresetRange, + customStartTime, + customEndTime, + setCustomStartTime, + setCustomEndTime, + granularity, + setGranularity, + allowedGranularities, + startTimeMs, + endTimeMs, + showCustomDialog, + openCustomDialog, + closeCustomDialog, + confirmCustomTime, + isCustomTimeValid, + }; +} diff --git a/src/lib/tauri-commands/analytics.ts b/src/lib/tauri-commands/analytics.ts new file mode 100644 index 0000000..d583b33 --- /dev/null +++ b/src/lib/tauri-commands/analytics.ts @@ -0,0 +1,36 @@ +/** + * Token 统计分析相关 Tauri 命令 + */ +import { invoke } from '@tauri-apps/api/core'; +import type { TrendQuery, TrendDataPoint, CostSummary } from '@/types/analytics'; + +/** + * 查询 Token 使用趋势数据 + * @param query 查询参数 + * @returns 趋势数据点数组 + */ +export async function queryTokenTrends(query: TrendQuery): Promise { + return await invoke('query_token_trends', { query }); +} + +/** + * 查询成本汇总数据 + * @param startTime 开始时间戳(毫秒) + * @param endTime 结束时间戳(毫秒) + * @param toolType 工具类型过滤(可选) + * @param sessionId 会话 ID 过滤(可选) + * @returns 成本汇总数据 + */ +export async function queryCostSummary( + startTime: number, + endTime: number, + toolType?: string, + sessionId?: string, +): Promise { + return await invoke('query_cost_summary', { + startTime, + endTime, + toolType, + sessionId, + }); +} diff --git a/src/lib/tauri-commands/index.ts b/src/lib/tauri-commands/index.ts index f1a96a8..4d5e9aa 100644 --- a/src/lib/tauri-commands/index.ts +++ b/src/lib/tauri-commands/index.ts @@ -28,6 +28,9 @@ export * from './dashboard'; // 会话管理 export * from './session'; +// Token 统计 +export * from './token-stats'; + // 余额监控 export * from './balance'; diff --git a/src/lib/tauri-commands/pricing.ts b/src/lib/tauri-commands/pricing.ts new file mode 100644 index 0000000..bf3f11a --- /dev/null +++ b/src/lib/tauri-commands/pricing.ts @@ -0,0 +1,75 @@ +/** + * 价格配置管理命令包装器 + * + * 提供价格模板 CRUD 操作和工具默认模板管理 + */ + +import { invoke } from '@tauri-apps/api/core'; +import type { PricingTemplate, PricingToolId } from '@/types/pricing'; + +/** + * 列出所有价格模板 + * + * @returns 所有可用价格模板的列表 + */ +export async function listPricingTemplates(): Promise { + return invoke('list_pricing_templates'); +} + +/** + * 获取指定价格模板 + * + * @param templateId - 模板 ID + * @returns 价格模板详细信息 + */ +export async function getPricingTemplate(templateId: string): Promise { + return invoke('get_pricing_template', { templateId }); +} + +/** + * 保存价格模板(创建或更新) + * + * @param template - 价格模板数据 + * + * @note + * - 如果模板 ID 已存在,将覆盖现有模板 + * - 不允许覆盖内置预设模板(is_default_preset = true) + */ +export async function savePricingTemplate(template: PricingTemplate): Promise { + return invoke('save_pricing_template', { template }); +} + +/** + * 删除价格模板 + * + * @param templateId - 模板 ID + * + * @note + * - 不允许删除内置预设模板 + */ +export async function deletePricingTemplate(templateId: string): Promise { + return invoke('delete_pricing_template', { templateId }); +} + +/** + * 设置工具的默认价格模板 + * + * @param toolId - 工具 ID(claude-code / codex / gemini-cli) + * @param templateId - 模板 ID + * + * @note + * - 模板必须存在才能设置为默认模板 + */ +export async function setDefaultTemplate(toolId: PricingToolId, templateId: string): Promise { + return invoke('set_default_template', { toolId, templateId }); +} + +/** + * 获取工具的默认价格模板 + * + * @param toolId - 工具 ID(claude-code / codex / gemini-cli) + * @returns 该工具当前使用的默认价格模板 + */ +export async function getDefaultTemplate(toolId: PricingToolId): Promise { + return invoke('get_default_template', { toolId }); +} diff --git a/src/lib/tauri-commands/session.ts b/src/lib/tauri-commands/session.ts index 9931612..800256c 100644 --- a/src/lib/tauri-commands/session.ts +++ b/src/lib/tauri-commands/session.ts @@ -45,6 +45,7 @@ export async function clearAllSessions(toolId: string): Promise { * @param customProfileName - 自定义配置名称 (global 时为 null) * @param url - API Base URL (global 时为空字符串) * @param apiKey - API Key (global 时为空字符串) + * @param pricingTemplateId - 价格模板 ID (Phase 6) */ export async function updateSessionConfig( sessionId: string, @@ -52,6 +53,7 @@ export async function updateSessionConfig( customProfileName: string | null, url: string, apiKey: string, + pricingTemplateId?: string | null, ): Promise { return await invoke('update_session_config', { sessionId, @@ -59,6 +61,7 @@ export async function updateSessionConfig( customProfileName, url, apiKey, + pricingTemplateId: pricingTemplateId || null, }); } diff --git a/src/lib/tauri-commands/token-stats.ts b/src/lib/tauri-commands/token-stats.ts new file mode 100644 index 0000000..14cdee3 --- /dev/null +++ b/src/lib/tauri-commands/token-stats.ts @@ -0,0 +1,95 @@ +// Token 统计命令模块 +// 负责透明代理的 Token 使用统计和请求日志管理 + +import { invoke } from '@tauri-apps/api/core'; +import type { + SessionStats, + TokenStatsQuery, + TokenLogsPage, + TokenStatsConfig, + DatabaseSummary, +} from '@/types/token-stats'; + +/** + * 查询会话实时统计 + * @param toolType - 工具类型 ("claude-code", "codex", "gemini-cli") + * @param sessionId - 会话 ID(UUID 格式) + * @returns 会话统计数据(输入/输出/缓存 Token 数量、请求次数) + */ +export async function getSessionStats(toolType: string, sessionId: string): Promise { + return await invoke('get_session_stats', { + toolType, + sessionId, + }); +} + +/** + * 分页查询 Token 日志 + * @param queryParams - 查询参数(工具类型、会话ID、时间范围、分页参数) + * @returns 分页查询结果(日志列表、总数、分页信息) + */ +export async function queryTokenLogs(queryParams: TokenStatsQuery): Promise { + return await invoke('query_token_logs', { + queryParams, + }); +} + +/** + * 手动清理旧日志 + * @param retentionDays - 保留天数(可选,未提供则使用配置) + * @param maxCount - 最大日志条数(可选,未提供则使用配置) + * @returns 删除的日志条数 + */ +export async function cleanupTokenLogs(retentionDays?: number, maxCount?: number): Promise { + return await invoke('cleanup_token_logs', { + retentionDays: retentionDays ?? null, + maxCount: maxCount ?? null, + }); +} + +/** + * 获取数据库统计摘要 + * @returns 数据库摘要信息(总日志数、最早/最新时间戳) + */ +export async function getTokenStatsSummary(): Promise { + const result = await invoke<[number, number | null, number | null]>('get_token_stats_summary'); + return { + total_logs: result[0], + oldest_timestamp: result[1] ?? undefined, + newest_timestamp: result[2] ?? undefined, + }; +} + +/** + * 强制执行 WAL checkpoint + * + * 将 WAL 文件中的所有数据回写到主数据库文件, + * 用于手动清理过大的 WAL 文件 + */ +export async function forceTokenStatsCheckpoint(): Promise { + return await invoke('force_token_stats_checkpoint'); +} + +/** + * 获取 Token 统计配置 + * @returns Token 统计配置(保留天数、最大条数、自动清理开关) + */ +export async function getTokenStatsConfig(): Promise { + const config = await invoke<{ token_stats_config: TokenStatsConfig }>('get_global_config'); + return config.token_stats_config; +} + +/** + * 更新 Token 统计配置 + * @param config - 新配置(部分字段更新) + */ +export async function updateTokenStatsConfig(config: Partial): Promise { + // 先获取当前完整配置 + const currentConfig = await getTokenStatsConfig(); + const updatedConfig = { ...currentConfig, ...config }; + + // 调用后端专用命令(后端会原子性地读取-修改-保存) + return await invoke('update_token_stats_config', { + config: updatedConfig, + }); +} diff --git a/src/lib/tauri-commands/token.ts b/src/lib/tauri-commands/token.ts index ad265fd..9e75928 100644 --- a/src/lib/tauri-commands/token.ts +++ b/src/lib/tauri-commands/token.ts @@ -73,6 +73,7 @@ export async function importTokenAsProfile( remoteToken: RemoteToken, toolId: string, profileName: string, + pricingTemplateId?: string, // 🆕 Phase 6: 可选的价格模板 ID ): Promise { return invoke('import_token_as_profile', { profileManager: null, // Managed by Tauri State @@ -80,6 +81,7 @@ export async function importTokenAsProfile( remoteToken, toolId, profileName, + pricingTemplateId: pricingTemplateId || null, }); } @@ -91,7 +93,7 @@ export async function createCustomProfile( profileName: string, apiKey: string, baseUrl: string, - extraConfig?: { wire_api?: string; model?: string }, + extraConfig?: { wire_api?: string; model?: string; pricing_template_id?: string }, ): Promise { return invoke('create_custom_profile', { profileManager: null, // Managed by Tauri State diff --git a/src/pages/ProfileManagementPage/components/CreateCustomProfileDialog.tsx b/src/pages/ProfileManagementPage/components/CreateCustomProfileDialog.tsx index e09e4ea..7c26cd4 100644 --- a/src/pages/ProfileManagementPage/components/CreateCustomProfileDialog.tsx +++ b/src/pages/ProfileManagementPage/components/CreateCustomProfileDialog.tsx @@ -21,6 +21,7 @@ import { Loader2, Sparkles, X } from 'lucide-react'; import { useToast } from '@/hooks/use-toast'; import type { ToolId } from '@/types/profile'; import { ProfileNameInput } from './ProfileNameInput'; +import { PricingTemplateSelector } from './PricingTemplateSelector'; import { pmListToolProfiles } from '@/lib/tauri-commands/profile'; import { createCustomProfile } from '@/lib/tauri-commands/token'; @@ -55,6 +56,7 @@ export function CreateCustomProfileDialog({ const [apiKey, setApiKey] = useState(''); const [wireApi, setWireApi] = useState('responses'); // Codex 特定 const [model, setModel] = useState(''); // Gemini 特定 + const [pricingTemplateId, setPricingTemplateId] = useState(undefined); // Phase 6: 价格模板 // UI 状态 const [creating, setCreating] = useState(false); @@ -71,6 +73,7 @@ export function CreateCustomProfileDialog({ setApiKey(''); setWireApi('responses'); setModel(''); + setPricingTemplateId(undefined); // 重置横幅状态(每次打开时都显示) setBannerVisible(true); } @@ -141,12 +144,16 @@ export function CreateCustomProfileDialog({ } // 构建额外配置 - const extraConfig: { wire_api?: string; model?: string } = {}; + const extraConfig: { wire_api?: string; model?: string; pricing_template_id?: string } = {}; if (toolId === 'codex') { extraConfig.wire_api = wireApi; } else if (toolId === 'gemini-cli' && model.trim()) { extraConfig.model = model; } + // Phase 6: 添加价格模板 ID + if (pricingTemplateId) { + extraConfig.pricing_template_id = pricingTemplateId; + } // 调用创建 API await createCustomProfile(toolId, profileName, apiKey, baseUrl, extraConfig); @@ -248,6 +255,13 @@ export function CreateCustomProfileDialog({

您的 API 访问密钥

+ {/* Phase 6: 价格模板选择器 */} + + {/* Codex 特定配置 */} {toolId === 'codex' && (
diff --git a/src/pages/ProfileManagementPage/components/ImportFromProviderDialog.tsx b/src/pages/ProfileManagementPage/components/ImportFromProviderDialog.tsx index a656960..e7c37ef 100644 --- a/src/pages/ProfileManagementPage/components/ImportFromProviderDialog.tsx +++ b/src/pages/ProfileManagementPage/components/ImportFromProviderDialog.tsx @@ -51,6 +51,7 @@ import { generateApiKeyForTool, getGlobalConfig } from '@/lib/tauri-commands'; import { DuckCodingGroupHint } from './DuckCodingGroupHint'; import { TokenDetailCard } from './TokenDetailCard'; import { ProfileNameInput } from './ProfileNameInput'; +import { PricingTemplateSelector } from './PricingTemplateSelector'; interface ImportFromProviderDialogProps { /** 对话框打开状态 */ @@ -98,6 +99,7 @@ export const ImportFromProviderDialog = forwardRef< // ==================== 共享状态 ==================== const [profileName, setProfileName] = useState(''); + const [pricingTemplateId, setPricingTemplateId] = useState(undefined); // 🆕 Phase 6: 价格模板 // ==================== 加载状态 ==================== const [loadingProviders, setLoadingProviders] = useState(false); @@ -407,7 +409,13 @@ export const ImportFromProviderDialog = forwardRef< return; } - await importTokenAsProfile(selectedProvider, selectedToken, toolId, profileName); + await importTokenAsProfile( + selectedProvider, + selectedToken, + toolId, + profileName, + pricingTemplateId, // 🆕 Phase 6: 价格模板 ID + ); toast({ title: '导入成功', description: `令牌「${selectedToken.name}」已成功导入为 Profile「${profileName}」`, @@ -554,7 +562,13 @@ export const ImportFromProviderDialog = forwardRef< const newToken = sortedTokens[0]; // 直接导入为 Profile - await importTokenAsProfile(selectedProvider, newToken, toolId, profileName); + await importTokenAsProfile( + selectedProvider, + newToken, + toolId, + profileName, + pricingTemplateId, // 🆕 Phase 6: 价格模板 ID + ); toast({ title: '导入成功', @@ -722,6 +736,13 @@ export const ImportFromProviderDialog = forwardRef< placeholder="例如: my_token_profile" /> + {/* 🆕 Phase 6: 价格模板选择器 */} + + {/* 导入按钮 */}
+ {/* Phase 6: 价格模板选择器 */} + handleChange('pricing_template_id', value || '')} + /> + {/* Codex 特定:Wire API */} {toolId === 'codex' && (
diff --git a/src/pages/ProfileManagementPage/hooks/useProfileManagement.ts b/src/pages/ProfileManagementPage/hooks/useProfileManagement.ts index 4421e58..278bd83 100644 --- a/src/pages/ProfileManagementPage/hooks/useProfileManagement.ts +++ b/src/pages/ProfileManagementPage/hooks/useProfileManagement.ts @@ -245,6 +245,7 @@ function buildProfilePayload(toolId: ProfileToolId, data: ProfileFormData): Prof type: 'claude-code', api_key: data.api_key, base_url: data.base_url, + pricing_template_id: data.pricing_template_id, // Phase 6: 价格模板 }; case 'codex': @@ -253,6 +254,7 @@ function buildProfilePayload(toolId: ProfileToolId, data: ProfileFormData): Prof api_key: data.api_key, base_url: data.base_url, wire_api: data.wire_api || 'responses', // 确保有 wire_api + pricing_template_id: data.pricing_template_id, // Phase 6: 价格模板 }; case 'gemini-cli': @@ -261,6 +263,7 @@ function buildProfilePayload(toolId: ProfileToolId, data: ProfileFormData): Prof api_key: data.api_key, base_url: data.base_url, model: data.model && data.model !== '' ? data.model : undefined, // 空值不设置 model + pricing_template_id: data.pricing_template_id, // Phase 6: 价格模板 }; default: diff --git a/src/pages/ProfileManagementPage/index.tsx b/src/pages/ProfileManagementPage/index.tsx index 0b02098..a00829c 100644 --- a/src/pages/ProfileManagementPage/index.tsx +++ b/src/pages/ProfileManagementPage/index.tsx @@ -104,6 +104,7 @@ export default function ProfileManagementPage() { base_url: editingProfile.base_url, wire_api: editingProfile.wire_api || editingProfile.provider, // 兼容两个字段名 model: editingProfile.model, + pricing_template_id: editingProfile.pricing_template_id, // Phase 6: 价格模板 }; }; diff --git a/src/pages/SettingsPage/components/ConfigGuardTab.tsx b/src/pages/SettingsPage/components/ConfigGuardTab.tsx index 535fc8f..d82be02 100644 --- a/src/pages/SettingsPage/components/ConfigGuardTab.tsx +++ b/src/pages/SettingsPage/components/ConfigGuardTab.tsx @@ -35,6 +35,7 @@ export function ConfigGuardTab() { const [fieldsDialogOpen, setFieldsDialogOpen] = useState(false); const [historyDialogOpen, setHistoryDialogOpen] = useState(false); + // 加载配置 const loadConfig = useCallback(async () => { try { setLoading(true); @@ -54,7 +55,6 @@ export function ConfigGuardTab() { } }, [toast]); - // 加载配置 useEffect(() => { loadConfig(); }, [loadConfig]); diff --git a/src/pages/SettingsPage/components/PricingTab.tsx b/src/pages/SettingsPage/components/PricingTab.tsx new file mode 100644 index 0000000..7f592da --- /dev/null +++ b/src/pages/SettingsPage/components/PricingTab.tsx @@ -0,0 +1,264 @@ +// 价格配置管理 Tab +// 管理价格模板和工具默认模板设置 + +import { useState, useEffect, useCallback } from 'react'; +import { Button } from '@/components/ui/button'; +import { Separator } from '@/components/ui/separator'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { Label } from '@/components/ui/label'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { DollarSign, Loader2, Plus } from 'lucide-react'; +import { useToast } from '@/hooks/use-toast'; +import { + listPricingTemplates, + getDefaultTemplate, + setDefaultTemplate, + deletePricingTemplate, +} from '@/lib/tauri-commands/pricing'; +import type { PricingTemplate, PricingToolId } from '@/types/pricing'; +import { TOOL_NAMES } from '@/types/pricing'; +import { TemplateCard } from './PricingTab/TemplateCard'; +import { TemplateEditorDialog } from '@/pages/SettingsPage/components/PricingTab/TemplateEditorDialog'; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from '@/components/ui/alert-dialog'; + +export function PricingTab() { + const { toast } = useToast(); + const [loading, setLoading] = useState(true); + const [templates, setTemplates] = useState([]); + const [defaultTemplates, setDefaultTemplates] = useState>({ + 'claude-code': '', + codex: '', + 'gemini-cli': '', + }); + const [editorDialogOpen, setEditorDialogOpen] = useState(false); + const [editingTemplate, setEditingTemplate] = useState(null); + const [deletingTemplateId, setDeletingTemplateId] = useState(null); + + // 加载模板列表和默认模板配置 + const loadData = useCallback(async () => { + try { + setLoading(true); + const templateList = await listPricingTemplates(); + setTemplates(templateList); + + // 加载三个工具的默认模板 + const tools: PricingToolId[] = ['claude-code', 'codex', 'gemini-cli']; + const defaults: Record = { + 'claude-code': '', + codex: '', + 'gemini-cli': '', + }; + + for (const toolId of tools) { + try { + const defaultTpl = await getDefaultTemplate(toolId); + defaults[toolId] = defaultTpl.id; + } catch (error) { + console.warn(`No default template for ${toolId}:`, error); + } + } + + setDefaultTemplates(defaults); + } catch (error) { + console.error('Failed to load pricing data:', error); + toast({ + title: '加载失败', + description: String(error), + variant: 'destructive', + }); + } finally { + setLoading(false); + } + }, [toast]); + + useEffect(() => { + loadData(); + }, [loadData]); + + // 设置工具默认模板 + const handleSetDefault = async (toolId: PricingToolId, templateId: string) => { + try { + await setDefaultTemplate(toolId, templateId); + setDefaultTemplates((prev) => ({ ...prev, [toolId]: templateId })); + toast({ + title: '设置成功', + description: `已将 ${TOOL_NAMES[toolId]} 的默认模板设置为 ${templates.find((t) => t.id === templateId)?.name}`, + }); + } catch (error) { + console.error('Failed to set default template:', error); + toast({ + title: '设置失败', + description: String(error), + variant: 'destructive', + }); + } + }; + + // 打开创建模板对话框 + const handleCreate = () => { + setEditingTemplate(null); + setEditorDialogOpen(true); + }; + + // 打开编辑模板对话框 + const handleEdit = (template: PricingTemplate) => { + setEditingTemplate(template); + setEditorDialogOpen(true); + }; + + // 保存模板(创建或更新) + const handleSave = async () => { + // 重新加载数据 + await loadData(); + setEditorDialogOpen(false); + setEditingTemplate(null); + }; + + // 删除模板 + const handleDelete = async (templateId: string) => { + setDeletingTemplateId(templateId); + }; + + // 确认删除 + const confirmDelete = async () => { + if (!deletingTemplateId) return; + + try { + await deletePricingTemplate(deletingTemplateId); + toast({ + title: '删除成功', + description: '模板已删除', + }); + await loadData(); + } catch (error) { + console.error('Failed to delete template:', error); + toast({ + title: '删除失败', + description: String(error), + variant: 'destructive', + }); + } finally { + setDeletingTemplateId(null); + } + }; + + if (loading) { + return ( +
+ + 加载配置中... +
+ ); + } + + return ( +
+ {/* 页头 */} +
+
+ +
+

价格配置管理

+

管理模型价格模板,用于 Token 成本计算

+
+
+ +
+ + + + {/* 工具默认模板选择器 */} + + + 工具默认模板 + 为每个工具指定默认价格模板(未在 Profile 中指定时使用) + + + {(['claude-code', 'codex', 'gemini-cli'] as const).map((toolId) => ( +
+ + +
+ ))} +
+
+ + {/* 模板列表(卡片式布局) */} +
+

所有模板

+
+ {templates.map((template) => ( + handleEdit(template)} + onDelete={() => handleDelete(template.id)} + /> + ))} +
+ {templates.length === 0 && ( +
+

暂无模板,点击"新建模板"创建第一个模板

+
+ )} +
+ + {/* 编辑器对话框 */} + + + {/* 删除确认对话框 */} + setDeletingTemplateId(null)}> + + + 确认删除模板 + + 此操作将永久删除该模板,且无法撤销。如果该模板正在被 Profile + 使用,将回退到工具默认模板。 + + + + 取消 + 确认删除 + + + +
+ ); +} diff --git a/src/pages/SettingsPage/components/PricingTab/CustomModelsEditor.tsx b/src/pages/SettingsPage/components/PricingTab/CustomModelsEditor.tsx new file mode 100644 index 0000000..eeed0a2 --- /dev/null +++ b/src/pages/SettingsPage/components/PricingTab/CustomModelsEditor.tsx @@ -0,0 +1,409 @@ +// 自定义模型编辑器组件 +// 展开式列表,每个模型可编辑价格和别名 + +import { useState } from 'react'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from '@/components/ui/collapsible'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { Plus, Trash2, ChevronDown, ChevronRight } from 'lucide-react'; +import type { ModelPrice } from '@/types/pricing'; + +/** + * 常见的 AI 提供商列表 + */ +const COMMON_PROVIDERS = [ + { value: 'anthropic', label: 'Anthropic (Claude)' }, + { value: 'openai', label: 'OpenAI (GPT)' }, + { value: 'google', label: 'Google (Gemini)' }, + { value: 'deepseek', label: 'DeepSeek' }, + { value: 'custom', label: '自定义提供商' }, +]; + +interface CustomModelsEditorProps { + /** 自定义模型数据 */ + data: Record; + /** 数据变更回调 */ + onChange: (data: Record) => void; + /** 只读模式(用于内置模板查看) */ + readOnly?: boolean; +} + +export function CustomModelsEditor({ data, onChange, readOnly = false }: CustomModelsEditorProps) { + const [expandedModels, setExpandedModels] = useState>(new Set()); + const [newModelName, setNewModelName] = useState(''); + const [newAliases, setNewAliases] = useState>({}); // 每个模型的新别名输入 + + const modelNames = Object.keys(data); + + // 切换展开状态 + const toggleExpand = (modelName: string) => { + const newSet = new Set(expandedModels); + if (newSet.has(modelName)) { + newSet.delete(modelName); + } else { + newSet.add(modelName); + } + setExpandedModels(newSet); + }; + + // 添加新模型 + const handleAddModel = () => { + if (!newModelName.trim()) return; + + const newModel: ModelPrice = { + provider: 'custom', + input_price_per_1m: 0, + output_price_per_1m: 0, + cache_write_price_per_1m: 0, + cache_read_price_per_1m: 0, + currency: 'USD', + aliases: [], + }; + + onChange({ + ...data, + [newModelName.trim()]: newModel, + }); + + // 自动展开新添加的模型 + const newSet = new Set(expandedModels); + newSet.add(newModelName.trim()); + setExpandedModels(newSet); + + setNewModelName(''); + }; + + // 删除模型 + const handleDeleteModel = (modelName: string) => { + const newData = { ...data }; + delete newData[modelName]; + onChange(newData); + + // 从展开列表中移除 + const newSet = new Set(expandedModels); + newSet.delete(modelName); + setExpandedModels(newSet); + }; + + // 更新模型字段 + const handleUpdateField = ( + modelName: string, + field: keyof ModelPrice, + value: string | number | string[], + ) => { + const newData = { + ...data, + [modelName]: { + ...data[modelName], + [field]: value, + }, + }; + onChange(newData); + }; + + // 添加别名 + const handleAddAlias = (modelName: string) => { + const newAlias = newAliases[modelName]?.trim(); + if (!newAlias) return; + + const currentAliases = data[modelName].aliases; + if (currentAliases.includes(newAlias)) { + // 别名已存在,不重复添加 + return; + } + + handleUpdateField(modelName, 'aliases', [...currentAliases, newAlias]); + // 清空输入框 + setNewAliases((prev) => ({ ...prev, [modelName]: '' })); + }; + + // 删除别名 + const handleDeleteAlias = (modelName: string, aliasIndex: number) => { + const currentAliases = data[modelName].aliases; + const newAliases = currentAliases.filter((_, i) => i !== aliasIndex); + handleUpdateField(modelName, 'aliases', newAliases); + }; + + return ( +
+
+
+

自定义模型

+

直接定义模型价格,不依赖继承

+
+
+ + {/* 添加新模型 */} + {!readOnly && ( + + + 添加新模型 + + +
+ setNewModelName(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter') { + handleAddModel(); + } + }} + /> + +
+
+
+ )} + + {/* 模型列表 */} + {modelNames.length === 0 ? ( +
+ {readOnly ? '此模板无自定义模型' : '暂无自定义模型,在上方输入框添加新模型'} +
+ ) : ( +
+ {modelNames.map((modelName) => { + const model = data[modelName]; + const isExpanded = expandedModels.has(modelName); + + return ( + toggleExpand(modelName)} + > + + + +
+
+ {isExpanded ? ( + + ) : ( + + )} + {modelName} +
+
+ 输入: ${model.input_price_per_1m.toFixed(2)}/1M + 输出: ${model.output_price_per_1m.toFixed(2)}/1M + {!readOnly && ( + + )} +
+
+
+
+ + + + {/* 提供商 */} +
+ + +
+ + {/* 货币 */} +
+ + handleUpdateField(modelName, 'currency', e.target.value)} + placeholder="USD" + disabled={readOnly} + /> +
+ + {/* 输入价格 */} +
+ + + handleUpdateField( + modelName, + 'input_price_per_1m', + parseFloat(e.target.value) || 0, + ) + } + disabled={readOnly} + /> +
+ + {/* 输出价格 */} +
+ + + handleUpdateField( + modelName, + 'output_price_per_1m', + parseFloat(e.target.value) || 0, + ) + } + disabled={readOnly} + /> +
+ + {/* 缓存写入价格 */} +
+ + + handleUpdateField( + modelName, + 'cache_write_price_per_1m', + e.target.value ? parseFloat(e.target.value) : 0, + ) + } + placeholder="留空表示不支持" + disabled={readOnly} + /> +
+ + {/* 缓存读取价格 */} +
+ + + handleUpdateField( + modelName, + 'cache_read_price_per_1m', + e.target.value ? parseFloat(e.target.value) : 0, + ) + } + placeholder="留空表示不支持" + disabled={readOnly} + /> +
+ + {/* 别名列表 */} +
+ +
+ {/* 添加别名输入框 */} + {!readOnly && ( +
+ + setNewAliases((prev) => ({ + ...prev, + [modelName]: e.target.value, + })) + } + onKeyDown={(e) => { + if (e.key === 'Enter') { + e.preventDefault(); + handleAddAlias(modelName); + } + }} + placeholder="输入别名(如:claude-sonnet-4-5)" + /> + +
+ )} + + {/* 别名列表展示 */} + {model.aliases.length > 0 ? ( +
+ {model.aliases.map((alias, index) => ( +
+ {alias} + {!readOnly && ( + + )} +
+ ))} +
+ ) : ( +
+ {readOnly ? '无别名' : '暂无别名,在上方输入框添加'} +
+ )} +

+ 用于匹配不同格式的模型 ID(如 claude-sonnet-4-5-20250929) +

+
+
+
+
+
+
+ ); + })} +
+ )} +
+ ); +} diff --git a/src/pages/SettingsPage/components/PricingTab/InheritedModelsTable.tsx b/src/pages/SettingsPage/components/PricingTab/InheritedModelsTable.tsx new file mode 100644 index 0000000..7b55013 --- /dev/null +++ b/src/pages/SettingsPage/components/PricingTab/InheritedModelsTable.tsx @@ -0,0 +1,268 @@ +// 继承配置表格组件 +// 支持多源继承:每个模型可从不同模板继承,并设置倍率 + +import { useState } from 'react'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table'; +import { Plus, Trash2 } from 'lucide-react'; +import type { InheritedModel, PricingTemplate } from '@/types/pricing'; + +interface InheritedModelsTableProps { + /** 继承配置数据 */ + data: InheritedModel[]; + /** 数据变更回调 */ + onChange: (data: InheritedModel[]) => void; + /** 可用的模板列表(用于选择源模板) */ + availableTemplates: PricingTemplate[]; + /** 只读模式(用于内置模板查看) */ + readOnly?: boolean; +} + +/** + * 从所有可用模板中提取唯一的模型名称列表 + */ +function extractAvailableModels(templates: PricingTemplate[]): string[] { + const modelSet = new Set(); + + templates.forEach((template) => { + // 从继承模型中提取 + template.inherited_models.forEach((m) => { + if (m.model_name) modelSet.add(m.model_name); + }); + // 从自定义模型中提取 + Object.keys(template.custom_models).forEach((modelName) => { + modelSet.add(modelName); + }); + }); + + return Array.from(modelSet).sort(); +} + +export function InheritedModelsTable({ + data, + onChange, + availableTemplates, + readOnly = false, +}: InheritedModelsTableProps) { + const [editingIndex, setEditingIndex] = useState(null); + const [editingRow, setEditingRow] = useState(null); + + // 提取所有可用的模型名称 + const availableModels = extractAvailableModels(availableTemplates); + + // 添加新行 + const handleAdd = () => { + const newRow: InheritedModel = { + model_name: '', + source_template_id: availableTemplates[0]?.id || '', + multiplier: 1.0, + }; + onChange([...data, newRow]); + setEditingIndex(data.length); + setEditingRow(newRow); + }; + + // 删除行 + const handleDelete = (index: number) => { + const newData = data.filter((_, i) => i !== index); + onChange(newData); + if (editingIndex === index) { + setEditingIndex(null); + setEditingRow(null); + } + }; + + // 开始编辑 + const handleEdit = (index: number) => { + setEditingIndex(index); + setEditingRow({ ...data[index] }); + }; + + // 保存编辑 + const handleSave = () => { + if (editingIndex !== null && editingRow) { + const newData = [...data]; + newData[editingIndex] = editingRow; + onChange(newData); + setEditingIndex(null); + setEditingRow(null); + } + }; + + // 取消编辑 + const handleCancel = () => { + if ( + editingIndex !== null && + editingIndex === data.length - 1 && + !data[editingIndex].model_name + ) { + // 如果是新添加的空行,取消时删除 + onChange(data.slice(0, -1)); + } + setEditingIndex(null); + setEditingRow(null); + }; + + return ( +
+
+
+

继承配置

+

每个模型可从不同模板继承,并应用倍率

+
+ {!readOnly && ( + + )} +
+ + {data.length === 0 ? ( +
+ {readOnly ? '此模板无继承配置' : '暂无继承配置,点击"添加模型"开始配置'} +
+ ) : ( +
+ + + + 模型名称 + 源模板 + 倍率 + {!readOnly && 操作} + + + + {data.map((row, index) => + editingIndex === index && editingRow && !readOnly ? ( + + + + + + + + + + setEditingRow({ + ...editingRow, + multiplier: parseFloat(e.target.value) || 0, + }) + } + className="h-8" + /> + + +
+ + +
+
+
+ ) : ( + + {row.model_name} + + {availableTemplates.find((t) => t.id === row.source_template_id)?.name || + row.source_template_id} + + {row.multiplier.toFixed(2)} + {!readOnly && ( + +
+ + +
+
+ )} +
+ ), + )} +
+
+
+ )} +
+ ); +} diff --git a/src/pages/SettingsPage/components/PricingTab/TemplateCard.tsx b/src/pages/SettingsPage/components/PricingTab/TemplateCard.tsx new file mode 100644 index 0000000..a2833d3 --- /dev/null +++ b/src/pages/SettingsPage/components/PricingTab/TemplateCard.tsx @@ -0,0 +1,91 @@ +// 价格模板卡片组件 +// 展示单个模板的信息和操作按钮 + +import { + Card, + CardContent, + CardFooter, + CardHeader, + CardTitle, + CardDescription, +} from '@/components/ui/card'; +import { Button } from '@/components/ui/button'; +import { Badge } from '@/components/ui/badge'; +import { Edit, Trash2 } from 'lucide-react'; +import type { PricingTemplate } from '@/types/pricing'; +import { + getTemplateMode, + getTemplateModeName, + getTotalModelCount, + formatTemplateTimestamp, +} from '@/types/pricing'; +import { cn } from '@/lib/utils'; + +interface TemplateCardProps { + /** 模板数据 */ + template: PricingTemplate; + /** 编辑回调 */ + onEdit: () => void; + /** 删除回调 */ + onDelete: () => void; +} + +export function TemplateCard({ template, onEdit, onDelete }: TemplateCardProps) { + const mode = getTemplateMode(template); + const isDefaultPreset = template.is_default_preset; + + return ( + + {/* 官方模板标记 */} + {isDefaultPreset && ( + + 官方模板 + + )} + + + {template.name} + {template.description} + + + + {/* 模板信息 */} +
+
+ 模式: + {getTemplateModeName(mode)} +
+
+ 模型数: + {getTotalModelCount(template)} +
+
+ 创建时间: + + {formatTemplateTimestamp(template.created_at)} + +
+
+
+ + + + {!isDefaultPreset && ( + + )} + +
+ ); +} diff --git a/src/pages/SettingsPage/components/PricingTab/TemplateEditorDialog.tsx b/src/pages/SettingsPage/components/PricingTab/TemplateEditorDialog.tsx new file mode 100644 index 0000000..ddecde6 --- /dev/null +++ b/src/pages/SettingsPage/components/PricingTab/TemplateEditorDialog.tsx @@ -0,0 +1,307 @@ +// 价格模板编辑器对话框 +// 创建或编辑价格模板(可视化表单) + +import { useState, useEffect } from 'react'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; +import { Textarea } from '@/components/ui/textarea'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; +import { Alert, AlertDescription } from '@/components/ui/alert'; +import { Loader2, AlertCircle } from 'lucide-react'; +import { useToast } from '@/hooks/use-toast'; +import { savePricingTemplate, listPricingTemplates } from '@/lib/tauri-commands/pricing'; +import type { PricingTemplate, InheritedModel, ModelPrice } from '@/types/pricing'; +import { InheritedModelsTable } from './InheritedModelsTable'; +import { CustomModelsEditor } from './CustomModelsEditor'; + +interface TemplateEditorDialogProps { + /** 对话框打开状态 */ + open: boolean; + /** 对话框状态变更回调 */ + onOpenChange: (open: boolean) => void; + /** 编辑的模板(null 表示新建) */ + template: PricingTemplate | null; + /** 保存成功回调 */ + onSave: () => void; +} + +export function TemplateEditorDialog({ + open, + onOpenChange, + template, + onSave, +}: TemplateEditorDialogProps) { + const { toast } = useToast(); + const [saving, setSaving] = useState(false); + const [name, setName] = useState(''); + const [description, setDescription] = useState(''); + const [inheritedModels, setInheritedModels] = useState([]); + const [customModels, setCustomModels] = useState>({}); + const [availableTemplates, setAvailableTemplates] = useState([]); + const [loadingTemplates, setLoadingTemplates] = useState(false); + + // 判断是否为只读模式(内置模板) + const isReadOnly = template?.is_default_preset || false; + + // 加载可用模板列表(用于继承配置) + useEffect(() => { + if (open) { + setLoadingTemplates(true); + listPricingTemplates() + .then(setAvailableTemplates) + .catch((error) => { + console.error('Failed to load templates:', error); + }) + .finally(() => setLoadingTemplates(false)); + } + }, [open]); + + // 初始化表单数据 + useEffect(() => { + if (!open) return; + + if (template) { + // 编辑模式 + setName(template.name); + setDescription(template.description); + setInheritedModels(template.inherited_models); + setCustomModels(template.custom_models); + } else { + // 新建模式 + setName(''); + setDescription(''); + setInheritedModels([]); + setCustomModels({}); + } + }, [open, template]); + + // 保存模板 + const handleSave = async () => { + // 验证基础字段 + if (!name.trim()) { + toast({ + title: '验证失败', + description: '请输入模板名称', + variant: 'destructive', + }); + return; + } + + // 验证至少有一种配置 + if (inheritedModels.length === 0 && Object.keys(customModels).length === 0) { + toast({ + title: '验证失败', + description: '请至少配置继承模型或自定义模型', + variant: 'destructive', + }); + return; + } + + // 验证继承配置的完整性 + for (const inherited of inheritedModels) { + if (!inherited.model_name.trim()) { + toast({ + title: '验证失败', + description: '继承配置中存在空的模型名称', + variant: 'destructive', + }); + return; + } + if (!inherited.source_template_id) { + toast({ + title: '验证失败', + description: `模型 ${inherited.model_name} 未选择源模板`, + variant: 'destructive', + }); + return; + } + if (inherited.multiplier <= 0) { + toast({ + title: '验证失败', + description: `模型 ${inherited.model_name} 的倍率必须大于 0`, + variant: 'destructive', + }); + return; + } + } + + try { + setSaving(true); + + const now = Date.now(); + const newTemplate: PricingTemplate = { + id: template?.id || generateTemplateId(name), + name: name.trim(), + description: description.trim(), + version: template?.version || '1.0.0', + created_at: template?.created_at || now, + updated_at: now, + inherited_models: inheritedModels, + custom_models: customModels, + tags: template?.tags || [], + is_default_preset: template?.is_default_preset || false, + }; + + await savePricingTemplate(newTemplate); + + toast({ + title: '保存成功', + description: template ? '模板已更新' : '模板已创建', + }); + + onSave(); + } catch (error) { + console.error('Failed to save template:', error); + toast({ + title: '保存失败', + description: String(error), + variant: 'destructive', + }); + } finally { + setSaving(false); + } + }; + + // 生成模板 ID + const generateTemplateId = (templateName: string): string => { + const sanitizedName = templateName + .toLowerCase() + .replace(/[^a-z0-9]+/g, '_') + .replace(/^_|_$/g, ''); + const timestamp = Date.now().toString(36); + return `${sanitizedName}_${timestamp}`; + }; + + return ( + + + + + {isReadOnly ? '查看模板(只读)' : template ? '编辑模板' : '新建模板'} + + + {isReadOnly + ? '此为内置预设模板,仅供查看,无法编辑' + : '配置模型价格模板,支持完全自定义、继承模式和混合模式'} + + + + + + 基础信息 + + 继承配置 {inheritedModels.length > 0 && `(${inheritedModels.length})`} + + + 自定义模型{' '} + {Object.keys(customModels).length > 0 && `(${Object.keys(customModels).length})`} + + + + {/* Tab 1: 基础信息 */} + + {isReadOnly && ( + + + + 此模板为内置预设模板,仅供查看。如需修改,请创建新模板或从此模板继承。 + + + )} +
+ + setName(e.target.value)} + disabled={isReadOnly} + /> +
+
+ +