diff --git a/binaries/runtime/src/operator/mod.rs b/binaries/runtime/src/operator/mod.rs index 8103e3c6..ded825df 100644 --- a/binaries/runtime/src/operator/mod.rs +++ b/binaries/runtime/src/operator/mod.rs @@ -3,9 +3,14 @@ use dora_core::{ descriptor::{OperatorDefinition, OperatorSource}, message::{Metadata, MetadataParameters}, }; +use dora_operator_api_python::metadata_to_pydict; use eyre::Context; #[cfg(feature = "tracing")] use opentelemetry::sdk::trace::Tracer; +use pyo3::{ + types::{PyBytes, PyDict}, + IntoPy, PyObject, Python, +}; use std::any::Any; use tokio::sync::mpsc::{Receiver, Sender}; @@ -95,6 +100,41 @@ pub enum IncomingEvent { }, } +impl IntoPy for IncomingEvent { + fn into_py(self, py: Python) -> PyObject { + let dict = PyDict::new(py); + + let ty = match self { + Self::Stop => "STOP", + Self::Input { + input_id, + metadata, + data, + } => { + dict.set_item("id", input_id.to_string()) + .wrap_err("failed to add input ID") + .unwrap(); + dict.set_item( + "data", + PyBytes::new(py, data.as_deref().unwrap_or_default()), + ) + .wrap_err("failed to add input data") + .unwrap(); + dict.set_item("metadata", metadata_to_pydict(&metadata, py)) + .wrap_err("failed to add input metadata") + .unwrap(); + "INPUT" + } + }; + + dict.set_item("type", ty) + .wrap_err("could not make type a python dictionary item") + .unwrap(); + + dict.into() + } +} + #[derive(Debug)] pub enum StopReason { InputsClosed, diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index be749d1e..5f8655cc 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -6,15 +6,9 @@ use dora_core::{ descriptor::source_is_url, }; use dora_download::download_file; -use dora_operator_api_python::metadata_to_pydict; use dora_operator_api_types::DoraStatus; use eyre::{bail, eyre, Context, Result}; -use pyo3::{ - pyclass, - types::IntoPyDict, - types::{PyBytes, PyDict}, - Py, Python, -}; +use pyo3::{pyclass, types::IntoPyDict, IntoPy, Py, Python}; use std::{ borrow::Cow, panic::{catch_unwind, AssertUnwindSafe}, @@ -23,15 +17,12 @@ use std::{ use tokio::sync::mpsc::{Receiver, Sender}; fn traceback(err: pyo3::PyErr) -> eyre::Report { - Python::with_gil(|py| { - eyre::Report::msg(format!( - "{}\n{err}", - err.traceback(py) - .expect("PyError should have a traceback") - .format() - .expect("Traceback could not be formatted") - )) - }) + let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok())); + if let Some(traceback) = traceback { + eyre::eyre!("{err}:\n{traceback}") + } else { + eyre::eyre!("{err}") + } } #[tracing::instrument(skip(events_tx, incoming_events, tracer))] @@ -109,72 +100,61 @@ pub fn run( Python::with_gil(init_operator).wrap_err("failed to init python operator")?; let reason = loop { - let Some(event) = incoming_events.blocking_recv() else { break StopReason::InputsClosed }; - - match event { - IncomingEvent::Input { - input_id, - mut metadata, - data, - } => { - #[cfg(feature = "tracing")] - let (_child_cx, string_cx) = { - use dora_tracing::{deserialize_context, serialize_context}; - use opentelemetry::trace::TraceContextExt; - use opentelemetry::{trace::Tracer, Context as OtelContext}; - - let cx = deserialize_context(&metadata.parameters.open_telemetry_context); - let span = tracer.start_with_context(format!("{}", input_id), &cx); - - let child_cx = OtelContext::current_with_span(span); - let string_cx = serialize_context(&child_cx); - (child_cx, string_cx) - }; - - #[cfg(not(feature = "tracing"))] - let string_cx = { - let () = tracer; - "".to_string() - }; - metadata.parameters.open_telemetry_context = Cow::Owned(string_cx); - - let status = Python::with_gil(|py| -> Result { - // We need to create a new scoped `GILPool` because the dora-runtime - // is currently started through a `start_runtime` wrapper function, - // which is annotated with `#[pyfunction]`. This attribute creates an - // initial `GILPool` that lasts for the entire lifetime of the `dora-runtime`. - // However, we want the `PyBytes` created below to be freed earlier. - // creating a new scoped `GILPool` tied to this closure, will free `PyBytes` - // at the end of the closure. - // See https://github.com/PyO3/pyo3/pull/2864 and - // https://github.com/PyO3/pyo3/issues/2853 for more details. - let pool = unsafe { py.new_pool() }; - let py = pool.python(); - let input_dict = PyDict::new(py); - - input_dict.set_item("id", input_id.as_str())?; - if let Some(data) = data { - let bytes = PyBytes::new(py, &data); - input_dict.set_item("data", bytes)?; - } - input_dict.set_item("metadata", metadata_to_pydict(&metadata, py))?; - - let status_enum = operator - .call_method1(py, "on_input", (input_dict, send_output.clone())) - .map_err(traceback)?; - let status_val = Python::with_gil(|py| status_enum.getattr(py, "value")) - .wrap_err("on_input must have enum return value")?; - Python::with_gil(|py| status_val.extract(py)) - .wrap_err("on_input has invalid return value") - })?; - match status { - s if s == DoraStatus::Continue as i32 => {} // ok - s if s == DoraStatus::Stop as i32 => break StopReason::ExplicitStop, - s if s == DoraStatus::StopAll as i32 => break StopReason::ExplicitStopAll, - other => bail!("on_input returned invalid status {other}"), - } - } - IncomingEvent::Stop => {} + let Some(mut event) = incoming_events.blocking_recv() else { break StopReason::InputsClosed }; + + if let IncomingEvent::Input { + input_id, metadata, .. + } = &mut event + { + #[cfg(feature = "tracing")] + let (_child_cx, string_cx) = { + use dora_tracing::{deserialize_context, serialize_context}; + use opentelemetry::trace::TraceContextExt; + use opentelemetry::{trace::Tracer, Context as OtelContext}; + + let cx = deserialize_context(&metadata.parameters.open_telemetry_context); + let span = tracer.start_with_context(format!("{}", input_id), &cx); + + let child_cx = OtelContext::current_with_span(span); + let string_cx = serialize_context(&child_cx); + (child_cx, string_cx) + }; + + #[cfg(not(feature = "tracing"))] + let string_cx = { + let _ = input_id; + let () = tracer; + "".to_string() + }; + metadata.parameters.open_telemetry_context = Cow::Owned(string_cx); + } + let status = Python::with_gil(|py| -> Result { + // We need to create a new scoped `GILPool` because the dora-runtime + // is currently started through a `start_runtime` wrapper function, + // which is annotated with `#[pyfunction]`. This attribute creates an + // initial `GILPool` that lasts for the entire lifetime of the `dora-runtime`. + // However, we want the `PyBytes` created below to be freed earlier. + // creating a new scoped `GILPool` tied to this closure, will free `PyBytes` + // at the end of the closure. + // See https://github.com/PyO3/pyo3/pull/2864 and + // https://github.com/PyO3/pyo3/issues/2853 for more details. + let pool = unsafe { py.new_pool() }; + let py = pool.python(); + let input_dict = event.into_py(py); + + let status_enum = operator + .call_method1(py, "on_event", (input_dict, send_output.clone())) + .map_err(traceback)?; + let status_val = Python::with_gil(|py| status_enum.getattr(py, "value")) + .wrap_err("on_event must have enum return value")?; + Python::with_gil(|py| status_val.extract(py)) + .wrap_err("on_event has invalid return value") + })?; + match status { + s if s == DoraStatus::Continue as i32 => {} // ok + s if s == DoraStatus::Stop as i32 => break StopReason::ExplicitStop, + s if s == DoraStatus::StopAll as i32 => break StopReason::ExplicitStopAll, + other => bail!("on_event returned invalid status {other}"), } }; diff --git a/examples/python-operator-dataflow/object_detection.py b/examples/python-operator-dataflow/object_detection.py index 098ec4d1..fd103f86 100755 --- a/examples/python-operator-dataflow/object_detection.py +++ b/examples/python-operator-dataflow/object_detection.py @@ -22,6 +22,14 @@ class Operator: def __init__(self): self.model = torch.hub.load("ultralytics/yolov5", "yolov5n") + def on_event( + self, + dora_event: dict, + send_output: Callable[[str, bytes], None], + ) -> DoraStatus: + if dora_event["type"] == "INPUT": + return self.on_input(dora_event, send_output) + def on_input( self, dora_input: dict, diff --git a/examples/python-operator-dataflow/plot.py b/examples/python-operator-dataflow/plot.py index 57a2a293..6c95eae8 100755 --- a/examples/python-operator-dataflow/plot.py +++ b/examples/python-operator-dataflow/plot.py @@ -26,6 +26,14 @@ class Operator: self.image = [] self.bboxs = [] + def on_event( + self, + dora_event: dict, + send_output: Callable[[str, bytes], None], + ) -> DoraStatus: + if dora_event["type"] == "INPUT": + return self.on_input(dora_event, send_output) + def on_input( self, dora_input: dict,