You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lib.rs 12 kB

3 years ago
3 years ago

  1. #![allow(clippy::borrow_deref_ref)] // clippy warns about code generated by #[pymethods]
  2. use std::env::current_dir;
  3. use std::path::PathBuf;
  4. use std::sync::Arc;
  5. use std::time::Duration;
  6. use arrow::pyarrow::{FromPyArrow, ToPyArrow};
  7. use dora_daemon::Daemon;
  8. use dora_download::download_file;
  9. use dora_node_api::dora_core::config::NodeId;
  10. use dora_node_api::dora_core::descriptor::source_is_url;
  11. use dora_node_api::merged::{MergeExternalSend, MergedEvent};
  12. use dora_node_api::{DataflowId, DoraNode, EventStream};
  13. use dora_operator_api_python::{pydict_to_metadata, DelayedCleanup, NodeCleanupHandle, PyEvent};
  14. use dora_ros2_bridge_python::Ros2Subscription;
  15. use eyre::Context;
  16. use futures::{Stream, StreamExt};
  17. use pyo3::prelude::*;
  18. use pyo3::types::{PyBytes, PyDict};
  19. /// use pyo3_special_method_derive::{Dict, Dir, Repr, Str};
  20. /// The custom node API lets you integrate `dora` into your application.
  21. /// It allows you to retrieve input and send output in any fashion you want.
  22. ///
  23. /// Use with:
  24. ///
  25. /// ```python
  26. /// from dora import Node
  27. ///
  28. /// node = Node()
  29. /// ```
  30. ///
  31. /// :type node_id: str, optional
  32. #[pyclass]
  33. /// #[derive(Dir, Dict, Str, Repr)]
  34. pub struct Node {
  35. events: Events,
  36. node: DelayedCleanup<DoraNode>,
  37. dataflow_id: DataflowId,
  38. node_id: NodeId,
  39. }
  40. #[pymethods]
  41. impl Node {
  42. #[new]
  43. #[pyo3(signature = (node_id=None))]
  44. pub fn new(node_id: Option<String>) -> eyre::Result<Self> {
  45. let (node, events) = if let Some(node_id) = node_id {
  46. DoraNode::init_flexible(NodeId::from(node_id))
  47. .context("Could not setup node from node id. Make sure to have a running dataflow with this dynamic node")?
  48. } else {
  49. DoraNode::init_from_env().context("Could not initiate node from environment variable. For dynamic node, please add a node id in the initialization function.")?
  50. };
  51. let dataflow_id = *node.dataflow_id();
  52. let node_id = node.id().clone();
  53. let node = DelayedCleanup::new(node);
  54. let events = DelayedCleanup::new(events);
  55. let cleanup_handle = NodeCleanupHandle {
  56. _handles: Arc::new((node.handle(), events.handle())),
  57. };
  58. Ok(Node {
  59. events: Events {
  60. inner: EventsInner::Dora(events),
  61. cleanup_handle,
  62. },
  63. dataflow_id,
  64. node_id,
  65. node,
  66. })
  67. }
  68. /// `.next()` gives you the next input that the node has received.
  69. /// It blocks until the next event becomes available.
  70. /// You can use timeout in seconds to return if no input is available.
  71. /// It will return `None` when all senders has been dropped.
  72. ///
  73. /// ```python
  74. /// event = node.next()
  75. /// ```
  76. ///
  77. /// You can also iterate over the event stream with a loop
  78. ///
  79. /// ```python
  80. /// for event in node:
  81. /// match event["type"]:
  82. /// case "INPUT":
  83. /// match event["id"]:
  84. /// case "image":
  85. /// ```
  86. ///
  87. /// :type timeout: float, optional
  88. /// :rtype: dict
  89. #[pyo3(signature = (timeout=None))]
  90. #[allow(clippy::should_implement_trait)]
  91. pub fn next(&mut self, py: Python, timeout: Option<f32>) -> PyResult<Option<Py<PyDict>>> {
  92. let event = py.allow_threads(|| self.events.recv(timeout.map(Duration::from_secs_f32)));
  93. if let Some(event) = event {
  94. let dict = event
  95. .to_py_dict(py)
  96. .context("Could not convert event into a dict")?;
  97. Ok(Some(dict))
  98. } else {
  99. Ok(None)
  100. }
  101. }
  102. /// You can iterate over the event stream with a loop
  103. ///
  104. /// ```python
  105. /// for event in node:
  106. /// match event["type"]:
  107. /// case "INPUT":
  108. /// match event["id"]:
  109. /// case "image":
  110. /// ```
  111. ///
  112. /// Default behaviour is to timeout after 2 seconds.
  113. ///
  114. /// :rtype: dict
  115. pub fn __next__(&mut self, py: Python) -> PyResult<Option<Py<PyDict>>> {
  116. self.next(py, None)
  117. }
  118. /// You can iterate over the event stream with a loop
  119. ///
  120. /// ```python
  121. /// for event in node:
  122. /// match event["type"]:
  123. /// case "INPUT":
  124. /// match event["id"]:
  125. /// case "image":
  126. /// ```
  127. ///
  128. /// :rtype: dict
  129. fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
  130. slf
  131. }
  132. /// `send_output` send data from the node.
  133. ///
  134. /// ```python
  135. /// Args:
  136. /// output_id: str,
  137. /// data: pyarrow.Array,
  138. /// metadata: Option[Dict],
  139. /// ```
  140. ///
  141. /// ex:
  142. ///
  143. /// ```python
  144. /// node.send_output("string", b"string", {"open_telemetry_context": "7632e76"})
  145. /// ```
  146. ///
  147. /// :type output_id: str
  148. /// :type data: pyarrow.Array
  149. /// :type metadata: dict, optional
  150. /// :rtype: None
  151. #[pyo3(signature = (output_id, data, metadata=None))]
  152. pub fn send_output(
  153. &mut self,
  154. output_id: String,
  155. data: PyObject,
  156. metadata: Option<Bound<'_, PyDict>>,
  157. py: Python,
  158. ) -> eyre::Result<()> {
  159. let parameters = pydict_to_metadata(metadata)?;
  160. if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
  161. let data = py_bytes.as_bytes();
  162. self.node
  163. .get_mut()
  164. .send_output_bytes(output_id.into(), parameters, data.len(), data)
  165. .wrap_err("failed to send output")?;
  166. } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow_bound(data.bind(py)) {
  167. self.node.get_mut().send_output(
  168. output_id.into(),
  169. parameters,
  170. arrow::array::make_array(arrow_array),
  171. )?;
  172. } else {
  173. eyre::bail!("invalid `data` type, must by `PyBytes` or arrow array")
  174. }
  175. Ok(())
  176. }
  177. /// Returns the full dataflow descriptor that this node is part of.
  178. ///
  179. /// This method returns the parsed dataflow YAML file.
  180. ///
  181. /// :rtype: dict
  182. pub fn dataflow_descriptor(&mut self, py: Python) -> eyre::Result<PyObject> {
  183. Ok(
  184. pythonize::pythonize(py, &self.node.get_mut().dataflow_descriptor())
  185. .map(|x| x.unbind())?,
  186. )
  187. }
  188. /// Returns the dataflow id.
  189. ///
  190. /// :rtype: str
  191. pub fn dataflow_id(&self) -> String {
  192. self.dataflow_id.to_string()
  193. }
  194. /// Merge an external event stream with dora main loop.
  195. /// This currently only work with ROS2.
  196. ///
  197. /// :type subscription: dora.Ros2Subscription
  198. /// :rtype: None
  199. pub fn merge_external_events(
  200. &mut self,
  201. subscription: &mut Ros2Subscription,
  202. ) -> eyre::Result<()> {
  203. let subscription = subscription.into_stream()?;
  204. let stream = futures::stream::poll_fn(move |cx| {
  205. let s = subscription.as_stream().map(|item| {
  206. match item.context("failed to read ROS2 message") {
  207. Ok((value, _info)) => Python::with_gil(|py| {
  208. value
  209. .to_pyarrow(py)
  210. .context("failed to convert value to pyarrow")
  211. .unwrap_or_else(|err| err_to_pyany(err, py))
  212. }),
  213. Err(err) => Python::with_gil(|py| err_to_pyany(err, py)),
  214. }
  215. });
  216. futures::pin_mut!(s);
  217. s.poll_next_unpin(cx)
  218. });
  219. // take out the event stream and temporarily replace it with a dummy
  220. let events = std::mem::replace(
  221. &mut self.events.inner,
  222. EventsInner::Merged(Box::new(futures::stream::empty())),
  223. );
  224. // update self.events with the merged stream
  225. self.events.inner = EventsInner::Merged(events.merge_external_send(Box::pin(stream)));
  226. Ok(())
  227. }
  228. }
  229. fn err_to_pyany(err: eyre::Report, gil: Python<'_>) -> Py<PyAny> {
  230. PyErr::from(err)
  231. .into_pyobject(gil)
  232. .unwrap_or_else(|infallible| match infallible {})
  233. .into_any()
  234. .unbind()
  235. }
  236. struct Events {
  237. inner: EventsInner,
  238. cleanup_handle: NodeCleanupHandle,
  239. }
  240. impl Events {
  241. fn recv(&mut self, timeout: Option<Duration>) -> Option<PyEvent> {
  242. let event = match &mut self.inner {
  243. EventsInner::Dora(events) => match timeout {
  244. Some(timeout) => events
  245. .get_mut()
  246. .recv_timeout(timeout)
  247. .map(MergedEvent::Dora),
  248. None => events.get_mut().recv().map(MergedEvent::Dora),
  249. },
  250. EventsInner::Merged(events) => futures::executor::block_on(events.next()),
  251. };
  252. event.map(|event| PyEvent {
  253. event,
  254. _cleanup: Some(self.cleanup_handle.clone()),
  255. })
  256. }
  257. }
  258. enum EventsInner {
  259. Dora(DelayedCleanup<EventStream>),
  260. Merged(Box<dyn Stream<Item = MergedEvent<PyObject>> + Unpin + Send + Sync>),
  261. }
  262. impl<'a> MergeExternalSend<'a, PyObject> for EventsInner {
  263. type Item = MergedEvent<PyObject>;
  264. fn merge_external_send(
  265. self,
  266. external_events: impl Stream<Item = PyObject> + Unpin + Send + Sync + 'a,
  267. ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
  268. match self {
  269. EventsInner::Dora(events) => events.merge_external_send(external_events),
  270. EventsInner::Merged(events) => {
  271. let merged = events.merge_external_send(external_events);
  272. Box::new(merged.map(|event| match event {
  273. MergedEvent::Dora(e) => MergedEvent::Dora(e),
  274. MergedEvent::External(e) => MergedEvent::External(e.flatten()),
  275. }))
  276. }
  277. }
  278. }
  279. }
  280. impl Node {
  281. pub fn id(&self) -> String {
  282. self.node_id.to_string()
  283. }
  284. }
  285. /// Start a runtime for Operators
  286. ///
  287. /// :rtype: None
  288. #[pyfunction]
  289. pub fn start_runtime() -> eyre::Result<()> {
  290. dora_runtime::main().wrap_err("Dora Runtime raised an error.")
  291. }
  292. pub fn resolve_dataflow(dataflow: String) -> eyre::Result<PathBuf> {
  293. let dataflow = if source_is_url(&dataflow) {
  294. // try to download the shared library
  295. let target_path = current_dir().context("Could not access the current dir")?;
  296. let rt = tokio::runtime::Builder::new_current_thread()
  297. .enable_all()
  298. .build()
  299. .context("tokio runtime failed")?;
  300. rt.block_on(async { download_file(&dataflow, &target_path).await })
  301. .wrap_err("failed to download dataflow yaml file")?
  302. } else {
  303. PathBuf::from(dataflow)
  304. };
  305. Ok(dataflow)
  306. }
  307. /// Run a Dataflow
  308. ///
  309. /// :rtype: None
  310. #[pyfunction]
  311. #[pyo3(signature = (dataflow_path, uv=None))]
  312. pub fn run(dataflow_path: String, uv: Option<bool>) -> eyre::Result<()> {
  313. let dataflow_path = resolve_dataflow(dataflow_path).context("could not resolve dataflow")?;
  314. let rt = tokio::runtime::Builder::new_multi_thread()
  315. .enable_all()
  316. .build()
  317. .context("tokio runtime failed")?;
  318. let result = rt.block_on(Daemon::run_dataflow(&dataflow_path, uv.unwrap_or_default()))?;
  319. match result.is_ok() {
  320. true => Ok(()),
  321. false => Err(eyre::eyre!(
  322. "Dataflow failed to run with error: {:?}",
  323. result.node_results
  324. )),
  325. }
  326. }
  327. #[pymodule]
  328. fn dora(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
  329. dora_ros2_bridge_python::create_dora_ros2_bridge_module(&m)?;
  330. m.add_function(wrap_pyfunction!(start_runtime, &m)?)?;
  331. m.add_function(wrap_pyfunction!(run, &m)?)?;
  332. m.add_class::<Node>()?;
  333. m.setattr("__version__", env!("CARGO_PKG_VERSION"))?;
  334. m.setattr("__author__", "Dora-rs Authors")?;
  335. Ok(())
  336. }