diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/__init__.py b/src/mindpilot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/api/api_schemas.py b/src/mindpilot/app/api/api_schemas.py index 4c90020..26de7ba 100644 --- a/src/mindpilot/app/api/api_schemas.py +++ b/src/mindpilot/app/api/api_schemas.py @@ -12,8 +12,108 @@ from openai.types.chat import ( completion_create_params, ) +from ..utils.openai_utils import MsgType + +# from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE +DEFAULT_LLM_MODEL = None # TODO 配置文件 +TEMPERATURE = 0.8 from ..pydantic_v2 import AnyUrl, BaseModel, Field -from ..utils import MsgType + + + +class OpenAIBaseInput(BaseModel): + user: Optional[str] = None + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Optional[Dict] = None + extra_query: Optional[Dict] = None + extra_json: Optional[Dict] = Field(None, alias="extra_body") + timeout: Optional[float] = None + + class Config: + extra = "allow" + + +class OpenAIChatInput(OpenAIBaseInput): + messages: List[ChatCompletionMessageParam] + model: str = DEFAULT_LLM_MODEL + frequency_penalty: Optional[float] = None + function_call: Optional[completion_create_params.FunctionCall] = None + functions: List[completion_create_params.Function] = None + logit_bias: Optional[Dict[str, int]] = None + logprobs: Optional[bool] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[float] = None + response_format: completion_create_params.ResponseFormat = None + seed: Optional[int] = None + stop: Union[Optional[str], List[str]] = None + stream: Optional[bool] = None + temperature: Optional[float] = TEMPERATURE + tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None + tools: List[Union[ChatCompletionToolParam, str]] = None + top_logprobs: Optional[int] = None + top_p: Optional[float] = None + + +class OpenAIEmbeddingsInput(OpenAIBaseInput): + input: Union[str, List[str]] + model: str + dimensions: Optional[int] = None + encoding_format: Optional[Literal["float", "base64"]] = None + + +class OpenAIImageBaseInput(OpenAIBaseInput): + model: str + n: int = 1 + response_format: Optional[Literal["url", "b64_json"]] = None + size: Optional[ + Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] + ] = "256x256" + + +class OpenAIImageGenerationsInput(OpenAIImageBaseInput): + prompt: str + quality: Literal["standard", "hd"] = None + style: Optional[Literal["vivid", "natural"]] = None + + +class OpenAIImageVariationsInput(OpenAIImageBaseInput): + image: Union[UploadFile, AnyUrl] + + +class OpenAIImageEditsInput(OpenAIImageVariationsInput): + prompt: str + mask: Union[UploadFile, AnyUrl] + + +class OpenAIAudioTranslationsInput(OpenAIBaseInput): + file: Union[UploadFile, AnyUrl] + model: str + prompt: Optional[str] = None + response_format: Optional[str] = None + temperature: float = TEMPERATURE + + +class OpenAIAudioTranscriptionsInput(OpenAIAudioTranslationsInput): + language: Optional[str] = None + timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None + + +class OpenAIAudioSpeechInput(OpenAIBaseInput): + input: str + model: str + voice: str + response_format: Optional[ + Literal["mp3", "opus", "aac", "flac", "pcm", "wav"] + ] = None + speed: Optional[float] = None + + +# class OpenAIFileInput(OpenAIBaseInput): +# file: UploadFile # FileTypes +# purpose: Literal["fine-tune", "assistants"] = "assistants" + class OpenAIBaseOutput(BaseModel): id: Optional[str] = None @@ -27,10 +127,10 @@ class OpenAIBaseOutput(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) tool_calls: List[Dict] = [] - status: Optional[int] = None + status: Optional[int] = None # AgentStatus message_type: int = MsgType.TEXT - message_id: Optional[str] = None - is_ref: bool = False + message_id: Optional[str] = None # id in database table + is_ref: bool = False # wheather show in seperated expander class Config: extra = "allow" diff --git a/src/mindpilot/app/api/openai_routes.py b/src/mindpilot/app/api/openai_routes.py new file mode 100644 index 0000000..a510e2f --- /dev/null +++ b/src/mindpilot/app/api/openai_routes.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import asyncio +import base64 +import logging +import os +import shutil +from contextlib import asynccontextmanager +from datetime import datetime +from pathlib import Path +from typing import AsyncGenerator, Dict, Iterable, Tuple + +from fastapi import APIRouter, Request +from fastapi.responses import FileResponse +from openai import AsyncClient +from openai.types.file_object import FileObject +from sse_starlette.sse import EventSourceResponse + +# from chatchat.configs import BASE_TEMP_DIR, log_verbose + + +from .api_schemas import * + +logger = logging.getLogger() + + +DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数 +model_semaphores: Dict[ + Tuple[str, str], asyncio.Semaphore +] = {} # key: (model_name, platform) +openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]) + + +# @asynccontextmanager +# async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: +# """ +# 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型 +# """ +# max_semaphore = 0 +# selected_platform = "" +# model_infos = get_model_info(model_name=model_name, multiple=True) +# for m, c in model_infos.items(): +# key = (m, c["platform_name"]) +# api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES) +# if key not in model_semaphores: +# model_semaphores[key] = asyncio.Semaphore(api_concurrencies) +# semaphore = model_semaphores[key] +# if semaphore._value >= api_concurrencies: +# selected_platform = c["platform_name"] +# break +# elif semaphore._value > max_semaphore: +# selected_platform = c["platform_name"] +# +# key = (m, selected_platform) +# semaphore = model_semaphores[key] +# try: +# await semaphore.acquire() +# yield get_OpenAIClient(platform_name=selected_platform, is_async=True) +# except Exception: +# logger.error(f"failed when request to {key}", exc_info=True) +# finally: +# semaphore.release() + + +async def openai_request( + method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = [] +): + """ + helper function to make openai request with extra fields + """ + + async def generator(): + for x in header: + if isinstance(x, str): + x = OpenAIChatOutput(content=x, object="chat.completion.chunk") + elif isinstance(x, dict): + x = OpenAIChatOutput.model_validate(x) + else: + raise RuntimeError(f"unsupported value: {header}") + for k, v in extra_json.items(): + setattr(x, k, v) + yield x.model_dump_json() + + async for chunk in await method(**params): + for k, v in extra_json.items(): + setattr(chunk, k, v) + yield chunk.model_dump_json() + + for x in tail: + if isinstance(x, str): + x = OpenAIChatOutput(content=x, object="chat.completion.chunk") + elif isinstance(x, dict): + x = OpenAIChatOutput.model_validate(x) + else: + raise RuntimeError(f"unsupported value: {tail}") + for k, v in extra_json.items(): + setattr(x, k, v) + yield x.model_dump_json() + + params = body.model_dump(exclude_unset=True) + + if hasattr(body, "stream") and body.stream: + return EventSourceResponse(generator()) + else: + result = await method(**params) + for k, v in extra_json.items(): + setattr(result, k, v) + return result.model_dump() + + +# @openai_router.get("/models") +# async def list_models() -> Dict: +# """ +# 整合所有平台的模型列表。 +# """ +# +# async def task(name: str, config: Dict): +# try: +# client = get_OpenAIClient(name, is_async=True) +# models = await client.models.list() +# return [{**x.model_dump(), "platform_name": name} for x in models.data] +# except Exception: +# logger.error(f"failed request to platform: {name}", exc_info=True) +# return [] +# +# result = [] +# tasks = [ +# asyncio.create_task(task(name, config)) +# for name, config in get_config_platforms().items() +# ] +# for t in asyncio.as_completed(tasks): +# result += await t +# +# return {"object": "list", "data": result} +# +# +# @openai_router.post("/chat/completions") +# async def create_chat_completions( +# request: Request, +# body: OpenAIChatInput, +# ): +# if log_verbose: +# print(body) +# async with get_model_client(body.model) as client: +# result = await openai_request(client.chat.completions.create, body) +# return result +# +# +# @openai_router.post("/completions") +# async def create_completions( +# request: Request, +# body: OpenAIChatInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.completions.create, body) +# +# +# @openai_router.post("/embeddings") +# async def create_embeddings( +# request: Request, +# body: OpenAIEmbeddingsInput, +# ): +# params = body.model_dump(exclude_unset=True) +# client = get_OpenAIClient(model_name=body.model) +# return (await client.embeddings.create(**params)).model_dump() +# +# +# @openai_router.post("/images/generations") +# async def create_image_generations( +# request: Request, +# body: OpenAIImageGenerationsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.images.generate, body) +# +# +# @openai_router.post("/images/variations") +# async def create_image_variations( +# request: Request, +# body: OpenAIImageVariationsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.images.create_variation, body) +# +# +# @openai_router.post("/images/edit") +# async def create_image_edit( +# request: Request, +# body: OpenAIImageEditsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.images.edit, body) +# +# +# @openai_router.post("/audio/translations", deprecated="暂不支持") +# async def create_audio_translations( +# request: Request, +# body: OpenAIAudioTranslationsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.audio.translations.create, body) +# +# +# @openai_router.post("/audio/transcriptions", deprecated="暂不支持") +# async def create_audio_transcriptions( +# request: Request, +# body: OpenAIAudioTranscriptionsInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.audio.transcriptions.create, body) +# +# +# @openai_router.post("/audio/speech", deprecated="暂不支持") +# async def create_audio_speech( +# request: Request, +# body: OpenAIAudioSpeechInput, +# ): +# async with get_model_client(body.model) as client: +# return await openai_request(client.audio.speech.create, body) +# +# +# def _get_file_id( +# purpose: str, +# created_at: int, +# filename: str, +# ) -> str: +# today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") +# return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode() +# +# +# def _get_file_info(file_id: str) -> Dict: +# splits = base64.urlsafe_b64decode(file_id).decode().split("/") +# created_at = -1 +# size = -1 +# file_path = _get_file_path(file_id) +# if os.path.isfile(file_path): +# created_at = int(os.path.getmtime(file_path)) +# size = os.path.getsize(file_path) +# +# return { +# "purpose": splits[0], +# "created_at": created_at, +# "filename": splits[2], +# "bytes": size, +# } +# +# +# def _get_file_path(file_id: str) -> str: +# file_id = base64.urlsafe_b64decode(file_id).decode() +# return os.path.join(BASE_TEMP_DIR, "openai_files", file_id) +# +# +# @openai_router.post("/files") +# async def files( +# request: Request, +# file: UploadFile, +# purpose: str = "assistants", +# ) -> Dict: +# created_at = int(datetime.now().timestamp()) +# file_id = _get_file_id( +# purpose=purpose, created_at=created_at, filename=file.filename +# ) +# file_path = _get_file_path(file_id) +# file_dir = os.path.dirname(file_path) +# os.makedirs(file_dir, exist_ok=True) +# with open(file_path, "wb") as fp: +# shutil.copyfileobj(file.file, fp) +# file.file.close() +# +# return dict( +# id=file_id, +# filename=file.filename, +# bytes=file.size, +# created_at=created_at, +# object="file", +# purpose=purpose, +# ) +# +# +# @openai_router.get("/files") +# def list_files(purpose: str) -> Dict[str, List[Dict]]: +# file_ids = [] +# root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose +# for dir, sub_dirs, files in os.walk(root_path): +# dir = Path(dir).relative_to(root_path).as_posix() +# for file in files: +# file_id = base64.urlsafe_b64encode( +# f"{purpose}/{dir}/{file}".encode() +# ).decode() +# file_ids.append(file_id) +# return { +# "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids] +# } +# +# +# @openai_router.get("/files/{file_id}") +# def retrieve_file(file_id: str) -> Dict: +# file_info = _get_file_info(file_id) +# return {**file_info, "id": file_id, "object": "file"} +# +# +# @openai_router.get("/files/{file_id}/content") +# def retrieve_file_content(file_id: str) -> Dict: +# file_path = _get_file_path(file_id) +# return FileResponse(file_path) +# +# +# @openai_router.delete("/files/{file_id}") +# def delete_file(file_id: str) -> Dict: +# file_path = _get_file_path(file_id) +# deleted = False +# +# try: +# if os.path.isfile(file_path): +# os.remove(file_path) +# deleted = True +# except: +# ... +# +# return {"id": file_id, "deleted": deleted, "object": "file"} diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index 5043054..fbc3b53 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -16,14 +16,8 @@ from ..callback_handler.agent_callback_handler import ( AgentStatus, ) from ..chat.utils import History -from ..utils import ( - MsgType, - get_ChatOpenAI, - get_prompt_template, - get_tool, - wrap_done, -) -from app.configs import MODEL_CONFIG,TOOL_CONFIG +from ..configs import MODEL_CONFIG, TOOL_CONFIG +from ..utils.openai_utils import get_ChatOpenAI, get_prompt_template, get_tool, wrap_done, MsgType def create_models_from_config(configs, callbacks, stream): diff --git a/src/mindpilot/app/tools/search_internet.py b/src/mindpilot/app/tools/search_internet.py index 2343df3..1fca235 100644 --- a/src/mindpilot/app/tools/search_internet.py +++ b/src/mindpilot/app/tools/search_internet.py @@ -6,12 +6,13 @@ from langchain_community.utilities import BingSearchAPIWrapper from langchain_community.utilities import DuckDuckGoSearchAPIWrapper from markdownify import markdownify from strsimpy.normalized_levenshtein import NormalizedLevenshtein -from app.utils import get_tool_config + from ..pydantic_v1 import Field # from chatchat.server.utils import get_tool_config from .tools_registry import BaseToolOutput, regist_tool +from ..utils.openai_utils import get_tool_config def bing_search(text, config): diff --git a/src/mindpilot/app/tools/weather_check.py b/src/mindpilot/app/tools/weather_check.py index 2fddd66..4f17b2b 100644 --- a/src/mindpilot/app/tools/weather_check.py +++ b/src/mindpilot/app/tools/weather_check.py @@ -6,8 +6,7 @@ import requests from ..pydantic_v1 import Field from .tools_registry import BaseToolOutput, regist_tool - -from app.utils import get_tool_config +from ..utils.openai_utils import get_tool_config @regist_tool(title="天气查询") diff --git a/src/mindpilot/app/tools/wolfram.py b/src/mindpilot/app/tools/wolfram.py index 1627867..3edf5cf 100644 --- a/src/mindpilot/app/tools/wolfram.py +++ b/src/mindpilot/app/tools/wolfram.py @@ -1,8 +1,9 @@ # Langchain 自带的 Wolfram Alpha API 封装 from ..pydantic_v1 import Field -from app.utils import get_tool_config + from .tools_registry import BaseToolOutput, regist_tool +from ..utils.openai_utils import get_tool_config @regist_tool diff --git a/src/mindpilot/app/utils/__init__.py b/src/mindpilot/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/utils/colorful.py b/src/mindpilot/app/utils/colorful.py new file mode 100644 index 0000000..f0414e5 --- /dev/null +++ b/src/mindpilot/app/utils/colorful.py @@ -0,0 +1,61 @@ +import platform +from sys import stdout + +if platform.system()=="Linux": + pass +else: + from colorama import init + init() + +# Do you like the elegance of Chinese characters? +def print红(*kw,**kargs): + print("\033[0;31m",*kw,"\033[0m",**kargs) +def print绿(*kw,**kargs): + print("\033[0;32m",*kw,"\033[0m",**kargs) +def print黄(*kw,**kargs): + print("\033[0;33m",*kw,"\033[0m",**kargs) +def print蓝(*kw,**kargs): + print("\033[0;34m",*kw,"\033[0m",**kargs) +def print紫(*kw,**kargs): + print("\033[0;35m",*kw,"\033[0m",**kargs) +def print靛(*kw,**kargs): + print("\033[0;36m",*kw,"\033[0m",**kargs) + +def print亮红(*kw,**kargs): + print("\033[1;31m",*kw,"\033[0m",**kargs) +def print亮绿(*kw,**kargs): + print("\033[1;32m",*kw,"\033[0m",**kargs) +def print亮黄(*kw,**kargs): + print("\033[1;33m",*kw,"\033[0m",**kargs) +def print亮蓝(*kw,**kargs): + print("\033[1;34m",*kw,"\033[0m",**kargs) +def print亮紫(*kw,**kargs): + print("\033[1;35m",*kw,"\033[0m",**kargs) +def print亮靛(*kw,**kargs): + print("\033[1;36m",*kw,"\033[0m",**kargs) + +# Do you like the elegance of Chinese characters? +def sprint红(*kw): + return "\033[0;31m"+' '.join(kw)+"\033[0m" +def sprint绿(*kw): + return "\033[0;32m"+' '.join(kw)+"\033[0m" +def sprint黄(*kw): + return "\033[0;33m"+' '.join(kw)+"\033[0m" +def sprint蓝(*kw): + return "\033[0;34m"+' '.join(kw)+"\033[0m" +def sprint紫(*kw): + return "\033[0;35m"+' '.join(kw)+"\033[0m" +def sprint靛(*kw): + return "\033[0;36m"+' '.join(kw)+"\033[0m" +def sprint亮红(*kw): + return "\033[1;31m"+' '.join(kw)+"\033[0m" +def sprint亮绿(*kw): + return "\033[1;32m"+' '.join(kw)+"\033[0m" +def sprint亮黄(*kw): + return "\033[1;33m"+' '.join(kw)+"\033[0m" +def sprint亮蓝(*kw): + return "\033[1;34m"+' '.join(kw)+"\033[0m" +def sprint亮紫(*kw): + return "\033[1;35m"+' '.join(kw)+"\033[0m" +def sprint亮靛(*kw): + return "\033[1;36m"+' '.join(kw)+"\033[0m" diff --git a/src/mindpilot/app/utils.py b/src/mindpilot/app/utils/openai_utils.py similarity index 93% rename from src/mindpilot/app/utils.py rename to src/mindpilot/app/utils/openai_utils.py index c8d337f..44c1208 100644 --- a/src/mindpilot/app/utils.py +++ b/src/mindpilot/app/utils/openai_utils.py @@ -148,7 +148,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]: type: "llm_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。 """ - from .configs.prompt_config import PROMPT_TEMPLATES + from src.mindpilot.app.configs import PROMPT_TEMPLATES return PROMPT_TEMPLATES.get(type, {}).get(name) @@ -156,12 +156,11 @@ def get_prompt_template(type: str, name: str) -> Optional[str]: def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]: import importlib - from app import tools + from src.mindpilot.app import tools importlib.reload(tools) - from app.tools import tools_registry - + from src.mindpilot.app.tools import tools_registry if name is None: return tools_registry._TOOLS_REGISTRY @@ -183,10 +182,11 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): # Signal the aiter to stop. event.set() + def get_OpenAIClient( - platform_name: str = None, - model_name: str = None, - is_async: bool = True, + platform_name: str = None, + model_name: str = None, + is_async: bool = True, ) -> Union[openai.Client, openai.AsyncClient]: # """ # construct an openai Client for specified platform or model @@ -204,7 +204,7 @@ def get_OpenAIClient( # assert platform_info, f"cannot find configured platform: {platform_name}" # TODO 配置文件 params = { - "base_url":"https://open.bigmodel.cn/api/paas/v4/", + "base_url": "https://open.bigmodel.cn/api/paas/v4/", "api_key": "8424573178d3681bb2e9bfbc5af24dd5.BKKxdk1d6zzgvfnV" } httpx_params = {} @@ -223,11 +223,11 @@ def get_OpenAIClient( params["http_client"] = httpx.Client(**httpx_params) return openai.Client(**params) -def get_tool_config(name: str = None) -> Dict: - from app.configs import TOOL_CONFIG +def get_tool_config(name: str = None) -> Dict: + from src.mindpilot.app.configs import TOOL_CONFIG if name is None: return TOOL_CONFIG else: - return TOOL_CONFIG.get(name, {}) \ No newline at end of file + return TOOL_CONFIG.get(name, {}) diff --git a/src/mindpilot/main.py b/src/mindpilot/main.py index e5b297c..e99af2f 100644 --- a/src/mindpilot/main.py +++ b/src/mindpilot/main.py @@ -8,7 +8,9 @@ from contextlib import asynccontextmanager from multiprocessing import Process import argparse from fastapi import FastAPI -from app.configs import HOST,PORT +from app.configs import HOST, PORT +from src.mindpilot.app.utils.colorful import print亮蓝 + logger = logging.getLogger() @@ -37,7 +39,7 @@ def run_api_server( ): import uvicorn from app.api.api_server import create_app - from app.utils import set_httpx_config + from src.mindpilot.app.utils.openai_utils import set_httpx_config set_httpx_config() app = create_app(run_mode=run_mode) @@ -115,7 +117,8 @@ def main(): cwd = os.getcwd() sys.path.append(cwd) multiprocessing.freeze_support() - print("cwd:" + cwd) + print亮蓝(f"当前工作目录:{cwd}") + print亮蓝(f"OpenAPI 文档地址:http://{HOST}:{PORT}/docs") if sys.version_info < (3, 10): loop = asyncio.get_event_loop()