|
- from __future__ import annotations
-
- import json
- import time
- from typing import Dict, List, Literal, Optional, Union
-
- from fastapi import UploadFile
- from openai.types.chat import (
- ChatCompletionMessageParam,
- ChatCompletionToolChoiceOptionParam,
- ChatCompletionToolParam,
- completion_create_params,
- )
-
- from ..utils.system_utils import MsgType
-
- # from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
- DEFAULT_LLM_MODEL = None # TODO 配置文件
- TEMPERATURE = 0.8
- from ..pydantic_v2 import AnyUrl, BaseModel, Field
-
-
- class OpenAIBaseInput(BaseModel):
- user: Optional[str] = None
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Optional[Dict] = None
- extra_query: Optional[Dict] = None
- extra_json: Optional[Dict] = Field(None, alias="extra_body")
- timeout: Optional[float] = None
-
- class Config:
- extra = "allow"
-
-
- class OpenAIChatInput(OpenAIBaseInput):
- messages: List[ChatCompletionMessageParam]
- model: str = DEFAULT_LLM_MODEL
- frequency_penalty: Optional[float] = None
- function_call: Optional[completion_create_params.FunctionCall] = None
- functions: List[completion_create_params.Function] = None
- logit_bias: Optional[Dict[str, int]] = None
- logprobs: Optional[bool] = None
- max_tokens: Optional[int] = None
- n: Optional[int] = None
- presence_penalty: Optional[float] = None
- response_format: completion_create_params.ResponseFormat = None
- seed: Optional[int] = None
- stop: Union[Optional[str], List[str]] = None
- stream: Optional[bool] = None
- temperature: Optional[float] = TEMPERATURE
- tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
- tools: List[Union[ChatCompletionToolParam, str]] = None
- top_logprobs: Optional[int] = None
- top_p: Optional[float] = None
-
-
- class OpenAIEmbeddingsInput(OpenAIBaseInput):
- input: Union[str, List[str]]
- model: str
- dimensions: Optional[int] = None
- encoding_format: Optional[Literal["float", "base64"]] = None
-
-
- class OpenAIImageBaseInput(OpenAIBaseInput):
- model: str
- n: int = 1
- response_format: Optional[Literal["url", "b64_json"]] = None
- size: Optional[
- Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
- ] = "256x256"
-
-
- class OpenAIImageGenerationsInput(OpenAIImageBaseInput):
- prompt: str
- quality: Literal["standard", "hd"] = None
- style: Optional[Literal["vivid", "natural"]] = None
-
-
- class OpenAIImageVariationsInput(OpenAIImageBaseInput):
- image: Union[UploadFile, AnyUrl]
-
-
- class OpenAIImageEditsInput(OpenAIImageVariationsInput):
- prompt: str
- mask: Union[UploadFile, AnyUrl]
-
-
- class OpenAIAudioTranslationsInput(OpenAIBaseInput):
- file: Union[UploadFile, AnyUrl]
- model: str
- prompt: Optional[str] = None
- response_format: Optional[str] = None
- temperature: float = TEMPERATURE
-
-
- class OpenAIAudioTranscriptionsInput(OpenAIAudioTranslationsInput):
- language: Optional[str] = None
- timestamp_granularities: Optional[List[Literal["word", "segment"]]] = None
-
-
- class OpenAIAudioSpeechInput(OpenAIBaseInput):
- input: str
- model: str
- voice: str
- response_format: Optional[
- Literal["mp3", "opus", "aac", "flac", "pcm", "wav"]
- ] = None
- speed: Optional[float] = None
-
-
- # class OpenAIFileInput(OpenAIBaseInput):
- # file: UploadFile # FileTypes
- # purpose: Literal["fine-tune", "assistants"] = "assistants"
-
-
- class OpenAIBaseOutput(BaseModel):
- id: Optional[str] = None
- content: Optional[str] = None
- model: Optional[str] = None
- object: Literal[
- "chat.completion", "chat.completion.chunk"
- ] = "chat.completion.chunk"
- role: Literal["assistant"] = "assistant"
- finish_reason: Optional[str] = None
- created: int = Field(default_factory=lambda: int(time.time()))
- tool_calls: List[Dict] = []
-
- status: Optional[int] = None # AgentStatus
- message_type: int = MsgType.TEXT
- message_id: Optional[str] = None # id in database table
- is_ref: bool = False # wheather show in seperated expander
-
- class Config:
- extra = "allow"
-
- def model_dump(self) -> dict:
- result = {
- "id": self.id,
- "object": self.object,
- "model": self.model,
- "created": self.created,
- "status": self.status,
- "message_type": self.message_type,
- "message_id": self.message_id,
- "is_ref": self.is_ref,
- **(self.model_extra or {}),
- }
-
- if self.object == "chat.completion.chunk":
- result["choices"] = [
- {
- "delta": {
- "content": self.content,
- "tool_calls": self.tool_calls,
- },
- "role": self.role,
- }
- ]
- elif self.object == "chat.completion":
- result["choices"] = [
- {
- "message": {
- "role": self.role,
- "content": self.content,
- "finish_reason": self.finish_reason,
- "tool_calls": self.tool_calls,
- }
- }
- ]
- return result
-
- def model_dump_json(self):
- return json.dumps(self.model_dump(), ensure_ascii=False)
-
-
- class OpenAIChatOutput(OpenAIBaseOutput):
- ...
|