You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python.rs 15 kB

1 year ago

  1. #![allow(clippy::borrow_deref_ref)] // clippy warns about code generated by #[pymethods]
  2. use super::{OperatorEvent, StopReason};
  3. use dora_core::{
  4. config::{NodeId, OperatorId},
  5. descriptor::{source_is_url, Descriptor, PythonSource},
  6. };
  7. use dora_download::download_file;
  8. use dora_node_api::{merged::MergedEvent, Event, Parameter};
  9. use dora_operator_api_python::PyEvent;
  10. use dora_operator_api_types::DoraStatus;
  11. use eyre::{bail, eyre, Context, Result};
  12. use pyo3::ffi::c_str;
  13. use pyo3::{
  14. pyclass,
  15. types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods, PyTracebackMethods},
  16. Py, PyAny, Python,
  17. };
  18. use std::{
  19. panic::{catch_unwind, AssertUnwindSafe},
  20. path::Path,
  21. };
  22. use tokio::sync::{mpsc::Sender, oneshot};
  23. use tracing::{error, field, span, warn};
  24. fn traceback(err: pyo3::PyErr) -> eyre::Report {
  25. let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok()));
  26. if let Some(traceback) = traceback {
  27. eyre::eyre!("{traceback}\n{err}")
  28. } else {
  29. eyre::eyre!("{err}")
  30. }
  31. }
  32. #[tracing::instrument(skip(events_tx, incoming_events), level = "trace")]
  33. pub fn run(
  34. node_id: &NodeId,
  35. operator_id: &OperatorId,
  36. python_source: &PythonSource,
  37. events_tx: Sender<OperatorEvent>,
  38. incoming_events: flume::Receiver<Event>,
  39. init_done: oneshot::Sender<Result<()>>,
  40. dataflow_descriptor: &Descriptor,
  41. ) -> eyre::Result<()> {
  42. let path = if source_is_url(&python_source.source) {
  43. let target_path = Path::new("build");
  44. // try to download the shared library
  45. let rt = tokio::runtime::Builder::new_current_thread()
  46. .enable_all()
  47. .build()?;
  48. rt.block_on(download_file(&python_source.source, target_path))
  49. .wrap_err("failed to download Python operator")?
  50. } else {
  51. Path::new(&python_source.source).to_owned()
  52. };
  53. if !path.exists() {
  54. bail!("No python file exists at {}", path.display());
  55. }
  56. let path = path
  57. .canonicalize()
  58. .wrap_err_with(|| format!("no file found at `{}`", path.display()))?;
  59. let module_name = path
  60. .file_stem()
  61. .ok_or_else(|| eyre!("module path has no file stem"))?
  62. .to_str()
  63. .ok_or_else(|| eyre!("module file stem is not valid utf8"))?;
  64. let path_parent = path.parent();
  65. let send_output = SendOutputCallback {
  66. events_tx: events_tx.clone(),
  67. };
  68. let init_operator = move |py: Python| {
  69. if let Some(parent_path) = path_parent {
  70. let parent_path = parent_path
  71. .to_str()
  72. .ok_or_else(|| eyre!("module path is not valid utf8"))?;
  73. let sys = py.import("sys").wrap_err("failed to import `sys` module")?;
  74. let sys_path = sys
  75. .getattr("path")
  76. .wrap_err("failed to import `sys.path` module")?;
  77. let sys_path_append = sys_path
  78. .getattr("append")
  79. .wrap_err("`sys.path.append` was not found")?;
  80. sys_path_append
  81. .call1((parent_path,))
  82. .wrap_err("failed to append module path to python search path")?;
  83. }
  84. let module = py.import(module_name).map_err(traceback)?;
  85. let operator_class = module
  86. .getattr("Operator")
  87. .wrap_err("no `Operator` class found in module")?;
  88. let locals = [("Operator", operator_class)]
  89. .into_py_dict(py)
  90. .context("Failed to create py_dict")?;
  91. let operator = py
  92. .eval(c_str!("Operator()"), None, Some(&locals))
  93. .map_err(traceback)?;
  94. operator.setattr(
  95. "dataflow_descriptor",
  96. pythonize::pythonize(py, dataflow_descriptor)?,
  97. )?;
  98. Result::<_, eyre::Report>::Ok(Py::from(operator))
  99. };
  100. let python_runner = move || {
  101. let mut operator =
  102. match Python::with_gil(init_operator).wrap_err("failed to init python operator") {
  103. Ok(op) => {
  104. let _ = init_done.send(Ok(()));
  105. op
  106. }
  107. Err(err) => {
  108. let _ = init_done.send(Err(err));
  109. bail!("Could not init python operator")
  110. }
  111. };
  112. let mut reload = false;
  113. let reason = loop {
  114. #[allow(unused_mut)]
  115. let Ok(mut event) = incoming_events.recv() else {
  116. break StopReason::InputsClosed;
  117. };
  118. if let Event::Reload { .. } = event {
  119. reload = true;
  120. // Reloading method
  121. #[allow(clippy::blocks_in_conditions)]
  122. match Python::with_gil(|py| -> Result<Py<PyAny>> {
  123. // Saving current state
  124. let current_state = operator
  125. .getattr(py, "__dict__")
  126. .wrap_err("Could not retrieve current operator state")?;
  127. let current_state =
  128. current_state.downcast_bound::<PyDict>(py).map_err(|err| {
  129. eyre!("could not extract operator state as a PyDict. Err: {}", err)
  130. })?;
  131. // Reload module
  132. let module = py
  133. .import(module_name)
  134. .map_err(traceback)
  135. .wrap_err(format!("Could not retrieve {module_name} while reloading"))?;
  136. let importlib = py
  137. .import("importlib")
  138. .wrap_err("failed to import `importlib` module")?;
  139. let module = importlib
  140. .call_method("reload", (module,), None)
  141. .wrap_err(format!("Could not reload {module_name} while reloading"))?;
  142. let reloaded_operator_class = module
  143. .getattr("Operator")
  144. .wrap_err("no `Operator` class found in module")?;
  145. // Create a new reloaded operator
  146. let locals = [("Operator", reloaded_operator_class)]
  147. .into_py_dict(py)
  148. .context("Failed to create py_dict")?;
  149. let operator: Py<pyo3::PyAny> = py
  150. .eval(c_str!("Operator()"), None, Some(&locals))
  151. .map_err(traceback)
  152. .wrap_err("Could not initialize reloaded operator")?
  153. .into();
  154. // Replace initialized state with current state
  155. operator
  156. .getattr(py, "__dict__")
  157. .wrap_err("Could not retrieve new operator state")?
  158. .downcast_bound::<PyDict>(py)
  159. .map_err(|err| {
  160. eyre!("could not extract new operator state as a PyDict. Err: {err}")
  161. })?
  162. .update(current_state.as_mapping())
  163. .wrap_err("could not restore operator state")?;
  164. Ok(operator)
  165. }) {
  166. Ok(reloaded_operator) => {
  167. operator = reloaded_operator;
  168. }
  169. Err(err) => {
  170. error!("Failed to reload operator.\n {err}");
  171. }
  172. }
  173. }
  174. let status = Python::with_gil(|py| -> Result<i32> {
  175. let span = span!(tracing::Level::TRACE, "on_event", input_id = field::Empty);
  176. let _ = span.enter();
  177. // Add metadata context if we have a tracer and
  178. // incoming input has some metadata.
  179. #[cfg(feature = "telemetry")]
  180. if let Event::Input {
  181. id: input_id,
  182. metadata,
  183. ..
  184. } = &mut event
  185. {
  186. use dora_tracing::telemetry::{deserialize_context, serialize_context};
  187. use tracing_opentelemetry::OpenTelemetrySpanExt;
  188. span.record("input_id", input_id.as_str());
  189. let otel = metadata.open_telemetry_context();
  190. let cx = deserialize_context(&otel);
  191. span.set_parent(cx);
  192. let cx = span.context();
  193. let string_cx = serialize_context(&cx);
  194. metadata.parameters.insert(
  195. "open_telemetry_context".to_string(),
  196. Parameter::String(string_cx),
  197. );
  198. }
  199. let py_event = PyEvent {
  200. event: MergedEvent::Dora(event),
  201. }
  202. .to_py_dict(py)
  203. .context("Could not convert event to pydict bound")?;
  204. let status_enum = operator
  205. .call_method1(py, "on_event", (py_event, send_output.clone()))
  206. .map_err(traceback);
  207. match status_enum {
  208. Ok(status_enum) => {
  209. let status_val = Python::with_gil(|py| status_enum.getattr(py, "value"))
  210. .wrap_err("on_event must have enum return value")?;
  211. Python::with_gil(|py| status_val.extract(py))
  212. .wrap_err("on_event has invalid return value")
  213. }
  214. Err(err) => {
  215. if reload {
  216. // Allow error in hot reloading environment to help development.
  217. warn!("{err}");
  218. Ok(DoraStatus::Continue as i32)
  219. } else {
  220. Err(err)
  221. }
  222. }
  223. }
  224. })?;
  225. match status {
  226. s if s == DoraStatus::Continue as i32 => {} // ok
  227. s if s == DoraStatus::Stop as i32 => break StopReason::ExplicitStop,
  228. s if s == DoraStatus::StopAll as i32 => break StopReason::ExplicitStopAll,
  229. other => bail!("on_event returned invalid status {other}"),
  230. }
  231. };
  232. // Dropping the operator using Python garbage collector.
  233. // Locking the GIL for immediate release.
  234. Python::with_gil(|_py| {
  235. drop(operator);
  236. });
  237. Result::<_, eyre::Report>::Ok(reason)
  238. };
  239. let closure = AssertUnwindSafe(|| {
  240. python_runner().wrap_err_with(|| format!("error in Python module at {}", path.display()))
  241. });
  242. match catch_unwind(closure) {
  243. Ok(Ok(reason)) => {
  244. let _ = events_tx.blocking_send(OperatorEvent::Finished { reason });
  245. }
  246. Ok(Err(err)) => {
  247. let _ = events_tx.blocking_send(OperatorEvent::Error(err));
  248. }
  249. Err(panic) => {
  250. let _ = events_tx.blocking_send(OperatorEvent::Panic(panic));
  251. }
  252. }
  253. Ok(())
  254. }
  255. #[pyclass]
  256. #[derive(Clone)]
  257. struct SendOutputCallback {
  258. events_tx: Sender<OperatorEvent>,
  259. }
  260. #[allow(unsafe_op_in_unsafe_fn)]
  261. mod callback_impl {
  262. use crate::operator::OperatorEvent;
  263. use super::SendOutputCallback;
  264. use aligned_vec::{AVec, ConstAlign};
  265. use arrow::{array::ArrayData, pyarrow::FromPyArrow};
  266. use dora_core::metadata::ArrowTypeInfoExt;
  267. use dora_message::metadata::ArrowTypeInfo;
  268. use dora_node_api::{
  269. arrow_utils::{copy_array_into_sample, required_data_size},
  270. ZERO_COPY_THRESHOLD,
  271. };
  272. use dora_operator_api_python::pydict_to_metadata;
  273. use dora_tracing::telemetry::deserialize_context;
  274. use eyre::{eyre, Context, Result};
  275. use pyo3::{
  276. pymethods,
  277. types::{PyBytes, PyBytesMethods, PyDict},
  278. Bound, PyObject, Python,
  279. };
  280. use tokio::sync::oneshot;
  281. use tracing::{field, span};
  282. use tracing_opentelemetry::OpenTelemetrySpanExt;
  283. /// Send an output from the operator:
  284. /// - the first argument is the `output_id` as defined in your dataflow.
  285. /// - the second argument is the data as either bytes or pyarrow.Array for zero copy.
  286. /// - the third argument is dora metadata if you want to link the tracing from one input into an output.
  287. /// `e.g.: send_output("bbox", pa.array([100], type=pa.uint8()), dora_event["metadata"])`
  288. #[pymethods]
  289. impl SendOutputCallback {
  290. #[pyo3(signature = (output, data, metadata=None))]
  291. fn __call__(
  292. &mut self,
  293. output: &str,
  294. data: PyObject,
  295. metadata: Option<Bound<'_, PyDict>>,
  296. py: Python,
  297. ) -> Result<()> {
  298. let parameters = pydict_to_metadata(metadata).wrap_err("failed to parse metadata")?;
  299. let span = span!(
  300. tracing::Level::TRACE,
  301. "send_output",
  302. output_id = field::Empty
  303. );
  304. span.record("output_id", output);
  305. let otel = if let Some(dora_node_api::Parameter::String(otel)) =
  306. parameters.get("open_telemetry_context")
  307. {
  308. otel.to_string()
  309. } else {
  310. "".to_string()
  311. };
  312. let cx = deserialize_context(&otel);
  313. span.set_parent(cx);
  314. let _ = span.enter();
  315. let allocate_sample = |data_len| {
  316. if data_len > ZERO_COPY_THRESHOLD {
  317. let (tx, rx) = oneshot::channel();
  318. self.events_tx
  319. .blocking_send(OperatorEvent::AllocateOutputSample {
  320. len: data_len,
  321. sample: tx,
  322. })
  323. .map_err(|_| eyre!("failed to send output to runtime"))?;
  324. rx.blocking_recv()
  325. .wrap_err("failed to request output sample")?
  326. .wrap_err("failed to allocate output sample")
  327. } else {
  328. let avec: AVec<u8, ConstAlign<128>> = AVec::__from_elem(128, 0, data_len);
  329. Ok(avec.into())
  330. }
  331. };
  332. let (sample, type_info) = if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
  333. let data = py_bytes.as_bytes();
  334. let mut sample = allocate_sample(data.len())?;
  335. sample.copy_from_slice(data);
  336. (sample, ArrowTypeInfo::byte_array(data.len()))
  337. } else if let Ok(arrow_array) = ArrayData::from_pyarrow_bound(data.bind(py)) {
  338. let total_len = required_data_size(&arrow_array);
  339. let mut sample = allocate_sample(total_len)?;
  340. let type_info = copy_array_into_sample(&mut sample, &arrow_array);
  341. (sample, type_info)
  342. } else {
  343. eyre::bail!("invalid `data` type, must by `PyBytes` or arrow array")
  344. };
  345. py.allow_threads(|| {
  346. let event = OperatorEvent::Output {
  347. output_id: output.to_owned().into(),
  348. type_info,
  349. parameters,
  350. data: Some(sample),
  351. };
  352. self.events_tx
  353. .blocking_send(event)
  354. .map_err(|_| eyre!("failed to send output to runtime"))
  355. })?;
  356. Ok(())
  357. }
  358. }
  359. }