diff --git a/README.md b/README.md index 8177132b7..e454aa906 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ # 简介 -> chatgpt-on-wechat(简称CoW)项目是基于大模型的智能对话机器人,支持微信公众号、企业微信应用、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/Gemini/LinkAI/ChatGLM/KIMI/文心一言/讯飞星火/通义千问/LinkAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。 +> chatgpt-on-wechat(简称CoW)项目是基于大模型的智能对话机器人,支持微信公众号、企业微信应用、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/Gemini/LinkAI/ChatGLM/KIMI/文心一言/讯飞星火/通义千问/LinkAI/Dify,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。 最新版本支持的功能如下: - ✅ **多端部署:** 有多种部署方式可选择且功能完备,目前已支持微信公众号、企业微信应用、飞书、钉钉等部署方式 -- ✅ **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4o-mini, GPT-4o, GPT-4, Claude-3.5, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM-4,Kimi(月之暗面), MiniMax +- ✅ **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4o-mini, GPT-4o, GPT-4, Claude-3.5, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM-4,Kimi(月之暗面), MiniMax, Dify - ✅ **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型 - ✅ **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型 - ✅ **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件 @@ -131,8 +131,8 @@ pip3 install -r requirements-optional.txt ```bash # config.json文件内容示例 -{ - "model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-4, gpt-4-turbo, wenxin, xunfei, glm-4, claude-3-haiku, moonshot +{, + "model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-4, gpt-4-turbo, wenxin, xunfei, glm-4, claude-3-haiku, moonshot,dify "open_ai_api_key": "YOUR API KEY", # 如果使用openAI模型则填入上面创建的 OpenAI API KEY "proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890" "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 @@ -175,7 +175,7 @@ pip3 install -r requirements-optional.txt **4.其他配置** -+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `gpt-4o-mini`, `gpt-4o`, `gpt-4`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件 ++ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `gpt-4o-mini`, `gpt-4o`, `gpt-4`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`, `dify`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件 + `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` diff --git a/app.py b/app.py index ff2a6c774..a61f0e0ca 100644 --- a/app.py +++ b/app.py @@ -31,12 +31,12 @@ def start_channel(channel_name: str): const.FEISHU, const.DINGTALK]: PluginManager().load_plugins() - if conf().get("use_linkai"): - try: - from common import linkai_client - threading.Thread(target=linkai_client.start, args=(channel,)).start() - except Exception as e: - pass + # if conf().get("use_linkai"): + # try: + # from common import linkai_client + # threading.Thread(target=linkai_client.start, args=(channel,)).start() + # except Exception as e: + # pass channel.startup() @@ -50,7 +50,7 @@ def run(): sigterm_handler_wrap(signal.SIGTERM) # create channel - channel_name = conf().get("channel_type", "wx") + channel_name = conf().get("channel_type", "dingtalk") if "--cmd" in sys.argv: channel_name = "terminal" diff --git a/bot/bot_factory.py b/bot/bot_factory.py index a6ef2415b..7b92ceeff 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -56,6 +56,10 @@ def create_bot(bot_type): from bot.gemini.google_gemini_bot import GoogleGeminiBot return GoogleGeminiBot() + elif bot_type == const.DIFY: + from bot.dify.dify_bot import DifyBot + return DifyBot() + elif bot_type == const.ZHIPU_AI: from bot.zhipuai.zhipuai_bot import ZHIPUAIBot return ZHIPUAIBot() diff --git a/bot/dify/dify_bot.py b/bot/dify/dify_bot.py new file mode 100644 index 000000000..5262d46ce --- /dev/null +++ b/bot/dify/dify_bot.py @@ -0,0 +1,305 @@ +# encoding:utf-8 +import json +import threading + +import requests + +from bot.bot import Bot +from bot.dify.dify_session import DifySession, DifySessionManager +from bridge.context import ContextType, Context +from bridge.reply import Reply, ReplyType +from common.log import logger +from common import const +from config import conf + +class DifyBot(Bot): + def __init__(self): + super().__init__() + # set the default api_key + self.api_key = conf().get("open_ai_api_key") + if conf().get("open_ai_api_base"): + self.api_base = conf().get("open_ai_api_base") + proxy = conf().get("proxy") + if proxy: + self.proxy = proxy + self.sessions = DifySessionManager(DifySession, model=conf().get("model", const.DIFY)) + + def reply(self, query, context: Context=None): + # acquire reply content + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: + if context.type == ContextType.IMAGE_CREATE: + query = conf().get('image_create_prefix', ['画'])[0] + query + logger.info("[DIFY] query={}".format(query)) + session_id = context["session_id"] + channel_type = conf().get("channel_type") + user = None + if channel_type == "wx": + user = context["msg"].other_user_nickname if context.get("msg") else "default" + elif channel_type in ["wechatcom_app", "wechatmp", "wechatmp_service", "wechatcom_service", "wework", + "dingtalk"]: + user = context["msg"].other_user_id if context.get("msg") else "default" + else: + return Reply(ReplyType.ERROR, + f"unsupported channel type: {channel_type}, now dify only support wx, wechatcom_app, wechatmp, wechatmp_service, dingtalk channel") + logger.debug(f"[DIFY] dify_user={user}") + user = user if user else "default" # 防止用户名为None,当被邀请进的群未设置群名称时用户名为None + session = self.sessions.get_session(session_id, user) + logger.debug(f"[DIFY] session={session} query={query}") + + reply, err = self._reply(query, session, context) + if err != None: + reply = Reply(ReplyType.TEXT, "我暂时遇到了一些问题,请您稍后重试~") + return reply + else: + reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) + return reply + + def _get_api_base_url(self): + return self.api_base + + def _get_headers(self): + return { + 'Authorization': f"Bearer {self.api_key}" + } + + def _get_payload(self, query, session: DifySession, response_mode): + return { + 'inputs': {}, + "query": query, + "response_mode": response_mode, + "conversation_id": session.get_conversation_id(), + "user": session.get_user() + } + + def _reply(self, query: str, session: DifySession, context: Context): + try: + session.count_user_message() # 限制一个conversation中消息数,防止conversation过长 + dify_app_type = conf().get('dify_app_type', 'chatbot') + if dify_app_type == 'chatbot': + return self._handle_chatbot(query, session) + elif dify_app_type == 'agent': + return self._handle_agent(query, session, context) + elif dify_app_type == 'workflow': + return self._handle_workflow(query, session) + else: + return None, "dify_app_type must be agent, chatbot or workflow" + + except Exception as e: + error_info = f"[DIFY] Exception: {e}" + logger.exception(error_info) + return None, error_info + + def _handle_chatbot(self, query: str, session: DifySession): + # TODO: 获取response部分抽取为公共函数 + base_url = self._get_api_base_url() + chat_url = f'{base_url}/chat-messages' + headers = self._get_headers() + response_mode = 'blocking' + payload = self._get_payload(query, session, response_mode) + response = requests.post(chat_url, headers=headers, json=payload) + if response.status_code != 200: + error_info = f"[DIFY] response text={response.text} status_code={response.status_code}" + logger.warn(error_info) + return None, error_info + + # response: + # { + # "event": "message", + # "message_id": "9da23599-e713-473b-982c-4328d4f5c78a", + # "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", + # "mode": "chat", + # "answer": "xxx", + # "metadata": { + # "usage": { + # }, + # "retriever_resources": [] + # }, + # "created_at": 1705407629 + # } + rsp_data = response.json() + logger.debug("[DIFY] usage {}".format(rsp_data.get('metadata', {}).get('usage', 0))) + # TODO: 处理返回的图片文件 + # {"answer": "![image](/files/tools/dbf9cd7c-2110-4383-9ba8-50d9fd1a4815.png?timestamp=1713970391&nonce=0d5badf2e39466042113a4ba9fd9bf83&sign=OVmdCxCEuEYwc9add3YNFFdUpn4VdFKgl84Cg54iLnU=)"} + reply = Reply(ReplyType.TEXT, rsp_data['answer']) + # 设置dify conversation_id, 依靠dify管理上下文 + if session.get_conversation_id() == '': + session.set_conversation_id(rsp_data['conversation_id']) + return reply, None + + def _handle_agent(self, query: str, session: DifySession, context: Context): + # TODO: 获取response抽取为公共函数 + base_url = self._get_api_base_url() + chat_url = f'{base_url}/chat-messages' + headers = self._get_headers() + response_mode = 'streaming' + payload = self._get_payload(query, session, response_mode) + response = requests.post(chat_url, headers=headers, json=payload) + if response.status_code != 200: + error_info = f"[DIFY] response text={response.text} status_code={response.status_code}" + logger.warn(error_info) + return None, error_info + # response: + # data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} + # data: {"event": "agent_thought", "id": "8dcf3648-fbad-407a-85dd-73a6f43aeb9f", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 1, "thought": "", "observation": "", "tool": "dalle3", "tool_input": "{\"dalle3\": {\"prompt\": \"cute Japanese anime girl with white hair, blue eyes, bunny girl suit\"}}", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} + # data: {"event": "agent_message", "id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "answer": "I have created an image of a cute Japanese", "created_at": 1705639511, "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} + # data: {"event": "message_end", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142", "metadata": {"usage": {"prompt_tokens": 305, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0003050", "completion_tokens": 97, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0001940", "total_tokens": 184, "total_price": "0.0002290", "currency": "USD", "latency": 1.771092874929309}}} + msgs, conversation_id = self._handle_sse_response(response) + channel = context.get("channel") + # TODO: 适配除微信以外的其他channel + is_group = context.get("isgroup", False) + for msg in msgs[:-1]: + if msg['type'] == 'agent_message': + if is_group: + at_prefix = "@" + context["msg"].actual_user_nickname + "\n" + msg['content'] = at_prefix + msg['content'] + reply = Reply(ReplyType.TEXT, msg['content']) + channel.send(reply, context) + elif msg['type'] == 'message_file': + url = self._fill_file_base_url(msg['content']['url']) + reply = Reply(ReplyType.IMAGE_URL, url) + thread = threading.Thread(target=channel.send, args=(reply, context)) + thread.start() + final_msg = msgs[-1] + reply = None + if final_msg['type'] == 'agent_message': + reply = Reply(ReplyType.TEXT, final_msg['content']) + elif final_msg['type'] == 'message_file': + url = self._fill_file_base_url(final_msg['content']['url']) + reply = Reply(ReplyType.IMAGE_URL, url) + # 设置dify conversation_id, 依靠dify管理上下文 + if session.get_conversation_id() == '': + session.set_conversation_id(conversation_id) + return reply, None + + def _handle_workflow(self, query: str, session: DifySession): + base_url = self._get_api_base_url() + workflow_url = f'{base_url}/workflows/run' + headers = self._get_headers() + payload = self._get_workflow_payload(query, session) + response = requests.post(workflow_url, headers=headers, json=payload) + if response.status_code != 200: + error_info = f"[DIFY] response text={response.text} status_code={response.status_code}" + logger.warn(error_info) + return None, error_info + # { + # "log_id": "djflajgkldjgd", + # "task_id": "9da23599-e713-473b-982c-4328d4f5c78a", + # "data": { + # "id": "fdlsjfjejkghjda", + # "workflow_id": "fldjaslkfjlsda", + # "status": "succeeded", + # "outputs": { + # "text": "Nice to meet you." + # }, + # "error": null, + # "elapsed_time": 0.875, + # "total_tokens": 3562, + # "total_steps": 8, + # "created_at": 1705407629, + # "finished_at": 1727807631 + # } + # } + rsp_data = response.json() + reply = Reply(ReplyType.TEXT, rsp_data['data']['outputs']['text']) + return reply, None + + def _fill_file_base_url(self, url: str): + if url.startswith("https://") or url.startswith("http://"): + return url + # 补全文件base url, 默认使用去掉"/v1"的dify api base url + return self._get_file_base_url() + url + + def _get_file_base_url(self) -> str: + return self._get_api_base_url().replace("/v1", "") + + def _get_workflow_payload(self, query, session: DifySession): + return { + 'inputs': { + "query": query + }, + "response_mode": "blocking", + "user": session.get_user() + } + + def _parse_sse_event(self, event_str): + """ + Parses a single SSE event string and returns a dictionary of its data. + """ + event_prefix = "data: " + if not event_str.startswith(event_prefix): + return None + trimmed_event_str = event_str[len(event_prefix):] + + # Check if trimmed_event_str is not empty and is a valid JSON string + if trimmed_event_str: + try: + event = json.loads(trimmed_event_str) + return event + except json.JSONDecodeError: + logger.error(f"Failed to decode JSON from SSE event: {trimmed_event_str}") + return None + else: + logger.warn("Received an empty SSE event.") + return None + + # TODO: 异步返回events + def _handle_sse_response(self, response: requests.Response): + events = [] + for line in response.iter_lines(): + if line: + decoded_line = line.decode('utf-8') + event = self._parse_sse_event(decoded_line) + if event: + events.append(event) + + merged_message = [] + accumulated_agent_message = '' + conversation_id = None + for event in events: + event_name = event['event'] + if event_name == 'agent_message' or event_name == 'message': + accumulated_agent_message += event['answer'] + logger.debug("[DIFY] accumulated_agent_message: {}".format(accumulated_agent_message)) + # 保存conversation_id + if not conversation_id: + conversation_id = event['conversation_id'] + elif event_name == 'agent_thought': + self._append_agent_message(accumulated_agent_message, merged_message) + accumulated_agent_message = '' + logger.debug("[DIFY] agent_thought: {}".format(event)) + elif event_name == 'message_file': + self._append_agent_message(accumulated_agent_message, merged_message) + accumulated_agent_message = '' + self._append_message_file(event, merged_message) + elif event_name == 'message_replace': + # TODO: handle message_replace + pass + elif event_name == 'error': + logger.error("[DIFY] error: {}".format(event)) + raise Exception(event) + elif event_name == 'message_end': + self._append_agent_message(accumulated_agent_message, merged_message) + logger.debug("[DIFY] message_end usage: {}".format(event['metadata']['usage'])) + break + else: + logger.warn("[DIFY] unknown event: {}".format(event)) + + if not conversation_id: + raise Exception("conversation_id not found") + + return merged_message, conversation_id + + def _append_agent_message(self, accumulated_agent_message, merged_message): + if accumulated_agent_message: + merged_message.append({ + 'type': 'agent_message', + 'content': accumulated_agent_message, + }) + + def _append_message_file(self, event: dict, merged_message: list): + if event.get('type') != 'image': + logger.warn("[DIFY] unsupported message file type: {}".format(event)) + merged_message.append({ + 'type': 'message_file', + 'content': event, + }) diff --git a/bot/dify/dify_session.py b/bot/dify/dify_session.py new file mode 100644 index 000000000..7ac2b14b2 --- /dev/null +++ b/bot/dify/dify_session.py @@ -0,0 +1,63 @@ +from common.expired_dict import ExpiredDict +from config import conf + + +class DifySession(object): + def __init__(self, session_id: str, user: str, conversation_id: str=''): + self.__session_id = session_id + self.__user = user + self.__conversation_id = conversation_id + self.__user_message_counter = 0 + + def get_session_id(self): + return self.__session_id + + def get_user(self): + return self.__user + + def get_conversation_id(self): + return self.__conversation_id + + def set_conversation_id(self, conversation_id): + self.__conversation_id = conversation_id + + def count_user_message(self): + if self.__user_message_counter >= conf().get("dify_convsersation_max_messages", 5): + self.__user_message_counter = 0 + # FIXME: dify目前不支持设置历史消息长度,暂时使用超过5条清空会话的策略,缺点是没有滑动窗口,会突然丢失历史消息 + self.__conversation_id = '' + + self.__user_message_counter += 1 + +class DifySessionManager(object): + def __init__(self, sessioncls, **session_kwargs): + if conf().get("expires_in_seconds"): + sessions = ExpiredDict(conf().get("expires_in_seconds")) + else: + sessions = dict() + self.sessions = sessions + self.sessioncls = sessioncls + self.session_kwargs = session_kwargs + + def _build_session(self, session_id: str, user: str): + """ + 如果session_id不在sessions中,创建一个新的session并添加到sessions中 + """ + if session_id is None: + return self.sessioncls(session_id, user) + + if session_id not in self.sessions: + self.sessions[session_id] = self.sessioncls(session_id, user) + session = self.sessions[session_id] + return session + + def get_session(self, session_id, user): + session = self._build_session(session_id, user) + return session + + def clear_session(self, session_id): + if session_id in self.sessions: + del self.sessions[session_id] + + def clear_all_session(self): + self.sessions.clear() diff --git a/bridge/bridge.py b/bridge/bridge.py index b7b3ebf84..6b7c6c96e 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -38,6 +38,8 @@ def __init__(self): self.btype["chat"] = const.QWEN_DASHSCOPE if model_type and model_type.startswith("gemini"): self.btype["chat"] = const.GEMINI + if model_type in [const.DIFY]: + self.btype["chat"] = const.DIFY if model_type in [const.ZHIPU_AI]: self.btype["chat"] = const.ZHIPU_AI if model_type and model_type.startswith("claude-3"): diff --git a/common/const.py b/common/const.py index 68d3795cd..130835802 100644 --- a/common/const.py +++ b/common/const.py @@ -11,10 +11,13 @@ QWEN_DASHSCOPE = "dashscope" # 通义新版sdk和api key + GEMINI = "gemini" # gemini-1.0-pro ZHIPU_AI = "glm-4" MOONSHOT = "moonshot" MiniMax = "minimax" +DIFY = "dify" + # model @@ -41,6 +44,7 @@ TTS_1 = "tts-1" TTS_1_HD = "tts-1-hd" + WEN_XIN = "wenxin" WEN_XIN_4 = "wenxin-4" @@ -65,7 +69,7 @@ "claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3-opus-20240229", "claude-3.5-sonnet", "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k", QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX, - LINKAI_35, LINKAI_4_TURBO, LINKAI_4o + LINKAI_35, LINKAI_4_TURBO, LINKAI_4o, DIFY ] # channel diff --git a/config-template.json b/config-template.json index d0268d3b1..cbd2f204a 100644 --- a/config-template.json +++ b/config-template.json @@ -3,6 +3,8 @@ "model": "", "open_ai_api_key": "YOUR API KEY", "claude_api_key": "YOUR API KEY", + "dify_app_type": "chatbot", + "dify_convsersation_max_messages": 5, "text_to_image": "dall-e-2", "voice_to_text": "openai", "text_to_voice": "openai", diff --git a/config.py b/config.py index cad68723d..d7b411ab0 100644 --- a/config.py +++ b/config.py @@ -88,6 +88,9 @@ "dashscope_api_key": "", # Google Gemini Api Key "gemini_api_key": "", + # dify配置 + "dify_app_type": "chatbot", # dify助手类型 chatbot(对应聊天助手)/agent(对应Agent)/workflow(对应工作流),默认为chatbot + "dify_convsersation_max_messages": 5, # dify目前不支持设置历史消息长度,暂时使用超过最大消息数清空会话的策略,缺点是没有滑动窗口,会突然丢失历史消息 # wework的通用配置 "wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开 # 语音设置 @@ -148,7 +151,7 @@ "dingtalk_client_id": "", # 钉钉机器人Client ID "dingtalk_client_secret": "", # 钉钉机器人Client Secret "dingtalk_card_enabled": False, - + # chatgpt指令自定义触发词 "clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头 # channel配置 diff --git a/docker/build.latest.sh b/docker/build.latest.sh index 92c356497..b70b54aa0 100644 --- a/docker/build.latest.sh +++ b/docker/build.latest.sh @@ -5,4 +5,4 @@ unset KUBECONFIG cd .. && docker build -f docker/Dockerfile.latest \ -t zhayujie/chatgpt-on-wechat . -docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$(date +%y%m%d) \ No newline at end of file +docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$(date +%y%m%d) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index d6383bfa8..f214cac97 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -36,6 +36,11 @@ "alias": ["model", "模型"], "desc": "查看和设置全局模型", }, + "set_openai_api_base": { + "alias": ["set_openai_api_base"], + "args": ["api_base"], + "desc": "设置你的OpenAI私有api_base", + }, "set_openai_api_key": { "alias": ["set_openai_api_key"], "args": ["api_key"], @@ -138,7 +143,7 @@ def get_help_text(isadmin, isgroup): help_text = "通用指令\n" for cmd, info in COMMANDS.items(): - if cmd in ["auth", "set_openai_api_key", "reset_openai_api_key", "set_gpt_model", "reset_gpt_model", "gpt_model"]: # 不显示帮助指令 + if not isadmin and cmd in ["auth", "set_openai_api_base", "set_openai_api_key", "reset_openai_api_key", "set_gpt_model", "reset_gpt_model", "gpt_model"]: # 不显示帮助指令 continue if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]: continue @@ -156,7 +161,7 @@ def get_help_text(isadmin, isgroup): if plugins[plugin].enabled and not plugins[plugin].hidden: namecn = plugins[plugin].namecn help_text += "\n%s:" % namecn - help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip() + help_text += PluginManager().instances[plugin].get_help_text(verbose=True).strip() if ADMIN_COMMANDS and isadmin: help_text += "\n\n管理员指令:\n" @@ -278,6 +283,13 @@ def on_handle_context(self, e_context: EventContext): ok, result = True, "模型设置为: " + str(model) elif cmd == "id": ok, result = True, user + elif cmd == "set_openai_api_base": + if len(args) == 1: + user_data = conf().get_user_data(user) + user_data["open_ai_api_base"] = args[0] + ok, result = True, "你的OpenAI私有api_base已设置为" + args[0] + else: + ok, result = False, "请提供一个api_base" elif cmd == "set_openai_api_key": if len(args) == 1: user_data = conf().get_user_data(user) @@ -313,7 +325,7 @@ def on_handle_context(self, e_context: EventContext): except Exception as e: ok, result = False, "你没有设置私有GPT模型" elif cmd == "reset": - if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI]: + if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.DIFY, const.ZHIPU_AI]: bot.sessions.clear_session(session_id) if Bridge().chat_bots.get(bottype): Bridge().chat_bots.get(bottype).sessions.clear_session(session_id) @@ -339,7 +351,7 @@ def on_handle_context(self, e_context: EventContext): ok, result = True, "配置已重载" elif cmd == "resetall": if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, - const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT]: + const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.DIFY, const.ZHIPU_AI, const.MOONSHOT]: channel.cancel_all_session() bot.sessions.clear_all_session() ok, result = True, "重置所有会话成功" diff --git a/plugins/midjourney/__init__.py b/plugins/midjourney/__init__.py new file mode 100644 index 000000000..9db7f7120 --- /dev/null +++ b/plugins/midjourney/__init__.py @@ -0,0 +1 @@ +from .midjourney import * diff --git a/plugins/midjourney/config.json.template b/plugins/midjourney/config.json.template new file mode 100644 index 000000000..0ae112a00 --- /dev/null +++ b/plugins/midjourney/config.json.template @@ -0,0 +1,9 @@ +{ + "user_drawing_mode":"relax", + "group_drawing_mode":"relax", + "default_drawing_mode":"relax", + "use_image_create_prefix":true, + "mj_proxy_server":"", + "mj_proxy_api_secret":"", + "mj_trigger_prefix":"/" +} \ No newline at end of file diff --git a/plugins/midjourney/midjourney.py b/plugins/midjourney/midjourney.py new file mode 100644 index 000000000..64e28155e --- /dev/null +++ b/plugins/midjourney/midjourney.py @@ -0,0 +1,514 @@ +# encoding:utf-8 +import base64 +import io +import re +import threading +import time + +import requests +from PIL import Image + +import plugins +from bridge.context import ContextType, Context +from bridge.reply import Reply, ReplyType +from channel import channel_factory +from channel.chat_message import ChatMessage +from common.expired_dict import ExpiredDict +from plugins import * + + +@plugins.register( + name="Midjourney", + desire_priority=98, + hidden=False, + desc="AI drawing plugin of midjourney", + version="1.0", + author="baojingyu", +) +class Midjourney(Plugin): + def __init__(self): + super().__init__() + # 获取当前文件的目录 + curdir = os.path.dirname(__file__) + # 配置文件的路径 + config_path = os.path.join(curdir, "config.json") + # 如果配置文件不存在 + if not os.path.exists(config_path): + # 输出日志信息,配置文件不存在,将使用模板 + logger.info('[Midjourney] 配置文件不存在,将使用config.json.template模板') + # 模板配置文件的路径 + config_path = os.path.join(curdir, "config.json.template") + # 打开并读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + # 加载 JSON 文件 + self.mj_plugin_config = json.load(f) + # 用户绘图模式 + self.user_drawing_mode = self.mj_plugin_config.get("user_drawing_mode", "relax") + # 群聊绘图模式 + self.group_drawing_mode = self.mj_plugin_config.get("group_drawing_mode", "relax") + # 默认绘图模式 + self.default_drawing_mode = self.mj_plugin_config.get("default_drawing_mode", "relax") + # 使用图像创建前缀,搭配image_create_prefix使用 + self.use_image_create_prefix = self.mj_plugin_config.get("default_drawing_mode", True) + self.mj_trigger_prefix = self.mj_plugin_config.get("mj_trigger_prefix", "/") + # 需要搭建Mindjourney Proxy https://github.com/novicezk/midjourney-proxy/blob/main/README_CN.md + self.mj_proxy_server = self.mj_plugin_config.get("mj_proxy_server") + self.mj_proxy_api_secret = self.mj_plugin_config.get("mj_proxy_api_secret", "") + if not self.mj_proxy_server: + logger.error( + f"[Midjourney] Initialization failed, missing required parameters , config={self.mj_plugin_config}") + # 获取 PluginManager 的单例实例 + plugin_manager = PluginManager() + # 停用Midjourney + plugin_manager.disable_plugin("Midjourney") + return + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + self.proxy = conf().get("proxy") + if self.proxy: + self.proxies = { + "http": self.proxy, + "https": self.proxy + } + else: + self.proxies = None + # 根据channel_type 动态创建通道 + self.channel_type = conf().get("channel_type") + self.channel = channel_factory.create_channel(self.channel_type) + self.task_id_dict = ExpiredDict(conf().get("expires_in_seconds",60 * 60)) + self.task_msg_dict = ExpiredDict(conf().get("expires_in_seconds",60 * 60)) + self.cmd_dict = ExpiredDict(conf().get("expires_in_seconds",60 * 60)) + # 批量查询任务结果 + self.batch_size = 10 + self.semaphore = threading.Semaphore(1) + self.lock = threading.Lock() # 用于控制对sessions的访问 + self.thread = threading.Thread(target=self.background_query_task_result) + self.thread.start() + logger.info(f"[Midjourney] inited, config={self.mj_plugin_config}") + + def on_handle_context(self, e_context: EventContext): + if not self.mj_plugin_config: + return + + context = e_context['context'] + if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]: + return + msg: ChatMessage = e_context["context"]["msg"] + logger.info(f"[Midjourney] context msg={msg}") + state = "" + # 检查 msg.other_user_id 和 msg.actual_user_nickname 是否为 None,如果是,则将它们替换为空字符串 + other_user_id = msg.other_user_id if msg.other_user_id else "" + actual_user_nickname = msg.actual_user_nickname if msg.actual_user_nickname else "" + if not msg.is_group: + state = "u:" + other_user_id + ":" + actual_user_nickname + else: + state = "r:" + other_user_id + ":" + actual_user_nickname + # Midjourney 作图任务 + self.process_midjourney_task(state, e_context) + + # imagine 命令:处理图片生成请求,并根据优先级添加模式标识。 + # up 命令:处理任务按钮的操作请求。 + # img2img 命令:处理图像到图像的生成请求。 + # describe 命令:处理图像描述请求。 + # shorten 命令:处理文本缩短请求。 + # seed 命令:获取任务图片的 seed 值。 + # query 命令:查询任务的状态。 + def process_midjourney_task(self, state, e_context: EventContext): + content = e_context["context"].content + msg: ChatMessage = e_context["context"]["msg"] + isgroup = msg.is_group + result = None + prompt = "" + try: + # 获取配置中的触发前缀和图片生成前缀列表 + image_create_prefixes = conf().get("image_create_prefix", []) + + # 处理图片生成的前缀 + if e_context["context"].type == ContextType.IMAGE_CREATE and self.mj_plugin_config.get( + "use_image_create_prefix"): + # 创建一个正则模式来匹配所有可能的前缀 + prefix_pattern = '|'.join(map(re.escape, image_create_prefixes)) + # 使用正则表达式只在字符串开头匹配前缀并替换 + content = re.sub(f'^(?:{prefix_pattern})', f"{self.mj_trigger_prefix}imagine ", msg.content, count=1) + logger.debug(f"[Midjourney] ole_content: {msg.content} , new_content: {content}") + + # 处理 imagine 命令 + if content.startswith(f"{self.mj_trigger_prefix}imagine "): + prompt = content[9:] + + # 检查用户是否已经输入了模式标识 + if not any(flag in prompt for flag in ["--relax", "--fast", "--turbo"]): + # 根据优先级添加模式标识 + if not isgroup and is_valid_mode(self.user_drawing_mode): + prompt += f" --{self.user_drawing_mode}" + elif isgroup and is_valid_mode(self.group_drawing_mode): + prompt += f" --{self.group_drawing_mode}" + elif is_valid_mode(self.default_drawing_mode): + prompt += f" --{self.default_drawing_mode}" + + # 处理 imagine 请求 + result = self.handle_imagine(prompt, state) + + # 处理 up 命令 + elif content.startswith(f"{self.mj_trigger_prefix}up "): + arr = content[4:].split() + try: + task_id = arr[0] + index = int(arr[1]) + except Exception as e: + e_context["reply"] = Reply(ReplyType.TEXT, '❌ 您的任务提交失败\nℹ️ 参数错误') + e_context.action = EventAction.BREAK_PASS + return + + # 获取任务 + task = self.get_task(task_id) + if task is None: + e_context["reply"] = Reply(ReplyType.TEXT, '❌ 您的任务提交失败\nℹ️ 任务ID不存在') + e_context.action = EventAction.BREAK_PASS + return + + # 检查按钮序号是否正确 + if index > len(task['buttons']): + e_context["reply"] = Reply(ReplyType.TEXT, '❌ 您的任务提交失败\nℹ️ 按钮序号不正确') + e_context.action = EventAction.BREAK_PASS + return + + # 获取按钮 + button = task['buttons'][index - 1] + if button['label'] == 'Custom Zoom': + e_context["reply"] = Reply(ReplyType.TEXT, '❌ 您的任务提交失败\nℹ️ 暂不支持自定义变焦') + e_context.action = EventAction.BREAK_PASS + return + + # 发送请求 + result = self.post_json('/submit/action', + {'customId': button['customId'], 'taskId': task_id, 'state': state}) + if result.get("code") == 21: + result = self.post_json('/submit/modal', + {'taskId': result.get("result"), 'state': state}) + + # 处理 img2img 命令 + elif content.startswith(f"{self.mj_trigger_prefix}img2img "): + self.cmd_dict[msg.actual_user_id] = content + e_context["reply"] = Reply(ReplyType.TEXT, '请给我发一张图片作为垫图') + e_context.action = EventAction.BREAK_PASS + return + + # 处理 describe 命令 + elif content == f"{self.mj_trigger_prefix}describe": + self.cmd_dict[msg.actual_user_id] = content + e_context["reply"] = Reply(ReplyType.TEXT, '请给我发一张图片用于图生文') + e_context.action = EventAction.BREAK_PASS + return + + # 处理 shorten 命令 + elif content.startswith(f"{self.mj_trigger_prefix}shorten "): + result = self.handle_shorten(content[9:], state) + + # 处理 seed 命令 + elif content.startswith(f"{self.mj_trigger_prefix}seed "): + task_id = content[6:] + result = self.get_task_image_seed(task_id) + if result.get("code") == 1: + e_context["reply"] = Reply(ReplyType.TEXT, '✅ 获取任务图片seed成功\n📨 任务ID: %s\n🔖 seed值: %s' % ( + task_id, result.get("result"))) + else: + e_context["reply"] = Reply(ReplyType.TEXT, '❌ 获取任务图片seed失败\n📨 任务ID: %s\nℹ️ %s' % ( + task_id, result.get("description"))) + e_context.action = EventAction.BREAK_PASS + return + + # 处理图片消息 + elif e_context["context"].type == ContextType.IMAGE: + cmd = self.cmd_dict.get(msg.actual_user_id) + if not cmd: + return + msg.prepare() + self.cmd_dict.pop(msg.actual_user_id) + if f"{self.mj_trigger_prefix}describe" == cmd: + result = self.handle_describe(content, state) + elif cmd.startswith(f"{self.mj_trigger_prefix}img2img "): + result = self.handle_img2img(content, cmd[9:], state) + else: + return + + # 处理 query 命令 + elif content.startswith(f"{self.mj_trigger_prefix}query "): + arr = content[7:].split() + try: + task_id = arr[0] + except Exception as e: + e_context["reply"] = Reply(ReplyType.TEXT, '❌ 您的任务查询失败\nℹ️ 参数错误') + e_context.action = EventAction.BREAK_PASS + return + # 查询任务 + task = self.get_task(task_id) + if task is None: + e_context["reply"] = Reply(ReplyType.TEXT, '❌ 您的任务查询失败\nℹ️ 任务ID不存在') + e_context.action = EventAction.BREAK_PASS + return + self.add_task(task_id, msg) + e_context.action = EventAction.BREAK_PASS + return + else: + return + except Exception as e: + logger.exception("[Midjourney] handle failed: %s" % e) + result = {'code': -9, 'description': '服务异常, 请稍后再试'} + + # 处理请求结果 + code = result.get("code") + if code == 1: + task_id = result.get("result") + self.add_task(task_id, msg) + + # 根据 prompt 中的标识设置模式说明 + mode_description = "" + if "--relax" in prompt: + mode_description = "ℹ️ Relax模式任务的等待时间通常为1-10分钟" + reply_text = f'✅ 您的任务已提交\n🚀 正在快速处理中,请稍后\n📨 任务ID: {task_id}\n{mode_description}' + e_context["reply"] = Reply(ReplyType.TEXT, reply_text) + elif code == 22: + self.add_task(result.get("result"), msg) + e_context["reply"] = Reply(ReplyType.TEXT, f'✅ 您的任务已提交\n⏰ {result.get("description")}') + else: + e_context["reply"] = Reply(ReplyType.TEXT, f'❌ 您的任务提交失败\nℹ️ {result.get("description")}') + e_context.action = EventAction.BREAK_PASS + + def handle_imagine(self, prompt, state): + return self.post_json('/submit/imagine', {'prompt': prompt, 'state': state}) + + def handle_describe(self, img_data, state): + + base64_str = self.image_file_to_base64(img_data) + return self.post_json('/submit/describe', {'base64': base64_str, 'state': state}) + + def handle_shorten(self, prompt, state): + return self.post_json('/submit/shorten', {'prompt': prompt, 'state': state}) + + def handle_img2img(self, img_data, prompt, state): + base64_str = self.image_file_to_base64(img_data) + return self.post_json('/submit/imagine', {'prompt': prompt, 'base64': base64_str, 'state': state}) + + def post_json(self, api_path, data): + return requests.post(url=self.mj_proxy_server + api_path, json=data, + headers={'mj-api-secret': self.mj_proxy_api_secret}).json() + + def get_task(self, task_id): + return requests.get(url=self.mj_proxy_server + '/task/%s/fetch' % task_id, + headers={'mj-api-secret': self.mj_proxy_api_secret}).json() + + def get_task_image_seed(self, task_id): + return requests.get(url=self.mj_proxy_server + '/task/%s/image-seed' % task_id, + headers={'mj-api-secret': self.mj_proxy_api_secret}).json() + + def query_tasks_by_ids(self, task_ids): + return self.post_json('/task/list-by-condition', {'ids': task_ids}) + + def add_task(self, task_id, msg): + # 将任务ID存储到任务ID字典中 + self.task_id_dict[task_id] = 'NOT_START' + # 将任务ID和消息信息关联存储到 task_msg_dict 字典中 + self.task_msg_dict[task_id] = msg + + def background_query_task_result(self): + while True: + with self.lock: + task_ids = list(self.task_id_dict.keys()) + + if task_ids: + num_batches = (len(task_ids) + self.batch_size - 1) // self.batch_size # 计算批次数量 + logger.debug("[Midjourney] background query task result running, size [%s]", len(task_ids)) + for i in range(num_batches): + # 获取当前批次的任务ID列表 + batch = task_ids[i * self.batch_size:(i + 1) * self.batch_size] + + self.handle_task_batch(batch) + + # 等待所有任务处理完成 + for _ in batch: + self.semaphore.acquire() + + # 避免过度占用CPU资源,适当休眠 + time.sleep(0.5) + + def handle_task_batch(self, task_ids): + tasks = self.query_tasks_by_ids(task_ids) # 批量查询任务 + if tasks is not None and len(tasks) > 0: + logger.debug( + f"[Midjourney] background handle task batch running, size {len(task_ids)}, taskIds [{','.join(task_ids)}]", ) + # 将 tasks 转换成键值对结构 + tasks_map = {task['id']: task for task in tasks} + for task_id in task_ids: + task = tasks_map.get(task_id) + self.process_task(task, task_id) + else: + # 如果没有返回任务,释放所有的信号量 + for _ in task_ids: + self.semaphore.release() + + def process_task(self, task, task_id): + if task is None: + self.handle_not_exist_task(task, task_id) + else: + self.handle_exist_task(task, task_id) + + # 只在这里释放批处理信号量 + self.semaphore.release() + + def handle_exist_task(self, task, task_id): + context = Context() + # 获取当前任务ID对应的消息信息 + msg = self.task_msg_dict.get(task_id) + # 在已有的context中存储消息信息 + context.kwargs['msg'] = msg + context.__setitem__("msg", msg) + state = task.get("state",None) + if state is None: + # 检查 msg.other_user_id 和 msg.actual_user_nickname 是否为 None,如果是,则将它们替换为空字符串 + other_user_id = msg.other_user_id if msg.other_user_id else "" + actual_user_nickname = msg.actual_user_nickname if msg.actual_user_nickname else "" + if not msg.is_group: + state = "u:" + other_user_id + ":" + actual_user_nickname + else: + state = "r:" + other_user_id + ":" + actual_user_nickname + + state_array = state.split(':', 2) + reply_prefix = self.extract_state_info(state_array) + context.__setitem__("receiver", reply_prefix) + + + reply = self.generate_reply(task_id, task, context, reply_prefix) + if reply is not None: + self.channel.send(reply, context) + else: + logger.debug( + f"[Midjourney] handle task_id: {task_id} , status :{task['status']} , progress : {task['progress']}") + + def handle_not_exist_task(self, task, task_id): + context = Context() + msg = self.task_msg_dict.get(task_id) + context.kwargs['msg'] = msg + context.__setitem__("msg", msg) + + state = task.get("state",None) + if state is None: + # 检查 msg.other_user_id 和 msg.actual_user_nickname 是否为 None,如果是,则将它们替换为空字符串 + other_user_id = msg.other_user_id if msg.other_user_id else "" + actual_user_nickname = msg.actual_user_nickname if msg.actual_user_nickname else "" + if not msg.is_group: + state = "u:" + other_user_id + ":" + actual_user_nickname + else: + state = "r:" + other_user_id + ":" + actual_user_nickname + state_array = state.split(':', 2) + reply_prefix = self.extract_state_info(state_array) + context.__setitem__("receiver", reply_prefix) + + reply = Reply(ReplyType.TEXT, '❌ 您的任务执行失败\nℹ️ 任务ID不存在\n📨 任务ID: %s' % (task_id)) + + self.channel.send(reply, context) + + logger.debug("[Midjourney] 任务执行失败 , 任务ID不存在: " + task_id) + self.task_id_dict.pop(task_id) + self.task_msg_dict.pop(task_id) + + def extract_state_info(self, state_array=None): + if not state_array: + receiver = state_array[1] if len(state_array) > 1 else None + reply_prefix = '@%s ' % state_array[2] if state_array[0] == 'r' else '' + return reply_prefix + return "" + + def generate_reply(self, task_id, task, context:Context, reply_prefix=''): + status = task['status'] + action = task['action'] + description = task.get('description', 'No description available') + context.__setitem__("promptEn", task['promptEn']) + if status == 'SUCCESS': + logger.debug("[Midjourney] 任务已完成: " + task_id) + self.task_id_dict.pop(task_id) + self.task_msg_dict.pop(task_id) + image_url = task.get('imageUrl', None) + + context.__setitem__("description", description) + context.__setitem__("image_url", image_url) + if action == 'DESCRIBE' or action == 'SHORTEN': + prompt = task['properties']['finalPrompt'] + reply_text = f"✅ 任务已完成\n📨 任务ID: {task_id}\n✨ {description}\n\n{self.get_buttons(task)}\n💡 使用 {self.mj_trigger_prefix}up 任务ID 序号执行动作\n🔖 {self.mj_trigger_prefix}up {task_id} 1" + return Reply(ReplyType.TEXT, reply_text) + elif action == 'UPSCALE': + reply_text = f"✅ 任务已完成\n📨 任务ID: {task_id}\n✨ {description}\n\n{self.get_buttons(task)}\n💡 使用 {self.mj_trigger_prefix}up 任务ID 序号执行动作\n🔖 {self.mj_trigger_prefix}up {task_id} 1" + return Reply(ReplyType.TEXT, reply_text) + else: + # image_storage = self.download_and_compress_image(image_url) + reply_text = f"✅ 任务已完成\n📨 任务ID: {task_id}\n✨ {description}\n\n{self.get_buttons(task)}\n💡 使用 {self.mj_trigger_prefix}up 任务ID 序号执行动作\n🔖 {self.mj_trigger_prefix}up {task_id} 1" + return Reply(ReplyType.TEXT, reply_text) + elif status == 'FAILURE': + self.task_id_dict.pop(task_id) + self.task_msg_dict.pop(task_id) + reply_text = f'❌ 任务执行失败\n📨 任务ID: {task_id}\n📒 失败原因: {task["failReason"]}\n✨ {description}' + return Reply(ReplyType.TEXT, reply_text) + + def image_file_to_base64(self, file_path): + with open(file_path, "rb") as image_file: + img_data = image_file.read() + img_base64 = base64.b64encode(img_data).decode("utf-8") + os.remove(file_path) + return "data:image/png;base64," + img_base64 + + def get_buttons(self, task): + res = '' + index = 1 + for button in task['buttons']: + name = button['emoji'] + button['label'] + if name in ['🎉Imagine all', '❤️']: + continue + res += ' %d - %s\n' % (index, name) + index += 1 + return res + + def download_and_compress_image(self, img_url, max_size=(800, 800)): # 下载并压缩图片 + headers = { + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36' + } + # 设置代理 + # self.proxies + # , proxies=self.proxies + pic_res = requests.get(img_url, headers=headers, stream=True, timeout=60 * 5) + image_storage = io.BytesIO() + size = 0 + for block in pic_res.iter_content(1024): + size += len(block) + image_storage.write(block) + image_storage.seek(0) + logger.debug(f"[MJ] download image success, size={size}, img_url={img_url}") + # 压缩图片 + initial_image = Image.open(image_storage) + initial_image.thumbnail(max_size) + output = io.BytesIO() + initial_image.save(output, format=initial_image.format) + output.seek(0) + return output + + # 检查模式是否有效 + + def get_help_text(self, verbose=False, **kwargs): + help_text = "这是一个能调用midjourney实现ai绘图的扩展能力。\n" + if not verbose: + return help_text + help_text += "使用说明: \n" + help_text += f"{self.mj_trigger_prefix}imagine 根据给出的提示词绘画;\n" + help_text += f"{self.mj_trigger_prefix}img2img 根据提示词+垫图生成图;\n" + help_text += f"{self.mj_trigger_prefix}up 任务ID 序号执行动作;\n" + help_text += f"{self.mj_trigger_prefix}describe 图片转文字;\n" + help_text += f"{self.mj_trigger_prefix}shorten 提示词分析;\n" + help_text += f"{self.mj_trigger_prefix}seed 获取任务图片的seed值;\n" + help_text += f"{self.mj_trigger_prefix}query 任务ID 查询任务进度;\n" + help_text += f"默认使用🐢 Relax绘图,也可以在提示词末尾使用 `--relax` 或 `--fast` 参数运行单个作业;\n" + image_create_prefixes = conf().get("image_create_prefix", []) + if image_create_prefixes and self.mj_plugin_config.get("use_image_create_prefix",False): + prefixes = ", ".join(image_create_prefixes) + help_text += f"支持图片回复前缀关键字:{prefixes}。\n使用格式:{image_create_prefixes[0]}一棵装饰着金色雪花和金色饰品的圣诞树,周围是地板上的礼物。房间是白色的,有浅色木材的装饰,一侧有一个壁炉,大窗户望向户外花园。一颗星星挂在高约三米的绿色松树顶上。这是一个充满节日庆祝气氛的优雅场景,充满了温暖和欢乐。一张超逼真的照片,以高分辨率2000万像素相机的风格拍摄。\n" + return help_text + +def is_valid_mode(mode): + return mode in ["relax", "fast", "turbo"]