| @@ -1,6 +1,6 @@ | |||
| from __future__ import annotations | |||
| from fastapi import APIRouter, Request | |||
| from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages | |||
| from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages, debug_messages | |||
| conversation_router = APIRouter(prefix="/api/conversation", tags=["对话接口"]) | |||
| @@ -28,3 +28,8 @@ conversation_router.post( | |||
| "/{conversation_id}/messages", | |||
| summary="发送消息" | |||
| )(send_messages) | |||
| conversation_router.post( | |||
| "/debug", | |||
| summary="调试对话", | |||
| )(debug_messages) | |||
| @@ -1,7 +1,7 @@ | |||
| import asyncio | |||
| import json | |||
| import uuid | |||
| from typing import AsyncIterable, List | |||
| from typing import AsyncIterable, List, Dict, Any | |||
| from fastapi import Body | |||
| from langchain.chains import LLMChain | |||
| @@ -378,3 +378,127 @@ async def chat_online( | |||
| ret.append(data) | |||
| return ret | |||
| async def debug_chat_online( | |||
| content: str, | |||
| history: List[History], | |||
| chat_model_config: dict, | |||
| tool_config: List[str], | |||
| agent_config: Dict[str, Any] | |||
| ): | |||
| async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: | |||
| callback = AgentExecutorAsyncIteratorCallbackHandler() | |||
| callbacks = [callback] | |||
| model, prompt = create_models_from_config( | |||
| callbacks=callbacks, configs=chat_model_config, stream=False | |||
| ) | |||
| all_tools = get_tool().values() | |||
| tool_configs = tool_config | |||
| agent_enable = agent_config["agent_enable"] | |||
| if agent_enable: | |||
| agent_name = agent_config["agent_name"] | |||
| agent_abstract = agent_config["agent_abstract"] | |||
| agent_info = agent_config["agent_info"] | |||
| if not tool_config: | |||
| tool_configs = agent_config["tool_config"] | |||
| agent_prompt_pre = "Your name is " + agent_name + "." + agent_abstract + ". Below is your detailed information:" + agent_info + "." | |||
| agent_prompt_after = "DO NOT forget " + agent_prompt_pre | |||
| prompt = agent_prompt_pre + prompt + agent_prompt_after | |||
| # TODO 处理知识库 | |||
| tool_configs = tool_configs or TOOL_CONFIG | |||
| tools = [tool for tool in all_tools if tool.name in tool_configs] | |||
| tools = [t.copy(update={"callbacks": callbacks}) for t in tools] | |||
| full_chain = create_models_chains( | |||
| prompts=prompt, | |||
| models=model, | |||
| tools=tools, | |||
| callbacks=callbacks, | |||
| history=history, | |||
| agent_enable=agent_enable | |||
| ) | |||
| _history = [History.from_data(h) for h in history] | |||
| chat_history = [h.to_msg_tuple() for h in _history] | |||
| history_message = convert_to_messages(chat_history) | |||
| task = asyncio.create_task( | |||
| wrap_done( | |||
| full_chain.ainvoke( | |||
| { | |||
| "input": content, | |||
| "chat_history": history_message, | |||
| } | |||
| ), | |||
| callback.done, | |||
| ) | |||
| ) | |||
| last_tool = {} | |||
| async for chunk in callback.aiter(): | |||
| data = json.loads(chunk) | |||
| data["tool_calls"] = [] | |||
| data["message_type"] = MsgType.TEXT | |||
| if data["status"] == AgentStatus.tool_start: | |||
| last_tool = { | |||
| "index": 0, | |||
| "id": data["run_id"], | |||
| "type": "function", | |||
| "function": { | |||
| "name": data["tool"], | |||
| "arguments": data["tool_input"], | |||
| }, | |||
| "tool_output": None, | |||
| "is_error": False, | |||
| } | |||
| data["tool_calls"].append(last_tool) | |||
| if data["status"] in [AgentStatus.tool_end]: | |||
| last_tool.update( | |||
| tool_output=data["tool_output"], | |||
| is_error=data.get("is_error", False), | |||
| ) | |||
| data["tool_calls"] = [last_tool] | |||
| last_tool = {} | |||
| try: | |||
| tool_output = json.loads(data["tool_output"]) | |||
| if message_type := tool_output.get("message_type"): | |||
| data["message_type"] = message_type | |||
| except: | |||
| ... | |||
| elif data["status"] == AgentStatus.agent_finish: | |||
| try: | |||
| tool_output = json.loads(data["text"]) | |||
| if message_type := tool_output.get("message_type"): | |||
| data["message_type"] = message_type | |||
| except: | |||
| ... | |||
| ret = OpenAIChatOutput( | |||
| id=f"chat{uuid.uuid4()}", | |||
| object="chat.completion.chunk", | |||
| content=data.get("text", ""), | |||
| role="assistant", | |||
| tool_calls=data["tool_calls"], | |||
| model=model.model_name, | |||
| status=data["status"], | |||
| message_type=data["message_type"], | |||
| ) | |||
| yield ret.model_dump_json() | |||
| await task | |||
| ret = [] | |||
| async for chunk in chat_iterator(): | |||
| data = json.loads(chunk) | |||
| if data["status"] != AgentStatus.llm_start and data["status"] != AgentStatus.llm_new_token: | |||
| ret.append(data) | |||
| return ret | |||
| @@ -6,11 +6,13 @@ from fastapi import Body | |||
| from uuid import uuid4 | |||
| from datetime import datetime | |||
| import sqlite3 | |||
| from ..chat.utils import History | |||
| from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection | |||
| from .message import init_messages_table, insert_message, split_message_content | |||
| from ..model_configs.utils import get_config_from_id | |||
| from ..agent.utils import get_agent_from_id | |||
| from ..chat.chat import chat_online | |||
| from ..chat.chat import chat_online, debug_chat_online | |||
| def init_conversations_table(): | |||
| @@ -185,6 +187,7 @@ async def send_messages( | |||
| max_tokens: int = Body(..., description="模型输出最大长度", examples=[4096]), | |||
| ): | |||
| # TODO 缺少知识库部分 | |||
| init_conversations_table() | |||
| init_messages_table() | |||
| conn = get_mindpilot_db_connection() | |||
| @@ -316,3 +319,83 @@ async def send_messages( | |||
| conn.close() | |||
| return BaseResponse(code=200, msg="success", data=response_messages) | |||
| async def debug_messages( | |||
| query: str = Body(..., description="用户输入", examples=[""]), | |||
| history: List[History] = Body( | |||
| [], | |||
| description="历史对话", | |||
| examples=[ | |||
| [ | |||
| {"role": "user", "content": "你好"}, | |||
| {"role": "assistant", "content": "您好,我是智能Agent桌面助手MindPilot,请问有什么可以帮您?"}, | |||
| ] | |||
| ], | |||
| ), | |||
| config_id: int = Body("0", description="模型配置", examples=[1]), | |||
| agent_config: dict = Body(..., description="agent配置", examples=[ | |||
| { | |||
| "agent_name": "调试助手", | |||
| "agent_abstract": "", | |||
| "agent_info": "这是一个用于调试的AI助手", | |||
| "agent_enable": True, | |||
| "temperature": 0.7, | |||
| "max_tokens": 150, | |||
| "tool_config": ["search_internet", "calculator"] | |||
| } | |||
| ]) | |||
| ): | |||
| # TODO 缺少知识库部分 | |||
| if not agent_config["agent_enable"]: | |||
| temp_agent_name = agent_config["agent_name"] | |||
| temp_agent_abstract = agent_config["agent_abstract"] | |||
| temp_agent_info = agent_config["agent_info"] | |||
| agent_prompt = "Your name is " + temp_agent_name + "." + temp_agent_abstract + ". Below is your detailed information:" + temp_agent_info + "." | |||
| history.append({"role": "user", "content": agent_prompt}) | |||
| # 获取模型配置 | |||
| chat_model_config = get_config_from_id(config_id=config_id) | |||
| model_key = next(iter(chat_model_config["llm_model"])) | |||
| chat_model_config["llm_model"][model_key]["temperature"] = agent_config['temperature'] | |||
| chat_model_config["llm_model"][model_key]["max_tokens"] = agent_config['max_tokens'] | |||
| # 获取模型输出 | |||
| ret = await debug_chat_online(content=query, history=history, chat_model_config=chat_model_config, | |||
| tool_config=agent_config['tool_config'], agent_config=agent_config) | |||
| response_messages = [] | |||
| for message in ret: | |||
| if message['status'] == 7: | |||
| message_role = message['choices'][0]['role'] | |||
| message_content = "Observation:\n" + message['choices'][0]['delta']['tool_calls'][0]['tool_output'] | |||
| timestamp = datetime.now().isoformat() | |||
| message_dict = { | |||
| "message_id": 0, | |||
| "agent_status": 7, | |||
| "text": message_content, | |||
| "files": [], | |||
| "timestamp": timestamp | |||
| } | |||
| response_messages.append(message_dict) | |||
| if message['status'] == 3: | |||
| message_role = message['choices'][0]['role'] | |||
| message_content = message['choices'][0]['delta']['content'] | |||
| message_list = split_message_content(message_content) | |||
| for m in message_list: | |||
| timestamp = datetime.now().isoformat() | |||
| message_dict = { | |||
| "message_id": 0, | |||
| "agent_status": 3, | |||
| "text": m, | |||
| "files": [], | |||
| "timestamp": timestamp | |||
| } | |||
| response_messages.append(message_dict) | |||
| # TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容 | |||
| return BaseResponse(code=200, msg="success", data=response_messages) | |||