Browse Source

V0.1.0

main
gjl 1 year ago
parent
commit
f8b3d6120b
6 changed files with 5 additions and 461 deletions
  1. +3
    -101
      src/mindpilot/app/api/api_schemas.py
  2. +0
    -7
      src/mindpilot/app/api/api_server.py
  3. +0
    -320
      src/mindpilot/app/api/openai_routes.py
  4. +0
    -33
      src/mindpilot/app/callback_handler/conversation_callback_handler.py
  5. +1
    -0
      src/mindpilot/app/chat/chat.py
  6. +1
    -0
      src/mindpilot/app/tools/search_internet.py

+ 3
- 101
src/mindpilot/app/api/api_schemas.py View File

@@ -12,107 +12,9 @@ from openai.types.chat import (
completion_create_params,
)

# from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
DEFAULT_LLM_MODEL = None # TODO 配置文件
TEMPERATURE = 0.8
from ..pydantic_v2 import AnyUrl, BaseModel, Field
from ..utils import MsgType


class OpenAIBaseInput(BaseModel):
user: Optional[str] = None
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict] = None
extra_query: Optional[Dict] = None
extra_json: Optional[Dict] = Field(None, alias="extra_body")
timeout: Optional[float] = None

class Config:
extra = "allow"


class OpenAIChatInput(OpenAIBaseInput):
messages: List[ChatCompletionMessageParam]
model: str = DEFAULT_LLM_MODEL
frequency_penalty: Optional[float] = None
function_call: Optional[completion_create_params.FunctionCall] = None
functions: List[completion_create_params.Function] = None
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: completion_create_params.ResponseFormat = None
seed: Optional[int] = None
stop: Union[Optional[str], List[str]] = None
stream: Optional[bool] = None
temperature: Optional[float] = TEMPERATURE
tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
tools: List[Union[ChatCompletionToolParam, str]] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None


class OpenAIEmbeddingsInput(OpenAIBaseInput):
input: Union[str, List[str]]
model: str
dimensions: Optional[int] = None
encoding_format: Optional[Literal["float", "base64"]] = None


class OpenAIImageBaseInput(OpenAIBaseInput):
model: str
n: int = 1
response_format: Optional[Literal["url", "b64_json"]] = None
size: Optional[
Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
] = "256x256"


class OpenAIImageGenerationsInput(OpenAIImageBaseInput):
prompt: str
quality: Literal["standard", "hd"] = None
style: Optional[Literal["vivid", "natural"]] = None


class OpenAIImageVariationsInput(OpenAIImageBaseInput):
image: Union[UploadFile, AnyUrl]


class OpenAIImageEditsInput(OpenAIImageVariationsInput):
prompt: str
mask: Union[UploadFile, AnyUrl]


class OpenAIAudioTranslationsInput(OpenAIBaseInput):
file: Union[UploadFile, AnyUrl]
model: str
prompt: Optional[str] = None
response_format: Optional[str] = None
temperature: float = TEMPERATURE


class OpenAIAudioTranscriptionsInput(OpenAIAudioTranslationsInput):
language: Optional[str] = None
timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None


class OpenAIAudioSpeechInput(OpenAIBaseInput):
input: str
model: str
voice: str
response_format: Optional[
Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]
] = None
speed: Optional[float] = None


# class OpenAIFileInput(OpenAIBaseInput):
# file: UploadFile # FileTypes
# purpose: Literal["fine-tune", "assistants"] = "assistants"


class OpenAIBaseOutput(BaseModel):
id: Optional[str] = None
content: Optional[str] = None
@@ -125,10 +27,10 @@ class OpenAIBaseOutput(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
tool_calls: List[Dict] = []

status: Optional[int] = None # AgentStatus
status: Optional[int] = None
message_type: int = MsgType.TEXT
message_id: Optional[str] = None # id in database table
is_ref: bool = False # wheather show in seperated expander
message_id: Optional[str] = None
is_ref: bool = False

class Config:
extra = "allow"


+ 0
- 7
src/mindpilot/app/api/api_server.py View File

@@ -34,11 +34,4 @@ def create_app(run_mode: str = None):
# app.include_router(openai_router)
# app.include_router(server_router)

# # 其它接口
# app.post(
# "/other/completion",
# tags=["Other"],
# summary="要求llm模型补全(通过LLMChain)",
# )(completion)

return app

+ 0
- 320
src/mindpilot/app/api/openai_routes.py View File

@@ -1,320 +0,0 @@
from __future__ import annotations

import asyncio
import base64
import logging
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator, Dict, Iterable, Tuple

from fastapi import APIRouter, Request
from fastapi.responses import FileResponse
from openai import AsyncClient
from openai.types.file_object import FileObject
from sse_starlette.sse import EventSourceResponse

# from chatchat.configs import BASE_TEMP_DIR, log_verbose
from ..utils import get_OpenAIClient

from .api_schemas import *

logger = logging.getLogger()


DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数
model_semaphores: Dict[
Tuple[str, str], asyncio.Semaphore
] = {} # key: (model_name, platform)
openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"])


# @asynccontextmanager
# async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
# """
# 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型
# """
# max_semaphore = 0
# selected_platform = ""
# model_infos = get_model_info(model_name=model_name, multiple=True)
# for m, c in model_infos.items():
# key = (m, c["platform_name"])
# api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES)
# if key not in model_semaphores:
# model_semaphores[key] = asyncio.Semaphore(api_concurrencies)
# semaphore = model_semaphores[key]
# if semaphore._value >= api_concurrencies:
# selected_platform = c["platform_name"]
# break
# elif semaphore._value > max_semaphore:
# selected_platform = c["platform_name"]
#
# key = (m, selected_platform)
# semaphore = model_semaphores[key]
# try:
# await semaphore.acquire()
# yield get_OpenAIClient(platform_name=selected_platform, is_async=True)
# except Exception:
# logger.error(f"failed when request to {key}", exc_info=True)
# finally:
# semaphore.release()


async def openai_request(
method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = []
):
"""
helper function to make openai request with extra fields
"""

async def generator():
for x in header:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {header}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()

async for chunk in await method(**params):
for k, v in extra_json.items():
setattr(chunk, k, v)
yield chunk.model_dump_json()

for x in tail:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {tail}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()

params = body.model_dump(exclude_unset=True)

if hasattr(body, "stream") and body.stream:
return EventSourceResponse(generator())
else:
result = await method(**params)
for k, v in extra_json.items():
setattr(result, k, v)
return result.model_dump()


# @openai_router.get("/models")
# async def list_models() -> Dict:
# """
# 整合所有平台的模型列表。
# """
#
# async def task(name: str, config: Dict):
# try:
# client = get_OpenAIClient(name, is_async=True)
# models = await client.models.list()
# return [{**x.model_dump(), "platform_name": name} for x in models.data]
# except Exception:
# logger.error(f"failed request to platform: {name}", exc_info=True)
# return []
#
# result = []
# tasks = [
# asyncio.create_task(task(name, config))
# for name, config in get_config_platforms().items()
# ]
# for t in asyncio.as_completed(tasks):
# result += await t
#
# return {"object": "list", "data": result}
#
#
# @openai_router.post("/chat/completions")
# async def create_chat_completions(
# request: Request,
# body: OpenAIChatInput,
# ):
# if log_verbose:
# print(body)
# async with get_model_client(body.model) as client:
# result = await openai_request(client.chat.completions.create, body)
# return result
#
#
# @openai_router.post("/completions")
# async def create_completions(
# request: Request,
# body: OpenAIChatInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.completions.create, body)
#
#
# @openai_router.post("/embeddings")
# async def create_embeddings(
# request: Request,
# body: OpenAIEmbeddingsInput,
# ):
# params = body.model_dump(exclude_unset=True)
# client = get_OpenAIClient(model_name=body.model)
# return (await client.embeddings.create(**params)).model_dump()
#
#
# @openai_router.post("/images/generations")
# async def create_image_generations(
# request: Request,
# body: OpenAIImageGenerationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.generate, body)
#
#
# @openai_router.post("/images/variations")
# async def create_image_variations(
# request: Request,
# body: OpenAIImageVariationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.create_variation, body)
#
#
# @openai_router.post("/images/edit")
# async def create_image_edit(
# request: Request,
# body: OpenAIImageEditsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.edit, body)
#
#
# @openai_router.post("/audio/translations", deprecated="暂不支持")
# async def create_audio_translations(
# request: Request,
# body: OpenAIAudioTranslationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.translations.create, body)
#
#
# @openai_router.post("/audio/transcriptions", deprecated="暂不支持")
# async def create_audio_transcriptions(
# request: Request,
# body: OpenAIAudioTranscriptionsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.transcriptions.create, body)
#
#
# @openai_router.post("/audio/speech", deprecated="暂不支持")
# async def create_audio_speech(
# request: Request,
# body: OpenAIAudioSpeechInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.speech.create, body)
#
#
# def _get_file_id(
# purpose: str,
# created_at: int,
# filename: str,
# ) -> str:
# today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
# return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode()
#
#
# def _get_file_info(file_id: str) -> Dict:
# splits = base64.urlsafe_b64decode(file_id).decode().split("/")
# created_at = -1
# size = -1
# file_path = _get_file_path(file_id)
# if os.path.isfile(file_path):
# created_at = int(os.path.getmtime(file_path))
# size = os.path.getsize(file_path)
#
# return {
# "purpose": splits[0],
# "created_at": created_at,
# "filename": splits[2],
# "bytes": size,
# }
#
#
# def _get_file_path(file_id: str) -> str:
# file_id = base64.urlsafe_b64decode(file_id).decode()
# return os.path.join(BASE_TEMP_DIR, "openai_files", file_id)
#
#
# @openai_router.post("/files")
# async def files(
# request: Request,
# file: UploadFile,
# purpose: str = "assistants",
# ) -> Dict:
# created_at = int(datetime.now().timestamp())
# file_id = _get_file_id(
# purpose=purpose, created_at=created_at, filename=file.filename
# )
# file_path = _get_file_path(file_id)
# file_dir = os.path.dirname(file_path)
# os.makedirs(file_dir, exist_ok=True)
# with open(file_path, "wb") as fp:
# shutil.copyfileobj(file.file, fp)
# file.file.close()
#
# return dict(
# id=file_id,
# filename=file.filename,
# bytes=file.size,
# created_at=created_at,
# object="file",
# purpose=purpose,
# )
#
#
# @openai_router.get("/files")
# def list_files(purpose: str) -> Dict[str, List[Dict]]:
# file_ids = []
# root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose
# for dir, sub_dirs, files in os.walk(root_path):
# dir = Path(dir).relative_to(root_path).as_posix()
# for file in files:
# file_id = base64.urlsafe_b64encode(
# f"{purpose}/{dir}/{file}".encode()
# ).decode()
# file_ids.append(file_id)
# return {
# "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids]
# }
#
#
# @openai_router.get("/files/{file_id}")
# def retrieve_file(file_id: str) -> Dict:
# file_info = _get_file_info(file_id)
# return {**file_info, "id": file_id, "object": "file"}
#
#
# @openai_router.get("/files/{file_id}/content")
# def retrieve_file_content(file_id: str) -> Dict:
# file_path = _get_file_path(file_id)
# return FileResponse(file_path)
#
#
# @openai_router.delete("/files/{file_id}")
# def delete_file(file_id: str) -> Dict:
# file_path = _get_file_path(file_id)
# deleted = False
#
# try:
# if os.path.isfile(file_path):
# os.remove(file_path)
# deleted = True
# except:
# ...
#
# return {"id": file_id, "deleted": deleted, "object": "file"}

+ 0
- 33
src/mindpilot/app/callback_handler/conversation_callback_handler.py View File

@@ -1,33 +0,0 @@
from typing import Any, Dict, List

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult

# from chatchat.server.db.repository import update_message


class ConversationCallbackHandler(BaseCallbackHandler):
raise_error: bool = True

def __init__(
self, conversation_id: str, message_id: str, chat_type: str, query: str
):
self.conversation_id = conversation_id
self.message_id = message_id
self.chat_type = chat_type
self.query = query
self.start_at = None

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
answer = response.generations[0][0].text
# update_message(self.message_id, answer)

+ 1
- 0
src/mindpilot/app/chat/chat.py View File

@@ -141,6 +141,7 @@ async def chat(
last_tool = {}
async for chunk in callback.aiter():
data = json.loads(chunk)
# print("data:{}".format(data))
data["tool_calls"] = []
data["message_type"] = MsgType.TEXT



+ 1
- 0
src/mindpilot/app/tools/search_internet.py View File

@@ -98,6 +98,7 @@ def search_engine(query: str, config: dict):
results = search_engine_use(
text=query, config=config["search_engine_config"][config["search_engine_name"]]
)
print(results)
docs = search_result2docs(results)
context = ""
docs = [


Loading…
Cancel
Save