diff --git a/binaries/daemon/src/listener/mod.rs b/binaries/daemon/src/listener/mod.rs index 28864506..149aa57c 100644 --- a/binaries/daemon/src/listener/mod.rs +++ b/binaries/daemon/src/listener/mod.rs @@ -7,6 +7,10 @@ use dora_core::{ }, }; use eyre::{eyre, Context}; +use futures::{ + future::{self, Fuse}, + FutureExt, +}; use shared_memory_server::{ShmemConf, ShmemServer}; use std::{ collections::{BTreeMap, VecDeque}, @@ -95,21 +99,17 @@ pub async fn spawn_listener_loop( } } -struct Listener { +struct Listener { dataflow_id: DataflowId, node_id: NodeId, daemon_tx: mpsc::Sender, subscribed_events: Option>, queue: VecDeque>>, queue_sizes: BTreeMap, - connection: C, } -impl Listener -where - C: Connection, -{ - pub(crate) async fn run( +impl Listener { + pub(crate) async fn run( mut connection: C, daemon_tx: mpsc::Sender, queue_sizes: BTreeMap, @@ -146,13 +146,16 @@ where let mut listener = Listener { dataflow_id, node_id, - connection, daemon_tx, subscribed_events: None, queue_sizes, queue: VecDeque::new(), }; - match listener.run_inner().await.wrap_err("listener failed") { + match listener + .run_inner(connection) + .await + .wrap_err("listener failed") + { Ok(()) => {} Err(err) => tracing::error!("{err:?}"), } @@ -176,34 +179,44 @@ where } } - async fn run_inner(&mut self) -> eyre::Result<()> { + async fn run_inner(&mut self, mut connection: C) -> eyre::Result<()> { loop { - // receive the next node message - let message = match self - .connection - .receive_message() - .await - .wrap_err("failed to receive DaemonRequest") - { - Ok(Some(m)) => m, + let mut next_message = connection.receive_message(); + let message = loop { + let next_event = if let Some(events) = &mut self.subscribed_events { + Box::pin(events.recv()).fuse() + } else { + Fuse::terminated() + }; + let event = match future::select(next_event, next_message).await { + future::Either::Left((event, n)) => { + next_message = n; + event + } + future::Either::Right((message, _)) => break message, + }; + if let Some(event) = event { + self.queue.push_back(Box::new(Some(event))); + self.handle_events().await?; + } + }; + + match message.wrap_err("failed to receive DaemonRequest") { + Ok(Some(message)) => { + self.handle_message(message, &mut connection).await?; + } + Err(err) => { + tracing::warn!("{err:?}"); + } Ok(None) => { tracing::debug!( "channel disconnected: {}/{}", self.dataflow_id, self.node_id ); - break; - } // disconnected - Err(err) => { - tracing::warn!("{err:?}"); - continue; + break; // disconnected } - }; - - // handle incoming events - self.handle_events().await?; - - self.handle_message(message).await?; + } } Ok(()) } @@ -255,19 +268,27 @@ where Ok(()) } - #[tracing::instrument(skip(self), fields(%self.dataflow_id, %self.node_id))] - async fn handle_message(&mut self, message: DaemonRequest) -> eyre::Result<()> { + #[tracing::instrument(skip(self, connection), fields(%self.dataflow_id, %self.node_id))] + async fn handle_message( + &mut self, + message: DaemonRequest, + connection: &mut C, + ) -> eyre::Result<()> { match message { DaemonRequest::Register { .. } => { let reply = DaemonReply::Result(Err("unexpected register message".into())); - self.send_reply(reply) + self.send_reply(reply, connection) .await .wrap_err("failed to send register reply")?; } DaemonRequest::Stopped => { let (reply_sender, reply) = oneshot::channel(); - self.process_daemon_event(DaemonNodeEvent::Stopped { reply_sender }, Some(reply)) - .await? + self.process_daemon_event( + DaemonNodeEvent::Stopped { reply_sender }, + Some(reply), + connection, + ) + .await? } DaemonRequest::CloseOutputs(outputs) => { let (reply_sender, reply) = oneshot::channel(); @@ -277,6 +298,7 @@ where reply_sender, }, Some(reply), + connection, ) .await? } @@ -290,7 +312,7 @@ where metadata, data, }; - self.process_daemon_event(event, None).await?; + self.process_daemon_event(event, None, connection).await?; } DaemonRequest::Subscribe => { let (tx, rx) = mpsc::unbounded_channel(); @@ -301,6 +323,7 @@ where reply_sender, }, Some(reply), + connection, ) .await?; self.subscribed_events = Some(rx); @@ -330,7 +353,7 @@ where DaemonReply::NextEvents(queued_events) }; - self.send_reply(reply) + self.send_reply(reply, connection) .await .wrap_err("failed to send NextEvent reply")?; } @@ -343,18 +366,26 @@ where drop_tokens: Vec, ) -> eyre::Result<()> { if !drop_tokens.is_empty() { - let drop_event = DaemonNodeEvent::ReportDrop { - tokens: drop_tokens, + let event = Event::Node { + dataflow_id: self.dataflow_id, + node_id: self.node_id.clone(), + event: DaemonNodeEvent::ReportDrop { + tokens: drop_tokens, + }, }; - self.process_daemon_event(drop_event, None).await?; + self.daemon_tx + .send(event) + .await + .map_err(|_| eyre!("failed to report drop tokens to daemon"))?; } Ok(()) } - async fn process_daemon_event( + async fn process_daemon_event( &mut self, event: DaemonNodeEvent, reply: Option>, + connection: &mut C, ) -> eyre::Result<()> { // send NodeEvent to daemon main loop let event = Event::Node { @@ -373,12 +404,16 @@ where } else { DaemonReply::Empty }; - self.send_reply(reply).await?; + self.send_reply(reply, connection).await?; Ok(()) } - async fn send_reply(&mut self, reply: DaemonReply) -> eyre::Result<()> { - self.connection + async fn send_reply( + &mut self, + reply: DaemonReply, + connection: &mut C, + ) -> eyre::Result<()> { + connection .send_reply(reply) .await .wrap_err_with(|| format!("failed to send reply to node `{}`", self.node_id))