You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lib.rs 8.8 kB

3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #![allow(clippy::borrow_deref_ref)] // clippy warns about code generated by #[pymethods]
  2. use std::time::Duration;
  3. use arrow::pyarrow::{FromPyArrow, ToPyArrow};
  4. use dora_node_api::dora_core::config::NodeId;
  5. use dora_node_api::merged::{MergeExternalSend, MergedEvent};
  6. use dora_node_api::{DoraNode, EventStream};
  7. use dora_operator_api_python::{pydict_to_metadata, PyEvent};
  8. use dora_ros2_bridge_python::Ros2Subscription;
  9. use eyre::Context;
  10. use futures::{Stream, StreamExt};
  11. use pyo3::prelude::*;
  12. use pyo3::types::{PyBytes, PyDict};
  13. use pyo3_special_method_derive::Dir;
  14. /// The custom node API lets you integrate `dora` into your application.
  15. /// It allows you to retrieve input and send output in any fashion you want.
  16. ///
  17. /// Use with:
  18. ///
  19. /// ```python
  20. /// from dora import Node
  21. ///
  22. /// node = Node()
  23. /// ```
  24. ///
  25. /// :type node_id: str, optional
  26. #[pyclass]
  27. #[derive(Dir)]
  28. pub struct Node {
  29. events: Events,
  30. pub node: DoraNode,
  31. }
  32. #[pymethods]
  33. impl Node {
  34. #[new]
  35. pub fn new(node_id: Option<String>) -> eyre::Result<Self> {
  36. let (node, events) = if let Some(node_id) = node_id {
  37. DoraNode::init_flexible(NodeId::from(node_id))
  38. .context("Could not setup node from node id. Make sure to have a running dataflow with this dynamic node")?
  39. } else {
  40. DoraNode::init_from_env().context("Couldn not initiate node from environment variable. For dynamic node, please add a node id in the initialization function.")?
  41. };
  42. Ok(Node {
  43. events: Events::Dora(events),
  44. node,
  45. })
  46. }
  47. /// `.next()` gives you the next input that the node has received.
  48. /// It blocks until the next event becomes available.
  49. /// You can use timeout in seconds to return if no input is available.
  50. /// It will return `None` when all senders has been dropped.
  51. ///
  52. /// ```python
  53. /// event = node.next()
  54. /// ```
  55. ///
  56. /// You can also iterate over the event stream with a loop
  57. ///
  58. /// ```python
  59. /// for event in node:
  60. /// match event["type"]:
  61. /// case "INPUT":
  62. /// match event["id"]:
  63. /// case "image":
  64. /// ```
  65. ///
  66. /// :type timeout: float, optional
  67. /// :rtype: dict
  68. #[allow(clippy::should_implement_trait)]
  69. pub fn next(&mut self, py: Python, timeout: Option<f32>) -> PyResult<Option<Py<PyDict>>> {
  70. let event = py.allow_threads(|| self.events.recv(timeout.map(Duration::from_secs_f32)));
  71. if let Some(event) = event {
  72. let dict = event
  73. .to_py_dict(py)
  74. .context("Could not convert event into a dict")?;
  75. Ok(Some(dict))
  76. } else {
  77. Ok(None)
  78. }
  79. }
  80. /// You can iterate over the event stream with a loop
  81. ///
  82. /// ```python
  83. /// for event in node:
  84. /// match event["type"]:
  85. /// case "INPUT":
  86. /// match event["id"]:
  87. /// case "image":
  88. /// ```
  89. ///
  90. /// Default behaviour is to timeout after 2 seconds.
  91. ///
  92. /// :rtype: dict
  93. pub fn __next__(&mut self, py: Python) -> PyResult<Option<Py<PyDict>>> {
  94. self.next(py, None)
  95. }
  96. /// You can iterate over the event stream with a loop
  97. ///
  98. /// ```python
  99. /// for event in node:
  100. /// match event["type"]:
  101. /// case "INPUT":
  102. /// match event["id"]:
  103. /// case "image":
  104. /// ```
  105. ///
  106. /// :rtype: dict
  107. fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
  108. slf
  109. }
  110. /// `send_output` send data from the node.
  111. ///
  112. /// ```python
  113. /// Args:
  114. /// output_id: str,
  115. /// data: pyarrow.Array,
  116. /// metadata: Option[Dict],
  117. /// ```
  118. ///
  119. /// ex:
  120. ///
  121. /// ```python
  122. /// node.send_output("string", b"string", {"open_telemetry_context": "7632e76"})
  123. /// ```
  124. ///
  125. /// :type output_id: str
  126. /// :type data: pyarrow.Array
  127. /// :type metadata: dict, optional
  128. /// :rtype: None
  129. pub fn send_output(
  130. &mut self,
  131. output_id: String,
  132. data: PyObject,
  133. metadata: Option<Bound<'_, PyDict>>,
  134. py: Python,
  135. ) -> eyre::Result<()> {
  136. let parameters = pydict_to_metadata(metadata)?;
  137. if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
  138. let data = py_bytes.as_bytes();
  139. self.node
  140. .send_output_bytes(output_id.into(), parameters, data.len(), data)
  141. .wrap_err("failed to send output")?;
  142. } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow_bound(data.bind(py)) {
  143. self.node.send_output(
  144. output_id.into(),
  145. parameters,
  146. arrow::array::make_array(arrow_array),
  147. )?;
  148. } else {
  149. eyre::bail!("invalid `data` type, must by `PyBytes` or arrow array")
  150. }
  151. Ok(())
  152. }
  153. /// Returns the full dataflow descriptor that this node is part of.
  154. ///
  155. /// This method returns the parsed dataflow YAML file.
  156. ///
  157. /// :rtype: dict
  158. pub fn dataflow_descriptor(&self, py: Python) -> pythonize::Result<PyObject> {
  159. pythonize::pythonize(py, self.node.dataflow_descriptor())
  160. }
  161. /// Returns the dataflow id.
  162. ///
  163. /// :rtype: str
  164. pub fn dataflow_id(&self) -> String {
  165. self.node.dataflow_id().to_string()
  166. }
  167. /// Merge an external event stream with dora main loop.
  168. /// This currently only work with ROS2.
  169. ///
  170. /// :type subscription: dora.Ros2Subscription
  171. /// :rtype: None
  172. pub fn merge_external_events(
  173. &mut self,
  174. subscription: &mut Ros2Subscription,
  175. ) -> eyre::Result<()> {
  176. let subscription = subscription.into_stream()?;
  177. let stream = futures::stream::poll_fn(move |cx| {
  178. let s = subscription.as_stream().map(|item| {
  179. match item.context("failed to read ROS2 message") {
  180. Ok((value, _info)) => Python::with_gil(|py| {
  181. value
  182. .to_pyarrow(py)
  183. .context("failed to convert value to pyarrow")
  184. .unwrap_or_else(|err| PyErr::from(err).to_object(py))
  185. }),
  186. Err(err) => Python::with_gil(|py| PyErr::from(err).to_object(py)),
  187. }
  188. });
  189. futures::pin_mut!(s);
  190. s.poll_next_unpin(cx)
  191. });
  192. // take out the event stream and temporarily replace it with a dummy
  193. let events = std::mem::replace(
  194. &mut self.events,
  195. Events::Merged(Box::new(futures::stream::empty())),
  196. );
  197. // update self.events with the merged stream
  198. self.events = Events::Merged(events.merge_external_send(Box::pin(stream)));
  199. Ok(())
  200. }
  201. }
  202. enum Events {
  203. Dora(EventStream),
  204. Merged(Box<dyn Stream<Item = MergedEvent<PyObject>> + Unpin + Send>),
  205. }
  206. impl Events {
  207. fn recv(&mut self, timeout: Option<Duration>) -> Option<PyEvent> {
  208. match self {
  209. Events::Dora(events) => match timeout {
  210. Some(timeout) => events.recv_timeout(timeout).map(PyEvent::from),
  211. None => events.recv().map(PyEvent::from),
  212. },
  213. Events::Merged(events) => futures::executor::block_on(events.next()).map(PyEvent::from),
  214. }
  215. }
  216. }
  217. impl<'a> MergeExternalSend<'a, PyObject> for Events {
  218. type Item = MergedEvent<PyObject>;
  219. fn merge_external_send(
  220. self,
  221. external_events: impl Stream<Item = PyObject> + Unpin + Send + 'a,
  222. ) -> Box<dyn Stream<Item = Self::Item> + Unpin + Send + 'a> {
  223. match self {
  224. Events::Dora(events) => events.merge_external_send(external_events),
  225. Events::Merged(events) => {
  226. let merged = events.merge_external_send(external_events);
  227. Box::new(merged.map(|event| match event {
  228. MergedEvent::Dora(e) => MergedEvent::Dora(e),
  229. MergedEvent::External(e) => MergedEvent::External(e.flatten()),
  230. }))
  231. }
  232. }
  233. }
  234. }
  235. impl Node {
  236. pub fn id(&self) -> String {
  237. self.node.id().to_string()
  238. }
  239. }
  240. /// Start a runtime for Operators
  241. ///
  242. /// :rtype: None
  243. #[pyfunction]
  244. pub fn start_runtime() -> eyre::Result<()> {
  245. dora_runtime::main().wrap_err("Dora Runtime raised an error.")
  246. }
  247. #[pymodule]
  248. fn dora(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
  249. dora_ros2_bridge_python::create_dora_ros2_bridge_module(&m)?;
  250. m.add_function(wrap_pyfunction!(start_runtime, &m)?)?;
  251. m.add_class::<Node>()?;
  252. m.setattr("__version__", env!("CARGO_PKG_VERSION"))?;
  253. m.setattr("__author__", "Dora-rs Authors")?;
  254. Ok(())
  255. }