From 0074ffd6e8e0746b1f75fab28fc6b3151d0256d5 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Tue, 13 Feb 2024 17:32:40 +0100 Subject: [PATCH] Refactor C++ ROS2 subscription API to make downcasts work Use the subscriber type for downcasting to ensure that the correct type is used. Also, store an unique ID per subscriber to differentiate subscriptions of same type after merging. --- Cargo.lock | 9 +- apis/c++/node/Cargo.toml | 2 + apis/c++/node/build.rs | 18 ++- apis/c++/node/src/lib.rs | 137 ++++++++---------- .../c++-ros2-dataflow/node-rust-api/main.cc | 37 ++--- .../extensions/ros2-bridge/msg-gen/src/lib.rs | 29 +--- .../ros2-bridge/msg-gen/src/types/message.rs | 81 ++++++----- rust-toolchain.toml | 2 +- 8 files changed, 144 insertions(+), 171 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e9ab1d47..56486dea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1608,6 +1608,7 @@ dependencies = [ "dora-ros2-bridge-msg-gen", "eyre", "futures-lite 2.2.0", + "prettyplease", "rust-format", "serde", "serde-big-array", @@ -4863,9 +4864,9 @@ checksum = "ef703b7cb59335eae2eb93ceb664c0eb7ea6bf567079d843e09420219668e072" [[package]] name = "safer-ffi" -version = "0.1.3" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9c1d19b288ca9898cd421c7b105fb7269918a7f8e9253a991e228981ca421ad" +checksum = "4483c5ab47f222d2c297e73a520c9003e09e2fe1f1b04edcb572e6939f303003" dependencies = [ "inventory 0.1.11", "inventory 0.3.12", @@ -4881,9 +4882,9 @@ dependencies = [ [[package]] name = "safer_ffi-proc_macros" -version = "0.1.3" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d7a04caa3ca2224f5ea4ddd850e2629c3b36b2b83621f87a8303bf41020110" +checksum = "bf04ebd3786110e64269a74eea58c5564dd92a1e790c0f6f9871d6fe1b8e34db" dependencies = [ "macro_rules_attribute", "prettyplease", diff --git a/apis/c++/node/Cargo.toml b/apis/c++/node/Cargo.toml index 53c0016a..ff27a2e8 100644 --- a/apis/c++/node/Cargo.toml +++ b/apis/c++/node/Cargo.toml @@ -18,6 +18,7 @@ ros2-bridge = [ "dep:dora-ros2-bridge", "dep:dora-ros2-bridge-msg-gen", "dep:rust-format", + "dep:prettyplease", "dep:serde", "dep:serde-big-array", ] @@ -37,3 +38,4 @@ dora-ros2-bridge-msg-gen = { workspace = true, optional = true } rust-format = { version = "0.3.4", features = [ "pretty_please", ], optional = true } +prettyplease = { version = "0.1", features = ["verbatim"], optional = true } diff --git a/apis/c++/node/build.rs b/apis/c++/node/build.rs index d60150d0..3a0e9c4f 100644 --- a/apis/c++/node/build.rs +++ b/apis/c++/node/build.rs @@ -80,6 +80,8 @@ mod ros2 { } pub fn generate_ros2_message_header(source_file: &Path) { + use std::io::Write as _; + let out_dir = source_file.parent().unwrap(); let relative_path = local_relative_path(&source_file) .ancestors() @@ -112,12 +114,18 @@ mod ros2 { std::fs::copy(&header_path, &target_path).unwrap(); println!("cargo:rerun-if-changed={}", header_path.display()); - std::fs::copy( - &code_path, - target_path.with_file_name("dora-ros2-bindings.cc"), - ) - .unwrap(); + + let mut node_header = + std::fs::File::open(target_path.with_file_name("dora-node-api.h")).unwrap(); + let mut code_file = std::fs::File::open(&code_path).unwrap(); println!("cargo:rerun-if-changed={}", code_path.display()); + let mut code_target_file = + std::fs::File::create(target_path.with_file_name("dora-ros2-bindings.cc")).unwrap(); + + // copy both the node header and the code file to prevent import errors + std::io::copy(&mut node_header, &mut code_target_file).unwrap(); + std::io::copy(&mut code_file, &mut code_target_file).unwrap(); + code_target_file.flush().unwrap(); } // copy from cxx-build source diff --git a/apis/c++/node/src/lib.rs b/apis/c++/node/src/lib.rs index d4c755ea..99a3b1cf 100644 --- a/apis/c++/node/src/lib.rs +++ b/apis/c++/node/src/lib.rs @@ -3,13 +3,14 @@ use std::any::Any; use dora_node_api::{ self, arrow::array::{AsArray, BinaryArray}, - merged::MergedEvent, + merged::{MergeExternal, MergedEvent}, Event, EventStream, }; use eyre::bail; #[cfg(feature = "ros2-bridge")] use dora_ros2_bridge::_core; +use futures_lite::{Stream, StreamExt}; #[cxx::bridge] #[allow(clippy::needless_lifetimes)] @@ -37,22 +38,24 @@ mod ffi { error: String, } - extern "C++" { - #[allow(dead_code)] - type ExternalEvents = crate::ros2::ExternalEvents; - #[allow(dead_code)] - type Ros2Event = crate::ros2::Ros2Event; + pub struct CombinedEvents { + events: Box, + } + + pub struct CombinedEvent { + event: Box, } extern "Rust" { type Events; - type MergedEvents; type OutputSender; type DoraEvent; + type MergedEvents; type MergedDoraEvent; fn init_dora_node() -> Result; + fn dora_events_into_combined(events: Box) -> CombinedEvents; fn next(self: &mut Events) -> Box; fn event_type(event: &Box) -> DoraEventType; fn event_as_input(event: Box) -> Result; @@ -62,13 +65,10 @@ mod ffi { data: &[u8], ) -> DoraResult; - fn merge_events(dora: Box, external: Box) -> Box; - fn next(self: &mut MergedEvents) -> Box; + fn next(self: &mut CombinedEvents) -> CombinedEvent; - fn is_ros2(event: &Box) -> bool; - fn downcast_ros2(event: Box) -> Result>; - fn is_dora(event: &Box) -> bool; - fn downcast_dora(event: Box) -> Result>; + fn is_dora(self: &CombinedEvent) -> bool; + fn downcast_dora(event: CombinedEvent) -> Result>; } } @@ -78,19 +78,6 @@ pub mod ros2 { include!(env!("ROS2_BINDINGS_PATH")); } -/// Dummy placeholder. -#[cfg(not(feature = "ros2-bridge"))] -#[cxx::bridge] -#[allow(clippy::needless_lifetimes)] -mod ros2 { - pub struct ExternalEvents { - dummy: u8, - } - pub struct Ros2Event { - dummy: u8, - } -} - fn init_dora_node() -> eyre::Result { let (node, events) = dora_node_api::DoraNode::init_from_env()?; let events = Events(events); @@ -110,6 +97,16 @@ impl Events { } } +fn dora_events_into_combined(events: Box) -> ffi::CombinedEvents { + let events = events.0.map(MergedEvent::Dora); + ffi::CombinedEvents { + events: Box::new(MergedEvents { + events: Some(Box::new(events)), + next_id: 1, + }), + } +} + pub struct DoraEvent(Option); fn event_type(event: &DoraEvent) -> ffi::DoraEventType { @@ -151,73 +148,57 @@ fn send_output(sender: &mut Box, id: String, data: &[u8]) -> ffi:: ffi::DoraResult { error } } -#[cfg(feature = "ros2-bridge")] -#[allow(clippy::boxed_local)] -pub fn merge_events( - dora_events: Box, - external: Box, -) -> Box { - use dora_node_api::merged::MergeExternal; - - let merge_external = dora_events - .0 - .merge_external(external.events.0.as_event_stream()); - Box::new(MergedEvents(Box::new(futures_lite::stream::block_on( - merge_external, - )))) -} - -/// Dummy -#[cfg(not(feature = "ros2-bridge"))] -#[allow(clippy::boxed_local)] -pub fn merge_events( - dora_events: Box, - _external: Box, -) -> Box { - use dora_node_api::merged::MergeExternal; - - let merge_external = dora_events.0.merge_external(futures_lite::stream::empty()); - Box::new(MergedEvents(Box::new(futures_lite::stream::block_on( - merge_external, - )))) +pub struct MergedEvents { + events: Option> + Unpin>>, + next_id: u32, } -pub struct MergedEvents(Box>> + Unpin>); - impl MergedEvents { - fn next(&mut self) -> Box { - let event = self.0.next(); - Box::new(MergedDoraEvent(event)) + fn next(&mut self) -> MergedDoraEvent { + let event = futures_lite::future::block_on(self.events.as_mut().unwrap().next()); + MergedDoraEvent(event) } -} -pub struct MergedDoraEvent(Option>>); + pub fn merge(&mut self, events: impl Stream> + Unpin + 'static) -> u32 { + let id = self.next_id; + self.next_id += 1; + let events = Box::pin(events.map(move |event| ExternalEvent { event, id })); + + let inner = self.events.take().unwrap(); + let merged: Box + Unpin + 'static> = + Box::new(inner.merge_external(events).map(|event| match event { + MergedEvent::Dora(event) => MergedEvent::Dora(event), + MergedEvent::External(event) => MergedEvent::External(event.flatten()), + })); + self.events = Some(merged); -fn is_ros2(event: &Box) -> bool { - match event.0 { - Some(MergedEvent::External(_)) => true, - _ => false, + id } } -fn downcast_ros2(event: Box) -> eyre::Result> { - match event.0 { - Some(MergedEvent::External(event)) => Ok(Box::new(ros2::Ros2Event { - event: Box::new(ros2::ExternalRos2Event(event)), - })), - _ => eyre::bail!("not an external event"), +impl ffi::CombinedEvents { + fn next(&mut self) -> ffi::CombinedEvent { + ffi::CombinedEvent { + event: Box::new(self.events.next()), + } } } -fn is_dora(event: &Box) -> bool { - match event.0 { - Some(MergedEvent::Dora(_)) => true, - _ => false, +pub struct MergedDoraEvent(Option>); + +pub struct ExternalEvent { + pub event: Box, + pub id: u32, +} + +impl ffi::CombinedEvent { + fn is_dora(&self) -> bool { + matches!(&self.event.0, Some(MergedEvent::Dora(_))) } } -fn downcast_dora(event: Box) -> eyre::Result> { - match event.0 { +fn downcast_dora(event: ffi::CombinedEvent) -> eyre::Result> { + match event.event.0 { Some(MergedEvent::Dora(event)) => Ok(Box::new(DoraEvent(Some(event)))), _ => eyre::bail!("not an external event"), } diff --git a/examples/c++-ros2-dataflow/node-rust-api/main.cc b/examples/c++-ros2-dataflow/node-rust-api/main.cc index 9e7a107a..4d8d0cb7 100644 --- a/examples/c++-ros2-dataflow/node-rust-api/main.cc +++ b/examples/c++-ros2-dataflow/node-rust-api/main.cc @@ -9,6 +9,9 @@ int main() { std::cout << "HELLO FROM C++" << std::endl; + auto dora_node = init_dora_node(); + auto merged_events = dora_events_into_combined(std::move(dora_node.events)); + auto qos = qos_default(); qos.durability = Ros2Durability::Volatile; qos.liveliness = Ros2Liveliness::Automatic; @@ -20,38 +23,19 @@ int main() auto vel_topic = node->create_topic_geometry_msgs_Twist("/turtle1", "cmd_vel", qos); auto vel_publisher = node->create_publisher(vel_topic, qos); auto pose_topic = node->create_topic_turtlesim_Pose("/turtle1", "pose", qos); - auto pose_subscription = node->create_subscription(pose_topic, qos); + auto pose_subscription = node->create_subscription(pose_topic, qos, merged_events); std::random_device dev; std::default_random_engine gen(dev()); std::uniform_real_distribution<> dist(0., 1.); - auto dora_node = init_dora_node(); - - std::cout << "MERGING EVENTS" << std::endl; - auto merged_events = merge_events(std::move(dora_node.events), event_stream(std::move(pose_subscription))); - std::cout << "MERGED EVENTS" << std::endl; - auto received_ticks = 0; for (int i = 0; i < 1000; i++) { - auto event = merged_events->next(); + auto event = merged_events.next(); - if (is_ros2(event)) - { - auto ros2_event = downcast_ros2(std::move(event)); - if (turtlesim::is_Pose(ros2_event)) - { - auto pose = turtlesim::downcast_Pose(std::move(ros2_event)); - std::cout << "Received Pose { x: " << pose->x << ", y: " << pose->y << " }" << std::endl; - } - else - { - std::cout << "received unexpected ros2 input" << std::endl; - } - } - else if (is_dora(event)) + if (event.is_dora()) { auto dora_event = downcast_dora(std::move(event)); @@ -83,6 +67,15 @@ int main() break; } } + else if (pose_subscription->matches(event)) + { + auto pose = pose_subscription->downcast(std::move(event)); + std::cout << "Received pose x:" << pose.x << ", y:" << pose.y << std::endl; + } + else + { + std::cout << "received unexpected event" << std::endl; + } } std::cout << "GOODBYE FROM C++ node (using Rust API)" << std::endl; diff --git a/libraries/extensions/ros2-bridge/msg-gen/src/lib.rs b/libraries/extensions/ros2-bridge/msg-gen/src/lib.rs index be4f0396..5c86823c 100644 --- a/libraries/extensions/ros2-bridge/msg-gen/src/lib.rs +++ b/libraries/extensions/ros2-bridge/msg-gen/src/lib.rs @@ -46,23 +46,19 @@ where ( quote! { #[cxx::bridge] }, quote! { + extern "C++" { + type CombinedEvents = crate::ffi::CombinedEvents; + type CombinedEvent = crate::ffi::CombinedEvent; + } + extern "Rust" { type Ros2Context; type Ros2Node; - type ExternalRos2Events; - type ExternalRos2Event; fn init_ros2_context() -> Result>; fn new_node(self: &Ros2Context, name_space: &str, base_name: &str) -> Result>; fn qos_default() -> Ros2QosPolicies; - #(#message_topic_defs)* - } - - pub struct ExternalEvents { - events: Box, - } - pub struct Ros2Event { - event: Box, + #(#message_topic_defs)* } #[derive(Debug, Clone)] @@ -194,19 +190,6 @@ where } } } - - pub use ffi::ExternalEvents; - pub use ffi::Ros2Event; - - pub struct ExternalRos2Events( - pub Box, - ); - - pub trait AsEventStream { - fn as_event_stream(self: Box) -> Box> + Unpin>; - } - - pub struct ExternalRos2Event(pub Box); }, ) } else { diff --git a/libraries/extensions/ros2-bridge/msg-gen/src/types/message.rs b/libraries/extensions/ros2-bridge/msg-gen/src/types/message.rs index 1ce74e16..b8d81173 100644 --- a/libraries/extensions/ros2-bridge/msg-gen/src/types/message.rs +++ b/libraries/extensions/ros2-bridge/msg-gen/src/types/message.rs @@ -192,22 +192,22 @@ impl Message { let cxx_create_publisher = format_ident!("create_publisher"); let struct_raw_name = format_ident!("{package_name}__{}", self.name); + let struct_raw_name_str = struct_raw_name.to_string(); let self_name = &self.name; let publish = format_ident!("publish__{package_name}__{}", self.name); let cxx_publish = format_ident!("publish"); let subscription_name = format_ident!("Subscription__{package_name}__{}", self.name); + let subscription_name_str = subscription_name.to_string(); let cxx_subscription_name = format_ident!("Subscription_{}", self.name); let create_subscription = format_ident!("new__Subscription__{package_name}__{}", self.name); let cxx_create_subscription = format_ident!("create_subscription"); - let event_stream = format_ident!("event_stream__{package_name}__{}", self.name); - let cxx_event_stream = format_ident!("event_stream"); - let is = format_ident!("is__{package_name}__{}", self.name); - let cxx_is = format_ident!("is_{}", self.name); + let matches = format_ident!("matches__{package_name}__{}", self.name); + let cxx_matches = format_ident!("matches"); let downcast = format_ident!("downcast__{package_name}__{}", self.name); - let cxx_downcast = format_ident!("downcast_{}", self.name); + let cxx_downcast = format_ident!("downcast"); let def = quote! { #[namespace = #package_name] @@ -218,7 +218,7 @@ impl Message { #[cxx_name = #cxx_create_publisher] fn #create_publisher(self: &mut Ros2Node, topic: &Box<#topic_name>, qos: Ros2QosPolicies) -> Result>; #[cxx_name = #cxx_create_subscription] - fn #create_subscription(self: &mut Ros2Node, topic: &Box<#topic_name>, qos: Ros2QosPolicies) -> Result>; + fn #create_subscription(self: &mut Ros2Node, topic: &Box<#topic_name>, qos: Ros2QosPolicies, events: &mut CombinedEvents) -> Result>; #[namespace = #package_name] #[cxx_name = #cxx_publisher_name] @@ -232,15 +232,11 @@ impl Message { type #subscription_name; #[namespace = #package_name] - #[cxx_name = #cxx_event_stream] - fn #event_stream(subscription: Box<#subscription_name>) -> Box; - - #[namespace = #package_name] - #[cxx_name = #cxx_is] - fn #is(event: &Box) -> bool; + #[cxx_name = #cxx_matches] + fn #matches(self: &#subscription_name, event: &CombinedEvent) -> bool; #[namespace = #package_name] #[cxx_name = #cxx_downcast] - fn #downcast(event: Box) -> Result>; + fn #downcast(self: &#subscription_name, event: CombinedEvent) -> Result<#struct_raw_name>; }; let imp = quote! { #[allow(non_camel_case_types)] @@ -262,9 +258,16 @@ impl Message { } #[allow(non_snake_case)] - pub fn #create_subscription(&mut self, topic: &Box<#topic_name>, qos: ffi::Ros2QosPolicies) -> eyre::Result> { - let subscription = self.0.create_subscription(&topic.0, Some(qos.into()))?; - Ok(Box::new(#subscription_name(subscription))) + pub fn #create_subscription(&mut self, topic: &Box<#topic_name>, qos: ffi::Ros2QosPolicies, events: &mut crate::ffi::CombinedEvents) -> eyre::Result> { + let subscription = self.0.create_subscription::(&topic.0, Some(qos.into()))?; + let stream = futures_lite::stream::unfold(subscription, |sub| async { + let item = sub.async_take().await; + let item_boxed: Box = Box::new(item); + Some((item_boxed, sub)) + }); + let id = events.events.merge(Box::pin(stream)); + + Ok(Box::new(#subscription_name { id })) } } @@ -280,30 +283,32 @@ impl Message { } #[allow(non_camel_case_types)] - pub struct #subscription_name(ros2_client::Subscription); - - #[allow(non_snake_case)] - fn #event_stream(subscription: Box<#subscription_name>) -> Box { - Box::new(ExternalEvents { events: Box::new(ExternalRos2Events(subscription)) }) - } - - #[allow(non_snake_case)] - fn #is(event: &Box) -> bool { - event.event.0.is::() - } - #[allow(non_snake_case)] - fn #downcast(event: Box) -> eyre::Result> { - event.event.0.downcast().map_err(|_| eyre::eyre!("downcast failed")) + pub struct #subscription_name { + id: u32, } - impl AsEventStream for #subscription_name { - fn as_event_stream(self: Box) -> Box> + Unpin> { - let stream = futures_lite::stream::unfold(self.0, |sub| async { - let item = sub.async_take().await; - let item_boxed: Box = Box::new(item); - Some((item_boxed, sub)) - }); - Box::new(Box::pin(stream)) + impl #subscription_name { + #[allow(non_snake_case)] + fn #matches(&self, event: &crate::ffi::CombinedEvent) -> bool { + match &event.event.as_ref().0 { + Some(crate::MergedEvent::External(event)) if event.id == self.id => true, + _ => false + } + } + #[allow(non_snake_case)] + fn #downcast(&self, event: crate::ffi::CombinedEvent) -> eyre::Result { + use eyre::WrapErr; + + match (*event.event).0 { + Some(crate::MergedEvent::External(event)) if event.id == self.id => { + let result = event.event.downcast::>() + .map_err(|_| eyre::eyre!("downcast to {} failed", #struct_raw_name_str))?; + + let (data, _info) = result.with_context(|| format!("failed to receive {} event", #subscription_name_str))?; + Ok(data) + }, + _ => eyre::bail!("not a {} event", #subscription_name_str), + } } } }; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 0538cafa..0a2102d4 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.70" +channel = "1.76" components = ["rustfmt", "clippy"]