diff --git a/services/dialog-engine/README.md b/services/dialog-engine/README.md index 5fd21a7..5c7c382 100644 --- a/services/dialog-engine/README.md +++ b/services/dialog-engine/README.md @@ -7,6 +7,7 @@ FastAPI service that powers synchronous chat + audio flows. It now accepts raw a - `POST /chat/stream` – existing text SSE endpoint. - `POST /chat/audio` – accepts base64 audio payloads, runs ASR, returns JSON transcript/reply. - `POST /chat/audio/stream` – SSE stream that emits `asr-partial`, `asr-final`, `text-delta`, and `done` events. +- `POST /chat/vision` – accepts base64-encoded images plus optional prompts for multimodal reasoning. - `POST /tts/mock` – helper for synchronous TTS testing (requires `SYNC_TTS_STREAMING=true`). ### Example (Sync Audio) @@ -45,6 +46,7 @@ Use any SSE client (curl `-N`, Postman, or VS Code REST client) to hit `/chat/au | `ASR_WHISPER_CACHE_DIR` | Optional model cache path | unset | | `SYNC_TTS_STREAMING` | Enable `/tts/mock` audio push | `false` | | `ENABLE_ASYNC_EXT` | Enables outbox + analytics events | `false` | +| `VISION_MAX_BYTES` | Max accepted image payload size in bytes | `4194304` | | `OUTPUT_INGEST_WS_URL` | Output handler WS endpoint | `ws://localhost:8002/ws/ingest/tts` | ## Dependencies diff --git a/services/dialog-engine/src/dialog_engine/app.py b/services/dialog-engine/src/dialog_engine/app.py index 470de4a..003372a 100644 --- a/services/dialog-engine/src/dialog_engine/app.py +++ b/services/dialog-engine/src/dialog_engine/app.py @@ -24,6 +24,7 @@ logger = logging.getLogger(__name__) SYNC_TTS_STREAMING = os.getenv("SYNC_TTS_STREAMING", "false").lower() in {"1", "true", "yes", "on"} ENABLE_ASYNC_EXT = os.getenv("ENABLE_ASYNC_EXT", "false").lower() in {"1", "true", "yes", "on"} +VISION_MAX_BYTES = int(os.getenv("VISION_MAX_BYTES", 4 * 1024 * 1024)) _flush_task = None try: @@ -431,6 +432,86 @@ async def event_generator() -> AsyncGenerator[bytes, None]: return StreamingResponse(event_generator(), media_type="text/event-stream", headers=headers) +@app.post("/chat/vision") +async def chat_vision(request: Request) -> JSONResponse: + try: + body = await request.json() + except Exception: + raise HTTPException(status_code=400, detail="invalid json") + + session_id = str(body.get("sessionId") or "default") + raw_image = body.get("image") + if not isinstance(raw_image, str) or not raw_image.strip(): + raise HTTPException(status_code=400, detail="image required") + + try: + image_bytes = base64.b64decode(raw_image, validate=True) + except (binascii.Error, TypeError): + raise HTTPException(status_code=400, detail="invalid image encoding") + + if not image_bytes: + raise HTTPException(status_code=400, detail="image required") + if len(image_bytes) > VISION_MAX_BYTES: + raise HTTPException(status_code=413, detail="image payload too large") + + prompt_raw = body.get("prompt") + prompt = prompt_raw.strip() if isinstance(prompt_raw, str) else None + mime_type_raw = body.get("mimeType") + mime_type = ( + mime_type_raw.strip() + if isinstance(mime_type_raw, str) and mime_type_raw.strip() + else "image/png" + ) + + meta_raw = body.get("meta") + meta = dict(meta_raw) if isinstance(meta_raw, dict) else {} + meta.setdefault("input_mode", "image") + + image_b64 = base64.b64encode(image_bytes).decode("ascii") + + user_turn = "[图片输入]" + if prompt: + user_turn = f"[图片输入]\n提示: {prompt}" + await chat_service.remember_turn(session_id=session_id, role="user", content=user_turn) + + try: + result = await chat_service.describe_image( + session_id=session_id, + image_b64=image_b64, + prompt=prompt, + mime_type=mime_type, + meta=meta, + ) + except HTTPException: + raise + except Exception as exc: # pragma: no cover - guard downstream failures + logger.exception("chat.vision.failed", extra={"sessionId": session_id}) + raise HTTPException(status_code=502, detail="vision_failed") from exc + + reply_text = str(result.get("reply", "")) + prompt_text = str(result.get("prompt") or (prompt or "")) + stats = result.get("stats") or {} + + await chat_service.remember_turn(session_id=session_id, role="assistant", content=reply_text) + + response_payload = { + "sessionId": session_id, + "prompt": prompt_text, + "reply": reply_text, + "stats": stats, + } + + _emit_async_events( + session_id=session_id, + body=body, + transcript=user_turn, + reply_text=reply_text, + stats=stats, + ) + + return JSONResponse(response_payload) + + @app.post("/tts/mock") async def tts_mock(request: Request, background: BackgroundTasks): """M2: Trigger a mock TTS stream to Output's ingest WS for testing. diff --git a/services/dialog-engine/src/dialog_engine/chat_service.py b/services/dialog-engine/src/dialog_engine/chat_service.py index 1bc3fda..75746eb 100644 --- a/services/dialog-engine/src/dialog_engine/chat_service.py +++ b/services/dialog-engine/src/dialog_engine/chat_service.py @@ -85,6 +85,74 @@ async def stream_reply( ): yield delta + async def describe_image( + self, + session_id: str, + *, + image_b64: str, + prompt: str | None, + mime_type: str | None, + meta: Dict[str, Any] | None = None, + ) -> Dict[str, Any]: + meta = meta or {} + raw_prompt = (prompt or "").strip() + prompt_text = raw_prompt or "请描述这张图片。" + lang = str(meta.get("lang") or "zh") + + self._reset_metrics() + context_turns: List[MemoryTurn] = [] + ltm_snippets: List[str] = [] + + if self._settings.llm.enabled: + context_turns = await self._fetch_short_term_context(session_id=session_id) + ltm_snippets = await self._fetch_ltm_snippets( + session_id=session_id, + user_text=prompt_text, + meta=meta, + ) + self._log_context_info(len(context_turns), len(ltm_snippets)) + try: + reply_text = await self._generate_vision_reply( + session_id=session_id, + prompt_text=prompt_text, + image_b64=image_b64, + mime_type=mime_type or "image/png", + meta=meta, + context=context_turns, + ltm_snippets=ltm_snippets, + ) + self.last_source = "llm" + self.last_error = None + self.last_ttft_ms = None + self.last_token_count = self._estimate_tokens(reply_text) + stats = { + "chat": { + "source": self.last_source, + "tokens": self.last_token_count, + "ttft_ms": self.last_ttft_ms, + } + } + return {"reply": reply_text, "prompt": prompt_text, "stats": stats} + except LLMNotConfiguredError as exc: + self.last_error = "llm_not_configured" + self._log_llm_fallback(reason=str(exc)) + except Exception as exc: # pragma: no cover - defensive catch + self.last_error = exc.__class__.__name__ + self._log_llm_fallback(reason=repr(exc)) + + reply_text = self._craft_image_reply(raw_prompt, lang) + self.last_source = "mock" + self.last_ttft_ms = None + self.last_token_count = self._estimate_tokens(reply_text) + stats = { + "chat": { + "source": self.last_source, + "tokens": self.last_token_count, + "ttft_ms": self.last_ttft_ms, + } + } + return {"reply": reply_text, "prompt": prompt_text, "stats": stats} + async def _stream_llm( self, *, @@ -118,6 +186,44 @@ async def _stream_mock( await asyncio.sleep(0.02 + random.random() * 0.03) yield word + (" " if not word.endswith("\n") else "") + async def _generate_vision_reply( + self, + *, + session_id: str, + prompt_text: str, + image_b64: str, + mime_type: str, + meta: Dict[str, Any], + context: List[MemoryTurn], + ltm_snippets: List[str], + ) -> str: + client = await self._ensure_llm_client() + messages = self._compose_messages( + user_text=prompt_text, + meta=meta, + context=context, + ltm_snippets=ltm_snippets, + ) + content: list[Dict[str, Any]] = [] + if prompt_text: + content.append({"type": "text", "text": prompt_text}) + content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{image_b64}", + }, + } + ) + if messages: + messages[-1] = {"role": "user", "content": content} + else: # pragma: no cover - defensive path + messages = [{"role": "user", "content": content}] + extra_options: Dict[str, Any] = { + "extra_headers": {"x-session-id": session_id}, + } + return await client.generate_vision_reply(messages, extra_options=extra_options) + async def remember_turn(self, session_id: str, *, role: str, content: str) -> None: if not content or not content.strip(): return @@ -277,3 +383,27 @@ def _craft_reply(self, user_text: str, lang: str) -> str: f"You said: '{user_text.strip()}'. That sounds interesting! I'm here to chat whenever you like. " "Feel free to share what you're up to!" ) + + def _craft_image_reply(self, prompt_text: str, lang: str) -> str: + display_prompt = prompt_text.strip() if prompt_text else "" + if lang.lower().startswith("zh"): + if display_prompt: + return ( + f"这张图片听起来很有意思!虽然我暂时看不到实际画面," + f"但根据你的提示「{display_prompt}」我可以和你一起展开想象。" + "要不要再告诉我一些细节?" + ) + return ( + "这张图片看起来很有意思!虽然我暂时无法直接看到内容," + "但如果你描述更多细节,我会和你一起展开想象。" + ) + if display_prompt: + return ( + "That picture sounds fascinating! I can't see it directly right now, " + f"but with your hint \"{display_prompt}\" we can imagine it together. " + "Feel free to share more details!" + ) + return ( + "That picture sounds fascinating! I can't view it directly, but if you describe a few more details " + "we can imagine it together." + ) diff --git a/services/dialog-engine/src/dialog_engine/llm_client.py b/services/dialog-engine/src/dialog_engine/llm_client.py index c7ef099..57a3c0c 100644 --- a/services/dialog-engine/src/dialog_engine/llm_client.py +++ b/services/dialog-engine/src/dialog_engine/llm_client.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -ChatMessage = Dict[str, str] +ChatMessage = Dict[str, Any] class LLMNotConfiguredError(RuntimeError): @@ -153,5 +153,73 @@ async def stream_chat( raise RuntimeError("LLM streaming failed after retries") from last_error + async def generate_vision_reply( + self, + messages: Sequence[ChatMessage], + *, + model: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + timeout: Optional[float] = None, + extra_options: Optional[Dict[str, Any]] = None, + ) -> str: + """Issue a non-streaming chat completion for multimodal prompts.""" + + cfg = self._llm_cfg + params: Dict[str, Any] = { + "model": model or cfg.model, + "messages": list(messages), + "temperature": temperature if temperature is not None else cfg.temperature, + "max_tokens": max_tokens if max_tokens is not None else cfg.max_tokens, + "top_p": top_p if top_p is not None else cfg.top_p, + "frequency_penalty": cfg.frequency_penalty, + "presence_penalty": cfg.presence_penalty, + "timeout": timeout if timeout is not None else cfg.timeout, + } + if extra_options: + params.update(extra_options) + + attempt = 0 + last_error: Exception | None = None + total_attempts = cfg.retry_limit + 1 + + while attempt < total_attempts: + attempt += 1 + try: + resp = await self._client.chat.completions.create(**params) + logger.info( + "llm.vision.complete", + extra={ + "model": params.get("model"), + "attempt": attempt, + "max_tokens": params.get("max_tokens"), + }, + ) + choices = getattr(resp, "choices", []) + for choice in choices: + message = getattr(choice, "message", None) + if not message: + continue + content = getattr(message, "content", None) + if isinstance(content, str) and content.strip(): + return content + logger.warning( + "llm.vision.empty", + extra={"model": params.get("model"), "attempt": attempt}, + ) + raise RuntimeError("vision_no_content") + except Exception as exc: # pragma: no cover - defensive catch + last_error = exc + logger.warning( + "llm.vision.error", + extra={"attempt": attempt, "model": params.get("model"), "error": repr(exc)}, + ) + if attempt >= total_attempts: + break + await asyncio.sleep(cfg.retry_backoff_seconds * attempt) + + raise RuntimeError("LLM vision completion failed after retries") from last_error + __all__ = ["OpenAIChatClient", "LLMNotConfiguredError", "LLMStreamEmptyError", "ChatMessage"] diff --git a/services/dialog-engine/tests/unit/test_chat_service.py b/services/dialog-engine/tests/unit/test_chat_service.py index af07195..1400be5 100644 --- a/services/dialog-engine/tests/unit/test_chat_service.py +++ b/services/dialog-engine/tests/unit/test_chat_service.py @@ -1,4 +1,5 @@ import asyncio +import base64 from typing import Iterable, List, Optional import pytest @@ -67,9 +68,11 @@ def _make_settings( class _StubLLMClient: - def __init__(self, responses: Iterable[str]) -> None: + def __init__(self, responses: Iterable[str], *, vision_reply: str = "Vision response") -> None: self._responses = list(responses) self.calls: List[list[dict[str, str]]] = [] + self.vision_calls: List[list[dict[str, object]]] = [] + self._vision_reply = vision_reply async def stream_chat(self, messages, **kwargs): self.calls.append(list(messages)) @@ -77,6 +80,11 @@ async def stream_chat(self, messages, **kwargs): await asyncio.sleep(0) yield token + async def generate_vision_reply(self, messages, **kwargs): + self.vision_calls.append(list(messages)) + await asyncio.sleep(0) + return self._vision_reply + class _FailingLLMClient: async def stream_chat(self, messages, **kwargs): @@ -84,6 +92,9 @@ async def stream_chat(self, messages, **kwargs): yield "" # pragma: no cover - ensure object is async generator raise RuntimeError("boom") + async def generate_vision_reply(self, messages, **kwargs): + raise RuntimeError("vision boom") + class _EmptyLLMClient: async def stream_chat(self, messages, **kwargs): @@ -91,6 +102,9 @@ async def stream_chat(self, messages, **kwargs): yield "" # pragma: no cover raise LLMStreamEmptyError("no content", tool_calls=[{"name": "dummy"}]) + async def generate_vision_reply(self, messages, **kwargs): + raise LLMStreamEmptyError("no content") + class _StubMemoryStore: def __init__(self, turns: Iterable[MemoryTurn]) -> None: @@ -247,3 +261,49 @@ async def test_stream_reply_llm_logs_context_counts(caplog): assert records assert records[0].stm_turns == 2 assert records[0].ltm_snippets == 1 + + +@pytest.mark.asyncio +async def test_describe_image_with_llm(): + stub = _StubLLMClient(["ignored"], vision_reply="这是一只可爱的猫咪。") + service = ChatService( + settings=_make_settings(enabled=True), + llm_client_factory=lambda: stub, + ) + + image_b64 = base64.b64encode(b"fake-bytes").decode("ascii") + result = await service.describe_image( + "sess-vision", + image_b64=image_b64, + prompt="请描述这张图片", + mime_type="image/png", + meta={"lang": "zh"}, + ) + + assert result["reply"].startswith("这是一只可爱的猫咪") + assert result["prompt"] == "请描述这张图片" + assert service.last_source == "llm" + assert stub.vision_calls + last_message = stub.vision_calls[0][-1] + assert last_message["role"] == "user" + content_items = last_message["content"] + assert any(item.get("type") == "image_url" for item in content_items) + + +@pytest.mark.asyncio +async def test_describe_image_mock_fallback_when_llm_disabled(): + service = ChatService(settings=_make_settings(enabled=False)) + image_b64 = base64.b64encode(b"fake-bytes").decode("ascii") + + result = await service.describe_image( + "sess-vision-mock", + image_b64=image_b64, + prompt=None, + mime_type="image/jpeg", + meta={"lang": "en"}, + ) + + assert "imagine" in result["reply"].lower() + assert result["prompt"] == "请描述这张图片。" + assert service.last_source == "mock" + assert service.last_token_count > 0 diff --git a/services/input-handler-python/main.py b/services/input-handler-python/main.py index 06b780a..f4444ef 100644 --- a/services/input-handler-python/main.py +++ b/services/input-handler-python/main.py @@ -28,6 +28,7 @@ DIALOG_ENGINE_URL = os.getenv("DIALOG_ENGINE_URL", "http://localhost:8100") TEXT_STREAM_ENDPOINT = "/chat/stream" AUDIO_ENDPOINT = "/chat/audio" +VISION_ENDPOINT = "/chat/vision" HTTP_TIMEOUT = httpx.Timeout(60.0, connect=5.0, read=60.0, write=10.0) # 临时文件存储目录 @@ -115,10 +116,7 @@ async def _handle_upload(self, websocket: WebSocket, task_id: str): if data.get("action") == "data_chunk": # 元数据消息,记录类型信息 - self.metadata[task_id] = { - "type": data["type"], - "chunk_id": data["chunk_id"] - } + self.metadata[task_id] = dict(data) if data["chunk_id"] != expected_chunk_id: await websocket.send_text( f"Chunk ID mismatch: expected {expected_chunk_id}, got {data['chunk_id']}" @@ -173,6 +171,23 @@ async def _process_upload(self, websocket: WebSocket, task_id: str): with open(input_file, "wb") as f: f.write(combined_data) logger.info(f"Saved audio input for task {task_id}, size: {len(combined_data)} bytes") + + elif data_type == "image": + meta = self.metadata.get(task_id, {}) + if isinstance(meta, dict): + mime_type = meta.get("mime_type") or meta.get("content_type") + else: + mime_type = None + file_suffix = self._infer_image_suffix(mime_type) + input_file = task_dir / f"input{file_suffix}" + with open(input_file, "wb") as f: + f.write(combined_data) + logger.info( + "Saved image input for task %s, size: %d bytes, mime: %s", + task_id, + len(combined_data), + mime_type or "unknown", + ) # 发送处理确认 await websocket.send_text(json.dumps({ @@ -186,6 +201,22 @@ async def _process_upload(self, websocket: WebSocket, task_id: str): asyncio.create_task(self._handle_text_task(task_id, content or "")) elif data_type == "audio": asyncio.create_task(self._handle_audio_task(task_id, input_file)) + elif data_type == "image": + meta = self.metadata.get(task_id, {}) + prompt = meta.get("prompt") if isinstance(meta, dict) else None + extra_meta = meta.get("meta") if isinstance(meta, dict) else None + mime_type = ( + meta.get("mime_type") or meta.get("content_type") + ) if isinstance(meta, dict) else None + asyncio.create_task( + self._handle_image_task( + task_id, + input_file, + prompt, + mime_type, + extra_meta if isinstance(extra_meta, dict) else None, + ) + ) else: logger.warning(f"Unsupported data type '{data_type}' for task {task_id}") @@ -235,6 +266,38 @@ async def _handle_audio_task(self, task_id: str, audio_file: Path) -> None: logger.error(f"Dialog-engine audio handling failed for task {task_id}: {exc}") await self._publish_error(task_id, str(exc) or "dialog_engine_failed") + async def _handle_image_task( + self, + task_id: str, + image_file: Path, + prompt: Optional[str], + mime_type: Optional[str], + meta: Optional[Dict[str, Any]], + ) -> None: + try: + result = await self._invoke_dialog_engine_image( + task_id, + image_file, + prompt=prompt, + mime_type=mime_type, + meta=meta, + ) + payload: Dict[str, Any] = { + "status": "success", + "sessionId": task_id, + "text": result.get("reply", ""), + "transcript": result.get("prompt", ""), + "stats": result.get("stats"), + "source": "dialog-engine", + "input_mode": "image", + } + if meta: + payload["meta"] = meta + await self._publish_response(task_id, payload) + except Exception as exc: + logger.error(f"Dialog-engine image handling failed for task {task_id}: {exc}") + await self._publish_error(task_id, str(exc) or "dialog_engine_failed") + async def _publish_response(self, task_id: str, payload: Dict[str, Any]) -> None: if not redis_client: logger.error("Redis client not available; cannot publish response") @@ -339,6 +402,45 @@ async def _invoke_dialog_engine_audio(self, task_id: str, audio_file: Path) -> D detail = exc.response.text raise RuntimeError(f"dialog_engine_audio_failed:{detail}") from exc + async def _invoke_dialog_engine_image( + self, + task_id: str, + image_file: Path, + *, + prompt: Optional[str], + mime_type: Optional[str], + meta: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + url = f"{DIALOG_ENGINE_URL.rstrip('/')}{VISION_ENDPOINT}" + try: + image_bytes = image_file.read_bytes() + except Exception as exc: + raise RuntimeError(f"read_image_failed:{exc}") from exc + if not image_bytes: + raise RuntimeError("image_payload_empty") + image_b64 = base64.b64encode(image_bytes).decode("ascii") + body: Dict[str, Any] = { + "sessionId": task_id, + "image": image_b64, + } + if prompt: + body["prompt"] = prompt + if mime_type: + body["mimeType"] = mime_type + if meta: + body["meta"] = meta + try: + async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: + resp = await client.post(url, json=body) + resp.raise_for_status() + return resp.json() + except httpx.HTTPStatusError as exc: + try: + detail = exc.response.json() + except ValueError: + detail = exc.response.text + raise RuntimeError(f"dialog_engine_image_failed:{detail}") from exc + @staticmethod def _infer_content_type(suffix: str) -> str: mapping = { @@ -348,6 +450,19 @@ def _infer_content_type(suffix: str) -> str: ".m4a": "audio/mp4", } return mapping.get(suffix, "audio/wav") + + @staticmethod + def _infer_image_suffix(mime_type: Optional[str]) -> str: + mapping = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/jpg": ".jpg", + "image/webp": ".webp", + "image/gif": ".gif", + } + if not mime_type: + return ".png" + return mapping.get(mime_type.lower(), ".png") def _cleanup_task_data(self, task_id: str): """清理任务相关的临时数据""" @@ -376,8 +491,8 @@ async def get():
专用于处理用户输入的WebSocket服务