diff --git a/apis/rust/node/src/daemon.rs b/apis/rust/node/src/daemon.rs index 502b7239..0c2f09e1 100644 --- a/apis/rust/node/src/daemon.rs +++ b/apis/rust/node/src/daemon.rs @@ -6,15 +6,16 @@ use dora_core::{ use dora_message::Metadata; use eyre::{bail, eyre, Context}; use shared_memory::{Shmem, ShmemConf}; -use std::{marker::PhantomData, time::Duration}; +use std::{marker::PhantomData, thread::JoinHandle, time::Duration}; pub struct DaemonConnection { pub control_channel: ControlChannel, pub event_stream: EventStream, + pub(crate) event_stream_thread: JoinHandle<()>, } impl DaemonConnection { - pub fn init( + pub(crate) fn init( dataflow_id: DataflowId, node_id: &NodeId, daemon_control_region_id: &str, @@ -23,12 +24,14 @@ impl DaemonConnection { let control_channel = ControlChannel::init(dataflow_id, node_id, daemon_control_region_id) .wrap_err("failed to init control stream")?; - let event_stream = EventStream::init(dataflow_id, node_id, daemon_events_region_id) - .wrap_err("failed to init event stream")?; + let (event_stream, event_stream_thread) = + EventStream::init(dataflow_id, node_id, daemon_events_region_id) + .wrap_err("failed to init event stream")?; Ok(Self { control_channel, event_stream, + event_stream_thread, }) } } @@ -132,7 +135,7 @@ impl EventStream { dataflow_id: DataflowId, node_id: &NodeId, daemon_events_region_id: &str, - ) -> eyre::Result { + ) -> eyre::Result<(Self, JoinHandle<()>)> { let daemon_events_region = ShmemConf::new() .os_id(daemon_events_region_id) .open() @@ -151,7 +154,7 @@ impl EventStream { let (tx, rx) = flume::bounded(1); let mut drop_tokens = Vec::new(); - std::thread::spawn(move || loop { + let thread = std::thread::spawn(move || loop { let event: NodeEvent = match channel.request(&ControlRequest::NextEvent { drop_tokens: std::mem::take(&mut drop_tokens), }) { @@ -193,7 +196,7 @@ impl EventStream { } }); - Ok(EventStream { receiver: rx }) + Ok((EventStream { receiver: rx }, thread)) } pub fn recv(&mut self) -> Option { diff --git a/apis/rust/node/src/lib.rs b/apis/rust/node/src/lib.rs index 46b3c1e4..3ce43f6e 100644 --- a/apis/rust/node/src/lib.rs +++ b/apis/rust/node/src/lib.rs @@ -1,3 +1,5 @@ +use std::thread::JoinHandle; + use daemon::{ControlChannel, DaemonConnection, EventStream}; pub use dora_core; use dora_core::{ @@ -16,6 +18,7 @@ pub struct DoraNode { node_config: NodeRunConfig, control_channel: ControlChannel, hlc: uhlc::HLC, + event_stream_thread: Option>, } impl DoraNode { @@ -43,6 +46,7 @@ impl DoraNode { let DaemonConnection { control_channel, event_stream, + event_stream_thread, } = DaemonConnection::init( dataflow_id, &node_id, @@ -56,6 +60,7 @@ impl DoraNode { node_config: run_config, control_channel, hlc: uhlc::HLC::default(), + event_stream_thread: Some(event_stream_thread), }; Ok((node, event_stream)) } @@ -111,8 +116,15 @@ impl DoraNode { impl Drop for DoraNode { #[tracing::instrument(skip(self), fields(self.id = %self.id))] fn drop(&mut self) { - if let Err(err) = self.control_channel.report_stop() { - tracing::error!("{err:?}"); + match self.control_channel.report_stop() { + Ok(()) => { + if let Some(thread) = self.event_stream_thread.take() { + if let Err(panic) = thread.join() { + std::panic::resume_unwind(panic); + } + } + } + Err(err) => tracing::error!("{err:?}"), } } }