From 7353c79e9ebe106355b90106c02f3e6d439a406c Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Fri, 29 Apr 2022 18:35:05 +0200 Subject: [PATCH] Add support for shared library operators The library must provide a `dora_on_input` function, which the runtime invokes when it receives input for the operator. In addition to the input id and value, we pass a callback function for sending output. By operating on single `(id, value)` pairs instead of sets of them, the operators stay flexible. They can define their own logic for input rules and send outputs as soon as they become available. The callback function design also limits allocations since output values can be be stack-allocated this way. --- Cargo.lock | 1 + common/src/descriptor/mod.rs | 4 +- runtime/Cargo.toml | 1 + runtime/src/main.rs | 56 ++++++++++- runtime/src/operator/mod.rs | 32 ++++-- runtime/src/operator/shared_lib.rs | 156 +++++++++++++++++++++++++++++ 6 files changed, 236 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 63d16c5e..4ce5a5bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -626,6 +626,7 @@ dependencies = [ "eyre", "futures", "futures-concurrency", + "libloading", "serde_yaml", "tokio", "tokio-stream", diff --git a/common/src/descriptor/mod.rs b/common/src/descriptor/mod.rs index 269e2ff1..cd1c06b6 100644 --- a/common/src/descriptor/mod.rs +++ b/common/src/descriptor/mod.rs @@ -42,7 +42,7 @@ pub enum NodeKind { Custom(CustomNode), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct OperatorConfig { pub id: OperatorId, pub name: Option, @@ -57,7 +57,7 @@ pub struct OperatorConfig { pub source: OperatorSource, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] #[serde(rename_all = "snake_case")] pub enum OperatorSource { SharedLibrary(PathBuf), diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 2f0ddf31..4f4e02bb 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -12,6 +12,7 @@ dora-common = { version = "0.1.0", path = "../common" } eyre = "0.6.8" futures = "0.3.21" futures-concurrency = "2.0.3" +libloading = "0.7.3" serde_yaml = "0.8.23" tokio = { version = "1.17.0", features = ["full"] } tokio-stream = "0.1.8" diff --git a/runtime/src/main.rs b/runtime/src/main.rs index 1e972af4..7f4a94d2 100644 --- a/runtime/src/main.rs +++ b/runtime/src/main.rs @@ -1,3 +1,5 @@ +#![warn(unsafe_op_in_unsafe_fn)] + use clap::StructOpt; use dora_api::{ self, @@ -59,7 +61,7 @@ async fn main() -> eyre::Result<()> { let mut operator_events = StreamMap::new(); for operator_config in &operators { let (events_tx, events) = mpsc::channel(1); - let operator = Operator::init(operator_config, events_tx.clone()) + let operator = Operator::init(operator_config.clone(), events_tx.clone()) .await .wrap_err_with(|| format!("failed to init operator {}", operator_config.id))?; operator_map.insert(&operator_config.id, operator); @@ -70,9 +72,9 @@ async fn main() -> eyre::Result<()> { .await .map_err(BoxError) .wrap_err("failed to create zenoh session")?; - let mut communication: Box = Box::new(zenoh); + let communication: Box = Box::new(zenoh); - let inputs = subscribe(communication.as_mut(), &dataflow.communication, &operators) + let inputs = subscribe(communication.as_ref(), &dataflow.communication, &operators) .await .context("failed to subscribe")?; @@ -108,7 +110,32 @@ async fn main() -> eyre::Result<()> { ) })?; } - Event::Operator { id, event } => match event {}, + Event::Operator { id, event } => { + let operator = operator_map + .get(&id) + .ok_or_else(|| eyre!("received event from unknown operator {id}"))?; + match event { + OperatorEvent::Output { id: data_id, value } => { + if !operator.config().outputs.contains(&data_id) { + eyre::bail!("unknown output {data_id} for operator {id}"); + } + publish( + &args.node_id, + id, + data_id, + &value, + communication.as_ref(), + &dataflow.communication, + ) + .await + .context("failed to publish operator output")?; + } + OperatorEvent::Error(err) => { + bail!(err.wrap_err(format!("operator {id} failed"))) + } + OperatorEvent::Panic(payload) => std::panic::resume_unwind(payload), + } + } } } @@ -116,7 +143,7 @@ async fn main() -> eyre::Result<()> { } async fn subscribe<'a>( - communication: &'a mut dyn CommunicationLayer, + communication: &'a dyn CommunicationLayer, communication_config: &CommunicationConfig, operators: &'a [OperatorConfig], ) -> eyre::Result + 'a> { @@ -170,6 +197,25 @@ async fn subscribe<'a>( Ok(streams.merge().take_until(finished)) } +async fn publish( + self_id: &NodeId, + operator_id: OperatorId, + output_id: DataId, + value: &[u8], + communication: &dyn CommunicationLayer, + communication_config: &CommunicationConfig, +) -> eyre::Result<()> { + let prefix = &communication_config.zenoh_prefix; + + let topic = format!("{prefix}/{self_id}/{operator_id}/{output_id}"); + communication + .publish(&topic, value) + .await + .wrap_err_with(|| format!("failed to send data for output {output_id}"))?; + + Ok(()) +} + enum Event { Input(OperatorInput), Operator { diff --git a/runtime/src/operator/mod.rs b/runtime/src/operator/mod.rs index 054f75e5..2e54afba 100644 --- a/runtime/src/operator/mod.rs +++ b/runtime/src/operator/mod.rs @@ -1,26 +1,31 @@ use dora_api::config::DataId; use dora_common::descriptor::{OperatorConfig, OperatorSource}; -use eyre::eyre; +use eyre::{eyre, Context}; +use std::any::Any; use tokio::sync::mpsc::{self, Sender}; mod shared_lib; pub struct Operator { operator_task: Sender, + config: OperatorConfig, } impl Operator { pub async fn init( - operator_config: &OperatorConfig, + operator_config: OperatorConfig, events_tx: Sender, ) -> eyre::Result { let (operator_task, operator_rx) = mpsc::channel(10); match &operator_config.source { OperatorSource::SharedLibrary(path) => { - let todo = - "init shared library operator at `path` with `events_tx` and `operator_rx`"; - eprintln!("WARNING: shared library operators are not supported yet"); + shared_lib::spawn(path, events_tx, operator_rx).wrap_err_with(|| { + format!( + "failed ot spawn shared library operator for {}", + operator_config.id + ) + })?; } OperatorSource::Python(path) => { eprintln!("WARNING: Python operators are not supported yet"); @@ -29,7 +34,10 @@ impl Operator { eprintln!("WARNING: WASM operators are not supported yet"); } } - Ok(Self { operator_task }) + Ok(Self { + operator_task, + config: operator_config, + }) } pub fn handle_input(&mut self, id: DataId, value: Vec) -> eyre::Result<()> { @@ -40,9 +48,19 @@ impl Operator { tokio::sync::mpsc::error::TrySendError::Full(_) => eyre!("operator queue full"), }) } + + /// Get a reference to the operator's config. + #[must_use] + pub fn config(&self) -> &OperatorConfig { + &self.config + } } -pub enum OperatorEvent {} +pub enum OperatorEvent { + Output { id: DataId, value: Vec }, + Error(eyre::Error), + Panic(Box), +} pub struct OperatorInput { id: DataId, diff --git a/runtime/src/operator/shared_lib.rs b/runtime/src/operator/shared_lib.rs index e69de29b..a22bfbc3 100644 --- a/runtime/src/operator/shared_lib.rs +++ b/runtime/src/operator/shared_lib.rs @@ -0,0 +1,156 @@ +use super::{OperatorEvent, OperatorInput}; +use eyre::{bail, Context}; +use libloading::Symbol; +use std::{ + ffi::c_void, + panic::{catch_unwind, AssertUnwindSafe}, + path::Path, + slice, thread, +}; +use tokio::sync::mpsc::{Receiver, Sender}; + +pub fn spawn( + path: &Path, + events_tx: Sender, + inputs: Receiver, +) -> eyre::Result<()> { + let library = unsafe { + libloading::Library::new(path) + .wrap_err_with(|| format!("failed to load shared library at `{}`", path.display()))? + }; + + thread::spawn(move || { + let closure = AssertUnwindSafe(|| { + let bindings = Bindings::init(&library)?; + + let operator = SharedLibraryOperator { + events_tx: events_tx.clone(), + inputs, + bindings, + }; + + operator.run() + }); + match catch_unwind(closure) { + Ok(Ok(())) => {} + Ok(Err(err)) => { + let _ = events_tx.blocking_send(OperatorEvent::Error(err)); + } + Err(panic) => { + let _ = events_tx.blocking_send(OperatorEvent::Panic(panic)); + } + } + }); + + Ok(()) +} + +struct SharedLibraryOperator<'lib> { + events_tx: Sender, + inputs: Receiver, + + bindings: Bindings<'lib>, +} + +impl<'lib> SharedLibraryOperator<'lib> { + fn run(mut self) -> eyre::Result<()> { + while let Some(input) = self.inputs.blocking_recv() { + let id_start = input.id.as_bytes().as_ptr(); + let id_len = input.id.as_bytes().len(); + let data_start = input.value.as_slice().as_ptr(); + let data_len = input.value.len(); + + println!("Received input {}", input.id); + let output = |id: &str, data: &[u8]| -> isize { + let result = self.events_tx.blocking_send(OperatorEvent::Output { + id: id.to_owned().into(), + value: data.to_owned(), + }); + match result { + Ok(()) => 0, + Err(_) => -1, + } + }; + let (output_fn, output_ctx) = wrap_closure(&output); + + let result = unsafe { + (self.bindings.on_input)( + id_start, id_len, data_start, data_len, output_fn, output_ctx, + ) + }; + if result != 0 { + bail!("on_input failed with exit code {result}"); + } + } + Ok(()) + } +} + +/// Wrap a closure with an FFI-compatible trampoline function. +/// +/// Returns a C compatible trampoline function and a data pointer that +/// must be passed as when invoking the trampoline function. +fn wrap_closure(closure: &F) -> (OutputFn, *const c_void) +where + F: Fn(&str, &[u8]) -> isize, +{ + /// Rust closures are just compiler-generated structs with a `call` method. This + /// trampoline function is generic over the closure type, which means that the + /// compiler's monomorphization step creates a different copy of that function + /// for each closure type. + /// + /// The trampoline function expects the pointer to the corresponding closure + /// struct as `context` argument. It casts that pointer back to a closure + /// struct pointer and invokes its call method. + unsafe extern "C" fn trampoline isize>( + id_start: *const u8, + id_len: usize, + data_start: *const u8, + data_len: usize, + context: *const c_void, + ) -> isize { + let id_raw = unsafe { slice::from_raw_parts(id_start, id_len) }; + let data = unsafe { slice::from_raw_parts(data_start, data_len) }; + let id = match std::str::from_utf8(id_raw) { + Ok(s) => s, + Err(_) => return -1, + }; + unsafe { (*(context as *const F))(id, data) } + } + + (trampoline::, closure as *const F as *const c_void) +} + +struct Bindings<'lib> { + on_input: Symbol<'lib, OnInputFn>, +} + +impl<'lib> Bindings<'lib> { + fn init(library: &'lib libloading::Library) -> Result { + let bindings = unsafe { + Bindings { + on_input: library + .get(b"dora_on_input") + .wrap_err("failed to get `dora_on_input`")?, + } + }; + Ok(bindings) + } +} + +type OnInputFn = unsafe extern "C" fn( + id_start: *const u8, + id_len: usize, + data_start: *const u8, + data_len: usize, + output: OutputFn, + output_context: *const c_void, +) -> isize; + +type OutputFn = unsafe extern "C" fn( + id_start: *const u8, + id_len: usize, + data_start: *const u8, + data_len: usize, + output_context: *const c_void, +) -> isize;