Browse Source

feat:agent部分完善

main
gjl 1 year ago
parent
commit
d1b53ff9dd
3 changed files with 48 additions and 6 deletions
  1. +27
    -0
      src/mindpilot/app/agent/utils.py
  2. +18
    -5
      src/mindpilot/app/chat/chat.py
  3. +3
    -1
      src/mindpilot/app/utils/system_utils.py

+ 27
- 0
src/mindpilot/app/agent/utils.py View File

@@ -0,0 +1,27 @@
import sqlite3


def get_agent_from_id(agent_id:int):
conn = sqlite3.connect('agents.db')
cursor = conn.cursor()

cursor.execute('SELECT * FROM agents WHERE id = ?', (agent_id,))
agent = cursor.fetchone()
conn.close()

if not agent:
return None

agent_dict = {
"id": agent[0],
"agent_name": agent[1],
"agent_abstract": agent[2],
"agent_info": agent[3],
"temperature": agent[4],
"max_tokens": agent[5],
"tool_config": agent[6].split(',') if agent[6] else [],
"kb_name": agent[7].split(',') if agent[7] else [],
"avatar": agent[8]
}

return agent_dict

+ 18
- 5
src/mindpilot/app/chat/chat.py View File

@@ -18,6 +18,7 @@ from ..callback_handler.agent_callback_handler import (
from ..chat.utils import History
from ..configs import MODEL_CONFIG, TOOL_CONFIG, OPENAI_PROMPT
from ..utils.system_utils import get_ChatOpenAI, get_tool, wrap_done, MsgType
from ..agent.utils import get_agent_from_id


def create_models_from_config(configs, callbacks, stream):
@@ -62,7 +63,7 @@ def create_models_chains(
[i.to_msg_template() for i in history]
)
else:
chat_prompt = None
chat_prompt = ChatPromptTemplate.from_messages([])

llm = models
llm.callbacks = callbacks
@@ -118,17 +119,29 @@ async def chat(
model, prompt = create_models_from_config(
callbacks=callbacks, configs=chat_model_config, stream=stream
)

all_tools = get_tool().values()
tool_configs = tool_config

if agent_enable:
if agent_id != -1:
# TODO 从数据库中获取Agent相关配置
pass
agent_dict = get_agent_from_id(agent_id)
agent_name = agent_dict["agent_name"]
agent_abstract = agent_dict["agent_abstract"]
agent_info = agent_dict["agent_info"]
if not tool_config:
tool_configs = agent_dict["tool_config"]
agent_prompt_pre = "Your name is " + agent_name + "." + agent_abstract + ". Below is your detailed information:" + agent_info + "."
agent_prompt_after = "DO NOT forget " + agent_prompt_pre
prompt = agent_prompt_pre + prompt + agent_prompt_after
# TODO 处理知识库
else:
prompt = prompt # 默认Agent提示模板

all_tools = get_tool().values()
tool_configs = tool_config or TOOL_CONFIG
tool_configs = tool_configs or TOOL_CONFIG
tools = [tool for tool in all_tools if tool.name in tool_configs]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]

full_chain = create_models_chains(
prompts=prompt,
models=model,


+ 3
- 1
src/mindpilot/app/utils/system_utils.py View File

@@ -181,6 +181,7 @@ def get_tool_config(name: str = None) -> Dict:
else:
return TOOL_CONFIG.get(name, {})


class BaseResponse(BaseModel):
code: int = Field(200, description="API status code")
msg: str = Field("success", description="API status message")
@@ -194,6 +195,7 @@ class BaseResponse(BaseModel):
}
}


class ListResponse(BaseResponse):
data: List[Any] = Field(..., description="List of data")

@@ -204,4 +206,4 @@ class ListResponse(BaseResponse):
"msg": "success",
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
}
}
}

Loading…
Cancel
Save