| @@ -17,4 +17,5 @@ numexpr~=2.8.7 | |||
| langchain_community==0.2.7 | |||
| markdownify~=0.13.1 | |||
| strsimpy~=0.2.1 | |||
| metaphor_python==0.1.23 | |||
| metaphor_python==0.1.23 | |||
| langchainhub==0.1.20 | |||
| @@ -10,11 +10,11 @@ from langchain_core.tools import BaseTool | |||
| def agents_registry( | |||
| llm: BaseLanguageModel, | |||
| tools: Sequence[BaseTool] = [], | |||
| callbacks: List[BaseCallbackHandler] = [], | |||
| prompt: str = None, | |||
| verbose: bool = False, | |||
| llm: BaseLanguageModel, | |||
| tools: Sequence[BaseTool] = [], | |||
| callbacks: List[BaseCallbackHandler] = [], | |||
| prompt: str = None, | |||
| verbose: bool = False, | |||
| ): | |||
| if prompt is not None: | |||
| prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)]) | |||
| @@ -23,8 +23,7 @@ def agents_registry( | |||
| agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt) | |||
| agent_executor = AgentExecutor( | |||
| agent=agent, tools=tools, verbose=verbose, callbacks=callbacks | |||
| agent=agent, tools=tools, verbose=verbose, callbacks=callbacks, handle_parsing_errors=True | |||
| ) | |||
| return agent_executor | |||
| @@ -5,18 +5,18 @@ from typing import Dict, List | |||
| from fastapi import APIRouter, Request | |||
| from langchain.prompts.prompt import PromptTemplate | |||
| # from chatchat.server.api_server.api_schemas import AgentStatus, MsgType, OpenAIChatInput | |||
| from ..chat import chat | |||
| from app.api.api_schemas import MsgType, OpenAIChatInput | |||
| from ..chat.chat import chat | |||
| # from chatchat.server.chat.file_chat import file_chat | |||
| # from chatchat.server.db.repository import add_message_to_db | |||
| # from chatchat.server.utils import ( | |||
| # get_OpenAIClient, | |||
| # get_prompt_template, | |||
| # get_tool, | |||
| # get_tool_config, | |||
| # ) | |||
| # | |||
| # from .openai_routes import openai_request | |||
| from ..utils import ( | |||
| get_OpenAIClient, | |||
| get_prompt_template, | |||
| get_tool, | |||
| get_tool_config, | |||
| ) | |||
| from .openai_routes import openai_request | |||
| chat_router = APIRouter(prefix="/chat", tags=["MindPilot对话"]) | |||
| @@ -24,148 +24,151 @@ chat_router.post( | |||
| "/chat", | |||
| summary="与llm模型对话(通过LLMChain)", | |||
| )(chat) | |||
| # | |||
| # # 定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name | |||
| # global_model_name = None | |||
| # | |||
| # | |||
| # @chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") | |||
| # async def chat_completions( | |||
| # request: Request, | |||
| # body: OpenAIChatInput, | |||
| # ) -> Dict: | |||
| # """ | |||
| # 请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数 | |||
| # tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换 | |||
| # 通过不同的参数组合调用不同的 chat 功能: | |||
| # - tool_choice | |||
| # - extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input) | |||
| # - extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice | |||
| # - tools: agent 对话 | |||
| # - 其它:LLM 对话 | |||
| # 以后还要考虑其它的组合(如文件对话) | |||
| # 返回与 openai 兼容的 Dict | |||
| # """ | |||
| # client = get_OpenAIClient(model_name=body.model, is_async=True) | |||
| # extra = {**body.model_extra} or {} | |||
| # for key in list(extra): | |||
| # delattr(body, key) | |||
| # | |||
| # global global_model_name | |||
| # global_model_name = body.model | |||
| # # check tools & tool_choice in request body | |||
| # if isinstance(body.tool_choice, str): | |||
| # if t := get_tool(body.tool_choice): | |||
| # body.tool_choice = {"function": {"name": t.name}, "type": "function"} | |||
| # if isinstance(body.tools, list): | |||
| # for i in range(len(body.tools)): | |||
| # if isinstance(body.tools[i], str): | |||
| # if t := get_tool(body.tools[i]): | |||
| # body.tools[i] = { | |||
| # "type": "function", | |||
| # "function": { | |||
| # "name": t.name, | |||
| # "description": t.description, | |||
| # "parameters": t.args, | |||
| # }, | |||
| # } | |||
| # | |||
| # conversation_id = extra.get("conversation_id") | |||
| # | |||
| # # chat based on result from one choiced tool | |||
| # if body.tool_choice: | |||
| # tool = get_tool(body.tool_choice["function"]["name"]) | |||
| # if not body.tools: | |||
| # body.tools = [ | |||
| # { | |||
| # "type": "function", | |||
| # "function": { | |||
| # "name": tool.name, | |||
| # "description": tool.description, | |||
| # "parameters": tool.args, | |||
| # }, | |||
| # } | |||
| # ] | |||
| # if tool_input := extra.get("tool_input"): | |||
| # message_id = ( | |||
| # add_message_to_db( | |||
| # chat_type="tool_call", | |||
| # query=body.messages[-1]["content"], | |||
| # conversation_id=conversation_id, | |||
| # ) | |||
| # if conversation_id | |||
| # else None | |||
| # ) | |||
| # | |||
| # tool_result = await tool.ainvoke(tool_input) | |||
| # prompt_template = PromptTemplate.from_template( | |||
| # get_prompt_template("llm_model", "rag"), template_format="jinja2" | |||
| # ) | |||
| # body.messages[-1]["content"] = prompt_template.format( | |||
| # context=tool_result, question=body.messages[-1]["content"] | |||
| # ) | |||
| # del body.tools | |||
| # del body.tool_choice | |||
| # extra_json = { | |||
| # "message_id": message_id, | |||
| # "status": None, | |||
| # } | |||
| # header = [ | |||
| # { | |||
| # **extra_json, | |||
| # "content": f"{tool_result}", | |||
| # "tool_output": tool_result.data, | |||
| # "is_ref": True, | |||
| # } | |||
| # ] | |||
| # return await openai_request( | |||
| # client.chat.completions.create, | |||
| # body, | |||
| # extra_json=extra_json, | |||
| # header=header, | |||
| # ) | |||
| # | |||
| # # agent chat with tool calls | |||
| # if body.tools: | |||
| # message_id = ( | |||
| # add_message_to_db( | |||
| # chat_type="agent_chat", | |||
| # query=body.messages[-1]["content"], | |||
| # conversation_id=conversation_id, | |||
| # ) | |||
| # if conversation_id | |||
| # else None | |||
| # ) | |||
| # | |||
| # chat_model_config = {} | |||
| # tool_names = [x["function"]["name"] for x in body.tools] | |||
| # tool_config = {name: get_tool_config(name) for name in tool_names} | |||
| # result = await chat( | |||
| # query=body.messages[-1]["content"], | |||
| # metadata=extra.get("metadata", {}), | |||
| # conversation_id=extra.get("conversation_id", ""), | |||
| # message_id=message_id, | |||
| # history_len=-1, | |||
| # history=body.messages[:-1], | |||
| # stream=body.stream, | |||
| # chat_model_config=extra.get("chat_model_config", chat_model_config), | |||
| # tool_config=extra.get("tool_config", tool_config), | |||
| # ) | |||
| # return result | |||
| # else: # LLM chat directly | |||
| # message_id = ( | |||
| # add_message_to_db( | |||
| # chat_type="llm_chat", | |||
| # query=body.messages[-1]["content"], | |||
| # conversation_id=conversation_id, | |||
| # ) | |||
| # if conversation_id | |||
| # else None | |||
| # ) | |||
| # extra_json = { | |||
| # "message_id": message_id, | |||
| # "status": None, | |||
| # } | |||
| # return await openai_request( | |||
| # client.chat.completions.create, body, extra_json=extra_json | |||
| # ) | |||
| # 定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name | |||
| global_model_name = None | |||
| @chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") | |||
| async def chat_completions( | |||
| request: Request, | |||
| body: OpenAIChatInput, | |||
| ) -> Dict: | |||
| """ | |||
| 请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数 | |||
| tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换 | |||
| 通过不同的参数组合调用不同的 chat 功能: | |||
| - tool_choice | |||
| - extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input) | |||
| - extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice | |||
| - tools: agent 对话 | |||
| - 其它:LLM 对话 | |||
| 返回与 openai 兼容的 Dict | |||
| """ | |||
| # print(body) | |||
| client = get_OpenAIClient(model_name=body.model, is_async=True) | |||
| extra = {**body.model_extra} or {} | |||
| for key in list(extra): | |||
| delattr(body, key) | |||
| global global_model_name | |||
| global_model_name = body.model | |||
| # check tools & tool_choice in request body | |||
| if isinstance(body.tool_choice, str): | |||
| if t := get_tool(body.tool_choice): | |||
| body.tool_choice = {"function": {"name": t.name}, "type": "function"} | |||
| if isinstance(body.tools, list): | |||
| for i in range(len(body.tools)): | |||
| if isinstance(body.tools[i], str): | |||
| if t := get_tool(body.tools[i]): | |||
| body.tools[i] = { | |||
| "type": "function", | |||
| "function": { | |||
| "name": t.name, | |||
| "description": t.description, | |||
| "parameters": t.args, | |||
| }, | |||
| } | |||
| conversation_id = extra.get("conversation_id") | |||
| # chat based on result from one choiced tool | |||
| if body.tool_choice: | |||
| tool = get_tool(body.tool_choice["function"]["name"]) | |||
| if not body.tools: | |||
| body.tools = [ | |||
| { | |||
| "type": "function", | |||
| "function": { | |||
| "name": tool.name, | |||
| "description": tool.description, | |||
| "parameters": tool.args, | |||
| }, | |||
| } | |||
| ] | |||
| if tool_input := extra.get("tool_input"): | |||
| # message_id = ( | |||
| # add_message_to_db( | |||
| # chat_type="tool_call", | |||
| # query=body.messages[-1]["content"], | |||
| # conversation_id=conversation_id, | |||
| # ) | |||
| # if conversation_id | |||
| # else None | |||
| # ) | |||
| tool_result = await tool.ainvoke(tool_input) | |||
| prompt_template = PromptTemplate.from_template( | |||
| get_prompt_template("llm_model", "rag"), template_format="jinja2" | |||
| ) | |||
| body.messages[-1]["content"] = prompt_template.format( | |||
| context=tool_result, question=body.messages[-1]["content"] | |||
| ) | |||
| del body.tools | |||
| del body.tool_choice | |||
| extra_json = { | |||
| # "message_id": message_id, | |||
| "status": None, | |||
| } | |||
| header = [ | |||
| { | |||
| **extra_json, | |||
| "content": f"{tool_result}", | |||
| "tool_output": tool_result.data, | |||
| "is_ref": True, | |||
| } | |||
| ] | |||
| return await openai_request( | |||
| client.chat.completions.create, | |||
| body, | |||
| extra_json=extra_json, | |||
| header=header, | |||
| ) | |||
| # agent chat with tool calls | |||
| if body.tools: | |||
| # message_id = ( | |||
| # add_message_to_db( | |||
| # chat_type="agent_chat", | |||
| # query=body.messages[-1]["content"], | |||
| # conversation_id=conversation_id, | |||
| # ) | |||
| # if conversation_id | |||
| # else None | |||
| # ) | |||
| chat_model_config = {} | |||
| tool_names = [x["function"]["name"] for x in body.tools] | |||
| tool_config = {name: get_tool_config(name) for name in tool_names} | |||
| # print(tool_config) | |||
| result = await chat( | |||
| query=body.messages[-1]["content"], | |||
| # query="搜索互联网,给出2024年7月1日是中国的什么节日", | |||
| metadata=extra.get("metadata", {}), | |||
| conversation_id=extra.get("conversation_id", ""), | |||
| # message_id=message_id, | |||
| history_len=-1, | |||
| history=body.messages[:-1], | |||
| stream=body.stream, | |||
| chat_model_config=extra.get("chat_model_config", chat_model_config), | |||
| tool_config=extra.get("tool_config", tool_config), | |||
| ) | |||
| return result | |||
| else: # LLM chat directly | |||
| # message_id = ( | |||
| # add_message_to_db( | |||
| # chat_type="llm_chat", | |||
| # query=body.messages[-1]["content"], | |||
| # conversation_id=conversation_id, | |||
| # ) | |||
| # if conversation_id | |||
| # else None | |||
| # ) | |||
| extra_json = { | |||
| # "message_id": message_id, | |||
| "status": None, | |||
| } | |||
| return await openai_request( | |||
| client.chat.completions.create, body, extra_json=extra_json | |||
| ) | |||
| @@ -0,0 +1,320 @@ | |||
| from __future__ import annotations | |||
| import asyncio | |||
| import base64 | |||
| import logging | |||
| import os | |||
| import shutil | |||
| from contextlib import asynccontextmanager | |||
| from datetime import datetime | |||
| from pathlib import Path | |||
| from typing import AsyncGenerator, Dict, Iterable, Tuple | |||
| from fastapi import APIRouter, Request | |||
| from fastapi.responses import FileResponse | |||
| from openai import AsyncClient | |||
| from openai.types.file_object import FileObject | |||
| from sse_starlette.sse import EventSourceResponse | |||
| # from chatchat.configs import BASE_TEMP_DIR, log_verbose | |||
| from ..utils import get_OpenAIClient | |||
| from .api_schemas import * | |||
| logger = logging.getLogger() | |||
| DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数 | |||
| model_semaphores: Dict[ | |||
| Tuple[str, str], asyncio.Semaphore | |||
| ] = {} # key: (model_name, platform) | |||
| openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]) | |||
| # @asynccontextmanager | |||
| # async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: | |||
| # """ | |||
| # 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型 | |||
| # """ | |||
| # max_semaphore = 0 | |||
| # selected_platform = "" | |||
| # model_infos = get_model_info(model_name=model_name, multiple=True) | |||
| # for m, c in model_infos.items(): | |||
| # key = (m, c["platform_name"]) | |||
| # api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES) | |||
| # if key not in model_semaphores: | |||
| # model_semaphores[key] = asyncio.Semaphore(api_concurrencies) | |||
| # semaphore = model_semaphores[key] | |||
| # if semaphore._value >= api_concurrencies: | |||
| # selected_platform = c["platform_name"] | |||
| # break | |||
| # elif semaphore._value > max_semaphore: | |||
| # selected_platform = c["platform_name"] | |||
| # | |||
| # key = (m, selected_platform) | |||
| # semaphore = model_semaphores[key] | |||
| # try: | |||
| # await semaphore.acquire() | |||
| # yield get_OpenAIClient(platform_name=selected_platform, is_async=True) | |||
| # except Exception: | |||
| # logger.error(f"failed when request to {key}", exc_info=True) | |||
| # finally: | |||
| # semaphore.release() | |||
| async def openai_request( | |||
| method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = [] | |||
| ): | |||
| """ | |||
| helper function to make openai request with extra fields | |||
| """ | |||
| async def generator(): | |||
| for x in header: | |||
| if isinstance(x, str): | |||
| x = OpenAIChatOutput(content=x, object="chat.completion.chunk") | |||
| elif isinstance(x, dict): | |||
| x = OpenAIChatOutput.model_validate(x) | |||
| else: | |||
| raise RuntimeError(f"unsupported value: {header}") | |||
| for k, v in extra_json.items(): | |||
| setattr(x, k, v) | |||
| yield x.model_dump_json() | |||
| async for chunk in await method(**params): | |||
| for k, v in extra_json.items(): | |||
| setattr(chunk, k, v) | |||
| yield chunk.model_dump_json() | |||
| for x in tail: | |||
| if isinstance(x, str): | |||
| x = OpenAIChatOutput(content=x, object="chat.completion.chunk") | |||
| elif isinstance(x, dict): | |||
| x = OpenAIChatOutput.model_validate(x) | |||
| else: | |||
| raise RuntimeError(f"unsupported value: {tail}") | |||
| for k, v in extra_json.items(): | |||
| setattr(x, k, v) | |||
| yield x.model_dump_json() | |||
| params = body.model_dump(exclude_unset=True) | |||
| if hasattr(body, "stream") and body.stream: | |||
| return EventSourceResponse(generator()) | |||
| else: | |||
| result = await method(**params) | |||
| for k, v in extra_json.items(): | |||
| setattr(result, k, v) | |||
| return result.model_dump() | |||
| # @openai_router.get("/models") | |||
| # async def list_models() -> Dict: | |||
| # """ | |||
| # 整合所有平台的模型列表。 | |||
| # """ | |||
| # | |||
| # async def task(name: str, config: Dict): | |||
| # try: | |||
| # client = get_OpenAIClient(name, is_async=True) | |||
| # models = await client.models.list() | |||
| # return [{**x.model_dump(), "platform_name": name} for x in models.data] | |||
| # except Exception: | |||
| # logger.error(f"failed request to platform: {name}", exc_info=True) | |||
| # return [] | |||
| # | |||
| # result = [] | |||
| # tasks = [ | |||
| # asyncio.create_task(task(name, config)) | |||
| # for name, config in get_config_platforms().items() | |||
| # ] | |||
| # for t in asyncio.as_completed(tasks): | |||
| # result += await t | |||
| # | |||
| # return {"object": "list", "data": result} | |||
| # | |||
| # | |||
| # @openai_router.post("/chat/completions") | |||
| # async def create_chat_completions( | |||
| # request: Request, | |||
| # body: OpenAIChatInput, | |||
| # ): | |||
| # if log_verbose: | |||
| # print(body) | |||
| # async with get_model_client(body.model) as client: | |||
| # result = await openai_request(client.chat.completions.create, body) | |||
| # return result | |||
| # | |||
| # | |||
| # @openai_router.post("/completions") | |||
| # async def create_completions( | |||
| # request: Request, | |||
| # body: OpenAIChatInput, | |||
| # ): | |||
| # async with get_model_client(body.model) as client: | |||
| # return await openai_request(client.completions.create, body) | |||
| # | |||
| # | |||
| # @openai_router.post("/embeddings") | |||
| # async def create_embeddings( | |||
| # request: Request, | |||
| # body: OpenAIEmbeddingsInput, | |||
| # ): | |||
| # params = body.model_dump(exclude_unset=True) | |||
| # client = get_OpenAIClient(model_name=body.model) | |||
| # return (await client.embeddings.create(**params)).model_dump() | |||
| # | |||
| # | |||
| # @openai_router.post("/images/generations") | |||
| # async def create_image_generations( | |||
| # request: Request, | |||
| # body: OpenAIImageGenerationsInput, | |||
| # ): | |||
| # async with get_model_client(body.model) as client: | |||
| # return await openai_request(client.images.generate, body) | |||
| # | |||
| # | |||
| # @openai_router.post("/images/variations") | |||
| # async def create_image_variations( | |||
| # request: Request, | |||
| # body: OpenAIImageVariationsInput, | |||
| # ): | |||
| # async with get_model_client(body.model) as client: | |||
| # return await openai_request(client.images.create_variation, body) | |||
| # | |||
| # | |||
| # @openai_router.post("/images/edit") | |||
| # async def create_image_edit( | |||
| # request: Request, | |||
| # body: OpenAIImageEditsInput, | |||
| # ): | |||
| # async with get_model_client(body.model) as client: | |||
| # return await openai_request(client.images.edit, body) | |||
| # | |||
| # | |||
| # @openai_router.post("/audio/translations", deprecated="暂不支持") | |||
| # async def create_audio_translations( | |||
| # request: Request, | |||
| # body: OpenAIAudioTranslationsInput, | |||
| # ): | |||
| # async with get_model_client(body.model) as client: | |||
| # return await openai_request(client.audio.translations.create, body) | |||
| # | |||
| # | |||
| # @openai_router.post("/audio/transcriptions", deprecated="暂不支持") | |||
| # async def create_audio_transcriptions( | |||
| # request: Request, | |||
| # body: OpenAIAudioTranscriptionsInput, | |||
| # ): | |||
| # async with get_model_client(body.model) as client: | |||
| # return await openai_request(client.audio.transcriptions.create, body) | |||
| # | |||
| # | |||
| # @openai_router.post("/audio/speech", deprecated="暂不支持") | |||
| # async def create_audio_speech( | |||
| # request: Request, | |||
| # body: OpenAIAudioSpeechInput, | |||
| # ): | |||
| # async with get_model_client(body.model) as client: | |||
| # return await openai_request(client.audio.speech.create, body) | |||
| # | |||
| # | |||
| # def _get_file_id( | |||
| # purpose: str, | |||
| # created_at: int, | |||
| # filename: str, | |||
| # ) -> str: | |||
| # today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") | |||
| # return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode() | |||
| # | |||
| # | |||
| # def _get_file_info(file_id: str) -> Dict: | |||
| # splits = base64.urlsafe_b64decode(file_id).decode().split("/") | |||
| # created_at = -1 | |||
| # size = -1 | |||
| # file_path = _get_file_path(file_id) | |||
| # if os.path.isfile(file_path): | |||
| # created_at = int(os.path.getmtime(file_path)) | |||
| # size = os.path.getsize(file_path) | |||
| # | |||
| # return { | |||
| # "purpose": splits[0], | |||
| # "created_at": created_at, | |||
| # "filename": splits[2], | |||
| # "bytes": size, | |||
| # } | |||
| # | |||
| # | |||
| # def _get_file_path(file_id: str) -> str: | |||
| # file_id = base64.urlsafe_b64decode(file_id).decode() | |||
| # return os.path.join(BASE_TEMP_DIR, "openai_files", file_id) | |||
| # | |||
| # | |||
| # @openai_router.post("/files") | |||
| # async def files( | |||
| # request: Request, | |||
| # file: UploadFile, | |||
| # purpose: str = "assistants", | |||
| # ) -> Dict: | |||
| # created_at = int(datetime.now().timestamp()) | |||
| # file_id = _get_file_id( | |||
| # purpose=purpose, created_at=created_at, filename=file.filename | |||
| # ) | |||
| # file_path = _get_file_path(file_id) | |||
| # file_dir = os.path.dirname(file_path) | |||
| # os.makedirs(file_dir, exist_ok=True) | |||
| # with open(file_path, "wb") as fp: | |||
| # shutil.copyfileobj(file.file, fp) | |||
| # file.file.close() | |||
| # | |||
| # return dict( | |||
| # id=file_id, | |||
| # filename=file.filename, | |||
| # bytes=file.size, | |||
| # created_at=created_at, | |||
| # object="file", | |||
| # purpose=purpose, | |||
| # ) | |||
| # | |||
| # | |||
| # @openai_router.get("/files") | |||
| # def list_files(purpose: str) -> Dict[str, List[Dict]]: | |||
| # file_ids = [] | |||
| # root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose | |||
| # for dir, sub_dirs, files in os.walk(root_path): | |||
| # dir = Path(dir).relative_to(root_path).as_posix() | |||
| # for file in files: | |||
| # file_id = base64.urlsafe_b64encode( | |||
| # f"{purpose}/{dir}/{file}".encode() | |||
| # ).decode() | |||
| # file_ids.append(file_id) | |||
| # return { | |||
| # "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids] | |||
| # } | |||
| # | |||
| # | |||
| # @openai_router.get("/files/{file_id}") | |||
| # def retrieve_file(file_id: str) -> Dict: | |||
| # file_info = _get_file_info(file_id) | |||
| # return {**file_info, "id": file_id, "object": "file"} | |||
| # | |||
| # | |||
| # @openai_router.get("/files/{file_id}/content") | |||
| # def retrieve_file_content(file_id: str) -> Dict: | |||
| # file_path = _get_file_path(file_id) | |||
| # return FileResponse(file_path) | |||
| # | |||
| # | |||
| # @openai_router.delete("/files/{file_id}") | |||
| # def delete_file(file_id: str) -> Dict: | |||
| # file_path = _get_file_path(file_id) | |||
| # deleted = False | |||
| # | |||
| # try: | |||
| # if os.path.isfile(file_path): | |||
| # os.remove(file_path) | |||
| # deleted = True | |||
| # except: | |||
| # ... | |||
| # | |||
| # return {"id": file_id, "deleted": deleted, "object": "file"} | |||
| @@ -32,6 +32,43 @@ from ..utils import ( | |||
| def create_models_from_config(configs, callbacks, stream): | |||
| configs = { | |||
| # 意图识别不需要输出,模型后台知道就行 | |||
| "preprocess_model": { | |||
| "glm-4-0520": { | |||
| "temperature": 0.05, | |||
| "max_tokens": 4096, | |||
| "history_len": 100, | |||
| "prompt_name": "default", | |||
| "callbacks": False, | |||
| }, | |||
| }, | |||
| "llm_model": { | |||
| "glm-4-0520": { | |||
| "temperature": 0.9, | |||
| "max_tokens": 4096, | |||
| "history_len": 10, | |||
| "prompt_name": "default", | |||
| "callbacks": True, | |||
| }, | |||
| }, | |||
| "action_model": { | |||
| "glm-4-0520": { | |||
| "temperature": 0.01, | |||
| "max_tokens": 4096, | |||
| "prompt_name": "default", | |||
| "callbacks": True, | |||
| }, | |||
| }, | |||
| "postprocess_model": { | |||
| "glm-4-0520": { | |||
| "temperature": 0.01, | |||
| "max_tokens": 4096, | |||
| "prompt_name": "default", | |||
| "callbacks": True, | |||
| } | |||
| } | |||
| } | |||
| models = {} | |||
| prompts = {} | |||
| for model_type, model_configs in configs.items(): | |||
| @@ -94,7 +131,7 @@ def create_models_chains( | |||
| async def chat( | |||
| query: str = Body(..., description="用户输入"), | |||
| query: str = Body(..., description="用户输入", examples=[""]), | |||
| metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), | |||
| conversation_id: str = Body("", description="对话框ID"), | |||
| message_id: str = Body(None, description="数据库消息ID"), | |||
| @@ -202,7 +239,7 @@ async def chat( | |||
| model=models["llm_model"].model_name, | |||
| status=data["status"], | |||
| message_type=data["message_type"], | |||
| message_id=message_id, | |||
| # message_id=message_id, | |||
| ) | |||
| yield ret.model_dump_json() | |||
| @@ -17,7 +17,7 @@ def weather_check( | |||
| tool_config = { | |||
| "use": False, | |||
| "api_key": "S8vrB4U_-c5mvAMiK", | |||
| "api_key": "SE7CGiRD5dvls08Ub", | |||
| } | |||
| api_key = tool_config.get("api_key") | |||
| url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c" | |||
| @@ -142,9 +142,10 @@ def get_ChatOpenAI( | |||
| # openai_api_key="", | |||
| # openai_proxy="", | |||
| # ) | |||
| # TODO 配置文件 | |||
| params.update( | |||
| openai_api_base="", | |||
| openai_api_key="", | |||
| openai_api_base="https://open.bigmodel.cn/api/paas/v4/", | |||
| openai_api_key="8424573178d3681bb2e9bfbc5af24dd5.BKKxdk1d6zzgvfnV", | |||
| openai_proxy="", | |||
| ) | |||
| model = ChatOpenAI(**params) | |||
| @@ -170,11 +171,11 @@ def get_prompt_template(type: str, name: str) -> Optional[str]: | |||
| def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]: | |||
| import importlib | |||
| from ..app import tools | |||
| from app import tools | |||
| importlib.reload(tools) | |||
| from ..app.tools import tools_registry | |||
| from app.tools import tools_registry | |||
| # update_search_local_knowledgebase_tool() | |||
| @@ -197,3 +198,131 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): | |||
| finally: | |||
| # Signal the aiter to stop. | |||
| event.set() | |||
| def get_OpenAIClient( | |||
| platform_name: str = None, | |||
| model_name: str = None, | |||
| is_async: bool = True, | |||
| ) -> Union[openai.Client, openai.AsyncClient]: | |||
| # """ | |||
| # construct an openai Client for specified platform or model | |||
| # """ | |||
| # if platform_name is None: | |||
| # platform_info = get_model_info( | |||
| # model_name=model_name, platform_name=platform_name | |||
| # ) | |||
| # if platform_info is None: | |||
| # raise RuntimeError( | |||
| # f"cannot find configured platform for model: {model_name}" | |||
| # ) | |||
| # platform_name = platform_info.get("platform_name") | |||
| # platform_info = get_config_platforms().get(platform_name) | |||
| # assert platform_info, f"cannot find configured platform: {platform_name}" | |||
| # TODO 配置文件 | |||
| params = { | |||
| "base_url":"https://open.bigmodel.cn/api/paas/v4/", | |||
| "api_key": "8424573178d3681bb2e9bfbc5af24dd5.BKKxdk1d6zzgvfnV" | |||
| } | |||
| httpx_params = {} | |||
| # if api_proxy := platform_info.get("api_proxy"): | |||
| # httpx_params = { | |||
| # "proxies": api_proxy, | |||
| # "transport": httpx.HTTPTransport(local_address="0.0.0.0"), | |||
| # } | |||
| if is_async: | |||
| if httpx_params: | |||
| params["http_client"] = httpx.AsyncClient(**httpx_params) | |||
| return openai.AsyncClient(**params) | |||
| else: | |||
| if httpx_params: | |||
| params["http_client"] = httpx.Client(**httpx_params) | |||
| return openai.Client(**params) | |||
| def get_tool_config(name: str = None) -> Dict: | |||
| import importlib | |||
| # from chatchat.configs import model_config | |||
| # importlib.reload(model_config) | |||
| # from chatchat.configs import TOOL_CONFIG | |||
| TOOL_CONFIG = { | |||
| # "search_local_knowledgebase": { | |||
| # "use": False, | |||
| # "top_k": 3, | |||
| # "score_threshold": 1.0, | |||
| # "conclude_prompt": { | |||
| # "with_result": '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' | |||
| # "不允许在答案中添加编造成分,答案请使用中文。 </指令>\n" | |||
| # "<已知信息>{{ context }}</已知信息>\n" | |||
| # "<问题>{{ question }}</问题>\n", | |||
| # "without_result": "请你根据我的提问回答我的问题:\n" | |||
| # "{{ question }}\n" | |||
| # "请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n", | |||
| # }, | |||
| # }, | |||
| "search_internet": { | |||
| "use": False, | |||
| "search_engine_name": "bing", | |||
| "search_engine_config": { | |||
| "bing": { | |||
| "result_len": 3, | |||
| "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", | |||
| "bing_key": "0f42b09dce16474a81c01562ded071dc", | |||
| }, | |||
| "metaphor": { | |||
| "result_len": 3, | |||
| "metaphor_api_key": "", | |||
| "split_result": False, | |||
| "chunk_size": 500, | |||
| "chunk_overlap": 0, | |||
| }, | |||
| "duckduckgo": {"result_len": 3}, | |||
| }, | |||
| "top_k": 10, | |||
| "verbose": "Origin", | |||
| "conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " | |||
| "</指令>\n<已知信息>{{ context }}</已知信息>\n" | |||
| "<问题>\n" | |||
| "{{ question }}\n" | |||
| "</问题>\n", | |||
| }, | |||
| "arxiv": { | |||
| "use": False, | |||
| }, | |||
| "shell": { | |||
| "use": False, | |||
| }, | |||
| "weather_check": { | |||
| "use": False, | |||
| "api_key": "SE7CGiRD5dvls08Ub", | |||
| }, | |||
| # "search_youtube": { | |||
| # "use": False, | |||
| # }, | |||
| "wolfram": { | |||
| "use": False, | |||
| "appid": "PWKVLW-6ETR93QX6Q", | |||
| }, | |||
| "calculate": { | |||
| "use": False, | |||
| }, | |||
| # "vqa_processor": { | |||
| # "use": False, | |||
| # "model_path": "your model path", | |||
| # "tokenizer_path": "your tokenizer path", | |||
| # "device": "cuda:1", | |||
| # }, | |||
| # "aqa_processor": { | |||
| # "use": False, | |||
| # "model_path": "your model path", | |||
| # "tokenizer_path": "yout tokenizer path", | |||
| # "device": "cuda:2", | |||
| # }, | |||
| # "text2images": { | |||
| # "use": False, | |||
| # }, | |||
| } | |||
| if name is None: | |||
| return TOOL_CONFIG | |||
| else: | |||
| return TOOL_CONFIG.get(name, {}) | |||