diff --git a/requirements.txt b/requirements.txt index 875698b..ade43f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,5 @@ numexpr~=2.8.7 langchain_community==0.2.7 markdownify~=0.13.1 strsimpy~=0.2.1 -metaphor_python==0.1.23 \ No newline at end of file +metaphor_python==0.1.23 +langchainhub==0.1.20 \ No newline at end of file diff --git a/src/mindpilot/app/agent/agents_registry.py b/src/mindpilot/app/agent/agents_registry.py index dd15a6b..3a29ff8 100644 --- a/src/mindpilot/app/agent/agents_registry.py +++ b/src/mindpilot/app/agent/agents_registry.py @@ -10,11 +10,11 @@ from langchain_core.tools import BaseTool def agents_registry( - llm: BaseLanguageModel, - tools: Sequence[BaseTool] = [], - callbacks: List[BaseCallbackHandler] = [], - prompt: str = None, - verbose: bool = False, + llm: BaseLanguageModel, + tools: Sequence[BaseTool] = [], + callbacks: List[BaseCallbackHandler] = [], + prompt: str = None, + verbose: bool = False, ): if prompt is not None: prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)]) @@ -23,8 +23,7 @@ def agents_registry( agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt) agent_executor = AgentExecutor( - agent=agent, tools=tools, verbose=verbose, callbacks=callbacks + agent=agent, tools=tools, verbose=verbose, callbacks=callbacks, handle_parsing_errors=True ) return agent_executor - diff --git a/src/mindpilot/app/api/chat_routes.py b/src/mindpilot/app/api/chat_routes.py index 94d9fe5..4a61ae4 100644 --- a/src/mindpilot/app/api/chat_routes.py +++ b/src/mindpilot/app/api/chat_routes.py @@ -5,18 +5,18 @@ from typing import Dict, List from fastapi import APIRouter, Request from langchain.prompts.prompt import PromptTemplate -# from chatchat.server.api_server.api_schemas import AgentStatus, MsgType, OpenAIChatInput -from ..chat import chat +from app.api.api_schemas import MsgType, OpenAIChatInput +from ..chat.chat import chat # from chatchat.server.chat.file_chat import file_chat # from chatchat.server.db.repository import add_message_to_db -# from chatchat.server.utils import ( -# get_OpenAIClient, -# get_prompt_template, -# get_tool, -# get_tool_config, -# ) -# -# from .openai_routes import openai_request +from ..utils import ( + get_OpenAIClient, + get_prompt_template, + get_tool, + get_tool_config, +) + +from .openai_routes import openai_request chat_router = APIRouter(prefix="/chat", tags=["MindPilot对话"]) @@ -24,148 +24,151 @@ chat_router.post( "/chat", summary="与llm模型对话(通过LLMChain)", )(chat) -# -# # 定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name -# global_model_name = None -# -# -# @chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") -# async def chat_completions( -# request: Request, -# body: OpenAIChatInput, -# ) -> Dict: -# """ -# 请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数 -# tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换 -# 通过不同的参数组合调用不同的 chat 功能: -# - tool_choice -# - extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input) -# - extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice -# - tools: agent 对话 -# - 其它:LLM 对话 -# 以后还要考虑其它的组合(如文件对话) -# 返回与 openai 兼容的 Dict -# """ -# client = get_OpenAIClient(model_name=body.model, is_async=True) -# extra = {**body.model_extra} or {} -# for key in list(extra): -# delattr(body, key) -# -# global global_model_name -# global_model_name = body.model -# # check tools & tool_choice in request body -# if isinstance(body.tool_choice, str): -# if t := get_tool(body.tool_choice): -# body.tool_choice = {"function": {"name": t.name}, "type": "function"} -# if isinstance(body.tools, list): -# for i in range(len(body.tools)): -# if isinstance(body.tools[i], str): -# if t := get_tool(body.tools[i]): -# body.tools[i] = { -# "type": "function", -# "function": { -# "name": t.name, -# "description": t.description, -# "parameters": t.args, -# }, -# } -# -# conversation_id = extra.get("conversation_id") -# -# # chat based on result from one choiced tool -# if body.tool_choice: -# tool = get_tool(body.tool_choice["function"]["name"]) -# if not body.tools: -# body.tools = [ -# { -# "type": "function", -# "function": { -# "name": tool.name, -# "description": tool.description, -# "parameters": tool.args, -# }, -# } -# ] -# if tool_input := extra.get("tool_input"): -# message_id = ( -# add_message_to_db( -# chat_type="tool_call", -# query=body.messages[-1]["content"], -# conversation_id=conversation_id, -# ) -# if conversation_id -# else None -# ) -# -# tool_result = await tool.ainvoke(tool_input) -# prompt_template = PromptTemplate.from_template( -# get_prompt_template("llm_model", "rag"), template_format="jinja2" -# ) -# body.messages[-1]["content"] = prompt_template.format( -# context=tool_result, question=body.messages[-1]["content"] -# ) -# del body.tools -# del body.tool_choice -# extra_json = { -# "message_id": message_id, -# "status": None, -# } -# header = [ -# { -# **extra_json, -# "content": f"{tool_result}", -# "tool_output": tool_result.data, -# "is_ref": True, -# } -# ] -# return await openai_request( -# client.chat.completions.create, -# body, -# extra_json=extra_json, -# header=header, -# ) -# -# # agent chat with tool calls -# if body.tools: -# message_id = ( -# add_message_to_db( -# chat_type="agent_chat", -# query=body.messages[-1]["content"], -# conversation_id=conversation_id, -# ) -# if conversation_id -# else None -# ) -# -# chat_model_config = {} -# tool_names = [x["function"]["name"] for x in body.tools] -# tool_config = {name: get_tool_config(name) for name in tool_names} -# result = await chat( -# query=body.messages[-1]["content"], -# metadata=extra.get("metadata", {}), -# conversation_id=extra.get("conversation_id", ""), -# message_id=message_id, -# history_len=-1, -# history=body.messages[:-1], -# stream=body.stream, -# chat_model_config=extra.get("chat_model_config", chat_model_config), -# tool_config=extra.get("tool_config", tool_config), -# ) -# return result -# else: # LLM chat directly -# message_id = ( -# add_message_to_db( -# chat_type="llm_chat", -# query=body.messages[-1]["content"], -# conversation_id=conversation_id, -# ) -# if conversation_id -# else None -# ) -# extra_json = { -# "message_id": message_id, -# "status": None, -# } -# return await openai_request( -# client.chat.completions.create, body, extra_json=extra_json -# ) + +# 定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name +global_model_name = None + + +@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") +async def chat_completions( + request: Request, + body: OpenAIChatInput, +) -> Dict: + """ + 请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数 + tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换 + 通过不同的参数组合调用不同的 chat 功能: + - tool_choice + - extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input) + - extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice + - tools: agent 对话 + - 其它:LLM 对话 + 返回与 openai 兼容的 Dict + """ + # print(body) + client = get_OpenAIClient(model_name=body.model, is_async=True) + extra = {**body.model_extra} or {} + for key in list(extra): + delattr(body, key) + + global global_model_name + global_model_name = body.model + # check tools & tool_choice in request body + if isinstance(body.tool_choice, str): + if t := get_tool(body.tool_choice): + body.tool_choice = {"function": {"name": t.name}, "type": "function"} + if isinstance(body.tools, list): + for i in range(len(body.tools)): + if isinstance(body.tools[i], str): + if t := get_tool(body.tools[i]): + body.tools[i] = { + "type": "function", + "function": { + "name": t.name, + "description": t.description, + "parameters": t.args, + }, + } + + conversation_id = extra.get("conversation_id") + + # chat based on result from one choiced tool + if body.tool_choice: + tool = get_tool(body.tool_choice["function"]["name"]) + if not body.tools: + body.tools = [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.args, + }, + } + ] + if tool_input := extra.get("tool_input"): + # message_id = ( + # add_message_to_db( + # chat_type="tool_call", + # query=body.messages[-1]["content"], + # conversation_id=conversation_id, + # ) + # if conversation_id + # else None + # ) + + tool_result = await tool.ainvoke(tool_input) + prompt_template = PromptTemplate.from_template( + get_prompt_template("llm_model", "rag"), template_format="jinja2" + ) + body.messages[-1]["content"] = prompt_template.format( + context=tool_result, question=body.messages[-1]["content"] + ) + del body.tools + del body.tool_choice + extra_json = { + # "message_id": message_id, + "status": None, + } + header = [ + { + **extra_json, + "content": f"{tool_result}", + "tool_output": tool_result.data, + "is_ref": True, + } + ] + return await openai_request( + client.chat.completions.create, + body, + extra_json=extra_json, + header=header, + ) + + # agent chat with tool calls + if body.tools: + # message_id = ( + # add_message_to_db( + # chat_type="agent_chat", + # query=body.messages[-1]["content"], + # conversation_id=conversation_id, + # ) + # if conversation_id + # else None + # ) + + chat_model_config = {} + tool_names = [x["function"]["name"] for x in body.tools] + tool_config = {name: get_tool_config(name) for name in tool_names} + # print(tool_config) + result = await chat( + query=body.messages[-1]["content"], + # query="搜索互联网,给出2024年7月1日是中国的什么节日", + metadata=extra.get("metadata", {}), + conversation_id=extra.get("conversation_id", ""), + # message_id=message_id, + history_len=-1, + history=body.messages[:-1], + stream=body.stream, + chat_model_config=extra.get("chat_model_config", chat_model_config), + tool_config=extra.get("tool_config", tool_config), + ) + return result + else: # LLM chat directly + # message_id = ( + # add_message_to_db( + # chat_type="llm_chat", + # query=body.messages[-1]["content"], + # conversation_id=conversation_id, + # ) + # if conversation_id + # else None + # ) + extra_json = { + # "message_id": message_id, + "status": None, + } + return await openai_request( + client.chat.completions.create, body, extra_json=extra_json + ) + diff --git a/src/mindpilot/app/api/openai_routes.py b/src/mindpilot/app/api/openai_routes.py new file mode 100644 index 0000000..4a7617b --- /dev/null +++ b/src/mindpilot/app/api/openai_routes.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import asyncio +import base64 +import logging +import os +import shutil +from contextlib import asynccontextmanager +from datetime import datetime +from pathlib import Path +from typing import AsyncGenerator, Dict, Iterable, Tuple + +from fastapi import APIRouter, Request +from fastapi.responses import FileResponse +from openai import AsyncClient +from openai.types.file_object import FileObject +from sse_starlette.sse import EventSourceResponse + +# from chatchat.configs import BASE_TEMP_DIR, log_verbose +from ..utils import get_OpenAIClient + +from .api_schemas import * + +logger = logging.getLogger() + + +DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数 +model_semaphores: Dict[ + Tuple[str, str], asyncio.Semaphore +] = {} # key: (model_name, platform) +openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]) + + +# @asynccontextmanager +# async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: +# """ +# 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型 +# """ +# max_semaphore = 0 +# selected_platform = "" +# model_infos = get_model_info(model_name=model_name, multiple=True) +# for m, c in model_infos.items(): +# key = (m, c["platform_name"]) +# api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES) +# if key not in model_semaphores: +# model_semaphores[key] = asyncio.Semaphore(api_concurrencies) +# semaphore = model_semaphores[key] +# if semaphore._value >= api_concurrencies: +# selected_platform = c["platform_name"] +# break +# elif semaphore._value > max_semaphore: +# selected_platform = c["platform_name"] +# +# key = (m, selected_platform) +# semaphore = model_semaphores[key] +# try: +# await semaphore.acquire() +# yield get_OpenAIClient(platform_name=selected_platform, is_async=True) +# except Exception: +# logger.error(f"failed when request to {key}", exc_info=True) +# finally: +# semaphore.release() + + +async def openai_request( + method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = [] +): + """ + helper function to make openai request with extra fields + """ + + async def generator(): + for x in header: + if isinstance(x, str): + x = OpenAIChatOutput(content=x, object="chat.completion.chunk") + elif isinstance(x, dict): + x = OpenAIChatOutput.model_validate(x) + else: + raise RuntimeError(f"unsupported value: {header}") + for k, v in extra_json.items(): + setattr(x, k, v) + yield x.model_dump_json() + + async for chunk in await method(**params): + for k, v in extra_json.items(): + setattr(chunk, k, v) + yield chunk.model_dump_json() + + for x in tail: + if isinstance(x, str): + x = OpenAIChatOutput(content=x, object="chat.completion.chunk") + elif isinstance(x, dict): + x = OpenAIChatOutput.model_validate(x) + else: + raise RuntimeError(f"unsupported value: {tail}") + for k, v in extra_json.items(): + setattr(x, k, v) + yield x.model_dump_json() + + params = body.model_dump(exclude_unset=True) + + if hasattr(body, "stream") and body.stream: + return EventSourceResponse(generator()) + else: + result = await method(**params) + for k, v in extra_json.items(): + setattr(result, k, v) + return result.model_dump() + + +# @openai_router.get("/models") +# async def list_models() -> Dict: +# """ +# 整合所有平台的模型列表。 +# """ +# +# async def task(name: str, config: Dict): +# try: +# client = get_OpenAIClient(name, is_async=True) +# models = await client.models.list() +# return [{**x.model_dump(), "platform_name": name} for x in models.data] +# except Exception: +# logger.error(f"failed request to platform: {name}", exc_info=True) +# return [] +# +# result = [] +# tasks = [ +# asyncio.create_task(task(name, config)) +# for name, config in get_config_platforms().items() +# ] +# for t in asyncio.as_completed(tasks): +# result += await t +# +# return {"object": "list", "data": result} +# +# +# @openai_router.post("/chat/completions") +# async def create_chat_completions( +# request: Request, +# body: OpenAIChatInput, +# ): +# if log_verbose: +# print(body) +# async with get_model_client(body.model) as client: +# result = await openai_request(client.chat.completions.create, body) +# return result +# +# +# @openai_router.post("/completions") +# async def create_completions( +# request: Request, +# body: OpenAIChatInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.completions.create, body) +# +# +# @openai_router.post("/embeddings") +# async def create_embeddings( +# request: Request, +# body: OpenAIEmbeddingsInput, +# ): +# params = body.model_dump(exclude_unset=True) +# client = get_OpenAIClient(model_name=body.model) +# return (await client.embeddings.create(**params)).model_dump() +# +# +# @openai_router.post("/images/generations") +# async def create_image_generations( +# request: Request, +# body: OpenAIImageGenerationsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.images.generate, body) +# +# +# @openai_router.post("/images/variations") +# async def create_image_variations( +# request: Request, +# body: OpenAIImageVariationsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.images.create_variation, body) +# +# +# @openai_router.post("/images/edit") +# async def create_image_edit( +# request: Request, +# body: OpenAIImageEditsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.images.edit, body) +# +# +# @openai_router.post("/audio/translations", deprecated="暂不支持") +# async def create_audio_translations( +# request: Request, +# body: OpenAIAudioTranslationsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.audio.translations.create, body) +# +# +# @openai_router.post("/audio/transcriptions", deprecated="暂不支持") +# async def create_audio_transcriptions( +# request: Request, +# body: OpenAIAudioTranscriptionsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.audio.transcriptions.create, body) +# +# +# @openai_router.post("/audio/speech", deprecated="暂不支持") +# async def create_audio_speech( +# request: Request, +# body: OpenAIAudioSpeechInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.audio.speech.create, body) +# +# +# def _get_file_id( +# purpose: str, +# created_at: int, +# filename: str, +# ) -> str: +# today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") +# return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode() +# +# +# def _get_file_info(file_id: str) -> Dict: +# splits = base64.urlsafe_b64decode(file_id).decode().split("/") +# created_at = -1 +# size = -1 +# file_path = _get_file_path(file_id) +# if os.path.isfile(file_path): +# created_at = int(os.path.getmtime(file_path)) +# size = os.path.getsize(file_path) +# +# return { +# "purpose": splits[0], +# "created_at": created_at, +# "filename": splits[2], +# "bytes": size, +# } +# +# +# def _get_file_path(file_id: str) -> str: +# file_id = base64.urlsafe_b64decode(file_id).decode() +# return os.path.join(BASE_TEMP_DIR, "openai_files", file_id) +# +# +# @openai_router.post("/files") +# async def files( +# request: Request, +# file: UploadFile, +# purpose: str = "assistants", +# ) -> Dict: +# created_at = int(datetime.now().timestamp()) +# file_id = _get_file_id( +# purpose=purpose, created_at=created_at, filename=file.filename +# ) +# file_path = _get_file_path(file_id) +# file_dir = os.path.dirname(file_path) +# os.makedirs(file_dir, exist_ok=True) +# with open(file_path, "wb") as fp: +# shutil.copyfileobj(file.file, fp) +# file.file.close() +# +# return dict( +# id=file_id, +# filename=file.filename, +# bytes=file.size, +# created_at=created_at, +# object="file", +# purpose=purpose, +# ) +# +# +# @openai_router.get("/files") +# def list_files(purpose: str) -> Dict[str, List[Dict]]: +# file_ids = [] +# root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose +# for dir, sub_dirs, files in os.walk(root_path): +# dir = Path(dir).relative_to(root_path).as_posix() +# for file in files: +# file_id = base64.urlsafe_b64encode( +# f"{purpose}/{dir}/{file}".encode() +# ).decode() +# file_ids.append(file_id) +# return { +# "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids] +# } +# +# +# @openai_router.get("/files/{file_id}") +# def retrieve_file(file_id: str) -> Dict: +# file_info = _get_file_info(file_id) +# return {**file_info, "id": file_id, "object": "file"} +# +# +# @openai_router.get("/files/{file_id}/content") +# def retrieve_file_content(file_id: str) -> Dict: +# file_path = _get_file_path(file_id) +# return FileResponse(file_path) +# +# +# @openai_router.delete("/files/{file_id}") +# def delete_file(file_id: str) -> Dict: +# file_path = _get_file_path(file_id) +# deleted = False +# +# try: +# if os.path.isfile(file_path): +# os.remove(file_path) +# deleted = True +# except: +# ... +# +# return {"id": file_id, "deleted": deleted, "object": "file"} diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index 046e2da..fc010af 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -32,6 +32,43 @@ from ..utils import ( def create_models_from_config(configs, callbacks, stream): + configs = { + # 意图识别不需要输出,模型后台知道就行 + "preprocess_model": { + "glm-4-0520": { + "temperature": 0.05, + "max_tokens": 4096, + "history_len": 100, + "prompt_name": "default", + "callbacks": False, + }, + }, + "llm_model": { + "glm-4-0520": { + "temperature": 0.9, + "max_tokens": 4096, + "history_len": 10, + "prompt_name": "default", + "callbacks": True, + }, + }, + "action_model": { + "glm-4-0520": { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "default", + "callbacks": True, + }, + }, + "postprocess_model": { + "glm-4-0520": { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "default", + "callbacks": True, + } + } + } models = {} prompts = {} for model_type, model_configs in configs.items(): @@ -94,7 +131,7 @@ def create_models_chains( async def chat( - query: str = Body(..., description="用户输入"), + query: str = Body(..., description="用户输入", examples=[""]), metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), conversation_id: str = Body("", description="对话框ID"), message_id: str = Body(None, description="数据库消息ID"), @@ -202,7 +239,7 @@ async def chat( model=models["llm_model"].model_name, status=data["status"], message_type=data["message_type"], - message_id=message_id, + # message_id=message_id, ) yield ret.model_dump_json() diff --git a/src/mindpilot/app/tools/weather_check.py b/src/mindpilot/app/tools/weather_check.py index b98ca46..b711c2b 100644 --- a/src/mindpilot/app/tools/weather_check.py +++ b/src/mindpilot/app/tools/weather_check.py @@ -17,7 +17,7 @@ def weather_check( tool_config = { "use": False, - "api_key": "S8vrB4U_-c5mvAMiK", + "api_key": "SE7CGiRD5dvls08Ub", } api_key = tool_config.get("api_key") url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c" diff --git a/src/mindpilot/app/utils.py b/src/mindpilot/app/utils.py index 70be1b2..0a6cc47 100644 --- a/src/mindpilot/app/utils.py +++ b/src/mindpilot/app/utils.py @@ -142,9 +142,10 @@ def get_ChatOpenAI( # openai_api_key="", # openai_proxy="", # ) + # TODO 配置文件 params.update( - openai_api_base="", - openai_api_key="", + openai_api_base="https://open.bigmodel.cn/api/paas/v4/", + openai_api_key="8424573178d3681bb2e9bfbc5af24dd5.BKKxdk1d6zzgvfnV", openai_proxy="", ) model = ChatOpenAI(**params) @@ -170,11 +171,11 @@ def get_prompt_template(type: str, name: str) -> Optional[str]: def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]: import importlib - from ..app import tools + from app import tools importlib.reload(tools) - from ..app.tools import tools_registry + from app.tools import tools_registry # update_search_local_knowledgebase_tool() @@ -197,3 +198,131 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): finally: # Signal the aiter to stop. event.set() + +def get_OpenAIClient( + platform_name: str = None, + model_name: str = None, + is_async: bool = True, +) -> Union[openai.Client, openai.AsyncClient]: + # """ + # construct an openai Client for specified platform or model + # """ + # if platform_name is None: + # platform_info = get_model_info( + # model_name=model_name, platform_name=platform_name + # ) + # if platform_info is None: + # raise RuntimeError( + # f"cannot find configured platform for model: {model_name}" + # ) + # platform_name = platform_info.get("platform_name") + # platform_info = get_config_platforms().get(platform_name) + # assert platform_info, f"cannot find configured platform: {platform_name}" + # TODO 配置文件 + params = { + "base_url":"https://open.bigmodel.cn/api/paas/v4/", + "api_key": "8424573178d3681bb2e9bfbc5af24dd5.BKKxdk1d6zzgvfnV" + } + httpx_params = {} + # if api_proxy := platform_info.get("api_proxy"): + # httpx_params = { + # "proxies": api_proxy, + # "transport": httpx.HTTPTransport(local_address="0.0.0.0"), + # } + + if is_async: + if httpx_params: + params["http_client"] = httpx.AsyncClient(**httpx_params) + return openai.AsyncClient(**params) + else: + if httpx_params: + params["http_client"] = httpx.Client(**httpx_params) + return openai.Client(**params) + +def get_tool_config(name: str = None) -> Dict: + import importlib + + # from chatchat.configs import model_config + # importlib.reload(model_config) + # from chatchat.configs import TOOL_CONFIG + TOOL_CONFIG = { + # "search_local_knowledgebase": { + # "use": False, + # "top_k": 3, + # "score_threshold": 1.0, + # "conclude_prompt": { + # "with_result": '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' + # "不允许在答案中添加编造成分,答案请使用中文。 \n" + # "<已知信息>{{ context }}\n" + # "<问题>{{ question }}\n", + # "without_result": "请你根据我的提问回答我的问题:\n" + # "{{ question }}\n" + # "请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n", + # }, + # }, + "search_internet": { + "use": False, + "search_engine_name": "bing", + "search_engine_config": { + "bing": { + "result_len": 3, + "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", + "bing_key": "0f42b09dce16474a81c01562ded071dc", + }, + "metaphor": { + "result_len": 3, + "metaphor_api_key": "", + "split_result": False, + "chunk_size": 500, + "chunk_overlap": 0, + }, + "duckduckgo": {"result_len": 3}, + }, + "top_k": 10, + "verbose": "Origin", + "conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " + "\n<已知信息>{{ context }}\n" + "<问题>\n" + "{{ question }}\n" + "\n", + }, + "arxiv": { + "use": False, + }, + "shell": { + "use": False, + }, + "weather_check": { + "use": False, + "api_key": "SE7CGiRD5dvls08Ub", + }, + # "search_youtube": { + # "use": False, + # }, + "wolfram": { + "use": False, + "appid": "PWKVLW-6ETR93QX6Q", + }, + "calculate": { + "use": False, + }, + # "vqa_processor": { + # "use": False, + # "model_path": "your model path", + # "tokenizer_path": "your tokenizer path", + # "device": "cuda:1", + # }, + # "aqa_processor": { + # "use": False, + # "model_path": "your model path", + # "tokenizer_path": "yout tokenizer path", + # "device": "cuda:2", + # }, + # "text2images": { + # "use": False, + # }, + } + if name is None: + return TOOL_CONFIG + else: + return TOOL_CONFIG.get(name, {}) \ No newline at end of file