| @@ -6,12 +6,13 @@ from ..utils.system_utils import BaseResponse, ListResponse | |||||
| def create_agent( | def create_agent( | ||||
| agent_name: str = Body(..., examples=["ChatGPT Agent"]), | agent_name: str = Body(..., examples=["ChatGPT Agent"]), | ||||
| agent_abstract: str = Body("", description="Agent简介。"), | |||||
| agent_abstract: str = Body("", description="Agent简介"), | |||||
| agent_info: str = Body("", description="Agent详细配置信息"), | agent_info: str = Body("", description="Agent详细配置信息"), | ||||
| temperature: float = Body(0.8, description="LLM温度"), | temperature: float = Body(0.8, description="LLM温度"), | ||||
| max_tokens: int = Body(4096, description="模型输出最大长度"), | max_tokens: int = Body(4096, description="模型输出最大长度"), | ||||
| tool_config: List[str] = Body([], description="工具配置", examples=[["search_internet", "weather_check"]]), | tool_config: List[str] = Body([], description="工具配置", examples=[["search_internet", "weather_check"]]), | ||||
| kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), | kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), | ||||
| avatar: str = Body("", description="头像图片的Base64编码") | |||||
| ) -> BaseResponse: | ) -> BaseResponse: | ||||
| conn = sqlite3.connect('agents.db') | conn = sqlite3.connect('agents.db') | ||||
| cursor = conn.cursor() | cursor = conn.cursor() | ||||
| @@ -19,13 +20,14 @@ def create_agent( | |||||
| cursor.execute(''' | cursor.execute(''' | ||||
| CREATE TABLE IF NOT EXISTS agents ( | CREATE TABLE IF NOT EXISTS agents ( | ||||
| id INTEGER PRIMARY KEY AUTOINCREMENT, | id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
| agent_name TEXT NOT NULL UNIQUE , | |||||
| agent_name TEXT NOT NULL UNIQUE, | |||||
| agent_abstract TEXT, | agent_abstract TEXT, | ||||
| agent_info TEXT, | agent_info TEXT, | ||||
| temperature REAL, | temperature REAL, | ||||
| max_tokens INTEGER, | max_tokens INTEGER, | ||||
| tool_config TEXT, | tool_config TEXT, | ||||
| kb_name TEXT | |||||
| kb_name TEXT, | |||||
| avatar TEXT | |||||
| ) | ) | ||||
| ''') | ''') | ||||
| conn.commit() | conn.commit() | ||||
| @@ -41,10 +43,10 @@ def create_agent( | |||||
| # TODO 处理知识库 | # TODO 处理知识库 | ||||
| cursor.execute(''' | cursor.execute(''' | ||||
| INSERT INTO agents (agent_name, agent_abstract, agent_info, temperature, max_tokens, tool_config, kb_name) | |||||
| VALUES (?, ?, ?, ?, ?, ?, ?) | |||||
| INSERT INTO agents (agent_name, agent_abstract, agent_info, temperature, max_tokens, tool_config, kb_name, avatar) | |||||
| VALUES (?, ?, ?, ?, ?, ?, ?, ?) | |||||
| ''', ( | ''', ( | ||||
| agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config), ','.join(kb_name))) | |||||
| agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config), ','.join(kb_name), avatar)) | |||||
| conn.commit() | conn.commit() | ||||
| conn.close() | conn.close() | ||||
| @@ -81,6 +83,7 @@ def update_agent( | |||||
| max_tokens: int = Body(4096, description="模型输出最大长度"), | max_tokens: int = Body(4096, description="模型输出最大长度"), | ||||
| tool_config: List[str] = Body([], description="工具配置", examples=[["search_internet", "weather_check"]]), | tool_config: List[str] = Body([], description="工具配置", examples=[["search_internet", "weather_check"]]), | ||||
| kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), | kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), | ||||
| avatar: str = Body("", description="头像图片的Base64编码") | |||||
| ) -> BaseResponse: | ) -> BaseResponse: | ||||
| conn = sqlite3.connect('agents.db') | conn = sqlite3.connect('agents.db') | ||||
| cursor = conn.cursor() | cursor = conn.cursor() | ||||
| @@ -93,14 +96,22 @@ def update_agent( | |||||
| if not existing_agent: | if not existing_agent: | ||||
| return BaseResponse(code=404, msg=f"不存在ID为 {agent_id} 的Agent") | return BaseResponse(code=404, msg=f"不存在ID为 {agent_id} 的Agent") | ||||
| if agent_name is None or agent_name.strip() == "": | |||||
| return BaseResponse(code=404, msg="Agent名称不能为空,请重新填写Agent名称") | |||||
| cursor.execute('SELECT id FROM agents WHERE agent_name = ?', (agent_name,)) | |||||
| existing_agent = cursor.fetchone() | |||||
| if existing_agent: | |||||
| return BaseResponse(code=404, msg=f"已存在同名Agent {agent_name}") | |||||
| #TODO 处理知识库 | #TODO 处理知识库 | ||||
| cursor.execute(''' | cursor.execute(''' | ||||
| UPDATE agents | UPDATE agents | ||||
| SET agent_name = ?, agent_abstract = ?, agent_info = ?, temperature = ?, max_tokens = ?, tool_config = ?, kb_name = ? | |||||
| SET agent_name = ?, agent_abstract = ?, agent_info = ?, temperature = ?, max_tokens = ?, tool_config = ?, kb_name = ?, avatar = ? | |||||
| WHERE id = ? | WHERE id = ? | ||||
| ''', ( | ''', ( | ||||
| agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config), ','.join(kb_name), | |||||
| agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config), ','.join(kb_name), avatar, | |||||
| agent_id)) | agent_id)) | ||||
| conn.commit() | conn.commit() | ||||
| conn.close() | conn.close() | ||||
| @@ -115,13 +126,14 @@ def list_agent() -> ListResponse: | |||||
| cursor.execute(''' | cursor.execute(''' | ||||
| CREATE TABLE IF NOT EXISTS agents ( | CREATE TABLE IF NOT EXISTS agents ( | ||||
| id INTEGER PRIMARY KEY AUTOINCREMENT, | id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
| agent_name TEXT NOT NULL UNIQUE , | |||||
| agent_name TEXT NOT NULL UNIQUE, | |||||
| agent_abstract TEXT, | agent_abstract TEXT, | ||||
| agent_info TEXT, | agent_info TEXT, | ||||
| temperature REAL, | temperature REAL, | ||||
| max_tokens INTEGER, | max_tokens INTEGER, | ||||
| tool_config TEXT, | tool_config TEXT, | ||||
| kb_name TEXT | |||||
| kb_name TEXT, | |||||
| avatar TEXT | |||||
| ) | ) | ||||
| ''') | ''') | ||||
| conn.commit() | conn.commit() | ||||
| @@ -144,7 +156,8 @@ def list_agent() -> ListResponse: | |||||
| "temperature": agent[4], | "temperature": agent[4], | ||||
| "max_tokens": agent[5], | "max_tokens": agent[5], | ||||
| "tool_config": agent[6].split(',') if agent[6] else [], | "tool_config": agent[6].split(',') if agent[6] else [], | ||||
| "kb_name": agent[7].split(',') if agent[7] else [] | |||||
| "kb_name": agent[7].split(',') if agent[7] else [], | |||||
| "avatar": agent[8] | |||||
| } | } | ||||
| agent_list.append(agent_dict) | agent_list.append(agent_dict) | ||||
| @@ -177,7 +190,8 @@ def get_agent( | |||||
| "temperature": agent[4], | "temperature": agent[4], | ||||
| "max_tokens": agent[5], | "max_tokens": agent[5], | ||||
| "tool_config": agent[6].split(',') if agent[6] else [], | "tool_config": agent[6].split(',') if agent[6] else [], | ||||
| "kb_name": agent[7].split(',') if agent[7] else [] | |||||
| "kb_name": agent[7].split(',') if agent[7] else [], | |||||
| "avatar": agent[8] | |||||
| } | } | ||||
| return ListResponse(code=200, msg=f"获取Agent ID为 {agent_id} 的信息成功", data=[agent_dict]) | return ListResponse(code=200, msg=f"获取Agent ID为 {agent_id} 的信息成功", data=[agent_dict]) | ||||