You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python.rs 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. use super::OperatorEvent;
  2. use dora_node_api::{communication::Publisher, config::DataId};
  3. use eyre::{bail, eyre, Context};
  4. use pyo3::{pyclass, types::IntoPyDict, types::PyBytes, Py, Python};
  5. use std::{
  6. collections::HashMap,
  7. panic::{catch_unwind, AssertUnwindSafe},
  8. path::Path,
  9. sync::Arc,
  10. thread,
  11. };
  12. use tokio::sync::mpsc::Sender;
  13. fn traceback(err: pyo3::PyErr) -> eyre::Report {
  14. Python::with_gil(|py| {
  15. eyre::Report::msg(format!(
  16. "{}\n{err}",
  17. err.traceback(py)
  18. .expect("PyError should have a traceback")
  19. .format()
  20. .expect("Traceback could not be formatted")
  21. ))
  22. })
  23. }
  24. pub fn spawn(
  25. path: &Path,
  26. events_tx: Sender<OperatorEvent>,
  27. inputs: flume::Receiver<dora_node_api::Input>,
  28. publishers: HashMap<DataId, Box<dyn Publisher>>,
  29. ) -> eyre::Result<()> {
  30. if !path.exists() {
  31. bail!("No python file exists at {}", path.display());
  32. }
  33. let path = path
  34. .canonicalize()
  35. .wrap_err_with(|| format!("no file found at `{}`", path.display()))?;
  36. let path_cloned = path.clone();
  37. let send_output = SendOutputCallback {
  38. publishers: Arc::new(publishers),
  39. };
  40. let init_operator = move |py: Python| {
  41. if let Some(parent_path) = path.parent() {
  42. let parent_path = parent_path
  43. .to_str()
  44. .ok_or_else(|| eyre!("module path is not valid utf8"))?;
  45. let sys = py.import("sys").wrap_err("failed to import `sys` module")?;
  46. let sys_path = sys
  47. .getattr("path")
  48. .wrap_err("failed to import `sys.path` module")?;
  49. let sys_path_append = sys_path
  50. .getattr("append")
  51. .wrap_err("`sys.path.append` was not found")?;
  52. sys_path_append
  53. .call1((parent_path,))
  54. .wrap_err("failed to append module path to python search path")?;
  55. }
  56. let module_name = path
  57. .file_stem()
  58. .ok_or_else(|| eyre!("module path has no file stem"))?
  59. .to_str()
  60. .ok_or_else(|| eyre!("module file stem is not valid utf8"))?;
  61. let module = py.import(module_name).map_err(traceback)?;
  62. let operator_class = module
  63. .getattr("Operator")
  64. .wrap_err("no `Operator` class found in module")?;
  65. let locals = [("Operator", operator_class)].into_py_dict(py);
  66. let operator = py
  67. .eval("Operator()", None, Some(locals))
  68. .map_err(traceback)?;
  69. Result::<_, eyre::Report>::Ok(Py::from(operator))
  70. };
  71. let python_runner = move || {
  72. let operator =
  73. Python::with_gil(init_operator).wrap_err("failed to init python operator")?;
  74. while let Ok(input) = inputs.recv() {
  75. let status_enum = Python::with_gil(|py| {
  76. operator
  77. .call_method1(
  78. py,
  79. "on_input",
  80. (
  81. input.id.to_string(),
  82. PyBytes::new(py, &input.data),
  83. send_output.clone(),
  84. ),
  85. )
  86. .map_err(traceback)
  87. })?;
  88. let status_val = Python::with_gil(|py| status_enum.getattr(py, "value"))
  89. .wrap_err("on_input must have enum return value")?;
  90. let status: i32 = Python::with_gil(|py| status_val.extract(py))
  91. .wrap_err("on_input has invalid return value")?;
  92. match status {
  93. 0 => {} // ok
  94. 1 => break, // stop
  95. other => bail!("on_input returned invalid status {other}"),
  96. }
  97. }
  98. Python::with_gil(|py| {
  99. let operator = operator.as_ref(py);
  100. if operator
  101. .hasattr("drop_operator")
  102. .wrap_err("failed to look for drop_operator")?
  103. {
  104. operator.call_method0("drop_operator")?;
  105. }
  106. Result::<_, eyre::Report>::Ok(())
  107. })?;
  108. Result::<_, eyre::Report>::Ok(())
  109. };
  110. thread::spawn(move || {
  111. let closure = AssertUnwindSafe(|| {
  112. python_runner()
  113. .wrap_err_with(|| format!("error in Python module at {}", path_cloned.display()))
  114. });
  115. match catch_unwind(closure) {
  116. Ok(Ok(())) => {
  117. let _ = events_tx.blocking_send(OperatorEvent::Finished);
  118. }
  119. Ok(Err(err)) => {
  120. let _ = events_tx.blocking_send(OperatorEvent::Error(err));
  121. }
  122. Err(panic) => {
  123. let _ = events_tx.blocking_send(OperatorEvent::Panic(panic));
  124. }
  125. }
  126. });
  127. Ok(())
  128. }
  129. #[pyclass]
  130. #[derive(Clone)]
  131. struct SendOutputCallback {
  132. publishers: Arc<HashMap<DataId, Box<dyn Publisher>>>,
  133. }
  134. #[allow(unsafe_op_in_unsafe_fn)]
  135. mod callback_impl {
  136. use super::SendOutputCallback;
  137. use eyre::{eyre, Context};
  138. use pyo3::{pymethods, PyResult};
  139. #[pymethods]
  140. impl SendOutputCallback {
  141. fn __call__(&mut self, output: &str, data: &[u8]) -> PyResult<()> {
  142. match self.publishers.get(output) {
  143. Some(publisher) => publisher
  144. .publish(data)
  145. .map_err(|err| eyre::eyre!(err))
  146. .context("publish failed"),
  147. None => Err(eyre!(
  148. "unexpected output {output} (not defined in dataflow config)"
  149. )),
  150. }
  151. .map_err(|err| err.into())
  152. }
  153. }
  154. }

DORA (Dataflow-Oriented Robotic Architecture) is middleware designed to streamline and simplify the creation of AI-based robotic applications. It offers low latency, composable, and distributed datafl