| Author | SHA1 | Message | Date |
|---|---|---|---|
|
|
877eb0e188 | Fix openai server | 7 months ago |
|
|
4beb1c2398 | Fix CI and add info about closed inputs | 7 months ago |
|
|
1c1a91c206 | fix undefined question in history | 7 months ago |
|
|
b6f3c66df1 | Adding vision to openai server | 7 months ago |
|
|
b2474de9a3 | Expose all input closed message as a stop message | 7 months ago |
| @@ -234,13 +234,7 @@ impl EventStream { | |||||
| Err(err) => Event::Error(format!("{err:?}")), | Err(err) => Event::Error(format!("{err:?}")), | ||||
| } | } | ||||
| } | } | ||||
| NodeEvent::AllInputsClosed => { | |||||
| let err = eyre!( | |||||
| "received `AllInputsClosed` event, which should be handled by background task" | |||||
| ); | |||||
| tracing::error!("{err:?}"); | |||||
| Event::Error(err.wrap_err("internal error").to_string()) | |||||
| } | |||||
| NodeEvent::AllInputsClosed => Event::Stop, | |||||
| }, | }, | ||||
| EventItem::FatalError(err) => { | EventItem::FatalError(err) => { | ||||
| @@ -92,6 +92,7 @@ fn event_stream_loop( | |||||
| clock: Arc<uhlc::HLC>, | clock: Arc<uhlc::HLC>, | ||||
| ) { | ) { | ||||
| let mut tx = Some(tx); | let mut tx = Some(tx); | ||||
| let mut close_tx = false; | |||||
| let mut pending_drop_tokens: Vec<(DropToken, flume::Receiver<()>, Instant, u64)> = Vec::new(); | let mut pending_drop_tokens: Vec<(DropToken, flume::Receiver<()>, Instant, u64)> = Vec::new(); | ||||
| let mut drop_tokens = Vec::new(); | let mut drop_tokens = Vec::new(); | ||||
| @@ -135,10 +136,8 @@ fn event_stream_loop( | |||||
| data: Some(data), .. | data: Some(data), .. | ||||
| } => data.drop_token(), | } => data.drop_token(), | ||||
| NodeEvent::AllInputsClosed => { | NodeEvent::AllInputsClosed => { | ||||
| // close the event stream | |||||
| tx = None; | |||||
| // skip this internal event | |||||
| continue; | |||||
| close_tx = true; | |||||
| None | |||||
| } | } | ||||
| _ => None, | _ => None, | ||||
| }; | }; | ||||
| @@ -166,6 +165,10 @@ fn event_stream_loop( | |||||
| } else { | } else { | ||||
| tracing::warn!("dropping event because event `tx` was already closed: `{inner:?}`"); | tracing::warn!("dropping event because event `tx` was already closed: `{inner:?}`"); | ||||
| } | } | ||||
| if close_tx { | |||||
| tx = None; | |||||
| }; | |||||
| } | } | ||||
| }; | }; | ||||
| if let Err(err) = result { | if let Err(err) = result { | ||||
| @@ -540,7 +540,7 @@ pub async fn spawn_node( | |||||
| // If log is an output, we're sending the logs to the dataflow | // If log is an output, we're sending the logs to the dataflow | ||||
| if let Some(stdout_output_name) = &send_stdout_to { | if let Some(stdout_output_name) = &send_stdout_to { | ||||
| // Convert logs to DataMessage | // Convert logs to DataMessage | ||||
| let array = message.into_arrow(); | |||||
| let array = message.clone().into_arrow(); | |||||
| let array: ArrayData = array.into(); | let array: ArrayData = array.into(); | ||||
| let total_len = required_data_size(&array); | let total_len = required_data_size(&array); | ||||
| @@ -3,14 +3,14 @@ nodes: | |||||
| build: cargo build -p dora-openai-proxy-server --release | build: cargo build -p dora-openai-proxy-server --release | ||||
| path: ../../target/release/dora-openai-proxy-server | path: ../../target/release/dora-openai-proxy-server | ||||
| outputs: | outputs: | ||||
| - chat_completion_request | |||||
| - text | |||||
| inputs: | inputs: | ||||
| completion_reply: dora-echo/echo | |||||
| text: dora-echo/echo | |||||
| - id: dora-echo | - id: dora-echo | ||||
| build: pip install -e ../../node-hub/dora-echo | build: pip install -e ../../node-hub/dora-echo | ||||
| path: dora-echo | path: dora-echo | ||||
| inputs: | inputs: | ||||
| echo: dora-openai-server/chat_completion_request | |||||
| echo: dora-openai-server/text | |||||
| outputs: | outputs: | ||||
| - echo | - echo | ||||
| @@ -32,11 +32,69 @@ def test_chat_completion(user_input): | |||||
| print(f"Error in chat completion: {e}") | print(f"Error in chat completion: {e}") | ||||
| def test_chat_completion_image_url(user_input): | |||||
| """TODO: Add docstring.""" | |||||
| try: | |||||
| response = client.chat.completions.create( | |||||
| model="gpt-3.5-turbo", | |||||
| messages=[ | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| {"type": "text", "text": "What is in this image?"}, | |||||
| { | |||||
| "type": "image_url", | |||||
| "image_url": { | |||||
| "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" | |||||
| }, | |||||
| }, | |||||
| ], | |||||
| } | |||||
| ], | |||||
| ) | |||||
| print("Chat Completion Response:") | |||||
| print(response.choices[0].message.content) | |||||
| except Exception as e: | |||||
| print(f"Error in chat completion: {e}") | |||||
| def test_chat_completion_image_base64(user_input): | |||||
| """TODO: Add docstring.""" | |||||
| try: | |||||
| response = client.chat.completions.create( | |||||
| model="gpt-3.5-turbo", | |||||
| messages=[ | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| {"type": "text", "text": "What is in this image?"}, | |||||
| { | |||||
| "type": "image_url", | |||||
| "image_url": { | |||||
| "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" | |||||
| }, | |||||
| }, | |||||
| ], | |||||
| } | |||||
| ], | |||||
| ) | |||||
| print("Chat Completion Response:") | |||||
| print(response.choices[0].message.content) | |||||
| except Exception as e: | |||||
| print(f"Error in chat completion: {e}") | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| print("Testing API endpoints...") | print("Testing API endpoints...") | ||||
| test_list_models() | |||||
| # test_list_models() | |||||
| print("\n" + "=" * 50 + "\n") | print("\n" + "=" * 50 + "\n") | ||||
| chat_input = input("Enter a message for chat completion: ") | chat_input = input("Enter a message for chat completion: ") | ||||
| test_chat_completion(chat_input) | test_chat_completion(chat_input) | ||||
| print("\n" + "=" * 50 + "\n") | |||||
| test_chat_completion_image_url(chat_input) | |||||
| print("\n" + "=" * 50 + "\n") | |||||
| test_chat_completion_image_base64(chat_input) | |||||
| print("\n" + "=" * 50 + "\n") | print("\n" + "=" * 50 + "\n") | ||||
| @@ -0,0 +1,16 @@ | |||||
| nodes: | |||||
| - id: dora-openai-server | |||||
| build: cargo build -p dora-openai-proxy-server --release | |||||
| path: ../../target/release/dora-openai-proxy-server | |||||
| outputs: | |||||
| - text | |||||
| inputs: | |||||
| text: dora-qwen2.5-vl/text | |||||
| - id: dora-qwen2.5-vl | |||||
| build: pip install -e ../../node-hub/dora-qwen2-5-vl | |||||
| path: dora-qwen2-5-vl | |||||
| inputs: | |||||
| text: dora-openai-server/text | |||||
| outputs: | |||||
| - text | |||||
| @@ -57,6 +57,20 @@ impl IntoArrow for &str { | |||||
| } | } | ||||
| } | } | ||||
| impl IntoArrow for String { | |||||
| type A = StringArray; | |||||
| fn into_arrow(self) -> Self::A { | |||||
| std::iter::once(Some(self)).collect() | |||||
| } | |||||
| } | |||||
| impl IntoArrow for Vec<String> { | |||||
| type A = StringArray; | |||||
| fn into_arrow(self) -> Self::A { | |||||
| StringArray::from(self) | |||||
| } | |||||
| } | |||||
| impl IntoArrow for () { | impl IntoArrow for () { | ||||
| type A = arrow::array::NullArray; | type A = arrow::array::NullArray; | ||||
| @@ -41,7 +41,7 @@ async fn main() -> eyre::Result<()> { | |||||
| node.send_output( | node.send_output( | ||||
| mistral_output.clone(), | mistral_output.clone(), | ||||
| metadata.parameters, | metadata.parameters, | ||||
| output.into_arrow(), | |||||
| output.as_str().into_arrow(), | |||||
| )?; | )?; | ||||
| } | } | ||||
| other => eprintln!("Received input `{other}`"), | other => eprintln!("Received input `{other}`"), | ||||
| @@ -1,140 +1,389 @@ | |||||
| """TODO: Add docstring.""" | |||||
| """FastAPI server with OpenAI compatibility and DORA integration, | |||||
| sending text and image data on separate DORA topics. | |||||
| """ | |||||
| import ast | |||||
| import asyncio | import asyncio | ||||
| from typing import List, Optional | |||||
| import base64 | |||||
| import time # For timestamps | |||||
| import uuid # For generating unique request IDs | |||||
| from typing import Any, List, Literal, Optional, Union | |||||
| import pyarrow as pa | import pyarrow as pa | ||||
| import uvicorn | import uvicorn | ||||
| from dora import Node | from dora import Node | ||||
| from fastapi import FastAPI | |||||
| from fastapi import FastAPI, HTTPException | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| DORA_RESPONSE_TIMEOUT = 10 | |||||
| app = FastAPI() | |||||
| # --- DORA Configuration --- | |||||
| DORA_RESPONSE_TIMEOUT_SECONDS = 20 | |||||
| DORA_TEXT_OUTPUT_TOPIC = "user_text_input" | |||||
| DORA_IMAGE_OUTPUT_TOPIC = "user_image_input" | |||||
| DORA_RESPONSE_INPUT_TOPIC = "chat_completion_result" # Topic FastAPI listens on | |||||
| app = FastAPI( | |||||
| title="DORA OpenAI-Compatible Demo Server (Separate Topics)", | |||||
| description="Sends text and image data on different DORA topics and awaits a consolidated response.", | |||||
| ) | |||||
| # --- Pydantic Models --- | |||||
| class ImageUrl(BaseModel): | |||||
| url: str | |||||
| detail: Optional[str] = "auto" | |||||
| class ContentPartText(BaseModel): | |||||
| type: Literal["text"] | |||||
| text: str | |||||
| class ContentPartImage(BaseModel): | |||||
| type: Literal["image_url"] | |||||
| image_url: ImageUrl | |||||
| class ChatCompletionMessage(BaseModel): | |||||
| """TODO: Add docstring.""" | |||||
| ContentPart = Union[ContentPartText, ContentPartImage] | |||||
| class ChatCompletionMessage(BaseModel): | |||||
| role: str | role: str | ||||
| content: str | |||||
| content: Union[str, List[ContentPart]] | |||||
| class ChatCompletionRequest(BaseModel): | class ChatCompletionRequest(BaseModel): | ||||
| """TODO: Add docstring.""" | |||||
| model: str | model: str | ||||
| messages: List[ChatCompletionMessage] | messages: List[ChatCompletionMessage] | ||||
| temperature: Optional[float] = 1.0 | temperature: Optional[float] = 1.0 | ||||
| max_tokens: Optional[int] = 100 | max_tokens: Optional[int] = 100 | ||||
| class ChatCompletionResponse(BaseModel): | |||||
| """TODO: Add docstring.""" | |||||
| class ChatCompletionChoiceMessage(BaseModel): | |||||
| role: str | |||||
| content: str | |||||
| class ChatCompletionChoice(BaseModel): | |||||
| index: int | |||||
| message: ChatCompletionChoiceMessage | |||||
| finish_reason: str | |||||
| logprobs: Optional[Any] = None | |||||
| class Usage(BaseModel): | |||||
| prompt_tokens: int | |||||
| completion_tokens: int | |||||
| total_tokens: int | |||||
| class ChatCompletionResponse(BaseModel): | |||||
| id: str | id: str | ||||
| object: str | |||||
| object: str = "chat.completion" | |||||
| created: int | created: int | ||||
| model: str | model: str | ||||
| choices: List[dict] | |||||
| usage: dict | |||||
| choices: List[ChatCompletionChoice] | |||||
| usage: Usage | |||||
| system_fingerprint: Optional[str] = None | |||||
| # --- DORA Node Initialization --- | |||||
| # This dictionary will hold unmatched responses if we implement more robust concurrent handling. | |||||
| # For now, it's a placeholder for future improvement. | |||||
| # unmatched_dora_responses = {} | |||||
| node = Node() # provide the name to connect to the dataflow if dynamic node | |||||
| try: | |||||
| node = Node() | |||||
| print("FastAPI Server: DORA Node initialized.") | |||||
| except Exception as e: | |||||
| print( | |||||
| f"FastAPI Server: Failed to initialize DORA Node. Running in standalone API mode. Error: {e}" | |||||
| ) | |||||
| node = None | |||||
| @app.post("/v1/chat/completions") | |||||
| @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) | |||||
| async def create_chat_completion(request: ChatCompletionRequest): | async def create_chat_completion(request: ChatCompletionRequest): | ||||
| """TODO: Add docstring.""" | |||||
| data = next( | |||||
| (msg.content for msg in request.messages if msg.role == "user"), | |||||
| "No user message found.", | |||||
| ) | |||||
| internal_request_id = str(uuid.uuid4()) | |||||
| openai_chat_id = f"chatcmpl-{internal_request_id}" | |||||
| current_timestamp = int(time.time()) | |||||
| # Convert user_message to Arrow array | |||||
| # user_message_array = pa.array([user_message]) | |||||
| # Publish user message to dora-echo | |||||
| # node.send_output("user_query", user_message_array) | |||||
| print(f"FastAPI Server: Processing request_id: {internal_request_id}") | |||||
| try: | |||||
| data = ast.literal_eval(data) | |||||
| except ValueError: | |||||
| print("Passing input as string") | |||||
| except SyntaxError: | |||||
| print("Passing input as string") | |||||
| if isinstance(data, list): | |||||
| data = pa.array(data) # initialize pyarrow array | |||||
| elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float) or isinstance(data, dict): | |||||
| data = pa.array([data]) | |||||
| user_text_parts = [] | |||||
| user_image_bytes: Optional[bytes] = None | |||||
| user_image_content_type: Optional[str] = None | |||||
| data_sent_to_dora = False | |||||
| for message in reversed(request.messages): | |||||
| if message.role == "user": | |||||
| if isinstance(message.content, str): | |||||
| user_text_parts.append(message.content) | |||||
| elif isinstance(message.content, list): | |||||
| for part in message.content: | |||||
| if part.type == "text": | |||||
| user_text_parts.append(part.text) | |||||
| elif part.type == "image_url": | |||||
| if user_image_bytes: # Use only the first image | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Warning - Multiple images found, using the first one." | |||||
| ) | |||||
| continue | |||||
| image_url_data = part.image_url.url | |||||
| if image_url_data.startswith("data:image"): | |||||
| try: | |||||
| header, encoded_data = image_url_data.split(",", 1) | |||||
| user_image_content_type = header.split(":")[1].split( | |||||
| ";" | |||||
| )[0] | |||||
| user_image_bytes = base64.b64decode(encoded_data) | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Decoded image {user_image_content_type}, size: {len(user_image_bytes)} bytes" | |||||
| ) | |||||
| except Exception as e: | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Error decoding base64 image: {e}" | |||||
| ) | |||||
| raise HTTPException( | |||||
| status_code=400, | |||||
| detail=f"Invalid base64 image data: {e}", | |||||
| ) | |||||
| else: | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Warning - Remote image URL '{image_url_data}' ignored. Only data URIs supported." | |||||
| ) | |||||
| # Consider if you want to break after the first user message or aggregate all | |||||
| # break | |||||
| final_user_text = "\n".join(reversed(user_text_parts)) if user_text_parts else "" | |||||
| prompt_tokens = len(final_user_text) | |||||
| if node: | |||||
| if final_user_text: | |||||
| text_payload = {"request_id": internal_request_id, "text": final_user_text} | |||||
| arrow_text_data = pa.array([text_payload]) | |||||
| node.send_output(DORA_TEXT_OUTPUT_TOPIC, arrow_text_data) | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Sent text to DORA topic '{DORA_TEXT_OUTPUT_TOPIC}'." | |||||
| ) | |||||
| data_sent_to_dora = True | |||||
| if user_image_bytes: | |||||
| image_payload = { | |||||
| "request_id": internal_request_id, | |||||
| "image_bytes": user_image_bytes, | |||||
| "image_content_type": user_image_content_type | |||||
| or "application/octet-stream", | |||||
| } | |||||
| arrow_image_data = pa.array([image_payload]) | |||||
| node.send_output(DORA_IMAGE_OUTPUT_TOPIC, arrow_image_data) | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Sent image to DORA topic '{DORA_IMAGE_OUTPUT_TOPIC}'." | |||||
| ) | |||||
| prompt_tokens += len(user_image_bytes) # Crude image token approximation | |||||
| data_sent_to_dora = True | |||||
| response_str: str | |||||
| if not data_sent_to_dora: | |||||
| if node is None: | |||||
| response_str = "DORA node not available. Cannot process request." | |||||
| else: | |||||
| response_str = "No user text or image found to send to DORA." | |||||
| print(f"FastAPI Server (Req {internal_request_id}): {response_str}") | |||||
| else: | else: | ||||
| data = pa.array(data) # initialize pyarrow array | |||||
| node.send_output("v1/chat/completions", data) | |||||
| # Wait for response from dora-echo | |||||
| while True: | |||||
| event = node.next(timeout=DORA_RESPONSE_TIMEOUT) | |||||
| if event["type"] == "ERROR": | |||||
| response_str = "No response received. Err: " + event["value"][0].as_py() | |||||
| break | |||||
| if event["type"] == "INPUT" and event["id"] == "v1/chat/completions": | |||||
| response = event["value"] | |||||
| response_str = response[0].as_py() if response else "No response received" | |||||
| break | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Waiting for response from DORA on topic '{DORA_RESPONSE_INPUT_TOPIC}'..." | |||||
| ) | |||||
| response_str = f"Timeout: No response from DORA for request_id {internal_request_id} within {DORA_RESPONSE_TIMEOUT_SECONDS}s." | |||||
| # WARNING: This blocking `node.next()` loop is not ideal for highly concurrent requests | |||||
| # in a single FastAPI worker process, as one request might block others or consume | |||||
| # a response meant for another if `request_id` matching isn't perfect or fast enough. | |||||
| # A more robust solution would involve a dedicated listener task and async Futures/Queues. | |||||
| start_wait_time = time.monotonic() | |||||
| while time.monotonic() - start_wait_time < DORA_RESPONSE_TIMEOUT_SECONDS: | |||||
| remaining_timeout = DORA_RESPONSE_TIMEOUT_SECONDS - ( | |||||
| time.monotonic() - start_wait_time | |||||
| ) | |||||
| if remaining_timeout <= 0: | |||||
| break | |||||
| event = node.next( | |||||
| timeout=min(1.0, remaining_timeout) | |||||
| ) # Poll with a smaller timeout | |||||
| if event is None: # Timeout for this poll iteration | |||||
| continue | |||||
| if event["type"] == "INPUT" and event["id"] == DORA_RESPONSE_INPUT_TOPIC: | |||||
| response_value_arrow = event["value"] | |||||
| if response_value_arrow and len(response_value_arrow) > 0: | |||||
| dora_response_data = response_value_arrow[ | |||||
| 0 | |||||
| ].as_py() # Expecting a dict | |||||
| if isinstance(dora_response_data, dict): | |||||
| resp_request_id = dora_response_data.get("request_id") | |||||
| if resp_request_id == internal_request_id: | |||||
| response_str = dora_response_data.get( | |||||
| "response_text", | |||||
| f"DORA response for {internal_request_id} missing 'response_text'.", | |||||
| ) | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Received correlated DORA response." | |||||
| ) | |||||
| break # Correct response received | |||||
| # This response is for another request. Ideally, store it. | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Received DORA response for different request_id '{resp_request_id}'. Discarding and waiting. THIS IS A CONCURRENCY ISSUE." | |||||
| ) | |||||
| # unmatched_dora_responses[resp_request_id] = dora_response_data # Example of storing | |||||
| else: | |||||
| response_str = f"Unrecognized DORA response format for {internal_request_id}: {str(dora_response_data)[:100]}" | |||||
| break | |||||
| else: | |||||
| response_str = ( | |||||
| f"Empty response payload from DORA for {internal_request_id}." | |||||
| ) | |||||
| break | |||||
| elif event["type"] == "ERROR": | |||||
| response_str = f"Error event from DORA for {internal_request_id}: {event.get('value', event.get('error', 'Unknown DORA Error'))}" | |||||
| print(response_str) | |||||
| break | |||||
| else: # Outer while loop timed out | |||||
| print( | |||||
| f"FastAPI Server (Req {internal_request_id}): Overall timeout waiting for DORA response." | |||||
| ) | |||||
| completion_tokens = len(response_str) | |||||
| total_tokens = prompt_tokens + completion_tokens | |||||
| return ChatCompletionResponse( | return ChatCompletionResponse( | ||||
| id="chatcmpl-1234", | |||||
| object="chat.completion", | |||||
| created=1234567890, | |||||
| id=openai_chat_id, | |||||
| created=current_timestamp, | |||||
| model=request.model, | model=request.model, | ||||
| choices=[ | choices=[ | ||||
| { | |||||
| "index": 0, | |||||
| "message": {"role": "assistant", "content": response_str}, | |||||
| "finish_reason": "stop", | |||||
| }, | |||||
| ChatCompletionChoice( | |||||
| index=0, | |||||
| message=ChatCompletionChoiceMessage( | |||||
| role="assistant", content=response_str | |||||
| ), | |||||
| finish_reason="stop", | |||||
| ) | |||||
| ], | ], | ||||
| usage={ | |||||
| "prompt_tokens": len(data), | |||||
| "completion_tokens": len(response_str), | |||||
| "total_tokens": len(data) + len(response_str), | |||||
| }, | |||||
| usage=Usage( | |||||
| prompt_tokens=prompt_tokens, | |||||
| completion_tokens=completion_tokens, | |||||
| total_tokens=total_tokens, | |||||
| ), | |||||
| ) | ) | ||||
| @app.get("/v1/models") | @app.get("/v1/models") | ||||
| async def list_models(): | async def list_models(): | ||||
| """TODO: Add docstring.""" | |||||
| return { | return { | ||||
| "object": "list", | "object": "list", | ||||
| "data": [ | "data": [ | ||||
| { | { | ||||
| "id": "gpt-3.5-turbo", | |||||
| "id": "dora-multi-stream-vision", | |||||
| "object": "model", | "object": "model", | ||||
| "created": 1677610602, | |||||
| "owned_by": "openai", | |||||
| "created": int(time.time()), | |||||
| "owned_by": "dora-ai", | |||||
| "permission": [], | |||||
| "root": "dora-multi-stream-vision", | |||||
| "parent": None, | |||||
| }, | }, | ||||
| ], | ], | ||||
| } | } | ||||
| async def run_fastapi(): | |||||
| """TODO: Add docstring.""" | |||||
| async def run_fastapi_server_task(): | |||||
| config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="info") | config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="info") | ||||
| server = uvicorn.Server(config) | server = uvicorn.Server(config) | ||||
| print("FastAPI Server: Uvicorn server starting.") | |||||
| await server.serve() | |||||
| print("FastAPI Server: Uvicorn server stopped.") | |||||
| server = asyncio.gather(server.serve()) | |||||
| while True: | |||||
| await asyncio.sleep(1) | |||||
| event = node.next(0.001) | |||||
| if event["type"] == "STOP": | |||||
| break | |||||
| async def run_dora_main_loop_task(): | |||||
| if not node: | |||||
| print("FastAPI Server: DORA node not initialized, DORA main loop skipped.") | |||||
| return | |||||
| print("FastAPI Server: DORA main loop listener started (for STOP event).") | |||||
| try: | |||||
| while True: | |||||
| # This loop is primarily for the main "STOP" event for the FastAPI node itself. | |||||
| # Individual request/response cycles are handled within the endpoint. | |||||
| event = node.next(timeout=1.0) # Check for STOP periodically | |||||
| if event is None: | |||||
| await asyncio.sleep(0.01) # Yield control if no event | |||||
| continue | |||||
| if event["type"] == "STOP": | |||||
| print( | |||||
| "FastAPI Server: DORA STOP event received. Requesting server shutdown." | |||||
| ) | |||||
| # Attempt to gracefully shut down Uvicorn | |||||
| # This is tricky; uvicorn's server.shutdown() or server.should_exit might be better | |||||
| # For simplicity, we cancel the server task. | |||||
| for task in asyncio.all_tasks(): | |||||
| # Identify the server task more reliably if possible | |||||
| if ( | |||||
| task.get_coro().__name__ == "serve" | |||||
| and hasattr(task.get_coro(), "cr_frame") | |||||
| and isinstance( | |||||
| task.get_coro().cr_frame.f_locals.get("self"), | |||||
| uvicorn.Server, | |||||
| ) | |||||
| ): | |||||
| task.cancel() | |||||
| print( | |||||
| "FastAPI Server: Uvicorn server task cancellation requested." | |||||
| ) | |||||
| break | |||||
| # Handle other unexpected general inputs/errors for the FastAPI node if necessary | |||||
| # elif event["type"] == "INPUT": | |||||
| # print(f"FastAPI Server (DORA Main Loop): Unexpected DORA input on ID '{event['id']}'") | |||||
| except asyncio.CancelledError: | |||||
| print("FastAPI Server: DORA main loop task cancelled.") | |||||
| except Exception as e: | |||||
| print(f"FastAPI Server: Error in DORA main loop: {e}") | |||||
| finally: | |||||
| print("FastAPI Server: DORA main loop listener finished.") | |||||
| async def main_async_runner(): | |||||
| server_task = asyncio.create_task(run_fastapi_server_task()) | |||||
| # Only run the DORA main loop if the node was initialized. | |||||
| # This loop is mainly for the STOP event. | |||||
| dora_listener_task = None | |||||
| if node: | |||||
| dora_listener_task = asyncio.create_task(run_dora_main_loop_task()) | |||||
| tasks_to_wait_for = [server_task, dora_listener_task] | |||||
| else: | |||||
| tasks_to_wait_for = [server_task] | |||||
| done, pending = await asyncio.wait( | |||||
| tasks_to_wait_for, | |||||
| return_when=asyncio.FIRST_COMPLETED, | |||||
| ) | |||||
| for task in pending: | |||||
| print(f"FastAPI Server: Cancelling pending task: {task.get_name()}") | |||||
| task.cancel() | |||||
| if pending: | |||||
| await asyncio.gather(*pending, return_exceptions=True) | |||||
| print("FastAPI Server: Application shutdown complete.") | |||||
| def main(): | def main(): | ||||
| """TODO: Add docstring.""" | |||||
| asyncio.run(run_fastapi()) | |||||
| print("FastAPI Server: Starting application...") | |||||
| try: | |||||
| asyncio.run(main_async_runner()) | |||||
| except KeyboardInterrupt: | |||||
| print("FastAPI Server: Keyboard interrupt received. Shutting down.") | |||||
| finally: | |||||
| print("FastAPI Server: Exited main function.") | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| asyncio.run(run_fastapi()) | |||||
| main() | |||||
| @@ -2,8 +2,8 @@ | |||||
| name = "dora-openai-server" | name = "dora-openai-server" | ||||
| version = "0.3.11" | version = "0.3.11" | ||||
| authors = [ | authors = [ | ||||
| { name = "Haixuan Xavier Tao", email = "tao.xavier@outlook.com" }, | |||||
| { name = "Enzo Le Van", email = "dev@enzo-le-van.fr" }, | |||||
| { name = "Haixuan Xavier Tao", email = "tao.xavier@outlook.com" }, | |||||
| { name = "Enzo Le Van", email = "dev@enzo-le-van.fr" }, | |||||
| ] | ] | ||||
| description = "Dora OpenAI API Server" | description = "Dora OpenAI API Server" | ||||
| license = { text = "MIT" } | license = { text = "MIT" } | ||||
| @@ -11,14 +11,13 @@ readme = "README.md" | |||||
| requires-python = ">=3.8" | requires-python = ">=3.8" | ||||
| dependencies = [ | dependencies = [ | ||||
| "dora-rs >= 0.3.9", | |||||
| "numpy < 2.0.0", | |||||
| "pyarrow >= 5.0.0", | |||||
| "fastapi >= 0.115", | |||||
| "asyncio >= 3.4", | |||||
| "uvicorn >= 0.31", | |||||
| "pydantic >= 2.9", | |||||
| "dora-rs >= 0.3.9", | |||||
| "numpy < 2.0.0", | |||||
| "pyarrow >= 5.0.0", | |||||
| "fastapi >= 0.115", | |||||
| "asyncio >= 3.4", | |||||
| "uvicorn >= 0.31", | |||||
| "pydantic >= 2.9", | |||||
| ] | ] | ||||
| [dependency-groups] | [dependency-groups] | ||||
| @@ -29,7 +28,6 @@ dora-openai-server = "dora_openai_server.main:main" | |||||
| [tool.ruff.lint] | [tool.ruff.lint] | ||||
| extend-select = [ | extend-select = [ | ||||
| "D", # pydocstyle | |||||
| "UP", # Ruff's UP rule | "UP", # Ruff's UP rule | ||||
| "PERF", # Ruff's PERF rule | "PERF", # Ruff's PERF rule | ||||
| "RET", # Ruff's RET rule | "RET", # Ruff's RET rule | ||||
| @@ -62,29 +62,118 @@ if ADAPTER_PATH != "": | |||||
| processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | ||||
| def generate(frames: dict, question, history, past_key_values=None, image_id=None): | |||||
| def generate( | |||||
| frames: dict, texts: list[str], history, past_key_values=None, image_id=None | |||||
| ): | |||||
| """Generate the response to the question given the image using Qwen2 model.""" | """Generate the response to the question given the image using Qwen2 model.""" | ||||
| if image_id is not None: | if image_id is not None: | ||||
| images = [frames[image_id]] | images = [frames[image_id]] | ||||
| else: | else: | ||||
| images = list(frames.values()) | images = list(frames.values()) | ||||
| messages = [ | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| messages = [] | |||||
| for text in texts: | |||||
| if text.startswith("<|system|>\n"): | |||||
| messages.append( | |||||
| { | { | ||||
| "type": "image", | |||||
| "image": image, | |||||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||||
| "role": "system", | |||||
| "content": [ | |||||
| {"type": "text", "text": text.replace("<|system|>\n", "")}, | |||||
| ], | |||||
| } | } | ||||
| for image in images | |||||
| ] | |||||
| + [ | |||||
| {"type": "text", "text": question}, | |||||
| ], | |||||
| }, | |||||
| ] | |||||
| ) | |||||
| elif text.startswith("<|assistant|>\n"): | |||||
| messages.append( | |||||
| { | |||||
| "role": "assistant", | |||||
| "content": [ | |||||
| {"type": "text", "text": text.replace("<|assistant|>\n", "")}, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| elif text.startswith("<|tool|>\n"): | |||||
| messages.append( | |||||
| { | |||||
| "role": "tool", | |||||
| "content": [ | |||||
| {"type": "text", "text": text.replace("<|tool|>\n", "")}, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| elif text.startswith("<|user|>\n<|im_start|>\n"): | |||||
| messages.append( | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| { | |||||
| "type": "text", | |||||
| "text": text.replace("<|user|>\n<|im_start|>\n", ""), | |||||
| }, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| elif text.startswith("<|user|>\n<|vision_start|>\n"): | |||||
| # Handle the case where the text starts with <|user|>\n<|vision_start|> | |||||
| image_url = text.replace("<|user|>\n<|vision_start|>\n", "") | |||||
| # If the last message was from the user, append the image URL to it | |||||
| if messages[-1]["role"] == "user": | |||||
| messages[-1]["content"].append( | |||||
| { | |||||
| "type": "image", | |||||
| "image": image_url, | |||||
| } | |||||
| ) | |||||
| else: | |||||
| messages.append( | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| { | |||||
| "type": "image", | |||||
| "image": image_url, | |||||
| }, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| else: | |||||
| messages.append( | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| {"type": "text", "text": text}, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| # If the last message was from the user, append the image URL to it | |||||
| if messages[-1]["role"] == "user": | |||||
| messages[-1]["content"] += [ | |||||
| { | |||||
| "type": "image", | |||||
| "image": image, | |||||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||||
| } | |||||
| for image in images | |||||
| ] | |||||
| else: | |||||
| messages.append( | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| { | |||||
| "type": "image", | |||||
| "image": image, | |||||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||||
| } | |||||
| for image in images | |||||
| ], | |||||
| } | |||||
| ) | |||||
| tmp_history = history + messages | tmp_history = history + messages | ||||
| # Preparation for inference | # Preparation for inference | ||||
| text = processor.apply_chat_template( | text = processor.apply_chat_template( | ||||
| @@ -120,19 +209,13 @@ def generate(frames: dict, question, history, past_key_values=None, image_id=Non | |||||
| clean_up_tokenization_spaces=False, | clean_up_tokenization_spaces=False, | ||||
| ) | ) | ||||
| if HISTORY: | if HISTORY: | ||||
| history += [ | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| {"type": "text", "text": question}, | |||||
| ], | |||||
| }, | |||||
| history = tmp_history + [ | |||||
| { | { | ||||
| "role": "assistant", | "role": "assistant", | ||||
| "content": [ | "content": [ | ||||
| {"type": "text", "text": output_text[0]}, | {"type": "text", "text": output_text[0]}, | ||||
| ], | ], | ||||
| }, | |||||
| } | |||||
| ] | ] | ||||
| return output_text[0], history, past_key_values | return output_text[0], history, past_key_values | ||||
| @@ -207,24 +290,22 @@ def main(): | |||||
| elif "text" in event_id: | elif "text" in event_id: | ||||
| if len(event["value"]) > 0: | if len(event["value"]) > 0: | ||||
| text = event["value"][0].as_py() | |||||
| texts = event["value"].to_pylist() | |||||
| image_id = event["metadata"].get("image_id", None) | image_id = event["metadata"].get("image_id", None) | ||||
| else: | else: | ||||
| text = cached_text | |||||
| words = text.split() | |||||
| texts = cached_text | |||||
| words = texts[-1].split() | |||||
| if len(ACTIVATION_WORDS) > 0 and all( | if len(ACTIVATION_WORDS) > 0 and all( | ||||
| word not in ACTIVATION_WORDS for word in words | word not in ACTIVATION_WORDS for word in words | ||||
| ): | ): | ||||
| continue | continue | ||||
| cached_text = text | |||||
| cached_text = texts | |||||
| if len(frames.keys()) == 0: | |||||
| continue | |||||
| # set the max number of tiles in `max_num` | # set the max number of tiles in `max_num` | ||||
| response, history, past_key_values = generate( | response, history, past_key_values = generate( | ||||
| frames, | frames, | ||||
| text, | |||||
| texts, | |||||
| history, | history, | ||||
| past_key_values, | past_key_values, | ||||
| image_id, | image_id, | ||||
| @@ -1,4 +1,10 @@ | |||||
| use dora_node_api::{self, dora_core::config::DataId, merged::MergeExternalSend, DoraNode, Event}; | |||||
| use dora_node_api::{ | |||||
| self, | |||||
| arrow::array::{AsArray, StringArray}, | |||||
| dora_core::config::DataId, | |||||
| merged::MergeExternalSend, | |||||
| DoraNode, Event, | |||||
| }; | |||||
| use eyre::{Context, ContextCompat}; | use eyre::{Context, ContextCompat}; | ||||
| use futures::{ | use futures::{ | ||||
| @@ -14,7 +20,7 @@ use hyper::{ | |||||
| }; | }; | ||||
| use message::{ | use message::{ | ||||
| ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage, | ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage, | ||||
| ChatCompletionRequest, ChatCompletionRequestMessage, Usage, | |||||
| ChatCompletionRequest, Usage, | |||||
| }; | }; | ||||
| use std::{ | use std::{ | ||||
| collections::VecDeque, | collections::VecDeque, | ||||
| @@ -71,7 +77,7 @@ async fn main() -> eyre::Result<()> { | |||||
| let merged = events.merge_external_send(server_events); | let merged = events.merge_external_send(server_events); | ||||
| let events = futures::executor::block_on_stream(merged); | let events = futures::executor::block_on_stream(merged); | ||||
| let output_id = DataId::from("chat_completion_request".to_owned()); | |||||
| let output_id = DataId::from("text".to_owned()); | |||||
| let mut reply_channels = VecDeque::new(); | let mut reply_channels = VecDeque::new(); | ||||
| for event in events { | for event in events { | ||||
| @@ -82,45 +88,15 @@ async fn main() -> eyre::Result<()> { | |||||
| break; | break; | ||||
| } | } | ||||
| ServerEvent::ChatCompletionRequest { request, reply } => { | ServerEvent::ChatCompletionRequest { request, reply } => { | ||||
| let message = request | |||||
| .messages | |||||
| .into_iter() | |||||
| .find_map(|m| match m { | |||||
| ChatCompletionRequestMessage::User(message) => Some(message), | |||||
| _ => None, | |||||
| }) | |||||
| .context("no user message found"); | |||||
| match message { | |||||
| Ok(message) => match message.content() { | |||||
| message::ChatCompletionUserMessageContent::Text(content) => { | |||||
| node.send_output_bytes( | |||||
| output_id.clone(), | |||||
| Default::default(), | |||||
| content.len(), | |||||
| content.as_bytes(), | |||||
| ) | |||||
| .context("failed to send dora output")?; | |||||
| reply_channels.push_back(( | |||||
| reply, | |||||
| content.as_bytes().len() as u64, | |||||
| request.model, | |||||
| )); | |||||
| } | |||||
| message::ChatCompletionUserMessageContent::Parts(_) => { | |||||
| if reply | |||||
| .send(Err(eyre::eyre!("unsupported message content"))) | |||||
| .is_err() | |||||
| { | |||||
| tracing::warn!("failed to send chat completion reply because channel closed early"); | |||||
| }; | |||||
| } | |||||
| }, | |||||
| Err(err) => { | |||||
| if reply.send(Err(err)).is_err() { | |||||
| tracing::warn!("failed to send chat completion reply error because channel closed early"); | |||||
| } | |||||
| } | |||||
| } | |||||
| let texts = request.to_texts(); | |||||
| node.send_output( | |||||
| output_id.clone(), | |||||
| Default::default(), | |||||
| StringArray::from(texts), | |||||
| ) | |||||
| .context("failed to send dora output")?; | |||||
| reply_channels.push_back((reply, 0 as u64, request.model)); | |||||
| } | } | ||||
| }, | }, | ||||
| dora_node_api::merged::MergedEvent::Dora(event) => match event { | dora_node_api::merged::MergedEvent::Dora(event) => match event { | ||||
| @@ -130,35 +106,42 @@ async fn main() -> eyre::Result<()> { | |||||
| metadata: _, | metadata: _, | ||||
| } => { | } => { | ||||
| match id.as_str() { | match id.as_str() { | ||||
| "completion_reply" => { | |||||
| "text" => { | |||||
| let (reply_channel, prompt_tokens, model) = | let (reply_channel, prompt_tokens, model) = | ||||
| reply_channels.pop_front().context("no reply channel")?; | reply_channels.pop_front().context("no reply channel")?; | ||||
| let data = TryFrom::try_from(&data) | |||||
| .with_context(|| format!("invalid reply data: {data:?}")) | |||||
| .map(|s: &[u8]| ChatCompletionObject { | |||||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||||
| object: "chat.completion".to_string(), | |||||
| created: chrono::Utc::now().timestamp() as u64, | |||||
| model: model.unwrap_or_default(), | |||||
| choices: vec![ChatCompletionObjectChoice { | |||||
| index: 0, | |||||
| message: ChatCompletionObjectMessage { | |||||
| role: message::ChatCompletionRole::Assistant, | |||||
| content: Some(String::from_utf8_lossy(s).to_string()), | |||||
| tool_calls: Vec::new(), | |||||
| function_call: None, | |||||
| }, | |||||
| finish_reason: message::FinishReason::stop, | |||||
| logprobs: None, | |||||
| }], | |||||
| usage: Usage { | |||||
| prompt_tokens, | |||||
| completion_tokens: s.len() as u64, | |||||
| total_tokens: prompt_tokens + s.len() as u64, | |||||
| let data = data.as_string::<i32>(); | |||||
| let string = data.iter().fold("".to_string(), |mut acc, s| { | |||||
| if let Some(s) = s { | |||||
| acc.push_str("\n"); | |||||
| acc.push_str(s); | |||||
| } | |||||
| acc | |||||
| }); | |||||
| let data = ChatCompletionObject { | |||||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||||
| object: "chat.completion".to_string(), | |||||
| created: chrono::Utc::now().timestamp() as u64, | |||||
| model: model.unwrap_or_default(), | |||||
| choices: vec![ChatCompletionObjectChoice { | |||||
| index: 0, | |||||
| message: ChatCompletionObjectMessage { | |||||
| role: message::ChatCompletionRole::Assistant, | |||||
| content: Some(string.to_string()), | |||||
| tool_calls: Vec::new(), | |||||
| function_call: None, | |||||
| }, | }, | ||||
| }); | |||||
| if reply_channel.send(data).is_err() { | |||||
| finish_reason: message::FinishReason::stop, | |||||
| logprobs: None, | |||||
| }], | |||||
| usage: Usage { | |||||
| prompt_tokens, | |||||
| completion_tokens: string.len() as u64, | |||||
| total_tokens: prompt_tokens + string.len() as u64, | |||||
| }, | |||||
| }; | |||||
| if reply_channel.send(Ok(data)).is_err() { | |||||
| tracing::warn!("failed to send chat completion reply because channel closed early"); | tracing::warn!("failed to send chat completion reply because channel closed early"); | ||||
| } | } | ||||
| } | } | ||||
| @@ -168,8 +151,11 @@ async fn main() -> eyre::Result<()> { | |||||
| Event::Stop => { | Event::Stop => { | ||||
| break; | break; | ||||
| } | } | ||||
| Event::InputClosed { id, .. } => { | |||||
| info!("Input channel closed for id: {}", id); | |||||
| } | |||||
| event => { | event => { | ||||
| println!("Event: {event:#?}") | |||||
| eyre::bail!("unexpected event: {:#?}", event) | |||||
| } | } | ||||
| }, | }, | ||||
| } | } | ||||
| @@ -230,6 +230,15 @@ impl<'de> Deserialize<'de> for ChatCompletionRequest { | |||||
| } | } | ||||
| } | } | ||||
| impl ChatCompletionRequest { | |||||
| pub fn to_texts(&self) -> Vec<String> { | |||||
| self.messages | |||||
| .iter() | |||||
| .flat_map(|message| message.to_texts()) | |||||
| .collect() | |||||
| } | |||||
| } | |||||
| /// Message for comprising the conversation. | /// Message for comprising the conversation. | ||||
| #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] | #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] | ||||
| #[serde(tag = "role", rename_all = "lowercase")] | #[serde(tag = "role", rename_all = "lowercase")] | ||||
| @@ -308,6 +317,22 @@ impl ChatCompletionRequestMessage { | |||||
| ChatCompletionRequestMessage::Tool(_) => None, | ChatCompletionRequestMessage::Tool(_) => None, | ||||
| } | } | ||||
| } | } | ||||
| /// The contents of the message. | |||||
| pub fn to_texts(&self) -> Vec<String> { | |||||
| match self { | |||||
| ChatCompletionRequestMessage::System(message) => { | |||||
| vec![String::from("<|system|>\n") + &message.content] | |||||
| } | |||||
| ChatCompletionRequestMessage::User(message) => message.content.to_texts(), | |||||
| ChatCompletionRequestMessage::Assistant(message) => { | |||||
| vec![String::from("<|assistant|>\n") + &message.content.clone().unwrap_or_default()] | |||||
| } | |||||
| ChatCompletionRequestMessage::Tool(message) => { | |||||
| vec![String::from("<|tool|>\n") + &message.content.clone()] | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| /// Sampling methods used for chat completion requests. | /// Sampling methods used for chat completion requests. | ||||
| @@ -587,6 +612,25 @@ impl ChatCompletionUserMessageContent { | |||||
| ChatCompletionUserMessageContent::Parts(_) => "parts", | ChatCompletionUserMessageContent::Parts(_) => "parts", | ||||
| } | } | ||||
| } | } | ||||
| pub fn to_texts(&self) -> Vec<String> { | |||||
| match self { | |||||
| ChatCompletionUserMessageContent::Text(text) => { | |||||
| vec![String::from("user: ") + &text.clone()] | |||||
| } | |||||
| ChatCompletionUserMessageContent::Parts(parts) => parts | |||||
| .iter() | |||||
| .map(|part| match part { | |||||
| ContentPart::Text(text_part) => { | |||||
| String::from("<|user|>\n<|im_start|>\n") + &text_part.text.clone() | |||||
| } | |||||
| ContentPart::Image(image) => { | |||||
| String::from("<|user|>\n<|vision_start|>\n") + &image.image().url.clone() | |||||
| } | |||||
| }) | |||||
| .collect(), | |||||
| } | |||||
| } | |||||
| } | } | ||||
| /// Define the content part of a user message. | /// Define the content part of a user message. | ||||