diff --git a/apis/rust/node/src/lib.rs b/apis/rust/node/src/lib.rs index 21aa79ef..168611c6 100644 --- a/apis/rust/node/src/lib.rs +++ b/apis/rust/node/src/lib.rs @@ -1,7 +1,7 @@ use communication::CommunicationLayer; use config::{CommunicationConfig, DataId, NodeId, NodeRunConfig}; use eyre::WrapErr; -use futures::{stream::FuturesUnordered, StreamExt}; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use futures_concurrency::Merge; use std::collections::HashSet; @@ -84,7 +84,12 @@ impl DoraNode { .wrap_err_with(|| format!("failed to subscribe on {topic}"))?; stop_messages.push(sub.into_future()); } - let finished = Box::pin(stop_messages.all(|_| async { true })); + let node_id = self.id.clone(); + let finished = Box::pin( + stop_messages + .all(|_| async { true }) + .map(move |_| println!("all inputs finished for node {node_id}")), + ); Ok(streams.merge().take_until(finished)) } @@ -121,8 +126,12 @@ impl Drop for DoraNode { .communication .publish_sync(&topic, &[]) .wrap_err_with(|| format!("failed to send stop message for source `{self_id}`")); - if let Err(err) = result { - tracing::error!("{err}") + match result { + Ok(()) => println!("sent stop message for {self_id}"), + Err(err) => { + println!("error sending stop message for {self_id}: {err:?}"); + tracing::error!("{err:?}") + } } } } diff --git a/apis/rust/operator/src/lib.rs b/apis/rust/operator/src/lib.rs index f9415fb0..efcf6ac4 100644 --- a/apis/rust/operator/src/lib.rs +++ b/apis/rust/operator/src/lib.rs @@ -14,7 +14,13 @@ pub trait DoraOperator: Default { id: &str, data: &[u8], output_sender: &mut DoraOutputSender, - ) -> Result<(), ()>; + ) -> Result; +} + +#[repr(isize)] +pub enum DoraStatus { + Continue = 0, + Stop = 1, } pub struct DoraOutputSender { diff --git a/apis/rust/operator/src/raw.rs b/apis/rust/operator/src/raw.rs index 95361efb..c270537f 100644 --- a/apis/rust/operator/src/raw.rs +++ b/apis/rust/operator/src/raw.rs @@ -45,7 +45,7 @@ pub unsafe fn dora_on_input( let operator: &mut O = unsafe { &mut *operator_context.cast() }; match operator.on_input(id, data, &mut output_sender) { - Ok(()) => 0, + Ok(status) => status as isize, Err(_) => -1, } } diff --git a/binaries/coordinator/examples/mini-dataflow.yml b/binaries/coordinator/examples/mini-dataflow.yml index 6723bf9b..087a6500 100644 --- a/binaries/coordinator/examples/mini-dataflow.yml +++ b/binaries/coordinator/examples/mini-dataflow.yml @@ -60,9 +60,9 @@ nodes: - id: python-operator operator: - python: ../../examples/python-operator/op.py + python: ../../examples/python-operator/op2.py inputs: time: timer/time - dora_time: dora/timer/millis/500 + dora_time: dora/timer/millis/50 outputs: - counter diff --git a/binaries/coordinator/examples/nodes/rust/rate_limit.rs b/binaries/coordinator/examples/nodes/rust/rate_limit.rs index 4b581f13..3b3c2cea 100644 --- a/binaries/coordinator/examples/nodes/rust/rate_limit.rs +++ b/binaries/coordinator/examples/nodes/rust/rate_limit.rs @@ -46,5 +46,7 @@ async fn main() -> eyre::Result<()> { } } + println!("rate limit finished"); + Ok(()) } diff --git a/binaries/coordinator/examples/nodes/rust/sink_logger.rs b/binaries/coordinator/examples/nodes/rust/sink_logger.rs index 8eb7d40c..cb499ba1 100644 --- a/binaries/coordinator/examples/nodes/rust/sink_logger.rs +++ b/binaries/coordinator/examples/nodes/rust/sink_logger.rs @@ -12,7 +12,7 @@ async fn main() -> eyre::Result<()> { let mut last_timestamp = None; loop { - let timeout = Duration::from_secs(2); + let timeout = Duration::from_secs(5); let input = match tokio::time::timeout(timeout, inputs.next()).await { Ok(Some(input)) => input, Ok(None) => break, diff --git a/binaries/coordinator/src/main.rs b/binaries/coordinator/src/main.rs index 094c5ff3..c3a972a2 100644 --- a/binaries/coordinator/src/main.rs +++ b/binaries/coordinator/src/main.rs @@ -100,11 +100,10 @@ async fn run_dataflow(dataflow_path: PathBuf, runtime: &Path) -> eyre::Result<() } for interval in dora_timers { - let communication = communication.clone(); - let task = tokio::spawn(async move { - let communication = communication::init(&communication) - .await - .wrap_err("failed to init communication layer")?; + let communication = communication::init(&communication) + .await + .wrap_err("failed to init communication layer")?; + tokio::spawn(async move { let topic = { let duration = format_duration(interval); format!("dora/timer/{duration}") @@ -112,13 +111,9 @@ async fn run_dataflow(dataflow_path: PathBuf, runtime: &Path) -> eyre::Result<() let mut stream = IntervalStream::new(tokio::time::interval(interval)); while (stream.next().await).is_some() { let publish = communication.publish(&topic, &[]); - publish - .await - .wrap_err("failed to publish timer tick message")?; + publish.await.expect("failed to publish timer tick message"); } - Ok(()) }); - tasks.push(task); } while let Some(task_result) = tasks.next().await { diff --git a/binaries/runtime/src/main.rs b/binaries/runtime/src/main.rs index 27f25b23..ce938613 100644 --- a/binaries/runtime/src/main.rs +++ b/binaries/runtime/src/main.rs @@ -15,7 +15,7 @@ use futures::{ use futures_concurrency::Merge; use operator::{Operator, OperatorEvent}; use std::{ - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, BTreeSet, HashMap}, mem, pin::Pin, }; @@ -45,6 +45,7 @@ async fn main() -> eyre::Result<()> { }; let mut operator_map = BTreeMap::new(); + let mut stopped_operators = BTreeSet::new(); let mut operator_events = StreamMap::new(); let mut operator_events_tx = HashMap::new(); for operator_config in &operators { @@ -72,15 +73,19 @@ async fn main() -> eyre::Result<()> { match event { Event::External(event) => match event { SubscribeEvent::Input(input) => { - let operator = - operator_map - .get_mut(&input.target_operator) - .ok_or_else(|| { - eyre!( + let operator = match operator_map.get_mut(&input.target_operator) { + Some(op) => op, + None => { + if stopped_operators.contains(&input.target_operator) { + continue; // operator was stopped already -> ignore input + } else { + bail!( "received input for unexpected operator `{}`", input.target_operator - ) - })?; + ); + } + } + }; operator .handle_input(input.id.clone(), input.data) @@ -92,14 +97,18 @@ async fn main() -> eyre::Result<()> { })?; } SubscribeEvent::InputsStopped { target_operator } => { - let events_tx = operator_events_tx.get(&target_operator).ok_or_else(|| { - eyre!("failed to get events_tx for operator {target_operator}") - })?; - - let events_tx = events_tx.clone(); - tokio::spawn(async move { - let _ = events_tx.send(OperatorEvent::EndOfInput).await; - }); + println!("all inputs finished for operator {node_id}/{target_operator}"); + match operator_map.get_mut(&target_operator) { + Some(op) => op.close_input_stream(), + None => { + if !stopped_operators.contains(&target_operator) { + bail!( + "received InputsStopped event for unknown operator `{}`", + target_operator + ); + } + } + } } }, Event::Operator { id, event } => { @@ -119,9 +128,10 @@ async fn main() -> eyre::Result<()> { bail!(err.wrap_err(format!("operator {id} failed"))) } OperatorEvent::Panic(payload) => std::panic::resume_unwind(payload), - OperatorEvent::EndOfInput => { + OperatorEvent::Finished => { if operator_map.remove(&id).is_some() { println!("operator {node_id}/{id} finished"); + stopped_operators.insert(id.clone()); // send stopped message publish( &node_id, diff --git a/binaries/runtime/src/operator/mod.rs b/binaries/runtime/src/operator/mod.rs index 580a239f..852334da 100644 --- a/binaries/runtime/src/operator/mod.rs +++ b/binaries/runtime/src/operator/mod.rs @@ -8,7 +8,7 @@ mod python; mod shared_lib; pub struct Operator { - operator_task: Sender, + operator_task: Option>, definition: OperatorDefinition, } @@ -41,13 +41,20 @@ impl Operator { } } Ok(Self { - operator_task, + operator_task: Some(operator_task), definition: operator_definition, }) } pub fn handle_input(&mut self, id: DataId, value: Vec) -> eyre::Result<()> { self.operator_task + .as_mut() + .ok_or_else(|| { + eyre!( + "input channel for {} was already closed", + self.definition.id + ) + })? .try_send(OperatorInput { id, value }) .map_err(|err| match err { tokio::sync::mpsc::error::TrySendError::Closed(_) => eyre!("operator crashed"), @@ -55,6 +62,10 @@ impl Operator { }) } + pub fn close_input_stream(&mut self) { + self.operator_task = None; + } + /// Get a reference to the operator's definition. #[must_use] pub fn definition(&self) -> &OperatorDefinition { @@ -66,7 +77,7 @@ pub enum OperatorEvent { Output { id: DataId, value: Vec }, Error(eyre::Error), Panic(Box), - EndOfInput, + Finished, } pub struct OperatorInput { diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index 1d2184f7..b4f3ba29 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -76,7 +76,7 @@ pub fn spawn( Python::with_gil(init_operator).wrap_err("failed to init python operator")?; while let Some(input) = inputs.blocking_recv() { - Python::with_gil(|py| -> eyre::Result<_> { + let status_enum = Python::with_gil(|py| { operator .call_method1( py, @@ -89,6 +89,15 @@ pub fn spawn( ) .map_err(traceback) })?; + let status_val = Python::with_gil(|py| status_enum.getattr(py, "value")) + .wrap_err("on_input must have enum return value")?; + let status: i32 = Python::with_gil(|py| status_val.extract(py)) + .wrap_err("on_input has invalid return value")?; + match status { + 0 => {} // ok + 1 => break, // stop + other => bail!("on_input returned invalid status {other}"), + } } Python::with_gil(|py| { @@ -112,7 +121,9 @@ pub fn spawn( }); match catch_unwind(closure) { - Ok(Ok(())) => {} + Ok(Ok(())) => { + let _ = events_tx.blocking_send(OperatorEvent::Finished); + } Ok(Err(err)) => { let _ = events_tx.blocking_send(OperatorEvent::Error(err)); } diff --git a/binaries/runtime/src/operator/shared_lib.rs b/binaries/runtime/src/operator/shared_lib.rs index fc338fff..1c84d248 100644 --- a/binaries/runtime/src/operator/shared_lib.rs +++ b/binaries/runtime/src/operator/shared_lib.rs @@ -32,7 +32,9 @@ pub fn spawn( operator.run() }); match catch_unwind(closure) { - Ok(Ok(())) => {} + Ok(Ok(())) => { + let _ = events_tx.blocking_send(OperatorEvent::Finished); + } Ok(Err(err)) => { let _ = events_tx.blocking_send(OperatorEvent::Error(err)); } @@ -95,8 +97,11 @@ impl<'lib> SharedLibraryOperator<'lib> { operator_context.raw, ) }; - if result != 0 { - bail!("on_input failed with error code {result}"); + match result { + 0 => {} // DoraStatus::Continue + 1 => break, // DoraStatus::Stop + -1 => bail!("on_input failed"), + other => bail!("on_input finished with unexpected exit code {other}"), } } Ok(()) diff --git a/examples/example-operator/src/lib.rs b/examples/example-operator/src/lib.rs index 052fb4fc..e0b1a4ba 100644 --- a/examples/example-operator/src/lib.rs +++ b/examples/example-operator/src/lib.rs @@ -1,6 +1,6 @@ #![warn(unsafe_op_in_unsafe_fn)] -use dora_operator_api::{register_operator, DoraOperator, DoraOutputSender}; +use dora_operator_api::{register_operator, DoraOperator, DoraOutputSender, DoraStatus}; register_operator!(ExampleOperator); @@ -15,7 +15,7 @@ impl DoraOperator for ExampleOperator { id: &str, data: &[u8], output_sender: &mut DoraOutputSender, - ) -> Result<(), ()> { + ) -> Result { match id { "time" => { let parsed = std::str::from_utf8(data).map_err(|_| ())?; @@ -35,6 +35,6 @@ impl DoraOperator for ExampleOperator { } other => eprintln!("ignoring unexpected input {other}"), } - Ok(()) + Ok(DoraStatus::Continue) } } diff --git a/examples/python-operator/op.py b/examples/python-operator/op.py index afcfaa91..bffe8508 100644 --- a/examples/python-operator/op.py +++ b/examples/python-operator/op.py @@ -1,5 +1,9 @@ from typing import Callable +from enum import Enum +class DoraStatus(Enum): + CONTINUE = 0 + STOP = 1 class Operator: """ @@ -26,5 +30,7 @@ class Operator: """ val_len = len(value) print(f"PYTHON received input {input_id}; value length: {val_len}") - send_output("counter", self.counter.to_bytes(1, "little")) - self.counter = (self.counter + 1) % 256 + send_output("counter", (self.counter % 256).to_bytes(1, "little")) + self.counter = self.counter + 1 + + return DoraStatus.OK diff --git a/examples/python-operator/op2.py b/examples/python-operator/op2.py new file mode 100644 index 00000000..0053b1a4 --- /dev/null +++ b/examples/python-operator/op2.py @@ -0,0 +1,39 @@ +from typing import Callable +from enum import Enum + +class DoraStatus(Enum): + OK = 0 + STOP = 1 + +class Operator: + """ + Example operator incrementing a counter every times its been called. + + The current value of the counter is sent back to dora on `counter`. + """ + + def __init__(self, counter=0): + self.counter = counter + + def on_input( + self, + input_id: str, + value: bytes, + send_output: Callable[[str, bytes], None], + ): + """Handle input by incrementing count by one. + + Args: + input_id (str): Id of the input declared in the yaml configuration + value (bytes): Bytes message of the input + send_output (Callable[[str, bytes]]): Function enabling sending output back to dora. + """ + val_len = len(value) + print(f"PYTHON received input {input_id}; value length: {val_len}") + send_output("counter", (self.counter % 256).to_bytes(1, "little")) + self.counter = self.counter + 1 + + if self.counter > 500: + return DoraStatus.STOP + else: + return DoraStatus.OK