| @@ -14,7 +14,7 @@ def create_agent( | |||
| kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), | |||
| avatar: str = Body("", description="头像图片的Base64编码") | |||
| ) -> BaseResponse: | |||
| conn = sqlite3.connect('agents.db') | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| cursor = conn.cursor() | |||
| cursor.execute(''' | |||
| @@ -79,7 +79,7 @@ def create_agent( | |||
| def delete_agent( | |||
| agent_id: int = Body(..., examples=["1"]) | |||
| ) -> BaseResponse: | |||
| conn = sqlite3.connect('agents.db') | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| cursor = conn.cursor() | |||
| if agent_id is None: | |||
| @@ -108,7 +108,7 @@ def update_agent( | |||
| kb_name: List[str] = Body([], examples=[["ChatGPT KB"]]), | |||
| avatar: str = Body("", description="头像图片的Base64编码") | |||
| ) -> BaseResponse: | |||
| conn = sqlite3.connect('agents.db') | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| cursor = conn.cursor() | |||
| if agent_id is None: | |||
| @@ -139,7 +139,7 @@ def update_agent( | |||
| def list_agent() -> ListResponse: | |||
| conn = sqlite3.connect('agents.db') | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| cursor = conn.cursor() | |||
| cursor.execute(''' | |||
| @@ -186,7 +186,7 @@ def list_agent() -> ListResponse: | |||
| def get_agent( | |||
| agent_id: int = Query(..., examples=["1"]), | |||
| ): | |||
| conn = sqlite3.connect('agents.db') | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| cursor = conn.cursor() | |||
| if agent_id is None: | |||
| @@ -1,8 +1,8 @@ | |||
| import sqlite3 | |||
| def get_agent_from_id(agent_id:int): | |||
| conn = sqlite3.connect('agents.db') | |||
| def get_agent_from_id(agent_id: int): | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| cursor = conn.cursor() | |||
| cursor.execute('SELECT * FROM agents WHERE id = ?', (agent_id,)) | |||
| @@ -13,7 +13,7 @@ def get_agent_from_id(agent_id:int): | |||
| return None | |||
| agent_dict = { | |||
| "id": agent[0], | |||
| "agent_id": agent[0], | |||
| "agent_name": agent[1], | |||
| "agent_abstract": agent[2], | |||
| "agent_info": agent[3], | |||
| @@ -24,4 +24,4 @@ def get_agent_from_id(agent_id:int): | |||
| "avatar": agent[8] | |||
| } | |||
| return agent_dict | |||
| return agent_dict | |||
| @@ -6,6 +6,7 @@ from .chat_routes import chat_router | |||
| from .tool_routes import tool_router | |||
| from .agent_routes import agent_router | |||
| from .config_routes import config_router | |||
| from .conversation_routes import conversation_router | |||
| def create_app(run_mode: str = None): | |||
| @@ -26,5 +27,6 @@ def create_app(run_mode: str = None): | |||
| app.include_router(tool_router) | |||
| app.include_router(agent_router) | |||
| app.include_router(config_router) | |||
| app.include_router(conversation_router) | |||
| return app | |||
| @@ -0,0 +1,30 @@ | |||
| from __future__ import annotations | |||
| from fastapi import APIRouter, Request | |||
| from ..conversation.conversation_api import add_conversation, list_conversations, get_conversation, delete_conversation, send_messages | |||
| conversation_router = APIRouter(prefix="/api/conversation", tags=["对话接口"]) | |||
| conversation_router.post( | |||
| "", | |||
| summary="创建conversation", | |||
| )(add_conversation) | |||
| conversation_router.get( | |||
| "", | |||
| summary="获取conversation列表", | |||
| )(list_conversations) | |||
| conversation_router.get( | |||
| "{conversation_id}", | |||
| summary="获取单个对话详情" | |||
| )(get_conversation) | |||
| conversation_router.delete( | |||
| "{conversation_id}", | |||
| summary="删除对话" | |||
| )(delete_conversation) | |||
| conversation_router.post( | |||
| "{conversation_id}/messages", | |||
| summary="发送消息" | |||
| )(send_messages) | |||
| @@ -27,29 +27,22 @@ def create_models_from_config(configs, callbacks, stream): | |||
| platform = configs["platform"] | |||
| base_url = configs["base_url"] | |||
| api_key = configs["api_key"] | |||
| is_openai = configs["is_openai"] | |||
| llm_model = configs["llm_model"] | |||
| model_name, params = next(iter(llm_model.items())) | |||
| callbacks = callbacks if params.get("callbacks", False) else None | |||
| if is_openai: | |||
| model_instance = get_ChatOpenAI( | |||
| model_name=model_name, | |||
| base_url=base_url, | |||
| api_key=api_key, | |||
| temperature=params.get("temperature", 0.8), | |||
| max_tokens=params.get("max_tokens", 4096), | |||
| callbacks=callbacks, | |||
| streaming=stream, | |||
| ) | |||
| model = model_instance | |||
| prompt = OPENAI_PROMPT | |||
| else: | |||
| # TODO 其他不兼容OPENAI API格式的平台 | |||
| model = None | |||
| prompt = None | |||
| pass | |||
| model_instance = get_ChatOpenAI( | |||
| model_name=model_name, | |||
| base_url=base_url, | |||
| api_key=api_key, | |||
| temperature=params.get("temperature", 0.8), | |||
| max_tokens=params.get("max_tokens", 4096), | |||
| callbacks=callbacks, | |||
| streaming=stream, | |||
| ) | |||
| model = model_instance | |||
| prompt = OPENAI_PROMPT | |||
| return model, prompt | |||
| @@ -177,7 +170,6 @@ async def chat( | |||
| last_tool = {} | |||
| async for chunk in callback.aiter(): | |||
| data = json.loads(chunk) | |||
| # print("data:{}".format(data)) | |||
| data["tool_calls"] = [] | |||
| data["message_type"] = MsgType.TEXT | |||
| @@ -245,6 +237,7 @@ async def chat( | |||
| async for chunk in chat_iterator(): | |||
| data = json.loads(chunk) | |||
| print(data) | |||
| if text := data["choices"][0]["delta"]["content"]: | |||
| ret.content += text | |||
| if data["status"] == AgentStatus.tool_end: | |||
| @@ -253,3 +246,148 @@ async def chat( | |||
| ret.created = data["created"] | |||
| return ret.model_dump() | |||
| async def chat_online( | |||
| content: str, | |||
| history: List[History], | |||
| chat_model_config: dict, | |||
| tool_config: List[str], | |||
| agent_id: int | |||
| ): | |||
| async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]: | |||
| callback = AgentExecutorAsyncIteratorCallbackHandler() | |||
| callbacks = [callback] | |||
| model, prompt = create_models_from_config( | |||
| callbacks=callbacks, configs=chat_model_config, stream=False | |||
| ) | |||
| all_tools = get_tool().values() | |||
| tool_configs = tool_config | |||
| agent_enable = True | |||
| if agent_id != -1: | |||
| if agent_id != 0: | |||
| agent_dict = get_agent_from_id(agent_id) | |||
| agent_name = agent_dict["agent_name"] | |||
| agent_abstract = agent_dict["agent_abstract"] | |||
| agent_info = agent_dict["agent_info"] | |||
| if not tool_config: | |||
| tool_configs = agent_dict["tool_config"] | |||
| agent_prompt_pre = "Your name is " + agent_name + "." + agent_abstract + ". Below is your detailed information:" + agent_info + "." | |||
| agent_prompt_after = "DO NOT forget " + agent_prompt_pre | |||
| prompt = agent_prompt_pre + prompt + agent_prompt_after | |||
| # TODO 处理知识库 | |||
| else: | |||
| prompt = prompt # 默认Agent提示模板 | |||
| else: | |||
| agent_enable = False | |||
| tool_configs = tool_configs or TOOL_CONFIG | |||
| tools = [tool for tool in all_tools if tool.name in tool_configs] | |||
| tools = [t.copy(update={"callbacks": callbacks}) for t in tools] | |||
| full_chain = create_models_chains( | |||
| prompts=prompt, | |||
| models=model, | |||
| tools=tools, | |||
| callbacks=callbacks, | |||
| history=history, | |||
| agent_enable=agent_enable | |||
| ) | |||
| _history = [History.from_data(h) for h in history] | |||
| chat_history = [h.to_msg_tuple() for h in _history] | |||
| history_message = convert_to_messages(chat_history) | |||
| task = asyncio.create_task( | |||
| wrap_done( | |||
| full_chain.ainvoke( | |||
| { | |||
| "input": content, | |||
| "chat_history": history_message, | |||
| } | |||
| ), | |||
| callback.done, | |||
| ) | |||
| ) | |||
| last_tool = {} | |||
| async for chunk in callback.aiter(): | |||
| data = json.loads(chunk) | |||
| # print("data:{}".format(data)) | |||
| data["tool_calls"] = [] | |||
| data["message_type"] = MsgType.TEXT | |||
| if data["status"] == AgentStatus.tool_start: | |||
| last_tool = { | |||
| "index": 0, | |||
| "id": data["run_id"], | |||
| "type": "function", | |||
| "function": { | |||
| "name": data["tool"], | |||
| "arguments": data["tool_input"], | |||
| }, | |||
| "tool_output": None, | |||
| "is_error": False, | |||
| } | |||
| data["tool_calls"].append(last_tool) | |||
| if data["status"] in [AgentStatus.tool_end]: | |||
| last_tool.update( | |||
| tool_output=data["tool_output"], | |||
| is_error=data.get("is_error", False), | |||
| ) | |||
| data["tool_calls"] = [last_tool] | |||
| last_tool = {} | |||
| try: | |||
| tool_output = json.loads(data["tool_output"]) | |||
| if message_type := tool_output.get("message_type"): | |||
| data["message_type"] = message_type | |||
| except: | |||
| ... | |||
| elif data["status"] == AgentStatus.agent_finish: | |||
| try: | |||
| tool_output = json.loads(data["text"]) | |||
| if message_type := tool_output.get("message_type"): | |||
| data["message_type"] = message_type | |||
| except: | |||
| ... | |||
| ret = OpenAIChatOutput( | |||
| id=f"chat{uuid.uuid4()}", | |||
| object="chat.completion.chunk", | |||
| content=data.get("text", ""), | |||
| role="assistant", | |||
| tool_calls=data["tool_calls"], | |||
| model=model.model_name, | |||
| status=data["status"], | |||
| message_type=data["message_type"], | |||
| ) | |||
| yield ret.model_dump_json() | |||
| await task | |||
| ret = OpenAIChatOutput( | |||
| id=f"chat{uuid.uuid4()}", | |||
| object="chat.completion", | |||
| content="", | |||
| role="assistant", | |||
| finish_reason="stop", | |||
| tool_calls=[], | |||
| status=AgentStatus.agent_finish, | |||
| message_type=MsgType.TEXT, | |||
| ) | |||
| async for chunk in chat_iterator(): | |||
| data = json.loads(chunk) | |||
| # print(data) | |||
| if text := data["choices"][0]["delta"]["content"]: | |||
| ret.content += text | |||
| if data["status"] == AgentStatus.tool_end: | |||
| ret.tool_calls += data["choices"][0]["delta"]["tool_calls"] | |||
| ret.model = data["model"] | |||
| ret.created = data["created"] | |||
| return ret.model_dump() | |||
| @@ -0,0 +1,251 @@ | |||
| import json | |||
| from typing import List | |||
| from fastapi import Body | |||
| from uuid import uuid4 | |||
| from datetime import datetime | |||
| import sqlite3 | |||
| from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection | |||
| from .message import init_messages_table, insert_message | |||
| from ..model_configs.utils import get_config_from_id | |||
| from ..agent.utils import get_agent_from_id | |||
| from ..chat.chat import chat_online | |||
| def init_conversations_table(): | |||
| conn = get_mindpilot_db_connection() | |||
| conn.execute(''' | |||
| CREATE TABLE IF NOT EXISTS conversations ( | |||
| conversation_id TEXT PRIMARY KEY, | |||
| title TEXT, | |||
| created_at TEXT, | |||
| updated_at TEXT, | |||
| is_summarized BOOLEAN, | |||
| agent_id INTEGER | |||
| ) | |||
| ''') | |||
| conn.commit() | |||
| conn.close() | |||
| async def add_conversation( | |||
| agent_id: int = Body(0, description="使用agent情况,-1代表不使用agent,0代表使用默认agent"), | |||
| ): | |||
| init_conversations_table() | |||
| init_messages_table() | |||
| conversation_id = str(uuid4()) | |||
| created_at = updated_at = datetime.now().isoformat() | |||
| is_summarized = False | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| cursor.execute(''' | |||
| INSERT INTO conversations (conversation_id, title, created_at, updated_at, is_summarized, agent_id) | |||
| VALUES (?, ?, ?, ?, ?, ?) | |||
| ''', (conversation_id, "New Conversation", created_at, updated_at, is_summarized, agent_id)) | |||
| conn.commit() | |||
| conn.close() | |||
| response_data = { | |||
| "conversation_id": conversation_id, | |||
| "title": "New Conversation", | |||
| "created_at": datetime.fromisoformat(created_at), | |||
| "updated_at": datetime.fromisoformat(updated_at), | |||
| "is_summarized": is_summarized, | |||
| "agent_id": agent_id, | |||
| } | |||
| return BaseResponse(code=200, msg="success", data=response_data) | |||
| async def list_conversations(): | |||
| init_conversations_table() | |||
| init_messages_table() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| cursor.execute(''' | |||
| SELECT conversation_id, title, created_at, updated_at, is_summarized, agent_id | |||
| FROM conversations | |||
| ''') | |||
| rows = cursor.fetchall() | |||
| conn.close() | |||
| conversations = [] | |||
| for row in rows: | |||
| conversation = { | |||
| "conversation_id": row['conversation_id'], | |||
| "title": row['title'], | |||
| "created_at": datetime.fromisoformat(row['created_at']), | |||
| "updated_at": datetime.fromisoformat(row['updated_at']), | |||
| "is_summarized": row['is_summarized'], | |||
| "agent_id": row['agent_id'], | |||
| } | |||
| conversations.append(conversation) | |||
| return ListResponse(code=200, msg="success", data=conversations) | |||
| async def get_conversation(conversation_id: str): | |||
| init_conversations_table() | |||
| init_messages_table() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| # 获取对话详情 | |||
| cursor.execute(''' | |||
| SELECT conversation_id, title, created_at, updated_at, is_summarized, agent_id | |||
| FROM conversations | |||
| WHERE conversation_id = ? | |||
| ''', (conversation_id,)) | |||
| conversation_row = cursor.fetchone() | |||
| if not conversation_row: | |||
| conn.close() | |||
| return BaseResponse(code=404, msg="Conversation not found") | |||
| conversation = { | |||
| "conversation_id": conversation_row['conversation_id'], | |||
| "title": conversation_row['title'], | |||
| "created_at": datetime.fromisoformat(conversation_row['created_at']), | |||
| "updated_at": datetime.fromisoformat(conversation_row['updated_at']), | |||
| "is_summarized": conversation_row['is_summarized'], | |||
| "agent_id": conversation_row['agent_id'], | |||
| "messages": [] | |||
| } | |||
| # 获取对话的所有消息 | |||
| cursor.execute(''' | |||
| SELECT id, agent_status, role, content, files, timestamp | |||
| FROM message | |||
| WHERE conversation_id = ? | |||
| ''', (conversation_id,)) | |||
| message_rows = cursor.fetchall() | |||
| for row in message_rows: | |||
| message = { | |||
| "message_id": row['id'], | |||
| "agent_status": row['agent_status'], | |||
| "role": row['role'], | |||
| "content": row['content'], | |||
| "files": json.loads(row['files']), | |||
| "timestamp": datetime.fromisoformat(row['timestamp']) | |||
| } | |||
| conversation['messages'].append(message) | |||
| conn.close() | |||
| return BaseResponse(code=200, msg="success", data=conversation) | |||
| async def delete_conversation(conversation_id: str): | |||
| init_conversations_table() | |||
| init_messages_table() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| # 检查对话是否存在 | |||
| cursor.execute(''' | |||
| SELECT conversation_id | |||
| FROM conversations | |||
| WHERE conversation_id = ? | |||
| ''', (conversation_id,)) | |||
| conversation_row = cursor.fetchone() | |||
| if not conversation_row: | |||
| conn.close() | |||
| return BaseResponse(code=404, msg="Conversation not found", data={"conversation_id": "-1"}) | |||
| # 删除对话相关的消息 | |||
| cursor.execute(''' | |||
| DELETE FROM message | |||
| WHERE conversation_id = ? | |||
| ''', (conversation_id,)) | |||
| # 删除对话 | |||
| cursor.execute(''' | |||
| DELETE FROM conversations | |||
| WHERE conversation_id = ? | |||
| ''', (conversation_id,)) | |||
| conn.commit() | |||
| conn.close() | |||
| return BaseResponse(code=200, msg="success", data={"conversation_id": conversation_id}) | |||
| async def send_messages( | |||
| conversation_id: str, | |||
| role: str = Body("", description="消息角色:user/assistant", examples=["user", "assistant"]), | |||
| agent_id: int = Body(0, description="使用agent,0为默认,-1为不使用agent", examples=[0]), | |||
| config_id: int = Body("0", description="模型配置", examples=[1]), | |||
| files: dict = Body({}, description="文件", examples=[{}]), | |||
| content: str = Body("", description="消息内容"), | |||
| tool_config: List[str] = Body([], description="工具配置", examples=[]), | |||
| ): | |||
| """ | |||
| 1. 获取历史记录 | |||
| 2. 存放用户输入 | |||
| 3. 获取模型配置 | |||
| 4. 获取agent信息 | |||
| 5. 组织模型输出 | |||
| 6. 存放模型输出 | |||
| """ | |||
| init_conversations_table() | |||
| init_messages_table() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| # 获取历史记录 | |||
| cursor.execute(''' | |||
| SELECT role, content, timestamp | |||
| FROM message | |||
| WHERE conversation_id = ? | |||
| ORDER BY timestamp | |||
| ''', (conversation_id,)) | |||
| message_rows = cursor.fetchall() | |||
| history = [] | |||
| for row in message_rows: | |||
| history.append({ | |||
| "role": row['role'], | |||
| "content": row['content'] | |||
| }) | |||
| # 存放用户输入 | |||
| insert_message(agent_status=0, role=role, content=content, files=json.dumps(files), conversation_id=conversation_id, | |||
| tool_calls=json.dumps({})) | |||
| # 获取模型配置 | |||
| chat_model_config = get_config_from_id(config_id=config_id) | |||
| # 获取模型输出 | |||
| ret = await chat_online(content=content, history=history, chat_model_config=chat_model_config, | |||
| tool_config=tool_config, agent_id=agent_id) | |||
| # 解析模型输出 | |||
| message_id = str(uuid4()) | |||
| message_role = ret['choices'][0]['message']['role'] | |||
| message_content = ret['choices'][0]['message']['content'] | |||
| tool_calls = ret['choices'][0]['message']['tool_calls'] | |||
| # 存放模型输出 | |||
| timestamp = insert_message(agent_status=5, role=message_role, content=message_content, files=json.dumps(files), | |||
| conversation_id=conversation_id, tool_calls=json.dumps(tool_calls)) | |||
| # 构建响应 | |||
| response_data = { | |||
| "messages": [ | |||
| { | |||
| "id": message_id, | |||
| "role": message_role, | |||
| "agent_status": 5, | |||
| "content": message_content, | |||
| "tool_calls": tool_calls, | |||
| "files": files, | |||
| "timestamp": datetime.fromisoformat(timestamp) | |||
| } | |||
| ] | |||
| } | |||
| conn.close() | |||
| return BaseResponse(code=200, msg="success", data=response_data) | |||
| @@ -0,0 +1,38 @@ | |||
| from datetime import datetime | |||
| import sqlite3 | |||
| import json | |||
| from ..utils.system_utils import get_mindpilot_db_connection | |||
| # 初始化数据库和表 | |||
| def init_messages_table(): | |||
| conn = get_mindpilot_db_connection() | |||
| conn.execute(''' | |||
| CREATE TABLE IF NOT EXISTS message ( | |||
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |||
| agent_status INTEGER, | |||
| role TEXT, | |||
| content TEXT, | |||
| tool_calls TEXT, --json格式 | |||
| files TEXT, --json格式 | |||
| timestamp TEXT, | |||
| conversation_id TEXT, | |||
| FOREIGN KEY (conversation_id) REFERENCES conversations (conversation_id) | |||
| ) | |||
| ''') | |||
| conn.commit() | |||
| conn.close() | |||
| def insert_message(agent_status, role, content, files, conversation_id, tool_calls): | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| timestamp = datetime.now().isoformat() # 获取当前时间戳 | |||
| cursor.execute(''' | |||
| INSERT INTO message (agent_status, role, content, tool_calls, files, timestamp, conversation_id) | |||
| VALUES (?, ?, ?, ?, ?, ?, ?) | |||
| ''', (agent_status, role, content, tool_calls, json.dumps(files), timestamp, conversation_id)) | |||
| conn.commit() | |||
| conn.close() | |||
| return timestamp | |||
| @@ -1,19 +1,12 @@ | |||
| import sqlite3 | |||
| from typing import List | |||
| from fastapi import APIRouter, Body, Query | |||
| from ..utils.system_utils import BaseResponse, ListResponse | |||
| # 初始化数据库连接 | |||
| def get_db_connection(): | |||
| conn = sqlite3.connect('model_configs.db') | |||
| conn.row_factory = sqlite3.Row | |||
| return conn | |||
| from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_connection | |||
| # 创建表结构 | |||
| def create_table(): | |||
| conn = get_db_connection() | |||
| conn = get_mindpilot_db_connection() | |||
| conn.execute(''' | |||
| CREATE TABLE IF NOT EXISTS model_configs ( | |||
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |||
| @@ -45,7 +38,7 @@ async def add_model_config( | |||
| ): | |||
| create_table() | |||
| conn = get_db_connection() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| cursor.execute(''' | |||
| @@ -83,7 +76,7 @@ async def add_model_config( | |||
| async def list_model_configs(): | |||
| create_table() | |||
| conn = get_db_connection() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| cursor.execute('SELECT * FROM model_configs') | |||
| @@ -116,7 +109,7 @@ async def get_model_config( | |||
| ): | |||
| create_table() | |||
| conn = get_db_connection() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,)) | |||
| @@ -159,7 +152,7 @@ async def update_model_config( | |||
| }]), | |||
| ): | |||
| create_table() | |||
| conn = get_db_connection() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,)) | |||
| @@ -206,7 +199,7 @@ async def update_model_config( | |||
| async def delete_model_config(config_id): | |||
| create_table() | |||
| conn = get_db_connection() | |||
| conn = get_mindpilot_db_connection() | |||
| cursor = conn.cursor() | |||
| # 首先检查配置是否存在 | |||
| @@ -1 +1,30 @@ | |||
| # TODO 获取配置 | |||
| import sqlite3 | |||
| def get_config_from_id(config_id: int): | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| cursor = conn.cursor() | |||
| cursor.execute('SELECT * FROM model_configs WHERE id = ?', (config_id,)) | |||
| config = cursor.fetchone() | |||
| conn.close() | |||
| if not config: | |||
| return None | |||
| config_dict = { | |||
| "config_id": config[0], | |||
| "config_name": config[1], | |||
| "platform": config[2], | |||
| "base_url": config[3], | |||
| "api_key": config[4], | |||
| "llm_model": { | |||
| config[5]: { | |||
| "temperature": config[8], | |||
| "max_tokens": config[7], | |||
| "callbacks": config[6], | |||
| } | |||
| } | |||
| } | |||
| return config_dict | |||
| @@ -3,6 +3,7 @@ import logging | |||
| import multiprocessing as mp | |||
| import os | |||
| import socket | |||
| import sqlite3 | |||
| import sys | |||
| from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed | |||
| from pathlib import Path | |||
| @@ -207,3 +208,8 @@ class ListResponse(BaseResponse): | |||
| "data": ["doc1.docx", "doc2.pdf", "doc3.txt"], | |||
| } | |||
| } | |||
| def get_mindpilot_db_connection(): | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| conn.row_factory = sqlite3.Row | |||
| return conn | |||