use std::{ borrow::Cow, collections::HashMap, path::{Path, PathBuf}, sync::Arc, }; use ::dora_ros2_bridge::{ros2_client, rustdds}; use arrow::{ array::{make_array, ArrayData}, pyarrow::{FromPyArrow, ToPyArrow}, }; use dora_ros2_bridge_msg_gen::types::Message; use eyre::{eyre, Context, ContextCompat}; use futures::{Stream, StreamExt}; use pyo3::{ prelude::{pyclass, pymethods}, types::{PyDict, PyList, PyModule}, PyAny, PyObject, PyResult, Python, }; use typed::{deserialize::StructDeserializer, TypeInfo, TypedValue}; pub mod qos; pub mod typed; #[pyclass] pub struct Ros2Context { context: ros2_client::Context, messages: Arc>>, } #[pymethods] impl Ros2Context { #[new] pub fn new(ros_paths: Option>) -> eyre::Result { let ament_prefix_path = std::env::var("AMENT_PREFIX_PATH"); let empty = String::new(); let paths: Vec<_> = match &ros_paths { Some(paths) => paths.iter().map(|p| p.as_path()).collect(), None => { let ament_prefix_path_parsed = match &ament_prefix_path { Ok(path) => path, Err(std::env::VarError::NotPresent) => &empty, Err(std::env::VarError::NotUnicode(s)) => { eyre::bail!( "AMENT_PREFIX_PATH is not valid unicode: `{}`", s.to_string_lossy() ); } }; ament_prefix_path_parsed.split(':').map(Path::new).collect() } }; let packages = dora_ros2_bridge_msg_gen::get_packages(&paths) .map_err(|err| eyre!(err)) .context("failed to parse ROS2 message types")?; let mut messages = HashMap::new(); for message in packages.into_iter().flat_map(|p| p.messages.into_iter()) { let entry: &mut HashMap = messages.entry(message.package.clone()).or_default(); entry.insert(message.name.clone(), message); } Ok(Self { context: ros2_client::Context::new()?, messages: Arc::new(messages), }) } /// Create a new ROS2 node pub fn new_node( &self, name: &str, namespace: &str, options: Ros2NodeOptions, ) -> eyre::Result { let name = ros2_client::NodeName::new(namespace, name) .map_err(|err| eyre!("invalid node name: {err}"))?; Ok(Ros2Node { node: self.context.new_node(name, options.into())?, messages: self.messages.clone(), }) } } #[pyclass] pub struct Ros2Node { node: ros2_client::Node, messages: Arc>>, } #[pymethods] impl Ros2Node { pub fn create_topic( &self, name: &str, message_type: String, qos: qos::Ros2QosPolicies, ) -> eyre::Result { let (namespace_name, message_name) = match (message_type.split_once("/"), message_type.split_once("::")) { (Some(msg), None) => msg, (None, Some(msg)) => msg, _ => eyre::bail!("Expected message type in the format `namespace/message` or `namespace::message`, such as `std_msgs/UInt8` but got: {}", message_type), }; let message_type_name = ros2_client::MessageTypeName::new(namespace_name, message_name); let topic_name = ros2_client::Name::parse(name) .map_err(|err| eyre!("failed to parse ROS2 topic name: {err}"))?; let topic = self .node .create_topic(&topic_name, message_type_name, &qos.into())?; let type_info = TypeInfo { package_name: namespace_name.to_owned().into(), message_name: message_name.to_owned().into(), messages: self.messages.clone(), }; Ok(Ros2Topic { topic, type_info }) } pub fn create_publisher( &mut self, topic: &Ros2Topic, qos: Option, ) -> eyre::Result { let publisher = self .node .create_publisher(&topic.topic, qos.map(Into::into))?; Ok(Ros2Publisher { publisher, type_info: topic.type_info.clone(), }) } pub fn create_subscription( &mut self, topic: &Ros2Topic, qos: Option, ) -> eyre::Result { let subscription = self .node .create_subscription(&topic.topic, qos.map(Into::into))?; Ok(Ros2Subscription { subscription: Some(subscription), deserializer: StructDeserializer::new(Cow::Owned(topic.type_info.clone())), }) } } #[derive(Debug, Clone, Default)] #[pyclass] #[non_exhaustive] pub struct Ros2NodeOptions { pub rosout: bool, } #[pymethods] impl Ros2NodeOptions { #[new] pub fn new(rosout: Option) -> Self { Self { rosout: rosout.unwrap_or(false), } } } impl From for ros2_client::NodeOptions { fn from(value: Ros2NodeOptions) -> Self { ros2_client::NodeOptions::new().enable_rosout(value.rosout) } } #[pyclass] #[non_exhaustive] pub struct Ros2Topic { topic: rustdds::Topic, type_info: TypeInfo<'static>, } #[pyclass] #[non_exhaustive] pub struct Ros2Publisher { publisher: ros2_client::Publisher>, type_info: TypeInfo<'static>, } #[pymethods] impl Ros2Publisher { pub fn publish(&self, data: &PyAny) -> eyre::Result<()> { let pyarrow = PyModule::import(data.py(), "pyarrow")?; let data = if data.is_instance_of::() { // convert to arrow struct scalar pyarrow.getattr("scalar")?.call1((data,))? } else { data }; let data = if data.is_instance(pyarrow.getattr("StructScalar")?)? { // convert to arrow array let list = PyList::new(data.py(), [data]); pyarrow.getattr("array")?.call1((list,))? } else { data }; let value = arrow::array::ArrayData::from_pyarrow(data)?; //// add type info to ensure correct serialization (e.g. struct types //// and map types need to be serialized differently) let typed_value = TypedValue { value: &make_array(value), type_info: &self.type_info, }; self.publisher .publish(typed_value) .map_err(|e| e.forget_data()) .context("publish failed")?; Ok(()) } } #[pyclass] #[non_exhaustive] pub struct Ros2Subscription { deserializer: StructDeserializer<'static>, subscription: Option>, } #[pymethods] impl Ros2Subscription { pub fn next(&self, py: Python) -> eyre::Result> { let message = self .subscription .as_ref() .context("subscription was already used")? .take_seed(self.deserializer.clone()) .context("failed to take next message from subscription")?; let Some((value, _info)) = message else { return Ok(None); }; let message = value.to_pyarrow(py)?; // TODO: add `info` Ok(Some(message)) } } impl Ros2Subscription { pub fn into_stream(&mut self) -> eyre::Result { let subscription = self .subscription .take() .context("subscription was already used")?; Ok(Ros2SubscriptionStream { deserializer: self.deserializer.clone(), subscription, }) } } pub struct Ros2SubscriptionStream { deserializer: StructDeserializer<'static>, subscription: ros2_client::Subscription, } impl Ros2SubscriptionStream { pub fn as_stream( &self, ) -> impl Stream> + '_ { self.subscription .async_stream_seed(self.deserializer.clone()) } } impl Stream for Ros2SubscriptionStream { type Item = Result<(ArrayData, ros2_client::MessageInfo), rustdds::dds::ReadError>; fn poll_next( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let s = self.as_stream(); futures::pin_mut!(s); s.poll_next_unpin(cx) } } pub fn create_dora_ros2_bridge_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; Ok(()) }