Allow operators to stop themselvestags/v0.0.0-test.4
| @@ -1,7 +1,7 @@ | |||
| use communication::CommunicationLayer; | |||
| use config::{CommunicationConfig, DataId, NodeId, NodeRunConfig}; | |||
| use eyre::WrapErr; | |||
| use futures::{stream::FuturesUnordered, StreamExt}; | |||
| use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; | |||
| use futures_concurrency::Merge; | |||
| use std::collections::HashSet; | |||
| @@ -84,7 +84,12 @@ impl DoraNode { | |||
| .wrap_err_with(|| format!("failed to subscribe on {topic}"))?; | |||
| stop_messages.push(sub.into_future()); | |||
| } | |||
| let finished = Box::pin(stop_messages.all(|_| async { true })); | |||
| let node_id = self.id.clone(); | |||
| let finished = Box::pin( | |||
| stop_messages | |||
| .all(|_| async { true }) | |||
| .map(move |_| println!("all inputs finished for node {node_id}")), | |||
| ); | |||
| Ok(streams.merge().take_until(finished)) | |||
| } | |||
| @@ -121,8 +126,12 @@ impl Drop for DoraNode { | |||
| .communication | |||
| .publish_sync(&topic, &[]) | |||
| .wrap_err_with(|| format!("failed to send stop message for source `{self_id}`")); | |||
| if let Err(err) = result { | |||
| tracing::error!("{err}") | |||
| match result { | |||
| Ok(()) => println!("sent stop message for {self_id}"), | |||
| Err(err) => { | |||
| println!("error sending stop message for {self_id}: {err:?}"); | |||
| tracing::error!("{err:?}") | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -14,7 +14,13 @@ pub trait DoraOperator: Default { | |||
| id: &str, | |||
| data: &[u8], | |||
| output_sender: &mut DoraOutputSender, | |||
| ) -> Result<(), ()>; | |||
| ) -> Result<DoraStatus, ()>; | |||
| } | |||
| #[repr(isize)] | |||
| pub enum DoraStatus { | |||
| Continue = 0, | |||
| Stop = 1, | |||
| } | |||
| pub struct DoraOutputSender { | |||
| @@ -45,7 +45,7 @@ pub unsafe fn dora_on_input<O: DoraOperator>( | |||
| let operator: &mut O = unsafe { &mut *operator_context.cast() }; | |||
| match operator.on_input(id, data, &mut output_sender) { | |||
| Ok(()) => 0, | |||
| Ok(status) => status as isize, | |||
| Err(_) => -1, | |||
| } | |||
| } | |||
| @@ -60,9 +60,9 @@ nodes: | |||
| - id: python-operator | |||
| operator: | |||
| python: ../../examples/python-operator/op.py | |||
| python: ../../examples/python-operator/op2.py | |||
| inputs: | |||
| time: timer/time | |||
| dora_time: dora/timer/millis/500 | |||
| dora_time: dora/timer/millis/50 | |||
| outputs: | |||
| - counter | |||
| @@ -46,5 +46,7 @@ async fn main() -> eyre::Result<()> { | |||
| } | |||
| } | |||
| println!("rate limit finished"); | |||
| Ok(()) | |||
| } | |||
| @@ -12,7 +12,7 @@ async fn main() -> eyre::Result<()> { | |||
| let mut last_timestamp = None; | |||
| loop { | |||
| let timeout = Duration::from_secs(2); | |||
| let timeout = Duration::from_secs(5); | |||
| let input = match tokio::time::timeout(timeout, inputs.next()).await { | |||
| Ok(Some(input)) => input, | |||
| Ok(None) => break, | |||
| @@ -100,11 +100,10 @@ async fn run_dataflow(dataflow_path: PathBuf, runtime: &Path) -> eyre::Result<() | |||
| } | |||
| for interval in dora_timers { | |||
| let communication = communication.clone(); | |||
| let task = tokio::spawn(async move { | |||
| let communication = communication::init(&communication) | |||
| .await | |||
| .wrap_err("failed to init communication layer")?; | |||
| let communication = communication::init(&communication) | |||
| .await | |||
| .wrap_err("failed to init communication layer")?; | |||
| tokio::spawn(async move { | |||
| let topic = { | |||
| let duration = format_duration(interval); | |||
| format!("dora/timer/{duration}") | |||
| @@ -112,13 +111,9 @@ async fn run_dataflow(dataflow_path: PathBuf, runtime: &Path) -> eyre::Result<() | |||
| let mut stream = IntervalStream::new(tokio::time::interval(interval)); | |||
| while (stream.next().await).is_some() { | |||
| let publish = communication.publish(&topic, &[]); | |||
| publish | |||
| .await | |||
| .wrap_err("failed to publish timer tick message")?; | |||
| publish.await.expect("failed to publish timer tick message"); | |||
| } | |||
| Ok(()) | |||
| }); | |||
| tasks.push(task); | |||
| } | |||
| while let Some(task_result) = tasks.next().await { | |||
| @@ -15,7 +15,7 @@ use futures::{ | |||
| use futures_concurrency::Merge; | |||
| use operator::{Operator, OperatorEvent}; | |||
| use std::{ | |||
| collections::{BTreeMap, HashMap}, | |||
| collections::{BTreeMap, BTreeSet, HashMap}, | |||
| mem, | |||
| pin::Pin, | |||
| }; | |||
| @@ -45,6 +45,7 @@ async fn main() -> eyre::Result<()> { | |||
| }; | |||
| let mut operator_map = BTreeMap::new(); | |||
| let mut stopped_operators = BTreeSet::new(); | |||
| let mut operator_events = StreamMap::new(); | |||
| let mut operator_events_tx = HashMap::new(); | |||
| for operator_config in &operators { | |||
| @@ -72,15 +73,19 @@ async fn main() -> eyre::Result<()> { | |||
| match event { | |||
| Event::External(event) => match event { | |||
| SubscribeEvent::Input(input) => { | |||
| let operator = | |||
| operator_map | |||
| .get_mut(&input.target_operator) | |||
| .ok_or_else(|| { | |||
| eyre!( | |||
| let operator = match operator_map.get_mut(&input.target_operator) { | |||
| Some(op) => op, | |||
| None => { | |||
| if stopped_operators.contains(&input.target_operator) { | |||
| continue; // operator was stopped already -> ignore input | |||
| } else { | |||
| bail!( | |||
| "received input for unexpected operator `{}`", | |||
| input.target_operator | |||
| ) | |||
| })?; | |||
| ); | |||
| } | |||
| } | |||
| }; | |||
| operator | |||
| .handle_input(input.id.clone(), input.data) | |||
| @@ -92,14 +97,18 @@ async fn main() -> eyre::Result<()> { | |||
| })?; | |||
| } | |||
| SubscribeEvent::InputsStopped { target_operator } => { | |||
| let events_tx = operator_events_tx.get(&target_operator).ok_or_else(|| { | |||
| eyre!("failed to get events_tx for operator {target_operator}") | |||
| })?; | |||
| let events_tx = events_tx.clone(); | |||
| tokio::spawn(async move { | |||
| let _ = events_tx.send(OperatorEvent::EndOfInput).await; | |||
| }); | |||
| println!("all inputs finished for operator {node_id}/{target_operator}"); | |||
| match operator_map.get_mut(&target_operator) { | |||
| Some(op) => op.close_input_stream(), | |||
| None => { | |||
| if !stopped_operators.contains(&target_operator) { | |||
| bail!( | |||
| "received InputsStopped event for unknown operator `{}`", | |||
| target_operator | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }, | |||
| Event::Operator { id, event } => { | |||
| @@ -119,9 +128,10 @@ async fn main() -> eyre::Result<()> { | |||
| bail!(err.wrap_err(format!("operator {id} failed"))) | |||
| } | |||
| OperatorEvent::Panic(payload) => std::panic::resume_unwind(payload), | |||
| OperatorEvent::EndOfInput => { | |||
| OperatorEvent::Finished => { | |||
| if operator_map.remove(&id).is_some() { | |||
| println!("operator {node_id}/{id} finished"); | |||
| stopped_operators.insert(id.clone()); | |||
| // send stopped message | |||
| publish( | |||
| &node_id, | |||
| @@ -8,7 +8,7 @@ mod python; | |||
| mod shared_lib; | |||
| pub struct Operator { | |||
| operator_task: Sender<OperatorInput>, | |||
| operator_task: Option<Sender<OperatorInput>>, | |||
| definition: OperatorDefinition, | |||
| } | |||
| @@ -41,13 +41,20 @@ impl Operator { | |||
| } | |||
| } | |||
| Ok(Self { | |||
| operator_task, | |||
| operator_task: Some(operator_task), | |||
| definition: operator_definition, | |||
| }) | |||
| } | |||
| pub fn handle_input(&mut self, id: DataId, value: Vec<u8>) -> eyre::Result<()> { | |||
| self.operator_task | |||
| .as_mut() | |||
| .ok_or_else(|| { | |||
| eyre!( | |||
| "input channel for {} was already closed", | |||
| self.definition.id | |||
| ) | |||
| })? | |||
| .try_send(OperatorInput { id, value }) | |||
| .map_err(|err| match err { | |||
| tokio::sync::mpsc::error::TrySendError::Closed(_) => eyre!("operator crashed"), | |||
| @@ -55,6 +62,10 @@ impl Operator { | |||
| }) | |||
| } | |||
| pub fn close_input_stream(&mut self) { | |||
| self.operator_task = None; | |||
| } | |||
| /// Get a reference to the operator's definition. | |||
| #[must_use] | |||
| pub fn definition(&self) -> &OperatorDefinition { | |||
| @@ -66,7 +77,7 @@ pub enum OperatorEvent { | |||
| Output { id: DataId, value: Vec<u8> }, | |||
| Error(eyre::Error), | |||
| Panic(Box<dyn Any + Send>), | |||
| EndOfInput, | |||
| Finished, | |||
| } | |||
| pub struct OperatorInput { | |||
| @@ -76,7 +76,7 @@ pub fn spawn( | |||
| Python::with_gil(init_operator).wrap_err("failed to init python operator")?; | |||
| while let Some(input) = inputs.blocking_recv() { | |||
| Python::with_gil(|py| -> eyre::Result<_> { | |||
| let status_enum = Python::with_gil(|py| { | |||
| operator | |||
| .call_method1( | |||
| py, | |||
| @@ -89,6 +89,15 @@ pub fn spawn( | |||
| ) | |||
| .map_err(traceback) | |||
| })?; | |||
| let status_val = Python::with_gil(|py| status_enum.getattr(py, "value")) | |||
| .wrap_err("on_input must have enum return value")?; | |||
| let status: i32 = Python::with_gil(|py| status_val.extract(py)) | |||
| .wrap_err("on_input has invalid return value")?; | |||
| match status { | |||
| 0 => {} // ok | |||
| 1 => break, // stop | |||
| other => bail!("on_input returned invalid status {other}"), | |||
| } | |||
| } | |||
| Python::with_gil(|py| { | |||
| @@ -112,7 +121,9 @@ pub fn spawn( | |||
| }); | |||
| match catch_unwind(closure) { | |||
| Ok(Ok(())) => {} | |||
| Ok(Ok(())) => { | |||
| let _ = events_tx.blocking_send(OperatorEvent::Finished); | |||
| } | |||
| Ok(Err(err)) => { | |||
| let _ = events_tx.blocking_send(OperatorEvent::Error(err)); | |||
| } | |||
| @@ -32,7 +32,9 @@ pub fn spawn( | |||
| operator.run() | |||
| }); | |||
| match catch_unwind(closure) { | |||
| Ok(Ok(())) => {} | |||
| Ok(Ok(())) => { | |||
| let _ = events_tx.blocking_send(OperatorEvent::Finished); | |||
| } | |||
| Ok(Err(err)) => { | |||
| let _ = events_tx.blocking_send(OperatorEvent::Error(err)); | |||
| } | |||
| @@ -95,8 +97,11 @@ impl<'lib> SharedLibraryOperator<'lib> { | |||
| operator_context.raw, | |||
| ) | |||
| }; | |||
| if result != 0 { | |||
| bail!("on_input failed with error code {result}"); | |||
| match result { | |||
| 0 => {} // DoraStatus::Continue | |||
| 1 => break, // DoraStatus::Stop | |||
| -1 => bail!("on_input failed"), | |||
| other => bail!("on_input finished with unexpected exit code {other}"), | |||
| } | |||
| } | |||
| Ok(()) | |||
| @@ -1,6 +1,6 @@ | |||
| #![warn(unsafe_op_in_unsafe_fn)] | |||
| use dora_operator_api::{register_operator, DoraOperator, DoraOutputSender}; | |||
| use dora_operator_api::{register_operator, DoraOperator, DoraOutputSender, DoraStatus}; | |||
| register_operator!(ExampleOperator); | |||
| @@ -15,7 +15,7 @@ impl DoraOperator for ExampleOperator { | |||
| id: &str, | |||
| data: &[u8], | |||
| output_sender: &mut DoraOutputSender, | |||
| ) -> Result<(), ()> { | |||
| ) -> Result<DoraStatus, ()> { | |||
| match id { | |||
| "time" => { | |||
| let parsed = std::str::from_utf8(data).map_err(|_| ())?; | |||
| @@ -35,6 +35,6 @@ impl DoraOperator for ExampleOperator { | |||
| } | |||
| other => eprintln!("ignoring unexpected input {other}"), | |||
| } | |||
| Ok(()) | |||
| Ok(DoraStatus::Continue) | |||
| } | |||
| } | |||
| @@ -1,5 +1,9 @@ | |||
| from typing import Callable | |||
| from enum import Enum | |||
| class DoraStatus(Enum): | |||
| CONTINUE = 0 | |||
| STOP = 1 | |||
| class Operator: | |||
| """ | |||
| @@ -26,5 +30,7 @@ class Operator: | |||
| """ | |||
| val_len = len(value) | |||
| print(f"PYTHON received input {input_id}; value length: {val_len}") | |||
| send_output("counter", self.counter.to_bytes(1, "little")) | |||
| self.counter = (self.counter + 1) % 256 | |||
| send_output("counter", (self.counter % 256).to_bytes(1, "little")) | |||
| self.counter = self.counter + 1 | |||
| return DoraStatus.OK | |||
| @@ -0,0 +1,39 @@ | |||
| from typing import Callable | |||
| from enum import Enum | |||
| class DoraStatus(Enum): | |||
| OK = 0 | |||
| STOP = 1 | |||
| class Operator: | |||
| """ | |||
| Example operator incrementing a counter every times its been called. | |||
| The current value of the counter is sent back to dora on `counter`. | |||
| """ | |||
| def __init__(self, counter=0): | |||
| self.counter = counter | |||
| def on_input( | |||
| self, | |||
| input_id: str, | |||
| value: bytes, | |||
| send_output: Callable[[str, bytes], None], | |||
| ): | |||
| """Handle input by incrementing count by one. | |||
| Args: | |||
| input_id (str): Id of the input declared in the yaml configuration | |||
| value (bytes): Bytes message of the input | |||
| send_output (Callable[[str, bytes]]): Function enabling sending output back to dora. | |||
| """ | |||
| val_len = len(value) | |||
| print(f"PYTHON received input {input_id}; value length: {val_len}") | |||
| send_output("counter", (self.counter % 256).to_bytes(1, "little")) | |||
| self.counter = self.counter + 1 | |||
| if self.counter > 500: | |||
| return DoraStatus.STOP | |||
| else: | |||
| return DoraStatus.OK | |||