diff --git a/backend-architecture.md b/backend-architecture.md
new file mode 100644
index 0000000..2337000
--- /dev/null
+++ b/backend-architecture.md
@@ -0,0 +1,290 @@
+# Free-Agent-Vtuber 后端架构图
+
+## 整体架构概览
+
+```mermaid
+graph TB
+ %% 外部接口层
+ subgraph "外部接口"
+ WebClient[Web 客户端]
+ API[API 调用]
+ end
+
+ %% 网关层
+ subgraph "API 网关层"
+ Gateway[Gateway Service
FastAPI + Flask
:8000]
+ end
+
+ %% 输入输出处理层
+ subgraph "输入输出处理层"
+ InputHandler[Input Handler
WebSocket Server
:8001]
+ OutputHandler[Output Handler
WebSocket Server
:8002]
+ end
+
+ %% 核心服务层
+ subgraph "核心处理服务"
+ ASR[ASR Service
语音识别
OpenAI Whisper]
+ Memory[Memory Service
短期记忆管理
对话上下文]
+ ChatAI[Chat AI Service
AI 对话生成
OpenAI GPT]
+ TTS[TTS Service
语音合成
Edge-TTS]
+ LTM[Long Term Memory
长期记忆
向量搜索]
+ end
+
+ %% 消息总线
+ subgraph "消息总线"
+ Redis[(Redis
事件驱动消息总线
:6379)]
+ end
+
+ %% 数据存储层
+ subgraph "数据存储"
+ PostgreSQL[(PostgreSQL
pgvector
:5432)]
+ MemoryData[(Memory Data
对话历史)]
+ TempFiles[(Temp Files
音频文件)]
+ end
+
+ %% 管理层
+ subgraph "管理与监控"
+ Manager[Manager Service
Flask Web UI
:5000]
+ end
+
+ %% 连接关系
+ WebClient --> Gateway
+ API --> Gateway
+
+ Gateway --> InputHandler
+ Gateway --> OutputHandler
+ Gateway --> Redis
+
+ InputHandler --> Redis
+ OutputHandler --> Redis
+
+ ASR --> Redis
+ Memory --> Redis
+ ChatAI --> Redis
+ TTS --> Redis
+ LTM --> Redis
+
+ Memory --> MemoryData
+ LTM --> PostgreSQL
+ TTS --> TempFiles
+ ASR --> TempFiles
+
+ Manager --> Redis
+ Manager -.-> ASR
+ Manager -.-> Memory
+ Manager -.-> ChatAI
+ Manager -.-> TTS
+ Manager -.-> LTM
+
+ %% 样式
+ classDef service fill:#e1f5fe,stroke:#0277bd,stroke-width:2px
+ classDef storage fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
+ classDef message fill:#fff3e0,stroke:#ef6c00,stroke-width:2px
+ classDef external fill:#e8f5e8,stroke:#2e7d32,stroke-width:2px
+
+ class Gateway,InputHandler,OutputHandler,ASR,Memory,ChatAI,TTS,LTM,Manager service
+ class PostgreSQL,MemoryData,TempFiles storage
+ class Redis message
+ class WebClient,API external
+```
+
+## 消息流架构详图
+
+```mermaid
+sequenceDiagram
+ participant Client as 客户端
+ participant GW as Gateway
+ participant IH as Input Handler
+ participant ASR as ASR Service
+ participant MEM as Memory Service
+ participant AI as Chat AI Service
+ participant LTM as Long Term Memory
+ participant TTS as TTS Service
+ participant OH as Output Handler
+ participant Redis as Redis Bus
+
+ Note over Client,Redis: 完整的语音交互流程
+
+ %% 1. 语音输入阶段
+ Client->>GW: POST /api/asr (音频文件)
+ GW->>Redis: LPUSH asr_tasks
+ Redis->>ASR: 消费 asr_tasks
+ ASR-->>Redis: PUBLISH asr_results
+
+ %% 2. 输入处理阶段
+ Redis->>IH: 订阅 asr_results
+ IH->>Redis: LPUSH user_input_queue
+
+ %% 3. 记忆处理阶段
+ Redis->>MEM: 消费 user_input_queue
+ MEM-->>Redis: PUBLISH memory_updates
+
+ %% 4. AI对话阶段
+ Redis->>AI: 订阅 memory_updates
+
+ %% 5. 长期记忆查询(可选)
+ alt 启用长期记忆
+ AI->>Redis: LPUSH ltm_requests
+ Redis->>LTM: 消费 ltm_requests
+ LTM-->>Redis: PUBLISH ltm_responses
+ Redis->>AI: 订阅 ltm_responses
+ end
+
+ %% 6. AI响应生成
+ AI-->>Redis: PUBLISH ai_responses
+ AI->>Redis: LPUSH tts_requests
+
+ %% 7. 语音合成阶段
+ Redis->>TTS: 消费 tts_requests
+ TTS-->>Redis: PUBLISH task_response:task_id
+
+ %% 8. 输出处理阶段
+ Redis->>OH: 订阅 task_response:*
+ OH->>GW: WebSocket 推送
+ GW->>Client: 返回合成语音
+```
+
+## 核心服务详细说明
+
+### 1. Gateway Service (网关服务)
+- **技术栈**: FastAPI + Flask Blueprint
+- **端口**: 8000
+- **职责**:
+ - 提供统一的API入口
+ - 处理HTTP请求和WebSocket连接
+ - 路由ASR请求到消息队列
+ - 管理客户端连接状态
+
+### 2. Input Handler (输入处理服务)
+- **技术栈**: Python + WebSocket
+- **端口**: 8001
+- **职责**:
+ - 订阅ASR识别结果
+ - 标准化用户输入格式
+ - 将处理后的输入推送到用户输入队列
+
+### 3. ASR Service (语音识别服务)
+- **技术栈**: Python + OpenAI Whisper
+- **职责**:
+ - 消费语音识别任务队列
+ - 将音频文件转换为文本
+ - 发布识别结果到消息总线
+
+### 4. Memory Service (记忆管理服务)
+- **技术栈**: Python
+- **职责**:
+ - 管理对话上下文和短期记忆
+ - 维护用户会话状态
+ - 为AI服务提供对话历史
+
+### 5. Chat AI Service (AI对话服务)
+- **技术栈**: Python + OpenAI GPT
+- **职责**:
+ - 生成AI回复
+ - 整合长期记忆内容(如果启用)
+ - 触发语音合成请求
+
+### 6. Long Term Memory Service (长期记忆服务)
+- **技术栈**: Python + pgvector + mem0
+- **职责**:
+ - 存储和检索长期记忆
+ - 向量相似度搜索
+ - 为对话提供相关历史信息
+
+### 7. TTS Service (语音合成服务)
+- **技术栈**: Python + Edge-TTS
+- **职责**:
+ - 将文本转换为语音
+ - 生成音频文件
+ - 发布合成结果
+
+### 8. Output Handler (输出处理服务)
+- **技术栈**: Python + WebSocket
+- **端口**: 8002
+- **职责**:
+ - 处理服务输出
+ - 通过WebSocket推送结果到客户端
+
+## 数据存储架构
+
+### Redis 消息总线
+- **队列 (List)**:
+ - `asr_tasks`: ASR识别任务
+ - `user_input_queue`: 用户输入队列
+ - `tts_requests`: TTS合成请求
+ - `ltm_requests`: 长期记忆查询请求
+
+- **发布/订阅 (Pub/Sub)**:
+ - `asr_results`: ASR识别结果
+ - `memory_updates`: 记忆更新通知
+ - `ai_responses`: AI回复
+ - `ltm_responses`: 长期记忆查询结果
+ - `task_response:{task_id}`: 任务响应
+
+### PostgreSQL + pgvector
+- 存储长期记忆向量数据
+- 支持语义相似度搜索
+- 维护用户历史对话记录
+
+### 文件存储
+- 临时音频文件存储 (`/tmp/aivtuber_tasks`)
+- 内存数据持久化 (`memory_data`)
+- 长期记忆数据 (`ltm_data`)
+
+## 架构特点
+
+1. **事件驱动**: 基于Redis的消息总线实现松耦合
+2. **微服务**: 每个功能模块独立部署和扩展
+3. **异步处理**: 支持并发处理多个用户请求
+4. **可插拔**: 支持不同的AI、TTS、ASR提供商
+5. **容器化**: 所有服务均支持Docker部署
+6. **可观测**: 完整的日志和监控体系
+
+## 部署架构
+
+```mermaid
+graph TB
+ subgraph "Docker Network: aivtuber-network"
+ subgraph "计算服务"
+ GW[gateway:8000]
+ IH[input-handler:8001]
+ OH[output-handler:8002]
+ ASR[asr]
+ MEM[memory]
+ AI[chat-ai]
+ TTS[tts]
+ LTM[long-term-memory]
+ end
+
+ subgraph "基础设施"
+ Redis[redis:6379]
+ PG[postgres:5432]
+ end
+
+ subgraph "存储卷"
+ RedisData[redis_data]
+ PostgresData[postgres_data]
+ MemoryData[memory_data]
+ LTMData[ltm_data]
+ TempFiles[temp_files]
+ end
+ end
+
+ subgraph "外部"
+ Client[客户端]
+ Manager[Manager:5000
可选管理界面]
+ end
+
+ Client --> GW
+ Manager -.-> GW
+
+ Redis --- RedisData
+ PG --- PostgresData
+ MEM --- MemoryData
+ LTM --- LTMData
+ TTS --- TempFiles
+ ASR --- TempFiles
+ OH --- TempFiles
+```
+
+这个架构设计实现了高度模块化、可扩展的AI虚拟主播系统,通过事件驱动的方式确保了系统的松耦合和高可用性。
\ No newline at end of file
diff --git a/services/dialog-engine/src/dialog_engine/app.py b/services/dialog-engine/src/dialog_engine/app.py
index a44cd92..69462d5 100644
--- a/services/dialog-engine/src/dialog_engine/app.py
+++ b/services/dialog-engine/src/dialog_engine/app.py
@@ -17,11 +17,22 @@
from .asr import AsrOptions, AsrService
from .tts_streamer import stream_text as tts_stream_text
from .ltm_outbox import add_event as outbox_add_event, start_flush_task as outbox_start_flush
+from .internal_state_store import InternalStateStore
app = FastAPI()
-chat_service = ChatService()
logger = logging.getLogger(__name__)
+
+# Initialize internal state store
+try:
+ import os
+ db_path = os.getenv("INTERNAL_STATE_DB_PATH", "internal_states.db")
+ state_store = InternalStateStore(db_path=db_path)
+except Exception as exc:
+ logger.exception("Failed to initialize InternalStateStore", extra={"error": repr(exc)})
+ state_store = None
+
+chat_service = ChatService(state_store=state_store)
SYNC_TTS_STREAMING = os.getenv("SYNC_TTS_STREAMING", "false").lower() in {"1", "true", "yes", "on"}
ENABLE_ASYNC_EXT = os.getenv("ENABLE_ASYNC_EXT", "false").lower() in {"1", "true", "yes", "on"}
VISION_MAX_BYTES = int(os.getenv("VISION_MAX_BYTES", 4 * 1024 * 1024))
@@ -192,6 +203,12 @@ async def event_generator() -> AsyncGenerator[bytes, None]:
return
stats = {"ttft_ms": round(ttft_ms or 0.0, 1), "tokens": chat_service.last_token_count}
+
+ # Include internal states in the done event
+ internal_states = await chat_service.get_internal_states(session_id)
+ if internal_states:
+ stats["internal_states"] = internal_states
+
yield _sse_format("done", {"stats": stats})
# Emit async events via outbox
@@ -418,6 +435,11 @@ async def event_generator() -> AsyncGenerator[bytes, None]:
"stats": stats,
}
+ # Include internal states in the done event
+ internal_states = await chat_service.get_internal_states(session_id)
+ if internal_states:
+ stats["internal_states"] = internal_states
+
yield _sse_format("done", done_payload)
_emit_async_events(
diff --git a/services/dialog-engine/src/dialog_engine/chat_service.py b/services/dialog-engine/src/dialog_engine/chat_service.py
index 75746eb..5dfb4cf 100644
--- a/services/dialog-engine/src/dialog_engine/chat_service.py
+++ b/services/dialog-engine/src/dialog_engine/chat_service.py
@@ -10,6 +10,8 @@
from .ltm_client import LTMInlineClient
from .memory_store import MemoryTurn, ShortTermMemoryStore
from .settings import Settings, settings as runtime_settings
+from .internal_state_store import InternalStateStore
+from .llm_functions import FUNCTION_DEFINITIONS, handle_tool_call
class ChatService:
@@ -22,12 +24,14 @@ def __init__(
llm_client_factory: Optional[Callable[[], OpenAIChatClient]] = None,
memory_store: Optional[ShortTermMemoryStore] = None,
ltm_client: Optional[LTMInlineClient] = None,
+ state_store: Optional[InternalStateStore] = None,
) -> None:
self._settings = settings or runtime_settings
self._llm_client_factory = llm_client_factory
self._llm_client: Optional[OpenAIChatClient] = None
self._memory_store = memory_store
self._ltm_client = ltm_client
+ self._state_store = state_store
self.last_token_count: int = 0
self.last_ttft_ms: Optional[float] = None
@@ -71,6 +75,9 @@ async def stream_reply(
return
except LLMStreamEmptyError as exc:
self.last_error = "llm_empty_stream"
+ # Process tool calls if present
+ if exc.tool_calls and self._state_store:
+ await self._process_tool_calls(exc.tool_calls, session_id)
self._log_llm_fallback(reason=f"empty_stream:{exc.tool_calls}")
except LLMNotConfiguredError as exc:
self.last_error = "llm_not_configured"
@@ -163,15 +170,23 @@ async def _stream_llm(
ltm_snippets: List[str],
) -> AsyncGenerator[str, None]:
client = await self._ensure_llm_client()
- messages = self._compose_messages(
+ meta_with_session = dict(meta)
+ meta_with_session["session_id"] = session_id
+ messages = await self._compose_messages(
user_text=user_text,
- meta=meta,
+ meta=meta_with_session,
context=context,
ltm_snippets=ltm_snippets,
)
extra_options: Dict[str, Any] = {
"extra_headers": {"x-session-id": session_id},
}
+
+ # Add function calling support if state store is available
+ if self._state_store:
+ extra_options["functions"] = FUNCTION_DEFINITIONS
+ extra_options["tool_choice"] = "auto"
+
async for delta in client.stream_chat(messages, extra_options=extra_options):
yield delta
@@ -198,9 +213,11 @@ async def _generate_vision_reply(
ltm_snippets: List[str],
) -> str:
client = await self._ensure_llm_client()
- messages = self._compose_messages(
+ meta_with_session = dict(meta)
+ meta_with_session["session_id"] = session_id
+ messages = await self._compose_messages(
user_text=prompt_text,
- meta=meta,
+ meta=meta_with_session,
context=context,
ltm_snippets=ltm_snippets,
)
@@ -261,7 +278,7 @@ async def _ensure_llm_client(self) -> OpenAIChatClient:
self._llm_client = client
return client
- def _compose_messages(
+ async def _compose_messages(
self,
*,
user_text: str,
@@ -274,6 +291,18 @@ def _compose_messages(
if system_prompt:
messages.append({"role": "system", "content": str(system_prompt)})
+ # Inject internal states as context if available
+ if self._state_store:
+ session_id = meta.get("session_id", "default")
+ state_dict = await self.get_internal_states(session_id)
+ if state_dict:
+ mood_summary = "; ".join([f"{k}:{v:.2f}" for k, v in state_dict.items()])
+ state_message = {
+ "role": "system",
+ "content": f"当前内部状态:{mood_summary}。请据此调整语气与行为。"
+ }
+ messages.append(state_message)
+
for turn in context:
role = turn.role if turn.role in {"user", "assistant", "system"} else "assistant"
messages.append({"role": role, "content": turn.content})
@@ -292,6 +321,30 @@ def _reset_metrics(self) -> None:
self.last_source = "mock"
self.last_error = None
+ async def _process_tool_calls(self, tool_calls: List[Any], session_id: str) -> None:
+ """Process tool calls from LLM to update internal states."""
+ if not self._state_store:
+ return
+
+ for tool_call in tool_calls:
+ try:
+ call_info = {
+ "name": getattr(tool_call, "function", {}).get("name"),
+ "arguments": getattr(tool_call, "function", {}).get("arguments", "{}")
+ }
+ await handle_tool_call(call_info, session_id, self._state_store)
+ except Exception as exc:
+ self._log_context_warning("tool_call.error", exc)
+
+ async def get_internal_states(self, session_id: str) -> Dict[str, float]:
+ """Get current internal states for a session."""
+ if not self._state_store:
+ return {}
+ try:
+ return await self._state_store.list_states(session_id)
+ except Exception:
+ return {}
+
def _log_llm_fallback(self, *, reason: str) -> None:
# Deliberately late import to avoid global logging setup requirements.
from logging import getLogger
diff --git a/services/dialog-engine/src/dialog_engine/internal_state_store.py b/services/dialog-engine/src/dialog_engine/internal_state_store.py
new file mode 100644
index 0000000..ea0c07b
--- /dev/null
+++ b/services/dialog-engine/src/dialog_engine/internal_state_store.py
@@ -0,0 +1,226 @@
+from __future__ import annotations
+
+"""Internal state storage for AI emotions and affinity tracking."""
+
+import asyncio
+import os
+import sqlite3
+import time
+from contextlib import contextmanager
+from typing import Dict, Optional
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class InternalStateStore:
+ """SQLite-backed store for AI internal states (emotion, affinity, etc.)."""
+
+ def __init__(self, *, db_path: str) -> None:
+ self._db_path = db_path
+ self._ensure_table_exists()
+
+ @contextmanager
+ def _get_connection(self, row_factory: bool = False):
+ """Context manager for database connections to ensure proper cleanup."""
+ if not self._db_path:
+ raise RuntimeError("Database path not configured")
+
+ try:
+ conn = sqlite3.connect(self._db_path)
+ if row_factory:
+ conn.row_factory = sqlite3.Row
+ yield conn
+ except Exception as exc:
+ logger.debug("internal_states.connect.error", exc_info=True)
+ raise RuntimeError("failed to open internal_states database") from exc
+ finally:
+ if 'conn' in locals():
+ conn.close()
+
+ def _ensure_table_exists(self) -> None:
+ """Create the internal_states table if it doesn't exist."""
+ if not self._db_path:
+ return
+
+ os.makedirs(os.path.dirname(self._db_path) or ".", exist_ok=True)
+
+ with self._get_connection() as conn:
+ try:
+ cur = conn.cursor()
+ cur.execute(
+ """
+ CREATE TABLE IF NOT EXISTS internal_states (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ session_id TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ state_value REAL NOT NULL,
+ updated_at INTEGER NOT NULL,
+ UNIQUE(session_id, state_key)
+ )
+ """,
+ )
+ # Create index for efficient queries
+ cur.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_internal_states_session_key
+ ON internal_states(session_id, state_key)
+ """,
+ )
+ conn.commit()
+ except Exception as exc:
+ logger.error("Failed to create internal_states table", exc_info=True)
+ raise RuntimeError("failed to create internal_states table") from exc
+
+ async def get_state(self, session_id: str, state_key: str) -> Optional[float]:
+ """Get a specific state value for a session."""
+ if not self._db_path or not os.path.exists(self._db_path):
+ return None
+
+ def _query() -> Optional[float]:
+ try:
+ with self._get_connection(row_factory=True) as conn:
+ row = conn.execute(
+ """
+ SELECT state_value
+ FROM internal_states
+ WHERE session_id = ? AND state_key = ?
+ """,
+ (session_id, state_key),
+ ).fetchone()
+ return float(row["state_value"]) if row else None
+ except Exception as exc:
+ logger.debug("internal_states.query.error", exc_info=True)
+ raise RuntimeError("failed to query internal_states database") from exc
+
+ try:
+ return await asyncio.to_thread(_query)
+ except (FileNotFoundError, RuntimeError):
+ return None
+
+ async def update_state(self, session_id: str, state_key: str, new_value: float) -> None:
+ """Update or insert a state value for a session."""
+ if not self._db_path:
+ return
+
+ def _upsert() -> None:
+ os.makedirs(os.path.dirname(self._db_path) or ".", exist_ok=True)
+ try:
+ with self._get_connection() as conn:
+ cur = conn.cursor()
+ # Ensure table exists (defensive)
+ cur.execute(
+ """
+ CREATE TABLE IF NOT EXISTS internal_states (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ session_id TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ state_value REAL NOT NULL,
+ updated_at INTEGER NOT NULL,
+ UNIQUE(session_id, state_key)
+ )
+ """,
+ )
+
+ # Insert or replace the state value
+ cur.execute(
+ """
+ INSERT OR REPLACE INTO internal_states(session_id, state_key, state_value, updated_at)
+ VALUES(?, ?, ?, ?)
+ """,
+ (session_id, state_key, float(new_value), int(time.time())),
+ )
+ conn.commit()
+ except Exception as exc:
+ logger.error("Failed to update internal state", exc_info=True)
+ raise RuntimeError("failed to update internal state") from exc
+
+ try:
+ await asyncio.to_thread(_upsert)
+ except RuntimeError:
+ logger.error("Failed to update internal state", exc_info=True)
+
+ async def list_states(self, session_id: str) -> Dict[str, float]:
+ """Get all state values for a session."""
+ if not self._db_path or not os.path.exists(self._db_path):
+ return {}
+
+ def _query() -> Dict[str, float]:
+ try:
+ with self._get_connection(row_factory=True) as conn:
+ rows = conn.execute(
+ """
+ SELECT state_key, state_value
+ FROM internal_states
+ WHERE session_id = ?
+ """,
+ (session_id,),
+ ).fetchall()
+
+ return {row["state_key"]: float(row["state_value"]) for row in rows}
+ except Exception as exc:
+ logger.debug("internal_states.query.error", exc_info=True)
+ raise RuntimeError("failed to query internal_states database") from exc
+
+ try:
+ return await asyncio.to_thread(_query)
+ except (FileNotFoundError, RuntimeError):
+ return {}
+
+ async def delete_state(self, session_id: str, state_key: str) -> bool:
+ """Delete a specific state for a session. Returns True if deleted."""
+ if not self._db_path:
+ return False
+
+ def _delete() -> bool:
+ try:
+ with self._get_connection() as conn:
+ cur = conn.cursor()
+ cur.execute(
+ """
+ DELETE FROM internal_states
+ WHERE session_id = ? AND state_key = ?
+ """,
+ (session_id, state_key),
+ )
+ conn.commit()
+ return cur.rowcount > 0
+ except Exception as exc:
+ logger.error("Failed to delete internal state", exc_info=True)
+ raise RuntimeError("failed to delete internal state") from exc
+
+ try:
+ return await asyncio.to_thread(_delete)
+ except RuntimeError:
+ return False
+
+ async def clear_session(self, session_id: str) -> int:
+ """Clear all states for a session. Returns number of deleted records."""
+ if not self._db_path:
+ return 0
+
+ def _clear() -> int:
+ try:
+ with self._get_connection() as conn:
+ cur = conn.cursor()
+ cur.execute(
+ """
+ DELETE FROM internal_states
+ WHERE session_id = ?
+ """,
+ (session_id,),
+ )
+ conn.commit()
+ return cur.rowcount
+ except Exception as exc:
+ logger.error("Failed to clear session states", exc_info=True)
+ raise RuntimeError("failed to clear session states") from exc
+
+ try:
+ return await asyncio.to_thread(_clear)
+ except RuntimeError:
+ return 0
+
+
+__all__ = ["InternalStateStore"]
\ No newline at end of file
diff --git a/services/dialog-engine/src/dialog_engine/llm_functions.py b/services/dialog-engine/src/dialog_engine/llm_functions.py
new file mode 100644
index 0000000..ca7b1b0
--- /dev/null
+++ b/services/dialog-engine/src/dialog_engine/llm_functions.py
@@ -0,0 +1,182 @@
+"""Function calling definitions and handlers for AI internal state management."""
+
+import json
+import logging
+from typing import Dict, Any
+
+logger = logging.getLogger(__name__)
+
+# Function definitions for OpenAI function calling
+FUNCTION_DEFINITIONS = [
+ {
+ "name": "update_internal_state",
+ "description": "更新 AI 的内部状态,如情绪值或好感度。可以用来表达当前的情感状态或对用户的态度变化。",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "state_key": {
+ "type": "string",
+ "description": "状态名,如 'emotion'(情绪)、'affinity'(好感度)、'energy'(能量)等"
+ },
+ "value": {
+ "type": "number",
+ "description": "新的数值,通常范围在 0-100 之间。数值越高表示正面情绪或好感度越高"
+ }
+ },
+ "required": ["state_key", "value"]
+ }
+ }
+]
+
+async def handle_tool_call(tool_call: Dict[str, Any], session_id: str, state_store) -> Dict[str, Any]:
+ """
+ Handle a tool call from the LLM.
+
+ Args:
+ tool_call: The tool call object from the LLM
+ session_id: The current session ID
+ state_store: The InternalStateStore instance
+
+ Returns:
+ Dict containing the result of the tool call
+ """
+ try:
+ function_name = tool_call.get("name")
+ arguments_str = tool_call.get("arguments", "{}")
+
+ if function_name == "update_internal_state":
+ # Parse arguments
+ try:
+ args = json.loads(arguments_str)
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse tool call arguments: {arguments_str}", exc_info=True)
+ return {
+ "success": False,
+ "error": f"Invalid arguments: {str(e)}"
+ }
+
+ # Validate required parameters
+ state_key = args.get("state_key")
+ value = args.get("value")
+
+ if state_key is None:
+ return {
+ "success": False,
+ "error": "Missing required parameter: state_key"
+ }
+
+ if value is None:
+ return {
+ "success": False,
+ "error": "Missing required parameter: value"
+ }
+
+ # Convert value to float
+ try:
+ value = float(value)
+ except (ValueError, TypeError):
+ return {
+ "success": False,
+ "error": f"Invalid value type: {value} must be a number"
+ }
+
+ # Validate value range (optional but recommended)
+ if not isinstance(value, (int, float)):
+ return {
+ "success": False,
+ "error": f"Invalid value: {value} must be numeric"
+ }
+
+ # Add boundary check for state values (0-100 range)
+ if not 0 <= value <= 100:
+ return {
+ "success": False,
+ "error": f"Value {value} out of valid range (0-100). State values should be between 0 and 100."
+ }
+
+ # Update the state
+ await state_store.update_state(session_id, state_key, value)
+
+ # Get updated value for confirmation
+ updated_value = await state_store.get_state(session_id, state_key)
+
+ logger.info(f"Updated internal state: session={session_id}, key={state_key}, value={value}")
+
+ return {
+ "success": True,
+ "message": f"Successfully updated {state_key} to {value}",
+ "session_id": session_id,
+ "state_key": state_key,
+ "old_value": None, # We could track this if needed
+ "new_value": updated_value
+ }
+
+ else:
+ return {
+ "success": False,
+ "error": f"Unknown function: {function_name}"
+ }
+
+ except Exception as exc:
+ logger.error(f"Error handling tool call: {exc}", exc_info=True)
+ return {
+ "success": False,
+ "error": f"Internal error: {str(exc)}"
+ }
+
+
+def format_state_for_context(states: Dict[str, float]) -> str:
+ """
+ Format internal states for inclusion in LLM context.
+
+ Args:
+ states: Dictionary of state key -> value
+
+ Returns:
+ Formatted string describing current states
+ """
+ if not states:
+ return "暂无内部状态数据"
+
+ # Define some common state descriptions
+ state_descriptions = {
+ "emotion": "情绪",
+ "affinity": "好感度",
+ "energy": "能量",
+ "mood": "心情",
+ "trust": "信任度",
+ "engagement": "参与度"
+ }
+
+ formatted_states = []
+ for key, value in states.items():
+ description = state_descriptions.get(key, key)
+ formatted_states.append(f"{description}: {value:.1f}")
+
+ return "当前内部状态:" + ",".join(formatted_states)
+
+
+def create_state_system_message(states: Dict[str, float]) -> Dict[str, str]:
+ """
+ Create a system message that includes current internal states.
+
+ Args:
+ states: Dictionary of state key -> value
+
+ Returns:
+ System message dict for inclusion in conversation
+ """
+ state_summary = format_state_for_context(states)
+
+ return {
+ "role": "system",
+ "content": f"{state_summary}。请根据这些内部状态调整你的语气和表达方式。例如:情绪值高时可以更热情友好,情绪值低时可以更冷静克制。好感度高时可以更亲近自然,好感度低时可以保持适当距离。"
+ }
+
+
+__all__ = [
+ "FUNCTION_DEFINITIONS",
+ "handle_tool_call",
+ "format_state_for_context",
+ "create_state_system_message"
+]
\ No newline at end of file
diff --git a/services/dialog-engine/tests/unit/test_internal_state_store.py b/services/dialog-engine/tests/unit/test_internal_state_store.py
new file mode 100644
index 0000000..5d7c149
--- /dev/null
+++ b/services/dialog-engine/tests/unit/test_internal_state_store.py
@@ -0,0 +1,193 @@
+"""Tests for InternalStateStore."""
+
+import os
+import tempfile
+import pytest
+import asyncio
+from dialog_engine.internal_state_store import InternalStateStore
+
+
+@pytest.fixture
+def temp_db_path():
+ """Create a temporary database file path."""
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
+ db_path = f.name
+ yield db_path
+ # Cleanup
+ if os.path.exists(db_path):
+ os.unlink(db_path)
+
+
+@pytest.fixture
+def state_store(temp_db_path):
+ """Create an InternalStateStore instance with temporary database."""
+ store = InternalStateStore(db_path=temp_db_path)
+ yield store
+ # No explicit cleanup needed as temp file will be deleted
+
+
+@pytest.mark.asyncio
+async def test_update_and_get_state(state_store):
+ """Test updating and retrieving a single state."""
+ session_id = "test_session"
+ state_key = "emotion"
+ state_value = 75.5
+
+ # Update state
+ await state_store.update_state(session_id, state_key, state_value)
+
+ # Get state
+ retrieved_value = await state_store.get_state(session_id, state_key)
+ assert retrieved_value == state_value
+
+
+@pytest.mark.asyncio
+async def test_get_nonexistent_state(state_store):
+ """Test getting a state that doesn't exist."""
+ session_id = "nonexistent_session"
+ state_key = "nonexistent_state"
+
+ retrieved_value = await state_store.get_state(session_id, state_key)
+ assert retrieved_value is None
+
+
+@pytest.mark.asyncio
+async def test_list_states(state_store):
+ """Test listing all states for a session."""
+ session_id = "test_session"
+ states = {
+ "emotion": 75.5,
+ "affinity": 60.0,
+ "energy": 80.2
+ }
+
+ # Update multiple states
+ for key, value in states.items():
+ await state_store.update_state(session_id, key, value)
+
+ # List all states
+ retrieved_states = await state_store.list_states(session_id)
+ assert retrieved_states == states
+
+
+@pytest.mark.asyncio
+async def test_list_states_empty(state_store):
+ """Test listing states for a session with no states."""
+ session_id = "empty_session"
+
+ retrieved_states = await state_store.list_states(session_id)
+ assert retrieved_states == {}
+
+
+@pytest.mark.asyncio
+async def test_update_existing_state(state_store):
+ """Test updating an existing state value."""
+ session_id = "test_session"
+ state_key = "emotion"
+ initial_value = 50.0
+ updated_value = 80.0
+
+ # Set initial value
+ await state_store.update_state(session_id, state_key, initial_value)
+ assert await state_store.get_state(session_id, state_key) == initial_value
+
+ # Update to new value
+ await state_store.update_state(session_id, state_key, updated_value)
+ assert await state_store.get_state(session_id, state_key) == updated_value
+
+
+@pytest.mark.asyncio
+async def test_delete_state(state_store):
+ """Test deleting a specific state."""
+ session_id = "test_session"
+ state_key = "emotion"
+ state_value = 75.5
+
+ # Set initial state
+ await state_store.update_state(session_id, state_key, state_value)
+ assert await state_store.get_state(session_id, state_key) == state_value
+
+ # Delete state
+ deleted = await state_store.delete_state(session_id, state_key)
+ assert deleted is True
+ assert await state_store.get_state(session_id, state_key) is None
+
+
+@pytest.mark.asyncio
+async def test_delete_nonexistent_state(state_store):
+ """Test deleting a state that doesn't exist."""
+ session_id = "test_session"
+ state_key = "nonexistent_state"
+
+ # Try to delete nonexistent state
+ deleted = await state_store.delete_state(session_id, state_key)
+ assert deleted is False
+
+
+@pytest.mark.asyncio
+async def test_clear_session(state_store):
+ """Test clearing all states for a session."""
+ session_id = "test_session"
+ states = {
+ "emotion": 75.5,
+ "affinity": 60.0,
+ "energy": 80.2
+ }
+
+ # Set up states
+ for key, value in states.items():
+ await state_store.update_state(session_id, key, value)
+ assert await state_store.list_states(session_id) == states
+
+ # Clear session
+ deleted_count = await state_store.clear_session(session_id)
+ assert deleted_count == len(states)
+ assert await state_store.list_states(session_id) == {}
+
+
+@pytest.mark.asyncio
+async def test_clear_empty_session(state_store):
+ """Test clearing a session with no states."""
+ session_id = "empty_session"
+
+ deleted_count = await state_store.clear_session(session_id)
+ assert deleted_count == 0
+
+
+@pytest.mark.asyncio
+async def test_multiple_sessions_isolation(state_store):
+ """Test that different sessions don't interfere with each other."""
+ session1_id = "session1"
+ session2_id = "session2"
+
+ # Set different states for different sessions
+ await state_store.update_state(session1_id, "emotion", 70.0)
+ await state_store.update_state(session2_id, "emotion", 80.0)
+ await state_store.update_state(session1_id, "affinity", 60.0)
+
+ # Verify isolation
+ session1_states = await state_store.list_states(session1_id)
+ session2_states = await state_store.list_states(session2_id)
+
+ assert session1_states == {"emotion": 70.0, "affinity": 60.0}
+ assert session2_states == {"emotion": 80.0}
+
+
+@pytest.mark.asyncio
+async def test_invalid_db_path():
+ """Test behavior with invalid database path."""
+ # Store with None path should not crash
+ store = InternalStateStore(db_path=None)
+
+ # Operations should safely return defaults
+ assert await store.get_state("session", "key") is None
+ assert await store.list_states("session") == {}
+ assert await store.delete_state("session", "key") is False
+ assert await store.clear_session("session") == 0
+
+ # Update should not crash
+ await store.update_state("session", "key", 50.0) # Should not raise exception
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
\ No newline at end of file
diff --git a/services/dialog-engine/tests/unit/test_llm_functions.py b/services/dialog-engine/tests/unit/test_llm_functions.py
new file mode 100644
index 0000000..f004067
--- /dev/null
+++ b/services/dialog-engine/tests/unit/test_llm_functions.py
@@ -0,0 +1,231 @@
+"""Tests for LLM function calling functionality."""
+
+import pytest
+import json
+from unittest.mock import AsyncMock, MagicMock
+from dialog_engine.llm_functions import (
+ FUNCTION_DEFINITIONS,
+ handle_tool_call,
+ format_state_for_context,
+ create_state_system_message
+)
+
+
+@pytest.mark.asyncio
+async def test_handle_tool_call_update_internal_state_success():
+ """Test successful internal state update via tool call."""
+ # Setup
+ mock_state_store = AsyncMock()
+ mock_state_store.get_state.return_value = 75.5
+
+ tool_call = {
+ "name": "update_internal_state",
+ "arguments": json.dumps({
+ "state_key": "emotion",
+ "value": 75.5
+ })
+ }
+
+ session_id = "test_session"
+
+ # Execute
+ result = await handle_tool_call(tool_call, session_id, mock_state_store)
+
+ # Verify
+ assert result["success"] is True
+ assert result["state_key"] == "emotion"
+ assert result["new_value"] == 75.5
+ assert result["session_id"] == session_id
+ mock_state_store.update_state.assert_called_once_with(session_id, "emotion", 75.5)
+ mock_state_store.get_state.assert_called_once_with(session_id, "emotion")
+
+
+@pytest.mark.asyncio
+async def test_handle_tool_call_invalid_json():
+ """Test handling of invalid JSON in tool call arguments."""
+ mock_state_store = AsyncMock()
+
+ tool_call = {
+ "name": "update_internal_state",
+ "arguments": "invalid json"
+ }
+
+ result = await handle_tool_call(tool_call, "test_session", mock_state_store)
+
+ assert result["success"] is False
+ assert "Invalid arguments" in result["error"]
+
+
+@pytest.mark.asyncio
+async def test_handle_tool_call_missing_state_key():
+ """Test handling of missing state_key parameter."""
+ mock_state_store = AsyncMock()
+
+ tool_call = {
+ "name": "update_internal_state",
+ "arguments": json.dumps({
+ "value": 75.5
+ })
+ }
+
+ result = await handle_tool_call(tool_call, "test_session", mock_state_store)
+
+ assert result["success"] is False
+ assert "Missing required parameter: state_key" in result["error"]
+
+
+@pytest.mark.asyncio
+async def test_handle_tool_call_missing_value():
+ """Test handling of missing value parameter."""
+ mock_state_store = AsyncMock()
+
+ tool_call = {
+ "name": "update_internal_state",
+ "arguments": json.dumps({
+ "state_key": "emotion"
+ })
+ }
+
+ result = await handle_tool_call(tool_call, "test_session", mock_state_store)
+
+ assert result["success"] is False
+ assert "Missing required parameter: value" in result["error"]
+
+
+@pytest.mark.asyncio
+async def test_handle_tool_call_invalid_value_type():
+ """Test handling of invalid value type."""
+ mock_state_store = AsyncMock()
+
+ tool_call = {
+ "name": "update_internal_state",
+ "arguments": json.dumps({
+ "state_key": "emotion",
+ "value": "not_a_number"
+ })
+ }
+
+ result = await handle_tool_call(tool_call, "test_session", mock_state_store)
+
+ assert result["success"] is False
+ assert "must be a number" in result["error"]
+
+
+@pytest.mark.asyncio
+async def test_handle_tool_call_unknown_function():
+ """Test handling of unknown function name."""
+ mock_state_store = AsyncMock()
+
+ tool_call = {
+ "name": "unknown_function",
+ "arguments": json.dumps({
+ "state_key": "emotion",
+ "value": 75.5
+ })
+ }
+
+ result = await handle_tool_call(tool_call, "test_session", mock_state_store)
+
+ assert result["success"] is False
+ assert "Unknown function: unknown_function" in result["error"]
+
+
+@pytest.mark.asyncio
+async def test_handle_tool_call_store_exception():
+ """Test handling of exceptions from state store."""
+ mock_state_store = AsyncMock()
+ mock_state_store.update_state.side_effect = Exception("Database error")
+
+ tool_call = {
+ "name": "update_internal_state",
+ "arguments": json.dumps({
+ "state_key": "emotion",
+ "value": 75.5
+ })
+ }
+
+ result = await handle_tool_call(tool_call, "test_session", mock_state_store)
+
+ assert result["success"] is False
+ assert "Internal error" in result["error"]
+
+
+def test_format_state_for_context_empty():
+ """Test formatting empty states."""
+ result = format_state_for_context({})
+ assert result == "暂无内部状态数据"
+
+
+def test_format_state_for_context_basic():
+ """Test formatting basic states."""
+ states = {
+ "emotion": 75.5,
+ "affinity": 60.0
+ }
+ result = format_state_for_context(states)
+ assert "情绪: 75.5" in result
+ assert "好感度: 60.0" in result
+ assert "当前内部状态:" in result
+
+
+def test_format_state_for_context_unknown_keys():
+ """Test formatting states with unknown keys."""
+ states = {
+ "unknown_state": 50.0,
+ "another_unknown": 80.0
+ }
+ result = format_state_for_context(states)
+ assert "unknown_state: 50.0" in result
+ assert "another_unknown: 80.0" in result
+
+
+def test_create_state_system_message():
+ """Test creating system message with states."""
+ states = {
+ "emotion": 75.5,
+ "affinity": 60.0
+ }
+
+ result = create_state_system_message(states)
+
+ assert result["role"] == "system"
+ assert "情绪: 75.5" in result["content"]
+ assert "好感度: 60.0" in result["content"]
+ assert "语气和表达方式" in result["content"]
+
+
+def test_create_state_system_message_empty():
+ """Test creating system message with empty states."""
+ result = create_state_system_message({})
+
+ assert result["role"] == "system"
+ assert "暂无内部状态数据" in result["content"]
+
+
+def test_function_definitions_structure():
+ """Test that function definitions have the correct structure."""
+ assert len(FUNCTION_DEFINITIONS) == 1
+
+ func_def = FUNCTION_DEFINITIONS[0]
+ assert func_def["name"] == "update_internal_state"
+ assert "description" in func_def
+ assert "parameters" in func_def
+
+ params = func_def["parameters"]
+ assert params["type"] == "object"
+ assert "properties" in params
+ assert "required" in params
+
+ properties = params["properties"]
+ assert "state_key" in properties
+ assert "value" in properties
+ assert properties["state_key"]["type"] == "string"
+ assert properties["value"]["type"] == "number"
+
+ required = params["required"]
+ assert "state_key" in required
+ assert "value" in required
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
\ No newline at end of file