Browse Source

Adding vision to openai server

tags/v0.3.12-rc0
haixuantao haixuantao 7 months ago
parent
commit
671c70e254
8 changed files with 544 additions and 171 deletions
  1. +3
    -3
      examples/openai-server/dataflow-rust.yml
  2. +59
    -1
      examples/openai-server/openai_api_client.py
  3. +16
    -0
      examples/openai-server/qwenvl.yml
  4. +14
    -0
      libraries/arrow-convert/src/into_impls.rs
  5. +261
    -84
      node-hub/dora-openai-server/dora_openai_server/main.py
  6. +96
    -15
      node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py
  7. +51
    -68
      node-hub/openai-proxy-server/src/main.rs
  8. +44
    -0
      node-hub/openai-proxy-server/src/message.rs

+ 3
- 3
examples/openai-server/dataflow-rust.yml View File

@@ -3,14 +3,14 @@ nodes:
build: cargo build -p dora-openai-proxy-server --release
path: ../../target/release/dora-openai-proxy-server
outputs:
- chat_completion_request
- text
inputs:
completion_reply: dora-echo/echo
text: dora-echo/echo

- id: dora-echo
build: pip install -e ../../node-hub/dora-echo
path: dora-echo
inputs:
echo: dora-openai-server/chat_completion_request
echo: dora-openai-server/text
outputs:
- echo

+ 59
- 1
examples/openai-server/openai_api_client.py View File

@@ -32,11 +32,69 @@ def test_chat_completion(user_input):
print(f"Error in chat completion: {e}")


def test_chat_completion_image_url(user_input):
"""TODO: Add docstring."""
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
},
},
],
}
],
)
print("Chat Completion Response:")
print(response.choices[0].message.content)
except Exception as e:
print(f"Error in chat completion: {e}")


def test_chat_completion_image_base64(user_input):
"""TODO: Add docstring."""
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
},
},
],
}
],
)
print("Chat Completion Response:")
print(response.choices[0].message.content)
except Exception as e:
print(f"Error in chat completion: {e}")


if __name__ == "__main__":
print("Testing API endpoints...")
test_list_models()
# test_list_models()
print("\n" + "=" * 50 + "\n")

chat_input = input("Enter a message for chat completion: ")
test_chat_completion(chat_input)

print("\n" + "=" * 50 + "\n")

test_chat_completion_image_url(chat_input)
print("\n" + "=" * 50 + "\n")
test_chat_completion_image_base64(chat_input)
print("\n" + "=" * 50 + "\n")

+ 16
- 0
examples/openai-server/qwenvl.yml View File

@@ -0,0 +1,16 @@
nodes:
- id: dora-openai-server
build: cargo build -p dora-openai-proxy-server --release
path: ../../target/release/dora-openai-proxy-server
outputs:
- text
inputs:
text: dora-qwen2.5-vl/text

- id: dora-qwen2.5-vl
build: pip install -e ../../node-hub/dora-qwen2-5-vl
path: dora-qwen2-5-vl
inputs:
text: dora-openai-server/text
outputs:
- text

+ 14
- 0
libraries/arrow-convert/src/into_impls.rs View File

@@ -57,6 +57,20 @@ impl IntoArrow for &str {
}
}

impl IntoArrow for String {
type A = StringArray;
fn into_arrow(self) -> Self::A {
std::iter::once(Some(self)).collect()
}
}

impl IntoArrow for Vec<String> {
type A = StringArray;
fn into_arrow(self) -> Self::A {
StringArray::from(self)
}
}

impl IntoArrow for () {
type A = arrow::array::NullArray;



+ 261
- 84
node-hub/dora-openai-server/dora_openai_server/main.py View File

@@ -1,140 +1,317 @@
"""TODO: Add docstring."""
"""
FastAPI server with OpenAI compatibility and DORA integration,
sending text and image data on separate DORA topics.
"""

import ast
import asyncio
from typing import List, Optional
import base64
import uuid # For generating unique request IDs
import time # For timestamps
from typing import List, Optional, Union, Dict, Any, Literal

import pyarrow as pa
import uvicorn
from dora import Node
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

DORA_RESPONSE_TIMEOUT = 10
app = FastAPI()
# --- DORA Configuration ---
DORA_RESPONSE_TIMEOUT_SECONDS = 20
DORA_TEXT_OUTPUT_TOPIC = "user_text_input"
DORA_IMAGE_OUTPUT_TOPIC = "user_image_input"
DORA_RESPONSE_INPUT_TOPIC = "chat_completion_result" # Topic FastAPI listens on

app = FastAPI(
title="DORA OpenAI-Compatible Demo Server (Separate Topics)",
description="Sends text and image data on different DORA topics and awaits a consolidated response.",
)

class ChatCompletionMessage(BaseModel):
"""TODO: Add docstring."""
# --- Pydantic Models ---
class ImageUrl(BaseModel):
url: str
detail: Optional[str] = "auto"

role: str
content: str
class ContentPartText(BaseModel):
type: Literal["text"]
text: str

class ContentPartImage(BaseModel):
type: Literal["image_url"]
image_url: ImageUrl

class ChatCompletionRequest(BaseModel):
"""TODO: Add docstring."""
ContentPart = Union[ContentPartText, ContentPartImage]

class ChatCompletionMessage(BaseModel):
role: str
content: Union[str, List[ContentPart]]

class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatCompletionMessage]
temperature: Optional[float] = 1.0
max_tokens: Optional[int] = 100

class ChatCompletionChoiceMessage(BaseModel):
role: str
content: str

class ChatCompletionResponse(BaseModel):
"""TODO: Add docstring."""
class ChatCompletionChoice(BaseModel):
index: int
message: ChatCompletionChoiceMessage
finish_reason: str
logprobs: Optional[Any] = None

class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int

class ChatCompletionResponse(BaseModel):
id: str
object: str
object: str = "chat.completion"
created: int
model: str
choices: List[dict]
usage: dict
choices: List[ChatCompletionChoice]
usage: Usage
system_fingerprint: Optional[str] = None

# --- DORA Node Initialization ---
# This dictionary will hold unmatched responses if we implement more robust concurrent handling.
# For now, it's a placeholder for future improvement.
# unmatched_dora_responses = {}

node = Node() # provide the name to connect to the dataflow if dynamic node
try:
node = Node()
print("FastAPI Server: DORA Node initialized.")
except Exception as e:
print(f"FastAPI Server: Failed to initialize DORA Node. Running in standalone API mode. Error: {e}")
node = None


@app.post("/v1/chat/completions")
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
"""TODO: Add docstring."""
data = next(
(msg.content for msg in request.messages if msg.role == "user"),
"No user message found.",
)
internal_request_id = str(uuid.uuid4())
openai_chat_id = f"chatcmpl-{internal_request_id}"
current_timestamp = int(time.time())

# Convert user_message to Arrow array
# user_message_array = pa.array([user_message])
# Publish user message to dora-echo
# node.send_output("user_query", user_message_array)
print(f"FastAPI Server: Processing request_id: {internal_request_id}")

try:
data = ast.literal_eval(data)
except ValueError:
print("Passing input as string")
except SyntaxError:
print("Passing input as string")
if isinstance(data, list):
data = pa.array(data) # initialize pyarrow array
elif isinstance(data, str) or isinstance(data, int) or isinstance(data, float) or isinstance(data, dict):
data = pa.array([data])
user_text_parts = []
user_image_bytes: Optional[bytes] = None
user_image_content_type: Optional[str] = None
data_sent_to_dora = False

for message in reversed(request.messages):
if message.role == "user":
if isinstance(message.content, str):
user_text_parts.append(message.content)
elif isinstance(message.content, list):
for part in message.content:
if part.type == "text":
user_text_parts.append(part.text)
elif part.type == "image_url":
if user_image_bytes: # Use only the first image
print(f"FastAPI Server (Req {internal_request_id}): Warning - Multiple images found, using the first one.")
continue
image_url_data = part.image_url.url
if image_url_data.startswith("data:image"):
try:
header, encoded_data = image_url_data.split(",", 1)
user_image_content_type = header.split(":")[1].split(";")[0]
user_image_bytes = base64.b64decode(encoded_data)
print(f"FastAPI Server (Req {internal_request_id}): Decoded image {user_image_content_type}, size: {len(user_image_bytes)} bytes")
except Exception as e:
print(f"FastAPI Server (Req {internal_request_id}): Error decoding base64 image: {e}")
raise HTTPException(status_code=400, detail=f"Invalid base64 image data: {e}")
else:
print(f"FastAPI Server (Req {internal_request_id}): Warning - Remote image URL '{image_url_data}' ignored. Only data URIs supported.")
# Consider if you want to break after the first user message or aggregate all
# break

final_user_text = "\n".join(reversed(user_text_parts)) if user_text_parts else ""
prompt_tokens = len(final_user_text)

if node:
if final_user_text:
text_payload = {"request_id": internal_request_id, "text": final_user_text}
arrow_text_data = pa.array([text_payload])
node.send_output(DORA_TEXT_OUTPUT_TOPIC, arrow_text_data)
print(f"FastAPI Server (Req {internal_request_id}): Sent text to DORA topic '{DORA_TEXT_OUTPUT_TOPIC}'.")
data_sent_to_dora = True

if user_image_bytes:
image_payload = {
"request_id": internal_request_id,
"image_bytes": user_image_bytes,
"image_content_type": user_image_content_type or "application/octet-stream"
}
arrow_image_data = pa.array([image_payload])
node.send_output(DORA_IMAGE_OUTPUT_TOPIC, arrow_image_data)
print(f"FastAPI Server (Req {internal_request_id}): Sent image to DORA topic '{DORA_IMAGE_OUTPUT_TOPIC}'.")
prompt_tokens += len(user_image_bytes) # Crude image token approximation
data_sent_to_dora = True

response_str: str
if not data_sent_to_dora:
if node is None:
response_str = "DORA node not available. Cannot process request."
else:
response_str = "No user text or image found to send to DORA."
print(f"FastAPI Server (Req {internal_request_id}): {response_str}")
else:
data = pa.array(data) # initialize pyarrow array
node.send_output("v1/chat/completions", data)

# Wait for response from dora-echo
while True:
event = node.next(timeout=DORA_RESPONSE_TIMEOUT)
if event["type"] == "ERROR":
response_str = "No response received. Err: " + event["value"][0].as_py()
break
if event["type"] == "INPUT" and event["id"] == "v1/chat/completions":
response = event["value"]
response_str = response[0].as_py() if response else "No response received"
break
print(f"FastAPI Server (Req {internal_request_id}): Waiting for response from DORA on topic '{DORA_RESPONSE_INPUT_TOPIC}'...")
response_str = f"Timeout: No response from DORA for request_id {internal_request_id} within {DORA_RESPONSE_TIMEOUT_SECONDS}s."
# WARNING: This blocking `node.next()` loop is not ideal for highly concurrent requests
# in a single FastAPI worker process, as one request might block others or consume
# a response meant for another if `request_id` matching isn't perfect or fast enough.
# A more robust solution would involve a dedicated listener task and async Futures/Queues.
start_wait_time = time.monotonic()
while time.monotonic() - start_wait_time < DORA_RESPONSE_TIMEOUT_SECONDS:
remaining_timeout = DORA_RESPONSE_TIMEOUT_SECONDS - (time.monotonic() - start_wait_time)
if remaining_timeout <= 0: break

event = node.next(timeout=min(1.0, remaining_timeout)) # Poll with a smaller timeout
if event is None: # Timeout for this poll iteration
continue

if event["type"] == "INPUT" and event["id"] == DORA_RESPONSE_INPUT_TOPIC:
response_value_arrow = event["value"]
if response_value_arrow and len(response_value_arrow) > 0:
dora_response_data = response_value_arrow[0].as_py() # Expecting a dict
if isinstance(dora_response_data, dict):
resp_request_id = dora_response_data.get("request_id")
if resp_request_id == internal_request_id:
response_str = dora_response_data.get("response_text", f"DORA response for {internal_request_id} missing 'response_text'.")
print(f"FastAPI Server (Req {internal_request_id}): Received correlated DORA response.")
break # Correct response received
else:
# This response is for another request. Ideally, store it.
print(f"FastAPI Server (Req {internal_request_id}): Received DORA response for different request_id '{resp_request_id}'. Discarding and waiting. THIS IS A CONCURRENCY ISSUE.")
# unmatched_dora_responses[resp_request_id] = dora_response_data # Example of storing
else:
response_str = f"Unrecognized DORA response format for {internal_request_id}: {str(dora_response_data)[:100]}"
break
else:
response_str = f"Empty response payload from DORA for {internal_request_id}."
break
elif event["type"] == "ERROR":
response_str = f"Error event from DORA for {internal_request_id}: {event.get('value', event.get('error', 'Unknown DORA Error'))}"
print(response_str)
break
else: # Outer while loop timed out
print(f"FastAPI Server (Req {internal_request_id}): Overall timeout waiting for DORA response.")


completion_tokens = len(response_str)
total_tokens = prompt_tokens + completion_tokens

return ChatCompletionResponse(
id="chatcmpl-1234",
object="chat.completion",
created=1234567890,
id=openai_chat_id,
created=current_timestamp,
model=request.model,
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": response_str},
"finish_reason": "stop",
},
ChatCompletionChoice(
index=0,
message=ChatCompletionChoiceMessage(role="assistant", content=response_str),
finish_reason="stop",
)
],
usage={
"prompt_tokens": len(data),
"completion_tokens": len(response_str),
"total_tokens": len(data) + len(response_str),
},
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
),
)


@app.get("/v1/models")
async def list_models():
"""TODO: Add docstring."""
return {
"object": "list",
"data": [
{
"id": "gpt-3.5-turbo",
"object": "model",
"created": 1677610602,
"owned_by": "openai",
"id": "dora-multi-stream-vision",
"object": "model", "created": int(time.time()), "owned_by": "dora-ai",
"permission": [], "root": "dora-multi-stream-vision", "parent": None,
},
],
}


async def run_fastapi():
"""TODO: Add docstring."""
async def run_fastapi_server_task():
config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="info")
server = uvicorn.Server(config)
print("FastAPI Server: Uvicorn server starting.")
await server.serve()
print("FastAPI Server: Uvicorn server stopped.")

server = asyncio.gather(server.serve())
while True:
await asyncio.sleep(1)
event = node.next(0.001)
if event["type"] == "STOP":
break
async def run_dora_main_loop_task():
if not node:
print("FastAPI Server: DORA node not initialized, DORA main loop skipped.")
return
print("FastAPI Server: DORA main loop listener started (for STOP event).")
try:
while True:
# This loop is primarily for the main "STOP" event for the FastAPI node itself.
# Individual request/response cycles are handled within the endpoint.
event = node.next(timeout=1.0) # Check for STOP periodically
if event is None:
await asyncio.sleep(0.01) # Yield control if no event
continue
if event["type"] == "STOP":
print("FastAPI Server: DORA STOP event received. Requesting server shutdown.")
# Attempt to gracefully shut down Uvicorn
# This is tricky; uvicorn's server.shutdown() or server.should_exit might be better
# For simplicity, we cancel the server task.
for task in asyncio.all_tasks():
# Identify the server task more reliably if possible
if task.get_coro().__name__ == 'serve' and hasattr(task.get_coro(), 'cr_frame') and \
isinstance(task.get_coro().cr_frame.f_locals.get('self'), uvicorn.Server):
task.cancel()
print("FastAPI Server: Uvicorn server task cancellation requested.")
break
# Handle other unexpected general inputs/errors for the FastAPI node if necessary
# elif event["type"] == "INPUT":
# print(f"FastAPI Server (DORA Main Loop): Unexpected DORA input on ID '{event['id']}'")

except asyncio.CancelledError:
print("FastAPI Server: DORA main loop task cancelled.")
except Exception as e:
print(f"FastAPI Server: Error in DORA main loop: {e}")
finally:
print("FastAPI Server: DORA main loop listener finished.")

def main():
"""TODO: Add docstring."""
asyncio.run(run_fastapi())
async def main_async_runner():
server_task = asyncio.create_task(run_fastapi_server_task())
# Only run the DORA main loop if the node was initialized.
# This loop is mainly for the STOP event.
dora_listener_task = None
if node:
dora_listener_task = asyncio.create_task(run_dora_main_loop_task())
tasks_to_wait_for = [server_task, dora_listener_task]
else:
tasks_to_wait_for = [server_task]

done, pending = await asyncio.wait(
tasks_to_wait_for, return_when=asyncio.FIRST_COMPLETED,
)

for task in pending:
print(f"FastAPI Server: Cancelling pending task: {task.get_name()}")
task.cancel()
if pending:
await asyncio.gather(*pending, return_exceptions=True)
print("FastAPI Server: Application shutdown complete.")

def main():
print("FastAPI Server: Starting application...")
try:
asyncio.run(main_async_runner())
except KeyboardInterrupt:
print("FastAPI Server: Keyboard interrupt received. Shutting down.")
finally:
print("FastAPI Server: Exited main function.")

if __name__ == "__main__":
asyncio.run(run_fastapi())
main()

+ 96
- 15
node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py View File

@@ -62,14 +62,100 @@ if ADAPTER_PATH != "":
processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)


def generate(frames: dict, question, history, past_key_values=None, image_id=None):
def generate(frames: dict, texts: list[str], history, past_key_values=None, image_id=None):
"""Generate the response to the question given the image using Qwen2 model."""
if image_id is not None:
images = [frames[image_id]]
else:
images = list(frames.values())
messages = [
{

messages = []

for text in texts:
if text.startswith("<|system|>\n"):
messages.append(
{
"role": "system",
"content": [
{"type": "text", "text": text.replace("<|system|>\n", "")},
],
}
)
elif text.startswith("<|assistant|>\n"):
messages.append(
{
"role": "assistant",
"content": [
{"type": "text", "text": text.replace("<|assistant|>\n", "")},
],
}
)
elif text.startswith("<|tool|>\n"):
messages.append(
{
"role": "tool",
"content": [
{"type": "text", "text": text.replace("<|tool|>\n", "")},
],
}
)
elif text.startswith("<|user|>\n<|im_start|>\n"):
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": text.replace("<|user|>\n<|im_start|>\n", "")},
],
}
)
elif text.startswith("<|user|>\n<|vision_start|>\n"):
# Handle the case where the text starts with <|user|>\n<|vision_start|>
image_url = text.replace("<|user|>\n<|vision_start|>\n", "")
# If the last message was from the user, append the image URL to it
if messages[-1]["role"] == "user" :
messages[-1]["content"].append(
{
"type": "image",
"image": image_url,
}
)
else:
messages.append(
{
"role": "user",
"content": [
{
"type": "image",
"image": image_url,
},
],
}
)
else:
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
)

# If the last message was from the user, append the image URL to it
if messages[-1]["role"] == "user" :
messages[-1]["content"] += [
{
"type": "image",
"image": image,
"resized_height": image.size[1] * IMAGE_RESIZE_RATIO,
"resized_width": image.size[0] * IMAGE_RESIZE_RATIO,
}
for image in images
]
else:
messages.append(
{
"role": "user",
"content": [
{
@@ -80,11 +166,8 @@ def generate(frames: dict, question, history, past_key_values=None, image_id=Non
}
for image in images
]
+ [
{"type": "text", "text": question},
],
},
]
})
tmp_history = history + messages
# Preparation for inference
text = processor.apply_chat_template(
@@ -207,24 +290,22 @@ def main():

elif "text" in event_id:
if len(event["value"]) > 0:
text = event["value"][0].as_py()
texts = event["value"].to_pylist()
image_id = event["metadata"].get("image_id", None)
else:
text = cached_text
words = text.split()
texts = cached_text
words = texts[-1].split()
if len(ACTIVATION_WORDS) > 0 and all(
word not in ACTIVATION_WORDS for word in words
):
continue

cached_text = text
cached_text = texts

if len(frames.keys()) == 0:
continue
# set the max number of tiles in `max_num`
response, history, past_key_values = generate(
frames,
text,
texts,
history,
past_key_values,
image_id,


+ 51
- 68
node-hub/openai-proxy-server/src/main.rs View File

@@ -1,4 +1,10 @@
use dora_node_api::{self, dora_core::config::DataId, merged::MergeExternalSend, DoraNode, Event};
use dora_node_api::{
self,
arrow::array::{AsArray, StringArray},
dora_core::config::DataId,
merged::MergeExternalSend,
DoraNode, Event,
};

use eyre::{Context, ContextCompat};
use futures::{
@@ -14,7 +20,7 @@ use hyper::{
};
use message::{
ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
ChatCompletionRequest, ChatCompletionRequestMessage, Usage,
ChatCompletionRequest, Usage,
};
use std::{
collections::VecDeque,
@@ -71,7 +77,7 @@ async fn main() -> eyre::Result<()> {
let merged = events.merge_external_send(server_events);
let events = futures::executor::block_on_stream(merged);

let output_id = DataId::from("chat_completion_request".to_owned());
let output_id = DataId::from("text".to_owned());
let mut reply_channels = VecDeque::new();

for event in events {
@@ -82,45 +88,15 @@ async fn main() -> eyre::Result<()> {
break;
}
ServerEvent::ChatCompletionRequest { request, reply } => {
let message = request
.messages
.into_iter()
.find_map(|m| match m {
ChatCompletionRequestMessage::User(message) => Some(message),
_ => None,
})
.context("no user message found");
match message {
Ok(message) => match message.content() {
message::ChatCompletionUserMessageContent::Text(content) => {
node.send_output_bytes(
output_id.clone(),
Default::default(),
content.len(),
content.as_bytes(),
)
.context("failed to send dora output")?;
reply_channels.push_back((
reply,
content.as_bytes().len() as u64,
request.model,
));
}
message::ChatCompletionUserMessageContent::Parts(_) => {
if reply
.send(Err(eyre::eyre!("unsupported message content")))
.is_err()
{
tracing::warn!("failed to send chat completion reply because channel closed early");
};
}
},
Err(err) => {
if reply.send(Err(err)).is_err() {
tracing::warn!("failed to send chat completion reply error because channel closed early");
}
}
}
let texts = request.to_texts();
node.send_output(
output_id.clone(),
Default::default(),
StringArray::from(texts),
)
.context("failed to send dora output")?;

reply_channels.push_back((reply, 0 as u64, request.model));
}
},
dora_node_api::merged::MergedEvent::Dora(event) => match event {
@@ -130,35 +106,42 @@ async fn main() -> eyre::Result<()> {
metadata: _,
} => {
match id.as_str() {
"completion_reply" => {
"text" => {
let (reply_channel, prompt_tokens, model) =
reply_channels.pop_front().context("no reply channel")?;
let data = TryFrom::try_from(&data)
.with_context(|| format!("invalid reply data: {data:?}"))
.map(|s: &[u8]| ChatCompletionObject {
id: format!("completion-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: model.unwrap_or_default(),
choices: vec![ChatCompletionObjectChoice {
index: 0,
message: ChatCompletionObjectMessage {
role: message::ChatCompletionRole::Assistant,
content: Some(String::from_utf8_lossy(s).to_string()),
tool_calls: Vec::new(),
function_call: None,
},
finish_reason: message::FinishReason::stop,
logprobs: None,
}],
usage: Usage {
prompt_tokens,
completion_tokens: s.len() as u64,
total_tokens: prompt_tokens + s.len() as u64,
let data = data.as_string::<i32>();
let string = data.iter().fold("".to_string(), |mut acc, s| {
if let Some(s) = s {
acc.push_str("\n");
acc.push_str(s);
}
acc
});

let data = ChatCompletionObject {
id: format!("completion-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: model.unwrap_or_default(),
choices: vec![ChatCompletionObjectChoice {
index: 0,
message: ChatCompletionObjectMessage {
role: message::ChatCompletionRole::Assistant,
content: Some(string.to_string()),
tool_calls: Vec::new(),
function_call: None,
},
});

if reply_channel.send(data).is_err() {
finish_reason: message::FinishReason::stop,
logprobs: None,
}],
usage: Usage {
prompt_tokens,
completion_tokens: string.len() as u64,
total_tokens: prompt_tokens + string.len() as u64,
},
};

if reply_channel.send(Ok(data)).is_err() {
tracing::warn!("failed to send chat completion reply because channel closed early");
}
}


+ 44
- 0
node-hub/openai-proxy-server/src/message.rs View File

@@ -230,6 +230,15 @@ impl<'de> Deserialize<'de> for ChatCompletionRequest {
}
}

impl ChatCompletionRequest {
pub fn to_texts(&self) -> Vec<String> {
self.messages
.iter()
.flat_map(|message| message.to_texts())
.collect()
}
}

/// Message for comprising the conversation.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(tag = "role", rename_all = "lowercase")]
@@ -308,6 +317,22 @@ impl ChatCompletionRequestMessage {
ChatCompletionRequestMessage::Tool(_) => None,
}
}

/// The contents of the message.
pub fn to_texts(&self) -> Vec<String> {
match self {
ChatCompletionRequestMessage::System(message) => {
vec![String::from("<|system|>\n") + &message.content]
}
ChatCompletionRequestMessage::User(message) => message.content.to_texts(),
ChatCompletionRequestMessage::Assistant(message) => {
vec![String::from("<|assistant|>\n") + &message.content.clone().unwrap_or_default()]
}
ChatCompletionRequestMessage::Tool(message) => {
vec![String::from("<|tool|>\n") + &message.content.clone()]
}
}
}
}

/// Sampling methods used for chat completion requests.
@@ -587,6 +612,25 @@ impl ChatCompletionUserMessageContent {
ChatCompletionUserMessageContent::Parts(_) => "parts",
}
}

pub fn to_texts(&self) -> Vec<String> {
match self {
ChatCompletionUserMessageContent::Text(text) => {
vec![String::from("user: ") + &text.clone()]
}
ChatCompletionUserMessageContent::Parts(parts) => parts
.iter()
.map(|part| match part {
ContentPart::Text(text_part) => {
String::from("<|user|>\n<|im_start|>\n") + &text_part.text.clone()
}
ContentPart::Image(image) => {
String::from("<|user|>\n<|vision_start|>\n") + &image.image().url.clone()
}
})
.collect(),
}
}
}

/// Define the content part of a user message.


Loading…
Cancel
Save