Browse Source

feat:调试对话

main
gjl 1 year ago
parent
commit
eeed7d84b6
3 changed files with 215 additions and 3 deletions
  1. +6
    -1
      src/mindpilot/app/api/conversation_routes.py
  2. +125
    -1
      src/mindpilot/app/chat/chat.py
  3. +84
    -1
      src/mindpilot/app/conversation/conversation_api.py

+ 6
- 1
src/mindpilot/app/api/conversation_routes.py View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from fastapi import APIRouter, Request
from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages
from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages, debug_messages

conversation_router = APIRouter(prefix="/api/conversation", tags=["对话接口"])

@@ -28,3 +28,8 @@ conversation_router.post(
"/{conversation_id}/messages",
summary="发送消息"
)(send_messages)

conversation_router.post(
"/debug",
summary="调试对话",
)(debug_messages)

+ 125
- 1
src/mindpilot/app/chat/chat.py View File

@@ -1,7 +1,7 @@
import asyncio
import json
import uuid
from typing import AsyncIterable, List
from typing import AsyncIterable, List, Dict, Any

from fastapi import Body
from langchain.chains import LLMChain
@@ -378,3 +378,127 @@ async def chat_online(
ret.append(data)

return ret


async def debug_chat_online(
content: str,
history: List[History],
chat_model_config: dict,
tool_config: List[str],
agent_config: Dict[str, Any]
):
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]

model, prompt = create_models_from_config(
callbacks=callbacks, configs=chat_model_config, stream=False
)

all_tools = get_tool().values()
tool_configs = tool_config

agent_enable = agent_config["agent_enable"]
if agent_enable:
agent_name = agent_config["agent_name"]
agent_abstract = agent_config["agent_abstract"]
agent_info = agent_config["agent_info"]
if not tool_config:
tool_configs = agent_config["tool_config"]
agent_prompt_pre = "Your name is " + agent_name + "." + agent_abstract + ". Below is your detailed information:" + agent_info + "."
agent_prompt_after = "DO NOT forget " + agent_prompt_pre
prompt = agent_prompt_pre + prompt + agent_prompt_after
# TODO 处理知识库

tool_configs = tool_configs or TOOL_CONFIG
tools = [tool for tool in all_tools if tool.name in tool_configs]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]

full_chain = create_models_chains(
prompts=prompt,
models=model,
tools=tools,
callbacks=callbacks,
history=history,
agent_enable=agent_enable
)

_history = [History.from_data(h) for h in history]
chat_history = [h.to_msg_tuple() for h in _history]

history_message = convert_to_messages(chat_history)

task = asyncio.create_task(
wrap_done(
full_chain.ainvoke(
{
"input": content,
"chat_history": history_message,
}
),
callback.done,
)
)

last_tool = {}
async for chunk in callback.aiter():
data = json.loads(chunk)
data["tool_calls"] = []
data["message_type"] = MsgType.TEXT

if data["status"] == AgentStatus.tool_start:
last_tool = {
"index": 0,
"id": data["run_id"],
"type": "function",
"function": {
"name": data["tool"],
"arguments": data["tool_input"],
},
"tool_output": None,
"is_error": False,
}
data["tool_calls"].append(last_tool)
if data["status"] in [AgentStatus.tool_end]:
last_tool.update(
tool_output=data["tool_output"],
is_error=data.get("is_error", False),
)
data["tool_calls"] = [last_tool]
last_tool = {}
try:
tool_output = json.loads(data["tool_output"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...
elif data["status"] == AgentStatus.agent_finish:
try:
tool_output = json.loads(data["text"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...

ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content=data.get("text", ""),
role="assistant",
tool_calls=data["tool_calls"],
model=model.model_name,
status=data["status"],
message_type=data["message_type"],
)
yield ret.model_dump_json()

await task

ret = []

async for chunk in chat_iterator():
data = json.loads(chunk)
if data["status"] != AgentStatus.llm_start and data["status"] != AgentStatus.llm_new_token:
ret.append(data)

return ret

+ 84
- 1
src/mindpilot/app/conversation/conversation_api.py View File

@@ -6,11 +6,13 @@ from fastapi import Body
from uuid import uuid4
from datetime import datetime
import sqlite3

from ..chat.utils import History
from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection
from .message import init_messages_table, insert_message, split_message_content
from ..model_configs.utils import get_config_from_id
from ..agent.utils import get_agent_from_id
from ..chat.chat import chat_online
from ..chat.chat import chat_online, debug_chat_online


def init_conversations_table():
@@ -185,6 +187,7 @@ async def send_messages(
max_tokens: int = Body(..., description="模型输出最大长度", examples=[4096]),

):
# TODO 缺少知识库部分
init_conversations_table()
init_messages_table()
conn = get_mindpilot_db_connection()
@@ -316,3 +319,83 @@ async def send_messages(
conn.close()

return BaseResponse(code=200, msg="success", data=response_messages)


async def debug_messages(
query: str = Body(..., description="用户输入", examples=[""]),
history: List[History] = Body(
[],
description="历史对话",
examples=[
[
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "您好,我是智能Agent桌面助手MindPilot,请问有什么可以帮您?"},
]
],
),
config_id: int = Body("0", description="模型配置", examples=[1]),
agent_config: dict = Body(..., description="agent配置", examples=[
{
"agent_name": "调试助手",
"agent_abstract": "",
"agent_info": "这是一个用于调试的AI助手",
"agent_enable": True,
"temperature": 0.7,
"max_tokens": 150,
"tool_config": ["search_internet", "calculator"]
}
])
):
# TODO 缺少知识库部分

if not agent_config["agent_enable"]:
temp_agent_name = agent_config["agent_name"]
temp_agent_abstract = agent_config["agent_abstract"]
temp_agent_info = agent_config["agent_info"]
agent_prompt = "Your name is " + temp_agent_name + "." + temp_agent_abstract + ". Below is your detailed information:" + temp_agent_info + "."
history.append({"role": "user", "content": agent_prompt})

# 获取模型配置
chat_model_config = get_config_from_id(config_id=config_id)
model_key = next(iter(chat_model_config["llm_model"]))
chat_model_config["llm_model"][model_key]["temperature"] = agent_config['temperature']
chat_model_config["llm_model"][model_key]["max_tokens"] = agent_config['max_tokens']

# 获取模型输出
ret = await debug_chat_online(content=query, history=history, chat_model_config=chat_model_config,
tool_config=agent_config['tool_config'], agent_config=agent_config)

response_messages = []
for message in ret:
if message['status'] == 7:
message_role = message['choices'][0]['role']
message_content = "Observation:\n" + message['choices'][0]['delta']['tool_calls'][0]['tool_output']
timestamp = datetime.now().isoformat()
message_dict = {
"message_id": 0,
"agent_status": 7,
"text": message_content,
"files": [],
"timestamp": timestamp
}
response_messages.append(message_dict)

if message['status'] == 3:
message_role = message['choices'][0]['role']
message_content = message['choices'][0]['delta']['content']
message_list = split_message_content(message_content)
for m in message_list:
timestamp = datetime.now().isoformat()
message_dict = {
"message_id": 0,
"agent_status": 3,
"text": m,
"files": [],
"timestamp": timestamp
}

response_messages.append(message_dict)

# TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容

return BaseResponse(code=200, msg="success", data=response_messages)

Loading…
Cancel
Save