Browse Source

feat:对话的增删改查,对话功能实现

main
gjl 1 year ago
parent
commit
4d9f6759e3
11 changed files with 530 additions and 43 deletions
  1. +5
    -5
      src/mindpilot/app/agent/agent_api.py
  2. +4
    -4
      src/mindpilot/app/agent/utils.py
  3. +2
    -0
      src/mindpilot/app/api/api_server.py
  4. +30
    -0
      src/mindpilot/app/api/conversation_routes.py
  5. +157
    -19
      src/mindpilot/app/chat/chat.py
  6. +0
    -0
      src/mindpilot/app/conversation/__init__.py
  7. +251
    -0
      src/mindpilot/app/conversation/conversation_api.py
  8. +38
    -0
      src/mindpilot/app/conversation/message.py
  9. +7
    -14
      src/mindpilot/app/model_configs/model_config_api.py
  10. +30
    -1
      src/mindpilot/app/model_configs/utils.py
  11. +6
    -0
      src/mindpilot/app/utils/system_utils.py

+ 5
- 5
src/mindpilot/app/agent/agent_api.py View File

@@ -14,7 +14,7 @@ def create_agent(
kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]),
avatar: str = Body("", description="头像图片的Base64编码")
) -> BaseResponse:
conn = sqlite3.connect('agents.db')
conn = sqlite3.connect('mindpilot.db')
cursor = conn.cursor()

cursor.execute('''
@@ -79,7 +79,7 @@ def create_agent(
def delete_agent(
agent_id: int = Body(..., examples=["1"])
) -> BaseResponse:
conn = sqlite3.connect('agents.db')
conn = sqlite3.connect('mindpilot.db')
cursor = conn.cursor()

if agent_id is None:
@@ -108,7 +108,7 @@ def update_agent(
kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]),
avatar: str = Body("", description="头像图片的Base64编码")
) -> BaseResponse:
conn = sqlite3.connect('agents.db')
conn = sqlite3.connect('mindpilot.db')
cursor = conn.cursor()

if agent_id is None:
@@ -139,7 +139,7 @@ def update_agent(


def list_agent() -> ListResponse:
conn = sqlite3.connect('agents.db')
conn = sqlite3.connect('mindpilot.db')
cursor = conn.cursor()

cursor.execute('''
@@ -186,7 +186,7 @@ def list_agent() -> ListResponse:
def get_agent(
agent_id: int = Query(..., examples=["1"]),
):
conn = sqlite3.connect('agents.db')
conn = sqlite3.connect('mindpilot.db')
cursor = conn.cursor()

if agent_id is None:


+ 4
- 4
src/mindpilot/app/agent/utils.py View File

@@ -1,8 +1,8 @@
import sqlite3


def get_agent_from_id(agent_id:int):
conn = sqlite3.connect('agents.db')
def get_agent_from_id(agent_id: int):
conn = sqlite3.connect('mindpilot.db')
cursor = conn.cursor()

cursor.execute('SELECT * FROM agents WHERE id = ?', (agent_id,))
@@ -13,7 +13,7 @@ def get_agent_from_id(agent_id:int):
return None

agent_dict = {
"id": agent[0],
"agent_id": agent[0],
"agent_name": agent[1],
"agent_abstract": agent[2],
"agent_info": agent[3],
@@ -24,4 +24,4 @@ def get_agent_from_id(agent_id:int):
"avatar": agent[8]
}

return agent_dict
return agent_dict

+ 2
- 0
src/mindpilot/app/api/api_server.py View File

@@ -6,6 +6,7 @@ from .chat_routes import chat_router
from .tool_routes import tool_router
from .agent_routes import agent_router
from .config_routes import config_router
from .conversation_routes import conversation_router


def create_app(run_mode: str = None):
@@ -26,5 +27,6 @@ def create_app(run_mode: str = None):
app.include_router(tool_router)
app.include_router(agent_router)
app.include_router(config_router)
app.include_router(conversation_router)

return app

+ 30
- 0
src/mindpilot/app/api/conversation_routes.py View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from fastapi import APIRouter, Request
from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages

conversation_router = APIRouter(prefix="/api/conversation", tags=["对话接口"])

conversation_router.post(
"",
summary="创建conversation",
)(add_conversation)

conversation_router.get(
"",
summary="获取conversation列表",
)(list_conversations)

conversation_router.get(
"{conversation_id}",
summary="获取单个对话详情"
)(get_conversation)

conversation_router.delete(
"{conversation_id}",
summary="删除对话"
)(delete_conversation)

conversation_router.post(
"{conversation_id}/messages",
summary="发送消息"
)(send_messages)

+ 157
- 19
src/mindpilot/app/chat/chat.py View File

@@ -27,29 +27,22 @@ def create_models_from_config(configs, callbacks, stream):
platform = configs["platform"]
base_url = configs["base_url"]
api_key = configs["api_key"]
is_openai = configs["is_openai"]
llm_model = configs["llm_model"]

model_name, params = next(iter(llm_model.items()))
callbacks = callbacks if params.get("callbacks", False) else None

if is_openai:
model_instance = get_ChatOpenAI(
model_name=model_name,
base_url=base_url,
api_key=api_key,
temperature=params.get("temperature", 0.8),
max_tokens=params.get("max_tokens", 4096),
callbacks=callbacks,
streaming=stream,
)
model = model_instance
prompt = OPENAI_PROMPT
else:
# TODO 其他不兼容OPENAI API格式的平台
model = None
prompt = None
pass
model_instance = get_ChatOpenAI(
model_name=model_name,
base_url=base_url,
api_key=api_key,
temperature=params.get("temperature", 0.8),
max_tokens=params.get("max_tokens", 4096),
callbacks=callbacks,
streaming=stream,
)
model = model_instance
prompt = OPENAI_PROMPT

return model, prompt

@@ -177,7 +170,6 @@ async def chat(
last_tool = {}
async for chunk in callback.aiter():
data = json.loads(chunk)
# print("data:{}".format(data))
data["tool_calls"] = []
data["message_type"] = MsgType.TEXT

@@ -245,6 +237,7 @@ async def chat(

async for chunk in chat_iterator():
data = json.loads(chunk)
print(data)
if text := data["choices"][0]["delta"]["content"]:
ret.content += text
if data["status"] == AgentStatus.tool_end:
@@ -253,3 +246,148 @@ async def chat(
ret.created = data["created"]

return ret.model_dump()


async def chat_online(
content: str,
history: List[History],
chat_model_config: dict,
tool_config: List[str],
agent_id: int
):
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]

model, prompt = create_models_from_config(
callbacks=callbacks, configs=chat_model_config, stream=False
)

all_tools = get_tool().values()
tool_configs = tool_config

agent_enable = True
if agent_id != -1:
if agent_id != 0:
agent_dict = get_agent_from_id(agent_id)
agent_name = agent_dict["agent_name"]
agent_abstract = agent_dict["agent_abstract"]
agent_info = agent_dict["agent_info"]
if not tool_config:
tool_configs = agent_dict["tool_config"]
agent_prompt_pre = "Your name is " + agent_name + "." + agent_abstract + ". Below is your detailed information:" + agent_info + "."
agent_prompt_after = "DO NOT forget " + agent_prompt_pre
prompt = agent_prompt_pre + prompt + agent_prompt_after
# TODO 处理知识库
else:
prompt = prompt # 默认Agent提示模板
else:
agent_enable = False

tool_configs = tool_configs or TOOL_CONFIG
tools = [tool for tool in all_tools if tool.name in tool_configs]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]

full_chain = create_models_chains(
prompts=prompt,
models=model,
tools=tools,
callbacks=callbacks,
history=history,
agent_enable=agent_enable
)

_history = [History.from_data(h) for h in history]
chat_history = [h.to_msg_tuple() for h in _history]

history_message = convert_to_messages(chat_history)

task = asyncio.create_task(
wrap_done(
full_chain.ainvoke(
{
"input": content,
"chat_history": history_message,
}
),
callback.done,
)
)

last_tool = {}
async for chunk in callback.aiter():
data = json.loads(chunk)
# print("data:{}".format(data))
data["tool_calls"] = []
data["message_type"] = MsgType.TEXT

if data["status"] == AgentStatus.tool_start:
last_tool = {
"index": 0,
"id": data["run_id"],
"type": "function",
"function": {
"name": data["tool"],
"arguments": data["tool_input"],
},
"tool_output": None,
"is_error": False,
}
data["tool_calls"].append(last_tool)
if data["status"] in [AgentStatus.tool_end]:
last_tool.update(
tool_output=data["tool_output"],
is_error=data.get("is_error", False),
)
data["tool_calls"] = [last_tool]
last_tool = {}
try:
tool_output = json.loads(data["tool_output"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...
elif data["status"] == AgentStatus.agent_finish:
try:
tool_output = json.loads(data["text"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...

ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content=data.get("text", ""),
role="assistant",
tool_calls=data["tool_calls"],
model=model.model_name,
status=data["status"],
message_type=data["message_type"],
)
yield ret.model_dump_json()

await task

ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion",
content="",
role="assistant",
finish_reason="stop",
tool_calls=[],
status=AgentStatus.agent_finish,
message_type=MsgType.TEXT,
)

async for chunk in chat_iterator():
data = json.loads(chunk)
# print(data)
if text := data["choices"][0]["delta"]["content"]:
ret.content += text
if data["status"] == AgentStatus.tool_end:
ret.tool_calls += data["choices"][0]["delta"]["tool_calls"]
ret.model = data["model"]
ret.created = data["created"]

return ret.model_dump()

+ 0
- 0
src/mindpilot/app/conversation/__init__.py View File


+ 251
- 0
src/mindpilot/app/conversation/conversation_api.py View File

@@ -0,0 +1,251 @@
import json
from typing import List

from fastapi import Body
from uuid import uuid4
from datetime import datetime
import sqlite3
from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection
from .message import init_messages_table, insert_message
from ..model_configs.utils import get_config_from_id
from ..agent.utils import get_agent_from_id
from ..chat.chat import chat_online


def init_conversations_table():
conn = get_mindpilot_db_connection()
conn.execute('''
CREATE TABLE IF NOT EXISTS conversations (
conversation_id TEXT PRIMARY KEY,
title TEXT,
created_at TEXT,
updated_at TEXT,
is_summarized BOOLEAN,
agent_id INTEGER
)
''')
conn.commit()
conn.close()


async def add_conversation(
agent_id: int = Body(0, description="使用agent情况,-1代表不使用agent,0代表使用默认agent"),
):
init_conversations_table()
init_messages_table()
conversation_id = str(uuid4())
created_at = updated_at = datetime.now().isoformat()
is_summarized = False

conn = get_mindpilot_db_connection()
cursor = conn.cursor()
cursor.execute('''
INSERT INTO conversations (conversation_id, title, created_at, updated_at, is_summarized, agent_id)
VALUES (?, ?, ?, ?, ?, ?)
''', (conversation_id, "New Conversation", created_at, updated_at, is_summarized, agent_id))
conn.commit()
conn.close()

response_data = {
"conversation_id": conversation_id,
"title": "New Conversation",
"created_at": datetime.fromisoformat(created_at),
"updated_at": datetime.fromisoformat(updated_at),
"is_summarized": is_summarized,
"agent_id": agent_id,
}
return BaseResponse(code=200, msg="success", data=response_data)


async def list_conversations():
init_conversations_table()
init_messages_table()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()
cursor.execute('''
SELECT conversation_id, title, created_at, updated_at, is_summarized, agent_id
FROM conversations
''')
rows = cursor.fetchall()
conn.close()

conversations = []
for row in rows:
conversation = {
"conversation_id": row['conversation_id'],
"title": row['title'],
"created_at": datetime.fromisoformat(row['created_at']),
"updated_at": datetime.fromisoformat(row['updated_at']),
"is_summarized": row['is_summarized'],
"agent_id": row['agent_id'],
}
conversations.append(conversation)

return ListResponse(code=200, msg="success", data=conversations)


async def get_conversation(conversation_id: str):
init_conversations_table()
init_messages_table()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()

# 获取对话详情
cursor.execute('''
SELECT conversation_id, title, created_at, updated_at, is_summarized, agent_id
FROM conversations
WHERE conversation_id = ?
''', (conversation_id,))
conversation_row = cursor.fetchone()

if not conversation_row:
conn.close()
return BaseResponse(code=404, msg="Conversation not found")

conversation = {
"conversation_id": conversation_row['conversation_id'],
"title": conversation_row['title'],
"created_at": datetime.fromisoformat(conversation_row['created_at']),
"updated_at": datetime.fromisoformat(conversation_row['updated_at']),
"is_summarized": conversation_row['is_summarized'],
"agent_id": conversation_row['agent_id'],
"messages": []
}

# 获取对话的所有消息
cursor.execute('''
SELECT id, agent_status, role, content, files, timestamp
FROM message
WHERE conversation_id = ?
''', (conversation_id,))
message_rows = cursor.fetchall()

for row in message_rows:
message = {
"message_id": row['id'],
"agent_status": row['agent_status'],
"role": row['role'],
"content": row['content'],
"files": json.loads(row['files']),
"timestamp": datetime.fromisoformat(row['timestamp'])
}
conversation['messages'].append(message)

conn.close()

return BaseResponse(code=200, msg="success", data=conversation)


async def delete_conversation(conversation_id: str):
init_conversations_table()
init_messages_table()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()

# 检查对话是否存在
cursor.execute('''
SELECT conversation_id
FROM conversations
WHERE conversation_id = ?
''', (conversation_id,))
conversation_row = cursor.fetchone()

if not conversation_row:
conn.close()
return BaseResponse(code=404, msg="Conversation not found", data={"conversation_id": "-1"})

# 删除对话相关的消息
cursor.execute('''
DELETE FROM message
WHERE conversation_id = ?
''', (conversation_id,))

# 删除对话
cursor.execute('''
DELETE FROM conversations
WHERE conversation_id = ?
''', (conversation_id,))

conn.commit()
conn.close()

return BaseResponse(code=200, msg="success", data={"conversation_id": conversation_id})


async def send_messages(
conversation_id: str,
role: str = Body("", description="消息角色:user/assistant", examples=["user", "assistant"]),
agent_id: int = Body(0, description="使用agent,0为默认,-1为不使用agent", examples=[0]),
config_id: int = Body("0", description="模型配置", examples=[1]),
files: dict = Body({}, description="文件", examples=[{}]),
content: str = Body("", description="消息内容"),
tool_config: List[str] = Body([], description="工具配置", examples=[]),
):
"""
1. 获取历史记录
2. 存放用户输入
3. 获取模型配置
4. 获取agent信息
5. 组织模型输出
6. 存放模型输出
"""
init_conversations_table()
init_messages_table()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()

# 获取历史记录
cursor.execute('''
SELECT role, content, timestamp
FROM message
WHERE conversation_id = ?
ORDER BY timestamp
''', (conversation_id,))
message_rows = cursor.fetchall()

history = []
for row in message_rows:
history.append({
"role": row['role'],
"content": row['content']
})

# 存放用户输入
insert_message(agent_status=0, role=role, content=content, files=json.dumps(files), conversation_id=conversation_id,
tool_calls=json.dumps({}))

# 获取模型配置
chat_model_config = get_config_from_id(config_id=config_id)

# 获取模型输出
ret = await chat_online(content=content, history=history, chat_model_config=chat_model_config,
tool_config=tool_config, agent_id=agent_id)

# 解析模型输出
message_id = str(uuid4())
message_role = ret['choices'][0]['message']['role']
message_content = ret['choices'][0]['message']['content']
tool_calls = ret['choices'][0]['message']['tool_calls']

# 存放模型输出
timestamp = insert_message(agent_status=5, role=message_role, content=message_content, files=json.dumps(files),
conversation_id=conversation_id, tool_calls=json.dumps(tool_calls))

# 构建响应
response_data = {
"messages": [
{
"id": message_id,
"role": message_role,
"agent_status": 5,
"content": message_content,
"tool_calls": tool_calls,
"files": files,
"timestamp": datetime.fromisoformat(timestamp)
}
]
}

conn.close()

return BaseResponse(code=200, msg="success", data=response_data)

+ 38
- 0
src/mindpilot/app/conversation/message.py View File

@@ -0,0 +1,38 @@
from datetime import datetime
import sqlite3
import json
from ..utils.system_utils import get_mindpilot_db_connection


# 初始化数据库和表
def init_messages_table():
conn = get_mindpilot_db_connection()
conn.execute('''
CREATE TABLE IF NOT EXISTS message (
id INTEGER PRIMARY KEY AUTOINCREMENT,
agent_status INTEGER,
role TEXT,
content TEXT,
tool_calls TEXT, --json格式
files TEXT, --json格式
timestamp TEXT,
conversation_id TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations (conversation_id)
)
''')
conn.commit()
conn.close()


def insert_message(agent_status, role, content, files, conversation_id, tool_calls):
conn = get_mindpilot_db_connection()
cursor = conn.cursor()
timestamp = datetime.now().isoformat() # 获取当前时间戳
cursor.execute('''
INSERT INTO message (agent_status, role, content, tool_calls, files, timestamp, conversation_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (agent_status, role, content, tool_calls, json.dumps(files), timestamp, conversation_id))
conn.commit()
conn.close()

return timestamp

+ 7
- 14
src/mindpilot/app/model_configs/model_config_api.py View File

@@ -1,19 +1,12 @@
import sqlite3
from typing import List
from fastapi import APIRouter, Body, Query
from ..utils.system_utils import BaseResponse, ListResponse


# 初始化数据库连接
def get_db_connection():
conn = sqlite3.connect('model_configs.db')
conn.row_factory = sqlite3.Row
return conn
from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection


# 创建表结构
def create_table():
conn = get_db_connection()
conn = get_mindpilot_db_connection()
conn.execute('''
CREATE TABLE IF NOT EXISTS model_configs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -45,7 +38,7 @@ async def add_model_config(

):
create_table()
conn = get_db_connection()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()

cursor.execute('''
@@ -83,7 +76,7 @@ async def add_model_config(

async def list_model_configs():
create_table()
conn = get_db_connection()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()

cursor.execute('SELECT * FROM model_configs')
@@ -116,7 +109,7 @@ async def get_model_config(

):
create_table()
conn = get_db_connection()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()

cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,))
@@ -159,7 +152,7 @@ async def update_model_config(
}]),
):
create_table()
conn = get_db_connection()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()

cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,))
@@ -206,7 +199,7 @@ async def update_model_config(

async def delete_model_config(config_id):
create_table()
conn = get_db_connection()
conn = get_mindpilot_db_connection()
cursor = conn.cursor()

# 首先检查配置是否存在


+ 30
- 1
src/mindpilot/app/model_configs/utils.py View File

@@ -1 +1,30 @@
# TODO 获取配置
import sqlite3


def get_config_from_id(config_id: int):
conn = sqlite3.connect('mindpilot.db')
cursor = conn.cursor()

cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,))
config = cursor.fetchone()
conn.close()

if not config:
return None

config_dict = {
"config_id": config[0],
"config_name": config[1],
"platform": config[2],
"base_url": config[3],
"api_key": config[4],
"llm_model": {
config[5]: {
"temperature": config[8],
"max_tokens": config[7],
"callbacks": config[6],
}
}
}

return config_dict

+ 6
- 0
src/mindpilot/app/utils/system_utils.py View File

@@ -3,6 +3,7 @@ import logging
import multiprocessing as mp
import os
import socket
import sqlite3
import sys
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from pathlib import Path
@@ -207,3 +208,8 @@ class ListResponse(BaseResponse):
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
}
}

def get_mindpilot_db_connection():
conn = sqlite3.connect('mindpilot.db')
conn.row_factory = sqlite3.Row
return conn

Loading…
Cancel
Save