diff --git a/src/mindpilot/app/agent/agents_registry.py b/src/mindpilot/app/agent/agents_registry.py index 0fca645..ab1885f 100644 --- a/src/mindpilot/app/agent/agents_registry.py +++ b/src/mindpilot/app/agent/agents_registry.py @@ -22,7 +22,7 @@ def agents_registry( prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)]) else: prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt - print(prompt) + # print(prompt) agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt) agent_executor = AgentExecutor( diff --git a/src/mindpilot/app/api/openai_routes.py b/src/mindpilot/app/api/openai_routes.py deleted file mode 100644 index a510e2f..0000000 --- a/src/mindpilot/app/api/openai_routes.py +++ /dev/null @@ -1,320 +0,0 @@ -from __future__ import annotations - -import asyncio -import base64 -import logging -import os -import shutil -from contextlib import asynccontextmanager -from datetime import datetime -from pathlib import Path -from typing import AsyncGenerator, Dict, Iterable, Tuple - -from fastapi import APIRouter, Request -from fastapi.responses import FileResponse -from openai import AsyncClient -from openai.types.file_object import FileObject -from sse_starlette.sse import EventSourceResponse - -# from chatchat.configs import BASE_TEMP_DIR, log_verbose - - -from .api_schemas import * - -logger = logging.getLogger() - - -DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数 -model_semaphores: Dict[ - Tuple[str, str], asyncio.Semaphore -] = {} # key: (model_name, platform) -openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"]) - - -# @asynccontextmanager -# async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]: -# """ -# 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型 -# """ -# max_semaphore = 0 -# selected_platform = "" -# model_infos = get_model_info(model_name=model_name, multiple=True) -# for m, c in model_infos.items(): -# key = (m, c["platform_name"]) -# api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES) -# if key not in model_semaphores: -# model_semaphores[key] = asyncio.Semaphore(api_concurrencies) -# semaphore = model_semaphores[key] -# if semaphore._value >= api_concurrencies: -# selected_platform = c["platform_name"] -# break -# elif semaphore._value > max_semaphore: -# selected_platform = c["platform_name"] -# -# key = (m, selected_platform) -# semaphore = model_semaphores[key] -# try: -# await semaphore.acquire() -# yield get_OpenAIClient(platform_name=selected_platform, is_async=True) -# except Exception: -# logger.error(f"failed when request to {key}", exc_info=True) -# finally: -# semaphore.release() - - -async def openai_request( - method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = [] -): - """ - helper function to make openai request with extra fields - """ - - async def generator(): - for x in header: - if isinstance(x, str): - x = OpenAIChatOutput(content=x, object="chat.completion.chunk") - elif isinstance(x, dict): - x = OpenAIChatOutput.model_validate(x) - else: - raise RuntimeError(f"unsupported value: {header}") - for k, v in extra_json.items(): - setattr(x, k, v) - yield x.model_dump_json() - - async for chunk in await method(**params): - for k, v in extra_json.items(): - setattr(chunk, k, v) - yield chunk.model_dump_json() - - for x in tail: - if isinstance(x, str): - x = OpenAIChatOutput(content=x, object="chat.completion.chunk") - elif isinstance(x, dict): - x = OpenAIChatOutput.model_validate(x) - else: - raise RuntimeError(f"unsupported value: {tail}") - for k, v in extra_json.items(): - setattr(x, k, v) - yield x.model_dump_json() - - params = body.model_dump(exclude_unset=True) - - if hasattr(body, "stream") and body.stream: - return EventSourceResponse(generator()) - else: - result = await method(**params) - for k, v in extra_json.items(): - setattr(result, k, v) - return result.model_dump() - - -# @openai_router.get("/models") -# async def list_models() -> Dict: -# """ -# 整合所有平台的模型列表。 -# """ -# -# async def task(name: str, config: Dict): -# try: -# client = get_OpenAIClient(name, is_async=True) -# models = await client.models.list() -# return [{**x.model_dump(), "platform_name": name} for x in models.data] -# except Exception: -# logger.error(f"failed request to platform: {name}", exc_info=True) -# return [] -# -# result = [] -# tasks = [ -# asyncio.create_task(task(name, config)) -# for name, config in get_config_platforms().items() -# ] -# for t in asyncio.as_completed(tasks): -# result += await t -# -# return {"object": "list", "data": result} -# -# -# @openai_router.post("/chat/completions") -# async def create_chat_completions( -# request: Request, -# body: OpenAIChatInput, -# ): -# if log_verbose: -# print(body) -# async with get_model_client(body.model) as client: -# result = await openai_request(client.chat.completions.create, body) -# return result -# -# -# @openai_router.post("/completions") -# async def create_completions( -# request: Request, -# body: OpenAIChatInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.completions.create, body) -# -# -# @openai_router.post("/embeddings") -# async def create_embeddings( -# request: Request, -# body: OpenAIEmbeddingsInput, -# ): -# params = body.model_dump(exclude_unset=True) -# client = get_OpenAIClient(model_name=body.model) -# return (await client.embeddings.create(**params)).model_dump() -# -# -# @openai_router.post("/images/generations") -# async def create_image_generations( -# request: Request, -# body: OpenAIImageGenerationsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.images.generate, body) -# -# -# @openai_router.post("/images/variations") -# async def create_image_variations( -# request: Request, -# body: OpenAIImageVariationsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.images.create_variation, body) -# -# -# @openai_router.post("/images/edit") -# async def create_image_edit( -# request: Request, -# body: OpenAIImageEditsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.images.edit, body) -# -# -# @openai_router.post("/audio/translations", deprecated="暂不支持") -# async def create_audio_translations( -# request: Request, -# body: OpenAIAudioTranslationsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.audio.translations.create, body) -# -# -# @openai_router.post("/audio/transcriptions", deprecated="暂不支持") -# async def create_audio_transcriptions( -# request: Request, -# body: OpenAIAudioTranscriptionsInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.audio.transcriptions.create, body) -# -# -# @openai_router.post("/audio/speech", deprecated="暂不支持") -# async def create_audio_speech( -# request: Request, -# body: OpenAIAudioSpeechInput, -# ): -# async with get_model_client(body.model) as client: -# return await openai_request(client.audio.speech.create, body) -# -# -# def _get_file_id( -# purpose: str, -# created_at: int, -# filename: str, -# ) -> str: -# today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") -# return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode() -# -# -# def _get_file_info(file_id: str) -> Dict: -# splits = base64.urlsafe_b64decode(file_id).decode().split("/") -# created_at = -1 -# size = -1 -# file_path = _get_file_path(file_id) -# if os.path.isfile(file_path): -# created_at = int(os.path.getmtime(file_path)) -# size = os.path.getsize(file_path) -# -# return { -# "purpose": splits[0], -# "created_at": created_at, -# "filename": splits[2], -# "bytes": size, -# } -# -# -# def _get_file_path(file_id: str) -> str: -# file_id = base64.urlsafe_b64decode(file_id).decode() -# return os.path.join(BASE_TEMP_DIR, "openai_files", file_id) -# -# -# @openai_router.post("/files") -# async def files( -# request: Request, -# file: UploadFile, -# purpose: str = "assistants", -# ) -> Dict: -# created_at = int(datetime.now().timestamp()) -# file_id = _get_file_id( -# purpose=purpose, created_at=created_at, filename=file.filename -# ) -# file_path = _get_file_path(file_id) -# file_dir = os.path.dirname(file_path) -# os.makedirs(file_dir, exist_ok=True) -# with open(file_path, "wb") as fp: -# shutil.copyfileobj(file.file, fp) -# file.file.close() -# -# return dict( -# id=file_id, -# filename=file.filename, -# bytes=file.size, -# created_at=created_at, -# object="file", -# purpose=purpose, -# ) -# -# -# @openai_router.get("/files") -# def list_files(purpose: str) -> Dict[str, List[Dict]]: -# file_ids = [] -# root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose -# for dir, sub_dirs, files in os.walk(root_path): -# dir = Path(dir).relative_to(root_path).as_posix() -# for file in files: -# file_id = base64.urlsafe_b64encode( -# f"{purpose}/{dir}/{file}".encode() -# ).decode() -# file_ids.append(file_id) -# return { -# "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids] -# } -# -# -# @openai_router.get("/files/{file_id}") -# def retrieve_file(file_id: str) -> Dict: -# file_info = _get_file_info(file_id) -# return {**file_info, "id": file_id, "object": "file"} -# -# -# @openai_router.get("/files/{file_id}/content") -# def retrieve_file_content(file_id: str) -> Dict: -# file_path = _get_file_path(file_id) -# return FileResponse(file_path) -# -# -# @openai_router.delete("/files/{file_id}") -# def delete_file(file_id: str) -> Dict: -# file_path = _get_file_path(file_id) -# deleted = False -# -# try: -# if os.path.isfile(file_path): -# os.remove(file_path) -# deleted = True -# except: -# ... -# -# return {"id": file_id, "deleted": deleted, "object": "file"} diff --git a/src/mindpilot/app/api/tool_routes.py b/src/mindpilot/app/api/tool_routes.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/callback_handler/agent_callback_handler.py b/src/mindpilot/app/callback_handler/agent_callback_handler.py index 157f968..5c50bee 100644 --- a/src/mindpilot/app/callback_handler/agent_callback_handler.py +++ b/src/mindpilot/app/callback_handler/agent_callback_handler.py @@ -33,7 +33,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.out = True async def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: data = { "status": AgentStatus.llm_start, @@ -43,7 +43,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"] + special_tokens = ["\nAction:", "\nObservation:", "<|observation|>", "\nThought:"] for stoken in special_tokens: if stoken in token: before_action = token.split(stoken)[0] @@ -63,15 +63,15 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, + self, + serialized: Dict[str, Any], + messages: List[List], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> None: data = { "status": AgentStatus.llm_start, @@ -88,7 +88,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_llm_error( - self, error: Exception | KeyboardInterrupt, **kwargs: Any + self, error: Exception | KeyboardInterrupt, **kwargs: Any ) -> None: data = { "status": AgentStatus.error, @@ -97,15 +97,15 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> None: data = { "run_id": str(run_id), @@ -116,13 +116,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_tool_end( - self, - output: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, ) -> None: """Run when tool ends running.""" data = { @@ -134,13 +134,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, ) -> None: """Run when tool errors.""" data = { @@ -153,13 +153,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_agent_action( - self, - action: AgentAction, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, ) -> None: data = { "status": AgentStatus.agent_action, @@ -170,13 +170,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_agent_finish( - self, - finish: AgentFinish, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, ) -> None: if "Thought:" in finish.return_values["output"]: finish.return_values["output"] = finish.return_values["output"].replace( @@ -190,13 +190,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.queue.put_nowait(dumps(data)) async def on_chain_end( - self, - outputs: Dict[str, Any], - *, - run_id: UUID, - parent_run_id: UUID | None = None, - tags: List[str] | None = None, - **kwargs: Any, + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: List[str] | None = None, + **kwargs: Any, ) -> None: self.done.set() self.out = True diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index fbc3b53..240d7d1 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -59,7 +59,6 @@ def create_models_chains( False ) chat_prompt = ChatPromptTemplate.from_messages([input_msg]) - print(chat_prompt) llm = models["llm_model"] llm.callbacks = callbacks diff --git a/src/mindpilot/app/chat/utils.py b/src/mindpilot/app/chat/utils.py index 11a95bb..7476b0a 100644 --- a/src/mindpilot/app/chat/utils.py +++ b/src/mindpilot/app/chat/utils.py @@ -1,10 +1,8 @@ -import logging -from functools import lru_cache from typing import Dict, List, Tuple, Union - from langchain.prompts.chat import ChatMessagePromptTemplate - from ..pydantic_v2 import BaseModel, Field +import logging +from typing import AsyncGenerator, Dict, Iterable, Tuple logger = logging.getLogger() @@ -48,4 +46,4 @@ class History(BaseModel): elif isinstance(h, dict): h = cls(**h) - return h + return h \ No newline at end of file