diff --git a/libraries/extensions/ros2-bridge/msg-gen/src/lib.rs b/libraries/extensions/ros2-bridge/msg-gen/src/lib.rs index c64f809b..e3ce3677 100644 --- a/libraries/extensions/ros2-bridge/msg-gen/src/lib.rs +++ b/libraries/extensions/ros2-bridge/msg-gen/src/lib.rs @@ -29,6 +29,12 @@ where let mut service_impls = Vec::new(); let mut service_creation_defs = Vec::new(); let mut service_creation_impls = Vec::new(); + + let mut action_defs = Vec::new(); + let mut action_impls = Vec::new(); + let mut action_creation_defs = Vec::new(); + let mut action_creation_impls = Vec::new(); + let mut aliases = Vec::new(); for package in &packages { for message in &package.messages { @@ -54,6 +60,20 @@ where } } + for action in &package.actions { + let (def, imp) = action.struct_token_stream(&package.name, create_cxx_bridge); + action_defs.push(def); + action_impls.push(imp); + if create_cxx_bridge { + let (action_creation_def, action_creation_impl) = + action.cxx_action_creation_functions(&package.name); + let action_creation_def = quote! { #action_creation_def }; + let action_creation_impl = quote! { #action_creation_impl }; + action_creation_defs.push(action_creation_def); + action_creation_impls.push(action_creation_impl); + } + } + aliases.push(package.aliases_token_stream()); } @@ -73,9 +93,11 @@ where fn init_ros2_context() -> Result>; fn new_node(self: &Ros2Context, name_space: &str, base_name: &str) -> Result>; fn qos_default() -> Ros2QosPolicies; + fn actionqos_default() -> Ros2ActionClientQosPolicies; #(#message_topic_defs)* #(#service_creation_defs)* + #(#action_creation_defs)* } #[derive(Debug, Clone)] @@ -89,6 +111,15 @@ where pub keep_last: i32, } + #[derive(Debug, Clone)] + pub struct Ros2ActionClientQosPolicies { + pub goal_service: Ros2QosPolicies, + pub result_service: Ros2QosPolicies, + pub cancel_service: Ros2QosPolicies, + pub feedback_subscription: Ros2QosPolicies, + pub status_subscription: Ros2QosPolicies, + } + /// DDS 2.2.3.4 DURABILITY #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Ros2Durability { @@ -150,6 +181,16 @@ where ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None) } + fn actionqos_default() -> ffi::Ros2ActionClientQosPolicies { + ffi::Ros2ActionClientQosPolicies::new( + Some(qos_default()), + Some(qos_default()), + Some(qos_default()), + Some(qos_default()), + Some(qos_default()) + ) + } + impl ffi::Ros2QosPolicies { pub fn new( durability: Option, @@ -229,6 +270,36 @@ where } } } + + impl ffi::Ros2ActionClientQosPolicies { + pub fn new( + goal_service: Option, + result_service: Option, + cancel_service: Option, + feedback_subscription: Option, + status_subscription: Option, + ) -> Self { + Self { + goal_service: goal_service.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)), + result_service: result_service.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)), + cancel_service: cancel_service.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)), + feedback_subscription: feedback_subscription.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)), + status_subscription: status_subscription.unwrap_or_else(|| ffi::Ros2QosPolicies::new(None, None, None, None, None, None, None)), + } + } + } + + impl From for ros2_client::action::ActionClientQosPolicies { + fn from(value: ffi::Ros2ActionClientQosPolicies) -> Self { + ros2_client::action::ActionClientQosPolicies { + goal_service: value.goal_service.into(), + result_service: value.result_service.into(), + cancel_service: value.cancel_service.into(), + feedback_subscription: value.feedback_subscription.into(), + status_subscription: value.status_subscription.into(), + } + } + } }, ) } else { @@ -253,6 +324,7 @@ where #(#shared_type_defs)* #(#service_defs)* + #(#action_defs)* } @@ -273,9 +345,11 @@ where #cxx_bridge_impls #(#message_topic_impls)* #(#service_creation_impls)* + #(#action_creation_impls)* #(#service_impls)* + #(#action_impls)* #(#aliases)* } diff --git a/libraries/extensions/ros2-bridge/msg-gen/src/types/action.rs b/libraries/extensions/ros2-bridge/msg-gen/src/types/action.rs index 9338abf8..b4b272bf 100644 --- a/libraries/extensions/ros2-bridge/msg-gen/src/types/action.rs +++ b/libraries/extensions/ros2-bridge/msg-gen/src/types/action.rs @@ -1,5 +1,6 @@ use heck::SnakeCase; use quote::{format_ident, quote, ToTokens}; +use syn::Ident; use super::{primitives::*, Member, Message, Service}; @@ -19,6 +20,78 @@ pub struct Action { } impl Action { + pub fn struct_token_stream( + &self, + package_name: &str, + gen_cxx_bridge: bool, + ) -> (impl ToTokens, impl ToTokens) { + let (goal_def, goal_impl) = self.goal.struct_token_stream(package_name, gen_cxx_bridge); + let (result_def, result_impl) = self + .result + .struct_token_stream(package_name, gen_cxx_bridge); + let (feedback_def, feedback_impl) = self + .feedback + .struct_token_stream(package_name, gen_cxx_bridge); + + let def = quote! { + #goal_def + #result_def + #feedback_def + }; + + let impls = quote! { + #goal_impl + #result_impl + #feedback_impl + }; + + (def, impls) + } + + pub fn alias_token_stream(&self, package_name: &Ident) -> impl ToTokens { + let action_type = format_ident!("{}", self.name); + let goal_type_raw = format_ident!("{package_name}__{}_Goal", self.name); + let result_type_raw = format_ident!("{package_name}__{}_Result", self.name); + let feedback_type_raw = format_ident!("{package_name}__{}_Feedback", self.name); + + let goal_type = format_ident!("{}Goal", self.name); + let result_type = format_ident!("{}Result", self.name); + let feedback_type = format_ident!("{}Feedback", self.name); + + let goal_type_name = goal_type.to_string(); + let result_type_name = result_type.to_string(); + let feedback_type_name = feedback_type.to_string(); + + quote! { + #[allow(non_camel_case_types)] + #[derive(std::fmt::Debug)] + pub struct #action_type; + + impl crate::ros2_client::ActionTypes for #action_type { + type GoalType = #goal_type; + type ResultType = #result_type; + type FeedbackType = #feedback_type; + + fn goal_type_name(&self) -> &str { + #goal_type_name + } + + fn result_type_name(&self) -> &str { + #result_type_name + } + + fn feedback_type_name(&self) -> &str { + #feedback_type_name + } + } + + pub use super::super::ffi::#goal_type_raw as #goal_type; + pub use super::super::ffi::#result_type_raw as #result_type; + pub use super::super::ffi::#feedback_type_raw as #feedback_type; + + } + } + fn send_goal_srv(&self) -> Service { let common = format!("{}_SendGoal", self.name); @@ -132,6 +205,195 @@ impl Action { } } + pub fn cxx_action_creation_functions( + &self, + package_name: &str, + ) -> (impl ToTokens, impl ToTokens) { + let client_name = format_ident!("Actionclient__{package_name}__{}", self.name); + let cxx_client_name = format_ident!("Actionclient_{}", self.name); + let create_client = format_ident!("new_ActionClient__{package_name}__{}", self.name); + let cxx_create_client = format!("create_action_client_{package_name}_{}", self.name); + + let package = format_ident!("{package_name}"); + let self_name = format_ident!("{}", self.name); + let self_name_str = &self.name; + + let send_goal = format_ident!("send_goal__{package_name}__{}", self.name); + let cxx_send_goal = format!("send_goal"); + + let matches = format_ident!("matches__{package_name}__{}", self.name); + let cxx_matches = format_ident!("matches"); + let downcast = format_ident!("action_downcast__{package_name}__{}", self.name); + let cxx_downcast = format_ident!("downcast"); + + let goal_type_raw = format_ident!("{package_name}__{}_Goal", self.name); + let result_type_raw = format_ident!("{package_name}__{}_Result", self.name); + + let result_type_raw_str = result_type_raw.to_string(); + + let def = quote! { + #[namespace = #package_name] + #[cxx_name = #cxx_client_name] + type #client_name; + + #[cxx_name = #cxx_create_client] + fn #create_client(self: &mut Ros2Node, name_space: &str, base_name: &str, qos:Ros2ActionClientQosPolicies, events: &mut CombinedEvents) -> Result>; + + #[namespace = #package_name] + #[cxx_name = #cxx_send_goal] + fn #send_goal(self: &mut #client_name, request: #goal_type_raw) -> Result<()>; + + #[namespace = #package_name] + #[cxx_name = #cxx_matches] + fn #matches(self: &mut #client_name, event: &CombinedEvent) -> bool; + + #[namespace = #package_name] + #[cxx_name = #cxx_downcast] + fn #downcast(self: &mut #client_name, event: CombinedEvent) -> Result<#result_type_raw>; + }; + + let imp = quote! { + impl Ros2Node { + #[allow(non_snake_case)] + pub fn #create_client(&mut self, name_space: &str, base_name: &str, qos: ffi::Ros2ActionClientQosPolicies, events: &mut crate::ffi::CombinedEvents) -> eyre::Result> { + use futures::StreamExt as _; + + let client = self.node.create_action_client::< #package :: action :: #self_name >( + ros2_client::ServiceMapping::Enhanced, + &ros2_client::Name::new(name_space, base_name).unwrap(), + &ros2_client::ActionTypeName::new(#package_name, #self_name_str), + qos.into(), + ).map_err(|e| eyre::eyre!("{e:?}"))?; + let (response_tx, response_rx) = flume::bounded(1); + let stream = response_rx.into_stream().map(|v: eyre::Result<_>| Box::new(v) as Box); + let id = events.events.merge(Box::pin(stream)); + + Ok(Box::new(#client_name { + client: std::sync::Arc::new(client), + response_tx: std::sync::Arc::new(response_tx), + executor: self.executor.clone(), + stream_id: id, + })) + } + } + + #[allow(non_camel_case_types)] + pub struct #client_name { + client: std::sync::Arc>, + response_tx: std::sync::Arc>>, + executor: std::sync::Arc, + stream_id: u32, + } + + impl #client_name { + + #[allow(non_snake_case)] + fn #send_goal(&mut self, request: ffi::#goal_type_raw) -> eyre::Result<()> { + use eyre::WrapErr; + use futures::task::SpawnExt as _; + use futures::stream::StreamExt; + use futures::executor::block_on; + use std::sync::Arc; + + let client_arc = Arc::new(self.client.clone()); + + let client_ref = Arc::clone(&client_arc); + let send_goal = async move { + client_ref.async_send_goal(request.clone()).await + }; + + let handle = self.executor.spawn_with_handle(send_goal) + .map_err(|e| eyre::eyre!("{e:?}"))?; + + let (goal_id, send_goal_response) = block_on(handle) + .map_err(|e| eyre::eyre!("{e:?}"))?; + + if !send_goal_response.accepted { + return Err(eyre::eyre!("Goal was rejected by the server.")); + } + + let feedback_handle = { + let client_ref = Arc::clone(&client_arc); + async move { + let feedback_stream = client_ref.feedback_stream(goal_id); + feedback_stream.for_each(|feedback| async { + match feedback { + Ok(feedback) => println!("Received feedback: {:?}", feedback), + Err(e) => eprintln!("Error while receive feedback: {:?}", e), + } + }).await; + } + }; + + self.executor.spawn(feedback_handle).context("failed to spawn feedback task")?; + + + let status_handle = { + let client_ref = Arc::clone(&client_arc); + async move { + let status_stream = client_ref.status_stream(goal_id); + status_stream.for_each(|status| async { + match status { + Ok(status) => println!("Status update: {:?}", status), + Err(e) => eprintln!("Error receiving status update: {:?}", e), + } + }).await; + } + }; + + self.executor.spawn(status_handle).context("failed to spawn status task")?; + + let request_result_handle = { + let client_ref = Arc::clone(&client_arc); + let response_tx = self.response_tx.clone(); + async move { + match client_ref.async_request_result(goal_id).await { + Ok((_status, result)) => { + let response = Ok(result); + if response_tx.send_async(response).await.is_err() { + tracing::warn!("failed to send action result"); + } + }, + Err(e) => { + tracing::error!("Failed to receive response for request {goal_id:?}: {:?}", e); + } + } + } + }; + + self.executor.spawn(request_result_handle).context("failed to spawn response task").map_err(|e| eyre::eyre!("{e:?}"))?; + Ok(()) + } + + #[allow(non_snake_case)] + fn #matches(&self, event: &crate::ffi::CombinedEvent) -> bool { + match &event.event.as_ref().0 { + Some(crate::MergedEvent::External(event)) if event.id == self.stream_id => true, + _ => false + } + } + + #[allow(non_snake_case)] + fn #downcast(&self, event: crate::ffi::CombinedEvent) -> eyre::Result { + use eyre::WrapErr; + + match (*event.event).0 { + Some(crate::MergedEvent::External(event)) if event.id == self.stream_id => { + let result = event.event.downcast::>() + .map_err(|_| eyre::eyre!("downcast to {} failed", #result_type_raw_str))?; + + let data = result.with_context(|| format!("failed to receive {} response", #self_name_str)) + .map_err(|e| eyre::eyre!("{e:?}"))?; + Ok(data) + }, + _ => eyre::bail!("not a {} response event", #self_name_str), + } + } + } + }; + (def, imp) + } + pub fn token_stream_with_mod(&self) -> impl ToTokens { let mod_name = format_ident!("_{}", self.name.to_snake_case()); let inner = self.token_stream(); diff --git a/libraries/extensions/ros2-bridge/msg-gen/src/types/package.rs b/libraries/extensions/ros2-bridge/msg-gen/src/types/package.rs index 47f3fa19..2e3fe5b1 100644 --- a/libraries/extensions/ros2-bridge/msg-gen/src/types/package.rs +++ b/libraries/extensions/ros2-bridge/msg-gen/src/types/package.rs @@ -98,6 +98,25 @@ impl Package { } } + fn action_aliases(&self, package_name: &Ident) -> impl ToTokens { + if self.actions.is_empty() { + quote! { + //empty msg + } + } else { + let items = self + .actions + .iter() + .map(|v| v.alias_token_stream(package_name)); + + quote! { + pub mod action { + #(#items)* + } // action + } + } + } + fn actions_block(&self) -> impl ToTokens { if self.actions.is_empty() { quote! { @@ -117,11 +136,13 @@ impl Package { let package_name = Ident::new(&self.name, Span::call_site()); let aliases = self.message_aliases(&package_name); let service_aliases = self.service_aliases(&package_name); + let action_aliases = self.action_aliases(&package_name); quote! { pub mod #package_name { #aliases #service_aliases + #action_aliases } } }