diff --git a/Cargo.lock b/Cargo.lock index aa91ce42..6dca9cb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1420,6 +1420,7 @@ dependencies = [ "dora-tracing", "eyre", "flume", + "futures", "once_cell", "serde", "serde_json", diff --git a/apis/rust/node/Cargo.toml b/apis/rust/node/Cargo.toml index fc83aea2..c242d44c 100644 --- a/apis/rust/node/Cargo.toml +++ b/apis/rust/node/Cargo.toml @@ -27,6 +27,7 @@ bincode = "1.3.3" shared_memory = "0.12.0" dora-tracing = { workspace = true, optional = true } arrow = "35.0.0" +futures = "0.3.28" [dev-dependencies] tokio = { version = "1.24.2", features = ["rt"] } diff --git a/apis/rust/node/src/event_stream/mod.rs b/apis/rust/node/src/event_stream/mod.rs index 7c247d1d..104722f8 100644 --- a/apis/rust/node/src/event_stream/mod.rs +++ b/apis/rust/node/src/event_stream/mod.rs @@ -1,6 +1,7 @@ use std::sync::Arc; pub use event::{Data, Event, MappedInputData}; +use futures::{Stream, StreamExt}; use self::thread::{EventItem, EventStreamThreadHandle}; use crate::daemon_connection::DaemonChannel; @@ -18,7 +19,7 @@ mod thread; pub struct EventStream { node_id: NodeId, - receiver: flume::Receiver, + receiver: flume::r#async::RecvStream<'static, EventItem>, _thread_handle: EventStreamThreadHandle, close_channel: DaemonChannel, clock: Arc, @@ -90,7 +91,7 @@ impl EventStream { Ok(EventStream { node_id: node_id.clone(), - receiver: rx, + receiver: rx.into_stream(), _thread_handle: thread_handle, close_channel, clock, @@ -99,25 +100,15 @@ impl EventStream { /// wait for the next event on the events stream. pub fn recv(&mut self) -> Option { - let event = self.receiver.recv(); - self.recv_common(event) + futures::executor::block_on(self.recv_async()) } pub async fn recv_async(&mut self) -> Option { - let event = self.receiver.recv_async().await; - self.recv_common(event) + self.receiver.next().await.map(Self::convert_event_item) } - #[tracing::instrument(skip(self), fields(%self.node_id))] - fn recv_common(&mut self, event: Result) -> Option { - let event = match event { - Ok(event) => event, - Err(flume::RecvError::Disconnected) => { - tracing::trace!("event channel disconnected"); - return None; - } - }; - let event = match event { + fn convert_event_item(item: EventItem) -> Event { + match item { EventItem::NodeEvent { event, ack_channel } => match event { NodeEvent::Stop => Event::Stop, NodeEvent::Reload { operator_id } => Event::Reload { operator_id }, @@ -156,9 +147,20 @@ impl EventStream { EventItem::FatalError(err) => { Event::Error(format!("fatal event stream error: {err:?}")) } - }; + } + } +} + +impl Stream for EventStream { + type Item = Event; - Some(event) + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.receiver + .poll_next_unpin(cx) + .map(|item| item.map(Self::convert_event_item)) } }