Browse Source

V0.1.0

main
gjl 1 year ago
parent
commit
be354280ca
6 changed files with 60 additions and 383 deletions
  1. +1
    -1
      src/mindpilot/app/agent/agents_registry.py
  2. +0
    -320
      src/mindpilot/app/api/openai_routes.py
  3. +0
    -0
      src/mindpilot/app/api/tool_routes.py
  4. +56
    -56
      src/mindpilot/app/callback_handler/agent_callback_handler.py
  5. +0
    -1
      src/mindpilot/app/chat/chat.py
  6. +3
    -5
      src/mindpilot/app/chat/utils.py

+ 1
- 1
src/mindpilot/app/agent/agents_registry.py View File

@@ -22,7 +22,7 @@ def agents_registry(
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
else:
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
print(prompt)
# print(prompt)
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)

agent_executor = AgentExecutor(


+ 0
- 320
src/mindpilot/app/api/openai_routes.py View File

@@ -1,320 +0,0 @@
from __future__ import annotations

import asyncio
import base64
import logging
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator, Dict, Iterable, Tuple

from fastapi import APIRouter, Request
from fastapi.responses import FileResponse
from openai import AsyncClient
from openai.types.file_object import FileObject
from sse_starlette.sse import EventSourceResponse

# from chatchat.configs import BASE_TEMP_DIR, log_verbose


from .api_schemas import *

logger = logging.getLogger()


DEFAULT_API_CONCURRENCIES = 5 # 默认单个模型最大并发数
model_semaphores: Dict[
Tuple[str, str], asyncio.Semaphore
] = {} # key: (model_name, platform)
openai_router = APIRouter(prefix="/v1", tags=["OpenAI 兼容平台整合接口"])


# @asynccontextmanager
# async def get_model_client(model_name: str) -> AsyncGenerator[AsyncClient]:
# """
# 对重名模型进行调度,依次选择:空闲的模型 -> 当前访问数最少的模型
# """
# max_semaphore = 0
# selected_platform = ""
# model_infos = get_model_info(model_name=model_name, multiple=True)
# for m, c in model_infos.items():
# key = (m, c["platform_name"])
# api_concurrencies = c.get("api_concurrencies", DEFAULT_API_CONCURRENCIES)
# if key not in model_semaphores:
# model_semaphores[key] = asyncio.Semaphore(api_concurrencies)
# semaphore = model_semaphores[key]
# if semaphore._value >= api_concurrencies:
# selected_platform = c["platform_name"]
# break
# elif semaphore._value > max_semaphore:
# selected_platform = c["platform_name"]
#
# key = (m, selected_platform)
# semaphore = model_semaphores[key]
# try:
# await semaphore.acquire()
# yield get_OpenAIClient(platform_name=selected_platform, is_async=True)
# except Exception:
# logger.error(f"failed when request to {key}", exc_info=True)
# finally:
# semaphore.release()


async def openai_request(
method, body, extra_json: Dict = {}, header: Iterable = [], tail: Iterable = []
):
"""
helper function to make openai request with extra fields
"""

async def generator():
for x in header:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {header}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()

async for chunk in await method(**params):
for k, v in extra_json.items():
setattr(chunk, k, v)
yield chunk.model_dump_json()

for x in tail:
if isinstance(x, str):
x = OpenAIChatOutput(content=x, object="chat.completion.chunk")
elif isinstance(x, dict):
x = OpenAIChatOutput.model_validate(x)
else:
raise RuntimeError(f"unsupported value: {tail}")
for k, v in extra_json.items():
setattr(x, k, v)
yield x.model_dump_json()

params = body.model_dump(exclude_unset=True)

if hasattr(body, "stream") and body.stream:
return EventSourceResponse(generator())
else:
result = await method(**params)
for k, v in extra_json.items():
setattr(result, k, v)
return result.model_dump()


# @openai_router.get("/models")
# async def list_models() -> Dict:
# """
# 整合所有平台的模型列表。
# """
#
# async def task(name: str, config: Dict):
# try:
# client = get_OpenAIClient(name, is_async=True)
# models = await client.models.list()
# return [{**x.model_dump(), "platform_name": name} for x in models.data]
# except Exception:
# logger.error(f"failed request to platform: {name}", exc_info=True)
# return []
#
# result = []
# tasks = [
# asyncio.create_task(task(name, config))
# for name, config in get_config_platforms().items()
# ]
# for t in asyncio.as_completed(tasks):
# result += await t
#
# return {"object": "list", "data": result}
#
#
# @openai_router.post("/chat/completions")
# async def create_chat_completions(
# request: Request,
# body: OpenAIChatInput,
# ):
# if log_verbose:
# print(body)
# async with get_model_client(body.model) as client:
# result = await openai_request(client.chat.completions.create, body)
# return result
#
#
# @openai_router.post("/completions")
# async def create_completions(
# request: Request,
# body: OpenAIChatInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.completions.create, body)
#
#
# @openai_router.post("/embeddings")
# async def create_embeddings(
# request: Request,
# body: OpenAIEmbeddingsInput,
# ):
# params = body.model_dump(exclude_unset=True)
# client = get_OpenAIClient(model_name=body.model)
# return (await client.embeddings.create(**params)).model_dump()
#
#
# @openai_router.post("/images/generations")
# async def create_image_generations(
# request: Request,
# body: OpenAIImageGenerationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.generate, body)
#
#
# @openai_router.post("/images/variations")
# async def create_image_variations(
# request: Request,
# body: OpenAIImageVariationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.create_variation, body)
#
#
# @openai_router.post("/images/edit")
# async def create_image_edit(
# request: Request,
# body: OpenAIImageEditsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.images.edit, body)
#
#
# @openai_router.post("/audio/translations", deprecated="暂不支持")
# async def create_audio_translations(
# request: Request,
# body: OpenAIAudioTranslationsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.translations.create, body)
#
#
# @openai_router.post("/audio/transcriptions", deprecated="暂不支持")
# async def create_audio_transcriptions(
# request: Request,
# body: OpenAIAudioTranscriptionsInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.transcriptions.create, body)
#
#
# @openai_router.post("/audio/speech", deprecated="暂不支持")
# async def create_audio_speech(
# request: Request,
# body: OpenAIAudioSpeechInput,
# ):
# async with get_model_client(body.model) as client:
# return await openai_request(client.audio.speech.create, body)
#
#
# def _get_file_id(
# purpose: str,
# created_at: int,
# filename: str,
# ) -> str:
# today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
# return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode()
#
#
# def _get_file_info(file_id: str) -> Dict:
# splits = base64.urlsafe_b64decode(file_id).decode().split("/")
# created_at = -1
# size = -1
# file_path = _get_file_path(file_id)
# if os.path.isfile(file_path):
# created_at = int(os.path.getmtime(file_path))
# size = os.path.getsize(file_path)
#
# return {
# "purpose": splits[0],
# "created_at": created_at,
# "filename": splits[2],
# "bytes": size,
# }
#
#
# def _get_file_path(file_id: str) -> str:
# file_id = base64.urlsafe_b64decode(file_id).decode()
# return os.path.join(BASE_TEMP_DIR, "openai_files", file_id)
#
#
# @openai_router.post("/files")
# async def files(
# request: Request,
# file: UploadFile,
# purpose: str = "assistants",
# ) -> Dict:
# created_at = int(datetime.now().timestamp())
# file_id = _get_file_id(
# purpose=purpose, created_at=created_at, filename=file.filename
# )
# file_path = _get_file_path(file_id)
# file_dir = os.path.dirname(file_path)
# os.makedirs(file_dir, exist_ok=True)
# with open(file_path, "wb") as fp:
# shutil.copyfileobj(file.file, fp)
# file.file.close()
#
# return dict(
# id=file_id,
# filename=file.filename,
# bytes=file.size,
# created_at=created_at,
# object="file",
# purpose=purpose,
# )
#
#
# @openai_router.get("/files")
# def list_files(purpose: str) -> Dict[str, List[Dict]]:
# file_ids = []
# root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose
# for dir, sub_dirs, files in os.walk(root_path):
# dir = Path(dir).relative_to(root_path).as_posix()
# for file in files:
# file_id = base64.urlsafe_b64encode(
# f"{purpose}/{dir}/{file}".encode()
# ).decode()
# file_ids.append(file_id)
# return {
# "data": [{**_get_file_info(x), "id": x, "object": "file"} for x in file_ids]
# }
#
#
# @openai_router.get("/files/{file_id}")
# def retrieve_file(file_id: str) -> Dict:
# file_info = _get_file_info(file_id)
# return {**file_info, "id": file_id, "object": "file"}
#
#
# @openai_router.get("/files/{file_id}/content")
# def retrieve_file_content(file_id: str) -> Dict:
# file_path = _get_file_path(file_id)
# return FileResponse(file_path)
#
#
# @openai_router.delete("/files/{file_id}")
# def delete_file(file_id: str) -> Dict:
# file_path = _get_file_path(file_id)
# deleted = False
#
# try:
# if os.path.isfile(file_path):
# os.remove(file_path)
# deleted = True
# except:
# ...
#
# return {"id": file_id, "deleted": deleted, "object": "file"}

+ 0
- 0
src/mindpilot/app/api/tool_routes.py View File


+ 56
- 56
src/mindpilot/app/callback_handler/agent_callback_handler.py View File

@@ -33,7 +33,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.out = True

async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
data = {
"status": AgentStatus.llm_start,
@@ -43,7 +43,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"]
special_tokens = ["\nAction:", "\nObservation:", "<|observation|>", "\nThought:"]
for stoken in special_tokens:
if stoken in token:
before_action = token.split(stoken)[0]
@@ -63,15 +63,15 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
data = {
"status": AgentStatus.llm_start,
@@ -88,7 +88,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_llm_error(
self, error: Exception | KeyboardInterrupt, **kwargs: Any
self, error: Exception | KeyboardInterrupt, **kwargs: Any
) -> None:
data = {
"status": AgentStatus.error,
@@ -97,15 +97,15 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
data = {
"run_id": str(run_id),
@@ -116,13 +116,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
data = {
@@ -134,13 +134,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
data = {
@@ -153,13 +153,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
data = {
"status": AgentStatus.agent_action,
@@ -170,13 +170,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
if "Thought:" in finish.return_values["output"]:
finish.return_values["output"] = finish.return_values["output"].replace(
@@ -190,13 +190,13 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))

async def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: List[str] | None = None,
**kwargs: Any,
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: List[str] | None = None,
**kwargs: Any,
) -> None:
self.done.set()
self.out = True

+ 0
- 1
src/mindpilot/app/chat/chat.py View File

@@ -59,7 +59,6 @@ def create_models_chains(
False
)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
print(chat_prompt)

llm = models["llm_model"]
llm.callbacks = callbacks


+ 3
- 5
src/mindpilot/app/chat/utils.py View File

@@ -1,10 +1,8 @@
import logging
from functools import lru_cache
from typing import Dict, List, Tuple, Union

from langchain.prompts.chat import ChatMessagePromptTemplate

from ..pydantic_v2 import BaseModel, Field
import logging
from typing import AsyncGenerator, Dict, Iterable, Tuple

logger = logging.getLogger()

@@ -48,4 +46,4 @@ class History(BaseModel):
elif isinstance(h, dict):
h = cls(**h)

return h
return h

Loading…
Cancel
Save