diff --git a/src/mindpilot/app/api/conversation_routes.py b/src/mindpilot/app/api/conversation_routes.py index 54d948f..38124ea 100644 --- a/src/mindpilot/app/api/conversation_routes.py +++ b/src/mindpilot/app/api/conversation_routes.py @@ -1,6 +1,6 @@ from __future__ import annotations from fastapi import APIRouter, Request -from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages +from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages, debug_messages conversation_router = APIRouter(prefix="/api/conversation", tags=["对话接口"]) @@ -28,3 +28,8 @@ conversation_router.post( "/{conversation_id}/messages", summary="发送消息" )(send_messages) + +conversation_router.post( + "/debug", + summary="调试对话", +)(debug_messages) diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index f8376bb..80f894a 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -1,7 +1,7 @@ import asyncio import json import uuid -from typing import AsyncIterable, List +from typing import AsyncIterable, List, Dict, Any from fastapi import Body from langchain.chains import LLMChain @@ -378,3 +378,127 @@ async def chat_online( ret.append(data) return ret + + +async def debug_chat_online( + content: str, + history: List[History], + chat_model_config: dict, + tool_config: List[str], + agent_config: Dict[str, Any] +): + async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: + callback = AgentExecutorAsyncIteratorCallbackHandler() + callbacks = [callback] + + model, prompt = create_models_from_config( + callbacks=callbacks, configs=chat_model_config, stream=False + ) + + all_tools = get_tool().values() + tool_configs = tool_config + + agent_enable = agent_config["agent_enable"] + if agent_enable: + agent_name = agent_config["agent_name"] + agent_abstract = agent_config["agent_abstract"] + agent_info = agent_config["agent_info"] + if not tool_config: + tool_configs = agent_config["tool_config"] + agent_prompt_pre = "Your name is " + agent_name + "." + agent_abstract + ". Below is your detailed information:" + agent_info + "." + agent_prompt_after = "DO NOT forget " + agent_prompt_pre + prompt = agent_prompt_pre + prompt + agent_prompt_after + # TODO 处理知识库 + + tool_configs = tool_configs or TOOL_CONFIG + tools = [tool for tool in all_tools if tool.name in tool_configs] + tools = [t.copy(update={"callbacks": callbacks}) for t in tools] + + full_chain = create_models_chains( + prompts=prompt, + models=model, + tools=tools, + callbacks=callbacks, + history=history, + agent_enable=agent_enable + ) + + _history = [History.from_data(h) for h in history] + chat_history = [h.to_msg_tuple() for h in _history] + + history_message = convert_to_messages(chat_history) + + task = asyncio.create_task( + wrap_done( + full_chain.ainvoke( + { + "input": content, + "chat_history": history_message, + } + ), + callback.done, + ) + ) + + last_tool = {} + async for chunk in callback.aiter(): + data = json.loads(chunk) + data["tool_calls"] = [] + data["message_type"] = MsgType.TEXT + + if data["status"] == AgentStatus.tool_start: + last_tool = { + "index": 0, + "id": data["run_id"], + "type": "function", + "function": { + "name": data["tool"], + "arguments": data["tool_input"], + }, + "tool_output": None, + "is_error": False, + } + data["tool_calls"].append(last_tool) + if data["status"] in [AgentStatus.tool_end]: + last_tool.update( + tool_output=data["tool_output"], + is_error=data.get("is_error", False), + ) + data["tool_calls"] = [last_tool] + last_tool = {} + try: + tool_output = json.loads(data["tool_output"]) + if message_type := tool_output.get("message_type"): + data["message_type"] = message_type + except: + ... + elif data["status"] == AgentStatus.agent_finish: + try: + tool_output = json.loads(data["text"]) + if message_type := tool_output.get("message_type"): + data["message_type"] = message_type + except: + ... + + ret = OpenAIChatOutput( + id=f"chat{uuid.uuid4()}", + object="chat.completion.chunk", + content=data.get("text", ""), + role="assistant", + tool_calls=data["tool_calls"], + model=model.model_name, + status=data["status"], + message_type=data["message_type"], + ) + yield ret.model_dump_json() + + await task + + ret = [] + + async for chunk in chat_iterator(): + data = json.loads(chunk) + if data["status"] != AgentStatus.llm_start and data["status"] != AgentStatus.llm_new_token: + ret.append(data) + + return ret diff --git a/src/mindpilot/app/conversation/conversation_api.py b/src/mindpilot/app/conversation/conversation_api.py index a7fc625..f65798c 100644 --- a/src/mindpilot/app/conversation/conversation_api.py +++ b/src/mindpilot/app/conversation/conversation_api.py @@ -6,11 +6,13 @@ from fastapi import Body from uuid import uuid4 from datetime import datetime import sqlite3 + +from ..chat.utils import History from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection from .message import init_messages_table, insert_message, split_message_content from ..model_configs.utils import get_config_from_id from ..agent.utils import get_agent_from_id -from ..chat.chat import chat_online +from ..chat.chat import chat_online, debug_chat_online def init_conversations_table(): @@ -185,6 +187,7 @@ async def send_messages( max_tokens: int = Body(..., description="模型输出最大长度", examples=[4096]), ): + # TODO 缺少知识库部分 init_conversations_table() init_messages_table() conn = get_mindpilot_db_connection() @@ -316,3 +319,83 @@ async def send_messages( conn.close() return BaseResponse(code=200, msg="success", data=response_messages) + + +async def debug_messages( + query: str = Body(..., description="用户输入", examples=[""]), + history: List[History] = Body( + [], + description="历史对话", + examples=[ + [ + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "您好,我是智能Agent桌面助手MindPilot,请问有什么可以帮您?"}, + ] + ], + ), + config_id: int = Body("0", description="模型配置", examples=[1]), + agent_config: dict = Body(..., description="agent配置", examples=[ + { + "agent_name": "调试助手", + "agent_abstract": "", + "agent_info": "这是一个用于调试的AI助手", + "agent_enable": True, + "temperature": 0.7, + "max_tokens": 150, + "tool_config": ["search_internet", "calculator"] + } + ]) +): + # TODO 缺少知识库部分 + + if not agent_config["agent_enable"]: + temp_agent_name = agent_config["agent_name"] + temp_agent_abstract = agent_config["agent_abstract"] + temp_agent_info = agent_config["agent_info"] + agent_prompt = "Your name is " + temp_agent_name + "." + temp_agent_abstract + ". Below is your detailed information:" + temp_agent_info + "." + history.append({"role": "user", "content": agent_prompt}) + + # 获取模型配置 + chat_model_config = get_config_from_id(config_id=config_id) + model_key = next(iter(chat_model_config["llm_model"])) + chat_model_config["llm_model"][model_key]["temperature"] = agent_config['temperature'] + chat_model_config["llm_model"][model_key]["max_tokens"] = agent_config['max_tokens'] + + # 获取模型输出 + ret = await debug_chat_online(content=query, history=history, chat_model_config=chat_model_config, + tool_config=agent_config['tool_config'], agent_config=agent_config) + + response_messages = [] + for message in ret: + if message['status'] == 7: + message_role = message['choices'][0]['role'] + message_content = "Observation:\n" + message['choices'][0]['delta']['tool_calls'][0]['tool_output'] + timestamp = datetime.now().isoformat() + message_dict = { + "message_id": 0, + "agent_status": 7, + "text": message_content, + "files": [], + "timestamp": timestamp + } + response_messages.append(message_dict) + + if message['status'] == 3: + message_role = message['choices'][0]['role'] + message_content = message['choices'][0]['delta']['content'] + message_list = split_message_content(message_content) + for m in message_list: + timestamp = datetime.now().isoformat() + message_dict = { + "message_id": 0, + "agent_status": 3, + "text": m, + "files": [], + "timestamp": timestamp + } + + response_messages.append(message_dict) + + # TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容 + + return BaseResponse(code=200, msg="success", data=response_messages) \ No newline at end of file