diff --git a/backend-architecture.md b/backend-architecture.md new file mode 100644 index 0000000..2337000 --- /dev/null +++ b/backend-architecture.md @@ -0,0 +1,290 @@ +# Free-Agent-Vtuber 后端架构图 + +## 整体架构概览 + +```mermaid +graph TB + %% 外部接口层 + subgraph "外部接口" + WebClient[Web 客户端] + API[API 调用] + end + + %% 网关层 + subgraph "API 网关层" + Gateway[Gateway Service
FastAPI + Flask
:8000] + end + + %% 输入输出处理层 + subgraph "输入输出处理层" + InputHandler[Input Handler
WebSocket Server
:8001] + OutputHandler[Output Handler
WebSocket Server
:8002] + end + + %% 核心服务层 + subgraph "核心处理服务" + ASR[ASR Service
语音识别
OpenAI Whisper] + Memory[Memory Service
短期记忆管理
对话上下文] + ChatAI[Chat AI Service
AI 对话生成
OpenAI GPT] + TTS[TTS Service
语音合成
Edge-TTS] + LTM[Long Term Memory
长期记忆
向量搜索] + end + + %% 消息总线 + subgraph "消息总线" + Redis[(Redis
事件驱动消息总线
:6379)] + end + + %% 数据存储层 + subgraph "数据存储" + PostgreSQL[(PostgreSQL
pgvector
:5432)] + MemoryData[(Memory Data
对话历史)] + TempFiles[(Temp Files
音频文件)] + end + + %% 管理层 + subgraph "管理与监控" + Manager[Manager Service
Flask Web UI
:5000] + end + + %% 连接关系 + WebClient --> Gateway + API --> Gateway + + Gateway --> InputHandler + Gateway --> OutputHandler + Gateway --> Redis + + InputHandler --> Redis + OutputHandler --> Redis + + ASR --> Redis + Memory --> Redis + ChatAI --> Redis + TTS --> Redis + LTM --> Redis + + Memory --> MemoryData + LTM --> PostgreSQL + TTS --> TempFiles + ASR --> TempFiles + + Manager --> Redis + Manager -.-> ASR + Manager -.-> Memory + Manager -.-> ChatAI + Manager -.-> TTS + Manager -.-> LTM + + %% 样式 + classDef service fill:#e1f5fe,stroke:#0277bd,stroke-width:2px + classDef storage fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + classDef message fill:#fff3e0,stroke:#ef6c00,stroke-width:2px + classDef external fill:#e8f5e8,stroke:#2e7d32,stroke-width:2px + + class Gateway,InputHandler,OutputHandler,ASR,Memory,ChatAI,TTS,LTM,Manager service + class PostgreSQL,MemoryData,TempFiles storage + class Redis message + class WebClient,API external +``` + +## 消息流架构详图 + +```mermaid +sequenceDiagram + participant Client as 客户端 + participant GW as Gateway + participant IH as Input Handler + participant ASR as ASR Service + participant MEM as Memory Service + participant AI as Chat AI Service + participant LTM as Long Term Memory + participant TTS as TTS Service + participant OH as Output Handler + participant Redis as Redis Bus + + Note over Client,Redis: 完整的语音交互流程 + + %% 1. 语音输入阶段 + Client->>GW: POST /api/asr (音频文件) + GW->>Redis: LPUSH asr_tasks + Redis->>ASR: 消费 asr_tasks + ASR-->>Redis: PUBLISH asr_results + + %% 2. 输入处理阶段 + Redis->>IH: 订阅 asr_results + IH->>Redis: LPUSH user_input_queue + + %% 3. 记忆处理阶段 + Redis->>MEM: 消费 user_input_queue + MEM-->>Redis: PUBLISH memory_updates + + %% 4. AI对话阶段 + Redis->>AI: 订阅 memory_updates + + %% 5. 长期记忆查询(可选) + alt 启用长期记忆 + AI->>Redis: LPUSH ltm_requests + Redis->>LTM: 消费 ltm_requests + LTM-->>Redis: PUBLISH ltm_responses + Redis->>AI: 订阅 ltm_responses + end + + %% 6. AI响应生成 + AI-->>Redis: PUBLISH ai_responses + AI->>Redis: LPUSH tts_requests + + %% 7. 语音合成阶段 + Redis->>TTS: 消费 tts_requests + TTS-->>Redis: PUBLISH task_response:task_id + + %% 8. 输出处理阶段 + Redis->>OH: 订阅 task_response:* + OH->>GW: WebSocket 推送 + GW->>Client: 返回合成语音 +``` + +## 核心服务详细说明 + +### 1. Gateway Service (网关服务) +- **技术栈**: FastAPI + Flask Blueprint +- **端口**: 8000 +- **职责**: + - 提供统一的API入口 + - 处理HTTP请求和WebSocket连接 + - 路由ASR请求到消息队列 + - 管理客户端连接状态 + +### 2. Input Handler (输入处理服务) +- **技术栈**: Python + WebSocket +- **端口**: 8001 +- **职责**: + - 订阅ASR识别结果 + - 标准化用户输入格式 + - 将处理后的输入推送到用户输入队列 + +### 3. ASR Service (语音识别服务) +- **技术栈**: Python + OpenAI Whisper +- **职责**: + - 消费语音识别任务队列 + - 将音频文件转换为文本 + - 发布识别结果到消息总线 + +### 4. Memory Service (记忆管理服务) +- **技术栈**: Python +- **职责**: + - 管理对话上下文和短期记忆 + - 维护用户会话状态 + - 为AI服务提供对话历史 + +### 5. Chat AI Service (AI对话服务) +- **技术栈**: Python + OpenAI GPT +- **职责**: + - 生成AI回复 + - 整合长期记忆内容(如果启用) + - 触发语音合成请求 + +### 6. Long Term Memory Service (长期记忆服务) +- **技术栈**: Python + pgvector + mem0 +- **职责**: + - 存储和检索长期记忆 + - 向量相似度搜索 + - 为对话提供相关历史信息 + +### 7. TTS Service (语音合成服务) +- **技术栈**: Python + Edge-TTS +- **职责**: + - 将文本转换为语音 + - 生成音频文件 + - 发布合成结果 + +### 8. Output Handler (输出处理服务) +- **技术栈**: Python + WebSocket +- **端口**: 8002 +- **职责**: + - 处理服务输出 + - 通过WebSocket推送结果到客户端 + +## 数据存储架构 + +### Redis 消息总线 +- **队列 (List)**: + - `asr_tasks`: ASR识别任务 + - `user_input_queue`: 用户输入队列 + - `tts_requests`: TTS合成请求 + - `ltm_requests`: 长期记忆查询请求 + +- **发布/订阅 (Pub/Sub)**: + - `asr_results`: ASR识别结果 + - `memory_updates`: 记忆更新通知 + - `ai_responses`: AI回复 + - `ltm_responses`: 长期记忆查询结果 + - `task_response:{task_id}`: 任务响应 + +### PostgreSQL + pgvector +- 存储长期记忆向量数据 +- 支持语义相似度搜索 +- 维护用户历史对话记录 + +### 文件存储 +- 临时音频文件存储 (`/tmp/aivtuber_tasks`) +- 内存数据持久化 (`memory_data`) +- 长期记忆数据 (`ltm_data`) + +## 架构特点 + +1. **事件驱动**: 基于Redis的消息总线实现松耦合 +2. **微服务**: 每个功能模块独立部署和扩展 +3. **异步处理**: 支持并发处理多个用户请求 +4. **可插拔**: 支持不同的AI、TTS、ASR提供商 +5. **容器化**: 所有服务均支持Docker部署 +6. **可观测**: 完整的日志和监控体系 + +## 部署架构 + +```mermaid +graph TB + subgraph "Docker Network: aivtuber-network" + subgraph "计算服务" + GW[gateway:8000] + IH[input-handler:8001] + OH[output-handler:8002] + ASR[asr] + MEM[memory] + AI[chat-ai] + TTS[tts] + LTM[long-term-memory] + end + + subgraph "基础设施" + Redis[redis:6379] + PG[postgres:5432] + end + + subgraph "存储卷" + RedisData[redis_data] + PostgresData[postgres_data] + MemoryData[memory_data] + LTMData[ltm_data] + TempFiles[temp_files] + end + end + + subgraph "外部" + Client[客户端] + Manager[Manager:5000
可选管理界面] + end + + Client --> GW + Manager -.-> GW + + Redis --- RedisData + PG --- PostgresData + MEM --- MemoryData + LTM --- LTMData + TTS --- TempFiles + ASR --- TempFiles + OH --- TempFiles +``` + +这个架构设计实现了高度模块化、可扩展的AI虚拟主播系统,通过事件驱动的方式确保了系统的松耦合和高可用性。 \ No newline at end of file diff --git a/services/dialog-engine/src/dialog_engine/app.py b/services/dialog-engine/src/dialog_engine/app.py index a44cd92..69462d5 100644 --- a/services/dialog-engine/src/dialog_engine/app.py +++ b/services/dialog-engine/src/dialog_engine/app.py @@ -17,11 +17,22 @@ from .asr import AsrOptions, AsrService from .tts_streamer import stream_text as tts_stream_text from .ltm_outbox import add_event as outbox_add_event, start_flush_task as outbox_start_flush +from .internal_state_store import InternalStateStore app = FastAPI() -chat_service = ChatService() logger = logging.getLogger(__name__) + +# Initialize internal state store +try: + import os + db_path = os.getenv("INTERNAL_STATE_DB_PATH", "internal_states.db") + state_store = InternalStateStore(db_path=db_path) +except Exception as exc: + logger.exception("Failed to initialize InternalStateStore", extra={"error": repr(exc)}) + state_store = None + +chat_service = ChatService(state_store=state_store) SYNC_TTS_STREAMING = os.getenv("SYNC_TTS_STREAMING", "false").lower() in {"1", "true", "yes", "on"} ENABLE_ASYNC_EXT = os.getenv("ENABLE_ASYNC_EXT", "false").lower() in {"1", "true", "yes", "on"} VISION_MAX_BYTES = int(os.getenv("VISION_MAX_BYTES", 4 * 1024 * 1024)) @@ -192,6 +203,12 @@ async def event_generator() -> AsyncGenerator[bytes, None]: return stats = {"ttft_ms": round(ttft_ms or 0.0, 1), "tokens": chat_service.last_token_count} + + # Include internal states in the done event + internal_states = await chat_service.get_internal_states(session_id) + if internal_states: + stats["internal_states"] = internal_states + yield _sse_format("done", {"stats": stats}) # Emit async events via outbox @@ -418,6 +435,11 @@ async def event_generator() -> AsyncGenerator[bytes, None]: "stats": stats, } + # Include internal states in the done event + internal_states = await chat_service.get_internal_states(session_id) + if internal_states: + stats["internal_states"] = internal_states + yield _sse_format("done", done_payload) _emit_async_events( diff --git a/services/dialog-engine/src/dialog_engine/chat_service.py b/services/dialog-engine/src/dialog_engine/chat_service.py index 75746eb..5dfb4cf 100644 --- a/services/dialog-engine/src/dialog_engine/chat_service.py +++ b/services/dialog-engine/src/dialog_engine/chat_service.py @@ -10,6 +10,8 @@ from .ltm_client import LTMInlineClient from .memory_store import MemoryTurn, ShortTermMemoryStore from .settings import Settings, settings as runtime_settings +from .internal_state_store import InternalStateStore +from .llm_functions import FUNCTION_DEFINITIONS, handle_tool_call class ChatService: @@ -22,12 +24,14 @@ def __init__( llm_client_factory: Optional[Callable[[], OpenAIChatClient]] = None, memory_store: Optional[ShortTermMemoryStore] = None, ltm_client: Optional[LTMInlineClient] = None, + state_store: Optional[InternalStateStore] = None, ) -> None: self._settings = settings or runtime_settings self._llm_client_factory = llm_client_factory self._llm_client: Optional[OpenAIChatClient] = None self._memory_store = memory_store self._ltm_client = ltm_client + self._state_store = state_store self.last_token_count: int = 0 self.last_ttft_ms: Optional[float] = None @@ -71,6 +75,9 @@ async def stream_reply( return except LLMStreamEmptyError as exc: self.last_error = "llm_empty_stream" + # Process tool calls if present + if exc.tool_calls and self._state_store: + await self._process_tool_calls(exc.tool_calls, session_id) self._log_llm_fallback(reason=f"empty_stream:{exc.tool_calls}") except LLMNotConfiguredError as exc: self.last_error = "llm_not_configured" @@ -163,15 +170,23 @@ async def _stream_llm( ltm_snippets: List[str], ) -> AsyncGenerator[str, None]: client = await self._ensure_llm_client() - messages = self._compose_messages( + meta_with_session = dict(meta) + meta_with_session["session_id"] = session_id + messages = await self._compose_messages( user_text=user_text, - meta=meta, + meta=meta_with_session, context=context, ltm_snippets=ltm_snippets, ) extra_options: Dict[str, Any] = { "extra_headers": {"x-session-id": session_id}, } + + # Add function calling support if state store is available + if self._state_store: + extra_options["functions"] = FUNCTION_DEFINITIONS + extra_options["tool_choice"] = "auto" + async for delta in client.stream_chat(messages, extra_options=extra_options): yield delta @@ -198,9 +213,11 @@ async def _generate_vision_reply( ltm_snippets: List[str], ) -> str: client = await self._ensure_llm_client() - messages = self._compose_messages( + meta_with_session = dict(meta) + meta_with_session["session_id"] = session_id + messages = await self._compose_messages( user_text=prompt_text, - meta=meta, + meta=meta_with_session, context=context, ltm_snippets=ltm_snippets, ) @@ -261,7 +278,7 @@ async def _ensure_llm_client(self) -> OpenAIChatClient: self._llm_client = client return client - def _compose_messages( + async def _compose_messages( self, *, user_text: str, @@ -274,6 +291,18 @@ def _compose_messages( if system_prompt: messages.append({"role": "system", "content": str(system_prompt)}) + # Inject internal states as context if available + if self._state_store: + session_id = meta.get("session_id", "default") + state_dict = await self.get_internal_states(session_id) + if state_dict: + mood_summary = "; ".join([f"{k}:{v:.2f}" for k, v in state_dict.items()]) + state_message = { + "role": "system", + "content": f"当前内部状态:{mood_summary}。请据此调整语气与行为。" + } + messages.append(state_message) + for turn in context: role = turn.role if turn.role in {"user", "assistant", "system"} else "assistant" messages.append({"role": role, "content": turn.content}) @@ -292,6 +321,30 @@ def _reset_metrics(self) -> None: self.last_source = "mock" self.last_error = None + async def _process_tool_calls(self, tool_calls: List[Any], session_id: str) -> None: + """Process tool calls from LLM to update internal states.""" + if not self._state_store: + return + + for tool_call in tool_calls: + try: + call_info = { + "name": getattr(tool_call, "function", {}).get("name"), + "arguments": getattr(tool_call, "function", {}).get("arguments", "{}") + } + await handle_tool_call(call_info, session_id, self._state_store) + except Exception as exc: + self._log_context_warning("tool_call.error", exc) + + async def get_internal_states(self, session_id: str) -> Dict[str, float]: + """Get current internal states for a session.""" + if not self._state_store: + return {} + try: + return await self._state_store.list_states(session_id) + except Exception: + return {} + def _log_llm_fallback(self, *, reason: str) -> None: # Deliberately late import to avoid global logging setup requirements. from logging import getLogger diff --git a/services/dialog-engine/src/dialog_engine/internal_state_store.py b/services/dialog-engine/src/dialog_engine/internal_state_store.py new file mode 100644 index 0000000..ea0c07b --- /dev/null +++ b/services/dialog-engine/src/dialog_engine/internal_state_store.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +"""Internal state storage for AI emotions and affinity tracking.""" + +import asyncio +import os +import sqlite3 +import time +from contextlib import contextmanager +from typing import Dict, Optional + +import logging + +logger = logging.getLogger(__name__) + + +class InternalStateStore: + """SQLite-backed store for AI internal states (emotion, affinity, etc.).""" + + def __init__(self, *, db_path: str) -> None: + self._db_path = db_path + self._ensure_table_exists() + + @contextmanager + def _get_connection(self, row_factory: bool = False): + """Context manager for database connections to ensure proper cleanup.""" + if not self._db_path: + raise RuntimeError("Database path not configured") + + try: + conn = sqlite3.connect(self._db_path) + if row_factory: + conn.row_factory = sqlite3.Row + yield conn + except Exception as exc: + logger.debug("internal_states.connect.error", exc_info=True) + raise RuntimeError("failed to open internal_states database") from exc + finally: + if 'conn' in locals(): + conn.close() + + def _ensure_table_exists(self) -> None: + """Create the internal_states table if it doesn't exist.""" + if not self._db_path: + return + + os.makedirs(os.path.dirname(self._db_path) or ".", exist_ok=True) + + with self._get_connection() as conn: + try: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS internal_states ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + state_key TEXT NOT NULL, + state_value REAL NOT NULL, + updated_at INTEGER NOT NULL, + UNIQUE(session_id, state_key) + ) + """, + ) + # Create index for efficient queries + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_internal_states_session_key + ON internal_states(session_id, state_key) + """, + ) + conn.commit() + except Exception as exc: + logger.error("Failed to create internal_states table", exc_info=True) + raise RuntimeError("failed to create internal_states table") from exc + + async def get_state(self, session_id: str, state_key: str) -> Optional[float]: + """Get a specific state value for a session.""" + if not self._db_path or not os.path.exists(self._db_path): + return None + + def _query() -> Optional[float]: + try: + with self._get_connection(row_factory=True) as conn: + row = conn.execute( + """ + SELECT state_value + FROM internal_states + WHERE session_id = ? AND state_key = ? + """, + (session_id, state_key), + ).fetchone() + return float(row["state_value"]) if row else None + except Exception as exc: + logger.debug("internal_states.query.error", exc_info=True) + raise RuntimeError("failed to query internal_states database") from exc + + try: + return await asyncio.to_thread(_query) + except (FileNotFoundError, RuntimeError): + return None + + async def update_state(self, session_id: str, state_key: str, new_value: float) -> None: + """Update or insert a state value for a session.""" + if not self._db_path: + return + + def _upsert() -> None: + os.makedirs(os.path.dirname(self._db_path) or ".", exist_ok=True) + try: + with self._get_connection() as conn: + cur = conn.cursor() + # Ensure table exists (defensive) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS internal_states ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + state_key TEXT NOT NULL, + state_value REAL NOT NULL, + updated_at INTEGER NOT NULL, + UNIQUE(session_id, state_key) + ) + """, + ) + + # Insert or replace the state value + cur.execute( + """ + INSERT OR REPLACE INTO internal_states(session_id, state_key, state_value, updated_at) + VALUES(?, ?, ?, ?) + """, + (session_id, state_key, float(new_value), int(time.time())), + ) + conn.commit() + except Exception as exc: + logger.error("Failed to update internal state", exc_info=True) + raise RuntimeError("failed to update internal state") from exc + + try: + await asyncio.to_thread(_upsert) + except RuntimeError: + logger.error("Failed to update internal state", exc_info=True) + + async def list_states(self, session_id: str) -> Dict[str, float]: + """Get all state values for a session.""" + if not self._db_path or not os.path.exists(self._db_path): + return {} + + def _query() -> Dict[str, float]: + try: + with self._get_connection(row_factory=True) as conn: + rows = conn.execute( + """ + SELECT state_key, state_value + FROM internal_states + WHERE session_id = ? + """, + (session_id,), + ).fetchall() + + return {row["state_key"]: float(row["state_value"]) for row in rows} + except Exception as exc: + logger.debug("internal_states.query.error", exc_info=True) + raise RuntimeError("failed to query internal_states database") from exc + + try: + return await asyncio.to_thread(_query) + except (FileNotFoundError, RuntimeError): + return {} + + async def delete_state(self, session_id: str, state_key: str) -> bool: + """Delete a specific state for a session. Returns True if deleted.""" + if not self._db_path: + return False + + def _delete() -> bool: + try: + with self._get_connection() as conn: + cur = conn.cursor() + cur.execute( + """ + DELETE FROM internal_states + WHERE session_id = ? AND state_key = ? + """, + (session_id, state_key), + ) + conn.commit() + return cur.rowcount > 0 + except Exception as exc: + logger.error("Failed to delete internal state", exc_info=True) + raise RuntimeError("failed to delete internal state") from exc + + try: + return await asyncio.to_thread(_delete) + except RuntimeError: + return False + + async def clear_session(self, session_id: str) -> int: + """Clear all states for a session. Returns number of deleted records.""" + if not self._db_path: + return 0 + + def _clear() -> int: + try: + with self._get_connection() as conn: + cur = conn.cursor() + cur.execute( + """ + DELETE FROM internal_states + WHERE session_id = ? + """, + (session_id,), + ) + conn.commit() + return cur.rowcount + except Exception as exc: + logger.error("Failed to clear session states", exc_info=True) + raise RuntimeError("failed to clear session states") from exc + + try: + return await asyncio.to_thread(_clear) + except RuntimeError: + return 0 + + +__all__ = ["InternalStateStore"] \ No newline at end of file diff --git a/services/dialog-engine/src/dialog_engine/llm_functions.py b/services/dialog-engine/src/dialog_engine/llm_functions.py new file mode 100644 index 0000000..ca7b1b0 --- /dev/null +++ b/services/dialog-engine/src/dialog_engine/llm_functions.py @@ -0,0 +1,182 @@ +"""Function calling definitions and handlers for AI internal state management.""" + +import json +import logging +from typing import Dict, Any + +logger = logging.getLogger(__name__) + +# Function definitions for OpenAI function calling +FUNCTION_DEFINITIONS = [ + { + "name": "update_internal_state", + "description": "更新 AI 的内部状态,如情绪值或好感度。可以用来表达当前的情感状态或对用户的态度变化。", + "parameters": { + "type": "object", + "properties": { + "state_key": { + "type": "string", + "description": "状态名,如 'emotion'(情绪)、'affinity'(好感度)、'energy'(能量)等" + }, + "value": { + "type": "number", + "description": "新的数值,通常范围在 0-100 之间。数值越高表示正面情绪或好感度越高" + } + }, + "required": ["state_key", "value"] + } + } +] + +async def handle_tool_call(tool_call: Dict[str, Any], session_id: str, state_store) -> Dict[str, Any]: + """ + Handle a tool call from the LLM. + + Args: + tool_call: The tool call object from the LLM + session_id: The current session ID + state_store: The InternalStateStore instance + + Returns: + Dict containing the result of the tool call + """ + try: + function_name = tool_call.get("name") + arguments_str = tool_call.get("arguments", "{}") + + if function_name == "update_internal_state": + # Parse arguments + try: + args = json.loads(arguments_str) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse tool call arguments: {arguments_str}", exc_info=True) + return { + "success": False, + "error": f"Invalid arguments: {str(e)}" + } + + # Validate required parameters + state_key = args.get("state_key") + value = args.get("value") + + if state_key is None: + return { + "success": False, + "error": "Missing required parameter: state_key" + } + + if value is None: + return { + "success": False, + "error": "Missing required parameter: value" + } + + # Convert value to float + try: + value = float(value) + except (ValueError, TypeError): + return { + "success": False, + "error": f"Invalid value type: {value} must be a number" + } + + # Validate value range (optional but recommended) + if not isinstance(value, (int, float)): + return { + "success": False, + "error": f"Invalid value: {value} must be numeric" + } + + # Add boundary check for state values (0-100 range) + if not 0 <= value <= 100: + return { + "success": False, + "error": f"Value {value} out of valid range (0-100). State values should be between 0 and 100." + } + + # Update the state + await state_store.update_state(session_id, state_key, value) + + # Get updated value for confirmation + updated_value = await state_store.get_state(session_id, state_key) + + logger.info(f"Updated internal state: session={session_id}, key={state_key}, value={value}") + + return { + "success": True, + "message": f"Successfully updated {state_key} to {value}", + "session_id": session_id, + "state_key": state_key, + "old_value": None, # We could track this if needed + "new_value": updated_value + } + + else: + return { + "success": False, + "error": f"Unknown function: {function_name}" + } + + except Exception as exc: + logger.error(f"Error handling tool call: {exc}", exc_info=True) + return { + "success": False, + "error": f"Internal error: {str(exc)}" + } + + +def format_state_for_context(states: Dict[str, float]) -> str: + """ + Format internal states for inclusion in LLM context. + + Args: + states: Dictionary of state key -> value + + Returns: + Formatted string describing current states + """ + if not states: + return "暂无内部状态数据" + + # Define some common state descriptions + state_descriptions = { + "emotion": "情绪", + "affinity": "好感度", + "energy": "能量", + "mood": "心情", + "trust": "信任度", + "engagement": "参与度" + } + + formatted_states = [] + for key, value in states.items(): + description = state_descriptions.get(key, key) + formatted_states.append(f"{description}: {value:.1f}") + + return "当前内部状态:" + ",".join(formatted_states) + + +def create_state_system_message(states: Dict[str, float]) -> Dict[str, str]: + """ + Create a system message that includes current internal states. + + Args: + states: Dictionary of state key -> value + + Returns: + System message dict for inclusion in conversation + """ + state_summary = format_state_for_context(states) + + return { + "role": "system", + "content": f"{state_summary}。请根据这些内部状态调整你的语气和表达方式。例如:情绪值高时可以更热情友好,情绪值低时可以更冷静克制。好感度高时可以更亲近自然,好感度低时可以保持适当距离。" + } + + +__all__ = [ + "FUNCTION_DEFINITIONS", + "handle_tool_call", + "format_state_for_context", + "create_state_system_message" +] \ No newline at end of file diff --git a/services/dialog-engine/tests/unit/test_internal_state_store.py b/services/dialog-engine/tests/unit/test_internal_state_store.py new file mode 100644 index 0000000..5d7c149 --- /dev/null +++ b/services/dialog-engine/tests/unit/test_internal_state_store.py @@ -0,0 +1,193 @@ +"""Tests for InternalStateStore.""" + +import os +import tempfile +import pytest +import asyncio +from dialog_engine.internal_state_store import InternalStateStore + + +@pytest.fixture +def temp_db_path(): + """Create a temporary database file path.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + yield db_path + # Cleanup + if os.path.exists(db_path): + os.unlink(db_path) + + +@pytest.fixture +def state_store(temp_db_path): + """Create an InternalStateStore instance with temporary database.""" + store = InternalStateStore(db_path=temp_db_path) + yield store + # No explicit cleanup needed as temp file will be deleted + + +@pytest.mark.asyncio +async def test_update_and_get_state(state_store): + """Test updating and retrieving a single state.""" + session_id = "test_session" + state_key = "emotion" + state_value = 75.5 + + # Update state + await state_store.update_state(session_id, state_key, state_value) + + # Get state + retrieved_value = await state_store.get_state(session_id, state_key) + assert retrieved_value == state_value + + +@pytest.mark.asyncio +async def test_get_nonexistent_state(state_store): + """Test getting a state that doesn't exist.""" + session_id = "nonexistent_session" + state_key = "nonexistent_state" + + retrieved_value = await state_store.get_state(session_id, state_key) + assert retrieved_value is None + + +@pytest.mark.asyncio +async def test_list_states(state_store): + """Test listing all states for a session.""" + session_id = "test_session" + states = { + "emotion": 75.5, + "affinity": 60.0, + "energy": 80.2 + } + + # Update multiple states + for key, value in states.items(): + await state_store.update_state(session_id, key, value) + + # List all states + retrieved_states = await state_store.list_states(session_id) + assert retrieved_states == states + + +@pytest.mark.asyncio +async def test_list_states_empty(state_store): + """Test listing states for a session with no states.""" + session_id = "empty_session" + + retrieved_states = await state_store.list_states(session_id) + assert retrieved_states == {} + + +@pytest.mark.asyncio +async def test_update_existing_state(state_store): + """Test updating an existing state value.""" + session_id = "test_session" + state_key = "emotion" + initial_value = 50.0 + updated_value = 80.0 + + # Set initial value + await state_store.update_state(session_id, state_key, initial_value) + assert await state_store.get_state(session_id, state_key) == initial_value + + # Update to new value + await state_store.update_state(session_id, state_key, updated_value) + assert await state_store.get_state(session_id, state_key) == updated_value + + +@pytest.mark.asyncio +async def test_delete_state(state_store): + """Test deleting a specific state.""" + session_id = "test_session" + state_key = "emotion" + state_value = 75.5 + + # Set initial state + await state_store.update_state(session_id, state_key, state_value) + assert await state_store.get_state(session_id, state_key) == state_value + + # Delete state + deleted = await state_store.delete_state(session_id, state_key) + assert deleted is True + assert await state_store.get_state(session_id, state_key) is None + + +@pytest.mark.asyncio +async def test_delete_nonexistent_state(state_store): + """Test deleting a state that doesn't exist.""" + session_id = "test_session" + state_key = "nonexistent_state" + + # Try to delete nonexistent state + deleted = await state_store.delete_state(session_id, state_key) + assert deleted is False + + +@pytest.mark.asyncio +async def test_clear_session(state_store): + """Test clearing all states for a session.""" + session_id = "test_session" + states = { + "emotion": 75.5, + "affinity": 60.0, + "energy": 80.2 + } + + # Set up states + for key, value in states.items(): + await state_store.update_state(session_id, key, value) + assert await state_store.list_states(session_id) == states + + # Clear session + deleted_count = await state_store.clear_session(session_id) + assert deleted_count == len(states) + assert await state_store.list_states(session_id) == {} + + +@pytest.mark.asyncio +async def test_clear_empty_session(state_store): + """Test clearing a session with no states.""" + session_id = "empty_session" + + deleted_count = await state_store.clear_session(session_id) + assert deleted_count == 0 + + +@pytest.mark.asyncio +async def test_multiple_sessions_isolation(state_store): + """Test that different sessions don't interfere with each other.""" + session1_id = "session1" + session2_id = "session2" + + # Set different states for different sessions + await state_store.update_state(session1_id, "emotion", 70.0) + await state_store.update_state(session2_id, "emotion", 80.0) + await state_store.update_state(session1_id, "affinity", 60.0) + + # Verify isolation + session1_states = await state_store.list_states(session1_id) + session2_states = await state_store.list_states(session2_id) + + assert session1_states == {"emotion": 70.0, "affinity": 60.0} + assert session2_states == {"emotion": 80.0} + + +@pytest.mark.asyncio +async def test_invalid_db_path(): + """Test behavior with invalid database path.""" + # Store with None path should not crash + store = InternalStateStore(db_path=None) + + # Operations should safely return defaults + assert await store.get_state("session", "key") is None + assert await store.list_states("session") == {} + assert await store.delete_state("session", "key") is False + assert await store.clear_session("session") == 0 + + # Update should not crash + await store.update_state("session", "key", 50.0) # Should not raise exception + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/services/dialog-engine/tests/unit/test_llm_functions.py b/services/dialog-engine/tests/unit/test_llm_functions.py new file mode 100644 index 0000000..f004067 --- /dev/null +++ b/services/dialog-engine/tests/unit/test_llm_functions.py @@ -0,0 +1,231 @@ +"""Tests for LLM function calling functionality.""" + +import pytest +import json +from unittest.mock import AsyncMock, MagicMock +from dialog_engine.llm_functions import ( + FUNCTION_DEFINITIONS, + handle_tool_call, + format_state_for_context, + create_state_system_message +) + + +@pytest.mark.asyncio +async def test_handle_tool_call_update_internal_state_success(): + """Test successful internal state update via tool call.""" + # Setup + mock_state_store = AsyncMock() + mock_state_store.get_state.return_value = 75.5 + + tool_call = { + "name": "update_internal_state", + "arguments": json.dumps({ + "state_key": "emotion", + "value": 75.5 + }) + } + + session_id = "test_session" + + # Execute + result = await handle_tool_call(tool_call, session_id, mock_state_store) + + # Verify + assert result["success"] is True + assert result["state_key"] == "emotion" + assert result["new_value"] == 75.5 + assert result["session_id"] == session_id + mock_state_store.update_state.assert_called_once_with(session_id, "emotion", 75.5) + mock_state_store.get_state.assert_called_once_with(session_id, "emotion") + + +@pytest.mark.asyncio +async def test_handle_tool_call_invalid_json(): + """Test handling of invalid JSON in tool call arguments.""" + mock_state_store = AsyncMock() + + tool_call = { + "name": "update_internal_state", + "arguments": "invalid json" + } + + result = await handle_tool_call(tool_call, "test_session", mock_state_store) + + assert result["success"] is False + assert "Invalid arguments" in result["error"] + + +@pytest.mark.asyncio +async def test_handle_tool_call_missing_state_key(): + """Test handling of missing state_key parameter.""" + mock_state_store = AsyncMock() + + tool_call = { + "name": "update_internal_state", + "arguments": json.dumps({ + "value": 75.5 + }) + } + + result = await handle_tool_call(tool_call, "test_session", mock_state_store) + + assert result["success"] is False + assert "Missing required parameter: state_key" in result["error"] + + +@pytest.mark.asyncio +async def test_handle_tool_call_missing_value(): + """Test handling of missing value parameter.""" + mock_state_store = AsyncMock() + + tool_call = { + "name": "update_internal_state", + "arguments": json.dumps({ + "state_key": "emotion" + }) + } + + result = await handle_tool_call(tool_call, "test_session", mock_state_store) + + assert result["success"] is False + assert "Missing required parameter: value" in result["error"] + + +@pytest.mark.asyncio +async def test_handle_tool_call_invalid_value_type(): + """Test handling of invalid value type.""" + mock_state_store = AsyncMock() + + tool_call = { + "name": "update_internal_state", + "arguments": json.dumps({ + "state_key": "emotion", + "value": "not_a_number" + }) + } + + result = await handle_tool_call(tool_call, "test_session", mock_state_store) + + assert result["success"] is False + assert "must be a number" in result["error"] + + +@pytest.mark.asyncio +async def test_handle_tool_call_unknown_function(): + """Test handling of unknown function name.""" + mock_state_store = AsyncMock() + + tool_call = { + "name": "unknown_function", + "arguments": json.dumps({ + "state_key": "emotion", + "value": 75.5 + }) + } + + result = await handle_tool_call(tool_call, "test_session", mock_state_store) + + assert result["success"] is False + assert "Unknown function: unknown_function" in result["error"] + + +@pytest.mark.asyncio +async def test_handle_tool_call_store_exception(): + """Test handling of exceptions from state store.""" + mock_state_store = AsyncMock() + mock_state_store.update_state.side_effect = Exception("Database error") + + tool_call = { + "name": "update_internal_state", + "arguments": json.dumps({ + "state_key": "emotion", + "value": 75.5 + }) + } + + result = await handle_tool_call(tool_call, "test_session", mock_state_store) + + assert result["success"] is False + assert "Internal error" in result["error"] + + +def test_format_state_for_context_empty(): + """Test formatting empty states.""" + result = format_state_for_context({}) + assert result == "暂无内部状态数据" + + +def test_format_state_for_context_basic(): + """Test formatting basic states.""" + states = { + "emotion": 75.5, + "affinity": 60.0 + } + result = format_state_for_context(states) + assert "情绪: 75.5" in result + assert "好感度: 60.0" in result + assert "当前内部状态:" in result + + +def test_format_state_for_context_unknown_keys(): + """Test formatting states with unknown keys.""" + states = { + "unknown_state": 50.0, + "another_unknown": 80.0 + } + result = format_state_for_context(states) + assert "unknown_state: 50.0" in result + assert "another_unknown: 80.0" in result + + +def test_create_state_system_message(): + """Test creating system message with states.""" + states = { + "emotion": 75.5, + "affinity": 60.0 + } + + result = create_state_system_message(states) + + assert result["role"] == "system" + assert "情绪: 75.5" in result["content"] + assert "好感度: 60.0" in result["content"] + assert "语气和表达方式" in result["content"] + + +def test_create_state_system_message_empty(): + """Test creating system message with empty states.""" + result = create_state_system_message({}) + + assert result["role"] == "system" + assert "暂无内部状态数据" in result["content"] + + +def test_function_definitions_structure(): + """Test that function definitions have the correct structure.""" + assert len(FUNCTION_DEFINITIONS) == 1 + + func_def = FUNCTION_DEFINITIONS[0] + assert func_def["name"] == "update_internal_state" + assert "description" in func_def + assert "parameters" in func_def + + params = func_def["parameters"] + assert params["type"] == "object" + assert "properties" in params + assert "required" in params + + properties = params["properties"] + assert "state_key" in properties + assert "value" in properties + assert properties["state_key"]["type"] == "string" + assert properties["value"]["type"] == "number" + + required = params["required"] + assert "state_key" in required + assert "value" in required + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file