From 07063e4ff96bcd48561bf210ac8e2c6e53c4d359 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Thu, 23 Mar 2023 17:48:31 +0100 Subject: [PATCH] dora-runtime: Only subscribe to daemon once all operators are ready --- binaries/runtime/src/lib.rs | 63 ++++++++++++++------- binaries/runtime/src/operator/mod.rs | 5 +- binaries/runtime/src/operator/python.rs | 5 +- binaries/runtime/src/operator/shared_lib.rs | 9 ++- 4 files changed, 55 insertions(+), 27 deletions(-) diff --git a/binaries/runtime/src/lib.rs b/binaries/runtime/src/lib.rs index 37ae8791..70df334f 100644 --- a/binaries/runtime/src/lib.rs +++ b/binaries/runtime/src/lib.rs @@ -2,7 +2,7 @@ use dora_core::{ config::{DataId, OperatorId}, - daemon_messages::RuntimeConfig, + daemon_messages::{NodeConfig, RuntimeConfig}, descriptor::OperatorConfig, }; use dora_node_api::DoraNode; @@ -17,7 +17,10 @@ use std::{ collections::{BTreeMap, BTreeSet, HashMap}, mem, }; -use tokio::{runtime::Builder, sync::mpsc}; +use tokio::{ + runtime::Builder, + sync::{mpsc, oneshot}, +}; use tokio_stream::wrappers::ReceiverStream; mod operator; @@ -35,7 +38,6 @@ pub fn main() -> eyre::Result<()> { operators, } = config; let node_id = config.node_id.clone(); - let (node, daemon_events) = DoraNode::init(config)?; let operator_definition = if operators.is_empty() { bail!("no operators"); @@ -52,21 +54,7 @@ pub fn main() -> eyre::Result<()> { id: operator_id.clone(), event, }); - let daemon_events = Box::pin(futures::stream::unfold(daemon_events, |mut stream| async { - let event = stream.recv_async().await.map(|event| match event { - dora_node_api::Event::Stop => Event::Stop, - dora_node_api::Event::Input { id, metadata, data } => Event::Input { - id, - metadata, - data: data.map(|data| data.to_owned()), - }, - dora_node_api::Event::InputClosed { id } => Event::InputClosed(id), - dora_node_api::Event::Error(err) => Event::Error(err), - _ => todo!(), - }); - event.map(|event| (event, stream)) - })); - let events = (operator_events, daemon_events).merge(); + let tokio_runtime = Builder::new_current_thread() .enable_all() .build() @@ -85,8 +73,15 @@ pub fn main() -> eyre::Result<()> { )] .into_iter() .collect(); + let (init_done_tx, init_done) = oneshot::channel(); let main_task = std::thread::spawn(move || -> Result<()> { - tokio_runtime.block_on(run(node, operator_config, events, operator_channels)) + tokio_runtime.block_on(run( + operator_config, + config, + operator_events, + operator_channels, + init_done, + )) }); let operator_id = operator_definition.id.clone(); @@ -95,6 +90,7 @@ pub fn main() -> eyre::Result<()> { operator_definition, incoming_events, operator_events_tx, + init_done_tx, ) .wrap_err_with(|| format!("failed to run operator {operator_id}"))?; @@ -115,12 +111,13 @@ fn queue_sizes(config: &OperatorConfig) -> std::collections::BTreeMap, - mut events: impl Stream + Unpin, + config: NodeConfig, + operator_events: impl Stream + Unpin, mut operator_channels: HashMap>, + init_done: oneshot::Receiver<()>, ) -> eyre::Result<()> { #[cfg(feature = "metrics")] let _started = { @@ -134,6 +131,28 @@ async fn run( _started }; + init_done + .await + .wrap_err("the `init_done` channel was closed unexpectedly")?; + tracing::info!("All operators are ready, starting runtime"); + + let (mut node, daemon_events) = DoraNode::init(config)?; + let daemon_events = Box::pin(futures::stream::unfold(daemon_events, |mut stream| async { + let event = stream.recv_async().await.map(|event| match event { + dora_node_api::Event::Stop => Event::Stop, + dora_node_api::Event::Input { id, metadata, data } => Event::Input { + id, + metadata, + data: data.map(|data| data.to_owned()), + }, + dora_node_api::Event::InputClosed { id } => Event::InputClosed(id), + dora_node_api::Event::Error(err) => Event::Error(err), + _ => todo!(), + }); + event.map(|event| (event, stream)) + })); + let mut events = (operator_events, daemon_events).merge(); + let mut open_operator_inputs: HashMap<_, BTreeSet<_>> = operators .iter() .map(|(id, config)| (id, config.inputs.keys().collect())) diff --git a/binaries/runtime/src/operator/mod.rs b/binaries/runtime/src/operator/mod.rs index ef1fbf61..90291d9c 100644 --- a/binaries/runtime/src/operator/mod.rs +++ b/binaries/runtime/src/operator/mod.rs @@ -12,7 +12,7 @@ use pyo3::{ IntoPy, PyObject, Python, }; use std::any::Any; -use tokio::sync::mpsc::Sender; +use tokio::sync::{mpsc::Sender, oneshot}; #[cfg(not(feature = "telemetry"))] type Tracer = (); @@ -26,6 +26,7 @@ pub fn run_operator( operator_definition: OperatorDefinition, incoming_events: flume::Receiver, events_tx: Sender, + init_done: oneshot::Sender<()>, ) -> eyre::Result<()> { #[cfg(feature = "telemetry")] let tracer = dora_tracing::telemetry::init_tracing( @@ -45,6 +46,7 @@ pub fn run_operator( events_tx, incoming_events, tracer, + init_done, ) .wrap_err_with(|| { format!( @@ -61,6 +63,7 @@ pub fn run_operator( events_tx, incoming_events, tracer, + init_done, ) .wrap_err_with(|| { format!( diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index ecdc74f2..ee21d213 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -14,7 +14,7 @@ use std::{ panic::{catch_unwind, AssertUnwindSafe}, path::Path, }; -use tokio::sync::mpsc::Sender; +use tokio::sync::{mpsc::Sender, oneshot}; fn traceback(err: pyo3::PyErr) -> eyre::Report { let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok())); @@ -33,6 +33,7 @@ pub fn run( events_tx: Sender, incoming_events: flume::Receiver, tracer: Tracer, + init_done: oneshot::Sender<()>, ) -> eyre::Result<()> { let path = if source_is_url(source) { let target_path = Path::new("build") @@ -99,6 +100,8 @@ pub fn run( let operator = Python::with_gil(init_operator).wrap_err("failed to init python operator")?; + let _ = init_done.send(()); + let reason = loop { let Ok(mut event) = incoming_events.recv() else { break StopReason::InputsClosed }; diff --git a/binaries/runtime/src/operator/shared_lib.rs b/binaries/runtime/src/operator/shared_lib.rs index e2221ccd..2e1c463e 100644 --- a/binaries/runtime/src/operator/shared_lib.rs +++ b/binaries/runtime/src/operator/shared_lib.rs @@ -19,7 +19,7 @@ use std::{ path::Path, sync::Arc, }; -use tokio::sync::mpsc::Sender; +use tokio::sync::{mpsc::Sender, oneshot}; pub fn run( node_id: &NodeId, @@ -28,6 +28,7 @@ pub fn run( events_tx: Sender, incoming_events: flume::Receiver, tracer: Tracer, + init_done: oneshot::Sender<()>, ) -> eyre::Result<()> { let path = if source_is_url(source) { let target_path = adjust_shared_library_path( @@ -60,7 +61,7 @@ pub fn run( events_tx: events_tx.clone(), }; - operator.run(tracer) + operator.run(tracer, init_done) }); match catch_unwind(closure) { Ok(Ok(reason)) => { @@ -85,7 +86,7 @@ struct SharedLibraryOperator<'lib> { } impl<'lib> SharedLibraryOperator<'lib> { - fn run(self, tracer: Tracer) -> eyre::Result { + fn run(self, tracer: Tracer, init_done: oneshot::Sender<()>) -> eyre::Result { let operator_context = { let DoraInitResult { result, @@ -101,6 +102,8 @@ impl<'lib> SharedLibraryOperator<'lib> { } }; + let _ = init_done.send(()); + let send_output_closure = Arc::new(move |output: Output| { let Output { id: output_id,