| @@ -234,7 +234,13 @@ impl EventStream { | |||
| Err(err) => Event::Error(format!("{err:?}")), | |||
| } | |||
| } | |||
| NodeEvent::AllInputsClosed => Event::Stop, | |||
| NodeEvent::AllInputsClosed => { | |||
| let err = eyre!( | |||
| "received `AllInputsClosed` event, which should be handled by background task" | |||
| ); | |||
| tracing::error!("{err:?}"); | |||
| Event::Error(err.wrap_err("internal error").to_string()) | |||
| } | |||
| }, | |||
| EventItem::FatalError(err) => { | |||
| @@ -92,7 +92,6 @@ fn event_stream_loop( | |||
| clock: Arc<uhlc::HLC>, | |||
| ) { | |||
| let mut tx = Some(tx); | |||
| let mut close_tx = false; | |||
| let mut pending_drop_tokens: Vec<(DropToken, flume::Receiver<()>, Instant, u64)> = Vec::new(); | |||
| let mut drop_tokens = Vec::new(); | |||
| @@ -136,8 +135,10 @@ fn event_stream_loop( | |||
| data: Some(data), .. | |||
| } => data.drop_token(), | |||
| NodeEvent::AllInputsClosed => { | |||
| close_tx = true; | |||
| None | |||
| // close the event stream | |||
| tx = None; | |||
| // skip this internal event | |||
| continue; | |||
| } | |||
| _ => None, | |||
| }; | |||
| @@ -165,10 +166,6 @@ fn event_stream_loop( | |||
| } else { | |||
| tracing::warn!("dropping event because event `tx` was already closed: `{inner:?}`"); | |||
| } | |||
| if close_tx { | |||
| tx = None; | |||
| }; | |||
| } | |||
| }; | |||
| if let Err(err) = result { | |||
| @@ -540,7 +540,7 @@ pub async fn spawn_node( | |||
| // If log is an output, we're sending the logs to the dataflow | |||
| if let Some(stdout_output_name) = &send_stdout_to { | |||
| // Convert logs to DataMessage | |||
| let array = message.clone().into_arrow(); | |||
| let array = message.into_arrow(); | |||
| let array: ArrayData = array.into(); | |||
| let total_len = required_data_size(&array); | |||
| @@ -3,14 +3,14 @@ nodes: | |||
| build: cargo build -p dora-openai-proxy-server --release | |||
| path: ../../target/release/dora-openai-proxy-server | |||
| outputs: | |||
| - text | |||
| - chat_completion_request | |||
| inputs: | |||
| text: dora-echo/echo | |||
| completion_reply: dora-echo/echo | |||
| - id: dora-echo | |||
| build: pip install -e ../../node-hub/dora-echo | |||
| path: dora-echo | |||
| inputs: | |||
| echo: dora-openai-server/text | |||
| echo: dora-openai-server/chat_completion_request | |||
| outputs: | |||
| - echo | |||
| @@ -32,69 +32,11 @@ def test_chat_completion(user_input): | |||
| print(f"Error in chat completion: {e}") | |||
| def test_chat_completion_image_url(user_input): | |||
| """TODO: Add docstring.""" | |||
| try: | |||
| response = client.chat.completions.create( | |||
| model="gpt-3.5-turbo", | |||
| messages=[ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": "What is in this image?"}, | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" | |||
| }, | |||
| }, | |||
| ], | |||
| } | |||
| ], | |||
| ) | |||
| print("Chat Completion Response:") | |||
| print(response.choices[0].message.content) | |||
| except Exception as e: | |||
| print(f"Error in chat completion: {e}") | |||
| def test_chat_completion_image_base64(user_input): | |||
| """TODO: Add docstring.""" | |||
| try: | |||
| response = client.chat.completions.create( | |||
| model="gpt-3.5-turbo", | |||
| messages=[ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": "What is in this image?"}, | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" | |||
| }, | |||
| }, | |||
| ], | |||
| } | |||
| ], | |||
| ) | |||
| print("Chat Completion Response:") | |||
| print(response.choices[0].message.content) | |||
| except Exception as e: | |||
| print(f"Error in chat completion: {e}") | |||
| if __name__ == "__main__": | |||
| print("Testing API endpoints...") | |||
| # test_list_models() | |||
| test_list_models() | |||
| print("\n" + "=" * 50 + "\n") | |||
| chat_input = input("Enter a message for chat completion: ") | |||
| test_chat_completion(chat_input) | |||
| print("\n" + "=" * 50 + "\n") | |||
| test_chat_completion_image_url(chat_input) | |||
| print("\n" + "=" * 50 + "\n") | |||
| test_chat_completion_image_base64(chat_input) | |||
| print("\n" + "=" * 50 + "\n") | |||
| @@ -1,16 +0,0 @@ | |||
| nodes: | |||
| - id: dora-openai-server | |||
| build: cargo build -p dora-openai-proxy-server --release | |||
| path: ../../target/release/dora-openai-proxy-server | |||
| outputs: | |||
| - text | |||
| inputs: | |||
| text: dora-qwen2.5-vl/text | |||
| - id: dora-qwen2.5-vl | |||
| build: pip install -e ../../node-hub/dora-qwen2-5-vl | |||
| path: dora-qwen2-5-vl | |||
| inputs: | |||
| text: dora-openai-server/text | |||
| outputs: | |||
| - text | |||
| @@ -57,20 +57,6 @@ impl IntoArrow for &str { | |||
| } | |||
| } | |||
| impl IntoArrow for String { | |||
| type A = StringArray; | |||
| fn into_arrow(self) -> Self::A { | |||
| std::iter::once(Some(self)).collect() | |||
| } | |||
| } | |||
| impl IntoArrow for Vec<String> { | |||
| type A = StringArray; | |||
| fn into_arrow(self) -> Self::A { | |||
| StringArray::from(self) | |||
| } | |||
| } | |||
| impl IntoArrow for () { | |||
| type A = arrow::array::NullArray; | |||
| @@ -41,7 +41,7 @@ async fn main() -> eyre::Result<()> { | |||
| node.send_output( | |||
| mistral_output.clone(), | |||
| metadata.parameters, | |||
| output.as_str().into_arrow(), | |||
| output.into_arrow(), | |||
| )?; | |||
| } | |||
| other => eprintln!("Received input `{other}`"), | |||
| @@ -1,389 +1,140 @@ | |||
| """FastAPI server with OpenAI compatibility and DORA integration, | |||
| sending text and image data on separate DORA topics. | |||
| """ | |||
| """TODO: Add docstring.""" | |||
| import ast | |||
| import asyncio | |||
| import base64 | |||
| import time # For timestamps | |||
| import uuid # For generating unique request IDs | |||
| from typing import Any, List, Literal, Optional, Union | |||
| from typing import List, Optional | |||
| import pyarrow as pa | |||
| import uvicorn | |||
| from dora import Node | |||
| from fastapi import FastAPI, HTTPException | |||
| from fastapi import FastAPI | |||
| from pydantic import BaseModel | |||
| # --- DORA Configuration --- | |||
| DORA_RESPONSE_TIMEOUT_SECONDS = 20 | |||
| DORA_TEXT_OUTPUT_TOPIC = "user_text_input" | |||
| DORA_IMAGE_OUTPUT_TOPIC = "user_image_input" | |||
| DORA_RESPONSE_INPUT_TOPIC = "chat_completion_result" # Topic FastAPI listens on | |||
| app = FastAPI( | |||
| title="DORA OpenAI-Compatible Demo Server (Separate Topics)", | |||
| description="Sends text and image data on different DORA topics and awaits a consolidated response.", | |||
| ) | |||
| # --- Pydantic Models --- | |||
| class ImageUrl(BaseModel): | |||
| url: str | |||
| detail: Optional[str] = "auto" | |||
| class ContentPartText(BaseModel): | |||
| type: Literal["text"] | |||
| text: str | |||
| class ContentPartImage(BaseModel): | |||
| type: Literal["image_url"] | |||
| image_url: ImageUrl | |||
| ContentPart = Union[ContentPartText, ContentPartImage] | |||
| DORA_RESPONSE_TIMEOUT = 10 | |||
| app = FastAPI() | |||
| class ChatCompletionMessage(BaseModel): | |||
| """TODO: Add docstring.""" | |||
| role: str | |||
| content: Union[str, List[ContentPart]] | |||
| content: str | |||
| class ChatCompletionRequest(BaseModel): | |||
| """TODO: Add docstring.""" | |||
| model: str | |||
| messages: List[ChatCompletionMessage] | |||
| temperature: Optional[float] = 1.0 | |||
| max_tokens: Optional[int] = 100 | |||
| class ChatCompletionChoiceMessage(BaseModel): | |||
| role: str | |||
| content: str | |||
| class ChatCompletionChoice(BaseModel): | |||
| index: int | |||
| message: ChatCompletionChoiceMessage | |||
| finish_reason: str | |||
| logprobs: Optional[Any] = None | |||
| class Usage(BaseModel): | |||
| prompt_tokens: int | |||
| completion_tokens: int | |||
| total_tokens: int | |||
| class ChatCompletionResponse(BaseModel): | |||
| """TODO: Add docstring.""" | |||
| id: str | |||
| object: str = "chat.completion" | |||
| object: str | |||
| created: int | |||
| model: str | |||
| choices: List[ChatCompletionChoice] | |||
| usage: Usage | |||
| system_fingerprint: Optional[str] = None | |||
| choices: List[dict] | |||
| usage: dict | |||
| # --- DORA Node Initialization --- | |||
| # This dictionary will hold unmatched responses if we implement more robust concurrent handling. | |||
| # For now, it's a placeholder for future improvement. | |||
| # unmatched_dora_responses = {} | |||
| try: | |||
| node = Node() | |||
| print("FastAPI Server: DORA Node initialized.") | |||
| except Exception as e: | |||
| print( | |||
| f"FastAPI Server: Failed to initialize DORA Node. Running in standalone API mode. Error: {e}" | |||
| ) | |||
| node = None | |||
| node = Node() # provide the name to connect to the dataflow if dynamic node | |||
| @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) | |||
| @app.post("/v1/chat/completions") | |||
| async def create_chat_completion(request: ChatCompletionRequest): | |||
| internal_request_id = str(uuid.uuid4()) | |||
| openai_chat_id = f"chatcmpl-{internal_request_id}" | |||
| current_timestamp = int(time.time()) | |||
| print(f"FastAPI Server: Processing request_id: {internal_request_id}") | |||
| user_text_parts = [] | |||
| user_image_bytes: Optional[bytes] = None | |||
| user_image_content_type: Optional[str] = None | |||
| data_sent_to_dora = False | |||
| for message in reversed(request.messages): | |||
| if message.role == "user": | |||
| if isinstance(message.content, str): | |||
| user_text_parts.append(message.content) | |||
| elif isinstance(message.content, list): | |||
| for part in message.content: | |||
| if part.type == "text": | |||
| user_text_parts.append(part.text) | |||
| elif part.type == "image_url": | |||
| if user_image_bytes: # Use only the first image | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Warning - Multiple images found, using the first one." | |||
| ) | |||
| continue | |||
| image_url_data = part.image_url.url | |||
| if image_url_data.startswith("data:image"): | |||
| try: | |||
| header, encoded_data = image_url_data.split(",", 1) | |||
| user_image_content_type = header.split(":")[1].split( | |||
| ";" | |||
| )[0] | |||
| user_image_bytes = base64.b64decode(encoded_data) | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Decoded image {user_image_content_type}, size: {len(user_image_bytes)} bytes" | |||
| ) | |||
| except Exception as e: | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Error decoding base64 image: {e}" | |||
| ) | |||
| raise HTTPException( | |||
| status_code=400, | |||
| detail=f"Invalid base64 image data: {e}", | |||
| ) | |||
| else: | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Warning - Remote image URL '{image_url_data}' ignored. Only data URIs supported." | |||
| ) | |||
| # Consider if you want to break after the first user message or aggregate all | |||
| # break | |||
| final_user_text = "\n".join(reversed(user_text_parts)) if user_text_parts else "" | |||
| prompt_tokens = len(final_user_text) | |||
| if node: | |||
| if final_user_text: | |||
| text_payload = {"request_id": internal_request_id, "text": final_user_text} | |||
| arrow_text_data = pa.array([text_payload]) | |||
| node.send_output(DORA_TEXT_OUTPUT_TOPIC, arrow_text_data) | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Sent text to DORA topic '{DORA_TEXT_OUTPUT_TOPIC}'." | |||
| ) | |||
| data_sent_to_dora = True | |||
| """TODO: Add docstring.""" | |||
| data = next( | |||
| (msg.content for msg in request.messages if msg.role == "user"), | |||
| "No user message found.", | |||
| ) | |||
| if user_image_bytes: | |||
| image_payload = { | |||
| "request_id": internal_request_id, | |||
| "image_bytes": user_image_bytes, | |||
| "image_content_type": user_image_content_type | |||
| or "application/octet-stream", | |||
| } | |||
| arrow_image_data = pa.array([image_payload]) | |||
| node.send_output(DORA_IMAGE_OUTPUT_TOPIC, arrow_image_data) | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Sent image to DORA topic '{DORA_IMAGE_OUTPUT_TOPIC}'." | |||
| ) | |||
| prompt_tokens += len(user_image_bytes) # Crude image token approximation | |||
| data_sent_to_dora = True | |||
| # Convert user_message to Arrow array | |||
| # user_message_array = pa.array([user_message]) | |||
| # Publish user message to dora-echo | |||
| # node.send_output("user_query", user_message_array) | |||
| response_str: str | |||
| if not data_sent_to_dora: | |||
| if node is None: | |||
| response_str = "DORA node not available. Cannot process request." | |||
| else: | |||
| response_str = "No user text or image found to send to DORA." | |||
| print(f"FastAPI Server (Req {internal_request_id}): {response_str}") | |||
| try: | |||
| data = ast.literal_eval(data) | |||
| except ValueError: | |||
| print("Passing input as string") | |||
| except SyntaxError: | |||
| print("Passing input as string") | |||
| if isinstance(data, list): | |||
| data = pa.array(data) # initialize pyarrow array | |||
| elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float) or isinstance(data, dict): | |||
| data = pa.array([data]) | |||
| else: | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Waiting for response from DORA on topic '{DORA_RESPONSE_INPUT_TOPIC}'..." | |||
| ) | |||
| response_str = f"Timeout: No response from DORA for request_id {internal_request_id} within {DORA_RESPONSE_TIMEOUT_SECONDS}s." | |||
| # WARNING: This blocking `node.next()` loop is not ideal for highly concurrent requests | |||
| # in a single FastAPI worker process, as one request might block others or consume | |||
| # a response meant for another if `request_id` matching isn't perfect or fast enough. | |||
| # A more robust solution would involve a dedicated listener task and async Futures/Queues. | |||
| start_wait_time = time.monotonic() | |||
| while time.monotonic() - start_wait_time < DORA_RESPONSE_TIMEOUT_SECONDS: | |||
| remaining_timeout = DORA_RESPONSE_TIMEOUT_SECONDS - ( | |||
| time.monotonic() - start_wait_time | |||
| ) | |||
| if remaining_timeout <= 0: | |||
| break | |||
| event = node.next( | |||
| timeout=min(1.0, remaining_timeout) | |||
| ) # Poll with a smaller timeout | |||
| if event is None: # Timeout for this poll iteration | |||
| continue | |||
| if event["type"] == "INPUT" and event["id"] == DORA_RESPONSE_INPUT_TOPIC: | |||
| response_value_arrow = event["value"] | |||
| if response_value_arrow and len(response_value_arrow) > 0: | |||
| dora_response_data = response_value_arrow[ | |||
| 0 | |||
| ].as_py() # Expecting a dict | |||
| if isinstance(dora_response_data, dict): | |||
| resp_request_id = dora_response_data.get("request_id") | |||
| if resp_request_id == internal_request_id: | |||
| response_str = dora_response_data.get( | |||
| "response_text", | |||
| f"DORA response for {internal_request_id} missing 'response_text'.", | |||
| ) | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Received correlated DORA response." | |||
| ) | |||
| break # Correct response received | |||
| # This response is for another request. Ideally, store it. | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Received DORA response for different request_id '{resp_request_id}'. Discarding and waiting. THIS IS A CONCURRENCY ISSUE." | |||
| ) | |||
| # unmatched_dora_responses[resp_request_id] = dora_response_data # Example of storing | |||
| else: | |||
| response_str = f"Unrecognized DORA response format for {internal_request_id}: {str(dora_response_data)[:100]}" | |||
| break | |||
| else: | |||
| response_str = ( | |||
| f"Empty response payload from DORA for {internal_request_id}." | |||
| ) | |||
| break | |||
| elif event["type"] == "ERROR": | |||
| response_str = f"Error event from DORA for {internal_request_id}: {event.get('value', event.get('error', 'Unknown DORA Error'))}" | |||
| print(response_str) | |||
| break | |||
| else: # Outer while loop timed out | |||
| print( | |||
| f"FastAPI Server (Req {internal_request_id}): Overall timeout waiting for DORA response." | |||
| ) | |||
| completion_tokens = len(response_str) | |||
| total_tokens = prompt_tokens + completion_tokens | |||
| data = pa.array(data) # initialize pyarrow array | |||
| node.send_output("v1/chat/completions", data) | |||
| # Wait for response from dora-echo | |||
| while True: | |||
| event = node.next(timeout=DORA_RESPONSE_TIMEOUT) | |||
| if event["type"] == "ERROR": | |||
| response_str = "No response received. Err: " + event["value"][0].as_py() | |||
| break | |||
| if event["type"] == "INPUT" and event["id"] == "v1/chat/completions": | |||
| response = event["value"] | |||
| response_str = response[0].as_py() if response else "No response received" | |||
| break | |||
| return ChatCompletionResponse( | |||
| id=openai_chat_id, | |||
| created=current_timestamp, | |||
| id="chatcmpl-1234", | |||
| object="chat.completion", | |||
| created=1234567890, | |||
| model=request.model, | |||
| choices=[ | |||
| ChatCompletionChoice( | |||
| index=0, | |||
| message=ChatCompletionChoiceMessage( | |||
| role="assistant", content=response_str | |||
| ), | |||
| finish_reason="stop", | |||
| ) | |||
| { | |||
| "index": 0, | |||
| "message": {"role": "assistant", "content": response_str}, | |||
| "finish_reason": "stop", | |||
| }, | |||
| ], | |||
| usage=Usage( | |||
| prompt_tokens=prompt_tokens, | |||
| completion_tokens=completion_tokens, | |||
| total_tokens=total_tokens, | |||
| ), | |||
| usage={ | |||
| "prompt_tokens": len(data), | |||
| "completion_tokens": len(response_str), | |||
| "total_tokens": len(data) + len(response_str), | |||
| }, | |||
| ) | |||
| @app.get("/v1/models") | |||
| async def list_models(): | |||
| """TODO: Add docstring.""" | |||
| return { | |||
| "object": "list", | |||
| "data": [ | |||
| { | |||
| "id": "dora-multi-stream-vision", | |||
| "id": "gpt-3.5-turbo", | |||
| "object": "model", | |||
| "created": int(time.time()), | |||
| "owned_by": "dora-ai", | |||
| "permission": [], | |||
| "root": "dora-multi-stream-vision", | |||
| "parent": None, | |||
| "created": 1677610602, | |||
| "owned_by": "openai", | |||
| }, | |||
| ], | |||
| } | |||
| async def run_fastapi_server_task(): | |||
| async def run_fastapi(): | |||
| """TODO: Add docstring.""" | |||
| config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="info") | |||
| server = uvicorn.Server(config) | |||
| print("FastAPI Server: Uvicorn server starting.") | |||
| await server.serve() | |||
| print("FastAPI Server: Uvicorn server stopped.") | |||
| async def run_dora_main_loop_task(): | |||
| if not node: | |||
| print("FastAPI Server: DORA node not initialized, DORA main loop skipped.") | |||
| return | |||
| print("FastAPI Server: DORA main loop listener started (for STOP event).") | |||
| try: | |||
| while True: | |||
| # This loop is primarily for the main "STOP" event for the FastAPI node itself. | |||
| # Individual request/response cycles are handled within the endpoint. | |||
| event = node.next(timeout=1.0) # Check for STOP periodically | |||
| if event is None: | |||
| await asyncio.sleep(0.01) # Yield control if no event | |||
| continue | |||
| if event["type"] == "STOP": | |||
| print( | |||
| "FastAPI Server: DORA STOP event received. Requesting server shutdown." | |||
| ) | |||
| # Attempt to gracefully shut down Uvicorn | |||
| # This is tricky; uvicorn's server.shutdown() or server.should_exit might be better | |||
| # For simplicity, we cancel the server task. | |||
| for task in asyncio.all_tasks(): | |||
| # Identify the server task more reliably if possible | |||
| if ( | |||
| task.get_coro().__name__ == "serve" | |||
| and hasattr(task.get_coro(), "cr_frame") | |||
| and isinstance( | |||
| task.get_coro().cr_frame.f_locals.get("self"), | |||
| uvicorn.Server, | |||
| ) | |||
| ): | |||
| task.cancel() | |||
| print( | |||
| "FastAPI Server: Uvicorn server task cancellation requested." | |||
| ) | |||
| break | |||
| # Handle other unexpected general inputs/errors for the FastAPI node if necessary | |||
| # elif event["type"] == "INPUT": | |||
| # print(f"FastAPI Server (DORA Main Loop): Unexpected DORA input on ID '{event['id']}'") | |||
| except asyncio.CancelledError: | |||
| print("FastAPI Server: DORA main loop task cancelled.") | |||
| except Exception as e: | |||
| print(f"FastAPI Server: Error in DORA main loop: {e}") | |||
| finally: | |||
| print("FastAPI Server: DORA main loop listener finished.") | |||
| async def main_async_runner(): | |||
| server_task = asyncio.create_task(run_fastapi_server_task()) | |||
| # Only run the DORA main loop if the node was initialized. | |||
| # This loop is mainly for the STOP event. | |||
| dora_listener_task = None | |||
| if node: | |||
| dora_listener_task = asyncio.create_task(run_dora_main_loop_task()) | |||
| tasks_to_wait_for = [server_task, dora_listener_task] | |||
| else: | |||
| tasks_to_wait_for = [server_task] | |||
| done, pending = await asyncio.wait( | |||
| tasks_to_wait_for, | |||
| return_when=asyncio.FIRST_COMPLETED, | |||
| ) | |||
| for task in pending: | |||
| print(f"FastAPI Server: Cancelling pending task: {task.get_name()}") | |||
| task.cancel() | |||
| if pending: | |||
| await asyncio.gather(*pending, return_exceptions=True) | |||
| print("FastAPI Server: Application shutdown complete.") | |||
| server = asyncio.gather(server.serve()) | |||
| while True: | |||
| await asyncio.sleep(1) | |||
| event = node.next(0.001) | |||
| if event["type"] == "STOP": | |||
| break | |||
| def main(): | |||
| print("FastAPI Server: Starting application...") | |||
| try: | |||
| asyncio.run(main_async_runner()) | |||
| except KeyboardInterrupt: | |||
| print("FastAPI Server: Keyboard interrupt received. Shutting down.") | |||
| finally: | |||
| print("FastAPI Server: Exited main function.") | |||
| """TODO: Add docstring.""" | |||
| asyncio.run(run_fastapi()) | |||
| if __name__ == "__main__": | |||
| main() | |||
| asyncio.run(run_fastapi()) | |||
| @@ -2,8 +2,8 @@ | |||
| name = "dora-openai-server" | |||
| version = "0.3.11" | |||
| authors = [ | |||
| { name = "Haixuan Xavier Tao", email = "tao.xavier@outlook.com" }, | |||
| { name = "Enzo Le Van", email = "dev@enzo-le-van.fr" }, | |||
| { name = "Haixuan Xavier Tao", email = "tao.xavier@outlook.com" }, | |||
| { name = "Enzo Le Van", email = "dev@enzo-le-van.fr" }, | |||
| ] | |||
| description = "Dora OpenAI API Server" | |||
| license = { text = "MIT" } | |||
| @@ -11,13 +11,14 @@ readme = "README.md" | |||
| requires-python = ">=3.8" | |||
| dependencies = [ | |||
| "dora-rs >= 0.3.9", | |||
| "numpy < 2.0.0", | |||
| "pyarrow >= 5.0.0", | |||
| "fastapi >= 0.115", | |||
| "asyncio >= 3.4", | |||
| "uvicorn >= 0.31", | |||
| "pydantic >= 2.9", | |||
| "dora-rs >= 0.3.9", | |||
| "numpy < 2.0.0", | |||
| "pyarrow >= 5.0.0", | |||
| "fastapi >= 0.115", | |||
| "asyncio >= 3.4", | |||
| "uvicorn >= 0.31", | |||
| "pydantic >= 2.9", | |||
| ] | |||
| [dependency-groups] | |||
| @@ -28,6 +29,7 @@ dora-openai-server = "dora_openai_server.main:main" | |||
| [tool.ruff.lint] | |||
| extend-select = [ | |||
| "D", # pydocstyle | |||
| "UP", # Ruff's UP rule | |||
| "PERF", # Ruff's PERF rule | |||
| "RET", # Ruff's RET rule | |||
| @@ -62,118 +62,29 @@ if ADAPTER_PATH != "": | |||
| processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | |||
| def generate( | |||
| frames: dict, texts: list[str], history, past_key_values=None, image_id=None | |||
| ): | |||
| def generate(frames: dict, question, history, past_key_values=None, image_id=None): | |||
| """Generate the response to the question given the image using Qwen2 model.""" | |||
| if image_id is not None: | |||
| images = [frames[image_id]] | |||
| else: | |||
| images = list(frames.values()) | |||
| messages = [] | |||
| for text in texts: | |||
| if text.startswith("<|system|>\n"): | |||
| messages.append( | |||
| { | |||
| "role": "system", | |||
| "content": [ | |||
| {"type": "text", "text": text.replace("<|system|>\n", "")}, | |||
| ], | |||
| } | |||
| ) | |||
| elif text.startswith("<|assistant|>\n"): | |||
| messages.append( | |||
| { | |||
| "role": "assistant", | |||
| "content": [ | |||
| {"type": "text", "text": text.replace("<|assistant|>\n", "")}, | |||
| ], | |||
| } | |||
| ) | |||
| elif text.startswith("<|tool|>\n"): | |||
| messages.append( | |||
| { | |||
| "role": "tool", | |||
| "content": [ | |||
| {"type": "text", "text": text.replace("<|tool|>\n", "")}, | |||
| ], | |||
| } | |||
| ) | |||
| elif text.startswith("<|user|>\n<|im_start|>\n"): | |||
| messages.append( | |||
| messages = [ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "type": "text", | |||
| "text": text.replace("<|user|>\n<|im_start|>\n", ""), | |||
| }, | |||
| ], | |||
| "type": "image", | |||
| "image": image, | |||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||
| } | |||
| ) | |||
| elif text.startswith("<|user|>\n<|vision_start|>\n"): | |||
| # Handle the case where the text starts with <|user|>\n<|vision_start|> | |||
| image_url = text.replace("<|user|>\n<|vision_start|>\n", "") | |||
| # If the last message was from the user, append the image URL to it | |||
| if messages[-1]["role"] == "user": | |||
| messages[-1]["content"].append( | |||
| { | |||
| "type": "image", | |||
| "image": image_url, | |||
| } | |||
| ) | |||
| else: | |||
| messages.append( | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "type": "image", | |||
| "image": image_url, | |||
| }, | |||
| ], | |||
| } | |||
| ) | |||
| else: | |||
| messages.append( | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": text}, | |||
| ], | |||
| } | |||
| ) | |||
| # If the last message was from the user, append the image URL to it | |||
| if messages[-1]["role"] == "user": | |||
| messages[-1]["content"] += [ | |||
| { | |||
| "type": "image", | |||
| "image": image, | |||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||
| } | |||
| for image in images | |||
| ] | |||
| else: | |||
| messages.append( | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "type": "image", | |||
| "image": image, | |||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||
| } | |||
| for image in images | |||
| ], | |||
| } | |||
| ) | |||
| for image in images | |||
| ] | |||
| + [ | |||
| {"type": "text", "text": question}, | |||
| ], | |||
| }, | |||
| ] | |||
| tmp_history = history + messages | |||
| # Preparation for inference | |||
| text = processor.apply_chat_template( | |||
| @@ -209,13 +120,19 @@ def generate( | |||
| clean_up_tokenization_spaces=False, | |||
| ) | |||
| if HISTORY: | |||
| history = tmp_history + [ | |||
| history += [ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": question}, | |||
| ], | |||
| }, | |||
| { | |||
| "role": "assistant", | |||
| "content": [ | |||
| {"type": "text", "text": output_text[0]}, | |||
| ], | |||
| } | |||
| }, | |||
| ] | |||
| return output_text[0], history, past_key_values | |||
| @@ -290,22 +207,24 @@ def main(): | |||
| elif "text" in event_id: | |||
| if len(event["value"]) > 0: | |||
| texts = event["value"].to_pylist() | |||
| text = event["value"][0].as_py() | |||
| image_id = event["metadata"].get("image_id", None) | |||
| else: | |||
| texts = cached_text | |||
| words = texts[-1].split() | |||
| text = cached_text | |||
| words = text.split() | |||
| if len(ACTIVATION_WORDS) > 0 and all( | |||
| word not in ACTIVATION_WORDS for word in words | |||
| ): | |||
| continue | |||
| cached_text = texts | |||
| cached_text = text | |||
| if len(frames.keys()) == 0: | |||
| continue | |||
| # set the max number of tiles in `max_num` | |||
| response, history, past_key_values = generate( | |||
| frames, | |||
| texts, | |||
| text, | |||
| history, | |||
| past_key_values, | |||
| image_id, | |||
| @@ -1,10 +1,4 @@ | |||
| use dora_node_api::{ | |||
| self, | |||
| arrow::array::{AsArray, StringArray}, | |||
| dora_core::config::DataId, | |||
| merged::MergeExternalSend, | |||
| DoraNode, Event, | |||
| }; | |||
| use dora_node_api::{self, dora_core::config::DataId, merged::MergeExternalSend, DoraNode, Event}; | |||
| use eyre::{Context, ContextCompat}; | |||
| use futures::{ | |||
| @@ -20,7 +14,7 @@ use hyper::{ | |||
| }; | |||
| use message::{ | |||
| ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage, | |||
| ChatCompletionRequest, Usage, | |||
| ChatCompletionRequest, ChatCompletionRequestMessage, Usage, | |||
| }; | |||
| use std::{ | |||
| collections::VecDeque, | |||
| @@ -77,7 +71,7 @@ async fn main() -> eyre::Result<()> { | |||
| let merged = events.merge_external_send(server_events); | |||
| let events = futures::executor::block_on_stream(merged); | |||
| let output_id = DataId::from("text".to_owned()); | |||
| let output_id = DataId::from("chat_completion_request".to_owned()); | |||
| let mut reply_channels = VecDeque::new(); | |||
| for event in events { | |||
| @@ -88,15 +82,45 @@ async fn main() -> eyre::Result<()> { | |||
| break; | |||
| } | |||
| ServerEvent::ChatCompletionRequest { request, reply } => { | |||
| let texts = request.to_texts(); | |||
| node.send_output( | |||
| output_id.clone(), | |||
| Default::default(), | |||
| StringArray::from(texts), | |||
| ) | |||
| .context("failed to send dora output")?; | |||
| reply_channels.push_back((reply, 0 as u64, request.model)); | |||
| let message = request | |||
| .messages | |||
| .into_iter() | |||
| .find_map(|m| match m { | |||
| ChatCompletionRequestMessage::User(message) => Some(message), | |||
| _ => None, | |||
| }) | |||
| .context("no user message found"); | |||
| match message { | |||
| Ok(message) => match message.content() { | |||
| message::ChatCompletionUserMessageContent::Text(content) => { | |||
| node.send_output_bytes( | |||
| output_id.clone(), | |||
| Default::default(), | |||
| content.len(), | |||
| content.as_bytes(), | |||
| ) | |||
| .context("failed to send dora output")?; | |||
| reply_channels.push_back(( | |||
| reply, | |||
| content.as_bytes().len() as u64, | |||
| request.model, | |||
| )); | |||
| } | |||
| message::ChatCompletionUserMessageContent::Parts(_) => { | |||
| if reply | |||
| .send(Err(eyre::eyre!("unsupported message content"))) | |||
| .is_err() | |||
| { | |||
| tracing::warn!("failed to send chat completion reply because channel closed early"); | |||
| }; | |||
| } | |||
| }, | |||
| Err(err) => { | |||
| if reply.send(Err(err)).is_err() { | |||
| tracing::warn!("failed to send chat completion reply error because channel closed early"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }, | |||
| dora_node_api::merged::MergedEvent::Dora(event) => match event { | |||
| @@ -106,42 +130,35 @@ async fn main() -> eyre::Result<()> { | |||
| metadata: _, | |||
| } => { | |||
| match id.as_str() { | |||
| "text" => { | |||
| "completion_reply" => { | |||
| let (reply_channel, prompt_tokens, model) = | |||
| reply_channels.pop_front().context("no reply channel")?; | |||
| let data = data.as_string::<i32>(); | |||
| let string = data.iter().fold("".to_string(), |mut acc, s| { | |||
| if let Some(s) = s { | |||
| acc.push_str("\n"); | |||
| acc.push_str(s); | |||
| } | |||
| acc | |||
| }); | |||
| let data = ChatCompletionObject { | |||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||
| object: "chat.completion".to_string(), | |||
| created: chrono::Utc::now().timestamp() as u64, | |||
| model: model.unwrap_or_default(), | |||
| choices: vec![ChatCompletionObjectChoice { | |||
| index: 0, | |||
| message: ChatCompletionObjectMessage { | |||
| role: message::ChatCompletionRole::Assistant, | |||
| content: Some(string.to_string()), | |||
| tool_calls: Vec::new(), | |||
| function_call: None, | |||
| let data = TryFrom::try_from(&data) | |||
| .with_context(|| format!("invalid reply data: {data:?}")) | |||
| .map(|s: &[u8]| ChatCompletionObject { | |||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||
| object: "chat.completion".to_string(), | |||
| created: chrono::Utc::now().timestamp() as u64, | |||
| model: model.unwrap_or_default(), | |||
| choices: vec![ChatCompletionObjectChoice { | |||
| index: 0, | |||
| message: ChatCompletionObjectMessage { | |||
| role: message::ChatCompletionRole::Assistant, | |||
| content: Some(String::from_utf8_lossy(s).to_string()), | |||
| tool_calls: Vec::new(), | |||
| function_call: None, | |||
| }, | |||
| finish_reason: message::FinishReason::stop, | |||
| logprobs: None, | |||
| }], | |||
| usage: Usage { | |||
| prompt_tokens, | |||
| completion_tokens: s.len() as u64, | |||
| total_tokens: prompt_tokens + s.len() as u64, | |||
| }, | |||
| finish_reason: message::FinishReason::stop, | |||
| logprobs: None, | |||
| }], | |||
| usage: Usage { | |||
| prompt_tokens, | |||
| completion_tokens: string.len() as u64, | |||
| total_tokens: prompt_tokens + string.len() as u64, | |||
| }, | |||
| }; | |||
| if reply_channel.send(Ok(data)).is_err() { | |||
| }); | |||
| if reply_channel.send(data).is_err() { | |||
| tracing::warn!("failed to send chat completion reply because channel closed early"); | |||
| } | |||
| } | |||
| @@ -151,11 +168,8 @@ async fn main() -> eyre::Result<()> { | |||
| Event::Stop => { | |||
| break; | |||
| } | |||
| Event::InputClosed { id, .. } => { | |||
| info!("Input channel closed for id: {}", id); | |||
| } | |||
| event => { | |||
| eyre::bail!("unexpected event: {:#?}", event) | |||
| println!("Event: {event:#?}") | |||
| } | |||
| }, | |||
| } | |||
| @@ -230,15 +230,6 @@ impl<'de> Deserialize<'de> for ChatCompletionRequest { | |||
| } | |||
| } | |||
| impl ChatCompletionRequest { | |||
| pub fn to_texts(&self) -> Vec<String> { | |||
| self.messages | |||
| .iter() | |||
| .flat_map(|message| message.to_texts()) | |||
| .collect() | |||
| } | |||
| } | |||
| /// Message for comprising the conversation. | |||
| #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] | |||
| #[serde(tag = "role", rename_all = "lowercase")] | |||
| @@ -317,22 +308,6 @@ impl ChatCompletionRequestMessage { | |||
| ChatCompletionRequestMessage::Tool(_) => None, | |||
| } | |||
| } | |||
| /// The contents of the message. | |||
| pub fn to_texts(&self) -> Vec<String> { | |||
| match self { | |||
| ChatCompletionRequestMessage::System(message) => { | |||
| vec![String::from("<|system|>\n") + &message.content] | |||
| } | |||
| ChatCompletionRequestMessage::User(message) => message.content.to_texts(), | |||
| ChatCompletionRequestMessage::Assistant(message) => { | |||
| vec![String::from("<|assistant|>\n") + &message.content.clone().unwrap_or_default()] | |||
| } | |||
| ChatCompletionRequestMessage::Tool(message) => { | |||
| vec![String::from("<|tool|>\n") + &message.content.clone()] | |||
| } | |||
| } | |||
| } | |||
| } | |||
| /// Sampling methods used for chat completion requests. | |||
| @@ -612,25 +587,6 @@ impl ChatCompletionUserMessageContent { | |||
| ChatCompletionUserMessageContent::Parts(_) => "parts", | |||
| } | |||
| } | |||
| pub fn to_texts(&self) -> Vec<String> { | |||
| match self { | |||
| ChatCompletionUserMessageContent::Text(text) => { | |||
| vec![String::from("user: ") + &text.clone()] | |||
| } | |||
| ChatCompletionUserMessageContent::Parts(parts) => parts | |||
| .iter() | |||
| .map(|part| match part { | |||
| ContentPart::Text(text_part) => { | |||
| String::from("<|user|>\n<|im_start|>\n") + &text_part.text.clone() | |||
| } | |||
| ContentPart::Image(image) => { | |||
| String::from("<|user|>\n<|vision_start|>\n") + &image.image().url.clone() | |||
| } | |||
| }) | |||
| .collect(), | |||
| } | |||
| } | |||
| } | |||
| /// Define the content part of a user message. | |||