| @@ -1165,9 +1165,9 @@ dependencies = [ | |||||
| [[package]] | [[package]] | ||||
| name = "avif-serialize" | name = "avif-serialize" | ||||
| version = "0.8.3" | |||||
| version = "0.8.4" | |||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "98922d6a4cfbcb08820c69d8eeccc05bb1f29bfa06b4f5b1dbfe9a868bd7608e" | |||||
| checksum = "19135c0c7a60bfee564dbe44ab5ce0557c6bf3884e5291a50be76a15640c4fbd" | |||||
| dependencies = [ | dependencies = [ | ||||
| "arrayvec", | "arrayvec", | ||||
| ] | ] | ||||
| @@ -144,7 +144,7 @@ pub struct DoraEvent(Option<Event>); | |||||
| fn event_type(event: &DoraEvent) -> ffi::DoraEventType { | fn event_type(event: &DoraEvent) -> ffi::DoraEventType { | ||||
| match &event.0 { | match &event.0 { | ||||
| Some(event) => match event { | Some(event) => match event { | ||||
| Event::Stop => ffi::DoraEventType::Stop, | |||||
| Event::Stop(_) => ffi::DoraEventType::Stop, | |||||
| Event::Input { .. } => ffi::DoraEventType::Input, | Event::Input { .. } => ffi::DoraEventType::Input, | ||||
| Event::InputClosed { .. } => ffi::DoraEventType::InputClosed, | Event::InputClosed { .. } => ffi::DoraEventType::InputClosed, | ||||
| Event::Error(_) => ffi::DoraEventType::Error, | Event::Error(_) => ffi::DoraEventType::Error, | ||||
| @@ -91,7 +91,7 @@ pub unsafe extern "C" fn dora_next_event(context: *mut c_void) -> *mut c_void { | |||||
| pub unsafe extern "C" fn read_dora_event_type(event: *const ()) -> EventType { | pub unsafe extern "C" fn read_dora_event_type(event: *const ()) -> EventType { | ||||
| let event: &Event = unsafe { &*event.cast() }; | let event: &Event = unsafe { &*event.cast() }; | ||||
| match event { | match event { | ||||
| Event::Stop => EventType::Stop, | |||||
| Event::Stop(_) => EventType::Stop, | |||||
| Event::Input { .. } => EventType::Input, | Event::Input { .. } => EventType::Input, | ||||
| Event::InputClosed { .. } => EventType::InputClosed, | Event::InputClosed { .. } => EventType::InputClosed, | ||||
| Event::Error(_) => EventType::Error, | Event::Error(_) => EventType::Error, | ||||
| @@ -6,7 +6,7 @@ use std::{ | |||||
| use arrow::pyarrow::ToPyArrow; | use arrow::pyarrow::ToPyArrow; | ||||
| use dora_node_api::{ | use dora_node_api::{ | ||||
| merged::{MergeExternalSend, MergedEvent}, | merged::{MergeExternalSend, MergedEvent}, | ||||
| DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter, | |||||
| DoraNode, Event, EventStream, Metadata, MetadataParameters, Parameter, StopCause, | |||||
| }; | }; | ||||
| use eyre::{Context, Result}; | use eyre::{Context, Result}; | ||||
| use futures::{Stream, StreamExt}; | use futures::{Stream, StreamExt}; | ||||
| @@ -146,7 +146,7 @@ impl PyEvent { | |||||
| fn ty(event: &Event) -> &str { | fn ty(event: &Event) -> &str { | ||||
| match event { | match event { | ||||
| Event::Stop => "STOP", | |||||
| Event::Stop(_) => "STOP", | |||||
| Event::Input { .. } => "INPUT", | Event::Input { .. } => "INPUT", | ||||
| Event::InputClosed { .. } => "INPUT_CLOSED", | Event::InputClosed { .. } => "INPUT_CLOSED", | ||||
| Event::Error(_) => "ERROR", | Event::Error(_) => "ERROR", | ||||
| @@ -158,6 +158,11 @@ impl PyEvent { | |||||
| match event { | match event { | ||||
| Event::Input { id, .. } => Some(id), | Event::Input { id, .. } => Some(id), | ||||
| Event::InputClosed { id } => Some(id), | Event::InputClosed { id } => Some(id), | ||||
| Event::Stop(cause) => match cause { | |||||
| StopCause::Manual => Some("MANUAL"), | |||||
| StopCause::AllInputsClosed => Some("ALL_INPUTS_CLOSED"), | |||||
| &_ => None, | |||||
| }, | |||||
| _ => None, | _ => None, | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,7 +10,7 @@ use shared_memory_extended::{Shmem, ShmemConf}; | |||||
| #[derive(Debug)] | #[derive(Debug)] | ||||
| #[non_exhaustive] | #[non_exhaustive] | ||||
| pub enum Event { | pub enum Event { | ||||
| Stop, | |||||
| Stop(StopCause), | |||||
| Reload { | Reload { | ||||
| operator_id: Option<OperatorId>, | operator_id: Option<OperatorId>, | ||||
| }, | }, | ||||
| @@ -25,6 +25,13 @@ pub enum Event { | |||||
| Error(String), | Error(String), | ||||
| } | } | ||||
| #[derive(Debug, Clone)] | |||||
| #[non_exhaustive] | |||||
| pub enum StopCause { | |||||
| Manual, | |||||
| AllInputsClosed, | |||||
| } | |||||
| pub enum RawData { | pub enum RawData { | ||||
| Empty, | Empty, | ||||
| Vec(AVec<u8, ConstAlign<128>>), | Vec(AVec<u8, ConstAlign<128>>), | ||||
| @@ -11,7 +11,7 @@ use dora_message::{ | |||||
| node_to_daemon::{DaemonRequest, Timestamped}, | node_to_daemon::{DaemonRequest, Timestamped}, | ||||
| DataflowId, | DataflowId, | ||||
| }; | }; | ||||
| pub use event::{Event, MappedInputData, RawData}; | |||||
| pub use event::{Event, MappedInputData, RawData, StopCause}; | |||||
| use futures::{ | use futures::{ | ||||
| future::{select, Either}, | future::{select, Either}, | ||||
| Stream, StreamExt, | Stream, StreamExt, | ||||
| @@ -199,7 +199,7 @@ impl EventStream { | |||||
| fn convert_event_item(item: EventItem) -> Event { | fn convert_event_item(item: EventItem) -> Event { | ||||
| match item { | match item { | ||||
| EventItem::NodeEvent { event, ack_channel } => match event { | EventItem::NodeEvent { event, ack_channel } => match event { | ||||
| NodeEvent::Stop => Event::Stop, | |||||
| NodeEvent::Stop => Event::Stop(event::StopCause::Manual), | |||||
| NodeEvent::Reload { operator_id } => Event::Reload { operator_id }, | NodeEvent::Reload { operator_id } => Event::Reload { operator_id }, | ||||
| NodeEvent::InputClosed { id } => Event::InputClosed { id }, | NodeEvent::InputClosed { id } => Event::InputClosed { id }, | ||||
| NodeEvent::Input { id, metadata, data } => { | NodeEvent::Input { id, metadata, data } => { | ||||
| @@ -234,13 +234,7 @@ impl EventStream { | |||||
| Err(err) => Event::Error(format!("{err:?}")), | Err(err) => Event::Error(format!("{err:?}")), | ||||
| } | } | ||||
| } | } | ||||
| NodeEvent::AllInputsClosed => { | |||||
| let err = eyre!( | |||||
| "received `AllInputsClosed` event, which should be handled by background task" | |||||
| ); | |||||
| tracing::error!("{err:?}"); | |||||
| Event::Error(err.wrap_err("internal error").to_string()) | |||||
| } | |||||
| NodeEvent::AllInputsClosed => Event::Stop(event::StopCause::AllInputsClosed), | |||||
| }, | }, | ||||
| EventItem::FatalError(err) => { | EventItem::FatalError(err) => { | ||||
| @@ -92,6 +92,7 @@ fn event_stream_loop( | |||||
| clock: Arc<uhlc::HLC>, | clock: Arc<uhlc::HLC>, | ||||
| ) { | ) { | ||||
| let mut tx = Some(tx); | let mut tx = Some(tx); | ||||
| let mut close_tx = false; | |||||
| let mut pending_drop_tokens: Vec<(DropToken, flume::Receiver<()>, Instant, u64)> = Vec::new(); | let mut pending_drop_tokens: Vec<(DropToken, flume::Receiver<()>, Instant, u64)> = Vec::new(); | ||||
| let mut drop_tokens = Vec::new(); | let mut drop_tokens = Vec::new(); | ||||
| @@ -135,10 +136,8 @@ fn event_stream_loop( | |||||
| data: Some(data), .. | data: Some(data), .. | ||||
| } => data.drop_token(), | } => data.drop_token(), | ||||
| NodeEvent::AllInputsClosed => { | NodeEvent::AllInputsClosed => { | ||||
| // close the event stream | |||||
| tx = None; | |||||
| // skip this internal event | |||||
| continue; | |||||
| close_tx = true; | |||||
| None | |||||
| } | } | ||||
| _ => None, | _ => None, | ||||
| }; | }; | ||||
| @@ -166,6 +165,10 @@ fn event_stream_loop( | |||||
| } else { | } else { | ||||
| tracing::warn!("dropping event because event `tx` was already closed: `{inner:?}`"); | tracing::warn!("dropping event because event `tx` was already closed: `{inner:?}`"); | ||||
| } | } | ||||
| if close_tx { | |||||
| tx = None; | |||||
| }; | |||||
| } | } | ||||
| }; | }; | ||||
| if let Err(err) = result { | if let Err(err) = result { | ||||
| @@ -20,7 +20,7 @@ pub use dora_message::{ | |||||
| metadata::{Metadata, MetadataParameters, Parameter}, | metadata::{Metadata, MetadataParameters, Parameter}, | ||||
| DataflowId, | DataflowId, | ||||
| }; | }; | ||||
| pub use event_stream::{merged, Event, EventStream, MappedInputData, RawData}; | |||||
| pub use event_stream::{merged, Event, EventStream, MappedInputData, RawData, StopCause}; | |||||
| pub use flume::Receiver; | pub use flume::Receiver; | ||||
| pub use node::{arrow_utils, DataSample, DoraNode, ZERO_COPY_THRESHOLD}; | pub use node::{arrow_utils, DataSample, DoraNode, ZERO_COPY_THRESHOLD}; | ||||
| @@ -540,7 +540,7 @@ pub async fn spawn_node( | |||||
| // If log is an output, we're sending the logs to the dataflow | // If log is an output, we're sending the logs to the dataflow | ||||
| if let Some(stdout_output_name) = &send_stdout_to { | if let Some(stdout_output_name) = &send_stdout_to { | ||||
| // Convert logs to DataMessage | // Convert logs to DataMessage | ||||
| let array = message.into_arrow(); | |||||
| let array = message.as_str().into_arrow(); | |||||
| let array: ArrayData = array.into(); | let array: ArrayData = array.into(); | ||||
| let total_len = required_data_size(&array); | let total_len = required_data_size(&array); | ||||
| @@ -232,10 +232,10 @@ async fn run( | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| RuntimeEvent::Event(Event::Stop) => { | |||||
| RuntimeEvent::Event(Event::Stop(cause)) => { | |||||
| // forward stop event to all operators and close the event channels | // forward stop event to all operators and close the event channels | ||||
| for (_, channel) in operator_channels.drain() { | for (_, channel) in operator_channels.drain() { | ||||
| let _ = channel.send_async(Event::Stop).await; | |||||
| let _ = channel.send_async(Event::Stop(cause.clone())).await; | |||||
| } | } | ||||
| } | } | ||||
| RuntimeEvent::Event(Event::Reload { | RuntimeEvent::Event(Event::Reload { | ||||
| @@ -182,7 +182,7 @@ impl<'lib> SharedLibraryOperator<'lib> { | |||||
| } | } | ||||
| let mut operator_event = match event { | let mut operator_event = match event { | ||||
| Event::Stop => dora_operator_api_types::RawEvent { | |||||
| Event::Stop(_) => dora_operator_api_types::RawEvent { | |||||
| input: None, | input: None, | ||||
| input_closed: None, | input_closed: None, | ||||
| stop: true, | stop: true, | ||||
| @@ -26,7 +26,7 @@ fn main() -> eyre::Result<()> { | |||||
| } | } | ||||
| other => eprintln!("Ignoring unexpected input `{other}`"), | other => eprintln!("Ignoring unexpected input `{other}`"), | ||||
| }, | }, | ||||
| Event::Stop => println!("Received manual stop"), | |||||
| Event::Stop(_) => println!("Received stop"), | |||||
| other => eprintln!("Received unexpected input: {other:?}"), | other => eprintln!("Received unexpected input: {other:?}"), | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,8 +24,8 @@ fn main() -> eyre::Result<()> { | |||||
| } | } | ||||
| other => eprintln!("Ignoring unexpected input `{other}`"), | other => eprintln!("Ignoring unexpected input `{other}`"), | ||||
| }, | }, | ||||
| Event::Stop => { | |||||
| println!("Received manual stop"); | |||||
| Event::Stop(_) => { | |||||
| println!("Received stop"); | |||||
| } | } | ||||
| Event::InputClosed { id } => { | Event::InputClosed { id } => { | ||||
| println!("Input `{id}` was closed"); | println!("Input `{id}` was closed"); | ||||
| @@ -3,14 +3,14 @@ nodes: | |||||
| build: cargo build -p dora-openai-proxy-server --release | build: cargo build -p dora-openai-proxy-server --release | ||||
| path: ../../target/release/dora-openai-proxy-server | path: ../../target/release/dora-openai-proxy-server | ||||
| outputs: | outputs: | ||||
| - chat_completion_request | |||||
| - text | |||||
| inputs: | inputs: | ||||
| completion_reply: dora-echo/echo | |||||
| text: dora-echo/echo | |||||
| - id: dora-echo | - id: dora-echo | ||||
| build: pip install -e ../../node-hub/dora-echo | build: pip install -e ../../node-hub/dora-echo | ||||
| path: dora-echo | path: dora-echo | ||||
| inputs: | inputs: | ||||
| echo: dora-openai-server/chat_completion_request | |||||
| echo: dora-openai-server/text | |||||
| outputs: | outputs: | ||||
| - echo | - echo | ||||
| @@ -32,11 +32,69 @@ def test_chat_completion(user_input): | |||||
| print(f"Error in chat completion: {e}") | print(f"Error in chat completion: {e}") | ||||
| def test_chat_completion_image_url(user_input): | |||||
| """TODO: Add docstring.""" | |||||
| try: | |||||
| response = client.chat.completions.create( | |||||
| model="gpt-3.5-turbo", | |||||
| messages=[ | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| {"type": "text", "text": "What is in this image?"}, | |||||
| { | |||||
| "type": "image_url", | |||||
| "image_url": { | |||||
| "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" | |||||
| }, | |||||
| }, | |||||
| ], | |||||
| } | |||||
| ], | |||||
| ) | |||||
| print("Chat Completion Response:") | |||||
| print(response.choices[0].message.content) | |||||
| except Exception as e: | |||||
| print(f"Error in chat completion: {e}") | |||||
| def test_chat_completion_image_base64(user_input): | |||||
| """TODO: Add docstring.""" | |||||
| try: | |||||
| response = client.chat.completions.create( | |||||
| model="gpt-3.5-turbo", | |||||
| messages=[ | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| {"type": "text", "text": "What is in this image?"}, | |||||
| { | |||||
| "type": "image_url", | |||||
| "image_url": { | |||||
| "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" | |||||
| }, | |||||
| }, | |||||
| ], | |||||
| } | |||||
| ], | |||||
| ) | |||||
| print("Chat Completion Response:") | |||||
| print(response.choices[0].message.content) | |||||
| except Exception as e: | |||||
| print(f"Error in chat completion: {e}") | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| print("Testing API endpoints...") | print("Testing API endpoints...") | ||||
| test_list_models() | |||||
| # test_list_models() | |||||
| print("\n" + "=" * 50 + "\n") | print("\n" + "=" * 50 + "\n") | ||||
| chat_input = input("Enter a message for chat completion: ") | chat_input = input("Enter a message for chat completion: ") | ||||
| test_chat_completion(chat_input) | test_chat_completion(chat_input) | ||||
| print("\n" + "=" * 50 + "\n") | |||||
| test_chat_completion_image_url(chat_input) | |||||
| print("\n" + "=" * 50 + "\n") | |||||
| test_chat_completion_image_base64(chat_input) | |||||
| print("\n" + "=" * 50 + "\n") | print("\n" + "=" * 50 + "\n") | ||||
| @@ -0,0 +1,16 @@ | |||||
| nodes: | |||||
| - id: dora-openai-server | |||||
| build: cargo build -p dora-openai-proxy-server --release | |||||
| path: ../../target/release/dora-openai-proxy-server | |||||
| outputs: | |||||
| - text | |||||
| inputs: | |||||
| text: dora-qwen2.5-vl/text | |||||
| - id: dora-qwen2.5-vl | |||||
| build: pip install -e ../../node-hub/dora-qwen2-5-vl | |||||
| path: dora-qwen2-5-vl | |||||
| inputs: | |||||
| text: dora-openai-server/text | |||||
| outputs: | |||||
| - text | |||||
| @@ -26,7 +26,7 @@ fn main() -> eyre::Result<()> { | |||||
| } | } | ||||
| other => eprintln!("Ignoring unexpected input `{other}`"), | other => eprintln!("Ignoring unexpected input `{other}`"), | ||||
| }, | }, | ||||
| Event::Stop => println!("Received manual stop"), | |||||
| Event::Stop(_) => println!("Received stop"), | |||||
| other => eprintln!("Received unexpected input: {other:?}"), | other => eprintln!("Received unexpected input: {other:?}"), | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,8 +25,8 @@ fn main() -> eyre::Result<()> { | |||||
| } | } | ||||
| other => eprintln!("Ignoring unexpected input `{other}`"), | other => eprintln!("Ignoring unexpected input `{other}`"), | ||||
| }, | }, | ||||
| Event::Stop => { | |||||
| println!("Received manual stop"); | |||||
| Event::Stop(_) => { | |||||
| println!("Received stop"); | |||||
| } | } | ||||
| Event::InputClosed { id } => { | Event::InputClosed { id } => { | ||||
| println!("Input `{id}` was closed"); | println!("Input `{id}` was closed"); | ||||
| @@ -24,8 +24,8 @@ fn main() -> eyre::Result<()> { | |||||
| } | } | ||||
| other => eprintln!("Ignoring unexpected input `{other}`"), | other => eprintln!("Ignoring unexpected input `{other}`"), | ||||
| }, | }, | ||||
| Event::Stop => { | |||||
| println!("Received manual stop"); | |||||
| Event::Stop(_) => { | |||||
| println!("Received stop"); | |||||
| } | } | ||||
| Event::InputClosed { id } => { | Event::InputClosed { id } => { | ||||
| println!("Input `{id}` was closed"); | println!("Input `{id}` was closed"); | ||||
| @@ -29,7 +29,7 @@ fn main() -> eyre::Result<()> { | |||||
| } | } | ||||
| other => eprintln!("ignoring unexpected input {other}"), | other => eprintln!("ignoring unexpected input {other}"), | ||||
| }, | }, | ||||
| Event::Stop => {} | |||||
| Event::Stop(_) => {} | |||||
| Event::InputClosed { id } => { | Event::InputClosed { id } => { | ||||
| println!("input `{id}` was closed"); | println!("input `{id}` was closed"); | ||||
| if *id == "random" { | if *id == "random" { | ||||
| @@ -119,7 +119,7 @@ fn main() -> eyre::Result<()> { | |||||
| } | } | ||||
| other => eprintln!("Ignoring unexpected input `{other}`"), | other => eprintln!("Ignoring unexpected input `{other}`"), | ||||
| }, | }, | ||||
| Event::Stop => println!("Received manual stop"), | |||||
| Event::Stop(_) => println!("Received stop"), | |||||
| other => eprintln!("Received unexpected input: {other:?}"), | other => eprintln!("Received unexpected input: {other:?}"), | ||||
| }, | }, | ||||
| MergedEvent::External(pose) => { | MergedEvent::External(pose) => { | ||||
| @@ -0,0 +1,54 @@ | |||||
| nodes: | |||||
| - id: camera | |||||
| build: pip install opencv-video-capture | |||||
| path: opencv-video-capture | |||||
| inputs: | |||||
| tick: dora/timer/millis/100 | |||||
| outputs: | |||||
| - image | |||||
| env: | |||||
| CAPTURE_PATH: 1 | |||||
| - id: dora-vggt | |||||
| build: pip install -e ../../node-hub/dora-vggt | |||||
| path: dora-vggt | |||||
| inputs: | |||||
| image: camera/image | |||||
| outputs: | |||||
| - depth | |||||
| - image | |||||
| env: | |||||
| DEPTH_ENCODING: mono16 | |||||
| - id: rav1e-depth | |||||
| path: dora-rav1e | |||||
| build: cargo build -p dora-rav1e --release | |||||
| inputs: | |||||
| depth: dora-vggt/depth | |||||
| outputs: | |||||
| - depth | |||||
| env: | |||||
| ENCODING: avif | |||||
| - id: rav1e-image | |||||
| path: dora-rav1e | |||||
| build: cargo build -p dora-rav1e --release | |||||
| inputs: | |||||
| image: dora-vggt/image | |||||
| outputs: | |||||
| - image | |||||
| env: | |||||
| ENCODING: avif | |||||
| - id: bench | |||||
| path: image_saver.py | |||||
| inputs: | |||||
| camera_depth: rav1e-image/image | |||||
| vggt_depth: rav1e-depth/depth | |||||
| - id: plot | |||||
| build: pip install dora-rerun | |||||
| path: dora-rerun | |||||
| inputs: | |||||
| camera/image: dora-vggt/image | |||||
| camera/depth: dora-vggt/depth | |||||
| @@ -0,0 +1,34 @@ | |||||
| from dora import Node | |||||
| node = Node() | |||||
| index_dict = {} | |||||
| i = 0 | |||||
| LEAD_TOPIC = "vggt_depth" | |||||
| for event in node: | |||||
| if event["type"] == "INPUT": | |||||
| if LEAD_TOPIC in event["id"]: | |||||
| storage = event["value"] | |||||
| metadata = event["metadata"] | |||||
| encoding = metadata["encoding"] | |||||
| width = metadata["width"] | |||||
| height = metadata["height"] | |||||
| # Save to file | |||||
| filename = f"out/{event['id']}_{i}.{encoding}" | |||||
| with open(filename, "wb") as f: | |||||
| f.write(storage.to_numpy()) | |||||
| for key, value in index_dict.items(): | |||||
| filename = f"out/{key}_{i}.{value['metadata']['encoding']}" | |||||
| with open(filename, "wb") as f: | |||||
| f.write(value["value"]) | |||||
| i += 1 | |||||
| else: | |||||
| # Store the event in the index dictionary | |||||
| index_dict[event["id"]] = { | |||||
| "type": event["type"], | |||||
| "value": event["value"].to_numpy(), | |||||
| "metadata": event["metadata"], | |||||
| } | |||||
| @@ -81,6 +81,20 @@ impl IntoArrow for NaiveTime { | |||||
| } | } | ||||
| } | } | ||||
| impl IntoArrow for String { | |||||
| type A = StringArray; | |||||
| fn into_arrow(self) -> Self::A { | |||||
| std::iter::once(Some(self)).collect() | |||||
| } | |||||
| } | |||||
| impl IntoArrow for Vec<String> { | |||||
| type A = StringArray; | |||||
| fn into_arrow(self) -> Self::A { | |||||
| StringArray::from(self) | |||||
| } | |||||
| } | |||||
| impl IntoArrow for NaiveDateTime { | impl IntoArrow for NaiveDateTime { | ||||
| type A = arrow::array::TimestampNanosecondArray; | type A = arrow::array::TimestampNanosecondArray; | ||||
| fn into_arrow(self) -> Self::A { | fn into_arrow(self) -> Self::A { | ||||
| @@ -11,6 +11,8 @@ def main(): | |||||
| node = Node() | node = Node() | ||||
| always_none = node.next(timeout=0.001) is None | always_none = node.next(timeout=0.001) is None | ||||
| always_none = node.next(timeout=0.001) is None | |||||
| print("Always None:", always_none) | |||||
| with keyboard.Events() as events: | with keyboard.Events() as events: | ||||
| while True: | while True: | ||||
| if not always_none: | if not always_none: | ||||
| @@ -19,6 +19,7 @@ def main(): | |||||
| start_recording_time = tm.time() | start_recording_time = tm.time() | ||||
| node = Node() | node = Node() | ||||
| always_none = node.next(timeout=0.001) is None | |||||
| always_none = node.next(timeout=0.001) is None | always_none = node.next(timeout=0.001) is None | ||||
| finished = False | finished = False | ||||
| @@ -41,13 +41,13 @@ async fn main() -> eyre::Result<()> { | |||||
| node.send_output( | node.send_output( | ||||
| mistral_output.clone(), | mistral_output.clone(), | ||||
| metadata.parameters, | metadata.parameters, | ||||
| output.into_arrow(), | |||||
| output.as_str().into_arrow(), | |||||
| )?; | )?; | ||||
| } | } | ||||
| other => eprintln!("Received input `{other}`"), | other => eprintln!("Received input `{other}`"), | ||||
| }, | }, | ||||
| Event::Stop => { | |||||
| println!("Received manual stop") | |||||
| Event::Stop(_) => { | |||||
| println!("Received command"); | |||||
| } | } | ||||
| Event::InputClosed { id } => { | Event::InputClosed { id } => { | ||||
| println!("input `{id}` was closed"); | println!("input `{id}` was closed"); | ||||
| @@ -62,29 +62,118 @@ if ADAPTER_PATH != "": | |||||
| processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | ||||
| def generate(frames: dict, question, history, past_key_values=None, image_id=None): | |||||
| def generate( | |||||
| frames: dict, texts: list[str], history, past_key_values=None, image_id=None | |||||
| ): | |||||
| """Generate the response to the question given the image using Qwen2 model.""" | """Generate the response to the question given the image using Qwen2 model.""" | ||||
| if image_id is not None: | if image_id is not None: | ||||
| images = [frames[image_id]] | images = [frames[image_id]] | ||||
| else: | else: | ||||
| images = list(frames.values()) | images = list(frames.values()) | ||||
| messages = [ | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| messages = [] | |||||
| for text in texts: | |||||
| if text.startswith("<|system|>\n"): | |||||
| messages.append( | |||||
| { | { | ||||
| "type": "image", | |||||
| "image": image, | |||||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||||
| "role": "system", | |||||
| "content": [ | |||||
| {"type": "text", "text": text.replace("<|system|>\n", "")}, | |||||
| ], | |||||
| } | } | ||||
| for image in images | |||||
| ] | |||||
| + [ | |||||
| {"type": "text", "text": question}, | |||||
| ], | |||||
| }, | |||||
| ] | |||||
| ) | |||||
| elif text.startswith("<|assistant|>\n"): | |||||
| messages.append( | |||||
| { | |||||
| "role": "assistant", | |||||
| "content": [ | |||||
| {"type": "text", "text": text.replace("<|assistant|>\n", "")}, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| elif text.startswith("<|tool|>\n"): | |||||
| messages.append( | |||||
| { | |||||
| "role": "tool", | |||||
| "content": [ | |||||
| {"type": "text", "text": text.replace("<|tool|>\n", "")}, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| elif text.startswith("<|user|>\n<|im_start|>\n"): | |||||
| messages.append( | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| { | |||||
| "type": "text", | |||||
| "text": text.replace("<|user|>\n<|im_start|>\n", ""), | |||||
| }, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| elif text.startswith("<|user|>\n<|vision_start|>\n"): | |||||
| # Handle the case where the text starts with <|user|>\n<|vision_start|> | |||||
| image_url = text.replace("<|user|>\n<|vision_start|>\n", "") | |||||
| # If the last message was from the user, append the image URL to it | |||||
| if messages[-1]["role"] == "user": | |||||
| messages[-1]["content"].append( | |||||
| { | |||||
| "type": "image", | |||||
| "image": image_url, | |||||
| } | |||||
| ) | |||||
| else: | |||||
| messages.append( | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| { | |||||
| "type": "image", | |||||
| "image": image_url, | |||||
| }, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| else: | |||||
| messages.append( | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| {"type": "text", "text": text}, | |||||
| ], | |||||
| } | |||||
| ) | |||||
| # If the last message was from the user, append the image URL to it | |||||
| if messages[-1]["role"] == "user": | |||||
| messages[-1]["content"] += [ | |||||
| { | |||||
| "type": "image", | |||||
| "image": image, | |||||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||||
| } | |||||
| for image in images | |||||
| ] | |||||
| else: | |||||
| messages.append( | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| { | |||||
| "type": "image", | |||||
| "image": image, | |||||
| "resized_height": image.size[1] * IMAGE_RESIZE_RATIO, | |||||
| "resized_width": image.size[0] * IMAGE_RESIZE_RATIO, | |||||
| } | |||||
| for image in images | |||||
| ], | |||||
| } | |||||
| ) | |||||
| tmp_history = history + messages | tmp_history = history + messages | ||||
| # Preparation for inference | # Preparation for inference | ||||
| text = processor.apply_chat_template( | text = processor.apply_chat_template( | ||||
| @@ -120,19 +209,13 @@ def generate(frames: dict, question, history, past_key_values=None, image_id=Non | |||||
| clean_up_tokenization_spaces=False, | clean_up_tokenization_spaces=False, | ||||
| ) | ) | ||||
| if HISTORY: | if HISTORY: | ||||
| history += [ | |||||
| { | |||||
| "role": "user", | |||||
| "content": [ | |||||
| {"type": "text", "text": question}, | |||||
| ], | |||||
| }, | |||||
| history = tmp_history + [ | |||||
| { | { | ||||
| "role": "assistant", | "role": "assistant", | ||||
| "content": [ | "content": [ | ||||
| {"type": "text", "text": output_text[0]}, | {"type": "text", "text": output_text[0]}, | ||||
| ], | ], | ||||
| }, | |||||
| } | |||||
| ] | ] | ||||
| return output_text[0], history, past_key_values | return output_text[0], history, past_key_values | ||||
| @@ -207,24 +290,22 @@ def main(): | |||||
| elif "text" in event_id: | elif "text" in event_id: | ||||
| if len(event["value"]) > 0: | if len(event["value"]) > 0: | ||||
| text = event["value"][0].as_py() | |||||
| texts = event["value"].to_pylist() | |||||
| image_id = event["metadata"].get("image_id", None) | image_id = event["metadata"].get("image_id", None) | ||||
| else: | else: | ||||
| text = cached_text | |||||
| words = text.split() | |||||
| texts = cached_text | |||||
| words = texts[-1].split() | |||||
| if len(ACTIVATION_WORDS) > 0 and all( | if len(ACTIVATION_WORDS) > 0 and all( | ||||
| word not in ACTIVATION_WORDS for word in words | word not in ACTIVATION_WORDS for word in words | ||||
| ): | ): | ||||
| continue | continue | ||||
| cached_text = text | |||||
| cached_text = texts | |||||
| if len(frames.keys()) == 0: | |||||
| continue | |||||
| # set the max number of tiles in `max_num` | # set the max number of tiles in `max_num` | ||||
| response, history, past_key_values = generate( | response, history, past_key_values = generate( | ||||
| frames, | frames, | ||||
| text, | |||||
| texts, | |||||
| history, | history, | ||||
| past_key_values, | past_key_values, | ||||
| image_id, | image_id, | ||||
| @@ -25,7 +25,7 @@ pyo3 = { workspace = true, features = [ | |||||
| "eyre", | "eyre", | ||||
| "generate-import-lib", | "generate-import-lib", | ||||
| ], optional = true } | ], optional = true } | ||||
| avif-serialize = "0.8.3" | |||||
| avif-serialize = "0.8.4" | |||||
| [lib] | [lib] | ||||
| @@ -336,7 +336,7 @@ pub fn lib_main() -> Result<()> { | |||||
| if let Some(buffer) = data.as_primitive_opt::<UInt16Type>() { | if let Some(buffer) = data.as_primitive_opt::<UInt16Type>() { | ||||
| let mut buffer = buffer.values().to_vec(); | let mut buffer = buffer.values().to_vec(); | ||||
| if std::env::var("FILL_ZEROS") | if std::env::var("FILL_ZEROS") | ||||
| .map(|s| s != "false") | |||||
| .map(|s| s.to_lowercase() != "false") | |||||
| .unwrap_or(true) | .unwrap_or(true) | ||||
| { | { | ||||
| fill_zeros_toward_center_y_plane_in_place(&mut buffer, width, height); | fill_zeros_toward_center_y_plane_in_place(&mut buffer, width, height); | ||||
| @@ -370,7 +370,28 @@ pub fn lib_main() -> Result<()> { | |||||
| let data = pkt.data; | let data = pkt.data; | ||||
| match output_encoding.as_str() { | match output_encoding.as_str() { | ||||
| "avif" => { | "avif" => { | ||||
| warn!("avif encoding not supported for mono16"); | |||||
| metadata.parameters.insert( | |||||
| "encoding".to_string(), | |||||
| Parameter::String("avif".to_string()), | |||||
| ); | |||||
| let data = avif_serialize::Aviffy::new() | |||||
| .full_color_range(false) | |||||
| .set_seq_profile(0) | |||||
| .set_monochrome(true) | |||||
| .to_vec( | |||||
| &data, | |||||
| None, | |||||
| enc.width as u32, | |||||
| enc.height as u32, | |||||
| enc.bit_depth as u8, | |||||
| ); | |||||
| let arrow = data.into_arrow(); | |||||
| node.send_output(id, metadata.parameters.clone(), arrow) | |||||
| .context("could not send output") | |||||
| .unwrap(); | |||||
| } | } | ||||
| _ => { | _ => { | ||||
| metadata.parameters.insert( | metadata.parameters.insert( | ||||
| @@ -1,8 +1,9 @@ | |||||
| """TODO: Add docstring.""" | """TODO: Add docstring.""" | ||||
| import io | import io | ||||
| from collections import deque as Deque | |||||
| import os | import os | ||||
| from collections import deque as Deque | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import pyarrow as pa | import pyarrow as pa | ||||
| @@ -10,22 +11,24 @@ import torch | |||||
| from dora import Node | from dora import Node | ||||
| from PIL import Image | from PIL import Image | ||||
| from vggt.models.vggt import VGGT | from vggt.models.vggt import VGGT | ||||
| from vggt.utils.geometry import unproject_depth_map_to_point_map | |||||
| from vggt.utils.load_fn import load_and_preprocess_images | from vggt.utils.load_fn import load_and_preprocess_images | ||||
| from vggt.utils.pose_enc import pose_encoding_to_extri_intri | from vggt.utils.pose_enc import pose_encoding_to_extri_intri | ||||
| from vggt.utils.geometry import unproject_depth_map_to_point_map | |||||
| CAMERA_HEIGHT = os.getenv("CAMERA_HEIGHT", "0.01") | CAMERA_HEIGHT = os.getenv("CAMERA_HEIGHT", "0.01") | ||||
| # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) | # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) | ||||
| dtype = torch.bfloat16 | dtype = torch.bfloat16 | ||||
| # Check if cuda is available and set the device accordingly | |||||
| device = "cuda" if torch.cuda.is_available() else "cpu" | |||||
| # Initialize the model and load the pretrained weights. | # Initialize the model and load the pretrained weights. | ||||
| # This will automatically download the model weights the first time it's run, which may take a while. | # This will automatically download the model weights the first time it's run, which may take a while. | ||||
| model = VGGT.from_pretrained("facebook/VGGT-1B").to("cuda") | |||||
| model = VGGT.from_pretrained("facebook/VGGT-1B").to(device) | |||||
| model.eval() | model.eval() | ||||
| DEPTH_ENCODING = os.environ.get("DEPTH_ENCODING", "float64") | |||||
| def main(): | def main(): | ||||
| @@ -35,7 +38,6 @@ def main(): | |||||
| for event in node: | for event in node: | ||||
| if event["type"] == "INPUT": | if event["type"] == "INPUT": | ||||
| if "image" in event["id"]: | if "image" in event["id"]: | ||||
| storage = event["value"] | storage = event["value"] | ||||
| metadata = event["metadata"] | metadata = event["metadata"] | ||||
| @@ -83,7 +85,7 @@ def main(): | |||||
| raw_images.append(buffer) | raw_images.append(buffer) | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| images = load_and_preprocess_images(raw_images).to("cuda") | |||||
| images = load_and_preprocess_images(raw_images).to(device) | |||||
| images = images[None] # add batch dimension | images = images[None] # add batch dimension | ||||
| aggregated_tokens_list, ps_idx = model.aggregator(images) | aggregated_tokens_list, ps_idx = model.aggregator(images) | ||||
| @@ -99,25 +101,34 @@ def main(): | |||||
| aggregated_tokens_list, images, ps_idx | aggregated_tokens_list, images, ps_idx | ||||
| ) | ) | ||||
| # Construct 3D Points from Depth Maps and Cameras | # Construct 3D Points from Depth Maps and Cameras | ||||
| # which usually leads to more accurate 3D points than point map branch | # which usually leads to more accurate 3D points than point map branch | ||||
| point_map_by_unprojection = unproject_depth_map_to_point_map(depth_map.squeeze(0), | |||||
| extrinsic.squeeze(0), | |||||
| intrinsic.squeeze(0)) | |||||
| point_map_by_unprojection = unproject_depth_map_to_point_map( | |||||
| depth_map.squeeze(0), extrinsic.squeeze(0), intrinsic.squeeze(0) | |||||
| ) | |||||
| # Get the last quartile of the 2nd axis | # Get the last quartile of the 2nd axis | ||||
| z_value = point_map_by_unprojection[0, :, :, 2] | z_value = point_map_by_unprojection[0, :, :, 2] | ||||
| z_first_quartile = np.quantile(z_value, 0.15) | z_first_quartile = np.quantile(z_value, 0.15) | ||||
| scale_factor = float(CAMERA_HEIGHT) / z_first_quartile | scale_factor = float(CAMERA_HEIGHT) / z_first_quartile | ||||
| print(f"Scale factor: {scale_factor}, with height: {CAMERA_HEIGHT} and max depth: {point_map_by_unprojection[0, :, :, 2].min()}") | |||||
| print(f" 0. all min and max depth values: {point_map_by_unprojection[0, :, :, 0].min()} / {point_map_by_unprojection[0, :, :, 0].max()}") | |||||
| print(f" 1. all min and max depth values: {point_map_by_unprojection[0, :, :, 1].min()} / {point_map_by_unprojection[0, :, :, 1].max()}") | |||||
| print(f" 2. all min and max depth values: {point_map_by_unprojection[0, :, :, 2].min()} / {point_map_by_unprojection[0, :, :, 2].max()}") | |||||
| print( | |||||
| f"Scale factor: {scale_factor}, with height: {CAMERA_HEIGHT} and max depth: {point_map_by_unprojection[0, :, :, 2].min()}" | |||||
| ) | |||||
| print( | |||||
| f" 0. all min and max depth values: {point_map_by_unprojection[0, :, :, 0].min()} / {point_map_by_unprojection[0, :, :, 0].max()}" | |||||
| ) | |||||
| print( | |||||
| f" 1. all min and max depth values: {point_map_by_unprojection[0, :, :, 1].min()} / {point_map_by_unprojection[0, :, :, 1].max()}" | |||||
| ) | |||||
| print( | |||||
| f" 2. all min and max depth values: {point_map_by_unprojection[0, :, :, 2].min()} / {point_map_by_unprojection[0, :, :, 2].max()}" | |||||
| ) | |||||
| print(f" first quartile of z values: {z_first_quartile}") | print(f" first quartile of z values: {z_first_quartile}") | ||||
| depth_map[depth_conf < 1.0] = 0.0 # Set low confidence pixels to 0 | depth_map[depth_conf < 1.0] = 0.0 # Set low confidence pixels to 0 | ||||
| depth_map = depth_map * scale_factor # Scale depth map to the desired height | |||||
| depth_map = ( | |||||
| depth_map * scale_factor | |||||
| ) # Scale depth map to the desired height | |||||
| depth_map = depth_map.to(torch.float64) | depth_map = depth_map.to(torch.float64) | ||||
| intrinsic = intrinsic[-1][-1] | intrinsic = intrinsic[-1][-1] | ||||
| @@ -127,20 +138,24 @@ def main(): | |||||
| r_1 = intrinsic[1, 2] | r_1 = intrinsic[1, 2] | ||||
| depth_map = depth_map[-1][-1].cpu().numpy() | depth_map = depth_map[-1][-1].cpu().numpy() | ||||
| # Warning: Make sure to add my_output_id and my_input_id within the dataflow. | # Warning: Make sure to add my_output_id and my_input_id within the dataflow. | ||||
| if DEPTH_ENCODING == "mono16": | |||||
| depth_map = (depth_map * 1000).astype(np.uint16) | |||||
| node.send_output( | node.send_output( | ||||
| output_id="depth", | output_id="depth", | ||||
| data=pa.array(depth_map.ravel()), | data=pa.array(depth_map.ravel()), | ||||
| metadata={ | metadata={ | ||||
| "width": depth_map.shape[1], | "width": depth_map.shape[1], | ||||
| "height": depth_map.shape[0], | "height": depth_map.shape[0], | ||||
| "focal": [ | |||||
| int(f_0), | |||||
| int(f_1), | |||||
| ], | |||||
| "resolution": [ | |||||
| int(r_0), | |||||
| int(r_1), | |||||
| ], | |||||
| "encoding": DEPTH_ENCODING, | |||||
| "focal": [ | |||||
| int(f_0), | |||||
| int(f_1), | |||||
| ], | |||||
| "resolution": [ | |||||
| int(r_0), | |||||
| int(r_1), | |||||
| ], | |||||
| }, | }, | ||||
| ) | ) | ||||
| @@ -1,4 +1,10 @@ | |||||
| use dora_node_api::{self, dora_core::config::DataId, merged::MergeExternalSend, DoraNode, Event}; | |||||
| use dora_node_api::{ | |||||
| self, | |||||
| arrow::array::{AsArray, StringArray}, | |||||
| dora_core::config::DataId, | |||||
| merged::MergeExternalSend, | |||||
| DoraNode, Event, | |||||
| }; | |||||
| use eyre::{Context, ContextCompat}; | use eyre::{Context, ContextCompat}; | ||||
| use futures::{ | use futures::{ | ||||
| @@ -14,7 +20,7 @@ use hyper::{ | |||||
| }; | }; | ||||
| use message::{ | use message::{ | ||||
| ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage, | ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage, | ||||
| ChatCompletionRequest, ChatCompletionRequestMessage, Usage, | |||||
| ChatCompletionRequest, Usage, | |||||
| }; | }; | ||||
| use std::{ | use std::{ | ||||
| collections::VecDeque, | collections::VecDeque, | ||||
| @@ -71,7 +77,7 @@ async fn main() -> eyre::Result<()> { | |||||
| let merged = events.merge_external_send(server_events); | let merged = events.merge_external_send(server_events); | ||||
| let events = futures::executor::block_on_stream(merged); | let events = futures::executor::block_on_stream(merged); | ||||
| let output_id = DataId::from("chat_completion_request".to_owned()); | |||||
| let output_id = DataId::from("text".to_owned()); | |||||
| let mut reply_channels = VecDeque::new(); | let mut reply_channels = VecDeque::new(); | ||||
| for event in events { | for event in events { | ||||
| @@ -82,45 +88,15 @@ async fn main() -> eyre::Result<()> { | |||||
| break; | break; | ||||
| } | } | ||||
| ServerEvent::ChatCompletionRequest { request, reply } => { | ServerEvent::ChatCompletionRequest { request, reply } => { | ||||
| let message = request | |||||
| .messages | |||||
| .into_iter() | |||||
| .find_map(|m| match m { | |||||
| ChatCompletionRequestMessage::User(message) => Some(message), | |||||
| _ => None, | |||||
| }) | |||||
| .context("no user message found"); | |||||
| match message { | |||||
| Ok(message) => match message.content() { | |||||
| message::ChatCompletionUserMessageContent::Text(content) => { | |||||
| node.send_output_bytes( | |||||
| output_id.clone(), | |||||
| Default::default(), | |||||
| content.len(), | |||||
| content.as_bytes(), | |||||
| ) | |||||
| .context("failed to send dora output")?; | |||||
| reply_channels.push_back(( | |||||
| reply, | |||||
| content.as_bytes().len() as u64, | |||||
| request.model, | |||||
| )); | |||||
| } | |||||
| message::ChatCompletionUserMessageContent::Parts(_) => { | |||||
| if reply | |||||
| .send(Err(eyre::eyre!("unsupported message content"))) | |||||
| .is_err() | |||||
| { | |||||
| tracing::warn!("failed to send chat completion reply because channel closed early"); | |||||
| }; | |||||
| } | |||||
| }, | |||||
| Err(err) => { | |||||
| if reply.send(Err(err)).is_err() { | |||||
| tracing::warn!("failed to send chat completion reply error because channel closed early"); | |||||
| } | |||||
| } | |||||
| } | |||||
| let texts = request.to_texts(); | |||||
| node.send_output( | |||||
| output_id.clone(), | |||||
| Default::default(), | |||||
| StringArray::from(texts), | |||||
| ) | |||||
| .context("failed to send dora output")?; | |||||
| reply_channels.push_back((reply, 0 as u64, request.model)); | |||||
| } | } | ||||
| }, | }, | ||||
| dora_node_api::merged::MergedEvent::Dora(event) => match event { | dora_node_api::merged::MergedEvent::Dora(event) => match event { | ||||
| @@ -130,46 +106,56 @@ async fn main() -> eyre::Result<()> { | |||||
| metadata: _, | metadata: _, | ||||
| } => { | } => { | ||||
| match id.as_str() { | match id.as_str() { | ||||
| "completion_reply" => { | |||||
| "text" => { | |||||
| let (reply_channel, prompt_tokens, model) = | let (reply_channel, prompt_tokens, model) = | ||||
| reply_channels.pop_front().context("no reply channel")?; | reply_channels.pop_front().context("no reply channel")?; | ||||
| let data = TryFrom::try_from(&data) | |||||
| .with_context(|| format!("invalid reply data: {data:?}")) | |||||
| .map(|s: &[u8]| ChatCompletionObject { | |||||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||||
| object: "chat.completion".to_string(), | |||||
| created: chrono::Utc::now().timestamp() as u64, | |||||
| model: model.unwrap_or_default(), | |||||
| choices: vec![ChatCompletionObjectChoice { | |||||
| index: 0, | |||||
| message: ChatCompletionObjectMessage { | |||||
| role: message::ChatCompletionRole::Assistant, | |||||
| content: Some(String::from_utf8_lossy(s).to_string()), | |||||
| tool_calls: Vec::new(), | |||||
| function_call: None, | |||||
| }, | |||||
| finish_reason: message::FinishReason::stop, | |||||
| logprobs: None, | |||||
| }], | |||||
| usage: Usage { | |||||
| prompt_tokens, | |||||
| completion_tokens: s.len() as u64, | |||||
| total_tokens: prompt_tokens + s.len() as u64, | |||||
| let data = data.as_string::<i32>(); | |||||
| let string = data.iter().fold("".to_string(), |mut acc, s| { | |||||
| if let Some(s) = s { | |||||
| acc.push_str("\n"); | |||||
| acc.push_str(s); | |||||
| } | |||||
| acc | |||||
| }); | |||||
| let data = ChatCompletionObject { | |||||
| id: format!("completion-{}", uuid::Uuid::new_v4()), | |||||
| object: "chat.completion".to_string(), | |||||
| created: chrono::Utc::now().timestamp() as u64, | |||||
| model: model.unwrap_or_default(), | |||||
| choices: vec![ChatCompletionObjectChoice { | |||||
| index: 0, | |||||
| message: ChatCompletionObjectMessage { | |||||
| role: message::ChatCompletionRole::Assistant, | |||||
| content: Some(string.to_string()), | |||||
| tool_calls: Vec::new(), | |||||
| function_call: None, | |||||
| }, | }, | ||||
| }); | |||||
| if reply_channel.send(data).is_err() { | |||||
| finish_reason: message::FinishReason::stop, | |||||
| logprobs: None, | |||||
| }], | |||||
| usage: Usage { | |||||
| prompt_tokens, | |||||
| completion_tokens: string.len() as u64, | |||||
| total_tokens: prompt_tokens + string.len() as u64, | |||||
| }, | |||||
| }; | |||||
| if reply_channel.send(Ok(data)).is_err() { | |||||
| tracing::warn!("failed to send chat completion reply because channel closed early"); | tracing::warn!("failed to send chat completion reply because channel closed early"); | ||||
| } | } | ||||
| } | } | ||||
| _ => eyre::bail!("unexpected input id: {}", id), | _ => eyre::bail!("unexpected input id: {}", id), | ||||
| }; | }; | ||||
| } | } | ||||
| Event::Stop => { | |||||
| Event::Stop(_) => { | |||||
| break; | break; | ||||
| } | } | ||||
| Event::InputClosed { id, .. } => { | |||||
| info!("Input channel closed for id: {}", id); | |||||
| } | |||||
| event => { | event => { | ||||
| println!("Event: {event:#?}") | |||||
| eyre::bail!("unexpected event: {:#?}", event) | |||||
| } | } | ||||
| }, | }, | ||||
| } | } | ||||
| @@ -230,6 +230,15 @@ impl<'de> Deserialize<'de> for ChatCompletionRequest { | |||||
| } | } | ||||
| } | } | ||||
| impl ChatCompletionRequest { | |||||
| pub fn to_texts(&self) -> Vec<String> { | |||||
| self.messages | |||||
| .iter() | |||||
| .flat_map(|message| message.to_texts()) | |||||
| .collect() | |||||
| } | |||||
| } | |||||
| /// Message for comprising the conversation. | /// Message for comprising the conversation. | ||||
| #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] | #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] | ||||
| #[serde(tag = "role", rename_all = "lowercase")] | #[serde(tag = "role", rename_all = "lowercase")] | ||||
| @@ -308,6 +317,22 @@ impl ChatCompletionRequestMessage { | |||||
| ChatCompletionRequestMessage::Tool(_) => None, | ChatCompletionRequestMessage::Tool(_) => None, | ||||
| } | } | ||||
| } | } | ||||
| /// The contents of the message. | |||||
| pub fn to_texts(&self) -> Vec<String> { | |||||
| match self { | |||||
| ChatCompletionRequestMessage::System(message) => { | |||||
| vec![String::from("<|system|>\n") + &message.content] | |||||
| } | |||||
| ChatCompletionRequestMessage::User(message) => message.content.to_texts(), | |||||
| ChatCompletionRequestMessage::Assistant(message) => { | |||||
| vec![String::from("<|assistant|>\n") + &message.content.clone().unwrap_or_default()] | |||||
| } | |||||
| ChatCompletionRequestMessage::Tool(message) => { | |||||
| vec![String::from("<|tool|>\n") + &message.content.clone()] | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| /// Sampling methods used for chat completion requests. | /// Sampling methods used for chat completion requests. | ||||
| @@ -587,6 +612,25 @@ impl ChatCompletionUserMessageContent { | |||||
| ChatCompletionUserMessageContent::Parts(_) => "parts", | ChatCompletionUserMessageContent::Parts(_) => "parts", | ||||
| } | } | ||||
| } | } | ||||
| pub fn to_texts(&self) -> Vec<String> { | |||||
| match self { | |||||
| ChatCompletionUserMessageContent::Text(text) => { | |||||
| vec![String::from("user: ") + &text.clone()] | |||||
| } | |||||
| ChatCompletionUserMessageContent::Parts(parts) => parts | |||||
| .iter() | |||||
| .map(|part| match part { | |||||
| ContentPart::Text(text_part) => { | |||||
| String::from("<|user|>\n<|im_start|>\n") + &text_part.text.clone() | |||||
| } | |||||
| ContentPart::Image(image) => { | |||||
| String::from("<|user|>\n<|vision_start|>\n") + &image.image().url.clone() | |||||
| } | |||||
| }) | |||||
| .collect(), | |||||
| } | |||||
| } | |||||
| } | } | ||||
| /// Define the content part of a user message. | /// Define the content part of a user message. | ||||