diff --git a/src/mindpilot/app/agent/utils.py b/src/mindpilot/app/agent/utils.py new file mode 100644 index 0000000..40eeb25 --- /dev/null +++ b/src/mindpilot/app/agent/utils.py @@ -0,0 +1,27 @@ +import sqlite3 + + +def get_agent_from_id(agent_id:int): + conn = sqlite3.connect('agents.db') + cursor = conn.cursor() + + cursor.execute('SELECT * FROM agents WHERE id = ?', (agent_id,)) + agent = cursor.fetchone() + conn.close() + + if not agent: + return None + + agent_dict = { + "id": agent[0], + "agent_name": agent[1], + "agent_abstract": agent[2], + "agent_info": agent[3], + "temperature": agent[4], + "max_tokens": agent[5], + "tool_config": agent[6].split(',') if agent[6] else [], + "kb_name": agent[7].split(',') if agent[7] else [], + "avatar": agent[8] + } + + return agent_dict \ No newline at end of file diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index f0fbf9b..bef92a2 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -18,6 +18,7 @@ from ..callback_handler.agent_callback_handler import ( from ..chat.utils import History from ..configs import MODEL_CONFIG, TOOL_CONFIG, OPENAI_PROMPT from ..utils.system_utils import get_ChatOpenAI, get_tool, wrap_done, MsgType +from ..agent.utils import get_agent_from_id def create_models_from_config(configs, callbacks, stream): @@ -62,7 +63,7 @@ def create_models_chains( [i.to_msg_template() for i in history] ) else: - chat_prompt = None + chat_prompt = ChatPromptTemplate.from_messages([]) llm = models llm.callbacks = callbacks @@ -118,17 +119,29 @@ async def chat( model, prompt = create_models_from_config( callbacks=callbacks, configs=chat_model_config, stream=stream ) + + all_tools = get_tool().values() + tool_configs = tool_config + if agent_enable: if agent_id != -1: - # TODO 从数据库中获取Agent相关配置 - pass + agent_dict = get_agent_from_id(agent_id) + agent_name = agent_dict["agent_name"] + agent_abstract = agent_dict["agent_abstract"] + agent_info = agent_dict["agent_info"] + if not tool_config: + tool_configs = agent_dict["tool_config"] + agent_prompt_pre = "Your name is " + agent_name + "." + agent_abstract + ". Below is your detailed information:" + agent_info + "." + agent_prompt_after = "DO NOT forget " + agent_prompt_pre + prompt = agent_prompt_pre + prompt + agent_prompt_after + # TODO 处理知识库 else: prompt = prompt # 默认Agent提示模板 - all_tools = get_tool().values() - tool_configs = tool_config or TOOL_CONFIG + tool_configs = tool_configs or TOOL_CONFIG tools = [tool for tool in all_tools if tool.name in tool_configs] tools = [t.copy(update={"callbacks": callbacks}) for t in tools] + full_chain = create_models_chains( prompts=prompt, models=model, diff --git a/src/mindpilot/app/utils/system_utils.py b/src/mindpilot/app/utils/system_utils.py index ac1b12d..7acf06e 100644 --- a/src/mindpilot/app/utils/system_utils.py +++ b/src/mindpilot/app/utils/system_utils.py @@ -181,6 +181,7 @@ def get_tool_config(name: str = None) -> Dict: else: return TOOL_CONFIG.get(name, {}) + class BaseResponse(BaseModel): code: int = Field(200, description="API status code") msg: str = Field("success", description="API status message") @@ -194,6 +195,7 @@ class BaseResponse(BaseModel): } } + class ListResponse(BaseResponse): data: List[Any] = Field(..., description="List of data") @@ -204,4 +206,4 @@ class ListResponse(BaseResponse): "msg": "success", "data": ["doc1.docx", "doc2.pdf", "doc3.txt"], } - } \ No newline at end of file + }