You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lib.rs 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. //! Demonstrates the most barebone usage of the Rerun SDK.
  2. use std::{collections::HashMap, env::VarError, path::Path};
  3. use dora_node_api::{
  4. arrow::{
  5. array::{Array, AsArray, Float64Array, StringArray, UInt16Array, UInt8Array},
  6. datatypes::Float32Type,
  7. },
  8. dora_core::config::DataId,
  9. into_vec, DoraNode, Event, Parameter,
  10. };
  11. use eyre::{eyre, Context, Result};
  12. use rerun::{
  13. components::ImageBuffer,
  14. external::{log::warn, re_types::ArrowBuffer},
  15. ImageFormat, Points3D, SpawnOptions,
  16. };
  17. pub mod boxes2d;
  18. pub mod series;
  19. pub mod urdf;
  20. use series::update_series;
  21. use urdf::{init_urdf, update_visualization};
  22. pub fn lib_main() -> Result<()> {
  23. // rerun `serve()` requires to have a running Tokio runtime in the current context.
  24. let rt = tokio::runtime::Builder::new_current_thread()
  25. .build()
  26. .expect("Failed to create tokio runtime");
  27. let _guard = rt.enter();
  28. let (node, mut events) = DoraNode::init_from_env().context("Could not initialize dora node")?;
  29. // Setup an image cache to paint depth images.
  30. let mut image_cache = HashMap::new();
  31. let mut mask_cache: HashMap<DataId, Vec<bool>> = HashMap::new();
  32. let mut color_cache: HashMap<DataId, rerun::Color> = HashMap::new();
  33. let mut options = SpawnOptions::default();
  34. let memory_limit = match std::env::var("RERUN_MEMORY_LIMIT") {
  35. Ok(memory_limit) => memory_limit
  36. .parse::<String>()
  37. .context("Could not parse RERUN_MEMORY_LIMIT value")?,
  38. Err(VarError::NotUnicode(_)) => {
  39. return Err(eyre!("RERUN_MEMORY_LIMIT env variable is not unicode"));
  40. }
  41. Err(VarError::NotPresent) => "25%".to_string(),
  42. };
  43. options.memory_limit = memory_limit;
  44. let rec = match std::env::var("OPERATING_MODE").as_deref() {
  45. Ok("SPAWN") => rerun::RecordingStreamBuilder::new("dora-rerun")
  46. .spawn_opts(&options, None)
  47. .context("Could not spawn rerun visualization")?,
  48. Ok("CONNECT") => {
  49. let opt = std::env::var("RERUN_SERVER_ADDR").unwrap_or("127.0.0.1:9876".to_string());
  50. rerun::RecordingStreamBuilder::new("dora-rerun")
  51. .connect_tcp_opts(std::net::SocketAddr::V4(opt.parse()?), None)
  52. .context("Could not connect to rerun visualization")?
  53. }
  54. Ok("SAVE") => {
  55. let id = node.dataflow_id();
  56. let path = Path::new("out")
  57. .join(id.to_string())
  58. .join(format!("archive-{}.rerun", id));
  59. rerun::RecordingStreamBuilder::new("dora-rerun")
  60. .save(path)
  61. .context("Could not save rerun visualization")?
  62. }
  63. Ok(_) => {
  64. warn!("Invalid operating mode, defaulting to SPAWN mode.");
  65. rerun::RecordingStreamBuilder::new("dora-rerun")
  66. .spawn_opts(&options, None)
  67. .context("Could not spawn rerun visualization")?
  68. }
  69. Err(_) => rerun::RecordingStreamBuilder::new("dora-rerun")
  70. .spawn_opts(&options, None)
  71. .context("Could not spawn rerun visualization")?,
  72. };
  73. let chains = init_urdf(&rec).context("Could not load urdf")?;
  74. match std::env::var("README") {
  75. Ok(readme) => {
  76. readme
  77. .parse::<String>()
  78. .context("Could not parse readme value")?;
  79. rec.log("README", &rerun::TextDocument::new(readme))
  80. .wrap_err("Could not log text")?;
  81. }
  82. Err(VarError::NotUnicode(_)) => {
  83. return Err(eyre!("readme env variable is not unicode"));
  84. }
  85. Err(VarError::NotPresent) => (),
  86. };
  87. let camera_pitch = std::env::var("CAMERA_PITCH")
  88. .unwrap_or("0.0".to_string())
  89. .parse::<f32>()
  90. .unwrap();
  91. while let Some(event) = events.recv() {
  92. if let Event::Input { id, data, metadata } = event {
  93. if id.as_str().contains("image") {
  94. let height =
  95. if let Some(Parameter::Integer(height)) = metadata.parameters.get("height") {
  96. height
  97. } else {
  98. &480
  99. };
  100. let width =
  101. if let Some(Parameter::Integer(width)) = metadata.parameters.get("width") {
  102. width
  103. } else {
  104. &640
  105. };
  106. let encoding = if let Some(Parameter::String(encoding)) =
  107. metadata.parameters.get("encoding")
  108. {
  109. encoding
  110. } else {
  111. "bgr8"
  112. };
  113. if encoding == "bgr8" {
  114. let buffer: &UInt8Array = data.as_any().downcast_ref().unwrap();
  115. let buffer: &[u8] = buffer.values();
  116. // Transpose values from BGR to RGB
  117. let buffer: Vec<u8> =
  118. buffer.chunks(3).flat_map(|x| [x[2], x[1], x[0]]).collect();
  119. image_cache.insert(id.clone(), buffer.clone());
  120. let buffer = ArrowBuffer::from(buffer);
  121. let image_buffer = ImageBuffer::try_from(buffer)
  122. .context("Could not convert buffer to image buffer")?;
  123. // let tensordata = ImageBuffer(buffer);
  124. let image = rerun::Image::new(
  125. image_buffer,
  126. ImageFormat::rgb8([*width as u32, *height as u32]),
  127. );
  128. rec.log(id.as_str(), &image)
  129. .context("could not log image")?;
  130. } else if encoding == "rgb8" {
  131. let buffer: &UInt8Array = data.as_any().downcast_ref().unwrap();
  132. image_cache.insert(id.clone(), buffer.values().to_vec());
  133. let buffer: &[u8] = buffer.values();
  134. let buffer = ArrowBuffer::from(buffer);
  135. let image_buffer = ImageBuffer::try_from(buffer)
  136. .context("Could not convert buffer to image buffer")?;
  137. let image = rerun::Image::new(
  138. image_buffer,
  139. ImageFormat::rgb8([*width as u32, *height as u32]),
  140. );
  141. rec.log(id.as_str(), &image)
  142. .context("could not log image")?;
  143. } else if ["jpeg", "png", "avif"].contains(&encoding) {
  144. let buffer: &UInt8Array = data.as_any().downcast_ref().unwrap();
  145. let buffer: &[u8] = buffer.values();
  146. let image = rerun::EncodedImage::from_file_contents(buffer.to_vec());
  147. rec.log(id.as_str(), &image)
  148. .context("could not log image")?;
  149. };
  150. } else if id.as_str().contains("depth") {
  151. let width =
  152. if let Some(Parameter::Integer(width)) = metadata.parameters.get("width") {
  153. width
  154. } else {
  155. &640
  156. };
  157. let focal_length =
  158. if let Some(Parameter::ListInt(focals)) = metadata.parameters.get("focal") {
  159. focals.to_vec()
  160. } else {
  161. vec![605, 605]
  162. };
  163. let resolution = if let Some(Parameter::ListInt(resolution)) =
  164. metadata.parameters.get("resolution")
  165. {
  166. resolution.to_vec()
  167. } else {
  168. vec![640, 480]
  169. };
  170. let pitch = if let Some(Parameter::Float(pitch)) = metadata.parameters.get("pitch")
  171. {
  172. *pitch as f32
  173. } else {
  174. camera_pitch
  175. };
  176. let cos_theta = pitch.cos();
  177. let sin_theta = pitch.sin();
  178. let points = match data.data_type() {
  179. dora_node_api::arrow::datatypes::DataType::Float64 => {
  180. let buffer: &Float64Array = data.as_any().downcast_ref().unwrap();
  181. let mut points = vec![];
  182. buffer.iter().enumerate().for_each(|(i, z)| {
  183. let u = i as f32 % *width as f32; // Calculate x-coordinate (u)
  184. let v = i as f32 / *width as f32; // Calculate y-coordinate (v)
  185. if let Some(z) = z {
  186. let z = z as f32;
  187. // Skip points that have empty depth or is too far away
  188. if z == 0. || z > 8.0 {
  189. points.push((0., 0., 0.));
  190. return;
  191. }
  192. let y = (u - resolution[0] as f32) * z / focal_length[0] as f32;
  193. let x = (v - resolution[1] as f32) * z / focal_length[1] as f32;
  194. let new_x = sin_theta * z + cos_theta * x;
  195. let new_y = -y;
  196. let new_z = cos_theta * z - sin_theta * x;
  197. points.push((new_x, new_y, new_z));
  198. } else {
  199. points.push((0., 0., 0.));
  200. }
  201. });
  202. Points3D::new(points)
  203. }
  204. dora_node_api::arrow::datatypes::DataType::UInt16 => {
  205. let buffer: &UInt16Array = data.as_any().downcast_ref().unwrap();
  206. let mut points = vec![];
  207. buffer.iter().enumerate().for_each(|(i, z)| {
  208. let u = i as f32 % *width as f32; // Calculate x-coordinate (u)
  209. let v = i as f32 / *width as f32; // Calculate y-coordinate (v)
  210. if let Some(z) = z {
  211. let z = z as f32 / 1000.0; // Convert to meters
  212. // Skip points that have empty depth or is too far away
  213. if z == 0. || z > 8.0 {
  214. points.push((0., 0., 0.));
  215. return;
  216. }
  217. let y = (u - resolution[0] as f32) * z / focal_length[0] as f32;
  218. let x = (v - resolution[1] as f32) * z / focal_length[1] as f32;
  219. let new_x = sin_theta * z + cos_theta * x;
  220. let new_y = -y;
  221. let new_z = cos_theta * z - sin_theta * x;
  222. points.push((new_x, new_y, new_z));
  223. } else {
  224. points.push((0., 0., 0.));
  225. }
  226. });
  227. Points3D::new(points)
  228. }
  229. _ => {
  230. return Err(eyre!("Unsupported depth data type {}", data.data_type()));
  231. }
  232. };
  233. if let Some(color_buffer) = image_cache.get(&id.replace("depth", "image")) {
  234. let colors = if let Some(mask) = mask_cache.get(&id.replace("depth", "masks")) {
  235. let mask_length = color_buffer.len() / 3;
  236. let number_masks = mask.len() / mask_length;
  237. color_buffer
  238. .chunks(3)
  239. .enumerate()
  240. .map(|(e, x)| {
  241. for i in 0..number_masks {
  242. if mask[i * mask_length + e] && (e % 3 == 0) {
  243. if i == 0 {
  244. return rerun::Color::from_rgb(255, x[1], x[2]);
  245. } else if i == 1 {
  246. return rerun::Color::from_rgb(x[0], 255, x[2]);
  247. } else if i == 2 {
  248. return rerun::Color::from_rgb(x[0], x[1], 255);
  249. } else {
  250. return rerun::Color::from_rgb(x[0], 255, x[2]);
  251. }
  252. }
  253. }
  254. rerun::Color::from_rgb(x[0], x[1], x[2])
  255. })
  256. .collect::<Vec<_>>()
  257. } else {
  258. color_buffer
  259. .chunks(3)
  260. .map(|x| rerun::Color::from_rgb(x[0], x[1], x[2]))
  261. .collect::<Vec<_>>()
  262. };
  263. rec.log(id.as_str(), &points.with_colors(colors))
  264. .context("could not log points")?;
  265. }
  266. } else if id.as_str().contains("text") {
  267. let buffer: StringArray = data.to_data().into();
  268. buffer.iter().try_for_each(|string| -> Result<()> {
  269. if let Some(str) = string {
  270. rec.log(id.as_str(), &rerun::TextLog::new(str))
  271. .wrap_err("Could not log text")
  272. } else {
  273. Ok(())
  274. }
  275. })?;
  276. } else if id.as_str().contains("boxes2d") {
  277. boxes2d::update_boxes2d(&rec, id, data, metadata).context("update boxes 2d")?;
  278. } else if id.as_str().contains("masks") {
  279. let masks = if let Some(data) = data.as_primitive_opt::<Float32Type>() {
  280. let data = data
  281. .iter()
  282. .map(|x| if let Some(x) = x { x > 0. } else { false })
  283. .collect::<Vec<_>>();
  284. data
  285. } else if let Some(data) = data.as_boolean_opt() {
  286. let data = data
  287. .iter()
  288. .map(|x| x.unwrap_or_default())
  289. .collect::<Vec<_>>();
  290. data
  291. } else {
  292. println!("Got unexpected data type: {}", data.data_type());
  293. continue;
  294. };
  295. mask_cache.insert(id.clone(), masks.clone());
  296. } else if id.as_str().contains("jointstate") {
  297. let mut positions: Vec<f32> = into_vec(&data)?;
  298. // Match file name
  299. let mut id = id.as_str().replace("jointstate_", "");
  300. id.push_str(".urdf");
  301. if let Some(chain) = chains.get(&id) {
  302. let dof = chain.dof();
  303. // Truncate or pad positions to match the chain's dof
  304. if dof < positions.len() {
  305. positions.truncate(dof);
  306. } else {
  307. for _ in 0..(dof - positions.len()) {
  308. positions.push(0.);
  309. }
  310. }
  311. update_visualization(&rec, chain, &id, &positions)?;
  312. } else {
  313. println!(
  314. "Could not find chain for {}. Only contains: {:#?}",
  315. id,
  316. chains.keys()
  317. );
  318. }
  319. } else if id.as_str().contains("series") {
  320. update_series(&rec, id, data).context("could not plot series")?;
  321. } else if id.as_str().contains("points3d") {
  322. // Get color or assign random color in cache
  323. let color = color_cache.get(&id);
  324. let color = if let Some(color) = color {
  325. color.clone()
  326. } else {
  327. let color =
  328. rerun::Color::from_rgb(rand::random::<u8>(), 180, rand::random::<u8>());
  329. color_cache.insert(id.clone(), color.clone());
  330. color
  331. };
  332. let dataid = id;
  333. // get a random color
  334. if let Ok(buffer) = into_vec::<f32>(&data) {
  335. let mut points = vec![];
  336. let mut colors = vec![];
  337. buffer.chunks(3).for_each(|chunk| {
  338. points.push((chunk[0], chunk[1], chunk[2]));
  339. colors.push(color);
  340. });
  341. let points = Points3D::new(points).with_radii(vec![0.013; colors.len()]);
  342. rec.log(dataid.as_str(), &points.with_colors(colors))
  343. .context("could not log points")?;
  344. }
  345. } else {
  346. println!("Could not find handler for {}", id);
  347. }
  348. }
  349. }
  350. Ok(())
  351. }
  352. #[cfg(feature = "python")]
  353. use pyo3::{
  354. pyfunction, pymodule,
  355. types::{PyModule, PyModuleMethods},
  356. wrap_pyfunction, Bound, PyResult, Python,
  357. };
  358. #[cfg(feature = "python")]
  359. #[pyfunction]
  360. fn py_main(_py: Python) -> eyre::Result<()> {
  361. lib_main()
  362. }
  363. /// A Python module implemented in Rust.
  364. #[cfg(feature = "python")]
  365. #[pymodule]
  366. fn dora_rerun(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
  367. m.add_function(wrap_pyfunction!(py_main, &m)?)?;
  368. m.add("__version__", env!("CARGO_PKG_VERSION"))?;
  369. Ok(())
  370. }