Browse Source

feat:创建Agent API实现

main
guojialiang 1 year ago
parent
commit
97b98b2746
7 changed files with 86 additions and 12 deletions
  1. +49
    -0
      src/mindpilot/app/agent/agent_api.py
  2. +1
    -7
      src/mindpilot/app/agent/agents_registry.py
  3. +10
    -0
      src/mindpilot/app/api/agent_routes.py
  4. +2
    -0
      src/mindpilot/app/api/api_server.py
  5. +1
    -1
      src/mindpilot/app/api/tool_routes.py
  6. +9
    -3
      src/mindpilot/app/chat/chat.py
  7. +14
    -1
      src/mindpilot/app/utils/system_utils.py

+ 49
- 0
src/mindpilot/app/agent/agent_api.py View File

@@ -0,0 +1,49 @@
from typing import List, Optional
from fastapi import Body, File, UploadFile
import sqlite3
from ..utils.system_utils import BaseResponse


def create_agent(
agent_name: str = Body(..., examples=["ChatGPT Agent"]),
agent_abstract: str = Body("", description="Agent简介。"),
agent_info: str = Body("", description="Agent详细配置信息"),
temperature: float = Body(0.8, description="LLM温度"),
max_tokens: int = Body(4096, description="模型输出最大长度"),
tool_config: List[str] = Body([], description="工具配置", examples=[["search_internet","weather_check"]]),
# kb_files: Optional[List[UploadFile]] = File(None, description="知识库文件"),
) -> BaseResponse:
conn = sqlite3.connect('agents.db')
cursor = conn.cursor()

cursor.execute('''
CREATE TABLE IF NOT EXISTS agents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
agent_name TEXT NOT NULL UNIQUE ,
agent_abstract TEXT,
agent_info TEXT,
temperature REAL,
max_tokens INTEGER,
tool_config TEXT
)
''')
conn.commit()

if agent_name is None or agent_name.strip() == "":
return BaseResponse(code=404, msg="Agent名称不能为空,请重新填写Agent名称")

cursor.execute('SELECT id FROM agents WHERE agent_name = ?', (agent_name,))
existing_agent = cursor.fetchone()
if existing_agent:
return BaseResponse(code=404, msg=f"已存在同名Agent {agent_name}")

cursor.execute('''
INSERT INTO agents (agent_name, agent_abstract, agent_info, temperature, max_tokens, tool_config)
VALUES (?, ?, ?, ?, ?, ?)
''', (agent_name, agent_abstract, agent_info, temperature, max_tokens, ','.join(tool_config)))
conn.commit()
conn.close()

# TODO 处理上传的知识库文件

return BaseResponse(code=200, msg=f"已新增Agent {agent_name}")

+ 1
- 7
src/mindpilot/app/agent/agents_registry.py View File

@@ -36,17 +36,11 @@ def agents_registry(
HumanMessagePromptTemplate(
prompt=PromptTemplate(
input_variables=['agent_scratchpad', 'input'],
template='''
{input}

{agent_scratchpad}
(reminder to respond in a JSON blob no matter what)
'''
template='''{input}\n\n{agent_scratchpad}\n(reminder to respond in a JSON blob no matter what)\n'''
)
)
]
)
# print(prompt)

agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)



+ 10
- 0
src/mindpilot/app/api/agent_routes.py View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from fastapi import APIRouter, Request
from ..agent.agent_api import create_agent

agent_router = APIRouter(prefix="/agent", tags=["Agent配置"])

agent_router.post(
"/create_agent",
summary="创建Agent",
)(create_agent)

+ 2
- 0
src/mindpilot/app/api/api_server.py View File

@@ -4,6 +4,7 @@ from starlette.responses import RedirectResponse

from .chat_routes import chat_router
from .tool_routes import tool_router
from .agent_routes import agent_router


def create_app(run_mode: str = None):
@@ -22,5 +23,6 @@ def create_app(run_mode: str = None):

app.include_router(chat_router)
app.include_router(tool_router)
app.include_router(agent_router)

return app

+ 1
- 1
src/mindpilot/app/api/tool_routes.py View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter
from ..utils.system_utils import get_tool

tool_router = APIRouter(prefix="/tools", tags=["MindPilot对话"])
tool_router = APIRouter(prefix="/tools", tags=["获取工具"])


@tool_router.get("/available_tools", summary="获取可用工具")


+ 9
- 3
src/mindpilot/app/chat/chat.py View File

@@ -46,6 +46,8 @@ def create_models_from_config(configs, callbacks, stream):
prompt = OPENAI_PROMPT
else:
#TODO 其他不兼容OPENAI API格式的平台
model = None
prompt = None
pass

return model, prompt
@@ -104,14 +106,13 @@ async def chat(
"gpt-4o-mini": {
"temperature": 0.8,
"max_tokens": 8192,
"history_len": 10,
"prompt_name": "default",
"callbacks": True,
},
}
}]),
tool_config: List[str] = Body([], description="工具配置", examples=[]),
agent_enable: bool = Body(True, description="是否启用Agent")
agent_enable: bool = Body(True, description="是否启用Agent"),
agent_name: str = Body("default", description="使用的Agent,默认为default")
):
"""Agent 对话"""

@@ -122,6 +123,11 @@ async def chat(
model, prompt = create_models_from_config(
callbacks=callbacks, configs=chat_model_config, stream=stream
)

if agent_name != "default":
#TODO 从数据库中获取Agent相关配置
pass

all_tools = get_tool().values()
tool_configs = tool_config or TOOL_CONFIG
tools = [tool for tool in all_tools if tool.name in tool_configs]


+ 14
- 1
src/mindpilot/app/utils/system_utils.py View File

@@ -35,7 +35,7 @@ from langchain_openai.llms import OpenAI
# TEMPERATURE,
# log_verbose,
# )
# from chatchat.server.pydantic_v2 import BaseModel, Field
from .pydantic_v2 import BaseModel, Field

logger = logging.getLogger()

@@ -180,3 +180,16 @@ def get_tool_config(name: str = None) -> Dict:
return TOOL_CONFIG
else:
return TOOL_CONFIG.get(name, {})

class BaseResponse(BaseModel):
code: int = Field(200, description="API status code")
msg: str = Field("success", description="API status message")
data: Any = Field(None, description="API data")

class Config:
json_schema_extra = {
"example": {
"code": 200,
"msg": "success",
}
}

Loading…
Cancel
Save