From 8a29d277afd3db4cdaf0bcc15b804e690ff4d84a Mon Sep 17 00:00:00 2001 From: gjl <2802427218@qq.com> Date: Mon, 19 Aug 2024 14:37:30 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E8=A7=A3=E5=86=B3agent=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mindpilot/app/chat/chat.py | 5 ++-- .../app/conversation/conversation_api.py | 25 ++++++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index b21a84c..f8376bb 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -17,7 +17,7 @@ from ..callback_handler.agent_callback_handler import ( ) from ..chat.utils import History from ..configs import MODEL_CONFIG, TOOL_CONFIG, OPENAI_PROMPT, PROMPT_TEMPLATES -from ..utils.system_utils import get_ChatOpenAI, get_tool, wrap_done, MsgType +from ..utils.system_utils import get_ChatOpenAI, get_tool, wrap_done, MsgType, get_mindpilot_db_connection from ..agent.utils import get_agent_from_id @@ -253,7 +253,8 @@ async def chat_online( history: List[History], chat_model_config: dict, tool_config: List[str], - agent_id: int + agent_id: int, + conversation_id: str ): async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: callback = AgentExecutorAsyncIteratorCallbackHandler() diff --git a/src/mindpilot/app/conversation/conversation_api.py b/src/mindpilot/app/conversation/conversation_api.py index bde91fe..a7fc625 100644 --- a/src/mindpilot/app/conversation/conversation_api.py +++ b/src/mindpilot/app/conversation/conversation_api.py @@ -225,10 +225,27 @@ async def send_messages( chat_model_config["llm_model"][model_key]["temperature"] = temperature chat_model_config["llm_model"][model_key]["max_tokens"] = max_tokens - if len(history) == 0: - summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。' + is_summery = False + + if agent_id == -1: + cursor.execute('''SELECT agent_id FROM conversations WHERE conversation_id = ?''', (conversation_id,)) + temp_agent_id = cursor.fetchone()[0] + if temp_agent_id != -1: + temp_agent = get_agent_from_id(temp_agent_id) + temp_agent_name = temp_agent["agent_name"] + temp_agent_abstract = temp_agent["agent_abstract"] + temp_agent_info = temp_agent["agent_info"] + agent_prompt = "Your name is " + temp_agent_name + "." + temp_agent_abstract + ". Below is your detailed information:" + temp_agent_info + "." + history.append({"role": "user", "content": agent_prompt}) + is_summery = True + + if len(history) == 0 or is_summery == True: + if len(history) == 0: + summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。' + if is_summery == True: + summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + agent_prompt + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。' summery = await chat_online(content=summery_prompt, history=[], chat_model_config=chat_model_config, - agent_id=-1, tool_config=tool_config) + agent_id=-1, tool_config=tool_config, conversation_id=conversation_id) summery_content = summery[0]['choices'][0]['delta']['content'] try: summery_content = json.loads(summery_content)["title"] @@ -243,7 +260,7 @@ async def send_messages( # 获取模型输出 ret = await chat_online(content=text, history=history, chat_model_config=chat_model_config, - tool_config=tool_config, agent_id=agent_id) + tool_config=tool_config, agent_id=agent_id, conversation_id=conversation_id) response_messages = [] for message in ret: