|
- from __future__ import annotations
-
- import asyncio
- import base64
- import logging
- import os
- import shutil
- from contextlib import asynccontextmanager
- from datetime import datetime
- from pathlib import Path
- from typing import AsyncGenerator, Dict, Iterable, Tuple
-
- from fastapi import APIRouter, Request
- from fastapi.responses import FileResponse
- from openai import AsyncClient
- from openai.types.file_object import FileObject
- from sse_starlette.sse import EventSourceResponse
-
- # from chatchat.configs import BASE_TEMP_DIR, log_verbose
-
-
- from .api_schemas import *
-
- logger = logging.getLogger()
-
-
- DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数
- model_semaphores: Dict[
- Tuple[str, str], asyncio.Semaphore
- ] = {} # key: (model_name, platform)
- openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"])
-
-
- # @asynccontextmanager
- # async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
- # """
- # 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型
- # """
- # max_semaphore = 0
- # selected_platform = ""
- # model_infos = get_model_info(model_name=model_name, multiple=True)
- # for m, c in model_infos.items():
- # key = (m, c["platform_name"])
- # api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES)
- # if key not in model_semaphores:
- # model_semaphores[key] = asyncio.Semaphore(api_concurrencies)
- # semaphore = model_semaphores[key]
- # if semaphore._value >= api_concurrencies:
- # selected_platform = c["platform_name"]
- # break
- # elif semaphore._value > max_semaphore:
- # selected_platform = c["platform_name"]
- #
- # key = (m, selected_platform)
- # semaphore = model_semaphores[key]
- # try:
- # await semaphore.acquire()
- # yield get_OpenAIClient(platform_name=selected_platform, is_async=True)
- # except Exception:
- # logger.error(f"failed when request to {key}", exc_info=True)
- # finally:
- # semaphore.release()
-
-
- async def openai_request(
- method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = []
- ):
- """
- helper function to make openai request with extra fields
- """
-
- async def generator():
- for x in header:
- if isinstance(x, str):
- x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
- elif isinstance(x, dict):
- x = OpenAIChatOutput.model_validate(x)
- else:
- raise RuntimeError(f"unsupported value: {header}")
- for k, v in extra_json.items():
- setattr(x, k, v)
- yield x.model_dump_json()
-
- async for chunk in await method(**params):
- for k, v in extra_json.items():
- setattr(chunk, k, v)
- yield chunk.model_dump_json()
-
- for x in tail:
- if isinstance(x, str):
- x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
- elif isinstance(x, dict):
- x = OpenAIChatOutput.model_validate(x)
- else:
- raise RuntimeError(f"unsupported value: {tail}")
- for k, v in extra_json.items():
- setattr(x, k, v)
- yield x.model_dump_json()
-
- params = body.model_dump(exclude_unset=True)
-
- if hasattr(body, "stream") and body.stream:
- return EventSourceResponse(generator())
- else:
- result = await method(**params)
- for k, v in extra_json.items():
- setattr(result, k, v)
- return result.model_dump()
-
-
- # @openai_router.get("/models")
- # async def list_models() -> Dict:
- # """
- # 整合所有平台的模型列表。
- # """
- #
- # async def task(name: str, config: Dict):
- # try:
- # client = get_OpenAIClient(name, is_async=True)
- # models = await client.models.list()
- # return [{**x.model_dump(), "platform_name": name} for x in models.data]
- # except Exception:
- # logger.error(f"failed request to platform: {name}", exc_info=True)
- # return []
- #
- # result = []
- # tasks = [
- # asyncio.create_task(task(name, config))
- # for name, config in get_config_platforms().items()
- # ]
- # for t in asyncio.as_completed(tasks):
- # result += await t
- #
- # return {"object": "list", "data": result}
- #
- #
- # @openai_router.post("/chat/completions")
- # async def create_chat_completions(
- # request: Request,
- # body: OpenAIChatInput,
- # ):
- # if log_verbose:
- # print(body)
- # async with get_model_client(body.model) as client:
- # result = await openai_request(client.chat.completions.create, body)
- # return result
- #
- #
- # @openai_router.post("/completions")
- # async def create_completions(
- # request: Request,
- # body: OpenAIChatInput,
- # ):
- # async with get_model_client(body.model) as client:
- # return await openai_request(client.completions.create, body)
- #
- #
- # @openai_router.post("/embeddings")
- # async def create_embeddings(
- # request: Request,
- # body: OpenAIEmbeddingsInput,
- # ):
- # params = body.model_dump(exclude_unset=True)
- # client = get_OpenAIClient(model_name=body.model)
- # return (await client.embeddings.create(**params)).model_dump()
- #
- #
- # @openai_router.post("/images/generations")
- # async def create_image_generations(
- # request: Request,
- # body: OpenAIImageGenerationsInput,
- # ):
- # async with get_model_client(body.model) as client:
- # return await openai_request(client.images.generate, body)
- #
- #
- # @openai_router.post("/images/variations")
- # async def create_image_variations(
- # request: Request,
- # body: OpenAIImageVariationsInput,
- # ):
- # async with get_model_client(body.model) as client:
- # return await openai_request(client.images.create_variation, body)
- #
- #
- # @openai_router.post("/images/edit")
- # async def create_image_edit(
- # request: Request,
- # body: OpenAIImageEditsInput,
- # ):
- # async with get_model_client(body.model) as client:
- # return await openai_request(client.images.edit, body)
- #
- #
- # @openai_router.post("/audio/translations", deprecated="暂不支持")
- # async def create_audio_translations(
- # request: Request,
- # body: OpenAIAudioTranslationsInput,
- # ):
- # async with get_model_client(body.model) as client:
- # return await openai_request(client.audio.translations.create, body)
- #
- #
- # @openai_router.post("/audio/transcriptions", deprecated="暂不支持")
- # async def create_audio_transcriptions(
- # request: Request,
- # body: OpenAIAudioTranscriptionsInput,
- # ):
- # async with get_model_client(body.model) as client:
- # return await openai_request(client.audio.transcriptions.create, body)
- #
- #
- # @openai_router.post("/audio/speech", deprecated="暂不支持")
- # async def create_audio_speech(
- # request: Request,
- # body: OpenAIAudioSpeechInput,
- # ):
- # async with get_model_client(body.model) as client:
- # return await openai_request(client.audio.speech.create, body)
- #
- #
- # def _get_file_id(
- # purpose: str,
- # created_at: int,
- # filename: str,
- # ) -> str:
- # today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
- # return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode()
- #
- #
- # def _get_file_info(file_id: str) -> Dict:
- # splits = base64.urlsafe_b64decode(file_id).decode().split("/")
- # created_at = -1
- # size = -1
- # file_path = _get_file_path(file_id)
- # if os.path.isfile(file_path):
- # created_at = int(os.path.getmtime(file_path))
- # size = os.path.getsize(file_path)
- #
- # return {
- # "purpose": splits[0],
- # "created_at": created_at,
- # "filename": splits[2],
- # "bytes": size,
- # }
- #
- #
- # def _get_file_path(file_id: str) -> str:
- # file_id = base64.urlsafe_b64decode(file_id).decode()
- # return os.path.join(BASE_TEMP_DIR, "openai_files", file_id)
- #
- #
- # @openai_router.post("/files")
- # async def files(
- # request: Request,
- # file: UploadFile,
- # purpose: str = "assistants",
- # ) -> Dict:
- # created_at = int(datetime.now().timestamp())
- # file_id = _get_file_id(
- # purpose=purpose, created_at=created_at, filename=file.filename
- # )
- # file_path = _get_file_path(file_id)
- # file_dir = os.path.dirname(file_path)
- # os.makedirs(file_dir, exist_ok=True)
- # with open(file_path, "wb") as fp:
- # shutil.copyfileobj(file.file, fp)
- # file.file.close()
- #
- # return dict(
- # id=file_id,
- # filename=file.filename,
- # bytes=file.size,
- # created_at=created_at,
- # object="file",
- # purpose=purpose,
- # )
- #
- #
- # @openai_router.get("/files")
- # def list_files(purpose: str) -> Dict[str, List[Dict]]:
- # file_ids = []
- # root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose
- # for dir, sub_dirs, files in os.walk(root_path):
- # dir = Path(dir).relative_to(root_path).as_posix()
- # for file in files:
- # file_id = base64.urlsafe_b64encode(
- # f"{purpose}/{dir}/{file}".encode()
- # ).decode()
- # file_ids.append(file_id)
- # return {
- # "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids]
- # }
- #
- #
- # @openai_router.get("/files/{file_id}")
- # def retrieve_file(file_id: str) -> Dict:
- # file_info = _get_file_info(file_id)
- # return {**file_info, "id": file_id, "object": "file"}
- #
- #
- # @openai_router.get("/files/{file_id}/content")
- # def retrieve_file_content(file_id: str) -> Dict:
- # file_path = _get_file_path(file_id)
- # return FileResponse(file_path)
- #
- #
- # @openai_router.delete("/files/{file_id}")
- # def delete_file(file_id: str) -> Dict:
- # file_path = _get_file_path(file_id)
- # deleted = False
- #
- # try:
- # if os.path.isfile(file_path):
- # os.remove(file_path)
- # deleted = True
- # except:
- # ...
- #
- # return {"id": file_id, "deleted": deleted, "object": "file"}
|