diff --git a/apis/rust/node/src/event_stream/mod.rs b/apis/rust/node/src/event_stream/mod.rs index 15c40e33..af8c42e6 100644 --- a/apis/rust/node/src/event_stream/mod.rs +++ b/apis/rust/node/src/event_stream/mod.rs @@ -234,13 +234,7 @@ impl EventStream { Err(err) => Event::Error(format!("{err:?}")), } } - NodeEvent::AllInputsClosed => { - let err = eyre!( - "received `AllInputsClosed` event, which should be handled by background task" - ); - tracing::error!("{err:?}"); - Event::Error(err.wrap_err("internal error").to_string()) - } + NodeEvent::AllInputsClosed => Event::Stop, }, EventItem::FatalError(err) => { diff --git a/apis/rust/node/src/event_stream/thread.rs b/apis/rust/node/src/event_stream/thread.rs index 5e982f74..a9dbba27 100644 --- a/apis/rust/node/src/event_stream/thread.rs +++ b/apis/rust/node/src/event_stream/thread.rs @@ -92,6 +92,7 @@ fn event_stream_loop( clock: Arc, ) { let mut tx = Some(tx); + let mut close_tx = false; let mut pending_drop_tokens: Vec<(DropToken, flume::Receiver<()>, Instant, u64)> = Vec::new(); let mut drop_tokens = Vec::new(); @@ -135,10 +136,8 @@ fn event_stream_loop( data: Some(data), .. } => data.drop_token(), NodeEvent::AllInputsClosed => { - // close the event stream - tx = None; - // skip this internal event - continue; + close_tx = true; + None } _ => None, }; @@ -166,6 +165,10 @@ fn event_stream_loop( } else { tracing::warn!("dropping event because event `tx` was already closed: `{inner:?}`"); } + + if close_tx { + tx = None; + }; } }; if let Err(err) = result { diff --git a/binaries/daemon/src/spawn.rs b/binaries/daemon/src/spawn.rs index 9087a4ec..1e5b5bf7 100644 --- a/binaries/daemon/src/spawn.rs +++ b/binaries/daemon/src/spawn.rs @@ -540,7 +540,7 @@ pub async fn spawn_node( // If log is an output, we're sending the logs to the dataflow if let Some(stdout_output_name) = &send_stdout_to { // Convert logs to DataMessage - let array = message.into_arrow(); + let array = message.clone().into_arrow(); let array: ArrayData = array.into(); let total_len = required_data_size(&array); diff --git a/examples/openai-server/dataflow-rust.yml b/examples/openai-server/dataflow-rust.yml index 8c6a1d8d..85668b5a 100644 --- a/examples/openai-server/dataflow-rust.yml +++ b/examples/openai-server/dataflow-rust.yml @@ -3,14 +3,14 @@ nodes: build: cargo build -p dora-openai-proxy-server --release path: ../../target/release/dora-openai-proxy-server outputs: - - chat_completion_request + - text inputs: - completion_reply: dora-echo/echo + text: dora-echo/echo - id: dora-echo build: pip install -e ../../node-hub/dora-echo path: dora-echo inputs: - echo: dora-openai-server/chat_completion_request + echo: dora-openai-server/text outputs: - echo diff --git a/examples/openai-server/openai_api_client.py b/examples/openai-server/openai_api_client.py index 0a88d5b1..1d81307b 100644 --- a/examples/openai-server/openai_api_client.py +++ b/examples/openai-server/openai_api_client.py @@ -32,11 +32,69 @@ def test_chat_completion(user_input): print(f"Error in chat completion: {e}") +def test_chat_completion_image_url(user_input): + """TODO: Add docstring.""" + try: + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + }, + }, + ], + } + ], + ) + print("Chat Completion Response:") + print(response.choices[0].message.content) + except Exception as e: + print(f"Error in chat completion: {e}") + + +def test_chat_completion_image_base64(user_input): + """TODO: Add docstring.""" + try: + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + ) + print("Chat Completion Response:") + print(response.choices[0].message.content) + except Exception as e: + print(f"Error in chat completion: {e}") + + if __name__ == "__main__": print("Testing API endpoints...") - test_list_models() + # test_list_models() print("\n" + "=" * 50 + "\n") chat_input = input("Enter a message for chat completion: ") test_chat_completion(chat_input) + + print("\n" + "=" * 50 + "\n") + + test_chat_completion_image_url(chat_input) + print("\n" + "=" * 50 + "\n") + test_chat_completion_image_base64(chat_input) print("\n" + "=" * 50 + "\n") diff --git a/examples/openai-server/qwenvl.yml b/examples/openai-server/qwenvl.yml new file mode 100644 index 00000000..b737b3be --- /dev/null +++ b/examples/openai-server/qwenvl.yml @@ -0,0 +1,16 @@ +nodes: + - id: dora-openai-server + build: cargo build -p dora-openai-proxy-server --release + path: ../../target/release/dora-openai-proxy-server + outputs: + - text + inputs: + text: dora-qwen2.5-vl/text + + - id: dora-qwen2.5-vl + build: pip install -e ../../node-hub/dora-qwen2-5-vl + path: dora-qwen2-5-vl + inputs: + text: dora-openai-server/text + outputs: + - text diff --git a/libraries/arrow-convert/src/into_impls.rs b/libraries/arrow-convert/src/into_impls.rs index a8434694..8d8a7dd1 100644 --- a/libraries/arrow-convert/src/into_impls.rs +++ b/libraries/arrow-convert/src/into_impls.rs @@ -57,6 +57,20 @@ impl IntoArrow for &str { } } +impl IntoArrow for String { + type A = StringArray; + fn into_arrow(self) -> Self::A { + std::iter::once(Some(self)).collect() + } +} + +impl IntoArrow for Vec { + type A = StringArray; + fn into_arrow(self) -> Self::A { + StringArray::from(self) + } +} + impl IntoArrow for () { type A = arrow::array::NullArray; diff --git a/node-hub/dora-mistral-rs/src/main.rs b/node-hub/dora-mistral-rs/src/main.rs index bb451e1e..a6beae37 100644 --- a/node-hub/dora-mistral-rs/src/main.rs +++ b/node-hub/dora-mistral-rs/src/main.rs @@ -41,7 +41,7 @@ async fn main() -> eyre::Result<()> { node.send_output( mistral_output.clone(), metadata.parameters, - output.into_arrow(), + output.as_str().into_arrow(), )?; } other => eprintln!("Received input `{other}`"), diff --git a/node-hub/dora-openai-server/dora_openai_server/main.py b/node-hub/dora-openai-server/dora_openai_server/main.py index aa4c25b8..e1713392 100644 --- a/node-hub/dora-openai-server/dora_openai_server/main.py +++ b/node-hub/dora-openai-server/dora_openai_server/main.py @@ -1,140 +1,389 @@ -"""TODO: Add docstring.""" +"""FastAPI server with OpenAI compatibility and DORA integration, +sending text and image data on separate DORA topics. +""" -import ast import asyncio -from typing import List, Optional +import base64 +import time # For timestamps +import uuid # For generating unique request IDs +from typing import Any, List, Literal, Optional, Union import pyarrow as pa import uvicorn from dora import Node -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from pydantic import BaseModel -DORA_RESPONSE_TIMEOUT = 10 -app = FastAPI() +# --- DORA Configuration --- +DORA_RESPONSE_TIMEOUT_SECONDS = 20 +DORA_TEXT_OUTPUT_TOPIC = "user_text_input" +DORA_IMAGE_OUTPUT_TOPIC = "user_image_input" +DORA_RESPONSE_INPUT_TOPIC = "chat_completion_result" # Topic FastAPI listens on +app = FastAPI( + title="DORA OpenAI-Compatible Demo Server (Separate Topics)", + description="Sends text and image data on different DORA topics and awaits a consolidated response.", +) + + +# --- Pydantic Models --- +class ImageUrl(BaseModel): + url: str + detail: Optional[str] = "auto" + + +class ContentPartText(BaseModel): + type: Literal["text"] + text: str + + +class ContentPartImage(BaseModel): + type: Literal["image_url"] + image_url: ImageUrl -class ChatCompletionMessage(BaseModel): - """TODO: Add docstring.""" +ContentPart = Union[ContentPartText, ContentPartImage] + + +class ChatCompletionMessage(BaseModel): role: str - content: str + content: Union[str, List[ContentPart]] class ChatCompletionRequest(BaseModel): - """TODO: Add docstring.""" - model: str messages: List[ChatCompletionMessage] temperature: Optional[float] = 1.0 max_tokens: Optional[int] = 100 -class ChatCompletionResponse(BaseModel): - """TODO: Add docstring.""" +class ChatCompletionChoiceMessage(BaseModel): + role: str + content: str + + +class ChatCompletionChoice(BaseModel): + index: int + message: ChatCompletionChoiceMessage + finish_reason: str + logprobs: Optional[Any] = None + + +class Usage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + +class ChatCompletionResponse(BaseModel): id: str - object: str + object: str = "chat.completion" created: int model: str - choices: List[dict] - usage: dict + choices: List[ChatCompletionChoice] + usage: Usage + system_fingerprint: Optional[str] = None + +# --- DORA Node Initialization --- +# This dictionary will hold unmatched responses if we implement more robust concurrent handling. +# For now, it's a placeholder for future improvement. +# unmatched_dora_responses = {} -node = Node() # provide the name to connect to the dataflow if dynamic node +try: + node = Node() + print("FastAPI Server: DORA Node initialized.") +except Exception as e: + print( + f"FastAPI Server: Failed to initialize DORA Node. Running in standalone API mode. Error: {e}" + ) + node = None -@app.post("/v1/chat/completions") +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): - """TODO: Add docstring.""" - data = next( - (msg.content for msg in request.messages if msg.role == "user"), - "No user message found.", - ) + internal_request_id = str(uuid.uuid4()) + openai_chat_id = f"chatcmpl-{internal_request_id}" + current_timestamp = int(time.time()) - # Convert user_message to Arrow array - # user_message_array = pa.array([user_message]) - # Publish user message to dora-echo - # node.send_output("user_query", user_message_array) + print(f"FastAPI Server: Processing request_id: {internal_request_id}") - try: - data = ast.literal_eval(data) - except ValueError: - print("Passing input as string") - except SyntaxError: - print("Passing input as string") - if isinstance(data, list): - data = pa.array(data) # initialize pyarrow array - elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float) or isinstance(data, dict): - data = pa.array([data]) + user_text_parts = [] + user_image_bytes: Optional[bytes] = None + user_image_content_type: Optional[str] = None + data_sent_to_dora = False + + for message in reversed(request.messages): + if message.role == "user": + if isinstance(message.content, str): + user_text_parts.append(message.content) + elif isinstance(message.content, list): + for part in message.content: + if part.type == "text": + user_text_parts.append(part.text) + elif part.type == "image_url": + if user_image_bytes: # Use only the first image + print( + f"FastAPI Server (Req {internal_request_id}): Warning - Multiple images found, using the first one." + ) + continue + image_url_data = part.image_url.url + if image_url_data.startswith("data:image"): + try: + header, encoded_data = image_url_data.split(",", 1) + user_image_content_type = header.split(":")[1].split( + ";" + )[0] + user_image_bytes = base64.b64decode(encoded_data) + print( + f"FastAPI Server (Req {internal_request_id}): Decoded image {user_image_content_type}, size: {len(user_image_bytes)} bytes" + ) + except Exception as e: + print( + f"FastAPI Server (Req {internal_request_id}): Error decoding base64 image: {e}" + ) + raise HTTPException( + status_code=400, + detail=f"Invalid base64 image data: {e}", + ) + else: + print( + f"FastAPI Server (Req {internal_request_id}): Warning - Remote image URL '{image_url_data}' ignored. Only data URIs supported." + ) + # Consider if you want to break after the first user message or aggregate all + # break + + final_user_text = "\n".join(reversed(user_text_parts)) if user_text_parts else "" + prompt_tokens = len(final_user_text) + + if node: + if final_user_text: + text_payload = {"request_id": internal_request_id, "text": final_user_text} + arrow_text_data = pa.array([text_payload]) + node.send_output(DORA_TEXT_OUTPUT_TOPIC, arrow_text_data) + print( + f"FastAPI Server (Req {internal_request_id}): Sent text to DORA topic '{DORA_TEXT_OUTPUT_TOPIC}'." + ) + data_sent_to_dora = True + + if user_image_bytes: + image_payload = { + "request_id": internal_request_id, + "image_bytes": user_image_bytes, + "image_content_type": user_image_content_type + or "application/octet-stream", + } + arrow_image_data = pa.array([image_payload]) + node.send_output(DORA_IMAGE_OUTPUT_TOPIC, arrow_image_data) + print( + f"FastAPI Server (Req {internal_request_id}): Sent image to DORA topic '{DORA_IMAGE_OUTPUT_TOPIC}'." + ) + prompt_tokens += len(user_image_bytes) # Crude image token approximation + data_sent_to_dora = True + + response_str: str + if not data_sent_to_dora: + if node is None: + response_str = "DORA node not available. Cannot process request." + else: + response_str = "No user text or image found to send to DORA." + print(f"FastAPI Server (Req {internal_request_id}): {response_str}") else: - data = pa.array(data) # initialize pyarrow array - node.send_output("v1/chat/completions", data) - - # Wait for response from dora-echo - while True: - event = node.next(timeout=DORA_RESPONSE_TIMEOUT) - if event["type"] == "ERROR": - response_str = "No response received. Err: " + event["value"][0].as_py() - break - if event["type"] == "INPUT" and event["id"] == "v1/chat/completions": - response = event["value"] - response_str = response[0].as_py() if response else "No response received" - break + print( + f"FastAPI Server (Req {internal_request_id}): Waiting for response from DORA on topic '{DORA_RESPONSE_INPUT_TOPIC}'..." + ) + response_str = f"Timeout: No response from DORA for request_id {internal_request_id} within {DORA_RESPONSE_TIMEOUT_SECONDS}s." + + # WARNING: This blocking `node.next()` loop is not ideal for highly concurrent requests + # in a single FastAPI worker process, as one request might block others or consume + # a response meant for another if `request_id` matching isn't perfect or fast enough. + # A more robust solution would involve a dedicated listener task and async Futures/Queues. + start_wait_time = time.monotonic() + while time.monotonic() - start_wait_time < DORA_RESPONSE_TIMEOUT_SECONDS: + remaining_timeout = DORA_RESPONSE_TIMEOUT_SECONDS - ( + time.monotonic() - start_wait_time + ) + if remaining_timeout <= 0: + break + + event = node.next( + timeout=min(1.0, remaining_timeout) + ) # Poll with a smaller timeout + + if event is None: # Timeout for this poll iteration + continue + + if event["type"] == "INPUT" and event["id"] == DORA_RESPONSE_INPUT_TOPIC: + response_value_arrow = event["value"] + if response_value_arrow and len(response_value_arrow) > 0: + dora_response_data = response_value_arrow[ + 0 + ].as_py() # Expecting a dict + if isinstance(dora_response_data, dict): + resp_request_id = dora_response_data.get("request_id") + if resp_request_id == internal_request_id: + response_str = dora_response_data.get( + "response_text", + f"DORA response for {internal_request_id} missing 'response_text'.", + ) + print( + f"FastAPI Server (Req {internal_request_id}): Received correlated DORA response." + ) + break # Correct response received + # This response is for another request. Ideally, store it. + print( + f"FastAPI Server (Req {internal_request_id}): Received DORA response for different request_id '{resp_request_id}'. Discarding and waiting. THIS IS A CONCURRENCY ISSUE." + ) + # unmatched_dora_responses[resp_request_id] = dora_response_data # Example of storing + else: + response_str = f"Unrecognized DORA response format for {internal_request_id}: {str(dora_response_data)[:100]}" + break + else: + response_str = ( + f"Empty response payload from DORA for {internal_request_id}." + ) + break + elif event["type"] == "ERROR": + response_str = f"Error event from DORA for {internal_request_id}: {event.get('value', event.get('error', 'Unknown DORA Error'))}" + print(response_str) + break + else: # Outer while loop timed out + print( + f"FastAPI Server (Req {internal_request_id}): Overall timeout waiting for DORA response." + ) + + completion_tokens = len(response_str) + total_tokens = prompt_tokens + completion_tokens return ChatCompletionResponse( - id="chatcmpl-1234", - object="chat.completion", - created=1234567890, + id=openai_chat_id, + created=current_timestamp, model=request.model, choices=[ - { - "index": 0, - "message": {"role": "assistant", "content": response_str}, - "finish_reason": "stop", - }, + ChatCompletionChoice( + index=0, + message=ChatCompletionChoiceMessage( + role="assistant", content=response_str + ), + finish_reason="stop", + ) ], - usage={ - "prompt_tokens": len(data), - "completion_tokens": len(response_str), - "total_tokens": len(data) + len(response_str), - }, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), ) @app.get("/v1/models") async def list_models(): - """TODO: Add docstring.""" return { "object": "list", "data": [ { - "id": "gpt-3.5-turbo", + "id": "dora-multi-stream-vision", "object": "model", - "created": 1677610602, - "owned_by": "openai", + "created": int(time.time()), + "owned_by": "dora-ai", + "permission": [], + "root": "dora-multi-stream-vision", + "parent": None, }, ], } -async def run_fastapi(): - """TODO: Add docstring.""" +async def run_fastapi_server_task(): config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="info") server = uvicorn.Server(config) + print("FastAPI Server: Uvicorn server starting.") + await server.serve() + print("FastAPI Server: Uvicorn server stopped.") - server = asyncio.gather(server.serve()) - while True: - await asyncio.sleep(1) - event = node.next(0.001) - if event["type"] == "STOP": - break + +async def run_dora_main_loop_task(): + if not node: + print("FastAPI Server: DORA node not initialized, DORA main loop skipped.") + return + print("FastAPI Server: DORA main loop listener started (for STOP event).") + try: + while True: + # This loop is primarily for the main "STOP" event for the FastAPI node itself. + # Individual request/response cycles are handled within the endpoint. + event = node.next(timeout=1.0) # Check for STOP periodically + if event is None: + await asyncio.sleep(0.01) # Yield control if no event + continue + if event["type"] == "STOP": + print( + "FastAPI Server: DORA STOP event received. Requesting server shutdown." + ) + # Attempt to gracefully shut down Uvicorn + # This is tricky; uvicorn's server.shutdown() or server.should_exit might be better + # For simplicity, we cancel the server task. + for task in asyncio.all_tasks(): + # Identify the server task more reliably if possible + if ( + task.get_coro().__name__ == "serve" + and hasattr(task.get_coro(), "cr_frame") + and isinstance( + task.get_coro().cr_frame.f_locals.get("self"), + uvicorn.Server, + ) + ): + task.cancel() + print( + "FastAPI Server: Uvicorn server task cancellation requested." + ) + break + # Handle other unexpected general inputs/errors for the FastAPI node if necessary + # elif event["type"] == "INPUT": + # print(f"FastAPI Server (DORA Main Loop): Unexpected DORA input on ID '{event['id']}'") + + except asyncio.CancelledError: + print("FastAPI Server: DORA main loop task cancelled.") + except Exception as e: + print(f"FastAPI Server: Error in DORA main loop: {e}") + finally: + print("FastAPI Server: DORA main loop listener finished.") + + +async def main_async_runner(): + server_task = asyncio.create_task(run_fastapi_server_task()) + + # Only run the DORA main loop if the node was initialized. + # This loop is mainly for the STOP event. + dora_listener_task = None + if node: + dora_listener_task = asyncio.create_task(run_dora_main_loop_task()) + tasks_to_wait_for = [server_task, dora_listener_task] + else: + tasks_to_wait_for = [server_task] + + done, pending = await asyncio.wait( + tasks_to_wait_for, + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + print(f"FastAPI Server: Cancelling pending task: {task.get_name()}") + task.cancel() + + if pending: + await asyncio.gather(*pending, return_exceptions=True) + print("FastAPI Server: Application shutdown complete.") def main(): - """TODO: Add docstring.""" - asyncio.run(run_fastapi()) + print("FastAPI Server: Starting application...") + try: + asyncio.run(main_async_runner()) + except KeyboardInterrupt: + print("FastAPI Server: Keyboard interrupt received. Shutting down.") + finally: + print("FastAPI Server: Exited main function.") if __name__ == "__main__": - asyncio.run(run_fastapi()) + main() diff --git a/node-hub/dora-openai-server/pyproject.toml b/node-hub/dora-openai-server/pyproject.toml index 8b29cec9..6ec73daf 100644 --- a/node-hub/dora-openai-server/pyproject.toml +++ b/node-hub/dora-openai-server/pyproject.toml @@ -2,8 +2,8 @@ name = "dora-openai-server" version = "0.3.11" authors = [ - { name = "Haixuan Xavier Tao", email = "tao.xavier@outlook.com" }, - { name = "Enzo Le Van", email = "dev@enzo-le-van.fr" }, + { name = "Haixuan Xavier Tao", email = "tao.xavier@outlook.com" }, + { name = "Enzo Le Van", email = "dev@enzo-le-van.fr" }, ] description = "Dora OpenAI API Server" license = { text = "MIT" } @@ -11,14 +11,13 @@ readme = "README.md" requires-python = ">=3.8" dependencies = [ - "dora-rs >= 0.3.9", - "numpy < 2.0.0", - "pyarrow >= 5.0.0", - - "fastapi >= 0.115", - "asyncio >= 3.4", - "uvicorn >= 0.31", - "pydantic >= 2.9", + "dora-rs >= 0.3.9", + "numpy < 2.0.0", + "pyarrow >= 5.0.0", + "fastapi >= 0.115", + "asyncio >= 3.4", + "uvicorn >= 0.31", + "pydantic >= 2.9", ] [dependency-groups] @@ -29,7 +28,6 @@ dora-openai-server = "dora_openai_server.main:main" [tool.ruff.lint] extend-select = [ - "D", # pydocstyle "UP", # Ruff's UP rule "PERF", # Ruff's PERF rule "RET", # Ruff's RET rule diff --git a/node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py b/node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py index 3125858c..898b444d 100644 --- a/node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py +++ b/node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py @@ -62,29 +62,118 @@ if ADAPTER_PATH != "": processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) -def generate(frames: dict, question, history, past_key_values=None, image_id=None): +def generate( + frames: dict, texts: list[str], history, past_key_values=None, image_id=None +): """Generate the response to the question given the image using Qwen2 model.""" if image_id is not None: images = [frames[image_id]] else: images = list(frames.values()) - messages = [ - { - "role": "user", - "content": [ + + messages = [] + + for text in texts: + if text.startswith("<|system|>\n"): + messages.append( { - "type": "image", - "image": image, - "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, - "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, + "role": "system", + "content": [ + {"type": "text", "text": text.replace("<|system|>\n", "")}, + ], } - for image in images - ] - + [ - {"type": "text", "text": question}, - ], - }, - ] + ) + elif text.startswith("<|assistant|>\n"): + messages.append( + { + "role": "assistant", + "content": [ + {"type": "text", "text": text.replace("<|assistant|>\n", "")}, + ], + } + ) + elif text.startswith("<|tool|>\n"): + messages.append( + { + "role": "tool", + "content": [ + {"type": "text", "text": text.replace("<|tool|>\n", "")}, + ], + } + ) + elif text.startswith("<|user|>\n<|im_start|>\n"): + messages.append( + { + "role": "user", + "content": [ + { + "type": "text", + "text": text.replace("<|user|>\n<|im_start|>\n", ""), + }, + ], + } + ) + elif text.startswith("<|user|>\n<|vision_start|>\n"): + # Handle the case where the text starts with <|user|>\n<|vision_start|> + image_url = text.replace("<|user|>\n<|vision_start|>\n", "") + + # If the last message was from the user, append the image URL to it + if messages[-1]["role"] == "user": + messages[-1]["content"].append( + { + "type": "image", + "image": image_url, + } + ) + else: + messages.append( + { + "role": "user", + "content": [ + { + "type": "image", + "image": image_url, + }, + ], + } + ) + else: + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + ], + } + ) + + # If the last message was from the user, append the image URL to it + if messages[-1]["role"] == "user": + messages[-1]["content"] += [ + { + "type": "image", + "image": image, + "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, + "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, + } + for image in images + ] + else: + messages.append( + { + "role": "user", + "content": [ + { + "type": "image", + "image": image, + "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, + "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, + } + for image in images + ], + } + ) + tmp_history = history + messages # Preparation for inference text = processor.apply_chat_template( @@ -120,19 +209,13 @@ def generate(frames: dict, question, history, past_key_values=None, image_id=Non clean_up_tokenization_spaces=False, ) if HISTORY: - history += [ - { - "role": "user", - "content": [ - {"type": "text", "text": question}, - ], - }, + history = tmp_history + [ { "role": "assistant", "content": [ {"type": "text", "text": output_text[0]}, ], - }, + } ] return output_text[0], history, past_key_values @@ -207,24 +290,22 @@ def main(): elif "text" in event_id: if len(event["value"]) > 0: - text = event["value"][0].as_py() + texts = event["value"].to_pylist() image_id = event["metadata"].get("image_id", None) else: - text = cached_text - words = text.split() + texts = cached_text + words = texts[-1].split() if len(ACTIVATION_WORDS) > 0 and all( word not in ACTIVATION_WORDS for word in words ): continue - cached_text = text + cached_text = texts - if len(frames.keys()) == 0: - continue # set the max number of tiles in `max_num` response, history, past_key_values = generate( frames, - text, + texts, history, past_key_values, image_id, diff --git a/node-hub/openai-proxy-server/src/main.rs b/node-hub/openai-proxy-server/src/main.rs index c0714886..5d0cc4a2 100644 --- a/node-hub/openai-proxy-server/src/main.rs +++ b/node-hub/openai-proxy-server/src/main.rs @@ -1,4 +1,10 @@ -use dora_node_api::{self, dora_core::config::DataId, merged::MergeExternalSend, DoraNode, Event}; +use dora_node_api::{ + self, + arrow::array::{AsArray, StringArray}, + dora_core::config::DataId, + merged::MergeExternalSend, + DoraNode, Event, +}; use eyre::{Context, ContextCompat}; use futures::{ @@ -14,7 +20,7 @@ use hyper::{ }; use message::{ ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage, - ChatCompletionRequest, ChatCompletionRequestMessage, Usage, + ChatCompletionRequest, Usage, }; use std::{ collections::VecDeque, @@ -71,7 +77,7 @@ async fn main() -> eyre::Result<()> { let merged = events.merge_external_send(server_events); let events = futures::executor::block_on_stream(merged); - let output_id = DataId::from("chat_completion_request".to_owned()); + let output_id = DataId::from("text".to_owned()); let mut reply_channels = VecDeque::new(); for event in events { @@ -82,45 +88,15 @@ async fn main() -> eyre::Result<()> { break; } ServerEvent::ChatCompletionRequest { request, reply } => { - let message = request - .messages - .into_iter() - .find_map(|m| match m { - ChatCompletionRequestMessage::User(message) => Some(message), - _ => None, - }) - .context("no user message found"); - match message { - Ok(message) => match message.content() { - message::ChatCompletionUserMessageContent::Text(content) => { - node.send_output_bytes( - output_id.clone(), - Default::default(), - content.len(), - content.as_bytes(), - ) - .context("failed to send dora output")?; - reply_channels.push_back(( - reply, - content.as_bytes().len() as u64, - request.model, - )); - } - message::ChatCompletionUserMessageContent::Parts(_) => { - if reply - .send(Err(eyre::eyre!("unsupported message content"))) - .is_err() - { - tracing::warn!("failed to send chat completion reply because channel closed early"); - }; - } - }, - Err(err) => { - if reply.send(Err(err)).is_err() { - tracing::warn!("failed to send chat completion reply error because channel closed early"); - } - } - } + let texts = request.to_texts(); + node.send_output( + output_id.clone(), + Default::default(), + StringArray::from(texts), + ) + .context("failed to send dora output")?; + + reply_channels.push_back((reply, 0 as u64, request.model)); } }, dora_node_api::merged::MergedEvent::Dora(event) => match event { @@ -130,35 +106,42 @@ async fn main() -> eyre::Result<()> { metadata: _, } => { match id.as_str() { - "completion_reply" => { + "text" => { let (reply_channel, prompt_tokens, model) = reply_channels.pop_front().context("no reply channel")?; - let data = TryFrom::try_from(&data) - .with_context(|| format!("invalid reply data: {data:?}")) - .map(|s: &[u8]| ChatCompletionObject { - id: format!("completion-{}", uuid::Uuid::new_v4()), - object: "chat.completion".to_string(), - created: chrono::Utc::now().timestamp() as u64, - model: model.unwrap_or_default(), - choices: vec![ChatCompletionObjectChoice { - index: 0, - message: ChatCompletionObjectMessage { - role: message::ChatCompletionRole::Assistant, - content: Some(String::from_utf8_lossy(s).to_string()), - tool_calls: Vec::new(), - function_call: None, - }, - finish_reason: message::FinishReason::stop, - logprobs: None, - }], - usage: Usage { - prompt_tokens, - completion_tokens: s.len() as u64, - total_tokens: prompt_tokens + s.len() as u64, + let data = data.as_string::(); + let string = data.iter().fold("".to_string(), |mut acc, s| { + if let Some(s) = s { + acc.push_str("\n"); + acc.push_str(s); + } + acc + }); + + let data = ChatCompletionObject { + id: format!("completion-{}", uuid::Uuid::new_v4()), + object: "chat.completion".to_string(), + created: chrono::Utc::now().timestamp() as u64, + model: model.unwrap_or_default(), + choices: vec![ChatCompletionObjectChoice { + index: 0, + message: ChatCompletionObjectMessage { + role: message::ChatCompletionRole::Assistant, + content: Some(string.to_string()), + tool_calls: Vec::new(), + function_call: None, }, - }); - - if reply_channel.send(data).is_err() { + finish_reason: message::FinishReason::stop, + logprobs: None, + }], + usage: Usage { + prompt_tokens, + completion_tokens: string.len() as u64, + total_tokens: prompt_tokens + string.len() as u64, + }, + }; + + if reply_channel.send(Ok(data)).is_err() { tracing::warn!("failed to send chat completion reply because channel closed early"); } } @@ -168,8 +151,11 @@ async fn main() -> eyre::Result<()> { Event::Stop => { break; } + Event::InputClosed { id, .. } => { + info!("Input channel closed for id: {}", id); + } event => { - println!("Event: {event:#?}") + eyre::bail!("unexpected event: {:#?}", event) } }, } diff --git a/node-hub/openai-proxy-server/src/message.rs b/node-hub/openai-proxy-server/src/message.rs index dff7e101..4c9eb99f 100644 --- a/node-hub/openai-proxy-server/src/message.rs +++ b/node-hub/openai-proxy-server/src/message.rs @@ -230,6 +230,15 @@ impl<'de> Deserialize<'de> for ChatCompletionRequest { } } +impl ChatCompletionRequest { + pub fn to_texts(&self) -> Vec { + self.messages + .iter() + .flat_map(|message| message.to_texts()) + .collect() + } +} + /// Message for comprising the conversation. #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] #[serde(tag = "role", rename_all = "lowercase")] @@ -308,6 +317,22 @@ impl ChatCompletionRequestMessage { ChatCompletionRequestMessage::Tool(_) => None, } } + + /// The contents of the message. + pub fn to_texts(&self) -> Vec { + match self { + ChatCompletionRequestMessage::System(message) => { + vec![String::from("<|system|>\n") + &message.content] + } + ChatCompletionRequestMessage::User(message) => message.content.to_texts(), + ChatCompletionRequestMessage::Assistant(message) => { + vec![String::from("<|assistant|>\n") + &message.content.clone().unwrap_or_default()] + } + ChatCompletionRequestMessage::Tool(message) => { + vec![String::from("<|tool|>\n") + &message.content.clone()] + } + } + } } /// Sampling methods used for chat completion requests. @@ -587,6 +612,25 @@ impl ChatCompletionUserMessageContent { ChatCompletionUserMessageContent::Parts(_) => "parts", } } + + pub fn to_texts(&self) -> Vec { + match self { + ChatCompletionUserMessageContent::Text(text) => { + vec![String::from("user: ") + &text.clone()] + } + ChatCompletionUserMessageContent::Parts(parts) => parts + .iter() + .map(|part| match part { + ContentPart::Text(text_part) => { + String::from("<|user|>\n<|im_start|>\n") + &text_part.text.clone() + } + ContentPart::Image(image) => { + String::from("<|user|>\n<|vision_start|>\n") + &image.image().url.clone() + } + }) + .collect(), + } + } } /// Define the content part of a user message.