From 4d9f6759e321e7a1d39ba49cddf2cd0803c3803a Mon Sep 17 00:00:00 2001 From: gjl <2802427218@qq.com> Date: Thu, 15 Aug 2024 18:00:16 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E5=AF=B9=E8=AF=9D=E7=9A=84=E5=A2=9E?= =?UTF-8?q?=E5=88=A0=E6=94=B9=E6=9F=A5=EF=BC=8C=E5=AF=B9=E8=AF=9D=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mindpilot/app/agent/agent_api.py | 10 +- src/mindpilot/app/agent/utils.py | 8 +- src/mindpilot/app/api/api_server.py | 2 + src/mindpilot/app/api/conversation_routes.py | 30 +++ src/mindpilot/app/chat/chat.py | 176 ++++++++++-- src/mindpilot/app/conversation/__init__.py | 0 .../app/conversation/conversation_api.py | 251 ++++++++++++++++++ src/mindpilot/app/conversation/message.py | 38 +++ .../app/model_configs/model_config_api.py | 21 +- src/mindpilot/app/model_configs/utils.py | 31 ++- src/mindpilot/app/utils/system_utils.py | 6 + 11 files changed, 530 insertions(+), 43 deletions(-) create mode 100644 src/mindpilot/app/api/conversation_routes.py create mode 100644 src/mindpilot/app/conversation/__init__.py create mode 100644 src/mindpilot/app/conversation/conversation_api.py create mode 100644 src/mindpilot/app/conversation/message.py diff --git a/src/mindpilot/app/agent/agent_api.py b/src/mindpilot/app/agent/agent_api.py index 1a2dffd..de1b4c7 100644 --- a/src/mindpilot/app/agent/agent_api.py +++ b/src/mindpilot/app/agent/agent_api.py @@ -14,7 +14,7 @@ def create_agent( kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), avatar: str = Body("", description="头像图片的Base64编码") ) -> BaseResponse: - conn = sqlite3.connect('agents.db') + conn = sqlite3.connect('mindpilot.db') cursor = conn.cursor() cursor.execute(''' @@ -79,7 +79,7 @@ def create_agent( def delete_agent( agent_id: int = Body(..., examples=["1"]) ) -> BaseResponse: - conn = sqlite3.connect('agents.db') + conn = sqlite3.connect('mindpilot.db') cursor = conn.cursor() if agent_id is None: @@ -108,7 +108,7 @@ def update_agent( kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), avatar: str = Body("", description="头像图片的Base64编码") ) -> BaseResponse: - conn = sqlite3.connect('agents.db') + conn = sqlite3.connect('mindpilot.db') cursor = conn.cursor() if agent_id is None: @@ -139,7 +139,7 @@ def update_agent( def list_agent() -> ListResponse: - conn = sqlite3.connect('agents.db') + conn = sqlite3.connect('mindpilot.db') cursor = conn.cursor() cursor.execute(''' @@ -186,7 +186,7 @@ def list_agent() -> ListResponse: def get_agent( agent_id: int = Query(..., examples=["1"]), ): - conn = sqlite3.connect('agents.db') + conn = sqlite3.connect('mindpilot.db') cursor = conn.cursor() if agent_id is None: diff --git a/src/mindpilot/app/agent/utils.py b/src/mindpilot/app/agent/utils.py index 40eeb25..6cc1e9d 100644 --- a/src/mindpilot/app/agent/utils.py +++ b/src/mindpilot/app/agent/utils.py @@ -1,8 +1,8 @@ import sqlite3 -def get_agent_from_id(agent_id:int): - conn = sqlite3.connect('agents.db') +def get_agent_from_id(agent_id: int): + conn = sqlite3.connect('mindpilot.db') cursor = conn.cursor() cursor.execute('SELECT * FROM agents WHERE id = ?', (agent_id,)) @@ -13,7 +13,7 @@ def get_agent_from_id(agent_id:int): return None agent_dict = { - "id": agent[0], + "agent_id": agent[0], "agent_name": agent[1], "agent_abstract": agent[2], "agent_info": agent[3], @@ -24,4 +24,4 @@ def get_agent_from_id(agent_id:int): "avatar": agent[8] } - return agent_dict \ No newline at end of file + return agent_dict diff --git a/src/mindpilot/app/api/api_server.py b/src/mindpilot/app/api/api_server.py index 13b97d5..560d2ea 100644 --- a/src/mindpilot/app/api/api_server.py +++ b/src/mindpilot/app/api/api_server.py @@ -6,6 +6,7 @@ from .chat_routes import chat_router from .tool_routes import tool_router from .agent_routes import agent_router from .config_routes import config_router +from .conversation_routes import conversation_router def create_app(run_mode: str = None): @@ -26,5 +27,6 @@ def create_app(run_mode: str = None): app.include_router(tool_router) app.include_router(agent_router) app.include_router(config_router) + app.include_router(conversation_router) return app diff --git a/src/mindpilot/app/api/conversation_routes.py b/src/mindpilot/app/api/conversation_routes.py new file mode 100644 index 0000000..83228fd --- /dev/null +++ b/src/mindpilot/app/api/conversation_routes.py @@ -0,0 +1,30 @@ +from __future__ import annotations +from fastapi import APIRouter, Request +from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages + +conversation_router = APIRouter(prefix="/api/conversation", tags=["对话接口"]) + +conversation_router.post( + "", + summary="创建conversation", +)(add_conversation) + +conversation_router.get( + "", + summary="获取conversation列表", +)(list_conversations) + +conversation_router.get( + "{conversation_id}", + summary="获取单个对话详情" +)(get_conversation) + +conversation_router.delete( + "{conversation_id}", + summary="删除对话" +)(delete_conversation) + +conversation_router.post( + "{conversation_id}/messages", + summary="发送消息" +)(send_messages) diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index 9928f92..34ce9df 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -27,29 +27,22 @@ def create_models_from_config(configs, callbacks, stream): platform = configs["platform"] base_url = configs["base_url"] api_key = configs["api_key"] - is_openai = configs["is_openai"] llm_model = configs["llm_model"] model_name, params = next(iter(llm_model.items())) callbacks = callbacks if params.get("callbacks", False) else None - if is_openai: - model_instance = get_ChatOpenAI( - model_name=model_name, - base_url=base_url, - api_key=api_key, - temperature=params.get("temperature", 0.8), - max_tokens=params.get("max_tokens", 4096), - callbacks=callbacks, - streaming=stream, - ) - model = model_instance - prompt = OPENAI_PROMPT - else: - # TODO 其他不兼容OPENAI API格式的平台 - model = None - prompt = None - pass + model_instance = get_ChatOpenAI( + model_name=model_name, + base_url=base_url, + api_key=api_key, + temperature=params.get("temperature", 0.8), + max_tokens=params.get("max_tokens", 4096), + callbacks=callbacks, + streaming=stream, + ) + model = model_instance + prompt = OPENAI_PROMPT return model, prompt @@ -177,7 +170,6 @@ async def chat( last_tool = {} async for chunk in callback.aiter(): data = json.loads(chunk) - # print("data:{}".format(data)) data["tool_calls"] = [] data["message_type"] = MsgType.TEXT @@ -245,6 +237,7 @@ async def chat( async for chunk in chat_iterator(): data = json.loads(chunk) + print(data) if text := data["choices"][0]["delta"]["content"]: ret.content += text if data["status"] == AgentStatus.tool_end: @@ -253,3 +246,148 @@ async def chat( ret.created = data["created"] return ret.model_dump() + + +async def chat_online( + content: str, + history: List[History], + chat_model_config: dict, + tool_config: List[str], + agent_id: int +): + async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: + callback = AgentExecutorAsyncIteratorCallbackHandler() + callbacks = [callback] + + model, prompt = create_models_from_config( + callbacks=callbacks, configs=chat_model_config, stream=False + ) + + all_tools = get_tool().values() + tool_configs = tool_config + + agent_enable = True + if agent_id != -1: + if agent_id != 0: + agent_dict = get_agent_from_id(agent_id) + agent_name = agent_dict["agent_name"] + agent_abstract = agent_dict["agent_abstract"] + agent_info = agent_dict["agent_info"] + if not tool_config: + tool_configs = agent_dict["tool_config"] + agent_prompt_pre = "Your name is " + agent_name + "." + agent_abstract + ". Below is your detailed information:" + agent_info + "." + agent_prompt_after = "DO NOT forget " + agent_prompt_pre + prompt = agent_prompt_pre + prompt + agent_prompt_after + # TODO 处理知识库 + else: + prompt = prompt # 默认Agent提示模板 + else: + agent_enable = False + + tool_configs = tool_configs or TOOL_CONFIG + tools = [tool for tool in all_tools if tool.name in tool_configs] + tools = [t.copy(update={"callbacks": callbacks}) for t in tools] + + full_chain = create_models_chains( + prompts=prompt, + models=model, + tools=tools, + callbacks=callbacks, + history=history, + agent_enable=agent_enable + ) + + _history = [History.from_data(h) for h in history] + chat_history = [h.to_msg_tuple() for h in _history] + + history_message = convert_to_messages(chat_history) + + task = asyncio.create_task( + wrap_done( + full_chain.ainvoke( + { + "input": content, + "chat_history": history_message, + } + ), + callback.done, + ) + ) + + last_tool = {} + async for chunk in callback.aiter(): + data = json.loads(chunk) + # print("data:{}".format(data)) + data["tool_calls"] = [] + data["message_type"] = MsgType.TEXT + + if data["status"] == AgentStatus.tool_start: + last_tool = { + "index": 0, + "id": data["run_id"], + "type": "function", + "function": { + "name": data["tool"], + "arguments": data["tool_input"], + }, + "tool_output": None, + "is_error": False, + } + data["tool_calls"].append(last_tool) + if data["status"] in [AgentStatus.tool_end]: + last_tool.update( + tool_output=data["tool_output"], + is_error=data.get("is_error", False), + ) + data["tool_calls"] = [last_tool] + last_tool = {} + try: + tool_output = json.loads(data["tool_output"]) + if message_type := tool_output.get("message_type"): + data["message_type"] = message_type + except: + ... + elif data["status"] == AgentStatus.agent_finish: + try: + tool_output = json.loads(data["text"]) + if message_type := tool_output.get("message_type"): + data["message_type"] = message_type + except: + ... + + ret = OpenAIChatOutput( + id=f"chat{uuid.uuid4()}", + object="chat.completion.chunk", + content=data.get("text", ""), + role="assistant", + tool_calls=data["tool_calls"], + model=model.model_name, + status=data["status"], + message_type=data["message_type"], + ) + yield ret.model_dump_json() + + await task + + ret = OpenAIChatOutput( + id=f"chat{uuid.uuid4()}", + object="chat.completion", + content="", + role="assistant", + finish_reason="stop", + tool_calls=[], + status=AgentStatus.agent_finish, + message_type=MsgType.TEXT, + ) + + async for chunk in chat_iterator(): + data = json.loads(chunk) + # print(data) + if text := data["choices"][0]["delta"]["content"]: + ret.content += text + if data["status"] == AgentStatus.tool_end: + ret.tool_calls += data["choices"][0]["delta"]["tool_calls"] + ret.model = data["model"] + ret.created = data["created"] + + return ret.model_dump() diff --git a/src/mindpilot/app/conversation/__init__.py b/src/mindpilot/app/conversation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/conversation/conversation_api.py b/src/mindpilot/app/conversation/conversation_api.py new file mode 100644 index 0000000..fd58ea0 --- /dev/null +++ b/src/mindpilot/app/conversation/conversation_api.py @@ -0,0 +1,251 @@ +import json +from typing import List + +from fastapi import Body +from uuid import uuid4 +from datetime import datetime +import sqlite3 +from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection +from .message import init_messages_table, insert_message +from ..model_configs.utils import get_config_from_id +from ..agent.utils import get_agent_from_id +from ..chat.chat import chat_online + + +def init_conversations_table(): + conn = get_mindpilot_db_connection() + conn.execute(''' + CREATE TABLE IF NOT EXISTS conversations ( + conversation_id TEXT PRIMARY KEY, + title TEXT, + created_at TEXT, + updated_at TEXT, + is_summarized BOOLEAN, + agent_id INTEGER + ) + ''') + conn.commit() + conn.close() + + +async def add_conversation( + agent_id: int = Body(0, description="使用agent情况,-1代表不使用agent,0代表使用默认agent"), +): + init_conversations_table() + init_messages_table() + conversation_id = str(uuid4()) + created_at = updated_at = datetime.now().isoformat() + is_summarized = False + + conn = get_mindpilot_db_connection() + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO conversations (conversation_id, title, created_at, updated_at, is_summarized, agent_id) + VALUES (?, ?, ?, ?, ?, ?) + ''', (conversation_id, "New Conversation", created_at, updated_at, is_summarized, agent_id)) + conn.commit() + conn.close() + + response_data = { + "conversation_id": conversation_id, + "title": "New Conversation", + "created_at": datetime.fromisoformat(created_at), + "updated_at": datetime.fromisoformat(updated_at), + "is_summarized": is_summarized, + "agent_id": agent_id, + } + return BaseResponse(code=200, msg="success", data=response_data) + + +async def list_conversations(): + init_conversations_table() + init_messages_table() + conn = get_mindpilot_db_connection() + cursor = conn.cursor() + cursor.execute(''' + SELECT conversation_id, title, created_at, updated_at, is_summarized, agent_id + FROM conversations + ''') + rows = cursor.fetchall() + conn.close() + + conversations = [] + for row in rows: + conversation = { + "conversation_id": row['conversation_id'], + "title": row['title'], + "created_at": datetime.fromisoformat(row['created_at']), + "updated_at": datetime.fromisoformat(row['updated_at']), + "is_summarized": row['is_summarized'], + "agent_id": row['agent_id'], + } + conversations.append(conversation) + + return ListResponse(code=200, msg="success", data=conversations) + + +async def get_conversation(conversation_id: str): + init_conversations_table() + init_messages_table() + conn = get_mindpilot_db_connection() + cursor = conn.cursor() + + # 获取对话详情 + cursor.execute(''' + SELECT conversation_id, title, created_at, updated_at, is_summarized, agent_id + FROM conversations + WHERE conversation_id = ? + ''', (conversation_id,)) + conversation_row = cursor.fetchone() + + if not conversation_row: + conn.close() + return BaseResponse(code=404, msg="Conversation not found") + + conversation = { + "conversation_id": conversation_row['conversation_id'], + "title": conversation_row['title'], + "created_at": datetime.fromisoformat(conversation_row['created_at']), + "updated_at": datetime.fromisoformat(conversation_row['updated_at']), + "is_summarized": conversation_row['is_summarized'], + "agent_id": conversation_row['agent_id'], + "messages": [] + } + + # 获取对话的所有消息 + cursor.execute(''' + SELECT id, agent_status, role, content, files, timestamp + FROM message + WHERE conversation_id = ? + ''', (conversation_id,)) + message_rows = cursor.fetchall() + + for row in message_rows: + message = { + "message_id": row['id'], + "agent_status": row['agent_status'], + "role": row['role'], + "content": row['content'], + "files": json.loads(row['files']), + "timestamp": datetime.fromisoformat(row['timestamp']) + } + conversation['messages'].append(message) + + conn.close() + + return BaseResponse(code=200, msg="success", data=conversation) + + +async def delete_conversation(conversation_id: str): + init_conversations_table() + init_messages_table() + conn = get_mindpilot_db_connection() + cursor = conn.cursor() + + # 检查对话是否存在 + cursor.execute(''' + SELECT conversation_id + FROM conversations + WHERE conversation_id = ? + ''', (conversation_id,)) + conversation_row = cursor.fetchone() + + if not conversation_row: + conn.close() + return BaseResponse(code=404, msg="Conversation not found", data={"conversation_id": "-1"}) + + # 删除对话相关的消息 + cursor.execute(''' + DELETE FROM message + WHERE conversation_id = ? + ''', (conversation_id,)) + + # 删除对话 + cursor.execute(''' + DELETE FROM conversations + WHERE conversation_id = ? + ''', (conversation_id,)) + + conn.commit() + conn.close() + + return BaseResponse(code=200, msg="success", data={"conversation_id": conversation_id}) + + +async def send_messages( + conversation_id: str, + role: str = Body("", description="消息角色:user/assistant", examples=["user", "assistant"]), + agent_id: int = Body(0, description="使用agent,0为默认,-1为不使用agent", examples=[0]), + config_id: int = Body("0", description="模型配置", examples=[1]), + files: dict = Body({}, description="文件", examples=[{}]), + content: str = Body("", description="消息内容"), + tool_config: List[str] = Body([], description="工具配置", examples=[]), +): + """ + 1. 获取历史记录 + 2. 存放用户输入 + 3. 获取模型配置 + 4. 获取agent信息 + 5. 组织模型输出 + 6. 存放模型输出 + """ + init_conversations_table() + init_messages_table() + conn = get_mindpilot_db_connection() + cursor = conn.cursor() + + # 获取历史记录 + cursor.execute(''' + SELECT role, content, timestamp + FROM message + WHERE conversation_id = ? + ORDER BY timestamp + ''', (conversation_id,)) + message_rows = cursor.fetchall() + + history = [] + for row in message_rows: + history.append({ + "role": row['role'], + "content": row['content'] + }) + + # 存放用户输入 + insert_message(agent_status=0, role=role, content=content, files=json.dumps(files), conversation_id=conversation_id, + tool_calls=json.dumps({})) + + # 获取模型配置 + chat_model_config = get_config_from_id(config_id=config_id) + + # 获取模型输出 + ret = await chat_online(content=content, history=history, chat_model_config=chat_model_config, + tool_config=tool_config, agent_id=agent_id) + + # 解析模型输出 + message_id = str(uuid4()) + message_role = ret['choices'][0]['message']['role'] + message_content = ret['choices'][0]['message']['content'] + tool_calls = ret['choices'][0]['message']['tool_calls'] + + # 存放模型输出 + timestamp = insert_message(agent_status=5, role=message_role, content=message_content, files=json.dumps(files), + conversation_id=conversation_id, tool_calls=json.dumps(tool_calls)) + + # 构建响应 + response_data = { + "messages": [ + { + "id": message_id, + "role": message_role, + "agent_status": 5, + "content": message_content, + "tool_calls": tool_calls, + "files": files, + "timestamp": datetime.fromisoformat(timestamp) + } + ] + } + + conn.close() + + return BaseResponse(code=200, msg="success", data=response_data) \ No newline at end of file diff --git a/src/mindpilot/app/conversation/message.py b/src/mindpilot/app/conversation/message.py new file mode 100644 index 0000000..fdda5cf --- /dev/null +++ b/src/mindpilot/app/conversation/message.py @@ -0,0 +1,38 @@ +from datetime import datetime +import sqlite3 +import json +from ..utils.system_utils import get_mindpilot_db_connection + + +# 初始化数据库和表 +def init_messages_table(): + conn = get_mindpilot_db_connection() + conn.execute(''' + CREATE TABLE IF NOT EXISTS message ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_status INTEGER, + role TEXT, + content TEXT, + tool_calls TEXT, --json格式 + files TEXT, --json格式 + timestamp TEXT, + conversation_id TEXT, + FOREIGN KEY (conversation_id) REFERENCES conversations (conversation_id) + ) + ''') + conn.commit() + conn.close() + + +def insert_message(agent_status, role, content, files, conversation_id, tool_calls): + conn = get_mindpilot_db_connection() + cursor = conn.cursor() + timestamp = datetime.now().isoformat() # 获取当前时间戳 + cursor.execute(''' + INSERT INTO message (agent_status, role, content, tool_calls, files, timestamp, conversation_id) + VALUES (?, ?, ?, ?, ?, ?, ?) + ''', (agent_status, role, content, tool_calls, json.dumps(files), timestamp, conversation_id)) + conn.commit() + conn.close() + + return timestamp diff --git a/src/mindpilot/app/model_configs/model_config_api.py b/src/mindpilot/app/model_configs/model_config_api.py index 72e5636..907d905 100644 --- a/src/mindpilot/app/model_configs/model_config_api.py +++ b/src/mindpilot/app/model_configs/model_config_api.py @@ -1,19 +1,12 @@ import sqlite3 from typing import List from fastapi import APIRouter, Body, Query -from ..utils.system_utils import BaseResponse, ListResponse - - -# 初始化数据库连接 -def get_db_connection(): - conn = sqlite3.connect('model_configs.db') - conn.row_factory = sqlite3.Row - return conn +from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection # 创建表结构 def create_table(): - conn = get_db_connection() + conn = get_mindpilot_db_connection() conn.execute(''' CREATE TABLE IF NOT EXISTS model_configs ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -45,7 +38,7 @@ async def add_model_config( ): create_table() - conn = get_db_connection() + conn = get_mindpilot_db_connection() cursor = conn.cursor() cursor.execute(''' @@ -83,7 +76,7 @@ async def add_model_config( async def list_model_configs(): create_table() - conn = get_db_connection() + conn = get_mindpilot_db_connection() cursor = conn.cursor() cursor.execute('SELECT * FROM model_configs') @@ -116,7 +109,7 @@ async def get_model_config( ): create_table() - conn = get_db_connection() + conn = get_mindpilot_db_connection() cursor = conn.cursor() cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,)) @@ -159,7 +152,7 @@ async def update_model_config( }]), ): create_table() - conn = get_db_connection() + conn = get_mindpilot_db_connection() cursor = conn.cursor() cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,)) @@ -206,7 +199,7 @@ async def update_model_config( async def delete_model_config(config_id): create_table() - conn = get_db_connection() + conn = get_mindpilot_db_connection() cursor = conn.cursor() # 首先检查配置是否存在 diff --git a/src/mindpilot/app/model_configs/utils.py b/src/mindpilot/app/model_configs/utils.py index d56a856..ac43375 100644 --- a/src/mindpilot/app/model_configs/utils.py +++ b/src/mindpilot/app/model_configs/utils.py @@ -1 +1,30 @@ -# TODO 获取配置 \ No newline at end of file +import sqlite3 + + +def get_config_from_id(config_id: int): + conn = sqlite3.connect('mindpilot.db') + cursor = conn.cursor() + + cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,)) + config = cursor.fetchone() + conn.close() + + if not config: + return None + + config_dict = { + "config_id": config[0], + "config_name": config[1], + "platform": config[2], + "base_url": config[3], + "api_key": config[4], + "llm_model": { + config[5]: { + "temperature": config[8], + "max_tokens": config[7], + "callbacks": config[6], + } + } + } + + return config_dict diff --git a/src/mindpilot/app/utils/system_utils.py b/src/mindpilot/app/utils/system_utils.py index 69860f2..9f6469e 100644 --- a/src/mindpilot/app/utils/system_utils.py +++ b/src/mindpilot/app/utils/system_utils.py @@ -3,6 +3,7 @@ import logging import multiprocessing as mp import os import socket +import sqlite3 import sys from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from pathlib import Path @@ -207,3 +208,8 @@ class ListResponse(BaseResponse): "data": ["doc1.docx", "doc2.pdf", "doc3.txt"], } } + +def get_mindpilot_db_connection(): + conn = sqlite3.connect('mindpilot.db') + conn.row_factory = sqlite3.Row + return conn