From 4e2f6e1b87e5af03447332a11fbcceb88f5f4572 Mon Sep 17 00:00:00 2001 From: guojialiang Date: Sat, 10 Aug 2024 22:40:32 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E5=AE=9E=E7=8E=B0Agent=E7=9A=84=E5=A2=9E?= =?UTF-8?q?=E5=88=A0=E6=94=B9=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mindpilot/app/agent/agent_api.py | 38 +++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/mindpilot/app/agent/agent_api.py b/src/mindpilot/app/agent/agent_api.py index 46fcf7e..ad74ee1 100644 --- a/src/mindpilot/app/agent/agent_api.py +++ b/src/mindpilot/app/agent/agent_api.py @@ -6,12 +6,13 @@ from ..utils.system_utils import BaseResponse, ListResponse def create_agent( agent_name: str = Body(..., examples=["ChatGPT Agent"]), - agent_abstract: str = Body("", description="Agent简介。"), + agent_abstract: str = Body("", description="Agent简介"), agent_info: str = Body("", description="Agent详细配置信息"), temperature: float = Body(0.8, description="LLM温度"), max_tokens: int = Body(4096, description="模型输出最大长度"), tool_config: List[str] = Body([], description="工具配置", examples=[["search_internet", "weather_check"]]), kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), + avatar: str = Body("", description="头像图片的Base64编码") ) -> BaseResponse: conn = sqlite3.connect('agents.db') cursor = conn.cursor() @@ -19,13 +20,14 @@ def create_agent( cursor.execute(''' CREATE TABLE IF NOT EXISTS agents ( id INTEGER PRIMARY KEY AUTOINCREMENT, - agent_name TEXT NOT NULL UNIQUE , + agent_name TEXT NOT NULL UNIQUE, agent_abstract TEXT, agent_info TEXT, temperature REAL, max_tokens INTEGER, tool_config TEXT, - kb_name TEXT + kb_name TEXT, + avatar TEXT ) ''') conn.commit() @@ -41,10 +43,10 @@ def create_agent( # TODO 处理知识库 cursor.execute(''' - INSERT INTO agents (agent_name, agent_abstract, agent_info, temperature, max_tokens, tool_config, kb_name) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO agents (agent_name, agent_abstract, agent_info, temperature, max_tokens, tool_config, kb_name, avatar) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', ( - agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config), ','.join(kb_name))) + agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config), ','.join(kb_name), avatar)) conn.commit() conn.close() @@ -81,6 +83,7 @@ def update_agent( max_tokens: int = Body(4096, description="模型输出最大长度"), tool_config: List[str] = Body([], description="工具配置", examples=[["search_internet", "weather_check"]]), kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), + avatar: str = Body("", description="头像图片的Base64编码") ) -> BaseResponse: conn = sqlite3.connect('agents.db') cursor = conn.cursor() @@ -93,14 +96,22 @@ def update_agent( if not existing_agent: return BaseResponse(code=404, msg=f"不存在ID为 {agent_id} 的Agent") + if agent_name is None or agent_name.strip() == "": + return BaseResponse(code=404, msg="Agent名称不能为空,请重新填写Agent名称") + + cursor.execute('SELECT id FROM agents WHERE agent_name = ?', (agent_name,)) + existing_agent = cursor.fetchone() + if existing_agent: + return BaseResponse(code=404, msg=f"已存在同名Agent {agent_name}") + #TODO 处理知识库 cursor.execute(''' UPDATE agents - SET agent_name = ?, agent_abstract = ?, agent_info = ?, temperature = ?, max_tokens = ?, tool_config = ?, kb_name = ? + SET agent_name = ?, agent_abstract = ?, agent_info = ?, temperature = ?, max_tokens = ?, tool_config = ?, kb_name = ?, avatar = ? WHERE id = ? ''', ( - agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config), ','.join(kb_name), + agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config), ','.join(kb_name), avatar, agent_id)) conn.commit() conn.close() @@ -115,13 +126,14 @@ def list_agent() -> ListResponse: cursor.execute(''' CREATE TABLE IF NOT EXISTS agents ( id INTEGER PRIMARY KEY AUTOINCREMENT, - agent_name TEXT NOT NULL UNIQUE , + agent_name TEXT NOT NULL UNIQUE, agent_abstract TEXT, agent_info TEXT, temperature REAL, max_tokens INTEGER, tool_config TEXT, - kb_name TEXT + kb_name TEXT, + avatar TEXT ) ''') conn.commit() @@ -144,7 +156,8 @@ def list_agent() -> ListResponse: "temperature": agent[4], "max_tokens": agent[5], "tool_config": agent[6].split(',') if agent[6] else [], - "kb_name": agent[7].split(',') if agent[7] else [] + "kb_name": agent[7].split(',') if agent[7] else [], + "avatar": agent[8] } agent_list.append(agent_dict) @@ -177,7 +190,8 @@ def get_agent( "temperature": agent[4], "max_tokens": agent[5], "tool_config": agent[6].split(',') if agent[6] else [], - "kb_name": agent[7].split(',') if agent[7] else [] + "kb_name": agent[7].split(',') if agent[7] else [], + "avatar": agent[8] } return ListResponse(code=200, msg=f"获取Agent ID为 {agent_id} 的信息成功", data=[agent_dict]) \ No newline at end of file