diff --git a/src/mindpilot/app/api/api_schemas.py b/src/mindpilot/app/api/api_schemas.py index 4a0ae41..4c90020 100644 --- a/src/mindpilot/app/api/api_schemas.py +++ b/src/mindpilot/app/api/api_schemas.py @@ -12,107 +12,9 @@ from openai.types.chat import ( completion_create_params, ) -# from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE -DEFAULT_LLM_MODEL = None # TODO 配置文件 -TEMPERATURE = 0.8 from ..pydantic_v2 import AnyUrl, BaseModel, Field from ..utils import MsgType - -class OpenAIBaseInput(BaseModel): - user: Optional[str] = None - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Optional[Dict] = None - extra_query: Optional[Dict] = None - extra_json: Optional[Dict] = Field(None, alias="extra_body") - timeout: Optional[float] = None - - class Config: - extra = "allow" - - -class OpenAIChatInput(OpenAIBaseInput): - messages: List[ChatCompletionMessageParam] - model: str = DEFAULT_LLM_MODEL - frequency_penalty: Optional[float] = None - function_call: Optional[completion_create_params.FunctionCall] = None - functions: List[completion_create_params.Function] = None - logit_bias: Optional[Dict[str, int]] = None - logprobs: Optional[bool] = None - max_tokens: Optional[int] = None - n: Optional[int] = None - presence_penalty: Optional[float] = None - response_format: completion_create_params.ResponseFormat = None - seed: Optional[int] = None - stop: Union[Optional[str], List[str]] = None - stream: Optional[bool] = None - temperature: Optional[float] = TEMPERATURE - tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None - tools: List[Union[ChatCompletionToolParam, str]] = None - top_logprobs: Optional[int] = None - top_p: Optional[float] = None - - -class OpenAIEmbeddingsInput(OpenAIBaseInput): - input: Union[str, List[str]] - model: str - dimensions: Optional[int] = None - encoding_format: Optional[Literal["float", "base64"]] = None - - -class OpenAIImageBaseInput(OpenAIBaseInput): - model: str - n: int = 1 - response_format: Optional[Literal["url", "b64_json"]] = None - size: Optional[ - Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] - ] = "256x256" - - -class OpenAIImageGenerationsInput(OpenAIImageBaseInput): - prompt: str - quality: Literal["standard", "hd"] = None - style: Optional[Literal["vivid", "natural"]] = None - - -class OpenAIImageVariationsInput(OpenAIImageBaseInput): - image: Union[UploadFile, AnyUrl] - - -class OpenAIImageEditsInput(OpenAIImageVariationsInput): - prompt: str - mask: Union[UploadFile, AnyUrl] - - -class OpenAIAudioTranslationsInput(OpenAIBaseInput): - file: Union[UploadFile, AnyUrl] - model: str - prompt: Optional[str] = None - response_format: Optional[str] = None - temperature: float = TEMPERATURE - - -class OpenAIAudioTranscriptionsInput(OpenAIAudioTranslationsInput): - language: Optional[str] = None - timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None - - -class OpenAIAudioSpeechInput(OpenAIBaseInput): - input: str - model: str - voice: str - response_format: Optional[ - Literal["mp3", "opus", "aac", "flac", "pcm", "wav"] - ] = None - speed: Optional[float] = None - - -# class OpenAIFileInput(OpenAIBaseInput): -# file: UploadFile # FileTypes -# purpose: Literal["fine-tune", "assistants"] = "assistants" - - class OpenAIBaseOutput(BaseModel): id: Optional[str] = None content: Optional[str] = None @@ -125,10 +27,10 @@ class OpenAIBaseOutput(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) tool_calls: List[Dict] = [] - status: Optional[int] = None # AgentStatus + status: Optional[int] = None message_type: int = MsgType.TEXT - message_id: Optional[str] = None # id in database table - is_ref: bool = False # wheather show in seperated expander + message_id: Optional[str] = None + is_ref: bool = False class Config: extra = "allow" diff --git a/src/mindpilot/app/api/api_server.py b/src/mindpilot/app/api/api_server.py index 96bbe93..4d4c5fa 100644 --- a/src/mindpilot/app/api/api_server.py +++ b/src/mindpilot/app/api/api_server.py @@ -34,11 +34,4 @@ def create_app(run_mode: str = None): # app.include_router(openai_router) # app.include_router(server_router) - # # 其它接口 - # app.post( - # "/other/completion", - # tags=["Other"], - # summary="要求llm模型补全(通过LLMChain)", - # )(completion) - return app \ No newline at end of file diff --git a/src/mindpilot/app/api/openai_routes.py b/src/mindpilot/app/api/openai_routes.py deleted file mode 100644 index 4a7617b..0000000 --- a/src/mindpilot/app/api/openai_routes.py +++ /dev/null @@ -1,320 +0,0 @@ -from __future__ import annotations - -import asyncio -import base64 -import logging -import os -import shutil -from contextlib import asynccontextmanager -from datetime import datetime -from pathlib import Path -from typing import AsyncGenerator, Dict, Iterable, Tuple - -from fastapi import APIRouter, Request -from fastapi.responses import FileResponse -from openai import AsyncClient -from openai.types.file_object import FileObject -from sse_starlette.sse import EventSourceResponse - -# from chatchat.configs import BASE_TEMP_DIR, log_verbose -from ..utils import get_OpenAIClient - -from .api_schemas import * - -logger = logging.getLogger() - - -DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数 -model_semaphores: Dict[ - Tuple[str, str], asyncio.Semaphore -] = {} # key: (model_name, platform) -openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]) - - -# @asynccontextmanager -# async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: -# """ -# 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型 -# """ -# max_semaphore = 0 -# selected_platform = "" -# model_infos = get_model_info(model_name=model_name, multiple=True) -# for m, c in model_infos.items(): -# key = (m, c["platform_name"]) -# api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES) -# if key not in model_semaphores: -# model_semaphores[key] = asyncio.Semaphore(api_concurrencies) -# semaphore = model_semaphores[key] -# if semaphore._value >= api_concurrencies: -# selected_platform = c["platform_name"] -# break -# elif semaphore._value > max_semaphore: -# selected_platform = c["platform_name"] -# -# key = (m, selected_platform) -# semaphore = model_semaphores[key] -# try: -# await semaphore.acquire() -# yield get_OpenAIClient(platform_name=selected_platform, is_async=True) -# except Exception: -# logger.error(f"failed when request to {key}", exc_info=True) -# finally: -# semaphore.release() - - -async def openai_request( - method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = [] -): - """ - helper function to make openai request with extra fields - """ - - async def generator(): - for x in header: - if isinstance(x, str): - x = OpenAIChatOutput(content=x, object="chat.completion.chunk") - elif isinstance(x, dict): - x = OpenAIChatOutput.model_validate(x) - else: - raise RuntimeError(f"unsupported value: {header}") - for k, v in extra_json.items(): - setattr(x, k, v) - yield x.model_dump_json() - - async for chunk in await method(**params): - for k, v in extra_json.items(): - setattr(chunk, k, v) - yield chunk.model_dump_json() - - for x in tail: - if isinstance(x, str): - x = OpenAIChatOutput(content=x, object="chat.completion.chunk") - elif isinstance(x, dict): - x = OpenAIChatOutput.model_validate(x) - else: - raise RuntimeError(f"unsupported value: {tail}") - for k, v in extra_json.items(): - setattr(x, k, v) - yield x.model_dump_json() - - params = body.model_dump(exclude_unset=True) - - if hasattr(body, "stream") and body.stream: - return EventSourceResponse(generator()) - else: - result = await method(**params) - for k, v in extra_json.items(): - setattr(result, k, v) - return result.model_dump() - - -# @openai_router.get("/models") -# async def list_models() -> Dict: -# """ -# 整合所有平台的模型列表。 -# """ -# -# async def task(name: str, config: Dict): -# try: -# client = get_OpenAIClient(name, is_async=True) -# models = await client.models.list() -# return [{**x.model_dump(), "platform_name": name} for x in models.data] -# except Exception: -# logger.error(f"failed request to platform: {name}", exc_info=True) -# return [] -# -# result = [] -# tasks = [ -# asyncio.create_task(task(name, config)) -# for name, config in get_config_platforms().items() -# ] -# for t in asyncio.as_completed(tasks): -# result += await t -# -# return {"object": "list", "data": result} -# -# -# @openai_router.post("/chat/completions") -# async def create_chat_completions( -# request: Request, -# body: OpenAIChatInput, -# ): -# if log_verbose: -# print(body) -# async with get_model_client(body.model) as client: -# result = await openai_request(client.chat.completions.create, body) -# return result -# -# -# @openai_router.post("/completions") -# async def create_completions( -# request: Request, -# body: OpenAIChatInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.completions.create, body) -# -# -# @openai_router.post("/embeddings") -# async def create_embeddings( -# request: Request, -# body: OpenAIEmbeddingsInput, -# ): -# params = body.model_dump(exclude_unset=True) -# client = get_OpenAIClient(model_name=body.model) -# return (await client.embeddings.create(**params)).model_dump() -# -# -# @openai_router.post("/images/generations") -# async def create_image_generations( -# request: Request, -# body: OpenAIImageGenerationsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.images.generate, body) -# -# -# @openai_router.post("/images/variations") -# async def create_image_variations( -# request: Request, -# body: OpenAIImageVariationsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.images.create_variation, body) -# -# -# @openai_router.post("/images/edit") -# async def create_image_edit( -# request: Request, -# body: OpenAIImageEditsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.images.edit, body) -# -# -# @openai_router.post("/audio/translations", deprecated="暂不支持") -# async def create_audio_translations( -# request: Request, -# body: OpenAIAudioTranslationsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.audio.translations.create, body) -# -# -# @openai_router.post("/audio/transcriptions", deprecated="暂不支持") -# async def create_audio_transcriptions( -# request: Request, -# body: OpenAIAudioTranscriptionsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.audio.transcriptions.create, body) -# -# -# @openai_router.post("/audio/speech", deprecated="暂不支持") -# async def create_audio_speech( -# request: Request, -# body: OpenAIAudioSpeechInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.audio.speech.create, body) -# -# -# def _get_file_id( -# purpose: str, -# created_at: int, -# filename: str, -# ) -> str: -# today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") -# return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode() -# -# -# def _get_file_info(file_id: str) -> Dict: -# splits = base64.urlsafe_b64decode(file_id).decode().split("/") -# created_at = -1 -# size = -1 -# file_path = _get_file_path(file_id) -# if os.path.isfile(file_path): -# created_at = int(os.path.getmtime(file_path)) -# size = os.path.getsize(file_path) -# -# return { -# "purpose": splits[0], -# "created_at": created_at, -# "filename": splits[2], -# "bytes": size, -# } -# -# -# def _get_file_path(file_id: str) -> str: -# file_id = base64.urlsafe_b64decode(file_id).decode() -# return os.path.join(BASE_TEMP_DIR, "openai_files", file_id) -# -# -# @openai_router.post("/files") -# async def files( -# request: Request, -# file: UploadFile, -# purpose: str = "assistants", -# ) -> Dict: -# created_at = int(datetime.now().timestamp()) -# file_id = _get_file_id( -# purpose=purpose, created_at=created_at, filename=file.filename -# ) -# file_path = _get_file_path(file_id) -# file_dir = os.path.dirname(file_path) -# os.makedirs(file_dir, exist_ok=True) -# with open(file_path, "wb") as fp: -# shutil.copyfileobj(file.file, fp) -# file.file.close() -# -# return dict( -# id=file_id, -# filename=file.filename, -# bytes=file.size, -# created_at=created_at, -# object="file", -# purpose=purpose, -# ) -# -# -# @openai_router.get("/files") -# def list_files(purpose: str) -> Dict[str, List[Dict]]: -# file_ids = [] -# root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose -# for dir, sub_dirs, files in os.walk(root_path): -# dir = Path(dir).relative_to(root_path).as_posix() -# for file in files: -# file_id = base64.urlsafe_b64encode( -# f"{purpose}/{dir}/{file}".encode() -# ).decode() -# file_ids.append(file_id) -# return { -# "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids] -# } -# -# -# @openai_router.get("/files/{file_id}") -# def retrieve_file(file_id: str) -> Dict: -# file_info = _get_file_info(file_id) -# return {**file_info, "id": file_id, "object": "file"} -# -# -# @openai_router.get("/files/{file_id}/content") -# def retrieve_file_content(file_id: str) -> Dict: -# file_path = _get_file_path(file_id) -# return FileResponse(file_path) -# -# -# @openai_router.delete("/files/{file_id}") -# def delete_file(file_id: str) -> Dict: -# file_path = _get_file_path(file_id) -# deleted = False -# -# try: -# if os.path.isfile(file_path): -# os.remove(file_path) -# deleted = True -# except: -# ... -# -# return {"id": file_id, "deleted": deleted, "object": "file"} diff --git a/src/mindpilot/app/callback_handler/conversation_callback_handler.py b/src/mindpilot/app/callback_handler/conversation_callback_handler.py deleted file mode 100644 index 707c32f..0000000 --- a/src/mindpilot/app/callback_handler/conversation_callback_handler.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any, Dict, List - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult - -# from chatchat.server.db.repository import update_message - - -class ConversationCallbackHandler(BaseCallbackHandler): - raise_error: bool = True - - def __init__( - self, conversation_id: str, message_id: str, chat_type: str, query: str - ): - self.conversation_id = conversation_id - self.message_id = message_id - self.chat_type = chat_type - self.query = query - self.start_at = None - - @property - def always_verbose(self) -> bool: - """Whether to call verbose callbacks even if verbose is False.""" - return True - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - answer = response.generations[0][0].text - # update_message(self.message_id, answer) diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index 16764fc..5043054 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -141,6 +141,7 @@ async def chat( last_tool = {} async for chunk in callback.aiter(): data = json.loads(chunk) + # print("data:{}".format(data)) data["tool_calls"] = [] data["message_type"] = MsgType.TEXT diff --git a/src/mindpilot/app/tools/search_internet.py b/src/mindpilot/app/tools/search_internet.py index f25d0a7..2343df3 100644 --- a/src/mindpilot/app/tools/search_internet.py +++ b/src/mindpilot/app/tools/search_internet.py @@ -98,6 +98,7 @@ def search_engine(query: str, config: dict): results = search_engine_use( text=query, config=config["search_engine_config"][config["search_engine_name"]] ) + print(results) docs = search_result2docs(results) context = "" docs = [