diff --git a/apis/python/operator/src/lib.rs b/apis/python/operator/src/lib.rs index bb7054dc..91451fe9 100644 --- a/apis/python/operator/src/lib.rs +++ b/apis/python/operator/src/lib.rs @@ -1,8 +1,9 @@ -use std::{borrow::Cow, sync::Arc}; +use std::sync::Arc; -use arrow::pyarrow::PyArrowConvert; +use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use arrow_schema::DataType; use dora_node_api::{Data, Event, Metadata, MetadataParameters}; -use eyre::{Context, Result}; +use eyre::{bail, Context, ContextCompat, Result}; use pyo3::{ exceptions::PyLookupError, prelude::*, @@ -62,9 +63,22 @@ impl PyEvent { /// Returns the payload of an input event as an arrow array (if any). fn value(&self, py: Python<'_>) -> PyResult> { if let Some(data) = &self.data { + let data_type = match &self.event { + Event::Input { metadata, .. } => match &metadata + .schema + .fields() + .first() + .context("no field in schema")? + .data_type() + { + DataType::List(field) => field.data_type().clone(), + _ => todo!(), + }, + _ => DataType::UInt8, + }; let array = data .clone() - .into_arrow_array() + .into_arrow_array(data_type) .map_err(|err| arrow::pyarrow::PyArrowException::new_err(err.to_string()))?; // TODO: Does this call leak data? let array_data = array.to_pyarrow(py)?; @@ -117,7 +131,7 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result { let otel_context: &str = value .extract() .context("parsing open telemetry context failed")?; - default_metadata.open_telemetry_context = Cow::Borrowed(otel_context); + default_metadata.open_telemetry_context = otel_context.to_string(); } _ => (), } @@ -141,14 +155,11 @@ pub fn python_output_len(data: &PyObject, py: Python) -> eyre::Result { if let Ok(py_bytes) = data.downcast::(py) { py_bytes.len().wrap_err("failed to get length of PyBytes") } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow(data.as_ref(py)) { - if arrow_array.data_type() != &arrow::datatypes::DataType::UInt8 { - eyre::bail!("only arrow arrays with data type `UInt8` are supported"); - } if arrow_array.buffers().len() != 1 { eyre::bail!("output arrow array must contain a single buffer"); } - Ok(arrow_array.len()) + Ok(arrow_array.buffers()[0].len()) } else { eyre::bail!("invalid `data` type, must by `PyBytes` or arrow array") } @@ -163,15 +174,15 @@ pub fn process_python_output( let data = py_bytes.as_bytes(); callback(data) } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow(data.as_ref(py)) { - if arrow_array.data_type() != &arrow::datatypes::DataType::UInt8 { - eyre::bail!("only arrow arrays with data type `UInt8` are supported"); - } if arrow_array.buffers().len() != 1 { eyre::bail!("output arrow array must contain a single buffer"); } - let len = arrow_array.len(); - let slice = &arrow_array.buffer(0)[..len]; + let buffers = arrow_array.buffers(); + if buffers.len() != 1 { + bail!("Arrow array must contain a single buffer"); + } + let slice = buffers[0]; callback(slice) } else { diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index 7e211236..acfaf6c8 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -310,17 +310,18 @@ mod callback_impl { Ok(()) })?; - let metadata = pydict_to_metadata(metadata) + let parameters = pydict_to_metadata(metadata) .wrap_err("failed to parse metadata")? .into_owned(); + let data_type = process_python_type(&data, py)?; + py.allow_threads(|| { let event = OperatorEvent::Output { output_id: output.to_owned().into(), - metadata, + data_type, + parameters, data: Some(sample), }; - - py.allow_threads(|| { self.events_tx .blocking_send(event) .map_err(|_| eyre!("failed to send output to runtime"))