You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lib.rs 14 kB

3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. #![allow(clippy::borrow_deref_ref)] // clippy warns about code generated by #[pymethods]
  2. use std::env::current_dir;
  3. use std::path::PathBuf;
  4. use std::sync::Arc;
  5. use std::time::Duration;
  6. use arrow::pyarrow::{FromPyArrow, ToPyArrow};
  7. use dora_daemon::Daemon;
  8. use dora_download::download_file;
  9. use dora_node_api::dora_core::config::NodeId;
  10. use dora_node_api::dora_core::descriptor::source_is_url;
  11. use dora_node_api::merged::{MergeExternalSend, MergedEvent};
  12. use dora_node_api::{DataflowId, DoraNode, EventStream};
  13. use dora_operator_api_python::{pydict_to_metadata, DelayedCleanup, NodeCleanupHandle, PyEvent};
  14. use dora_ros2_bridge_python::Ros2Subscription;
  15. use eyre::Context;
  16. use futures::{Stream, StreamExt};
  17. use pyo3::prelude::*;
  18. use pyo3::types::{PyBytes, PyDict};
  19. /// use pyo3_special_method_derive::{Dict, Dir, Repr, Str};
  20. /// The custom node API lets you integrate `dora` into your application.
  21. /// It allows you to retrieve input and send output in any fashion you want.
  22. ///
  23. /// Use with:
  24. ///
  25. /// ```python
  26. /// from dora import Node
  27. ///
  28. /// node = Node()
  29. /// ```
  30. ///
  31. /// :type node_id: str, optional
  32. #[pyclass]
  33. /// #[derive(Dir, Dict, Str, Repr)]
  34. pub struct Node {
  35. events: Events,
  36. node: DelayedCleanup<DoraNode>,
  37. dataflow_id: DataflowId,
  38. node_id: NodeId,
  39. }
  40. #[pymethods]
  41. impl Node {
  42. #[new]
  43. #[pyo3(signature = (node_id=None))]
  44. pub fn new(node_id: Option<String>) -> eyre::Result<Self> {
  45. let (node, events) = if let Some(node_id) = node_id {
  46. DoraNode::init_flexible(NodeId::from(node_id))
  47. .context("Could not setup node from node id. Make sure to have a running dataflow with this dynamic node")?
  48. } else {
  49. DoraNode::init_from_env().context("Could not initiate node from environment variable. For dynamic node, please add a node id in the initialization function.")?
  50. };
  51. let dataflow_id = *node.dataflow_id();
  52. let node_id = node.id().clone();
  53. let node = DelayedCleanup::new(node);
  54. let events = events;
  55. let cleanup_handle = NodeCleanupHandle {
  56. _handles: Arc::new(node.handle()),
  57. };
  58. Ok(Node {
  59. events: Events {
  60. inner: EventsInner::Dora(events),
  61. _cleanup_handle: cleanup_handle,
  62. },
  63. dataflow_id,
  64. node_id,
  65. node,
  66. })
  67. }
  68. /// `.next()` gives you the next input that the node has received.
  69. /// It blocks until the next event becomes available.
  70. /// You can use timeout in seconds to return if no input is available.
  71. /// It will return `None` when all senders has been dropped.
  72. ///
  73. /// ```python
  74. /// event = node.next()
  75. /// ```
  76. ///
  77. /// You can also iterate over the event stream with a loop
  78. ///
  79. /// ```python
  80. /// for event in node:
  81. /// match event["type"]:
  82. /// case "INPUT":
  83. /// match event["id"]:
  84. /// case "image":
  85. /// ```
  86. ///
  87. /// :type timeout: float, optional
  88. /// :rtype: dict
  89. #[pyo3(signature = (timeout=None))]
  90. #[allow(clippy::should_implement_trait)]
  91. pub fn next(&mut self, py: Python, timeout: Option<f32>) -> PyResult<Option<Py<PyDict>>> {
  92. let event = py.allow_threads(|| self.events.recv(timeout.map(Duration::from_secs_f32)));
  93. if let Some(event) = event {
  94. let dict = event
  95. .to_py_dict(py)
  96. .context("Could not convert event into a dict")?;
  97. Ok(Some(dict))
  98. } else {
  99. Ok(None)
  100. }
  101. }
  102. /// `.recv_async()` gives you the next input that the node has received asynchronously.
  103. /// It does not blocks until the next event becomes available.
  104. /// You can use timeout in seconds to return if no input is available.
  105. /// It will return an Error if the timeout is reached.
  106. /// It will return `None` when all senders has been dropped.
  107. ///
  108. /// warning::
  109. /// This feature is experimental as pyo3 async (rust-python FFI) is still in development.
  110. ///
  111. /// ```python
  112. /// event = await node.recv_async()
  113. /// ```
  114. ///
  115. /// You can also iterate over the event stream with a loop
  116. ///
  117. /// :type timeout: float, optional
  118. /// :rtype: dict
  119. #[pyo3(signature = (timeout=None))]
  120. #[allow(clippy::should_implement_trait)]
  121. pub async fn recv_async(&mut self, timeout: Option<f32>) -> PyResult<Option<Py<PyDict>>> {
  122. let event = self
  123. .events
  124. .recv_async_timeout(timeout.map(Duration::from_secs_f32))
  125. .await;
  126. if let Some(event) = event {
  127. // Get python
  128. Python::with_gil(|py| {
  129. let dict = event
  130. .to_py_dict(py)
  131. .context("Could not convert event into a dict")?;
  132. Ok(Some(dict))
  133. })
  134. } else {
  135. Ok(None)
  136. }
  137. }
  138. /// You can iterate over the event stream with a loop
  139. ///
  140. /// ```python
  141. /// for event in node:
  142. /// match event["type"]:
  143. /// case "INPUT":
  144. /// match event["id"]:
  145. /// case "image":
  146. /// ```
  147. ///
  148. /// Default behaviour is to timeout after 2 seconds.
  149. ///
  150. /// :rtype: dict
  151. pub fn __next__(&mut self, py: Python) -> PyResult<Option<Py<PyDict>>> {
  152. self.next(py, None)
  153. }
  154. /// You can iterate over the event stream with a loop
  155. ///
  156. /// ```python
  157. /// for event in node:
  158. /// match event["type"]:
  159. /// case "INPUT":
  160. /// match event["id"]:
  161. /// case "image":
  162. /// ```
  163. ///
  164. /// :rtype: dict
  165. fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
  166. slf
  167. }
  168. /// `send_output` send data from the node.
  169. ///
  170. /// ```python
  171. /// Args:
  172. /// output_id: str,
  173. /// data: pyarrow.Array,
  174. /// metadata: Option[Dict],
  175. /// ```
  176. ///
  177. /// ex:
  178. ///
  179. /// ```python
  180. /// node.send_output("string", b"string", {"open_telemetry_context": "7632e76"})
  181. /// ```
  182. ///
  183. /// :type output_id: str
  184. /// :type data: pyarrow.Array
  185. /// :type metadata: dict, optional
  186. /// :rtype: None
  187. #[pyo3(signature = (output_id, data, metadata=None))]
  188. pub fn send_output(
  189. &mut self,
  190. output_id: String,
  191. data: PyObject,
  192. metadata: Option<Bound<'_, PyDict>>,
  193. py: Python,
  194. ) -> eyre::Result<()> {
  195. let parameters = pydict_to_metadata(metadata)?;
  196. if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
  197. let data = py_bytes.as_bytes();
  198. self.node
  199. .get_mut()
  200. .send_output_bytes(output_id.into(), parameters, data.len(), data)
  201. .wrap_err("failed to send output")?;
  202. } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow_bound(data.bind(py)) {
  203. self.node.get_mut().send_output(
  204. output_id.into(),
  205. parameters,
  206. arrow::array::make_array(arrow_array),
  207. )?;
  208. } else {
  209. eyre::bail!("invalid `data` type, must by `PyBytes` or arrow array")
  210. }
  211. Ok(())
  212. }
  213. /// Returns the full dataflow descriptor that this node is part of.
  214. ///
  215. /// This method returns the parsed dataflow YAML file.
  216. ///
  217. /// :rtype: dict
  218. pub fn dataflow_descriptor(&mut self, py: Python) -> eyre::Result<PyObject> {
  219. Ok(
  220. pythonize::pythonize(py, &self.node.get_mut().dataflow_descriptor())
  221. .map(|x| x.unbind())?,
  222. )
  223. }
  224. /// Returns the dataflow id.
  225. ///
  226. /// :rtype: str
  227. pub fn dataflow_id(&self) -> String {
  228. self.dataflow_id.to_string()
  229. }
  230. /// Merge an external event stream with dora main loop.
  231. /// This currently only work with ROS2.
  232. ///
  233. /// :type subscription: dora.Ros2Subscription
  234. /// :rtype: None
  235. pub fn merge_external_events(
  236. &mut self,
  237. subscription: &mut Ros2Subscription,
  238. ) -> eyre::Result<()> {
  239. let subscription = subscription.into_stream()?;
  240. let stream = futures::stream::poll_fn(move |cx| {
  241. let s = subscription.as_stream().map(|item| {
  242. match item.context("failed to read ROS2 message") {
  243. Ok((value, _info)) => Python::with_gil(|py| {
  244. value
  245. .to_pyarrow(py)
  246. .context("failed to convert value to pyarrow")
  247. .unwrap_or_else(|err| err_to_pyany(err, py))
  248. }),
  249. Err(err) => Python::with_gil(|py| err_to_pyany(err, py)),
  250. }
  251. });
  252. futures::pin_mut!(s);
  253. s.poll_next_unpin(cx)
  254. });
  255. // take out the event stream and temporarily replace it with a dummy
  256. let events = std::mem::replace(
  257. &mut self.events.inner,
  258. EventsInner::Merged(Box::new(futures::stream::empty())),
  259. );
  260. // update self.events with the merged stream
  261. self.events.inner = EventsInner::Merged(events.merge_external_send(Box::pin(stream)));
  262. Ok(())
  263. }
  264. }
  265. fn err_to_pyany(err: eyre::Report, gil: Python<'_>) -> Py<PyAny> {
  266. PyErr::from(err)
  267. .into_pyobject(gil)
  268. .unwrap_or_else(|infallible| match infallible {})
  269. .into_any()
  270. .unbind()
  271. }
  272. struct Events {
  273. inner: EventsInner,
  274. _cleanup_handle: NodeCleanupHandle,
  275. }
  276. impl Events {
  277. fn recv(&mut self, timeout: Option<Duration>) -> Option<PyEvent> {
  278. let event = match &mut self.inner {
  279. EventsInner::Dora(events) => match timeout {
  280. Some(timeout) => events.recv_timeout(timeout).map(MergedEvent::Dora),
  281. None => events.recv().map(MergedEvent::Dora),
  282. },
  283. EventsInner::Merged(events) => futures::executor::block_on(events.next()),
  284. };
  285. event.map(|event| PyEvent { event })
  286. }
  287. async fn recv_async_timeout(&mut self, timeout: Option<Duration>) -> Option<PyEvent> {
  288. let event = match &mut self.inner {
  289. EventsInner::Dora(events) => match timeout {
  290. Some(timeout) => events
  291. .recv_async_timeout(timeout)
  292. .await
  293. .map(MergedEvent::Dora),
  294. None => events.recv_async().await.map(MergedEvent::Dora),
  295. },
  296. EventsInner::Merged(events) => events.next().await,
  297. };
  298. event.map(|event| PyEvent { event })
  299. }
  300. }
  301. enum EventsInner {
  302. Dora(EventStream),
  303. Merged(Box<dyn Stream<Item = MergedEvent<PyObject>> + Unpin + Send + Sync>),
  304. }
  305. impl<'a> MergeExternalSend<'a, PyObject> for EventsInner {
  306. type Item = MergedEvent<PyObject>;
  307. fn merge_external_send(
  308. self,
  309. external_events: impl Stream<Item = PyObject> + Unpin + Send + Sync + 'a,
  310. ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + Sync + 'a> {
  311. match self {
  312. EventsInner::Dora(events) => events.merge_external_send(external_events),
  313. EventsInner::Merged(events) => {
  314. let merged = events.merge_external_send(external_events);
  315. Box::new(merged.map(|event| match event {
  316. MergedEvent::Dora(e) => MergedEvent::Dora(e),
  317. MergedEvent::External(e) => MergedEvent::External(e.flatten()),
  318. }))
  319. }
  320. }
  321. }
  322. }
  323. impl Node {
  324. pub fn id(&self) -> String {
  325. self.node_id.to_string()
  326. }
  327. }
  328. /// Start a runtime for Operators
  329. ///
  330. /// :rtype: None
  331. #[pyfunction]
  332. pub fn start_runtime() -> eyre::Result<()> {
  333. dora_runtime::main().wrap_err("Dora Runtime raised an error.")
  334. }
  335. pub fn resolve_dataflow(dataflow: String) -> eyre::Result<PathBuf> {
  336. let dataflow = if source_is_url(&dataflow) {
  337. // try to download the shared library
  338. let target_path = current_dir().context("Could not access the current dir")?;
  339. let rt = tokio::runtime::Builder::new_current_thread()
  340. .enable_all()
  341. .build()
  342. .context("tokio runtime failed")?;
  343. rt.block_on(async { download_file(&dataflow, &target_path).await })
  344. .wrap_err("failed to download dataflow yaml file")?
  345. } else {
  346. PathBuf::from(dataflow)
  347. };
  348. Ok(dataflow)
  349. }
  350. /// Run a Dataflow
  351. ///
  352. /// :rtype: None
  353. #[pyfunction]
  354. #[pyo3(signature = (dataflow_path, uv=None))]
  355. pub fn run(dataflow_path: String, uv: Option<bool>) -> eyre::Result<()> {
  356. let dataflow_path = resolve_dataflow(dataflow_path).context("could not resolve dataflow")?;
  357. let rt = tokio::runtime::Builder::new_multi_thread()
  358. .enable_all()
  359. .build()
  360. .context("tokio runtime failed")?;
  361. let result = rt.block_on(Daemon::run_dataflow(&dataflow_path, uv.unwrap_or_default()))?;
  362. match result.is_ok() {
  363. true => Ok(()),
  364. false => Err(eyre::eyre!(
  365. "Dataflow failed to run with error: {:?}",
  366. result.node_results
  367. )),
  368. }
  369. }
  370. #[pymodule]
  371. fn dora(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
  372. dora_ros2_bridge_python::create_dora_ros2_bridge_module(&m)?;
  373. m.add_function(wrap_pyfunction!(start_runtime, &m)?)?;
  374. m.add_function(wrap_pyfunction!(run, &m)?)?;
  375. m.add_class::<Node>()?;
  376. m.setattr("__version__", env!("CARGO_PKG_VERSION"))?;
  377. m.setattr("__author__", "Dora-rs Authors")?;
  378. Ok(())
  379. }