diff --git a/Cargo.lock b/Cargo.lock index a888e51a..3bb2d198 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1179,6 +1179,7 @@ dependencies = [ "dora-operator-api-types", "dora-tracing", "eyre", + "flume", "futures", "futures-concurrency", "libloading", diff --git a/binaries/runtime/Cargo.toml b/binaries/runtime/Cargo.toml index 2f972c7d..542877c7 100644 --- a/binaries/runtime/Cargo.toml +++ b/binaries/runtime/Cargo.toml @@ -30,6 +30,7 @@ pyo3 = { version = "0.16", features = ["eyre", "abi3-py37"] } tracing = "0.1.36" tracing-subscriber = "0.3.15" dora-download = { path = "../../libraries/extensions/download" } +flume = "0.10.14" [features] tracing = ["opentelemetry", "dora-tracing"] diff --git a/binaries/runtime/src/lib.rs b/binaries/runtime/src/lib.rs index cbee3a23..91831f0f 100644 --- a/binaries/runtime/src/lib.rs +++ b/binaries/runtime/src/lib.rs @@ -71,7 +71,7 @@ pub fn main() -> eyre::Result<()> { .wrap_err("Could not build a tokio runtime.")?; let mut operator_channels = HashMap::new(); - let (operator_channel, incoming_events) = mpsc::channel(10); + let (operator_channel, incoming_events) = operator::channel::channel(tokio_runtime.handle()); operator_channels.insert(operator_definition.id.clone(), operator_channel); tracing::info!("spawning main task"); @@ -107,7 +107,7 @@ async fn run( mut node: DoraNode, operators: HashMap, mut events: impl Stream + Unpin, - mut operator_channels: HashMap>, + mut operator_channels: HashMap>, ) -> eyre::Result<()> { #[cfg(feature = "metrics")] let _started = { @@ -198,7 +198,7 @@ async fn run( Event::Stop => { // forward stop event to all operators and close the event channels for (_, channel) in operator_channels.drain() { - let _ = channel.send(operator::IncomingEvent::Stop).await; + let _ = channel.send_async(operator::IncomingEvent::Stop).await; } } Event::Input { id, metadata, data } => { @@ -214,7 +214,7 @@ async fn run( }; if let Err(err) = operator_channel - .send(operator::IncomingEvent::Input { + .send_async(operator::IncomingEvent::Input { input_id: input_id.clone(), metadata, data, diff --git a/binaries/runtime/src/operator/channel.rs b/binaries/runtime/src/operator/channel.rs new file mode 100644 index 00000000..1cf094fb --- /dev/null +++ b/binaries/runtime/src/operator/channel.rs @@ -0,0 +1,92 @@ +use super::IncomingEvent; +use futures::{ + future::{self, FusedFuture}, + FutureExt, +}; +use std::collections::VecDeque; + +pub fn channel( + runtime: &tokio::runtime::Handle, +) -> (flume::Sender, flume::Receiver) { + let (incoming_tx, incoming_rx) = flume::bounded(10); + let (outgoing_tx, outgoing_rx) = flume::bounded(0); + + runtime.spawn(async { + let mut buffer = InputBuffer::new(); + buffer.run(incoming_rx, outgoing_tx).await; + }); + + (incoming_tx, outgoing_rx) +} + +struct InputBuffer { + queue: VecDeque, +} + +impl InputBuffer { + pub fn new() -> Self { + Self { + queue: VecDeque::new(), + } + } + + pub async fn run( + &mut self, + incoming: flume::Receiver, + outgoing: flume::Sender, + ) { + let mut send_out_buf = future::Fuse::terminated(); + loop { + let next_incoming = incoming.recv_async(); + match future::select(next_incoming, send_out_buf).await { + future::Either::Left((event, mut send_out)) => { + match event { + Ok(event) => { + // received a new event -> push it to the queue + self.queue.push_back(event); + + // TODO: drop oldest events when queue becomes too full + + // if outgoing queue is empty, fill it again + if send_out.is_terminated() { + send_out = self.send_next_queued(&outgoing); + } + } + Err(flume::RecvError::Disconnected) => { + // the incoming channel was closed -> exit if we sent out all events already + if send_out.is_terminated() && self.queue.is_empty() { + break; + } + } + } + + // reassign the send_out future, which might be still in progress + send_out_buf = send_out; + } + future::Either::Right((send_result, _)) => match send_result { + Ok(()) => { + send_out_buf = self.send_next_queued(&outgoing); + } + Err(flume::SendError(_)) => break, + }, + }; + } + } + + fn send_next_queued<'a>( + &mut self, + outgoing: &'a flume::Sender, + ) -> future::Fuse> { + if let Some(next) = self.queue.pop_front() { + outgoing.send_async(next).fuse() + } else { + future::Fuse::terminated() + } + } +} + +impl Default for InputBuffer { + fn default() -> Self { + Self::new() + } +} diff --git a/binaries/runtime/src/operator/mod.rs b/binaries/runtime/src/operator/mod.rs index 1810b5b7..df2548e2 100644 --- a/binaries/runtime/src/operator/mod.rs +++ b/binaries/runtime/src/operator/mod.rs @@ -12,18 +12,19 @@ use pyo3::{ IntoPy, PyObject, Python, }; use std::any::Any; -use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::mpsc::Sender; #[cfg(not(feature = "tracing"))] type Tracer = (); +pub mod channel; mod python; mod shared_lib; pub fn run_operator( node_id: &NodeId, operator_definition: OperatorDefinition, - incoming_events: Receiver, + incoming_events: flume::Receiver, events_tx: Sender, ) -> eyre::Result<()> { #[cfg(feature = "tracing")] diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index 5f8655cc..c5406ce5 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -14,7 +14,7 @@ use std::{ panic::{catch_unwind, AssertUnwindSafe}, path::Path, }; -use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::mpsc::Sender; fn traceback(err: pyo3::PyErr) -> eyre::Report { let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok())); @@ -31,7 +31,7 @@ pub fn run( operator_id: &OperatorId, source: &str, events_tx: Sender, - mut incoming_events: Receiver, + incoming_events: flume::Receiver, tracer: Tracer, ) -> eyre::Result<()> { let path = if source_is_url(source) { @@ -100,7 +100,7 @@ pub fn run( Python::with_gil(init_operator).wrap_err("failed to init python operator")?; let reason = loop { - let Some(mut event) = incoming_events.blocking_recv() else { break StopReason::InputsClosed }; + let Ok(mut event) = incoming_events.recv() else { break StopReason::InputsClosed }; if let IncomingEvent::Input { input_id, metadata, .. diff --git a/binaries/runtime/src/operator/shared_lib.rs b/binaries/runtime/src/operator/shared_lib.rs index c9401466..9e6e5666 100644 --- a/binaries/runtime/src/operator/shared_lib.rs +++ b/binaries/runtime/src/operator/shared_lib.rs @@ -19,14 +19,14 @@ use std::{ path::Path, sync::Arc, }; -use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::mpsc::Sender; pub fn run( node_id: &NodeId, operator_id: &OperatorId, source: &str, events_tx: Sender, - incoming_events: Receiver, + incoming_events: flume::Receiver, tracer: Tracer, ) -> eyre::Result<()> { let path = if source_is_url(source) { @@ -78,14 +78,14 @@ pub fn run( } struct SharedLibraryOperator<'lib> { - incoming_events: Receiver, + incoming_events: flume::Receiver, events_tx: Sender, bindings: Bindings<'lib>, } impl<'lib> SharedLibraryOperator<'lib> { - fn run(mut self, tracer: Tracer) -> eyre::Result { + fn run(self, tracer: Tracer) -> eyre::Result { let operator_context = { let DoraInitResult { result, @@ -134,7 +134,7 @@ impl<'lib> SharedLibraryOperator<'lib> { }); let reason = loop { - let Some(mut event) = self.incoming_events.blocking_recv() else { + let Ok(mut event) = self.incoming_events.recv() else { break StopReason::InputsClosed };