diff --git a/Cargo.lock b/Cargo.lock index de10853f..f074414b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1560,6 +1560,7 @@ dependencies = [ "flume", "futures", "futures-concurrency", + "futures-timer", "once_cell", "serde", "serde_json", @@ -2199,6 +2200,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.28" diff --git a/apis/python/node/src/lib.rs b/apis/python/node/src/lib.rs index 19f61264..611c0a19 100644 --- a/apis/python/node/src/lib.rs +++ b/apis/python/node/src/lib.rs @@ -57,12 +57,13 @@ impl Node { /// case "image": /// ``` #[allow(clippy::should_implement_trait)] - pub fn next(&mut self, py: Python) -> PyResult> { - self.__next__(py) + pub fn next(&mut self, py: Python, timeout: Option) -> PyResult> { + let event = py.allow_threads(|| self.events.recv(timeout)); + Ok(event) } pub fn __next__(&mut self, py: Python) -> PyResult> { - let event = py.allow_threads(|| self.events.recv()); + let event = py.allow_threads(|| self.events.recv(None)); Ok(event) } @@ -156,9 +157,12 @@ enum Events { } impl Events { - fn recv(&mut self) -> Option { + fn recv(&mut self, timeout: Option) -> Option { match self { - Events::Dora(events) => events.recv().map(PyEvent::from), + Events::Dora(events) => match timeout { + Some(timeout) => events.recv_timeout(timeout).map(PyEvent::from), + None => events.recv().map(PyEvent::from), + }, Events::Merged(events) => futures::executor::block_on(events.next()).map(PyEvent::from), } } diff --git a/apis/rust/node/Cargo.toml b/apis/rust/node/Cargo.toml index d13f4583..f5fe00bb 100644 --- a/apis/rust/node/Cargo.toml +++ b/apis/rust/node/Cargo.toml @@ -30,6 +30,7 @@ arrow = { workspace = true } arrow-schema = { workspace = true } futures = "0.3.28" futures-concurrency = "7.3.0" +futures-timer = "3.0.2" dora-arrow-convert = { workspace = true } aligned-vec = "0.5.0" diff --git a/apis/rust/node/src/event_stream/mod.rs b/apis/rust/node/src/event_stream/mod.rs index 48e7ecaa..cf403f30 100644 --- a/apis/rust/node/src/event_stream/mod.rs +++ b/apis/rust/node/src/event_stream/mod.rs @@ -1,7 +1,11 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; pub use event::{Event, MappedInputData, RawData}; -use futures::{Stream, StreamExt}; +use futures::{ + future::{select, Either}, + Stream, StreamExt, +}; +use futures_timer::Delay; use self::{ event::SharedMemoryData, @@ -104,11 +108,28 @@ impl EventStream { /// wait for the next event on the events stream. pub fn recv(&mut self) -> Option { - futures::executor::block_on(self.recv_async()) + futures::executor::block_on(self.recv_async(None)) + } + + /// wait for the next event on the events stream until timeout + pub fn recv_timeout(&mut self, secs: f32) -> Option { + futures::executor::block_on(self.recv_async(Some(secs))) } - pub async fn recv_async(&mut self) -> Option { - self.receiver.next().await.map(Self::convert_event_item) + pub async fn recv_async(&mut self, secs: Option) -> Option { + let receive_event = self.receiver.next(); + match secs { + None => receive_event.await, + Some(secs) => { + match select(Delay::new(Duration::from_secs_f32(secs)), receive_event).await { + Either::Left((_elapsed, _)) => { + Some(EventItem::TimedoutError(eyre!("Receiver timed out"))) + } + Either::Right((event, _)) => event, + } + } + } + .map(Self::convert_event_item) } fn convert_event_item(item: EventItem) -> Event { @@ -161,6 +182,9 @@ impl EventStream { EventItem::FatalError(err) => { Event::Error(format!("fatal event stream error: {err:?}")) } + EventItem::TimedoutError(err) => { + Event::Error(format!("Timeout event stream error: {err:?}")) + } } } } diff --git a/apis/rust/node/src/event_stream/thread.rs b/apis/rust/node/src/event_stream/thread.rs index 7ee13427..05480d5c 100644 --- a/apis/rust/node/src/event_stream/thread.rs +++ b/apis/rust/node/src/event_stream/thread.rs @@ -30,6 +30,7 @@ pub enum EventItem { ack_channel: flume::Sender<()>, }, FatalError(eyre::Report), + TimedoutError(eyre::Report), } pub struct EventStreamThreadHandle {