diff --git a/src/lib.rs b/src/lib.rs index f14aea1e..76efc867 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,2 @@ pub mod descriptor; -pub mod python_binding; -pub mod server; +pub mod python; diff --git a/src/main.rs b/src/main.rs index a78cb422..93097efb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ use dora_rs::descriptor::Descriptor; -use eyre::{Context, Result}; -use pyo3::prepare_freethreaded_python; +use eyre::Context; use std::{fs::File, path::PathBuf}; use structopt::StructOpt; @@ -10,34 +9,30 @@ enum Command { #[structopt(about = "Print Graph")] Graph { file: PathBuf }, #[structopt(about = "Run Python server")] - StartPython, + StartPython(dora_rs::python::server::PythonCommand), } -fn main() -> Result<()> { +fn main() -> eyre::Result<()> { + env_logger::init(); + let command = Command::from_args(); match command { Command::Graph { file } => { - let descriptor_file = File::open(&file) - .context("failed to open given file") - .unwrap(); + let descriptor_file = File::open(&file).context("failed to open given file")?; let descriptor: Descriptor = serde_yaml::from_reader(descriptor_file) - .context("failed to parse given descriptor") - .unwrap(); + .context("failed to parse given descriptor")?; let visualized = descriptor .visualize_as_mermaid() - .context("failed to visualize descriptor") - .unwrap(); + .context("failed to visualize descriptor")?; println!("{visualized}"); println!( "Paste the above output on https://mermaid.live/ or in a \ ```mermaid code block on GitHub to display it." ); } - Command::StartPython => { - prepare_freethreaded_python(); - - dora_rs::server::main(); + Command::StartPython(command) => { + dora_rs::python::server::run(command).context("python server failed")?; } } diff --git a/src/python_binding.rs b/src/python/binding.rs similarity index 72% rename from src/python_binding.rs rename to src/python/binding.rs index 9be8d0ff..f1f7f373 100644 --- a/src/python_binding.rs +++ b/src/python/binding.rs @@ -1,24 +1,17 @@ -use eyre::{eyre, Context}; +use eyre::Context; use pyo3::prelude::*; -use serde::Deserialize; use std::collections::{BTreeMap, HashMap}; -#[derive(Deserialize, Debug)] -struct PythonVariables { - app: String, - function: String, -} - -pub fn init() -> eyre::Result> { - let variables = envy::from_env::().unwrap(); +pub fn init(app: &str, function: &str) -> eyre::Result> { + pyo3::prepare_freethreaded_python(); Ok(Python::with_gil(|py| { let file = py - .import(&variables.app) + .import(app) .wrap_err("The import file was not found. Check your PYTHONPATH env variable.") .unwrap(); // convert Function into a PyObject let identity = file - .getattr(variables.function) + .getattr(function) .wrap_err("The Function was not found in the imported file.") .unwrap(); identity.to_object(py) diff --git a/src/python/mod.rs b/src/python/mod.rs new file mode 100644 index 00000000..06dcd6cf --- /dev/null +++ b/src/python/mod.rs @@ -0,0 +1,2 @@ +pub mod binding; +pub mod server; diff --git a/src/server.rs b/src/python/server.rs similarity index 78% rename from src/server.rs rename to src/python/server.rs index 6faa98dc..f361aab8 100644 --- a/src/server.rs +++ b/src/python/server.rs @@ -1,3 +1,4 @@ +use super::binding; use eyre::eyre; use eyre::WrapErr; use futures::future::join_all; @@ -9,32 +10,28 @@ use std::collections::{BTreeMap, HashMap}; use std::hash::Hash; use std::hash::Hasher; use std::time::{Duration, Instant}; +use structopt::StructOpt; use tokio::time::timeout; use zenoh::config::Config; use zenoh::prelude::SplitBuffer; -use crate::python_binding::{call, init}; - static DURATION_MILLIS: u64 = 5; -#[derive(Deserialize, Debug)] -struct ConfigVariables { - subscriptions: Vec, + +#[derive(Deserialize, Debug, Clone, StructOpt)] +pub struct PythonCommand { + pub subscriptions: Vec, + pub app: String, + pub function: String, } #[tokio::main] -pub async fn main() -> PyResult<()> { +pub async fn run(variables: PythonCommand) -> PyResult<()> { // Subscribe - let variables = envy::from_env::().unwrap(); - - env_logger::init(); - let config = Config::default(); - let session = zenoh::open(config).await.unwrap(); + let session = zenoh::open(Config::default()).await.unwrap(); // Create a hashmap of all subscriptions. let mut subscribers = HashMap::new(); - let subs = variables.subscriptions.clone(); - - for subscription in &subs { + for subscription in &variables.subscriptions { subscribers.insert(subscription.clone(), session .subscribe(subscription) .await @@ -48,14 +45,13 @@ pub async fn main() -> PyResult<()> { let mut states = BTreeMap::new(); let mut states_hash = hash(&states); - let py_function = init() + let py_function = binding::init(&variables.app, &variables.function) .wrap_err("Failed to init the Python Function") .unwrap(); let duration = Duration::from_millis(DURATION_MILLIS); let mut futures_put = vec![]; loop { - let now = Instant::now(); let mut futures = vec![]; for (_, v) in subscribers.iter_mut() { futures.push(timeout(duration, v.next())); @@ -63,7 +59,7 @@ pub async fn main() -> PyResult<()> { let results = join_all(futures).await; - for (result, subscription) in results.into_iter().zip(&subs) { + for (result, subscription) in results.into_iter().zip(&variables.subscriptions.clone()) { if let Ok(Some(data)) = result { let value = data.value.payload; let binary = value.contiguous(); @@ -82,7 +78,7 @@ pub async fn main() -> PyResult<()> { let now = Instant::now(); - let outputs = call(&py_function, states.clone()).await.unwrap(); + let outputs = binding::call(&py_function, states.clone()).await.unwrap(); println!("call python {:#?}", now.elapsed()); for (key, value) in outputs {