You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python.rs 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. #![allow(clippy::borrow_deref_ref)] // clippy warns about code generated by #[pymethods]
  2. use super::{OperatorEvent, StopReason};
  3. use dora_core::{
  4. config::{NodeId, OperatorId},
  5. descriptor::source_is_url,
  6. };
  7. use dora_download::download_file;
  8. use dora_node_api::Event;
  9. use dora_operator_api_python::PyEvent;
  10. use dora_operator_api_types::DoraStatus;
  11. use eyre::{bail, eyre, Context, Result};
  12. use pyo3::{
  13. pyclass,
  14. types::{IntoPyDict, PyDict},
  15. Py, PyAny, Python,
  16. };
  17. use std::{
  18. panic::{catch_unwind, AssertUnwindSafe},
  19. path::Path,
  20. };
  21. use tokio::sync::{mpsc::Sender, oneshot};
  22. use tracing::{error, field, span, warn};
  23. fn traceback(err: pyo3::PyErr) -> eyre::Report {
  24. let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok()));
  25. if let Some(traceback) = traceback {
  26. eyre::eyre!("{err}{traceback}")
  27. } else {
  28. eyre::eyre!("{err}")
  29. }
  30. }
  31. #[tracing::instrument(skip(events_tx, incoming_events), level = "trace")]
  32. pub fn run(
  33. node_id: &NodeId,
  34. operator_id: &OperatorId,
  35. source: &str,
  36. events_tx: Sender<OperatorEvent>,
  37. incoming_events: flume::Receiver<Event>,
  38. init_done: oneshot::Sender<Result<()>>,
  39. ) -> eyre::Result<()> {
  40. let path = if source_is_url(source) {
  41. let target_path = Path::new("build")
  42. .join(node_id.to_string())
  43. .join(format!("{}.py", operator_id));
  44. // try to download the shared library
  45. let rt = tokio::runtime::Builder::new_current_thread()
  46. .enable_all()
  47. .build()?;
  48. rt.block_on(download_file(source, &target_path))
  49. .wrap_err("failed to download Python operator")?;
  50. target_path
  51. } else {
  52. Path::new(source).to_owned()
  53. };
  54. if !path.exists() {
  55. bail!("No python file exists at {}", path.display());
  56. }
  57. let path = path
  58. .canonicalize()
  59. .wrap_err_with(|| format!("no file found at `{}`", path.display()))?;
  60. let module_name = path
  61. .file_stem()
  62. .ok_or_else(|| eyre!("module path has no file stem"))?
  63. .to_str()
  64. .ok_or_else(|| eyre!("module file stem is not valid utf8"))?;
  65. let path_parent = path.parent();
  66. let send_output = SendOutputCallback {
  67. events_tx: events_tx.clone(),
  68. };
  69. let init_operator = move |py: Python| {
  70. if let Some(parent_path) = path_parent {
  71. let parent_path = parent_path
  72. .to_str()
  73. .ok_or_else(|| eyre!("module path is not valid utf8"))?;
  74. let sys = py.import("sys").wrap_err("failed to import `sys` module")?;
  75. let sys_path = sys
  76. .getattr("path")
  77. .wrap_err("failed to import `sys.path` module")?;
  78. let sys_path_append = sys_path
  79. .getattr("append")
  80. .wrap_err("`sys.path.append` was not found")?;
  81. sys_path_append
  82. .call1((parent_path,))
  83. .wrap_err("failed to append module path to python search path")?;
  84. }
  85. let module = py.import(module_name).map_err(traceback)?;
  86. let operator_class = module
  87. .getattr("Operator")
  88. .wrap_err("no `Operator` class found in module")?;
  89. let locals = [("Operator", operator_class)].into_py_dict(py);
  90. let operator = py
  91. .eval("Operator()", None, Some(locals))
  92. .map_err(traceback)?;
  93. Result::<_, eyre::Report>::Ok(Py::from(operator))
  94. };
  95. let python_runner = move || {
  96. let mut operator =
  97. match Python::with_gil(init_operator).wrap_err("failed to init python operator") {
  98. Ok(op) => {
  99. let _ = init_done.send(Ok(()));
  100. op
  101. }
  102. Err(err) => {
  103. let _ = init_done.send(Err(err));
  104. bail!("Could not init python operator")
  105. }
  106. };
  107. let reason = loop {
  108. #[allow(unused_mut)]
  109. let Ok(mut event) = incoming_events.recv() else { break StopReason::InputsClosed };
  110. if let Event::Reload { .. } = event {
  111. // Reloading method
  112. match Python::with_gil(|py| -> Result<Py<PyAny>> {
  113. // Saving current state
  114. let current_state = operator
  115. .getattr(py, "__dict__")
  116. .wrap_err("Could not retrieve current operator state")?;
  117. let current_state = current_state
  118. .extract::<&PyDict>(py)
  119. .wrap_err("could not extract operator state as a PyDict")?;
  120. // Reload module
  121. let module = py
  122. .import(module_name)
  123. .map_err(traceback)
  124. .wrap_err(format!("Could not retrieve {module_name} while reloading"))?;
  125. let importlib = py
  126. .import("importlib")
  127. .wrap_err("failed to import `importlib` module")?;
  128. let module = importlib
  129. .call_method("reload", (module,), None)
  130. .wrap_err(format!("Could not reload {module_name} while reloading"))?;
  131. let reloaded_operator_class = module
  132. .getattr("Operator")
  133. .wrap_err("no `Operator` class found in module")?;
  134. // Create a new reloaded operator
  135. let locals = [("Operator", reloaded_operator_class)].into_py_dict(py);
  136. let operator: Py<pyo3::PyAny> = py
  137. .eval("Operator()", None, Some(locals))
  138. .map_err(traceback)
  139. .wrap_err("Could not initialize reloaded operator")?
  140. .into();
  141. // Replace initialized state with current state
  142. operator
  143. .getattr(py, "__dict__")
  144. .wrap_err("Could not retrieve new operator state")?
  145. .extract::<&PyDict>(py)
  146. .wrap_err("could not extract new operator state as a PyDict")?
  147. .update(current_state.as_mapping())
  148. .wrap_err("could not restore operator state")?;
  149. Ok(operator)
  150. }) {
  151. Ok(reloaded_operator) => {
  152. operator = reloaded_operator;
  153. }
  154. Err(err) => {
  155. error!("Failed to reload operator.\n {err}");
  156. }
  157. }
  158. }
  159. let status = Python::with_gil(|py| -> Result<i32> {
  160. let span = span!(tracing::Level::TRACE, "on_event", input_id = field::Empty);
  161. let _ = span.enter();
  162. // We need to create a new scoped `GILPool` because the dora-runtime
  163. // is currently started through a `start_runtime` wrapper function,
  164. // which is annotated with `#[pyfunction]`. This attribute creates an
  165. // initial `GILPool` that lasts for the entire lifetime of the `dora-runtime`.
  166. // However, we want the `PyBytes` created below to be freed earlier.
  167. // creating a new scoped `GILPool` tied to this closure, will free `PyBytes`
  168. // at the end of the closure.
  169. // See https://github.com/PyO3/pyo3/pull/2864 and
  170. // https://github.com/PyO3/pyo3/issues/2853 for more details.
  171. let pool = unsafe { py.new_pool() };
  172. let py = pool.python();
  173. // Add metadata context if we have a tracer and
  174. // incoming input has some metadata.
  175. #[cfg(feature = "telemetry")]
  176. if let Event::Input {
  177. id: input_id,
  178. metadata,
  179. ..
  180. } = &mut event
  181. {
  182. use dora_tracing::telemetry::{deserialize_context, serialize_context};
  183. use std::borrow::Cow;
  184. use tracing_opentelemetry::OpenTelemetrySpanExt;
  185. span.record("input_id", input_id.as_str());
  186. let cx = deserialize_context(&metadata.parameters.open_telemetry_context);
  187. span.set_parent(cx);
  188. let cx = span.context();
  189. let string_cx = serialize_context(&cx);
  190. metadata.parameters.open_telemetry_context = Cow::Owned(string_cx);
  191. }
  192. let py_event = PyEvent::from(event);
  193. let status_enum = operator
  194. .call_method1(py, "on_event", (py_event, send_output.clone()))
  195. .map_err(traceback)?;
  196. let status_val = Python::with_gil(|py| status_enum.getattr(py, "value"))
  197. .wrap_err("on_event must have enum return value")?;
  198. Python::with_gil(|py| status_val.extract(py))
  199. .wrap_err("on_event has invalid return value")
  200. })?;
  201. match status {
  202. s if s == DoraStatus::Continue as i32 => {} // ok
  203. s if s == DoraStatus::Stop as i32 => break StopReason::ExplicitStop,
  204. s if s == DoraStatus::StopAll as i32 => break StopReason::ExplicitStopAll,
  205. other => bail!("on_event returned invalid status {other}"),
  206. }
  207. };
  208. // Dropping the operator using Python garbage collector.
  209. // Locking the GIL for immediate release.
  210. Python::with_gil(|_py| {
  211. drop(operator);
  212. });
  213. Result::<_, eyre::Report>::Ok(reason)
  214. };
  215. let closure = AssertUnwindSafe(|| {
  216. python_runner().wrap_err_with(|| format!("error in Python module at {}", path.display()))
  217. });
  218. match catch_unwind(closure) {
  219. Ok(Ok(reason)) => {
  220. let _ = events_tx.blocking_send(OperatorEvent::Finished { reason });
  221. }
  222. Ok(Err(err)) => {
  223. let _ = events_tx.blocking_send(OperatorEvent::Error(err));
  224. }
  225. Err(panic) => {
  226. let _ = events_tx.blocking_send(OperatorEvent::Panic(panic));
  227. }
  228. }
  229. Ok(())
  230. }
  231. #[pyclass]
  232. #[derive(Clone)]
  233. struct SendOutputCallback {
  234. events_tx: Sender<OperatorEvent>,
  235. }
  236. #[allow(unsafe_op_in_unsafe_fn)]
  237. mod callback_impl {
  238. use crate::operator::OperatorEvent;
  239. use super::SendOutputCallback;
  240. use dora_operator_api_python::{process_python_output, pydict_to_metadata, python_output_len};
  241. use eyre::{eyre, Context, Result};
  242. use pyo3::{pymethods, types::PyDict, PyObject, Python};
  243. use tokio::sync::oneshot;
  244. #[pymethods]
  245. impl SendOutputCallback {
  246. fn __call__(
  247. &mut self,
  248. output: &str,
  249. data: PyObject,
  250. metadata: Option<&PyDict>,
  251. py: Python,
  252. ) -> Result<()> {
  253. let data_len = python_output_len(&data, py)?;
  254. let mut sample = py.allow_threads(|| {
  255. let (tx, rx) = oneshot::channel();
  256. self.events_tx
  257. .blocking_send(OperatorEvent::AllocateOutputSample {
  258. len: data_len,
  259. sample: tx,
  260. })
  261. .map_err(|_| eyre!("failed to send output to runtime"))?;
  262. let sample = rx
  263. .blocking_recv()
  264. .wrap_err("failed to request output sample")?
  265. .wrap_err("failed to allocate output sample")?;
  266. Result::<_, eyre::Report>::Ok(sample)
  267. })?;
  268. process_python_output(&data, py, |data| {
  269. sample.copy_from_slice(data);
  270. Ok(())
  271. })?;
  272. let metadata = pydict_to_metadata(metadata)
  273. .wrap_err("failed to parse metadata")?
  274. .into_owned();
  275. let event = OperatorEvent::Output {
  276. output_id: output.to_owned().into(),
  277. metadata,
  278. data: Some(sample),
  279. };
  280. py.allow_threads(|| {
  281. self.events_tx
  282. .blocking_send(event)
  283. .map_err(|_| eyre!("failed to send output to runtime"))
  284. })?;
  285. Ok(())
  286. }
  287. }
  288. }

DORA (Dataflow-Oriented Robotic Architecture) is middleware designed to streamline and simplify the creation of AI-based robotic applications. It offers low latency, composable, and distributed datafl