From 923db6d8356a755ec3ed9ce03d03a4140425e2d5 Mon Sep 17 00:00:00 2001 From: qweqwe21321 <1135233347@qq.com> Date: Tue, 19 May 2026 19:55:56 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20async/await=E7=BC=BA=E5=A4=B1,=20?= =?UTF-8?q?=E6=96=AD=E8=A8=80=E6=AE=8B=E7=95=99,=20=E6=9C=AA=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=E5=8F=98=E9=87=8F,=20exit()=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aag/computing_engine/mcp_client.py | 259 +++--- .../text_2_graph/text_2_graph.py | 378 +++++--- aag/dataload_main.py | 2 +- aag/main.py | 245 +++-- aag/reasoner/model_deployment.py | 867 ++++++++++-------- web/frontend/route/sockets_chat.py | 242 +++-- 6 files changed, 1218 insertions(+), 775 deletions(-) diff --git a/aag/computing_engine/mcp_client.py b/aag/computing_engine/mcp_client.py index 8f038be..528f2d7 100644 --- a/aag/computing_engine/mcp_client.py +++ b/aag/computing_engine/mcp_client.py @@ -26,7 +26,12 @@ class GraphMCPClient: """Graph computation MCP client with schema loaded from files.""" - def __init__(self, server_command: str = "python", server_args: List[str] = None, schema_dir: Optional[str] = None): + def __init__( + self, + server_command: str = "python", + server_args: List[str] = None, + schema_dir: Optional[str] = None, + ): self.server_command = server_command self.server_args = server_args or ["networkx_server/mcp_server.py"] self.schema_dir = Path(schema_dir) if schema_dir else None @@ -34,46 +39,41 @@ def __init__(self, server_command: str = "python", server_args: List[str] = None self.available_tools = {} self.input_schemas_cache = {} self.output_schemas_cache = {} - self.code_executor = DynamicCodeExecutor( - timeout=120, - auto_install=True - ) - + self.code_executor = DynamicCodeExecutor(timeout=120, auto_install=True) + async def connect(self): """Connect to the MCP server.""" try: server_params = StdioServerParameters( - command=self.server_command, - args=self.server_args + command=self.server_command, args=self.server_args ) - + self.stdio_client = stdio_client(server_params) self.read_stream, self.write_stream = await self.stdio_client.__aenter__() self.session = ClientSession(self.read_stream, self.write_stream) - - + await self.session.__aenter__() - await self.session.initialize() + await self.session.initialize() await self._load_available_tools() - + logger.info("✅ Connected to MCP server") return True - + except Exception as e: logger.error(f"❌ Failed to connect to MCP server: {e}") return False - + async def disconnect(self): """Disconnect from the MCP server.""" try: if self.session: await self.session.__aexit__(None, None, None) - if hasattr(self, 'stdio_client'): + if hasattr(self, "stdio_client"): await self.stdio_client.__aexit__(None, None, None) logger.info("👋 Disconnected from MCP server") except Exception as e: logger.error(f"❌ Error while disconnecting: {e}") - + def _load_schemas(self, schema_type: str) -> Dict[str, Dict[str, Any]]: """ Load schemas from file (input or output). @@ -89,17 +89,26 @@ def _load_schemas(self, schema_type: str) -> Dict[str, Dict[str, Any]]: elif self.server_args: if len(self.server_args) >= 2 and self.server_args[0] == "-m": module_name = self.server_args[1] - module_path = module_name.replace('.', '/') - module_dir = '/'.join(module_path.split('/')[:-1]) - schema_file = Path(__file__).parent.parent.parent / module_dir / f"generated_{schema_type}_schemas.json" + module_path = module_name.replace(".", "/") + module_dir = "/".join(module_path.split("/")[:-1]) + schema_file = ( + Path(__file__).parent.parent.parent + / module_dir + / f"generated_{schema_type}_schemas.json" + ) else: - schema_file = Path(self.server_args[0]).parent / f"generated_{schema_type}_schemas.json" + schema_file = ( + Path(self.server_args[0]).parent + / f"generated_{schema_type}_schemas.json" + ) else: - schema_file = Path(__file__).parent / f"generated_{schema_type}_schemas.json" - + schema_file = ( + Path(__file__).parent / f"generated_{schema_type}_schemas.json" + ) + if schema_file.exists(): try: - with open(schema_file, 'r', encoding='utf-8') as f: + with open(schema_file, "r", encoding="utf-8") as f: schemas = json.load(f) logger.info(f"✅ Loaded {schema_type} schemas for {len(schemas)} tools") logger.info(f" File: {schema_file.absolute()}") @@ -109,42 +118,46 @@ def _load_schemas(self, schema_type: str) -> Dict[str, Dict[str, Any]]: else: logger.warning(f"⚠️ {schema_type} schemas file not found: {schema_file}") logger.info(" Hint: run MCP Server first to generate schemas") - + return {} - + async def _load_available_tools(self): """Load and cache available tools (using preloaded schemas).""" try: tools_response = await self.session.list_tools() self.input_schemas_cache = self._load_schemas("input") self.output_schemas_cache = self._load_schemas("output") - + for tool in tools_response.tools: tool_name = tool.name input_schema = self.input_schemas_cache.get(tool_name) output_schema = self.output_schemas_cache.get(tool_name) if not input_schema: input_schema = tool.inputSchema - logger.debug(f"⚠️ Tool '{tool_name}' using server-provided input_schema") - + logger.debug( + f"⚠️ Tool '{tool_name}' using server-provided input_schema" + ) + self.available_tools[tool_name] = { - 'name': tool_name, - 'description': tool.description, - 'input_schema': input_schema, - 'output_schema': output_schema + "name": tool_name, + "description": tool.description, + "input_schema": input_schema, + "output_schema": output_schema, } - + logger.info(f"✅ Loaded {len(self.available_tools)} available tools") except Exception as e: logger.error(f"❌ Failed to load tool list: {e}") - + async def list_tools(self) -> List[Dict[str, Any]]: """List all available tools.""" if not self.available_tools: await self._load_available_tools() return list(self.available_tools.values()) - - def get_tool_schema(self, tool_name: str, schema_type: str = 'input') -> Optional[Dict[str, Any]]: + + def get_tool_schema( + self, tool_name: str, schema_type: str = "input" + ) -> Optional[Dict[str, Any]]: """ Get schema for a tool. @@ -156,17 +169,17 @@ def get_tool_schema(self, tool_name: str, schema_type: str = 'input') -> Optiona if not tool: logger.warning(f"⚠️ Tool '{tool_name}' not found") return None - + schema_key = f"{schema_type}_schema" schema = tool.get(schema_key) - + if not schema: logger.debug(f"⚠️ Tool '{tool_name}' has no {schema_type}_schema") - if schema_type == 'output': + if schema_type == "output": return self._get_default_output_schema() - + return schema - + def _get_default_output_schema(self) -> Dict[str, Any]: """Return default output schema.""" return { @@ -176,41 +189,53 @@ def _get_default_output_schema(self) -> Dict[str, Any]: "success": {"type": "boolean"}, "result": {"description": "Algorithm result (type varies)"}, "error": {"type": ["string", "null"]}, - "summary": {"type": ["string", "null"]} + "summary": {"type": ["string", "null"]}, }, - "required": ["algorithm", "success"] + "required": ["algorithm", "success"], } - - def validate_arguments(self, tool_name: str, arguments: Dict[str, Any]) -> tuple[bool, Optional[str]]: + + def validate_arguments( + self, tool_name: str, arguments: Dict[str, Any] + ) -> tuple[bool, Optional[str]]: """Validate tool arguments against schema.""" - schema = self.get_tool_schema(tool_name, 'input') + schema = self.get_tool_schema(tool_name, "input") if not schema: return False, f"Tool '{tool_name}' not found or missing input schema" - - properties = schema.get('parameters') or schema.get('properties') or {} - required = schema.get('required', []) + + properties = schema.get("parameters") or schema.get("properties") or {} + required = schema.get("required", []) for req_param in required: if req_param not in arguments: return False, f"Missing required parameter: {req_param}" - + for param_name in arguments.keys(): - if param_name not in properties and param_name != "__post_processing_code__": + if ( + param_name not in properties + and param_name != "__post_processing_code__" + ): logger.warning(f"⚠️ Parameter '{param_name}' not in schema") - + return True, None - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any], post_processing_code: Optional[str] = None, global_graph: Optional[GraphData] = None, validate: bool = True) -> Dict[str, Any]: + + async def call_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + post_processing_code: Optional[str] = None, + global_graph: Optional[GraphData] = None, + validate: bool = True, + ) -> Dict[str, Any]: """ Call MCP tool (supports optional post-processing code) - + Args: tool_name: Tool name arguments: Tool arguments post_processing_code: Optional post-processing code validate: Whether to validate arguments - + Returns: Tool execution result dictionary """ @@ -221,41 +246,55 @@ async def call_tool(self, tool_name: str, arguments: Dict[str, Any], post_proces # Convert empty containers to None (avoid empty container traps in libraries like NetworkX) if isinstance(value, (dict, list)) and len(value) == 0: cleaned_arguments[key] = None - logger.warning(f"⚠️ Parameter '{key}' is an empty container, converted to None") + logger.warning( + f"⚠️ Parameter '{key}' is an empty container, converted to None" + ) else: cleaned_arguments[key] = value # ========== Argument validation ========== if validate: - is_valid, error_msg = self.validate_arguments(tool_name, arguments) + is_valid, error_msg = self.validate_arguments( + tool_name, cleaned_arguments + ) if not is_valid: logger.error(f"❌ Argument validation failed: {error_msg}") return { "success": False, "error": f"Argument validation failed: {error_msg}", - "summary": "Invalid argument format" + "summary": "Invalid argument format", } - + # ========== Step 1: Call original tool ========== logger.info(f"📤 Calling tool {tool_name}...") - result = await self.session.call_tool(tool_name, arguments=arguments) + result = await self.session.call_tool( + tool_name, arguments=cleaned_arguments + ) if result.content and len(result.content) > 0: content = result.content[0] - if hasattr(content, 'text'): + if hasattr(content, "text"): try: original_response = json.loads(content.text) # logger.info(f"✅ Tool execution completed: {original_response.get('summary', '')}") except json.JSONDecodeError as e: raw_content = content.text if content.text else "Empty response" - logger.error(f"JSON parsing failed: {e} - Raw response: {raw_content}") + logger.error( + f"JSON parsing failed: {e} - Raw response: {raw_content}" + ) return { - "success": False, + "success": False, "error": raw_content, - "summary": "Server returned non-JSON response" + "summary": "Server returned non-JSON response", } else: - return {"success": False, "error": "Tool call failed: Unable to parse tool call result"} + return { + "success": False, + "error": "Tool call failed: Unable to parse tool call result", + } else: - return {"success": False, "error": "Tool call failed: Unable to parse tool call result"} + return { + "success": False, + "error": "Tool call failed: Unable to parse tool call result", + } # ========== Step 2: If post-processing code exists, call post-processing tool ========== if post_processing_code: @@ -263,106 +302,110 @@ async def call_tool(self, tool_name: str, arguments: Dict[str, Any], post_proces try: # ✅ Key change: Use local executor for processing processed_data = self.code_executor.execute( - post_processing_code, + post_processing_code, original_response.get("result"), # Only pass the result part - global_graph=global_graph + global_graph=global_graph, ) - + # Update return result original_response["result"] = processed_data - original_response["summary"] = original_response.get("summary", "") + " (Local post-processing applied)" - + original_response["summary"] = ( + original_response.get("summary", "") + + " (Local post-processing applied)" + ) + logger.info(f"✅ Post-processing execution completed") # logger.info(f"Key results extracted: {processed_data}") logger.info(f"post_processing_status:{original_response}") return original_response - + except Exception as post_error: logger.error(f"Post-processing execution failed: {post_error}") # logger.warning("⚠️ Post-processing failed, returning original result") # return original_response - return {"success": False, "error": f"❌ Post-processing code execution failed: {post_error}"} - + return { + "success": False, + "error": f"❌ Post-processing code execution failed: {post_error}", + } + return original_response - + except Exception as e: logger.error(f"❌ Tool call failed: {e}") - return { - "success": False, - "error": str(e), - "summary": "Tool call error" - } - + return {"success": False, "error": str(e), "summary": "Tool call error"} + def print_tool_info(self, tool_name: str): """Print detailed tool info.""" tool = self.available_tools.get(tool_name) if not tool: print(f"❌ Tool '{tool_name}' not found") return - - print(f"\n{'='*60}") + + print(f"\n{'=' * 60}") print(f"Tool: {tool['name']}") print(f"Description: {tool['description']}") - + print(f"\n--- Input Schema ---") - input_schema = tool.get('input_schema') + input_schema = tool.get("input_schema") if input_schema: print(json.dumps(input_schema, indent=2, ensure_ascii=False)) else: print("⚠️ No Input Schema") - + print(f"\n--- Output Schema ---") - output_schema = tool.get('output_schema') + output_schema = tool.get("output_schema") if output_schema: print(json.dumps(output_schema, indent=2, ensure_ascii=False)) else: print("⚠️ No Output Schema") - - print(f"{'='*60}\n") - + print(f"{'=' * 60}\n") + + # TODO: load_data_from_csv() 放在 MCP 客户端类中违反单一职责原则。 + # 该方法应迁移至 aag.data_pipeline 模块(如 CsvDataLoader)。 + # 暂不移动以避免破坏现有接口,迁移时需同步更新所有调用方。 def load_data_from_csv(self, accounts_file: str, transactions_file: str) -> tuple: """Load graph data from CSV files (with type handling).""" vertices = [] edges = [] - + try: logger.info(f"📖 Reading accounts: {accounts_file}") - with open(accounts_file, 'r', encoding='utf-8') as f: + with open(accounts_file, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: vertex = { - 'vid': row['acct_id'], - 'properties': {k: v for k, v in row.items() if k != 'acct_id'} + "vid": row["acct_id"], + "properties": {k: v for k, v in row.items() if k != "acct_id"}, } vertices.append(vertex) - + logger.info(f"📖 Reading transactions: {transactions_file}") - with open(transactions_file, 'r', encoding='utf-8') as f: + with open(transactions_file, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: properties = {} for k, v in row.items(): - if k not in ['orig_acct', 'bene_acct']: + if k not in ["orig_acct", "bene_acct"]: try: - if '.' in v: + if "." in v: properties[k] = float(v) else: properties[k] = int(v) except (ValueError, TypeError): properties[k] = v - + edge = { - 'src': row['orig_acct'], - 'dst': row['bene_acct'], - 'rank': 0, - 'properties': properties + "src": row["orig_acct"], + "dst": row["bene_acct"], + "rank": 0, + "properties": properties, } edges.append(edge) - + logger.info(f"✅ Loaded {len(vertices)} vertices and {len(edges)} edges") return vertices, edges - + except Exception as e: logger.error(f"❌ CSV data load failed: {e}") raise @@ -378,4 +421,4 @@ async def create_client_and_connect() -> GraphMCPClient: if __name__ == "__main__": - print("GraphMCPClient – schema file version") \ No newline at end of file + print("GraphMCPClient – schema file version") diff --git a/aag/data_pipeline/data_transformer/text_2_graph/text_2_graph.py b/aag/data_pipeline/data_transformer/text_2_graph/text_2_graph.py index 6a44103..eaa8f7d 100644 --- a/aag/data_pipeline/data_transformer/text_2_graph/text_2_graph.py +++ b/aag/data_pipeline/data_transformer/text_2_graph/text_2_graph.py @@ -151,12 +151,21 @@ def _llm_completion_text(raw: Any) -> str: """ - class Text2Graph: - def __init__(self, type: str, file_path: str, graph_name: str, llm_name: str, api_key: str, chunk_size: int = 512, base_url: str = None, thread_count: int = 1): + def __init__( + self, + type: str, + file_path: str, + graph_name: str, + llm_name: str, + api_key: str, + chunk_size: int = 512, + base_url: str = None, + thread_count: int = 1, + ): """ 初始化文本到图转换器 - + Args: file_path: 输入的文本文件路径 """ @@ -169,16 +178,22 @@ def __init__(self, type: str, file_path: str, graph_name: str, llm_name: str, ap if not file_exist(self.file_path): raise FileNotFoundError(f"文本文件未找到: {self.file_path}") - + # sentences = re.split(r'(?<=[。!?\n])', self.read_file_path().strip()) # 按句子切分 markdown_text = self.read_file_path_markitdown() self.markdown_text_chars = len(markdown_text) self.markdown_text_bytes = len(markdown_text.encode("utf-8")) - self.source_file_size_bytes = os.path.getsize(self.file_path) if os.path.exists(self.file_path) else None - self.markdown_file_size_bytes = os.path.getsize(self.md_file_path) if os.path.exists(self.md_file_path) else None - sentences = re.split(r'(?<=[。!?\n])', markdown_text.strip()) # 按句子切分 + self.source_file_size_bytes = ( + os.path.getsize(self.file_path) if os.path.exists(self.file_path) else None + ) + self.markdown_file_size_bytes = ( + os.path.getsize(self.md_file_path) + if os.path.exists(self.md_file_path) + else None + ) + sentences = re.split(r"(?<=[。!?\n])", markdown_text.strip()) # 按句子切分 self.text_chunks = [] - current_chunk = '' + current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) <= chunk_size: @@ -189,12 +204,16 @@ def __init__(self, type: str, file_path: str, graph_name: str, llm_name: str, ap if current_chunk: self.text_chunks.append(current_chunk) - - print(f"文本已加载,共 {len(self.text_chunks)} 个块, 每块约 {chunk_size} 字符。") + + print( + f"文本已加载,共 {len(self.text_chunks)} 个块, 每块约 {chunk_size} 字符。" + ) if type == "ollama": - self.llm = OllamaEnv(llm_mode_name = llm_name) + self.llm = OllamaEnv(llm_mode_name=llm_name) elif type == "openai": - self.llm = OpenAIEnv(api_key = api_key, model_name = llm_name, base_url = base_url) + self.llm = OpenAIEnv( + api_key=api_key, model_name=llm_name, base_url=base_url + ) print(f"使用 OpenAI 模型: {llm_name}") else: raise ValueError(f"不支持的类型: {type}") @@ -228,7 +247,9 @@ def _merge_parse_metrics(dst: Dict[str, Any], src: Dict[str, Any]) -> None: ): dst[k] = (dst.get(k, 0) or 0) + (src.get(k, 0) or 0) - def _invoke_llm_with_metrics(self, prompt: str) -> Tuple[str, Dict[str, Optional[int]], int]: + def _invoke_llm_with_metrics( + self, prompt: str + ) -> Tuple[str, Dict[str, Optional[int]], int]: """ 统一执行一次 LLM 调用并返回文本、token usage 与耗时(毫秒)。 OpenAI 模型可拿到 usage;其他模型 usage 为空值。 @@ -245,7 +266,11 @@ def _invoke_llm_with_metrics(self, prompt: str) -> Tuple[str, Dict[str, Optional model=self.llm.model, messages=[{"role": "user", "content": prompt}], ) - text = (resp.choices[0].message.content or "") if getattr(resp, "choices", None) else "" + text = ( + (resp.choices[0].message.content or "") + if getattr(resp, "choices", None) + else "" + ) u = getattr(resp, "usage", None) if u is not None: usage["prompt_tokens"] = getattr(u, "prompt_tokens", None) @@ -271,41 +296,55 @@ def _generate_with_metrics(self, prompt: str) -> Tuple[str, Dict[str, Any]]: messages=[{"role": "user", "content": prompt}], ) usage_obj = getattr(resp, "usage", None) - text = (resp.choices[0].message.content or "") if getattr(resp, "choices", None) else "" + text = ( + (resp.choices[0].message.content or "") + if getattr(resp, "choices", None) + else "" + ) else: raw = self.llm.generate_response(query=prompt) text = _llm_completion_text(raw) elapsed_ms = int((time.perf_counter() - t0) * 1000) metric = { "elapsed_ms": elapsed_ms, - "prompt_tokens": getattr(usage_obj, "prompt_tokens", None) if usage_obj is not None else None, - "completion_tokens": getattr(usage_obj, "completion_tokens", None) if usage_obj is not None else None, - "total_tokens": getattr(usage_obj, "total_tokens", None) if usage_obj is not None else None, + "prompt_tokens": getattr(usage_obj, "prompt_tokens", None) + if usage_obj is not None + else None, + "completion_tokens": getattr(usage_obj, "completion_tokens", None) + if usage_obj is not None + else None, + "total_tokens": getattr(usage_obj, "total_tokens", None) + if usage_obj is not None + else None, } return text, metric def read_file_path(self) -> str: import os + ext = os.path.splitext(self.file_path)[1].lower() text = "" if ext == ".txt": - with open(self.file_path, 'r', encoding='utf-8') as f: + with open(self.file_path, "r", encoding="utf-8") as f: text = f.read() - + elif ext == ".docx": from docx import Document + doc = Document(self.file_path) text = "\n".join([p.text for p in doc.paragraphs]) - + elif ext == ".doc": import mammoth + with open(self.file_path, "rb") as f: result = mammoth.extract_raw_text(f) text = result.value - + elif ext == ".pdf": import fitz # PyMuPDF + # 打开 PDF doc = fitz.open(self.file_path) # 提取每页文本 @@ -317,26 +356,28 @@ def read_file_path(self) -> str: # 拼接所有页 text = "\n".join(text_list) # 清理多余空白和换行,保证句子连续 - text = re.sub(r'\r\n|\r', '\n', text) # 统一换行 - text = re.sub(r'\n+', '\n', text) # 多个换行合并为一个 - text = re.sub(r'[ \t]+', ' ', text) # 多空格缩成一个空格 - text = text.strip() # 去掉首尾空白 - + text = re.sub(r"\r\n|\r", "\n", text) # 统一换行 + text = re.sub(r"\n+", "\n", text) # 多个换行合并为一个 + text = re.sub(r"[ \t]+", " ", text) # 多空格缩成一个空格 + text = text.strip() # 去掉首尾空白 + else: raise ValueError(f"不支持的文件类型: {ext}") return text - + def read_file_path_markitdown(self) -> str: """ 将用户上传的文件转为 Markdown,并在 schema YAML 中追加一条 text dataset。 """ import os + base_path, _ = os.path.splitext(self.file_path) self.md_file_path = f"{base_path}.md" # 1️⃣ 转成 Markdown from markitdown import MarkItDown + md = MarkItDown() result = md.convert(self.file_path) markdown_text = result.text_content @@ -346,23 +387,24 @@ def read_file_path_markitdown(self) -> str: f.write(markdown_text) return markdown_text - - def extract_graph(self, MAX_RETRIES = 5): + + def extract_graph(self, MAX_RETRIES=5): """ 从文本数据中提取知识图谱 - + Returns: 三元组列表 """ print("开始从文本中提取知识图谱...") - + triplets = [] entities = [] for idx, chunk in enumerate(tqdm(self.text_chunks)): - for attempt in range(1, MAX_RETRIES + 1): - raw = self.llm.generate_response(query=prompt_template_str.format(context=chunk)) + raw = self.llm.generate_response( + query=prompt_template_str.format(context=chunk) + ) text = _llm_completion_text(raw) if not text: print(f"⚠️ 块 {idx} 第 {attempt} 次尝试未获得响应") @@ -380,7 +422,7 @@ def extract_graph(self, MAX_RETRIES = 5): # 5 次都不成功,跳过该块 print(f"❌ 块 {idx} 超过 {MAX_RETRIES} 次尝试仍失败,跳过") continue - + # 注释掉是应为提取的实体和三元组不一样,有可能对应不上,我们只用提取到的三元组就OK了 # entity_label = {} # if isinstance(response["entities"], List): @@ -405,9 +447,9 @@ def extract_graph(self, MAX_RETRIES = 5): for triplet in response.get("triplets", []): # 1️⃣ 必须是列表/元组且长度为3 if not isinstance(triplet, (list, tuple)) or len(triplet) != 3: - #print(f"⚠️ 无效三元组,跳过: {triplet}") + # print(f"⚠️ 无效三元组,跳过: {triplet}") continue - + # 2️⃣ 遍历每个元素,处理字符串大小写 processed_triplet = [] for phrase in triplet: @@ -415,34 +457,35 @@ def extract_graph(self, MAX_RETRIES = 5): processed_triplet.append(phrase.strip().capitalize()) else: processed_triplet.append(phrase) # 保留原样(可能是数字或空) - + new_triplets.append(processed_triplet) triplets.extend(new_triplets) - #print(f"提取完成,共获得 {len(triplets)} 个三元组。") + # print(f"提取完成,共获得 {len(triplets)} 个三元组。") return triplets def save_triplet(self, triplets: List[List[str]]): import os, csv + dir_path = os.path.dirname(self.file_path) base_name = os.path.splitext(os.path.basename(self.file_path))[0] entities_csv_path = os.path.join(dir_path, f"{base_name}_accounts.csv") triplets_csv_path = os.path.join(dir_path, f"{base_name}_transactions.csv") - + valid_triplets = [] for t in triplets: if len(t) != 3: - #print(f"⚠️ 无效三元组,跳过: {t}") + # print(f"⚠️ 无效三元组,跳过: {t}") continue head, rel, tail = t if not head or not tail or not rel: - #print(f"⚠️ 三元组元素为空,跳过: {t}") + # print(f"⚠️ 三元组元素为空,跳过: {t}") continue valid_triplets.append([head, rel, tail]) if not valid_triplets: - #print("⚠️ 没有合法三元组,退出。") + # print("⚠️ 没有合法三元组,退出。") return entities = {} @@ -465,8 +508,8 @@ def save_triplet(self, triplets: List[List[str]]): for head, rel, tail in valid_triplets: writer.writerow([entities[head], entities[tail], rel]) - #print(f"✅ 实体 CSV 保存到: {entities_csv_path}") - #print(f"✅ 边 CSV 保存到: {triplets_csv_path}") + # print(f"✅ 实体 CSV 保存到: {entities_csv_path}") + # print(f"✅ 边 CSV 保存到: {triplets_csv_path}") schema_path = os.path.join(dir_path, f"{base_name}_graph_schemas.yaml") graph_name = f"{base_name}_Graph" @@ -485,7 +528,7 @@ def save_triplet(self, triplets: List[List[str]]): "id_field": "acct_id", "label_field": "dsply_nm", "path": entities_csv_path, - "type": "account" + "type": "account", } ], "edge": [ @@ -497,14 +540,14 @@ def save_triplet(self, triplets: List[List[str]]): "source_field": "tran_id", "target_field": "orig_acct", "type": "transfer", - "weight_field": None + "weight_field": None, } ], "graph": { "directed": "true", "heterogeneous": "false", "multigraph": "false", - "weighted": "false" + "weighted": "false", }, "graph_store_info": { "backend": "nebula_graph", @@ -512,13 +555,14 @@ def save_triplet(self, triplets: List[List[str]]): "space_name": graph_name, "status": "success", "version": "null", - "vertex_count": len(entities) - } - } + "vertex_count": len(entities), + }, + }, } ] } import yaml + with open(schema_path, "w", encoding="utf-8") as f: yaml.dump(schema_dict, f, sort_keys=False, allow_unicode=True) @@ -530,34 +574,38 @@ def extract_graph_by_openie(self): from openie import StanfordOpenIE except ImportError: import subprocess, sys + package_name = "stanford-openie" - #print(f"⚙️ 检测到未安装依赖 '{package_name}',正在自动安装...") + # print(f"⚙️ 检测到未安装依赖 '{package_name}',正在自动安装...") try: - subprocess.check_call([sys.executable, "-m", "pip", "install", package_name]) - #print(f"✅ 成功安装 {package_name}") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", package_name] + ) + # print(f"✅ 成功安装 {package_name}") from openie import StanfordOpenIE # 再次导入 except subprocess.CalledProcessError as e: - #print(f"❌ 安装 {package_name} 失败,请手动安装:pip install {package_name}") + # print(f"❌ 安装 {package_name} 失败,请手动安装:pip install {package_name}") raise e - + triplets = [] for idx, chunk in enumerate(tqdm(self.text_chunks)): properties = { - 'openie.affinity_probability_cap': 2 / 3, + "openie.affinity_probability_cap": 2 / 3, } with StanfordOpenIE(properties=properties) as client: for triple in client.annotate(chunk): - head = triple['subject'].strip().capitalize() - relation = triple['relation'].strip().capitalize() - tail = triple['object'].strip().capitalize() + head = triple["subject"].strip().capitalize() + relation = triple["relation"].strip().capitalize() + tail = triple["object"].strip().capitalize() triplet = [head, relation, tail] triplets.append(triplet) - #print(f"提取完成,共获得 {len(triplets)} 个三元组。") + # print(f"提取完成,共获得 {len(triplets)} 个三元组。") print(triplets) - assert False, "stop here" return triplets - def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_schema_file_path, MAX_RETRIES = 5): + def extract_graph_and_entity_by_LLM( + self, each_dataset, file_name, each_dataset_schema_file_path, MAX_RETRIES=5 + ): """ 从文本数据中提取三元组与实体类型 @@ -579,10 +627,11 @@ def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_ token_observed = False for idx, chunk in enumerate(tqdm(self.text_chunks)): - for attempt in range(1, MAX_RETRIES + 1): total_calls += 1 - response, call_metric = self._generate_with_metrics(prompt_template_str.format(context=chunk)) + response, call_metric = self._generate_with_metrics( + prompt_template_str.format(context=chunk) + ) total_llm_elapsed_ms += call_metric.get("elapsed_ms", 0) or 0 pt = call_metric.get("prompt_tokens") ct = call_metric.get("completion_tokens") @@ -597,57 +646,60 @@ def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_ total_tokens_sum += tt token_observed = True if not response: - #print(f"⚠️ 块 {idx} 第 {attempt} 次无响应") + # print(f"⚠️ 块 {idx} 第 {attempt} 次无响应") if attempt > 1: total_retries += 1 continue - + cleaned = response.strip() # remove markdown mark - cleaned = re.sub(r"^```[a-zA-Z]*\n?|```$", "", cleaned, flags=re.MULTILINE).strip() + cleaned = re.sub( + r"^```[a-zA-Z]*\n?|```$", "", cleaned, flags=re.MULTILINE + ).strip() # match {} content - match = re.search(r'\{[\s\S]*\}', cleaned) + match = re.search(r"\{[\s\S]*\}", cleaned) if not match: - #print(f"❌ 未找到 JSON 内容,原始响应:\n{cleaned}") + # print(f"❌ 未找到 JSON 内容,原始响应:\n{cleaned}") if attempt > 1: total_retries += 1 continue json_str = match.group(0) - + try: response_parsed = json.loads(json_str) except json.JSONDecodeError as e: # JSON 解析失败,记录日志并跳过当前条目 - #print(f"[Warning] JSON parsing failed at index {idx} of {attempt}: {e}") + # print(f"[Warning] JSON parsing failed at index {idx} of {attempt}: {e}") if attempt > 1: total_retries += 1 continue if not isinstance(response_parsed, dict): - #print(f"⚠️ 块 {idx} 第 {attempt} 次 JSON 解析失败") + # print(f"⚠️ 块 {idx} 第 {attempt} 次 JSON 解析失败") if attempt > 1: total_retries += 1 continue response = response_parsed break else: - #print(f"❌ 块 {idx} 超过 {MAX_RETRIES} 次尝试失败,跳过") + # print(f"❌ 块 {idx} 超过 {MAX_RETRIES} 次尝试失败,跳过") continue - - + # ====================== # 1️⃣ 解析 entities 部分 # ====================== entities_from_response = {} - + if isinstance(response.get("entities"), dict): for entity_type, names in response["entities"].items(): for name in names: if isinstance(name, str) and name.strip(): - entities_from_response[name.strip()] = entity_type.strip().capitalize() + entities_from_response[name.strip()] = ( + entity_type.strip().capitalize() + ) # ====================== # 2️⃣ 解析 triplets 部分 @@ -655,7 +707,7 @@ def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_ new_triplets = [] for triplet in response.get("triplets", []): if not isinstance(triplet, (list, tuple)) or len(triplet) != 3: - #print(f"⚠️ 无效三元组,跳过: {triplet}") + # print(f"⚠️ 无效三元组,跳过: {triplet}") continue # 去掉空字符串,只保留非空字符串和数字/其他类型 clean_triplet = [] @@ -671,11 +723,7 @@ def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_ # 长度必须为3才保留 if len(clean_triplet) == 3: new_triplets.append(clean_triplet) - if clean_triplet[0] == "New york": - print(clean_triplet) - print(len(clean_triplet)) else: - #print(f"⚠️ 三元组长度不足3,跳过: {triplet}") pass triplets.extend(new_triplets) @@ -690,10 +738,14 @@ def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_ entity2type[ent] = entities_from_response[ent] else: # 调用大模型补全实体类型 - entity_type, type_metric = self._ask_entity_type(ent, chunk, MAX_RETRIES) + entity_type, type_metric = self._ask_entity_type( + ent, chunk, MAX_RETRIES + ) total_calls += type_metric.get("call_count", 0) total_retries += type_metric.get("retry_count", 0) - total_llm_elapsed_ms += type_metric.get("llm_elapsed_ms", 0) or 0 + total_llm_elapsed_ms += ( + type_metric.get("llm_elapsed_ms", 0) or 0 + ) tpt = type_metric.get("prompt_tokens") tct = type_metric.get("completion_tokens") ttt = type_metric.get("total_tokens") @@ -707,9 +759,10 @@ def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_ total_tokens_sum += ttt token_observed = True entity2type[ent] = entity_type - + each_dataset[file_name]["parsing_rate"] = (idx + 1) / len(self.text_chunks) import yaml + output_file = [] for key in each_dataset: output_file.append(each_dataset[key]) @@ -724,7 +777,9 @@ def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_ # print(f"✅ 提取完成,共 {len(triplets)} 个三元组,{len(entity2id)} 个实体") elapsed_ms = int((time.perf_counter() - t_total) * 1000) - text_bytes = os.path.getsize(self.file_path) if os.path.exists(self.file_path) else None + text_bytes = ( + os.path.getsize(self.file_path) if os.path.exists(self.file_path) else None + ) text_chars = sum(len(c) for c in self.text_chunks) parse_metrics = { "provider": self.provider, @@ -745,7 +800,9 @@ def extract_graph_and_entity_by_LLM(self, each_dataset, file_name, each_dataset_ } return triplets, entity2id, entity2type, parse_metrics - def extract_graph_and_entity_by_LLM_with_mutiple(self, each_dataset, file_name, each_dataset_schema_file_path, MAX_RETRIES = 5): + def extract_graph_and_entity_by_LLM_with_mutiple( + self, each_dataset, file_name, each_dataset_schema_file_path, MAX_RETRIES=5 + ): """ 从文本数据中提取三元组与实体类型(多线程)。 @@ -758,28 +815,37 @@ def extract_graph_and_entity_by_LLM_with_mutiple(self, each_dataset, file_name, import yaml t_total = time.perf_counter() - print(f"开始使用 LLM 从文本中提取知识图谱(多线程,最多 {self.thread_count} 个 chunk 并行)...") + print( + f"开始使用 LLM 从文本中提取知识图谱(多线程,最多 {self.thread_count} 个 chunk 并行)..." + ) n_chunks = len(self.text_chunks) if n_chunks == 0: - return [], {}, {}, { - "provider": self.provider, - "model": self.model_name, - "chunk_count": 0, - "llm_call_count": 0, - "llm_elapsed_ms": 0, - "retry_count": 0, - "elapsed_ms": 0, - "text_bytes": self.source_file_size_bytes, - "text_chars": 0, - "markdown_text_bytes": self.markdown_text_bytes, - "markdown_text_chars": self.markdown_text_chars, - "markdown_file_size_bytes": self.markdown_file_size_bytes, - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, - } + return ( + [], + {}, + {}, + { + "provider": self.provider, + "model": self.model_name, + "chunk_count": 0, + "llm_call_count": 0, + "llm_elapsed_ms": 0, + "retry_count": 0, + "elapsed_ms": 0, + "text_bytes": self.source_file_size_bytes, + "text_chars": 0, + "markdown_text_bytes": self.markdown_text_bytes, + "markdown_text_chars": self.markdown_text_chars, + "markdown_file_size_bytes": self.markdown_file_size_bytes, + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + }, + ) - def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[str, str], Dict[str, Any]]: + def process_one_chunk( + idx: int, chunk: str + ) -> Tuple[int, List[List[str]], Dict[str, str], Dict[str, Any]]: new_triplets: List[List[str]] = [] entity2type_local: Dict[str, str] = {} local_calls = 0 @@ -792,7 +858,9 @@ def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[ for attempt in range(1, MAX_RETRIES + 1): local_calls += 1 - response, call_metric = self._generate_with_metrics(prompt_template_str.format(context=chunk)) + response, call_metric = self._generate_with_metrics( + prompt_template_str.format(context=chunk) + ) local_elapsed_ms += call_metric.get("elapsed_ms", 0) or 0 pt = call_metric.get("prompt_tokens") ct = call_metric.get("completion_tokens") @@ -813,8 +881,10 @@ def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[ continue cleaned = response.strip() - cleaned = re.sub(r"^```[a-zA-Z]*\n?|```$", "", cleaned, flags=re.MULTILINE).strip() - match = re.search(r'\{[\s\S]*\}', cleaned) + cleaned = re.sub( + r"^```[a-zA-Z]*\n?|```$", "", cleaned, flags=re.MULTILINE + ).strip() + match = re.search(r"\{[\s\S]*\}", cleaned) if not match: if attempt > 1: local_retries += 1 @@ -835,21 +905,34 @@ def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[ response = response_parsed break else: - return idx, [], {}, { - "llm_call_count": local_calls, - "llm_elapsed_ms": local_elapsed_ms, - "retry_count": local_retries, - "prompt_tokens": local_prompt_tokens if local_token_observed else None, - "completion_tokens": local_completion_tokens if local_token_observed else None, - "total_tokens": local_total_tokens if local_token_observed else None, - } + return ( + idx, + [], + {}, + { + "llm_call_count": local_calls, + "llm_elapsed_ms": local_elapsed_ms, + "retry_count": local_retries, + "prompt_tokens": local_prompt_tokens + if local_token_observed + else None, + "completion_tokens": local_completion_tokens + if local_token_observed + else None, + "total_tokens": local_total_tokens + if local_token_observed + else None, + }, + ) entities_from_response = {} if isinstance(response.get("entities"), dict): for entity_type, names in response["entities"].items(): for name in names: if isinstance(name, str) and name.strip(): - entities_from_response[name.strip()] = entity_type.strip().capitalize() + entities_from_response[name.strip()] = ( + entity_type.strip().capitalize() + ) for triplet in response.get("triplets", []): if not isinstance(triplet, (list, tuple)) or len(triplet) != 3: @@ -872,10 +955,14 @@ def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[ if ent in entities_from_response: entity2type_local[ent] = entities_from_response[ent] else: - entity_type, type_metric = self._ask_entity_type(ent, chunk, MAX_RETRIES) + entity_type, type_metric = self._ask_entity_type( + ent, chunk, MAX_RETRIES + ) local_calls += type_metric.get("call_count", 0) local_retries += type_metric.get("retry_count", 0) - local_elapsed_ms += type_metric.get("llm_elapsed_ms", 0) or 0 + local_elapsed_ms += ( + type_metric.get("llm_elapsed_ms", 0) or 0 + ) tpt = type_metric.get("prompt_tokens") tct = type_metric.get("completion_tokens") ttt = type_metric.get("total_tokens") @@ -890,14 +977,25 @@ def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[ local_token_observed = True entity2type_local[ent] = entity_type - return idx, new_triplets, entity2type_local, { - "llm_call_count": local_calls, - "llm_elapsed_ms": local_elapsed_ms, - "retry_count": local_retries, - "prompt_tokens": local_prompt_tokens if local_token_observed else None, - "completion_tokens": local_completion_tokens if local_token_observed else None, - "total_tokens": local_total_tokens if local_token_observed else None, - } + return ( + idx, + new_triplets, + entity2type_local, + { + "llm_call_count": local_calls, + "llm_elapsed_ms": local_elapsed_ms, + "retry_count": local_retries, + "prompt_tokens": local_prompt_tokens + if local_token_observed + else None, + "completion_tokens": local_completion_tokens + if local_token_observed + else None, + "total_tokens": local_total_tokens + if local_token_observed + else None, + }, + ) results: Dict[int, Tuple[List[List[str]], Dict[str, str]]] = {} yaml_lock = threading.Lock() @@ -915,7 +1013,9 @@ def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[ executor.submit(process_one_chunk, idx, chunk): idx for idx, chunk in enumerate(self.text_chunks) } - for fut in tqdm(as_completed(future_to_idx), total=n_chunks, desc="LLM chunks"): + for fut in tqdm( + as_completed(future_to_idx), total=n_chunks, desc="LLM chunks" + ): idx, new_triplets, entity2type_local, metric = fut.result() results[idx] = (new_triplets, entity2type_local) @@ -940,7 +1040,9 @@ def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[ each_dataset[file_name]["parsing_rate"] = completed / n_chunks output_file = [each_dataset[k] for k in each_dataset] final_schema = {"datasets": output_file} - with open(each_dataset_schema_file_path, "w", encoding="utf-8") as f: + with open( + each_dataset_schema_file_path, "w", encoding="utf-8" + ) as f: yaml.dump(final_schema, f, sort_keys=False, allow_unicode=True) triplets: List[List[str]] = [] @@ -977,7 +1079,9 @@ def process_one_chunk(idx: int, chunk: str) -> Tuple[int, List[List[str]], Dict[ } return triplets, entity2id, entity2type, parse_metrics - def _ask_entity_type(self, entity: str, chunk: str, MAX_RETRIES = 5) -> Tuple[str, Dict[str, Any]]: + def _ask_entity_type( + self, entity: str, chunk: str, MAX_RETRIES=5 + ) -> Tuple[str, Dict[str, Any]]: """向 LLM 查询实体类型,并返回该步骤的调用指标。""" prompt = prompt_entity_type.format(entity=entity, text_context=chunk) call_count = 0 @@ -1010,7 +1114,9 @@ def _ask_entity_type(self, entity: str, chunk: str, MAX_RETRIES = 5) -> Tuple[st "retry_count": retry_count, "llm_elapsed_ms": llm_elapsed_ms, "prompt_tokens": prompt_tokens_sum if token_observed else None, - "completion_tokens": completion_tokens_sum if token_observed else None, + "completion_tokens": completion_tokens_sum + if token_observed + else None, "total_tokens": total_tokens_sum if token_observed else None, } return response.strip().capitalize(), metric @@ -1045,7 +1151,9 @@ def save_graph_with_entity( os.makedirs(process_dir, exist_ok=True) entities_csv_path = os.path.join(process_dir, f"{self.graph_name}_accounts.csv") - triplets_csv_path = os.path.join(process_dir, f"{self.graph_name}_transactions.csv") + triplets_csv_path = os.path.join( + process_dir, f"{self.graph_name}_transactions.csv" + ) # 过滤非法三元组,并确保端点实体在 entity2id 中有映射 valid_triplets: List[List[str]] = [] @@ -1152,4 +1260,4 @@ def save_graph_with_entity( with open(graph_schema_file, "w", encoding="utf-8") as f: yaml.dump({"datasets": existing}, f, sort_keys=False, allow_unicode=True) - return new_schema \ No newline at end of file + return new_schema diff --git a/aag/dataload_main.py b/aag/dataload_main.py index 763a317..8060637 100644 --- a/aag/dataload_main.py +++ b/aag/dataload_main.py @@ -69,4 +69,4 @@ def main() -> int: if __name__ == "__main__": - exit(main()) + sys.exit(main()) diff --git a/aag/main.py b/aag/main.py index 1e8ed88..16c1538 100644 --- a/aag/main.py +++ b/aag/main.py @@ -18,9 +18,7 @@ logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(funcName)s(): %(message)s", - handlers=[ - logging.StreamHandler(sys.stdout) - ] + handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger(__name__) @@ -36,10 +34,10 @@ def get_user_prompt(current_mode: str) -> str: """ 根据当前模式生成用户输入提示符 - + Args: current_mode: 当前模式 "normal" | "interact" | "expert" - + Returns: 格式化的提示符字符串 """ @@ -54,9 +52,9 @@ def print_dag_info(dag_info: Dict[str, Any]) -> None: """打印DAG信息(格式化输出)""" if not dag_info: return - + print("\n📊 --- DAG 信息 ---") - + # 打印子查询计划 if "subquery_plan" in dag_info: plan = dag_info["subquery_plan"] @@ -64,68 +62,85 @@ def print_dag_info(dag_info: Dict[str, Any]) -> None: print("\n子查询列表:") for i, subq in enumerate(plan["subqueries"], 1): print(f" {i}. [{subq.get('id', '?')}] {subq.get('query', '')}") - deps = subq.get('depends_on', []) + deps = subq.get("depends_on", []) if deps: print(f" 依赖: {', '.join(deps)}") - + # 打印步骤信息 if "steps" in dag_info: print("\n步骤详情:") for step_id, step_info in dag_info["steps"].items(): print(f" [{step_id}] {step_info.get('question', '')}") - if step_info.get('algorithm'): + if step_info.get("algorithm"): print(f" 算法: {step_info['algorithm']}") - if step_info.get('task_type'): + if step_info.get("task_type"): print(f" 类型: {step_info['task_type']}") - + # 打印拓扑顺序 if "topological_order" in dag_info: order = dag_info["topological_order"] print(f"\n执行顺序: {' → '.join(order)}") - + print("-" * 74) def parse_arguments(): """解析命令行参数""" - parser = argparse.ArgumentParser(description="Analytics Augmented Generation Engine - 端到端分析增强生成框架") - + parser = argparse.ArgumentParser( + description="Analytics Augmented Generation Engine - 端到端分析增强生成框架" + ) + # 基础配置 - parser.add_argument("--mode", choices=["interactive", "batch", "process"], - default="interactive", help="运行模式") + parser.add_argument( + "--mode", + choices=["interactive", "batch", "process"], + default="interactive", + help="运行模式", + ) parser.add_argument("--config", type=str, help="配置文件路径") - + # 数据库配置 - parser.add_argument("--graph-space", type=str, default="graph_space", - help="图数据库空间名称") - parser.add_argument("--vector-collection", type=str, default="vector_collection", - help="向量数据库集合名称") - + parser.add_argument( + "--graph-space", type=str, default="graph_space", help="图数据库空间名称" + ) + parser.add_argument( + "--vector-collection", + type=str, + default="vector_collection", + help="向量数据库集合名称", + ) + # 模型配置 - parser.add_argument("--llm-model", type=str, default="llama3.1:70b", - help="LLM模型名称") - parser.add_argument("--embedding-model", type=str, default="BAAI/bge-large-en-v1.5", - help="嵌入模型名称") - parser.add_argument("--llm-type", choices=["ollama", "openai"], default="ollama", - help="LLM类型") - + parser.add_argument( + "--llm-model", type=str, default="llama3.1:70b", help="LLM模型名称" + ) + parser.add_argument( + "--embedding-model", + type=str, + default="BAAI/bge-large-en-v1.5", + help="嵌入模型名称", + ) + parser.add_argument( + "--llm-type", choices=["ollama", "openai"], default="ollama", help="LLM类型" + ) + # 设备配置 - parser.add_argument("--llm-device", type=str, default="cuda:0", - help="LLM设备") - parser.add_argument("--embed-device", type=str, default="cuda:0", - help="嵌入模型设备") - + parser.add_argument("--llm-device", type=str, default="cuda:0", help="LLM设备") + parser.add_argument( + "--embed-device", type=str, default="cuda:0", help="嵌入模型设备" + ) + # RAG配置 - parser.add_argument("--graph-k-hop", type=int, default=2, - help="图遍历跳数") - parser.add_argument("--vector-k-similarity", type=int, default=5, - help="向量相似度检索数量") - + parser.add_argument("--graph-k-hop", type=int, default=2, help="图遍历跳数") + parser.add_argument( + "--vector-k-similarity", type=int, default=5, help="向量相似度检索数量" + ) + # 输入输出 parser.add_argument("--input-file", type=str, help="输入文件路径") parser.add_argument("--output-file", type=str, help="输出文件路径") parser.add_argument("--questions", nargs="+", help="问题列表") - + return parser.parse_args() @@ -144,11 +159,11 @@ async def interactive_mode(engine: AAGEngine): try: question = input(f"\n{get_user_prompt(current_mode)}").strip() - if question.lower() in ['quit', 'exit', 'q']: + if question.lower() in ["quit", "exit", "q"]: print("🐾 AAG小助手退下喵~ 再见! (ฅ'ω'ฅ)") break - elif question.lower().startswith('mode '): + elif question.lower().startswith("mode "): mode_arg = question[5:].strip().lower() if mode_arg in ["normal", "interact", "expert"]: old_mode = current_mode @@ -161,14 +176,18 @@ async def interactive_mode(engine: AAGEngine): elif current_mode == "interact": print(" 提示:生成DAG后可 modify/start 交互调整") else: - print(" 提示:输入自然语言专家指令,系统会构建DAG并校验算法边界") + print( + " 提示:输入自然语言专家指令,系统会构建DAG并校验算法边界" + ) print("-" * 74) else: - print("⚠️ 无效模式,请使用 'mode normal' / 'mode interact' / 'mode expert'") + print( + "⚠️ 无效模式,请使用 'mode normal' / 'mode interact' / 'mode expert'" + ) print("-" * 74) continue - elif question.lower() == 'stats': + elif question.lower() == "stats": response = engine.get_performance_summary() print("\n📊 --- 性能统计 ---") for key, value in response.items(): @@ -176,31 +195,33 @@ async def interactive_mode(engine: AAGEngine): print("-" * 74) continue - elif question.lower() in ['datasets', 'list', 'list datasets']: + elif question.lower() in ["datasets", "list", "list datasets"]: try: ds_map = engine.list_datasets() print("\n📁 --- 可用数据集 ---") for dtype, names in ds_map.items(): - print(f"{dtype} ({len(names)}): {', '.join(names) if names else '(empty)'}") + print( + f"{dtype} ({len(names)}): {', '.join(names) if names else '(empty)'}" + ) print("-" * 74) except Exception as e: print(f"⚠️ 列出数据集失败: {e}") print("-" * 74) continue - elif question.lower().startswith('use '): + elif question.lower().startswith("use "): cmd = question[4:].strip() dtype = None name = cmd - if ':' in cmd: - parts = cmd.split(':', 1) + if ":" in cmd: + parts = cmd.split(":", 1) if len(parts) == 2: dtype, name = parts[0].strip(), parts[1].strip() else: toks = cmd.split() if len(toks) >= 2: - name = ' '.join(toks[:-1]).strip() + name = " ".join(toks[:-1]).strip() dtype = toks[-1].strip() try: @@ -209,7 +230,10 @@ async def interactive_mode(engine: AAGEngine): scope = dtype if dtype else "graph/table/text" print(f"❌ 未找到数据集: '{name}' (搜索范围: {scope})") else: - print(f"✅ 已选择数据集: {name}" + (f" ({dtype})" if dtype else "")) + print( + f"✅ 已选择数据集: {name}" + + (f" ({dtype})" if dtype else "") + ) dag_built = False print("-" * 74) except Exception as e: @@ -217,13 +241,17 @@ async def interactive_mode(engine: AAGEngine): print("-" * 74) continue - elif question.lower() in ['help', 'h']: + elif question.lower() in ["help", "h"]: print("\n📌 === 帮助菜单 ===") print("通用命令:") print(" 📊 stats 显示性能统计") print(" 📁 datasets | list 列出所有可用数据集") - print(" 🗂 use 选定数据集 (自动推断类型)") - print(" 🗂 use 指定类型 (graph/table/text)") + print( + " 🗂 use 选定数据集 (自动推断类型)" + ) + print( + " 🗂 use 指定类型 (graph/table/text)" + ) print(" 🗂 use : dtype:name 形式选择") print(" 🔄 mode normal|interact|expert 切换执行模式") print(" ❓ help | h 显示帮助") @@ -236,13 +264,17 @@ async def interactive_mode(engine: AAGEngine): if current_mode == "expert": print("\n专家模式输入示例:") - print(" 自然语言: 先找节点23所在社区,再在社区里用pagerank找前10个关键节点") + print( + " 自然语言: 先找节点23所在社区,再在社区里用pagerank找前10个关键节点" + ) print(f"\n当前模式: {mode_label[current_mode]}模式") print("\n问题前缀示例:") print(" normal: 找出节点45的社区") print(" interact: 找出节点45的社区") - print(" expert: 先找节点23所在社区,再在社区里用pagerank找前10关键节点") + print( + " expert: 先找节点23所在社区,再在社区里用pagerank找前10关键节点" + ) print("-" * 74) continue @@ -270,7 +302,9 @@ async def interactive_mode(engine: AAGEngine): current_mode = question_mode dag_built = False - if current_mode == "interact" and actual_question.lower().startswith("modify "): + if current_mode == "interact" and actual_question.lower().startswith( + "modify " + ): if not dag_built: print("⚠️ 请先输入问题生成DAG") continue @@ -286,7 +320,7 @@ async def interactive_mode(engine: AAGEngine): print(f"❌ {result['error']}") else: print(f"✅ {result.get('message', 'DAG已更新')}") - print_dag_info(result.get('dag_info', {})) + print_dag_info(result.get("dag_info", {})) print("\n请选择下一步操作:") print(" 🔧 modify 修改DAG") print(" ▶️ start 开始分析") @@ -296,7 +330,11 @@ async def interactive_mode(engine: AAGEngine): print("-" * 74) continue - if current_mode in {"interact", "expert"} and actual_question.lower() in ['start', 'analyze', '开始分析']: + if current_mode in {"interact", "expert"} and actual_question.lower() in [ + "start", + "analyze", + "开始分析", + ]: if not dag_built: print("⚠️ 请先输入问题生成DAG") continue @@ -326,7 +364,7 @@ async def interactive_mode(engine: AAGEngine): dag_built = False else: print(f"✅ {result.get('message', 'DAG已生成')}") - print_dag_info(result.get('dag_info', {})) + print_dag_info(result.get("dag_info", {})) print("\n请选择下一步操作:") print(" 🔧 modify 修改DAG") print(" ▶️ start 开始分析") @@ -344,19 +382,23 @@ async def interactive_mode(engine: AAGEngine): dag_built = False else: print(f"✅ {result.get('message', '专家DAG处理完成')}") - print_dag_info(result.get('dag_info', {})) + print_dag_info(result.get("dag_info", {})) validation = result.get("algorithm_validation", {}) unsupported = validation.get("unsupported_algorithms", []) if unsupported: print("\n⚠️ 以下算法不在算法库中:") for item in unsupported: - print(f" - {item.get('query_id')}: {item.get('requested_algorithm')}") + print( + f" - {item.get('query_id')}: {item.get('requested_algorithm')}" + ) suggestions = item.get("suggestions") or [] if suggestions: print(f" 建议: {', '.join(suggestions)}") - instruction_adjustments = validation.get("instruction_algorithm_adjustments", []) + instruction_adjustments = validation.get( + "instruction_algorithm_adjustments", [] + ) if instruction_adjustments: print("\nℹ️ 专家指令算法替换说明:") for item in instruction_adjustments: @@ -389,35 +431,38 @@ async def interactive_mode(engine: AAGEngine): print(f"⚠️ 处理分析时出错: {e}") -def batch_mode(engine: AAGEngine, questions: List[str], output_file: Optional[str] = None): - """批处理模式""" +async def batch_mode( + engine: AAGEngine, questions: List[str], output_file: Optional[str] = None +): + """批处理模式(异步)""" print(f"=== AAG Pipeline 批处理模式 ===") print(f"处理 {len(questions)} 个问题") - + results = [] - + for i, question in enumerate(questions, 1): print(f"\n处理问题 {i}/{len(questions)}: {question}") - + try: - result = engine.run(question) + result = await engine.run(question) results.append(result) - - + except Exception as e: print(f"✗ 处理失败: {e}") - results.append({ - "question": question, - "raise error": f"处理失败: {e}", - }) - + results.append( + { + "question": question, + "raise error": f"处理失败: {e}", + } + ) + # 输出结果 if output_file: import json - with open(output_file, 'w', encoding='utf-8') as f: + + with open(output_file, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"\n结果已保存到: {output_file}") - async def main(): @@ -429,7 +474,9 @@ async def main(): raise FileNotFoundError(f"默认配置文件未找到: {config_path}") print(f"{CYAN}\n{'=' * 74}{RESET}") - print(f"{BOLD} 🧠 欢迎使用 AAG 智能分析系统 (Analytics Augmented Generation Engine) {RESET}") + print( + f"{BOLD} 🧠 欢迎使用 AAG 智能分析系统 (Analytics Augmented Generation Engine) {RESET}" + ) print(f"{CYAN}{'=' * 74}{RESET}") print(f" 配置文件: {config_path}") print(f" 启动时间 : {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") @@ -452,7 +499,9 @@ async def main(): if mode == "interactive": print(f"{BOLD}💬 当前运行模式:交互模式 (Interactive Mode){RESET}") print("-" * 74) - print(" 输入问题按 Enter 分析;命令:stats | datasets | use [dtype] | mode | help | quit") + print( + " 输入问题按 Enter 分析;命令:stats | datasets | use [dtype] | mode | help | quit" + ) print(" 支持 normal / interact / expert 三种模式,使用 'mode ' 切换") print(f"{CYAN}{'=' * 74}{RESET}") # interactive_mode(engine) @@ -467,9 +516,11 @@ async def main(): questions = config.get("questions") output_file = config.get("output_file", "results.json") if not questions: - print(f"{YELLOW}错误:批处理模式需要在配置文件中提供 'questions' 字段{RESET}") + print( + f"{YELLOW}错误:批处理模式需要在配置文件中提供 'questions' 字段{RESET}" + ) return 1 - batch_mode(engine, questions, output_file) + await batch_mode(engine, questions, output_file) else: print(f"{YELLOW}❌ 未知的运行模式: {mode}{RESET}") @@ -486,17 +537,17 @@ async def main(): return 0 -def main_delay(): - """主函数""" +async def main_delay(): + """主函数(异步)""" args = parse_arguments() - + try: # 创建配置 - 优先从配置文件读取,否则使用命令行参数 if args.config: print(f"从配置文件加载配置: {args.config}") config = load_config_from_yaml(args.config) else: - # TODO(chaoyi): 数过多, llm 现在不支持通过命令行适配 可选的llm,后期需要修正 + # TODO(chaoyi): 数过多, llm 现在不支持通过命令行适配 可选的llm,后期需要修正 print("使用命令行参数创建配置") config = create_engine_config( graph_space_name=args.graph_space, @@ -507,30 +558,30 @@ def main_delay(): ollama_device=args.llm_device, embed_device=args.embed_device, graph_k_hop=args.graph_k_hop, - vector_k_similarity=args.vector_k_similarity + vector_k_similarity=args.vector_k_similarity, ) - + # 初始化Engine print("正在初始化 AAG Engine...") engine = AAGEngine(config) print("✓ Engine 初始化完成") - + # 根据模式运行 if args.mode == "interactive": - interactive_mode(engine) + await interactive_mode(engine) elif args.mode == "batch": if not args.questions: print("批处理模式需要提供问题列表,使用 --questions 参数") return - batch_mode(engine, args.questions, args.output_file) - + await batch_mode(engine, args.questions, args.output_file) + # 清理资源 - engine.shutdown() - + await engine.shutdown() + except Exception as e: print(f"运行出错: {e}") return 1 - + return 0 diff --git a/aag/reasoner/model_deployment.py b/aag/reasoner/model_deployment.py index bb139c9..9744200 100644 --- a/aag/reasoner/model_deployment.py +++ b/aag/reasoner/model_deployment.py @@ -4,7 +4,7 @@ import openai import json import re -import requests # deepseek +import requests # deepseek from typing import Dict, Literal, Any, List, Optional from llama_index.core import Settings from llama_index.core.utils import print_text @@ -22,12 +22,10 @@ from aag.error_recovery.enhancer import enhance_prompt - - EMBEDD_DIMS = { "BAAI/bge-large-en-v1.5": 1024, "BAAI/bge-base-en-v1.5": 768, - "BAAI/bge-small-en-v1.5": 384 + "BAAI/bge-small-en-v1.5": 384, } DAG_REVISION_SYSTEM_PROMPT = ( @@ -40,21 +38,21 @@ "4. Output must strictly follow the JSON schema below, with NO explanations and NO extra text.\n\n" "### Required Output JSON Schema:\n" "{\n" - " \"subqueries\": [\n" + ' "subqueries": [\n' " {\n" - " \"id\": \"qX\",\n" - " \"query\": \"text\",\n" - " \"depends_on\": []\n" + ' "id": "qX",\n' + ' "query": "text",\n' + ' "depends_on": []\n' " }\n" " ]\n" "}\n\n" "### Example Input (Before Modification):\n" "{\n" - " \"subqueries\": [\n" - " {\"id\": \"q1\", \"query\": \"Check if user Anna is a high-risk user.\", \"depends_on\": []},\n" - " {\"id\": \"q2\", \"query\": \"List all potential money laundering pathways around Anna.\", \"depends_on\": [\"q1\"]},\n" - " {\"id\": \"q3\", \"query\": \"Estimate the amount of cash that may have been illegally transferred out in relation to Anna.\", \"depends_on\": [\"q2\"]},\n" - " {\"id\": \"q4\", \"query\": \"Find the account with the largest transaction amount in the suspicious paths.\", \"depends_on\": [\"q2\"]}\n" + ' "subqueries": [\n' + ' {"id": "q1", "query": "Check if user Anna is a high-risk user.", "depends_on": []},\n' + ' {"id": "q2", "query": "List all potential money laundering pathways around Anna.", "depends_on": ["q1"]},\n' + ' {"id": "q3", "query": "Estimate the amount of cash that may have been illegally transferred out in relation to Anna.", "depends_on": ["q2"]},\n' + ' {"id": "q4", "query": "Find the account with the largest transaction amount in the suspicious paths.", "depends_on": ["q2"]}\n' " ]\n" "}\n\n" "### Example User Edit Instruction:\n" @@ -62,18 +60,20 @@ "Modify node 4 so that it becomes: identify the account with the largest transaction amount within Anna’s community.\n\n" "### Example Output:\n" "{\n" - " \"subqueries\": [\n" - " {\"id\": \"q1\", \"query\": \"Check if user Anna is a high-risk user.\", \"depends_on\": []},\n" - " {\"id\": \"q2\", \"query\": \"Identify the potential fraud community in which Anna resides to narrow the scope of subsequent risk monitoring.\", \"depends_on\": [\"q1\"]},\n" - " {\"id\": \"q3\", \"query\": \"List all potential money laundering pathways within the high-risk community where Anna is located.\", \"depends_on\": [\"q2\"]},\n" - " {\"id\": \"q4\", \"query\": \"Estimate the amount of cash that may have been illegally transferred out in relation to Anna.\", \"depends_on\": [\"q3\"]},\n" - " {\"id\": \"q5\", \"query\": \"Identify the account with the largest transaction amount within Anna's fraud community.\", \"depends_on\": [\"q2\"]}\n" + ' "subqueries": [\n' + ' {"id": "q1", "query": "Check if user Anna is a high-risk user.", "depends_on": []},\n' + ' {"id": "q2", "query": "Identify the potential fraud community in which Anna resides to narrow the scope of subsequent risk monitoring.", "depends_on": ["q1"]},\n' + ' {"id": "q3", "query": "List all potential money laundering pathways within the high-risk community where Anna is located.", "depends_on": ["q2"]},\n' + ' {"id": "q4", "query": "Estimate the amount of cash that may have been illegally transferred out in relation to Anna.", "depends_on": ["q3"]},\n' + ' {"id": "q5", "query": "Identify the account with the largest transaction amount within Anna\'s fraud community.", "depends_on": ["q2"]}\n' " ]\n" "}" ) -def build_dag_revision_user_prompt(current_plan: Dict[str, Any], user_request: str) -> str: +def build_dag_revision_user_prompt( + current_plan: Dict[str, Any], user_request: str +) -> str: plan_text = json.dumps(current_plan, ensure_ascii=False, indent=2) normalized_request = user_request.strip() return ( @@ -82,8 +82,8 @@ def build_dag_revision_user_prompt(current_plan: Dict[str, Any], user_request: s "User request:\n" f"{normalized_request}\n\n" "Update the plan so it satisfies the request. Rules:\n" - "1. Output JSON only with the shape {\"subqueries\": [...]}.\n" - "2. Each entry needs \"id\", \"query\", and \"depends_on\" (list) fields.\n" + '1. Output JSON only with the shape {"subqueries": [...]}.\n' + '2. Each entry needs "id", "query", and "depends_on" (list) fields.\n' "3. Preserve existing ids when editing content; introduce new ids only for new steps.\n" "4. Keep dependencies acyclic and align them with the described changes.\n" "5. If the user removes or inserts nodes between two ids, reflect that explicitly.\n" @@ -92,32 +92,31 @@ def build_dag_revision_user_prompt(current_plan: Dict[str, Any], user_request: s class EmbeddingEnv: - - def __init__(self, - embed_name="BAAI/bge-large-en-v1.5", - embed_batch_size=20, - device="cuda:0"): + def __init__( + self, embed_name="BAAI/bge-large-en-v1.5", embed_batch_size=20, device="cuda:0" + ): self.embed_name = embed_name self.embed_batch_size = embed_batch_size assert embed_name in EMBEDD_DIMS self.dim = EMBEDD_DIMS[embed_name] - if 'BAAI' in embed_name: + if "BAAI" in embed_name: print(f"use huggingface embedding {embed_name}") self.embed_model = HuggingFaceEmbedding( - model_name=embed_name, - embed_batch_size=embed_batch_size, - device=device) + model_name=embed_name, embed_batch_size=embed_batch_size, device=device + ) else: print(f"use openai embedding {embed_name}") self.embed_model = OpenAIEmbedding( - model=embed_name, embed_batch_size=embed_batch_size) + model=embed_name, embed_batch_size=embed_batch_size + ) Settings.embed_model = self.embed_model print_text( f"EmbeddingEnv: embed_name {embed_name}, embed_batch_size {self.embed_batch_size}, dim {self.dim}\n", - color='red') + color="red", + ) def __str__(self): return f"{self.embed_name} {self.embed_batch_size}" @@ -157,41 +156,47 @@ def calculate_similarity(self, query1, query2): class OllamaEnv: - ROLE_MAP = { "system": MessageRole.SYSTEM, "user": MessageRole.USER, "assistant": MessageRole.ASSISTANT, } - def __init__(self, - llm_mode_name="llama3.1:70b", - llm_embed_name="BAAI/bge-large-en-v1.5", - chunk_size=512, - chunk_overlap=20, - embed_batch_size=20, - device='cuda:2', - timeout=150000, - port=11434, - verbose=False): - + def __init__( + self, + llm_mode_name="llama3.1:70b", + llm_embed_name="BAAI/bge-large-en-v1.5", + chunk_size=512, + chunk_overlap=20, + embed_batch_size=20, + device="cuda:2", + timeout=150000, + port=11434, + verbose=False, + ): base_url = f"http://localhost:{port}" - Settings.llm = Ollama(model=llm_mode_name, request_timeout=timeout, - temperature=0.0, base_url=base_url) # , device=device + Settings.llm = Ollama( + model=llm_mode_name, + request_timeout=timeout, + temperature=0.0, + base_url=base_url, + ) # , device=device self.verbose = verbose - if 'BAAI' in llm_embed_name: + if "BAAI" in llm_embed_name: Settings.embed_model = HuggingFaceEmbedding( model_name=llm_embed_name, embed_batch_size=embed_batch_size, - device=device) + device=device, + ) else: print(f"use openai embedding {llm_embed_name}") Settings.embed_model = OpenAIEmbedding( - model=llm_embed_name, embed_batch_size=embed_batch_size) + model=llm_embed_name, embed_batch_size=embed_batch_size + ) if llm_embed_name not in EMBEDD_DIMS.keys(): - raise NotImplementedError('embed model not support!') + raise NotImplementedError("embed model not support!") self.llm = Settings.llm self.embed_model = Settings.embed_model @@ -199,15 +204,15 @@ def __init__(self, Settings.chunk_size = chunk_size Settings.chunk_overlap = chunk_overlap print( - f"llm_mode_name: {llm_mode_name}, llm_embed_name: {llm_embed_name}, chunk_size: {Settings.chunk_size}, chunk_overlap: {chunk_overlap}") - + f"llm_mode_name: {llm_mode_name}, llm_embed_name: {llm_embed_name}, chunk_size: {Settings.chunk_size}, chunk_overlap: {chunk_overlap}" + ) def chat(self, messages: list) -> str: if messages and isinstance(messages[0], dict): messages = [ ChatMessage( role=self.ROLE_MAP.get(m["role"], MessageRole.USER), - content=m.get("content", "") + content=m.get("content", ""), ) for m in messages ] @@ -219,15 +224,20 @@ def chat(self, messages: list) -> str: def generate_response(self, query: str): response = self.llm.complete(query) return response - - def execute_prompt(self, full_prompt: str, parse_json: bool = True, response_format: Optional[Dict] = None) -> Any: + + def execute_prompt( + self, + full_prompt: str, + parse_json: bool = True, + response_format: Optional[Dict] = None, + ) -> Any: """Execute a formatted prompt and return the response. - + Args: full_prompt: The fully formatted prompt string parse_json: Whether to parse the response as JSON response_format: Optional response format dict (for OpenAI compatibility) - + Returns: Parsed JSON dict if parse_json=True, otherwise raw response text """ @@ -235,20 +245,17 @@ def execute_prompt(self, full_prompt: str, parse_json: bool = True, response_for if parse_json: return extract_json_from_response(response.text) return response.text - + def check_data_dependency( - self, - q1_question: str, - q1_algorithm: str, - q2_question: str, - q2_algorithm: str) -> bool: + self, q1_question: str, q1_algorithm: str, q2_question: str, q2_algorithm: str + ) -> bool: """Determine whether Q2 depends on the result of Q1 using the LLM.""" try: full_prompt = check_data_dependency_prompt.format( q1_question=q1_question, q1_algorithm=q1_algorithm, q2_question=q2_question, - q2_algorithm=q2_algorithm + q2_algorithm=q2_algorithm, ) result = self.execute_prompt(full_prompt, parse_json=True) depends = result.get("q2_depends_on_q1") @@ -263,15 +270,11 @@ def check_data_dependency( except Exception as e: print(f"Error determining data dependency with Ollama: {e}") return False - + def plan_subqueries(self, decompose: bool, query: str) -> dict: - if decompose == False: + if decompose == False: # Do not decompose, treat as a single question - return{"subqueries": [{ - "id": "q1", - "query": query, - "depends_on": [] - }]} + return {"subqueries": [{"id": "q1", "query": query, "depends_on": []}]} full_prompt = plan_subqueries_prompt.format(query=query) return self.execute_prompt(full_prompt, parse_json=True) @@ -280,69 +283,83 @@ def classify_question_type(self, question: str) -> dict: full_prompt = classify_question_type_prompt.format(question=question) return self.execute_prompt(full_prompt, parse_json=True) - - def revise_subquery_plan(self, current_plan: Dict[str, Any], user_request: str) -> Dict[str, Any]: + def revise_subquery_plan( + self, current_plan: Dict[str, Any], user_request: str + ) -> Dict[str, Any]: full_prompt = revise_subquery_plan_prompt.format( current_plan=json.dumps(current_plan, ensure_ascii=False), - user_request=user_request + user_request=user_request, ) return self.execute_prompt(full_prompt, parse_json=True) def select_task_type(self, question: str, task_type_list: list) -> dict: full_prompt = select_task_type_prompt.format( - question=question, - task_type_list=task_type_list + question=question, task_type_list=task_type_list ) return self.execute_prompt(full_prompt, parse_json=True) - - def select_algorithm(self, question: str, algorithm_list: list, graph_schema: Optional[Dict[str, Any]] = None) -> dict: + def select_algorithm( + self, + question: str, + algorithm_list: list, + graph_schema: Optional[Dict[str, Any]] = None, + ) -> dict: schema_context = "" if graph_schema: schema_context = f""" Current Graph Dataset Schema: -- Dataset: {graph_schema.get('dataset_name', 'Unknown')} -- Graph Type: {'Directed' if graph_schema.get('graph_properties', {}).get('directed') else 'Undirected'}, {'Heterogeneous' if graph_schema.get('graph_properties', {}).get('heterogeneous') else 'Homogeneous'}, {'Multigraph' if graph_schema.get('graph_properties', {}).get('multigraph') else 'Simple'}, {'Weighted' if graph_schema.get('graph_properties', {}).get('weighted') else 'Unweighted'} -- Vertex Types: {', '.join(graph_schema.get('vertex_types', []))} -- Edge Types: {', '.join(graph_schema.get('edge_types', []))} -- Vertex Configurations: {json.dumps(graph_schema.get('vertex_configs', []), ensure_ascii=False, indent=2)} -- Edge Configurations: {json.dumps(graph_schema.get('edge_configs', []), ensure_ascii=False, indent=2)} +- Dataset: {graph_schema.get("dataset_name", "Unknown")} +- Graph Type: {"Directed" if graph_schema.get("graph_properties", {}).get("directed") else "Undirected"}, {"Heterogeneous" if graph_schema.get("graph_properties", {}).get("heterogeneous") else "Homogeneous"}, {"Multigraph" if graph_schema.get("graph_properties", {}).get("multigraph") else "Simple"}, {"Weighted" if graph_schema.get("graph_properties", {}).get("weighted") else "Unweighted"} +- Vertex Types: {", ".join(graph_schema.get("vertex_types", []))} +- Edge Types: {", ".join(graph_schema.get("edge_types", []))} +- Vertex Configurations: {json.dumps(graph_schema.get("vertex_configs", []), ensure_ascii=False, indent=2)} +- Edge Configurations: {json.dumps(graph_schema.get("edge_configs", []), ensure_ascii=False, indent=2)} Please consider this schema when selecting the algorithm to ensure compatibility. """ - - response = self.llm.complete(select_algorithm_prompt.format( - question=question, - algorithm_list=algorithm_list - ) + schema_context) + + response = self.llm.complete( + select_algorithm_prompt.format( + question=question, algorithm_list=algorithm_list + ) + + schema_context + ) return extract_json_from_response(response.text) - - - def extract_parameters_with_postprocess(self, question: str, tool_description: str) -> dict: + + def extract_parameters_with_postprocess( + self, question: str, tool_description: str + ) -> dict: full_prompt = extract_parameters_with_postprocess_promt.format( - question=question, - tool_description=tool_description + question=question, tool_description=tool_description ) return self.execute_prompt(full_prompt, parse_json=True) - - def extract_parameters_with_postprocess_new(self, question: str, tool_description: str, vertex_schema: Dict[str, str], edge_schema: Dict[str, str]) -> dict: + + def extract_parameters_with_postprocess_new( + self, + question: str, + tool_description: str, + vertex_schema: Dict[str, str], + edge_schema: Dict[str, str], + ) -> dict: """Extract parameters and generate post-processing code with vertex and edge schema information.""" - response = self.llm.complete(extract_parameters_with_postprocess_promt_new.format( - question=question, - tool_description=tool_description, - vertex_schema=json.dumps(vertex_schema, indent=2), - edge_schema=json.dumps(edge_schema, indent=2) - )) + response = self.llm.complete( + extract_parameters_with_postprocess_promt_new.format( + question=question, + tool_description=tool_description, + vertex_schema=json.dumps(vertex_schema, indent=2), + edge_schema=json.dumps(edge_schema, indent=2), + ) + ) return extract_json_from_response(response.text) - + def merge_parameters_from_dependencies( - self, - question: str, - tool_description: str, - vertex_schema: Dict[str, str], + self, + question: str, + tool_description: str, + vertex_schema: Dict[str, str], edge_schema: Dict[str, str], - dependency_parameters: Dict[str, Any] + dependency_parameters: Dict[str, Any], ) -> dict: """Merge dependency parameters with extracted parameters and generate post-processing code.""" full_prompt = merge_parameters_with_dependencies_prompt.format( @@ -350,12 +367,13 @@ def merge_parameters_from_dependencies( tool_description=tool_description, dependency_parameters=json.dumps(dependency_parameters, indent=2), vertex_schema=json.dumps(vertex_schema, indent=2), - edge_schema=json.dumps(edge_schema, indent=2) + edge_schema=json.dumps(edge_schema, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - - def generate_answer_from_algorithm_result(self, question: str, tool_description: str, tool_result: Dict[str, Any]) -> str: + def generate_answer_from_algorithm_result( + self, question: str, tool_description: str, tool_result: Dict[str, Any] + ) -> str: prompt = f""" You are a professional data analyst responsible for interpreting graph algorithm results. Your task is to analyze the computation output of a given tool based on: - The **natural-language user question** (*question*) @@ -447,68 +465,113 @@ def generate_answer_from_algorithm_result(self, question: str, tool_description: if not response_text: return "Unable to generate answer from the algorithm result." return response_text - - def analyze_dependency_type_and_locate_dependency_data(self, current_question:str, task_type:str, current_algo_desc:str, parent_question: str, parent_outputs_meta:list)-> dict: + + def analyze_dependency_type_and_locate_dependency_data( + self, + current_question: str, + task_type: str, + current_algo_desc: str, + parent_question: str, + parent_outputs_meta: list, + ) -> dict: full_prompt = analyze_dependency_type_and_locate_dependency_data_prompt.format( current_question=current_question, task_type=task_type, current_algo_desc=current_algo_desc, parent_question=parent_question, - parent_outputs_meta=parent_outputs_meta + parent_outputs_meta=parent_outputs_meta, ) return self.execute_prompt(full_prompt, parse_json=True) - - def map_parameters(self, current_question: str, current_algo_desc: str, dependency_items: List[Dict[str, Any]]) -> Dict[str, Any]: + + def map_parameters( + self, + current_question: str, + current_algo_desc: str, + dependency_items: List[Dict[str, Any]], + ) -> Dict[str, Any]: full_prompt = map_parameters_prompt.format( current_question=current_question, algo_desc=current_algo_desc, - dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2) + dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - - def generate_graph_conversion_code(self, current_question: str, dependency_items: List[Dict[str, Any]])-> Dict[str, Any]: + + def generate_graph_conversion_code( + self, current_question: str, dependency_items: List[Dict[str, Any]] + ) -> Dict[str, Any]: full_prompt = generate_graph_conversion_code_prompt.format( current_question=current_question, - dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2) + dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - + def generate_numeric_analysis_code( - self, - question: str, - dependency_items: List[Dict[str, Any]], - vertex_schema: Dict[str, str], - edge_schema: Dict[str, str] + self, + question: str, + dependency_items: List[Dict[str, Any]], + vertex_schema: Dict[str, str], + edge_schema: Dict[str, str], ) -> Dict[str, Any]: full_prompt = generate_numeric_analysis_code_prompt.format( question=question, - dependency_data_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), + dependency_data_items=json.dumps( + dependency_items, ensure_ascii=False, indent=2 + ), vertex_schema=json.dumps(vertex_schema, ensure_ascii=False, indent=2), - edge_schema=json.dumps(edge_schema, ensure_ascii=False, indent=2) + edge_schema=json.dumps(edge_schema, ensure_ascii=False, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - + def nl_query_classify_type(self, question: str, query_templates: dict) -> str: """Classify the query type for natural language query engine.""" full_prompt = nl_query_classify_type_prompt.format( question=question, - query_templates=json.dumps(query_templates, ensure_ascii=False, indent=2) + query_templates=json.dumps(query_templates, ensure_ascii=False, indent=2), ) response_text = self.execute_prompt(full_prompt, parse_json=False) return response_text.strip().strip('"').strip("'") - - def nl_query_extract_params(self, question: str, query_type: str, template: dict, - schema_info: str, query_modifiers: dict) -> dict: + + def nl_query_extract_params( + self, + question: str, + query_type: str, + template: dict, + schema_info: str, + query_modifiers: dict, + ) -> dict: """Extract parameters for natural language query engine.""" full_prompt = nl_query_extract_params_prompt.format( schema_info=schema_info, query_type=query_type, - template_description=template['description'], - template_method=template['method'], - required_params=template['required_params'], - optional_params=template.get('optional_params', []), + template_description=template["description"], + template_method=template["method"], + required_params=template["required_params"], + optional_params=template.get("optional_params", []), query_modifiers=json.dumps(query_modifiers, ensure_ascii=False, indent=2), - question=question + question=question, + ) + return self.execute_prompt(full_prompt, parse_json=True) + + # add gjq + def nl_query_validate_cypher( + self, + cypher: str, + question: str, + query_type: str, + params: dict, + schema_info: str, + template_info: str, + template_cypher_example: str, + ) -> dict: + """校验自然语言查询引擎的 Cypher 语句""" + full_prompt = nl_query_validate_cypher_prompt.format( + schema_info=schema_info, + template_info=template_info, + template_cypher_example=template_cypher_example, + question=question, + query_type=query_type, + params=json.dumps(params, ensure_ascii=False, indent=2), + cypher=cypher, ) return self.execute_prompt(full_prompt, parse_json=True) @@ -525,38 +588,23 @@ def _proxy_env_lines() -> List[str]: """用于排查 ConnectTimeout / TLS:是否走了 HTTP 代理。""" lines = [] for key in ( - "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "NO_PROXY", - "http_proxy", "https_proxy", "all_proxy", "no_proxy", + "HTTP_PROXY", + "HTTPS_PROXY", + "ALL_PROXY", + "NO_PROXY", + "http_proxy", + "https_proxy", + "all_proxy", + "no_proxy", ): val = os.environ.get(key) if val: lines.append(f"{key}={val!r}") return lines - # add gjq - def nl_query_validate_cypher(self, cypher: str, question: str, query_type: str, - params: dict, schema_info: str, template_info: str, - template_cypher_example: str) -> dict: - """Validate Cypher statement for natural language query engine.""" - full_prompt = nl_query_validate_cypher_prompt.format( - schema_info=schema_info, - template_info=template_info, - template_cypher_example=template_cypher_example, - question=question, - query_type=query_type, - params=json.dumps(params, ensure_ascii=False, indent=2), - cypher=cypher - ) - return self.execute_prompt(full_prompt, parse_json=True) - class OpenAIEnv: - - def __init__(self, - base_url, - api_key, - model_name, - temperature: float = 0.0): + def __init__(self, base_url, api_key, model_name, temperature: float = 0.0): self.base_url = base_url self.api_key = api_key self.model = model_name @@ -572,20 +620,28 @@ def __init__(self, ) px = _proxy_env_lines() if px: - print("[OpenAIEnv] proxy env (requests may use HTTP CONNECT / TLS via proxy):", flush=True) + print( + "[OpenAIEnv] proxy env (requests may use HTTP CONNECT / TLS via proxy):", + flush=True, + ) for line in px: print(f" {line}", flush=True) else: print("[OpenAIEnv] no HTTP(S)_PROXY / ALL_PROXY in environment", flush=True) - def execute_prompt(self, full_prompt: str, parse_json: bool = True, response_format: Optional[Dict] = None) -> Any: + def execute_prompt( + self, + full_prompt: str, + parse_json: bool = True, + response_format: Optional[Dict] = None, + ) -> Any: """Execute a formatted prompt and return the response. - + Args: full_prompt: The fully formatted prompt string parse_json: Whether to parse the response as JSON response_format: Optional response format dict (for OpenAI compatibility) - + Returns: Parsed JSON dict if parse_json=True, otherwise raw response text """ @@ -595,7 +651,7 @@ def execute_prompt(self, full_prompt: str, parse_json: bool = True, response_for "messages": messages, "temperature": self.temperature, } - + if response_format: request_kwargs["response_format"] = response_format elif parse_json: @@ -617,7 +673,10 @@ def execute_prompt(self, full_prompt: str, parse_json: bool = True, response_for flush=True, ) if getattr(e, "__cause__", None) is not None: - print(f"[OpenAIEnv] cause: {type(e.__cause__).__name__}: {e.__cause__}", flush=True) + print( + f"[OpenAIEnv] cause: {type(e.__cause__).__name__}: {e.__cause__}", + flush=True, + ) raise elapsed = time.perf_counter() - t0 response_text = response.choices[0].message.content @@ -632,20 +691,19 @@ def execute_prompt(self, full_prompt: str, parse_json: bool = True, response_for return response_text def check_data_dependency( - self, - q1_question: str, - q1_algorithm: str, - q2_question: str, - q2_algorithm: str) -> bool: + self, q1_question: str, q1_algorithm: str, q2_question: str, q2_algorithm: str + ) -> bool: """Use the llm to assess whether Q2 depends on Q1.""" try: full_prompt = check_data_dependency_prompt.format( q1_question=q1_question, q1_algorithm=q1_algorithm, q2_question=q2_question, - q2_algorithm=q2_algorithm + q2_algorithm=q2_algorithm, + ) + result = self.execute_prompt( + full_prompt, parse_json=True, response_format={"type": "json_object"} ) - result = self.execute_prompt(full_prompt, parse_json=True, response_format={"type": "json_object"}) depends = result.get("q2_depends_on_q1") if isinstance(depends, bool): return depends @@ -688,7 +746,10 @@ def generate_response(self, query: str): flush=True, ) if getattr(e, "__cause__", None) is not None: - print(f"[OpenAIEnv] cause: {type(e.__cause__).__name__}: {e.__cause__}", flush=True) + print( + f"[OpenAIEnv] cause: {type(e.__cause__).__name__}: {e.__cause__}", + flush=True, + ) # 常见:代理 TLS 握手超时、连接超时、读超时 msg_l = str(e).lower() if "timeout" in msg_l or "timed out" in msg_l or ename.endswith("Timeout"): @@ -697,26 +758,20 @@ def generate_response(self, query: str): flush=True, ) return None - - + def plan_subqueries(self, decompose: bool, query: str) -> dict: - if decompose == False: + if decompose == False: # Do not decompose, treat as a single question - return{"subqueries": [{ - "id": "q1", - "query": query, - "depends_on": [] - }]} + return {"subqueries": [{"id": "q1", "query": query, "depends_on": []}]} full_prompt = plan_subqueries_prompt.format(query=query) return self.execute_prompt(full_prompt, parse_json=True) def revise_subquery_plan( - self, - current_plan: Dict[str, Any], - user_request: str) -> Dict[str, Any]: + self, current_plan: Dict[str, Any], user_request: str + ) -> Dict[str, Any]: full_prompt = revise_subquery_plan_prompt.format( - current_plan=json.dumps(current_plan, ensure_ascii=False), - user_request=user_request + current_plan=json.dumps(current_plan, ensure_ascii=False), + user_request=user_request, ) return self.execute_prompt(full_prompt, parse_json=True) @@ -724,66 +779,81 @@ def classify_question_type(self, question: str) -> dict: """Classify whether a question requires graph algorithm or numeric analysis.""" full_prompt = classify_question_type_prompt.format(question=question) return self.execute_prompt(full_prompt, parse_json=True) - + def select_task_type(self, question: str, task_type_list: list) -> dict: full_prompt = select_task_type_prompt.format( - question=question, - task_type_list=task_type_list + question=question, task_type_list=task_type_list ) return self.execute_prompt(full_prompt, parse_json=True) - - def select_algorithm(self, question: str, algorithm_list: list, graph_schema: Optional[Dict[str, Any]] = None) -> dict: + + def select_algorithm( + self, + question: str, + algorithm_list: list, + graph_schema: Optional[Dict[str, Any]] = None, + ) -> dict: schema_context = "" if graph_schema: schema_context = f""" Current Graph Dataset Schema: -- Dataset: {graph_schema.get('dataset_name', 'Unknown')} -- Graph Type: {'Directed' if graph_schema.get('graph_properties', {}).get('directed') else 'Undirected'}, {'Heterogeneous' if graph_schema.get('graph_properties', {}).get('heterogeneous') else 'Homogeneous'}, {'Multigraph' if graph_schema.get('graph_properties', {}).get('multigraph') else 'Simple'}, {'Weighted' if graph_schema.get('graph_properties', {}).get('weighted') else 'Unweighted'} -- Vertex Types: {', '.join(graph_schema.get('vertex_types', []))} -- Edge Types: {', '.join(graph_schema.get('edge_types', []))} -- Vertex Configurations: {json.dumps(graph_schema.get('vertex_configs', []), ensure_ascii=False, indent=2)} -- Edge Configurations: {json.dumps(graph_schema.get('edge_configs', []), ensure_ascii=False, indent=2)} +- Dataset: {graph_schema.get("dataset_name", "Unknown")} +- Graph Type: {"Directed" if graph_schema.get("graph_properties", {}).get("directed") else "Undirected"}, {"Heterogeneous" if graph_schema.get("graph_properties", {}).get("heterogeneous") else "Homogeneous"}, {"Multigraph" if graph_schema.get("graph_properties", {}).get("multigraph") else "Simple"}, {"Weighted" if graph_schema.get("graph_properties", {}).get("weighted") else "Unweighted"} +- Vertex Types: {", ".join(graph_schema.get("vertex_types", []))} +- Edge Types: {", ".join(graph_schema.get("edge_types", []))} +- Vertex Configurations: {json.dumps(graph_schema.get("vertex_configs", []), ensure_ascii=False, indent=2)} +- Edge Configurations: {json.dumps(graph_schema.get("edge_configs", []), ensure_ascii=False, indent=2)} Please consider this schema when selecting the algorithm to ensure compatibility. """ - + response = self.client.chat.completions.create( model=self.model, - messages=[{"role": "user", "content": select_algorithm_prompt.format( - question=question, - algorithm_list=algorithm_list - ) + schema_context}], + messages=[ + { + "role": "user", + "content": select_algorithm_prompt.format( + question=question, algorithm_list=algorithm_list + ) + + schema_context, + } + ], temperature=self.temperature, ) response_text = response.choices[0].message.content return parse_openai_json_response(response_text, "select_algorithm") - - def extract_parameters_with_postprocess(self, question: str, tool_description: str) -> dict: + + def extract_parameters_with_postprocess( + self, question: str, tool_description: str + ) -> dict: full_prompt = extract_parameters_with_postprocess_promt.format( - question=question, - tool_description=tool_description + question=question, tool_description=tool_description ) return self.execute_prompt(full_prompt, parse_json=True) - - def extract_parameters_with_postprocess_new(self, question: str, tool_description: str, vertex_schema: Dict[str, str], edge_schema: Dict[str, str]) -> dict: + + def extract_parameters_with_postprocess_new( + self, + question: str, + tool_description: str, + vertex_schema: Dict[str, str], + edge_schema: Dict[str, str], + ) -> dict: """Extract parameters and generate post-processing code with vertex and edge schema information.""" full_prompt = extract_parameters_with_postprocess_promt_new.format( question=question, tool_description=tool_description, vertex_schema=json.dumps(vertex_schema, indent=2), - edge_schema=json.dumps(edge_schema, indent=2) + edge_schema=json.dumps(edge_schema, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - def merge_parameters_from_dependencies( - self, - question: str, - tool_description: str, - vertex_schema: Dict[str, str], + self, + question: str, + tool_description: str, + vertex_schema: Dict[str, str], edge_schema: Dict[str, str], - dependency_parameters: Dict[str, Any] + dependency_parameters: Dict[str, Any], ) -> dict: """Merge dependency parameters with extracted parameters and generate post-processing code.""" full_prompt = merge_parameters_with_dependencies_prompt.format( @@ -791,11 +861,13 @@ def merge_parameters_from_dependencies( tool_description=tool_description, dependency_parameters=json.dumps(dependency_parameters, indent=2), vertex_schema=json.dumps(vertex_schema, indent=2), - edge_schema=json.dumps(edge_schema, indent=2) + edge_schema=json.dumps(edge_schema, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - def generate_answer_from_algorithm_result(self, question: str, tool_description: str, tool_result: Dict[str, Any]) -> str: + def generate_answer_from_algorithm_result( + self, question: str, tool_description: str, tool_result: Dict[str, Any] + ) -> str: # prompt = f""" # You are a professional data analyst responsible for interpreting graph algorithm results. # Your task is to analyze the computation output of a given tool based on: @@ -828,7 +900,7 @@ def generate_answer_from_algorithm_result(self, question: str, tool_description: # Do NOT return JSON or code — only natural language. # """ - # + # prompt = f""" You are a professional data analyst responsible for interpreting graph algorithm results. Your task is to analyze the computation output of a given tool based on: - The **natural-language user question** (*question*) @@ -925,7 +997,7 @@ def generate_answer_from_algorithm_result(self, question: str, tool_description: if not response_text: return "Unable to generate answer from the algorithm result." return response_text - + def chat(self, messages: list) -> str: response = self.client.chat.completions.create( model=self.model, @@ -937,91 +1009,119 @@ def chat(self, messages: list) -> str: return "Unable to generate answer." return response_text - def analyze_dependency_type_and_locate_dependency_data(self, current_question:str, task_type:str, current_algo_desc:str, parent_question: str, parent_outputs_meta:list)-> dict: + def analyze_dependency_type_and_locate_dependency_data( + self, + current_question: str, + task_type: str, + current_algo_desc: str, + parent_question: str, + parent_outputs_meta: list, + ) -> dict: full_prompt = analyze_dependency_type_and_locate_dependency_data_prompt.format( current_question=current_question, task_type=task_type, current_algo_desc=current_algo_desc, parent_question=parent_question, - parent_outputs_meta=parent_outputs_meta + parent_outputs_meta=parent_outputs_meta, ) return self.execute_prompt(full_prompt, parse_json=True) - - def map_parameters(self, current_question: str, current_algo_desc: str, dependency_items: List[Dict[str, Any]]) -> Dict[str, Any]: + def map_parameters( + self, + current_question: str, + current_algo_desc: str, + dependency_items: List[Dict[str, Any]], + ) -> Dict[str, Any]: full_prompt = map_parameters_prompt.format( current_question=current_question, algo_desc=current_algo_desc, - dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2) + dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - def generate_graph_conversion_code(self, current_question: str, dependency_items: List[Dict[str, Any]])-> Dict[str, Any]: + def generate_graph_conversion_code( + self, current_question: str, dependency_items: List[Dict[str, Any]] + ) -> Dict[str, Any]: full_prompt = generate_graph_conversion_code_prompt.format( current_question=current_question, - dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2) + dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - + def generate_numeric_analysis_code( - self, - question: str, - dependency_items: List[Dict[str, Any]], - vertex_schema: Dict[str, str], - edge_schema: Dict[str, str] + self, + question: str, + dependency_items: List[Dict[str, Any]], + vertex_schema: Dict[str, str], + edge_schema: Dict[str, str], ) -> Dict[str, Any]: full_prompt = generate_numeric_analysis_code_prompt.format( question=question, - dependency_data_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), + dependency_data_items=json.dumps( + dependency_items, ensure_ascii=False, indent=2 + ), vertex_schema=json.dumps(vertex_schema, ensure_ascii=False, indent=2), - edge_schema=json.dumps(edge_schema, ensure_ascii=False, indent=2) + edge_schema=json.dumps(edge_schema, ensure_ascii=False, indent=2), ) return self.execute_prompt(full_prompt, parse_json=True) - + # add gjq def nl_query_classify_type(self, question: str, query_templates: dict) -> str: """Classify the query type for natural language query engine.""" full_prompt = nl_query_classify_type_prompt.format( question=question, - query_templates=json.dumps(query_templates, ensure_ascii=False, indent=2) + query_templates=json.dumps(query_templates, ensure_ascii=False, indent=2), ) response_text = self.execute_prompt(full_prompt, parse_json=False) return response_text.strip().strip('"').strip("'") - + # add gjq - def nl_query_extract_params(self, question: str, query_type: str, template: dict, - schema_info: str, query_modifiers: dict) -> dict: + def nl_query_extract_params( + self, + question: str, + query_type: str, + template: dict, + schema_info: str, + query_modifiers: dict, + ) -> dict: """Extract parameters for natural language query engine.""" full_prompt = nl_query_extract_params_prompt.format( schema_info=schema_info, query_type=query_type, - template_description=template['description'], - template_method=template['method'], - required_params=template['required_params'], - optional_params=template.get('optional_params', []), + template_description=template["description"], + template_method=template["method"], + required_params=template["required_params"], + optional_params=template.get("optional_params", []), query_modifiers=json.dumps(query_modifiers, ensure_ascii=False, indent=2), - question=question + question=question, ) return self.execute_prompt(full_prompt, parse_json=True) + # add gjq - def nl_query_validate_cypher(self, cypher: str, question: str, query_type: str, - params: dict, schema_info: str, template_info: str, - template_cypher_example: str) -> dict: + def nl_query_validate_cypher( + self, + cypher: str, + question: str, + query_type: str, + params: dict, + schema_info: str, + template_info: str, + template_cypher_example: str, + ) -> dict: """Validate Cypher statement for natural language query engine.""" - full_prompt=nl_query_validate_cypher_prompt.format( - schema_info=schema_info, - template_info=template_info, - template_cypher_example=template_cypher_example, - question=question, - query_type=query_type, - params=json.dumps(params, ensure_ascii=False, indent=2), - cypher=cypher - ) + full_prompt = nl_query_validate_cypher_prompt.format( + schema_info=schema_info, + template_info=template_info, + template_cypher_example=template_cypher_example, + question=question, + query_type=query_type, + params=json.dumps(params, ensure_ascii=False, indent=2), + cypher=cypher, + ) return self.execute_prompt(full_prompt, parse_json=True) class Reasoner: - def __init__(self, config: ReasonerConfig): """Initialize Reasoner by provider selection. @@ -1030,7 +1130,9 @@ def __init__(self, config: ReasonerConfig): fallback_embed_model: used only for OllamaEnv when no embedding is provided from outside. """ if config is None or config.llm is None: - raise ValueError("Reasoner requires a valid ReasonerConfig with llm settings") + raise ValueError( + "Reasoner requires a valid ReasonerConfig with llm settings" + ) self.config = config provider = (config.llm.provider or "ollama").lower() @@ -1058,7 +1160,9 @@ def __init__(self, config: ReasonerConfig): # Allow environment variable fallback api_key = os.environ.get("OPENAI_API_KEY") if not api_key: - raise ValueError("OpenAI provider requires an API key via config.reasoner.llm.openai.api_key or env OPENAI_API_KEY") + raise ValueError( + "OpenAI provider requires an API key via config.reasoner.llm.openai.api_key or env OPENAI_API_KEY" + ) self.env = OpenAIEnv( base_url=base_url, api_key=api_key, @@ -1068,15 +1172,10 @@ def __init__(self, config: ReasonerConfig): else: raise ValueError(f"Unsupported provider: {provider}") - def plan_subqueries(self, decompose: bool, query: str) -> dict: - if decompose == False: + if decompose == False: # Do not decompose, treat as a single question - return{"subqueries": [{ - "id": "q1", - "query": query, - "depends_on": [] - }]} + return {"subqueries": [{"id": "q1", "query": query, "depends_on": []}]} full_prompt = plan_subqueries_prompt.format(query=query) return self.env.execute_prompt(full_prompt, parse_json=True) @@ -1084,32 +1183,35 @@ def plan_expert_subqueries_with_algorithms( self, expert_instruction: str, algorithm_library_info: str = "", - dataset_info: Optional[str] = None + dataset_info: Optional[str] = None, ) -> Dict[str, Any]: """ 将专家自然语言指令转为带 algorithm 字段的 subqueries。 """ full_prompt = expert_subqueries_with_algorithms_prompt_zh.format( expert_instruction=expert_instruction, - algorithm_library_info=algorithm_library_info or "Algorithm library not available", - dataset_info=dataset_info or "N/A" + algorithm_library_info=algorithm_library_info + or "Algorithm library not available", + dataset_info=dataset_info or "N/A", ) return self.env.execute_prompt(full_prompt, parse_json=True) - - def revise_subquery_plan(self, current_plan: Dict[str, Any], user_request: str) -> Dict[str, Any]: + + def revise_subquery_plan( + self, current_plan: Dict[str, Any], user_request: str + ) -> Dict[str, Any]: full_prompt = revise_subquery_plan_prompt.format( current_plan=json.dumps(current_plan, ensure_ascii=False), - user_request=user_request + user_request=user_request, ) return self.env.execute_prompt(full_prompt, parse_json=True) - + def refine_subqueries(self, current_dag: Dict[str, Any]) -> Dict[str, Any]: """ 根据找到的算法信息优化DAG,确保子查询严格按照任务类型边界划分 - + Args: current_dag: 当前的DAG结构,包含subqueries列表 - + Returns: 优化后的DAG结构 """ @@ -1117,31 +1219,38 @@ def refine_subqueries(self, current_dag: Dict[str, Any]) -> Dict[str, Any]: current_dag=json.dumps(current_dag, ensure_ascii=False, indent=2) ) return self.env.execute_prompt(full_prompt, parse_json=True) - + def select_task_type(self, question: str, task_type_list: list) -> dict: full_prompt = select_task_type_prompt.format( - question=question, - task_type_list=task_type_list + question=question, task_type_list=task_type_list ) return self.env.execute_prompt(full_prompt, parse_json=True) - - def select_algorithm(self, question: str, algorithm_list: list, graph_schema: Optional[Dict[str, Any]] = None) -> dict: + + def select_algorithm( + self, + question: str, + algorithm_list: list, + graph_schema: Optional[Dict[str, Any]] = None, + ) -> dict: return self.env.select_algorithm(question, algorithm_list, graph_schema) - - def extract_parameters_with_postprocess(self, question: str, tool_description: str) -> dict: + + def extract_parameters_with_postprocess( + self, question: str, tool_description: str + ) -> dict: full_prompt = extract_parameters_with_postprocess_promt.format( - question=question, - tool_description=tool_description + question=question, tool_description=tool_description ) return self.env.execute_prompt(full_prompt, parse_json=True) - - def check_data_dependency(self, q1_question: str, q1_algorithm: str, q2_question: str, q2_algorithm: str) -> bool: + + def check_data_dependency( + self, q1_question: str, q1_algorithm: str, q2_question: str, q2_algorithm: str + ) -> bool: try: full_prompt = check_data_dependency_prompt.format( q1_question=q1_question, q1_algorithm=q1_algorithm, q2_question=q2_question, - q2_algorithm=q2_algorithm + q2_algorithm=q2_algorithm, ) result = self.env.execute_prompt(full_prompt, parse_json=True) depends = result.get("q2_depends_on_q1") @@ -1157,16 +1266,27 @@ def check_data_dependency(self, q1_question: str, q1_algorithm: str, q2_questio print(f"Error determining data dependency: {e}") return False - def generate_answer_from_algorithm_result(self, question: str, tool_description: str, tool_result: Dict[str, Any]): - return self.env.generate_answer_from_algorithm_result(question, tool_description, tool_result) + def generate_answer_from_algorithm_result( + self, question: str, tool_description: str, tool_result: Dict[str, Any] + ): + return self.env.generate_answer_from_algorithm_result( + question, tool_description, tool_result + ) - def analyze_dependency_type_and_locate_dependency_data(self, current_question:str, task_type:str, current_algo_desc:str, parent_question: str, parent_outputs_meta:list) -> dict: + def analyze_dependency_type_and_locate_dependency_data( + self, + current_question: str, + task_type: str, + current_algo_desc: str, + parent_question: str, + parent_outputs_meta: list, + ) -> dict: full_prompt = analyze_dependency_type_and_locate_dependency_data_prompt.format( current_question=current_question, task_type=task_type, current_algo_desc=current_algo_desc, parent_question=parent_question, - parent_outputs_meta=parent_outputs_meta + parent_outputs_meta=parent_outputs_meta, ) return self.env.execute_prompt(full_prompt, parse_json=True) @@ -1174,33 +1294,42 @@ def classify_question_type(self, question: str) -> Dict[str, Any]: full_prompt = classify_question_type_prompt.format(question=question) return self.env.execute_prompt(full_prompt, parse_json=True) - def map_parameters(self, current_question: str, current_algo_desc: str, dependency_items: List[Dict[str, Any]]) -> Dict[str, Any]: + def map_parameters( + self, + current_question: str, + current_algo_desc: str, + dependency_items: List[Dict[str, Any]], + ) -> Dict[str, Any]: full_prompt = map_parameters_prompt.format( current_question=current_question, algo_desc=current_algo_desc, - dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2) + dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), ) return self.env.execute_prompt(full_prompt, parse_json=True) - def generate_graph_conversion_code(self, current_question: str, dependency_items: List[Dict[str, Any]])-> Dict[str, Any]: + def generate_graph_conversion_code( + self, current_question: str, dependency_items: List[Dict[str, Any]] + ) -> Dict[str, Any]: full_prompt = generate_graph_conversion_code_prompt.format( current_question=current_question, - dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2) + dependency_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), ) return self.env.execute_prompt(full_prompt, parse_json=True) def generate_numeric_analysis_code( - self, - question: str, - dependency_items: List[Dict[str, Any]], - vertex_schema: Dict[str, str], - edge_schema: Dict[str, str] + self, + question: str, + dependency_items: List[Dict[str, Any]], + vertex_schema: Dict[str, str], + edge_schema: Dict[str, str], ) -> Dict[str, Any]: full_prompt = generate_numeric_analysis_code_prompt.format( question=question, - dependency_data_items=json.dumps(dependency_items, ensure_ascii=False, indent=2), + dependency_data_items=json.dumps( + dependency_items, ensure_ascii=False, indent=2 + ), vertex_schema=json.dumps(vertex_schema, ensure_ascii=False, indent=2), - edge_schema=json.dumps(edge_schema, ensure_ascii=False, indent=2) + edge_schema=json.dumps(edge_schema, ensure_ascii=False, indent=2), ) return self.env.execute_prompt(full_prompt, parse_json=True) @@ -1217,37 +1346,38 @@ def get_quetion_response(self, question, graph_result, language="en"): def generate_response(self, prompt: str): if hasattr(self.env, "generate_response"): return self.env.generate_response(prompt) - raise NotImplementedError("Underlying environment does not support generate_response/complete") + raise NotImplementedError( + "Underlying environment does not support generate_response/complete" + ) def extract_parameters_with_postprocess_new( - self, - question: str, - tool_description: str, - vertex_schema: Dict[str, str], + self, + question: str, + tool_description: str, + vertex_schema: Dict[str, str], edge_schema: Dict[str, str], - error_history: Optional[List[Dict[str, Any]]] = None, - trace: Optional[Dict[str, Any]] = None, + error_history: Optional[List[Dict[str, Any]]] = None, + trace: Optional[Dict[str, Any]] = None, ) -> dict: full_prompt = extract_parameters_with_postprocess_promt_new.format( question=question, tool_description=tool_description, vertex_schema=json.dumps(vertex_schema, ensure_ascii=False, indent=2), - edge_schema=json.dumps(edge_schema, ensure_ascii=False, indent=2) + edge_schema=json.dumps(edge_schema, ensure_ascii=False, indent=2), ) if error_history: full_prompt = enhance_prompt( - base_prompt=full_prompt, - error_history=error_history + base_prompt=full_prompt, error_history=error_history ) return self.env.execute_prompt(full_prompt, parse_json=True) def merge_parameters_from_dependencies( - self, - question: str, - tool_description: str, - vertex_schema: Dict[str, str], + self, + question: str, + tool_description: str, + vertex_schema: Dict[str, str], edge_schema: Dict[str, str], dependency_parameters: Dict[str, Any], error_history: Optional[List[Dict[str, Any]]] = None, @@ -1259,16 +1389,14 @@ def merge_parameters_from_dependencies( tool_description=tool_description, dependency_parameters=json.dumps(dependency_parameters, indent=2), vertex_schema=json.dumps(vertex_schema, indent=2), - edge_schema=json.dumps(edge_schema, indent=2) + edge_schema=json.dumps(edge_schema, indent=2), ) if error_history: full_prompt = enhance_prompt( - base_prompt=full_prompt, - error_history=error_history + base_prompt=full_prompt, error_history=error_history ) - return self.env.execute_prompt(full_prompt, parse_json=True) def chat(self, messages: list): @@ -1279,49 +1407,60 @@ def chat(self, messages: list): def general_query_response(self, query): messages = [ {"role": "system", "content": general_query_prompt}, - { - "role": "user", - "content": query - }, + {"role": "user", "content": query}, ] return self.chat(messages) - + def nl_query_classify_type(self, question: str, query_templates: dict) -> str: """Classify the query type for natural language query engine.""" full_prompt = nl_query_classify_type_prompt.format( question=question, - query_templates=json.dumps(query_templates, ensure_ascii=False, indent=2) + query_templates=json.dumps(query_templates, ensure_ascii=False, indent=2), ) response_text = self.env.execute_prompt(full_prompt, parse_json=False) return response_text.strip().strip('"').strip("'") - - def nl_query_extract_params(self, question: str, query_type: str, template: dict, - schema_info: str, query_modifiers: dict) -> dict: + + def nl_query_extract_params( + self, + question: str, + query_type: str, + template: dict, + schema_info: str, + query_modifiers: dict, + ) -> dict: """Extract parameters for natural language query engine.""" full_prompt = nl_query_extract_params_prompt.format( schema_info=schema_info, query_type=query_type, - template_description=template['description'], - template_method=template['method'], - required_params=template['required_params'], - optional_params=template.get('optional_params', []), + template_description=template["description"], + template_method=template["method"], + required_params=template["required_params"], + optional_params=template.get("optional_params", []), query_modifiers=json.dumps(query_modifiers, ensure_ascii=False, indent=2), - question=question + question=question, ) return self.env.execute_prompt(full_prompt, parse_json=True) - def nl_query_validate_cypher(self, cypher: str, question: str, query_type: str, - params: dict, schema_info: str, template_info: str, - template_cypher_example: str) -> dict: + + def nl_query_validate_cypher( + self, + cypher: str, + question: str, + query_type: str, + params: dict, + schema_info: str, + template_info: str, + template_cypher_example: str, + ) -> dict: """Validate Cypher statement for natural language query engine.""" - full_prompt=nl_query_validate_cypher_prompt.format( - schema_info=schema_info, - template_info=template_info, - template_cypher_example=template_cypher_example, - question=question, - query_type=query_type, - params=json.dumps(params, ensure_ascii=False, indent=2), - cypher=cypher - ) + full_prompt = nl_query_validate_cypher_prompt.format( + schema_info=schema_info, + template_info=template_info, + template_cypher_example=template_cypher_example, + question=question, + query_type=query_type, + params=json.dumps(params, ensure_ascii=False, indent=2), + cypher=cypher, + ) return self.execute_prompt(full_prompt, parse_json=True) def rewrite_query( @@ -1329,38 +1468,44 @@ def rewrite_query( original_query: str, algorithm_library_info: str, dataset_info: Optional[str] = None, - use_chinese: bool = True + use_chinese: bool = True, ) -> Dict[str, Any]: """ Rewrite a vague user query into a concrete, executable query. - + Args: original_query: The user's original vague question algorithm_library_info: Information about available task types and algorithms dataset_info: Optional information about the current graph dataset use_chinese: Whether to use Chinese prompt (default: True) - + Returns: Dict containing: - rewritten_query: The concrete, executable query - reasoning: Explanation of changes made - mapped_concepts: List of concept mappings """ - dataset_context = dataset_info if dataset_info else ("暂无数据集信息" if use_chinese else "No dataset information available") - + dataset_context = ( + dataset_info + if dataset_info + else ( + "暂无数据集信息" if use_chinese else "No dataset information available" + ) + ) + # Choose prompt based on language preference if use_chinese: - from aag.reasoner.prompt_template.llm_prompt_zh import rewrite_query_prompt_zh + from aag.reasoner.prompt_template.llm_prompt_zh import ( + rewrite_query_prompt_zh, + ) + prompt_template = rewrite_query_prompt_zh else: prompt_template = rewrite_query_prompt - + full_prompt = prompt_template.format( original_query=original_query, algorithm_library_info=algorithm_library_info, - dataset_info=dataset_context + dataset_info=dataset_context, ) return self.env.execute_prompt(full_prompt, parse_json=True) - - - diff --git a/web/frontend/route/sockets_chat.py b/web/frontend/route/sockets_chat.py index 9207365..f2d2055 100644 --- a/web/frontend/route/sockets_chat.py +++ b/web/frontend/route/sockets_chat.py @@ -3,15 +3,67 @@ import re import logging import asyncio -from flask_socketio import emit +from flask import request +from flask_socketio import emit, disconnect from . import socketio + # Add project path to import AAG services -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) +sys.path.append( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + ) +) from aag.api.async_runtime import get_background_loop, get_chat_service - +from aag.api.services.chat_service import CHAT_FRIENDLY_ERROR_MSG logger = logging.getLogger(__name__) +# WebSocket 连接跟踪:记录活跃连接的 session ID +_active_connections: set = set() + + +@socketio.on("connect") +def on_connect(): + """WebSocket 连接认证:验证 token 后才允许连接""" + expected_token = os.getenv("YIGRAPH_WS_TOKEN", "") + + # 未设置环境变量则跳过验证(开发环境兼容) + if not expected_token: + logger.warning( + "YIGRAPH_WS_TOKEN 未设置,WebSocket 连接跳过认证(仅建议开发环境使用)" + ) + _active_connections.add(request.sid) + logger.info( + f"WebSocket 客户端已连接(无认证): {request.sid}, 活跃连接数: {len(_active_connections)}" + ) + return + + # 从请求参数或 Header 中提取 token + token = request.args.get("token", "") + if not token: + token = request.headers.get("X-WS-Token", "") + + if token != expected_token: + logger.warning(f"WebSocket 认证失败,拒绝连接: {request.sid}") + # 认证失败:拒绝连接并返回错误信息 + disconnect() + return False + + _active_connections.add(request.sid) + logger.info( + f"WebSocket 客户端已连接: {request.sid}, 活跃连接数: {len(_active_connections)}" + ) + + +@socketio.on("disconnect") +def on_disconnect(): + """WebSocket 断开连接:清理资源并记录""" + sid = request.sid + _active_connections.discard(sid) + logger.info( + f"WebSocket 客户端已断开: {sid}, 剩余活跃连接数: {len(_active_connections)}" + ) + def split_into_sentences(text: str, max_len: int = 80): if not text: @@ -20,15 +72,17 @@ def split_into_sentences(text: str, max_len: int = 80): # Step 1: split by punctuation (CJK and Western) # (?<=[。!?\n]) after CJK punctuation or newline # (?<=[.!?]\s) after period/question/exclamation + space - sentences = re.split(r'(?<=[。!?\n])|(?<=[.!?]\s)', text) - + sentences = re.split(r"(?<=[。!?\n])|(?<=[.!?]\s)", text) + # Filter empty strings and strip whitespace sentences = [s.strip() for s in sentences if s.strip()] - + # If no sentence boundaries found (e.g. long run-on text), split by paragraphs if len(sentences) <= 1: # Split by blank lines, then by single newlines - paragraphs = [p.strip() for p in text.replace('\n\n', '\n').split('\n') if p.strip()] + paragraphs = [ + p.strip() for p in text.replace("\n\n", "\n").split("\n") if p.strip() + ] if len(paragraphs) > 1: sentences = paragraphs @@ -36,39 +90,39 @@ def split_into_sentences(text: str, max_len: int = 80): def smart_split_long_sentence(sentence: str, max_len: int): if len(sentence) <= max_len: return [sentence] - + result = [] start = 0 text_len = len(sentence) - + while start < text_len: end = start + max_len - + # Already at end, append remainder if end >= text_len: result.append(sentence[start:].strip()) break - + # Look backward from end for first space or CJK punctuation as split point split_pos = -1 for i in range(end, max(start, end - 20), -1): - if sentence[i] in ' \t\n。!?,!?': + if sentence[i] in " \t\n。!?,!?": split_pos = i break - + # No break character found, force split at max_len if split_pos == -1: split_pos = end - - chunk = sentence[start:split_pos + 1].strip() + + chunk = sentence[start : split_pos + 1].strip() if chunk: result.append(chunk) start = split_pos + 1 - + # Guard against infinite loop if start >= text_len: break - + return result final = [] @@ -77,9 +131,10 @@ def smart_split_long_sentence(sentence: str, max_len: int): final.extend(smart_split_long_sentence(sent, max_len)) else: final.append(sent) - + return final if final else [text.strip()] + def smart_split_markdown(text: str, max_len: int = 80): """ Markdown-aware smart split. Preserves code blocks and key syntax; @@ -91,52 +146,54 @@ def smart_split_markdown(text: str, max_len: int = 80): # --- Layer 1: preserve code blocks --- # Split into [plain, code block, plain, code block, ...] # (```[\s\S]*?```) multi-line code, (`[^`\n]+`) inline code - parts = re.split(r'(```[\s\S]*?```|`[^`\n]+`)', text) + parts = re.split(r"(```[\s\S]*?```|`[^`\n]+`)", text) atoms = [] - + for part in parts: if not part: continue - + # 1. Code block (starts with `): treat as single atom - if part.startswith('`'): + if part.startswith("`"): # Very long block: split by newlines for streaming - if len(part) > max_len and '\n' in part: - code_lines = part.split('\n') + if len(part) > max_len and "\n" in part: + code_lines = part.split("\n") for idx, line in enumerate(code_lines): - suffix = '\n' if idx < len(code_lines) - 1 else '' + suffix = "\n" if idx < len(code_lines) - 1 else "" atoms.append(line + suffix) else: atoms.append(part) - + # 2. Plain text: fine-grained split else: # Split by paragraph/newline first (important in Markdown) - lines = part.split('\n') + lines = part.split("\n") for i, line in enumerate(lines): - suffix = '\n' if i < len(lines) - 1 else '' + suffix = "\n" if i < len(lines) - 1 else "" full_line = line + suffix - + if not line.strip(): atoms.append(full_line) continue # Split by sentence punctuation within line - sub_parts = re.split(r'([。!?]|(?<=[.!?])\s)', line) - + sub_parts = re.split(r"([。!?]|(?<=[.!?])\s)", line) + current_sent = "" for sub in sub_parts: current_sent += sub - if sub in ['。', '!', '?'] or (sub.strip() == '' and len(current_sent) > 0): + if sub in ["。", "!", "?"] or ( + sub.strip() == "" and len(current_sent) > 0 + ): atoms.append(current_sent) current_sent = "" - elif len(current_sent) > 0 and current_sent[-1] in '.!?': - pass - + elif len(current_sent) > 0 and current_sent[-1] in ".!?": + pass + if current_sent: atoms.append(current_sent) - + if suffix: if atoms: atoms[-1] += suffix @@ -162,7 +219,7 @@ def smart_split_markdown(text: str, max_len: int = 80): return final_chunks -@socketio.on('chat_request') +@socketio.on("chat_request") def handle_chat_request(data): """WebSocket chat handler: receive user message, push streaming results.""" # 1. Parse parameters @@ -170,26 +227,36 @@ def handle_chat_request(data): user_message = str(data.get("message") or "").strip() selected_model = str(data.get("model") or "") dag_confirm = str(data.get("dag_confirm") or "").strip() - is_dag_modification = str(data.get("is_dag_modification", "false")).lower() == "true" + is_dag_modification = ( + str(data.get("is_dag_modification", "false")).lower() == "true" + ) dag_id = str(data.get("dag_id") or "") modifications = data.get("modifications", "") # may be str or other expert_mode = data.get("expert_mode", False) dataset = str(data.get("dataset") or "").strip() # dataset name from frontend _dtype = data.get("dataset_type") or data.get("file_type") - dataset_type = str(_dtype).strip() if _dtype else None # "text" | "graph" | None + dataset_type = ( + str(_dtype).strip() if _dtype else None + ) # "text" | "graph" | None custom_mode = data.get("mode") except Exception as e: logger.error(f"Failed to parse parameters: {e}") - emit('chat_response', {"error": "Invalid request format. Please check parameters."}) + emit( + "chat_response", + {"error": "Invalid request format. Please check parameters."}, + ) return # 2. Validation if not user_message and not dag_confirm and not is_dag_modification: - emit('chat_response', {"error": "Message content cannot be empty."}) + emit("chat_response", {"error": "Message content cannot be empty."}) return if not dataset: - emit('chat_response', {"error": "Dataset is empty. Please specify a dataset first."}) + emit( + "chat_response", + {"error": "Dataset is empty. Please specify a dataset first."}, + ) return try: @@ -199,15 +266,19 @@ def handle_chat_request(data): # Callback to send streaming data (called from background event loop thread) def send_response(data_chunk): """Send response data to frontend.""" - socketio.emit('chat_response', data_chunk) + socketio.emit("chat_response", data_chunk) # Determine mode if custom_mode == "interact": mode = "interact" else: mode = "expert" if expert_mode else "normal" - logger.info(f"WS request: model={selected_model}, dataset={dataset}, message={user_message[:20]}..., expertMode={expert_mode}, mode={mode}") - print(f"WS request: model={selected_model}, dataset={dataset}, message={user_message[:20]}..., expertMode={expert_mode}, mode={mode}") + logger.info( + f"WS request: model={selected_model}, dataset={dataset}, message={user_message[:20]}..., expertMode={expert_mode}, mode={mode}" + ) + print( + f"WS request: model={selected_model}, dataset={dataset}, message={user_message[:20]}..., expertMode={expert_mode}, mode={mode}" + ) async def process_request(): try: @@ -219,41 +290,53 @@ async def process_request(): if result.get("success"): result_text = result.get("result", "") - paragraphs = [p.strip() for p in result_text.split('\n') if p.strip()] + paragraphs = [ + p.strip() for p in result_text.split("\n") if p.strip() + ] for i, para in enumerate(paragraphs): - send_response({ - 'type': 'result', - 'contentType': 'text', - 'content': para - }) - await asyncio.sleep(0.5 if i < len(paragraphs)-1 else 0.3) + send_response( + { + "type": "result", + "contentType": "text", + "content": para, + } + ) + await asyncio.sleep(0.5 if i < len(paragraphs) - 1 else 0.3) else: - send_response({ - 'error': result.get("error", "Analysis execution failed.") - }) - send_response({'type': 'stream_end'}) + send_response( + {"error": result.get("error", "Analysis execution failed.")} + ) + send_response({"type": "stream_end"}) return if is_dag_modification or (dag_confirm == "no" and modifications): engine = chat_service.engine_service.get_engine() engine.specific_dataset(dataset, dataset_type) modification_request = modifications or user_message - logger.info(f"DAG modification request received: {modification_request}") + logger.info( + f"DAG modification request received: {modification_request}" + ) - result = await chat_service.process_dag_modification(modification_request) + result = await chat_service.process_dag_modification( + modification_request + ) if result.get("success"): - dag_content = chat_service._convert_dag_to_frontend_format(result) - send_response({ - 'type': 'result', - 'contentType': 'dag', - 'content': dag_content - }) + dag_content = chat_service._convert_dag_to_frontend_format( + result + ) + send_response( + { + "type": "result", + "contentType": "dag", + "content": dag_content, + } + ) else: - send_response({ - 'error': result.get("error", "DAG modification failed.") - }) - send_response({'type': 'stream_end'}) + send_response( + {"error": result.get("error", "DAG modification failed.")} + ) + send_response({"type": "stream_end"}) return # Normal chat request — streaming (stream_end is sent internally) @@ -265,12 +348,18 @@ async def process_request(): dataset_type=dataset_type, mode=mode, expert_mode=expert_mode, - callback=send_response + callback=send_response, ) except Exception as exc: logger.error(f"Background processing failed: {exc}", exc_info=True) - send_response({'type': 'result', 'contentType': 'text', 'content': CHAT_FRIENDLY_ERROR_MSG}) - send_response({'type': 'stream_end'}) + send_response( + { + "type": "result", + "contentType": "text", + "content": CHAT_FRIENDLY_ERROR_MSG, + } + ) + send_response({"type": "stream_end"}) future = asyncio.run_coroutine_threadsafe(process_request(), loop) future.result() @@ -280,5 +369,12 @@ async def process_request(): except Exception as e: error_msg = f"Processing failed: {str(e)}" logger.error(error_msg, exc_info=True) - emit('chat_response', {'type': 'result', 'contentType': 'text', 'content': CHAT_FRIENDLY_ERROR_MSG}) - emit('chat_response', {'type': 'stream_end'}) + emit( + "chat_response", + { + "type": "result", + "contentType": "text", + "content": CHAT_FRIENDLY_ERROR_MSG, + }, + ) + emit("chat_response", {"type": "stream_end"})