Browse Source

feat:完善对话功能

main
gjl 1 year ago
parent
commit
67d529f769
5 changed files with 106 additions and 62 deletions
  1. +1
    -1
      src/mindpilot/app/callback_handler/agent_callback_handler.py
  2. +6
    -20
      src/mindpilot/app/chat/chat.py
  3. +1
    -1
      src/mindpilot/app/configs/prompt_config.py
  4. +66
    -34
      src/mindpilot/app/conversation/conversation_api.py
  5. +32
    -6
      src/mindpilot/app/conversation/message.py

+ 1
- 1
src/mindpilot/app/callback_handler/agent_callback_handler.py View File

@@ -43,7 +43,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["\n\nAction:", "\n\nObservation:", "<|observation|>", "\n\nThought:"]
special_tokens = ["\n\nAction:", "\n\nObservation:", "<|observation|>", "\n\nThought:", "\nThought:", "\nAction:"]
for stoken in special_tokens:
if stoken in token:
before_action = token.split(stoken)[0]


+ 6
- 20
src/mindpilot/app/chat/chat.py View File

@@ -317,7 +317,6 @@ async def chat_online(
last_tool = {}
async for chunk in callback.aiter():
data = json.loads(chunk)
# print("data:{}".format(data))
data["tool_calls"] = []
data["message_type"] = MsgType.TEXT

@@ -369,25 +368,12 @@ async def chat_online(

await task

ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion",
content="",
role="assistant",
finish_reason="stop",
tool_calls=[],
status=AgentStatus.agent_finish,
message_type=MsgType.TEXT,
)
ret = []

async for chunk in chat_iterator():
data = json.loads(chunk)
# print(data)
if text := data["choices"][0]["delta"]["content"]:
ret.content += text
if data["status"] == AgentStatus.tool_end:
ret.tool_calls += data["choices"][0]["delta"]["tool_calls"]
ret.model = data["model"]
ret.created = data["created"]

return ret.model_dump()
print(data)
if data["status"] != AgentStatus.llm_start and data["status"] != AgentStatus.llm_new_token:
ret.append(data)

return ret

+ 1
- 1
src/mindpilot/app/configs/prompt_config.py View File

@@ -50,6 +50,6 @@ Action:
```

Begin! Reminder to ALWAYS respond with a valid JSON blob of a single action. Use tools if necessary. Try to reply in Chinese as much as possible.
Don't forget the Question, Thought, and Observation sections.You MUST strictly follow the above process to output, first output the Question section ONCE, then repeat the Thought section, Action section, Observation section N times until you receive the Final Answer.
Don't forget the Question, Thought, and Observation sections.You MUST strictly follow the above process to output, The Question can only be output once throughout the entire process, then repeat the Thought section, Action section, Observation section N times until you receive the Final Answer.
Please provide as much output content as possible for the Final Answer.
'''

+ 66
- 34
src/mindpilot/app/conversation/conversation_api.py View File

@@ -1,4 +1,5 @@
import json
import re
from typing import List

from fastapi import Body
@@ -6,7 +7,7 @@ from uuid import uuid4
from datetime import datetime
import sqlite3
from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection
from .message import init_messages_table, insert_message
from .message import init_messages_table, insert_message, split_message_content
from ..model_configs.utils import get_config_from_id
from ..agent.utils import get_agent_from_id
from ..chat.chat import chat_online
@@ -181,14 +182,6 @@ async def send_messages(
content: str = Body("", description="消息内容"),
tool_config: List[str] = Body([], description="工具配置", examples=[]),
):
"""
1. 获取历史记录
2. 存放用户输入
3. 获取模型配置
4. 获取agent信息
5. 组织模型输出
6. 存放模型输出
"""
init_conversations_table()
init_messages_table()
conn = get_mindpilot_db_connection()
@@ -210,9 +203,22 @@ async def send_messages(
"content": row['content']
})

if len(history) == 0:
# TODO 总结标题
pass

# print(history)

# 存放用户输入
insert_message(agent_status=0, role=role, content=content, files=json.dumps(files), conversation_id=conversation_id,
tool_calls=json.dumps({}))
_, timestamp_user = insert_message(agent_status=0, role=role, content=content, files=json.dumps(files),
conversation_id=conversation_id)

cursor.execute('''
UPDATE conversations
SET updated_at = ?
WHERE conversation_id = ?
''', (timestamp_user, conversation_id))
conn.commit()

# 获取模型配置
chat_model_config = get_config_from_id(config_id=config_id)
@@ -221,31 +227,57 @@ async def send_messages(
ret = await chat_online(content=content, history=history, chat_model_config=chat_model_config,
tool_config=tool_config, agent_id=agent_id)

# 解析模型输出
message_id = str(uuid4())
message_role = ret['choices'][0]['message']['role']
message_content = ret['choices'][0]['message']['content']
tool_calls = ret['choices'][0]['message']['tool_calls']

# 存放模型输出
timestamp = insert_message(agent_status=5, role=message_role, content=message_content, files=json.dumps(files),
conversation_id=conversation_id, tool_calls=json.dumps(tool_calls))

# 构建响应
response_data = {
"messages": [
{
"id": message_id,
"role": message_role,
"agent_status": 5,
response_messages = []
for message in ret:
if message['status'] == 7:
message_role = message['choices'][0]['role']
message_content = "Observation:\n" + message['choices'][0]['delta']['tool_calls'][0]['tool_output']
message_id, timestamp_message = insert_message(agent_status=7, role=message_role, content=message_content,
files=json.dumps({}), conversation_id=conversation_id)

cursor.execute('''
UPDATE conversations
SET updated_at = ?
WHERE conversation_id = ?
''', (timestamp_message, conversation_id))
conn.commit()

message_dict = {
"message_id": message_id,
"agent_status": 7,
"content": message_content,
"tool_calls": tool_calls,
"files": files,
"timestamp": datetime.fromisoformat(timestamp)
"files": [],
"timestamp": timestamp_message
}
]
}
response_messages.append(message_dict)

if message['status'] == 3:
message_role = message['choices'][0]['role']
message_content = message['choices'][0]['delta']['content']
message_list = split_message_content(message_content)
for m in message_list:
message_id, timestamp_message = insert_message(agent_status=3, role=message_role, content=m,
files=json.dumps({}), conversation_id=conversation_id)

cursor.execute('''
UPDATE conversations
SET updated_at = ?
WHERE conversation_id = ?
''', (timestamp_message, conversation_id))
conn.commit()

message_dict = {
"message_id": message_id,
"agent_status": 3,
"content": m,
"files": [],
"timestamp": timestamp_message
}

response_messages.append(message_dict)

# TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容

conn.close()

return BaseResponse(code=200, msg="success", data=response_data)
return BaseResponse(code=200, msg="success", data=response_messages)

+ 32
- 6
src/mindpilot/app/conversation/message.py View File

@@ -1,3 +1,4 @@
import re
from datetime import datetime
import sqlite3
import json
@@ -13,7 +14,6 @@ def init_messages_table():
agent_status INTEGER,
role TEXT,
content TEXT,
tool_calls TEXT, --json格式
files TEXT, --json格式
timestamp TEXT,
conversation_id TEXT,
@@ -24,15 +24,41 @@ def init_messages_table():
conn.close()


def insert_message(agent_status, role, content, files, conversation_id, tool_calls):
def insert_message(agent_status, role, content, files, conversation_id):
conn = get_mindpilot_db_connection()
cursor = conn.cursor()
timestamp = datetime.now().isoformat() # 获取当前时间戳
cursor.execute('''
INSERT INTO message (agent_status, role, content, tool_calls, files, timestamp, conversation_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (agent_status, role, content, tool_calls, json.dumps(files), timestamp, conversation_id))
INSERT INTO message (agent_status, role, content, files, timestamp, conversation_id)
VALUES (?, ?, ?, ?, ?, ?)
''', (agent_status, role, content, json.dumps(files), timestamp, conversation_id))
conn.commit()

# 获取插入行的 id
message_id = cursor.lastrowid

conn.close()

return timestamp
return message_id, timestamp


def split_message_content(message_content):
# 定义正则表达式匹配模式
pattern = r'(Question:|Thought:|Action:|Observation:)(.*?)(?=(Question:|Thought:|Action:|Observation:|$))'

# 使用正则表达式查找所有匹配的部分
matches = re.findall(pattern, message_content, re.DOTALL)

# 如果没有匹配到任何关键词,直接返回原文
if not matches:
return [message_content.strip()]

# 创建一个列表来按顺序存储结果
result = []

# 遍历匹配结果,并将它们存入列表
for match in matches:
section = match[0].strip() + match[1].strip() # 保留关键字和内容
result.append(section)

return result

Loading…
Cancel
Save