Browse Source

Merge remote-tracking branch 'origin/main'

# Conflicts:
#	src/mindpilot/app/api/api_schemas.py
#	src/mindpilot/app/api/openai_routes.py
main
gjl 1 year ago
parent
commit
3c985f3d36
12 changed files with 509 additions and 30 deletions
  1. +0
    -0
      src/__init__.py
  2. +0
    -0
      src/mindpilot/__init__.py
  3. +104
    -4
      src/mindpilot/app/api/api_schemas.py
  4. +320
    -0
      src/mindpilot/app/api/openai_routes.py
  5. +2
    -8
      src/mindpilot/app/chat/chat.py
  6. +2
    -1
      src/mindpilot/app/tools/search_internet.py
  7. +1
    -2
      src/mindpilot/app/tools/weather_check.py
  8. +2
    -1
      src/mindpilot/app/tools/wolfram.py
  9. +0
    -0
      src/mindpilot/app/utils/__init__.py
  10. +61
    -0
      src/mindpilot/app/utils/colorful.py
  11. +11
    -11
      src/mindpilot/app/utils/openai_utils.py
  12. +6
    -3
      src/mindpilot/main.py

+ 0
- 0
src/__init__.py View File


+ 0
- 0
src/mindpilot/__init__.py View File


+ 104
- 4
src/mindpilot/app/api/api_schemas.py View File

@@ -12,8 +12,108 @@ from openai.types.chat import (
completion_create_params,
)

from ..utils.openai_utils import MsgType

# from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
DEFAULT_LLM_MODEL = None # TODO 配置文件
TEMPERATURE = 0.8
from ..pydantic_v2 import AnyUrl, BaseModel, Field
from ..utils import MsgType



class OpenAIBaseInput(BaseModel):
user: Optional[str] = None
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict] = None
extra_query: Optional[Dict] = None
extra_json: Optional[Dict] = Field(None, alias="extra_body")
timeout: Optional[float] = None

class Config:
extra = "allow"


class OpenAIChatInput(OpenAIBaseInput):
messages: List[ChatCompletionMessageParam]
model: str = DEFAULT_LLM_MODEL
frequency_penalty: Optional[float] = None
function_call: Optional[completion_create_params.FunctionCall] = None
functions: List[completion_create_params.Function] = None
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: completion_create_params.ResponseFormat = None
seed: Optional[int] = None
stop: Union[Optional[str], List[str]] = None
stream: Optional[bool] = None
temperature: Optional[float] = TEMPERATURE
tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
tools: List[Union[ChatCompletionToolParam, str]] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None


class OpenAIEmbeddingsInput(OpenAIBaseInput):
input: Union[str, List[str]]
model: str
dimensions: Optional[int] = None
encoding_format: Optional[Literal["float", "base64"]] = None


class OpenAIImageBaseInput(OpenAIBaseInput):
model: str
n: int = 1
response_format: Optional[Literal["url", "b64_json"]] = None
size: Optional[
Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
] = "256x256"


class OpenAIImageGenerationsInput(OpenAIImageBaseInput):
prompt: str
quality: Literal["standard", "hd"] = None
style: Optional[Literal["vivid", "natural"]] = None


class OpenAIImageVariationsInput(OpenAIImageBaseInput):
image: Union[UploadFile, AnyUrl]


class OpenAIImageEditsInput(OpenAIImageVariationsInput):
prompt: str
mask: Union[UploadFile, AnyUrl]


class OpenAIAudioTranslationsInput(OpenAIBaseInput):
file: Union[UploadFile, AnyUrl]
model: str
prompt: Optional[str] = None
response_format: Optional[str] = None
temperature: float = TEMPERATURE


class OpenAIAudioTranscriptionsInput(OpenAIAudioTranslationsInput):
language: Optional[str] = None
timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None


class OpenAIAudioSpeechInput(OpenAIBaseInput):
input: str
model: str
voice: str
response_format: Optional[
Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]
] = None
speed: Optional[float] = None


# class OpenAIFileInput(OpenAIBaseInput):
# file: UploadFile # FileTypes
# purpose: Literal["fine-tune", "assistants"] = "assistants"


class OpenAIBaseOutput(BaseModel):
id: Optional[str] = None
@@ -27,10 +127,10 @@ class OpenAIBaseOutput(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
tool_calls: List[Dict] = []

status: Optional[int] = None
status: Optional[int] = None # AgentStatus
message_type: int = MsgType.TEXT
message_id: Optional[str] = None
is_ref: bool = False
message_id: Optional[str] = None # id in database table
is_ref: bool = False # wheather show in seperated expander

class Config:
extra = "allow"


+ 320
- 0
src/mindpilot/app/api/openai_routes.py View File

@@ -0,0 +1,320 @@
from __future__ import annotations

import asyncio
import base64
import logging
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator, Dict, Iterable, Tuple

from fastapi import APIRouter, Request
from fastapi.responses import FileResponse
from openai import AsyncClient
from openai.types.file_object import FileObject
from sse_starlette.sse import EventSourceResponse

# from chatchat.configs import BASE_TEMP_DIR, log_verbose


from .api_schemas import *

logger = logging.getLogger()


DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数
model_semaphores: Dict[
Tuple[str, str], asyncio.Semaphore
] = {} # key: (model_name, platform)
openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"])


# @asynccontextmanager
# async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
# """
# 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型
# """
# max_semaphore = 0
# selected_platform = ""
# model_infos = get_model_info(model_name=model_name, multiple=True)
# for m, c in model_infos.items():
# key = (m, c["platform_name"])
# api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES)
# if key not in model_semaphores:
# model_semaphores[key] = asyncio.Semaphore(api_concurrencies)
# semaphore = model_semaphores[key]
# if semaphore._value >= api_concurrencies:
# selected_platform = c["platform_name"]
# break
# elif semaphore._value > max_semaphore:
# selected_platform = c["platform_name"]
#
# key = (m, selected_platform)
# semaphore = model_semaphores[key]
# try:
# await semaphore.acquire()
# yield get_OpenAIClient(platform_name=selected_platform, is_async=True)
# except Exception:
# logger.error(f"failed when request to {key}", exc_info=True)
# finally:
# semaphore.release()


async def openai_request(
method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = []
):
"""
helper function to make openai request with extra fields
"""

async def generator():
for x in header:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {header}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()

async for chunk in await method(**params):
for k, v in extra_json.items():
setattr(chunk, k, v)
yield chunk.model_dump_json()

for x in tail:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {tail}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()

params = body.model_dump(exclude_unset=True)

if hasattr(body, "stream") and body.stream:
return EventSourceResponse(generator())
else:
result = await method(**params)
for k, v in extra_json.items():
setattr(result, k, v)
return result.model_dump()


# @openai_router.get("/models")
# async def list_models() -> Dict:
# """
# 整合所有平台的模型列表。
# """
#
# async def task(name: str, config: Dict):
# try:
# client = get_OpenAIClient(name, is_async=True)
# models = await client.models.list()
# return [{**x.model_dump(), "platform_name": name} for x in models.data]
# except Exception:
# logger.error(f"failed request to platform: {name}", exc_info=True)
# return []
#
# result = []
# tasks = [
# asyncio.create_task(task(name, config))
# for name, config in get_config_platforms().items()
# ]
# for t in asyncio.as_completed(tasks):
# result += await t
#
# return {"object": "list", "data": result}
#
#
# @openai_router.post("/chat/completions")
# async def create_chat_completions(
# request: Request,
# body: OpenAIChatInput,
# ):
# if log_verbose:
# print(body)
# async with get_model_client(body.model) as client:
# result = await openai_request(client.chat.completions.create, body)
# return result
#
#
# @openai_router.post("/completions")
# async def create_completions(
# request: Request,
# body: OpenAIChatInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.completions.create, body)
#
#
# @openai_router.post("/embeddings")
# async def create_embeddings(
# request: Request,
# body: OpenAIEmbeddingsInput,
# ):
# params = body.model_dump(exclude_unset=True)
# client = get_OpenAIClient(model_name=body.model)
# return (await client.embeddings.create(**params)).model_dump()
#
#
# @openai_router.post("/images/generations")
# async def create_image_generations(
# request: Request,
# body: OpenAIImageGenerationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.generate, body)
#
#
# @openai_router.post("/images/variations")
# async def create_image_variations(
# request: Request,
# body: OpenAIImageVariationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.create_variation, body)
#
#
# @openai_router.post("/images/edit")
# async def create_image_edit(
# request: Request,
# body: OpenAIImageEditsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.edit, body)
#
#
# @openai_router.post("/audio/translations", deprecated="暂不支持")
# async def create_audio_translations(
# request: Request,
# body: OpenAIAudioTranslationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.translations.create, body)
#
#
# @openai_router.post("/audio/transcriptions", deprecated="暂不支持")
# async def create_audio_transcriptions(
# request: Request,
# body: OpenAIAudioTranscriptionsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.transcriptions.create, body)
#
#
# @openai_router.post("/audio/speech", deprecated="暂不支持")
# async def create_audio_speech(
# request: Request,
# body: OpenAIAudioSpeechInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.speech.create, body)
#
#
# def _get_file_id(
# purpose: str,
# created_at: int,
# filename: str,
# ) -> str:
# today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
# return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode()
#
#
# def _get_file_info(file_id: str) -> Dict:
# splits = base64.urlsafe_b64decode(file_id).decode().split("/")
# created_at = -1
# size = -1
# file_path = _get_file_path(file_id)
# if os.path.isfile(file_path):
# created_at = int(os.path.getmtime(file_path))
# size = os.path.getsize(file_path)
#
# return {
# "purpose": splits[0],
# "created_at": created_at,
# "filename": splits[2],
# "bytes": size,
# }
#
#
# def _get_file_path(file_id: str) -> str:
# file_id = base64.urlsafe_b64decode(file_id).decode()
# return os.path.join(BASE_TEMP_DIR, "openai_files", file_id)
#
#
# @openai_router.post("/files")
# async def files(
# request: Request,
# file: UploadFile,
# purpose: str = "assistants",
# ) -> Dict:
# created_at = int(datetime.now().timestamp())
# file_id = _get_file_id(
# purpose=purpose, created_at=created_at, filename=file.filename
# )
# file_path = _get_file_path(file_id)
# file_dir = os.path.dirname(file_path)
# os.makedirs(file_dir, exist_ok=True)
# with open(file_path, "wb") as fp:
# shutil.copyfileobj(file.file, fp)
# file.file.close()
#
# return dict(
# id=file_id,
# filename=file.filename,
# bytes=file.size,
# created_at=created_at,
# object="file",
# purpose=purpose,
# )
#
#
# @openai_router.get("/files")
# def list_files(purpose: str) -> Dict[str, List[Dict]]:
# file_ids = []
# root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose
# for dir, sub_dirs, files in os.walk(root_path):
# dir = Path(dir).relative_to(root_path).as_posix()
# for file in files:
# file_id = base64.urlsafe_b64encode(
# f"{purpose}/{dir}/{file}".encode()
# ).decode()
# file_ids.append(file_id)
# return {
# "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids]
# }
#
#
# @openai_router.get("/files/{file_id}")
# def retrieve_file(file_id: str) -> Dict:
# file_info = _get_file_info(file_id)
# return {**file_info, "id": file_id, "object": "file"}
#
#
# @openai_router.get("/files/{file_id}/content")
# def retrieve_file_content(file_id: str) -> Dict:
# file_path = _get_file_path(file_id)
# return FileResponse(file_path)
#
#
# @openai_router.delete("/files/{file_id}")
# def delete_file(file_id: str) -> Dict:
# file_path = _get_file_path(file_id)
# deleted = False
#
# try:
# if os.path.isfile(file_path):
# os.remove(file_path)
# deleted = True
# except:
# ...
#
# return {"id": file_id, "deleted": deleted, "object": "file"}

+ 2
- 8
src/mindpilot/app/chat/chat.py View File

@@ -16,14 +16,8 @@ from ..callback_handler.agent_callback_handler import (
AgentStatus,
)
from ..chat.utils import History
from ..utils import (
MsgType,
get_ChatOpenAI,
get_prompt_template,
get_tool,
wrap_done,
)
from app.configs import MODEL_CONFIG,TOOL_CONFIG
from ..configs import MODEL_CONFIG, TOOL_CONFIG
from ..utils.openai_utils import get_ChatOpenAI, get_prompt_template, get_tool, wrap_done, MsgType


def create_models_from_config(configs, callbacks, stream):


+ 2
- 1
src/mindpilot/app/tools/search_internet.py View File

@@ -6,12 +6,13 @@ from langchain_community.utilities import BingSearchAPIWrapper
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from markdownify import markdownify
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from app.utils import get_tool_config

from ..pydantic_v1 import Field
# from chatchat.server.utils import get_tool_config

from .tools_registry import BaseToolOutput, regist_tool
from ..utils.openai_utils import get_tool_config


def bing_search(text, config):


+ 1
- 2
src/mindpilot/app/tools/weather_check.py View File

@@ -6,8 +6,7 @@ import requests
from ..pydantic_v1 import Field

from .tools_registry import BaseToolOutput, regist_tool

from app.utils import get_tool_config
from ..utils.openai_utils import get_tool_config


@regist_tool(title="天气查询")


+ 2
- 1
src/mindpilot/app/tools/wolfram.py View File

@@ -1,8 +1,9 @@
# Langchain 自带的 Wolfram Alpha API 封装

from ..pydantic_v1 import Field
from app.utils import get_tool_config
from .tools_registry import BaseToolOutput, regist_tool
from ..utils.openai_utils import get_tool_config


@regist_tool


+ 0
- 0
src/mindpilot/app/utils/__init__.py View File


+ 61
- 0
src/mindpilot/app/utils/colorful.py View File

@@ -0,0 +1,61 @@
import platform
from sys import stdout

if platform.system()=="Linux":
pass
else:
from colorama import init
init()

# Do you like the elegance of Chinese characters?
def print红(*kw,**kargs):
print("\033[0;31m",*kw,"\033[0m",**kargs)
def print绿(*kw,**kargs):
print("\033[0;32m",*kw,"\033[0m",**kargs)
def print黄(*kw,**kargs):
print("\033[0;33m",*kw,"\033[0m",**kargs)
def print蓝(*kw,**kargs):
print("\033[0;34m",*kw,"\033[0m",**kargs)
def print紫(*kw,**kargs):
print("\033[0;35m",*kw,"\033[0m",**kargs)
def print靛(*kw,**kargs):
print("\033[0;36m",*kw,"\033[0m",**kargs)

def print亮红(*kw,**kargs):
print("\033[1;31m",*kw,"\033[0m",**kargs)
def print亮绿(*kw,**kargs):
print("\033[1;32m",*kw,"\033[0m",**kargs)
def print亮黄(*kw,**kargs):
print("\033[1;33m",*kw,"\033[0m",**kargs)
def print亮蓝(*kw,**kargs):
print("\033[1;34m",*kw,"\033[0m",**kargs)
def print亮紫(*kw,**kargs):
print("\033[1;35m",*kw,"\033[0m",**kargs)
def print亮靛(*kw,**kargs):
print("\033[1;36m",*kw,"\033[0m",**kargs)

# Do you like the elegance of Chinese characters?
def sprint红(*kw):
return "\033[0;31m"+' '.join(kw)+"\033[0m"
def sprint绿(*kw):
return "\033[0;32m"+' '.join(kw)+"\033[0m"
def sprint黄(*kw):
return "\033[0;33m"+' '.join(kw)+"\033[0m"
def sprint蓝(*kw):
return "\033[0;34m"+' '.join(kw)+"\033[0m"
def sprint紫(*kw):
return "\033[0;35m"+' '.join(kw)+"\033[0m"
def sprint靛(*kw):
return "\033[0;36m"+' '.join(kw)+"\033[0m"
def sprint亮红(*kw):
return "\033[1;31m"+' '.join(kw)+"\033[0m"
def sprint亮绿(*kw):
return "\033[1;32m"+' '.join(kw)+"\033[0m"
def sprint亮黄(*kw):
return "\033[1;33m"+' '.join(kw)+"\033[0m"
def sprint亮蓝(*kw):
return "\033[1;34m"+' '.join(kw)+"\033[0m"
def sprint亮紫(*kw):
return "\033[1;35m"+' '.join(kw)+"\033[0m"
def sprint亮靛(*kw):
return "\033[1;36m"+' '.join(kw)+"\033[0m"

src/mindpilot/app/utils.py → src/mindpilot/app/utils/openai_utils.py View File

@@ -148,7 +148,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
type: "llm_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
"""

from .configs.prompt_config import PROMPT_TEMPLATES
from src.mindpilot.app.configs import PROMPT_TEMPLATES

return PROMPT_TEMPLATES.get(type, {}).get(name)

@@ -156,12 +156,11 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]:
import importlib

from app import tools
from src.mindpilot.app import tools

importlib.reload(tools)

from app.tools import tools_registry

from src.mindpilot.app.tools import tools_registry

if name is None:
return tools_registry._TOOLS_REGISTRY
@@ -183,10 +182,11 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
# Signal the aiter to stop.
event.set()


def get_OpenAIClient(
platform_name: str = None,
model_name: str = None,
is_async: bool = True,
platform_name: str = None,
model_name: str = None,
is_async: bool = True,
) -> Union[openai.Client, openai.AsyncClient]:
# """
# construct an openai Client for specified platform or model
@@ -204,7 +204,7 @@ def get_OpenAIClient(
# assert platform_info, f"cannot find configured platform: {platform_name}"
# TODO 配置文件
params = {
"base_url":"https://open.bigmodel.cn/api/paas/v4/",
"base_url": "https://open.bigmodel.cn/api/paas/v4/",
"api_key": "8424573178d3681bb2e9bfbc5af24dd5.BKKxdk1d6zzgvfnV"
}
httpx_params = {}
@@ -223,11 +223,11 @@ def get_OpenAIClient(
params["http_client"] = httpx.Client(**httpx_params)
return openai.Client(**params)

def get_tool_config(name: str = None) -> Dict:

from app.configs import TOOL_CONFIG
def get_tool_config(name: str = None) -> Dict:
from src.mindpilot.app.configs import TOOL_CONFIG

if name is None:
return TOOL_CONFIG
else:
return TOOL_CONFIG.get(name, {})
return TOOL_CONFIG.get(name, {})

+ 6
- 3
src/mindpilot/main.py View File

@@ -8,7 +8,9 @@ from contextlib import asynccontextmanager
from multiprocessing import Process
import argparse
from fastapi import FastAPI
from app.configs import HOST,PORT
from app.configs import HOST, PORT
from src.mindpilot.app.utils.colorful import print亮蓝


logger = logging.getLogger()

@@ -37,7 +39,7 @@ def run_api_server(
):
import uvicorn
from app.api.api_server import create_app
from app.utils import set_httpx_config
from src.mindpilot.app.utils.openai_utils import set_httpx_config

set_httpx_config()
app = create_app(run_mode=run_mode)
@@ -115,7 +117,8 @@ def main():
cwd = os.getcwd()
sys.path.append(cwd)
multiprocessing.freeze_support()
print("cwd:" + cwd)
print亮蓝(f"当前工作目录:{cwd}")
print亮蓝(f"OpenAPI 文档地址:http://{HOST}:{PORT}/docs")

if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()


Loading…
Cancel
Save