diff --git a/Cargo.lock b/Cargo.lock index f035e31c..a39587d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -512,6 +512,7 @@ name = "dora-rs" version = "0.1.0" dependencies = [ "env_logger", + "envy", "eyre", "futures", "pyo3", @@ -536,6 +537,15 @@ dependencies = [ "termcolor", ] +[[package]] +name = "envy" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f47e0157f2cb54f5ae1bd371b30a2ae4311e1c028f575cd4e81de7353215965" +dependencies = [ + "serde", +] + [[package]] name = "event-listener" version = "2.5.2" diff --git a/Cargo.toml b/Cargo.toml index 9c487bbc..fd206673 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,4 @@ tokio = { version="1.17.0", features=["full"]} pyo3 = "0.16.1" pyo3-asyncio = { version = "0.16", features = ["tokio-runtime", "attributes"] } futures = "0.3.12" +envy = "0.4.2" \ No newline at end of file diff --git a/src/server.rs b/src/server.rs index 0f469c36..3e857752 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,32 +1,41 @@ use eyre::{eyre, Context}; -use futures::prelude::*; +use futures::future::join_all; +use futures::stream::Next; +use futures::{join, prelude::*}; use pyo3::prelude::*; use std::collections::hash_map::DefaultHasher; use std::collections::{BTreeMap, HashMap}; -use std::env; use std::hash::Hash; use std::hash::Hasher; -use std::time::Duration; +use std::time::{Duration, Instant}; use tokio::time::timeout; use zenoh::config::Config; use zenoh::prelude::SplitBuffer; -static DURATION_MILLIS: u64 = 5; +static DURATION_MILLIS: u64 = 1; +use serde::Deserialize; + +#[derive(Deserialize, Debug)] +struct ConfigVariables { + subscriptions: Vec, + app: String, + function: String, +} #[pyo3_asyncio::tokio::main] pub async fn main() -> PyResult<()> { // Subscribe + let variables = envy::from_env::().unwrap(); + env_logger::init(); let config = Config::default(); let session = zenoh::open(config).await.unwrap(); - let subscriptions = env::var("SRC_LABELS") - .wrap_err("Env variable not set") - .unwrap(); - let subscriptions = subscriptions.split(":"); + // Create a hashmap of all subscriptions. let mut subscribers = HashMap::new(); + let subs = variables.subscriptions.clone(); - for subscription in subscriptions { - subscribers.insert(subscription, session + for subscription in &subs { + subscribers.insert(subscription.clone(), session .subscribe(subscription) .await .map_err(|err| { @@ -35,37 +44,47 @@ pub async fn main() -> PyResult<()> { .unwrap()); } + // Store the latest value of all subscription as well as the output of the function. hash the state to easily check if the state has changed. let mut states = BTreeMap::new(); let mut hasher = DefaultHasher::new(); states.hash(&mut hasher); let mut state_hash = hasher.finish(); - let identity = initialize("app".to_string(), "return_1".to_string()).unwrap(); + let py_function = initialize(variables.app, variables.function).unwrap(); let dur = Duration::from_millis(DURATION_MILLIS); + let mut futures_put = vec![]; loop { - for (subscription, subscriber) in subscribers.iter_mut() { - let result = timeout(dur, subscriber.next()).await; + let now = Instant::now(); + let mut futures = vec![]; + for (_, v) in subscribers.iter_mut() { + futures.push(timeout(dur, v.next())); + } + + let results = join_all(futures).await; + + for (result, subscription) in results.into_iter().zip(&subs) { if let Ok(Some(data)) = result { let value = data.value.payload; let binary = value.contiguous(); - states.insert(subscription.clone(), binary.to_vec()); + states.insert( + subscription.clone().to_string(), + String::from_utf8(binary.to_vec()).unwrap(), + ); } } let mut hasher = DefaultHasher::new(); states.hash(&mut hasher); - - if state_hash == hasher.finish() { + let new_hash = hasher.finish(); + if state_hash == new_hash { continue; - } else { - state_hash = hasher.finish(); } let result = Python::with_gil(|py| { let args = (states.clone().into_py(py),); pyo3_asyncio::tokio::into_future( - identity + py_function .call(py, args, None) .wrap_err("The Python function call did not succeed.") .unwrap() @@ -81,15 +100,16 @@ pub async fn main() -> PyResult<()> { let outputs: HashMap = Python::with_gil(|py| result.extract(py)) .wrap_err("Could not retrieve the python result.") .unwrap(); + for (key, value) in outputs { - session - .put(key, value) - .await - .map_err(|err| { - eyre!("Could not put the output within the chosen key expression topic. Error: {err}") - }) - .unwrap(); + states.insert(key.clone(), value.clone()); + futures_put.push(timeout(dur, session.put(key, value))); } + + let mut hasher = DefaultHasher::new(); + states.hash(&mut hasher); + state_hash = hasher.finish(); + println!("loop {:#?}", now.elapsed()); } }