| @@ -3,14 +3,14 @@ nodes: | |||
| build: cargo build -p dora-openai-proxy-server --release | |||
| path: ../../target/release/dora-openai-proxy-server | |||
| outputs: | |||
| - chat_completion_request | |||
| - text | |||
| inputs: | |||
| completion_reply: dora-echo/echo | |||
| text: dora-echo/echo | |||
| - id: dora-echo | |||
| build: pip install -e ../../node-hub/dora-echo | |||
| path: dora-echo | |||
| inputs: | |||
| echo: dora-openai-server/chat_completion_request | |||
| echo: dora-openai-server/text | |||
| outputs: | |||
| - echo | |||
| @@ -32,11 +32,69 @@ def test_chat_completion(user_input): | |||
| print(f"Error in chat completion: {e}") | |||
| def test_chat_completion_image_url(user_input): | |||
| """TODO: Add docstring.""" | |||
| try: | |||
| response = client.chat.completions.create( | |||
| model="gpt-3.5-turbo", | |||
| messages=[ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": "What is in this image?"}, | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" | |||
| }, | |||
| }, | |||
| ], | |||
| } | |||
| ], | |||
| ) | |||
| print("Chat Completion Response:") | |||
| print(response.choices[0].message.content) | |||
| except Exception as e: | |||
| print(f"Error in chat completion: {e}") | |||
| def test_chat_completion_image_base64(user_input): | |||
| """TODO: Add docstring.""" | |||
| try: | |||
| response = client.chat.completions.create( | |||
| model="gpt-3.5-turbo", | |||
| messages=[ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": "What is in this image?"}, | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": "" | |||
| }, | |||
| }, | |||
| ], | |||
| } | |||
| ], | |||
| ) | |||
| print("Chat Completion Response:") | |||
| print(response.choices[0].message.content) | |||
| except Exception as e: | |||
| print(f"Error in chat completion: {e}") | |||
| if __name__ == "__main__": | |||
| print("Testing API endpoints...") | |||
| test_list_models() | |||
| # test_list_models() | |||
| print("\n" + "=" * 50 + "\n") | |||
| chat_input = input("Enter a message for chat completion: ") | |||
| test_chat_completion(chat_input) | |||
| print("\n" + "=" * 50 + "\n") | |||
| test_chat_completion_image_url(chat_input) | |||
| print("\n" + "=" * 50 + "\n") | |||
| test_chat_completion_image_base64(chat_input) | |||
| print("\n" + "=" * 50 + "\n") | |||
| @@ -0,0 +1,16 @@ | |||
| nodes: | |||
| - id: dora-openai-server | |||
| build: cargo build -p dora-openai-proxy-server --release | |||
| path: ../../target/release/dora-openai-proxy-server | |||
| outputs: | |||
| - text | |||
| inputs: | |||
| text: dora-qwen2.5-vl/text | |||
| - id: dora-qwen2.5-vl | |||
| build: pip install -e ../../node-hub/dora-qwen2-5-vl | |||
| path: dora-qwen2-5-vl | |||
| inputs: | |||
| text: dora-openai-server/text | |||
| outputs: | |||
| - text | |||
| @@ -81,6 +81,20 @@ impl IntoArrow for NaiveTime { | |||
| } | |||
| } | |||
| impl IntoArrow for String { | |||
| type A = StringArray; | |||
| fn into_arrow(self) -> Self::A { | |||
| std::iter::once(Some(self)).collect() | |||
| } | |||
| } | |||
| impl IntoArrow for Vec<String> { | |||
| type A = StringArray; | |||
| fn into_arrow(self) -> Self::A { | |||
| StringArray::from(self) | |||
| } | |||
| } | |||
| impl IntoArrow for NaiveDateTime { | |||
| type A = arrow::array::TimestampNanosecondArray; | |||
| fn into_arrow(self) -> Self::A { | |||
| @@ -41,7 +41,7 @@ async fn main() -> eyre::Result<()> { | |||
| node.send_output( | |||
| mistral_output.clone(), | |||
| metadata.parameters, | |||
| output.into_arrow(), | |||
| output.as_str().into_arrow(), | |||
| )?; | |||
| } | |||
| other => eprintln!("Received input `{other}`"), | |||
| @@ -62,29 +62,118 @@ if ADAPTER_PATH != "": | |||
| processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | |||
| def generate(frames: dict, question, history, past_key_values=None, image_id=None): | |||
| def generate( | |||
| frames: dict, texts: list[str], history, past_key_values=None, image_id=None | |||
| ): | |||
| """Generate the response to the question given the image using Qwen2 model.""" | |||
| if image_id is not None: | |||
| images = [frames[image_id]] | |||
| else: | |||
| images = list(frames.values()) | |||
| messages = [ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| messages = [] | |||
| for text in texts: | |||
| if text.startswith("<|system|>\n"): | |||
| messages.append( | |||
| { | |||
| "type": "image", | |||
| "image": image, | |||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||
| "role": "system", | |||
| "content": [ | |||
| {"type": "text", "text": text.replace("<|system|>\n", "")}, | |||
| ], | |||
| } | |||
| for image in images | |||
| ] | |||
| + [ | |||
| {"type": "text", "text": question}, | |||
| ], | |||
| }, | |||
| ] | |||
| ) | |||
| elif text.startswith("<|assistant|>\n"): | |||
| messages.append( | |||
| { | |||
| "role": "assistant", | |||
| "content": [ | |||
| {"type": "text", "text": text.replace("<|assistant|>\n", "")}, | |||
| ], | |||
| } | |||
| ) | |||
| elif text.startswith("<|tool|>\n"): | |||
| messages.append( | |||
| { | |||
| "role": "tool", | |||
| "content": [ | |||
| {"type": "text", "text": text.replace("<|tool|>\n", "")}, | |||
| ], | |||
| } | |||
| ) | |||
| elif text.startswith("<|user|>\n<|im_start|>\n"): | |||
| messages.append( | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "type": "text", | |||
| "text": text.replace("<|user|>\n<|im_start|>\n", ""), | |||
| }, | |||
| ], | |||
| } | |||
| ) | |||
| elif text.startswith("<|user|>\n<|vision_start|>\n"): | |||
| # Handle the case where the text starts with <|user|>\n<|vision_start|> | |||
| image_url = text.replace("<|user|>\n<|vision_start|>\n", "") | |||
| # If the last message was from the user, append the image URL to it | |||
| if messages[-1]["role"] == "user": | |||
| messages[-1]["content"].append( | |||
| { | |||
| "type": "image", | |||
| "image": image_url, | |||
| } | |||
| ) | |||
| else: | |||
| messages.append( | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "type": "image", | |||
| "image": image_url, | |||
| }, | |||
| ], | |||
| } | |||
| ) | |||
| else: | |||
| messages.append( | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": text}, | |||
| ], | |||
| } | |||
| ) | |||
| # If the last message was from the user, append the image URL to it | |||
| if messages[-1]["role"] == "user": | |||
| messages[-1]["content"] += [ | |||
| { | |||
| "type": "image", | |||
| "image": image, | |||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||
| } | |||
| for image in images | |||
| ] | |||
| else: | |||
| messages.append( | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "type": "image", | |||
| "image": image, | |||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||
| } | |||
| for image in images | |||
| ], | |||
| } | |||
| ) | |||
| tmp_history = history + messages | |||
| # Preparation for inference | |||
| text = processor.apply_chat_template( | |||
| @@ -120,19 +209,13 @@ def generate(frames: dict, question, history, past_key_values=None, image_id=Non | |||
| clean_up_tokenization_spaces=False, | |||
| ) | |||
| if HISTORY: | |||
| history += [ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"type": "text", "text": question}, | |||
| ], | |||
| }, | |||
| history = tmp_history + [ | |||
| { | |||
| "role": "assistant", | |||
| "content": [ | |||
| {"type": "text", "text": output_text[0]}, | |||
| ], | |||
| }, | |||
| } | |||
| ] | |||
| return output_text[0], history, past_key_values | |||
| @@ -207,24 +290,22 @@ def main(): | |||
| elif "text" in event_id: | |||
| if len(event["value"]) > 0: | |||
| text = event["value"][0].as_py() | |||
| texts = event["value"].to_pylist() | |||
| image_id = event["metadata"].get("image_id", None) | |||
| else: | |||
| text = cached_text | |||
| words = text.split() | |||
| texts = cached_text | |||
| words = texts[-1].split() | |||
| if len(ACTIVATION_WORDS) > 0 and all( | |||
| word not in ACTIVATION_WORDS for word in words | |||
| ): | |||
| continue | |||
| cached_text = text | |||
| cached_text = texts | |||
| if len(frames.keys()) == 0: | |||
| continue | |||
| # set the max number of tiles in `max_num` | |||
| response, history, past_key_values = generate( | |||
| frames, | |||
| text, | |||
| texts, | |||
| history, | |||
| past_key_values, | |||
| image_id, | |||
| @@ -1,4 +1,10 @@ | |||
| use dora_node_api::{self, dora_core::config::DataId, merged::MergeExternalSend, DoraNode, Event}; | |||
| use dora_node_api::{ | |||
| self, | |||
| arrow::array::{AsArray, StringArray}, | |||
| dora_core::config::DataId, | |||
| merged::MergeExternalSend, | |||
| DoraNode, Event, | |||
| }; | |||
| use eyre::{Context, ContextCompat}; | |||
| use futures::{ | |||
| @@ -14,7 +20,7 @@ use hyper::{ | |||
| }; | |||
| use message::{ | |||
| ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage, | |||
| ChatCompletionRequest, ChatCompletionRequestMessage, Usage, | |||
| ChatCompletionRequest, Usage, | |||
| }; | |||
| use std::{ | |||
| collections::VecDeque, | |||
| @@ -71,7 +77,7 @@ async fn main() -> eyre::Result<()> { | |||
| let merged = events.merge_external_send(server_events); | |||
| let events = futures::executor::block_on_stream(merged); | |||
| let output_id = DataId::from("chat_completion_request".to_owned()); | |||
| let output_id = DataId::from("text".to_owned()); | |||
| let mut reply_channels = VecDeque::new(); | |||
| for event in events { | |||
| @@ -82,45 +88,15 @@ async fn main() -> eyre::Result<()> { | |||
| break; | |||
| } | |||
| ServerEvent::ChatCompletionRequest { request, reply } => { | |||
| let message = request | |||
| .messages | |||
| .into_iter() | |||
| .find_map(|m| match m { | |||
| ChatCompletionRequestMessage::User(message) => Some(message), | |||
| _ => None, | |||
| }) | |||
| .context("no user message found"); | |||
| match message { | |||
| Ok(message) => match message.content() { | |||
| message::ChatCompletionUserMessageContent::Text(content) => { | |||
| node.send_output_bytes( | |||
| output_id.clone(), | |||
| Default::default(), | |||
| content.len(), | |||
| content.as_bytes(), | |||
| ) | |||
| .context("failed to send dora output")?; | |||
| reply_channels.push_back(( | |||
| reply, | |||
| content.as_bytes().len() as u64, | |||
| request.model, | |||
| )); | |||
| } | |||
| message::ChatCompletionUserMessageContent::Parts(_) => { | |||
| if reply | |||
| .send(Err(eyre::eyre!("unsupported message content"))) | |||
| .is_err() | |||
| { | |||
| tracing::warn!("failed to send chat completion reply because channel closed early"); | |||
| }; | |||
| } | |||
| }, | |||
| Err(err) => { | |||
| if reply.send(Err(err)).is_err() { | |||
| tracing::warn!("failed to send chat completion reply error because channel closed early"); | |||
| } | |||
| } | |||
| } | |||
| let texts = request.to_texts(); | |||
| node.send_output( | |||
| output_id.clone(), | |||
| Default::default(), | |||
| StringArray::from(texts), | |||
| ) | |||
| .context("failed to send dora output")?; | |||
| reply_channels.push_back((reply, 0 as u64, request.model)); | |||
| } | |||
| }, | |||
| dora_node_api::merged::MergedEvent::Dora(event) => match event { | |||
| @@ -130,35 +106,42 @@ async fn main() -> eyre::Result<()> { | |||
| metadata: _, | |||
| } => { | |||
| match id.as_str() { | |||
| "completion_reply" => { | |||
| "text" => { | |||
| let (reply_channel, prompt_tokens, model) = | |||
| reply_channels.pop_front().context("no reply channel")?; | |||
| let data = TryFrom::try_from(&data) | |||
| .with_context(|| format!("invalid reply data: {data:?}")) | |||
| .map(|s: &[u8]| ChatCompletionObject { | |||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||
| object: "chat.completion".to_string(), | |||
| created: chrono::Utc::now().timestamp() as u64, | |||
| model: model.unwrap_or_default(), | |||
| choices: vec![ChatCompletionObjectChoice { | |||
| index: 0, | |||
| message: ChatCompletionObjectMessage { | |||
| role: message::ChatCompletionRole::Assistant, | |||
| content: Some(String::from_utf8_lossy(s).to_string()), | |||
| tool_calls: Vec::new(), | |||
| function_call: None, | |||
| }, | |||
| finish_reason: message::FinishReason::stop, | |||
| logprobs: None, | |||
| }], | |||
| usage: Usage { | |||
| prompt_tokens, | |||
| completion_tokens: s.len() as u64, | |||
| total_tokens: prompt_tokens + s.len() as u64, | |||
| let data = data.as_string::<i32>(); | |||
| let string = data.iter().fold("".to_string(), |mut acc, s| { | |||
| if let Some(s) = s { | |||
| acc.push_str("\n"); | |||
| acc.push_str(s); | |||
| } | |||
| acc | |||
| }); | |||
| let data = ChatCompletionObject { | |||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||
| object: "chat.completion".to_string(), | |||
| created: chrono::Utc::now().timestamp() as u64, | |||
| model: model.unwrap_or_default(), | |||
| choices: vec![ChatCompletionObjectChoice { | |||
| index: 0, | |||
| message: ChatCompletionObjectMessage { | |||
| role: message::ChatCompletionRole::Assistant, | |||
| content: Some(string.to_string()), | |||
| tool_calls: Vec::new(), | |||
| function_call: None, | |||
| }, | |||
| }); | |||
| if reply_channel.send(data).is_err() { | |||
| finish_reason: message::FinishReason::stop, | |||
| logprobs: None, | |||
| }], | |||
| usage: Usage { | |||
| prompt_tokens, | |||
| completion_tokens: string.len() as u64, | |||
| total_tokens: prompt_tokens + string.len() as u64, | |||
| }, | |||
| }; | |||
| if reply_channel.send(Ok(data)).is_err() { | |||
| tracing::warn!("failed to send chat completion reply because channel closed early"); | |||
| } | |||
| } | |||
| @@ -168,8 +151,11 @@ async fn main() -> eyre::Result<()> { | |||
| Event::Stop(_) => { | |||
| break; | |||
| } | |||
| Event::InputClosed { id, .. } => { | |||
| info!("Input channel closed for id: {}", id); | |||
| } | |||
| event => { | |||
| println!("Event: {event:#?}") | |||
| eyre::bail!("unexpected event: {:#?}", event) | |||
| } | |||
| }, | |||
| } | |||
| @@ -230,6 +230,15 @@ impl<'de> Deserialize<'de> for ChatCompletionRequest { | |||
| } | |||
| } | |||
| impl ChatCompletionRequest { | |||
| pub fn to_texts(&self) -> Vec<String> { | |||
| self.messages | |||
| .iter() | |||
| .flat_map(|message| message.to_texts()) | |||
| .collect() | |||
| } | |||
| } | |||
| /// Message for comprising the conversation. | |||
| #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] | |||
| #[serde(tag = "role", rename_all = "lowercase")] | |||
| @@ -308,6 +317,22 @@ impl ChatCompletionRequestMessage { | |||
| ChatCompletionRequestMessage::Tool(_) => None, | |||
| } | |||
| } | |||
| /// The contents of the message. | |||
| pub fn to_texts(&self) -> Vec<String> { | |||
| match self { | |||
| ChatCompletionRequestMessage::System(message) => { | |||
| vec![String::from("<|system|>\n") + &message.content] | |||
| } | |||
| ChatCompletionRequestMessage::User(message) => message.content.to_texts(), | |||
| ChatCompletionRequestMessage::Assistant(message) => { | |||
| vec![String::from("<|assistant|>\n") + &message.content.clone().unwrap_or_default()] | |||
| } | |||
| ChatCompletionRequestMessage::Tool(message) => { | |||
| vec![String::from("<|tool|>\n") + &message.content.clone()] | |||
| } | |||
| } | |||
| } | |||
| } | |||
| /// Sampling methods used for chat completion requests. | |||
| @@ -587,6 +612,25 @@ impl ChatCompletionUserMessageContent { | |||
| ChatCompletionUserMessageContent::Parts(_) => "parts", | |||
| } | |||
| } | |||
| pub fn to_texts(&self) -> Vec<String> { | |||
| match self { | |||
| ChatCompletionUserMessageContent::Text(text) => { | |||
| vec![String::from("user: ") + &text.clone()] | |||
| } | |||
| ChatCompletionUserMessageContent::Parts(parts) => parts | |||
| .iter() | |||
| .map(|part| match part { | |||
| ContentPart::Text(text_part) => { | |||
| String::from("<|user|>\n<|im_start|>\n") + &text_part.text.clone() | |||
| } | |||
| ContentPart::Image(image) => { | |||
| String::from("<|user|>\n<|vision_start|>\n") + &image.image().url.clone() | |||
| } | |||
| }) | |||
| .collect(), | |||
| } | |||
| } | |||
| } | |||
| /// Define the content part of a user message. | |||