| @@ -8,7 +8,7 @@ on: | |||
| workflow_dispatch: | |||
| env: | |||
| RUST_LOG: trace | |||
| RUST_LOG: INFO | |||
| jobs: | |||
| test: | |||
| @@ -62,16 +62,10 @@ jobs: | |||
| - name: "Check" | |||
| run: cargo check --all | |||
| - name: "Build" | |||
| run: cargo build --all | |||
| - name: "Build (Without Python node as it is build with maturin)" | |||
| run: cargo build --all --exclude dora-node-api-python | |||
| - name: "Test" | |||
| # Remove Windows as there is `pdb` linker issue. | |||
| # See: https://github.com/dora-rs/dora/pull/359#discussion_r1360268497 | |||
| if: runner.os == 'Linux' || runner.os == 'macOS' | |||
| run: cargo test --all | |||
| - name: "Test" | |||
| if: runner.os == 'Windows' | |||
| run: cargo test --all --lib | |||
| run: cargo test --all --exclude dora-ros2-bridge-python | |||
| # Run examples as separate job because otherwise we will exhaust the disk | |||
| # space of the GitHub action runners. | |||
| @@ -138,21 +132,23 @@ jobs: | |||
| # python examples | |||
| - uses: actions/setup-python@v2 | |||
| if: runner.os == 'Linux' || runner.os == 'macOS' | |||
| if: runner.os != 'Windows' | |||
| with: | |||
| python-version: "3.8" | |||
| - uses: actions/setup-python@v2 | |||
| if: runner.os == 'Windows' | |||
| with: | |||
| python-version: "3.10" | |||
| - name: "Python Dataflow example" | |||
| if: runner.os == 'Linux' || runner.os == 'macOS' | |||
| run: cargo run --example python-dataflow | |||
| - name: "Python Operator Dataflow example" | |||
| if: runner.os == 'Linux' || runner.os == 'macOS' | |||
| run: cargo run --example python-operator-dataflow | |||
| # ROS2 bridge examples | |||
| ros2-bridge-examples: | |||
| name: "ROS2 Bridge Examples" | |||
| runs-on: ubuntu-latest | |||
| timeout-minutes: 30 | |||
| timeout-minutes: 45 | |||
| steps: | |||
| - uses: actions/checkout@v3 | |||
| - uses: r7kamura/rust-problem-matchers@v1.1.0 | |||
| @@ -168,6 +164,10 @@ jobs: | |||
| with: | |||
| required-ros-distributions: humble | |||
| - run: 'source /opt/ros/humble/setup.bash && echo AMENT_PREFIX_PATH=${AMENT_PREFIX_PATH} >> "$GITHUB_ENV"' | |||
| - name: "Install pyarrow for testing" | |||
| run: pip install numpy pyarrow | |||
| - name: "Test" | |||
| run: cargo test -p dora-ros2-bridge-python | |||
| - name: "Rust ROS2 Bridge example" | |||
| timeout-minutes: 30 | |||
| env: | |||
| @@ -176,8 +176,13 @@ jobs: | |||
| source /opt/ros/humble/setup.bash && ros2 run turtlesim turtlesim_node & | |||
| cargo run --example rust-ros2-dataflow --features="ros2-examples" | |||
| - uses: actions/setup-python@v2 | |||
| if: runner.os != 'Windows' | |||
| with: | |||
| python-version: "3.8" | |||
| - uses: actions/setup-python@v2 | |||
| if: runner.os == 'Windows' | |||
| with: | |||
| python-version: "3.10" | |||
| - name: "python-ros2-dataflow" | |||
| timeout-minutes: 30 | |||
| env: | |||
| @@ -235,38 +240,36 @@ jobs: | |||
| # fail-fast by using bash shell explictly | |||
| shell: bash | |||
| run: | | |||
| cargo install --path binaries/coordinator --locked | |||
| cargo install --path binaries/daemon --locked | |||
| cargo install --path binaries/cli --locked | |||
| - name: "Test CLI" | |||
| timeout-minutes: 30 | |||
| # fail-fast by using bash shell explictly | |||
| shell: bash | |||
| run: | | |||
| dora-cli up | |||
| dora-cli list | |||
| dora up | |||
| dora list | |||
| # Test Rust template Project | |||
| dora-cli new test_rust_project --internal-create-with-path-dependencies | |||
| dora new test_rust_project --internal-create-with-path-dependencies | |||
| cd test_rust_project | |||
| cargo build --all | |||
| dora-cli start dataflow.yml --name ci-rust-test | |||
| dora start dataflow.yml --name ci-rust-test | |||
| sleep 10 | |||
| dora-cli stop --name ci-rust-test | |||
| dora stop --name ci-rust-test | |||
| cd .. | |||
| # Test Python template Project | |||
| pip3 install maturin | |||
| maturin build -m apis/python/node/Cargo.toml | |||
| pip3 install target/wheels/* | |||
| dora-cli new test_python_project --lang python --internal-create-with-path-dependencies | |||
| dora new test_python_project --lang python --internal-create-with-path-dependencies | |||
| cd test_python_project | |||
| dora-cli start dataflow.yml --name ci-python-test | |||
| dora start dataflow.yml --name ci-python-test | |||
| sleep 10 | |||
| dora-cli stop --name ci-python-test | |||
| dora stop --name ci-python-test | |||
| cd .. | |||
| dora-cli destroy | |||
| dora destroy | |||
| clippy: | |||
| name: "Clippy" | |||
| @@ -15,7 +15,6 @@ jobs: | |||
| strategy: | |||
| matrix: | |||
| platform: [ubuntu-20.04] | |||
| python-version: ["3.7"] | |||
| fail-fast: false | |||
| runs-on: ${{ matrix.platform }} | |||
| @@ -58,10 +57,10 @@ jobs: | |||
| cargo publish -p dora-node-api-c --token ${{ secrets.CARGO_REGISTRY_TOKEN }} | |||
| # Publish binaries crates | |||
| cargo publish -p dora-cli --token ${{ secrets.CARGO_REGISTRY_TOKEN }} | |||
| cargo publish -p dora-coordinator --token ${{ secrets.CARGO_REGISTRY_TOKEN }} | |||
| cargo publish -p dora-runtime --token ${{ secrets.CARGO_REGISTRY_TOKEN }} | |||
| cargo publish -p dora-daemon --token ${{ secrets.CARGO_REGISTRY_TOKEN }} | |||
| cargo publish -p dora-cli --token ${{ secrets.CARGO_REGISTRY_TOKEN }} | |||
| # Publish extension crates | |||
| cargo publish -p dora-record --token ${{ secrets.CARGO_REGISTRY_TOKEN }} | |||
| @@ -72,7 +71,6 @@ jobs: | |||
| strategy: | |||
| matrix: | |||
| platform: [windows-2022] | |||
| python-version: ["3.7"] | |||
| fail-fast: false | |||
| runs-on: ${{ matrix.platform }} | |||
| @@ -83,16 +81,14 @@ jobs: | |||
| - name: "Build binaries" | |||
| timeout-minutes: 60 | |||
| run: "cargo build --release -p dora-coordinator -p dora-cli -p dora-daemon" | |||
| run: "cargo build --release -p dora-cli" | |||
| - name: Create Archive (Windows) | |||
| if: runner.os == 'Windows' | |||
| shell: powershell | |||
| run: | | |||
| New-Item -Path archive -ItemType Directory | |||
| Copy-Item target/release/dora-coordinator.exe -Destination archive | |||
| Copy-Item target/release/dora-daemon.exe -Destination archive | |||
| Copy-Item target/release/dora-cli.exe -Destination archive/dora.exe | |||
| Copy-Item target/release/dora.exe -Destination archive/dora.exe | |||
| Compress-Archive -Path archive\* -DestinationPath archive.zip | |||
| - name: "Upload release asset" | |||
| @@ -111,7 +107,6 @@ jobs: | |||
| strategy: | |||
| matrix: | |||
| platform: [macos-12, ubuntu-20.04] | |||
| python-version: ["3.7"] | |||
| fail-fast: false | |||
| runs-on: ${{ matrix.platform }} | |||
| @@ -122,15 +117,13 @@ jobs: | |||
| - name: "Build binaries" | |||
| timeout-minutes: 60 | |||
| run: "cargo build --release -p dora-coordinator -p dora-cli -p dora-daemon" | |||
| run: "cargo build --release -p dora-cli" | |||
| - name: "Create Archive (Unix)" | |||
| if: runner.os == 'Linux' || runner.os == 'macOS' | |||
| run: | | |||
| mkdir archive | |||
| cp target/release/dora-coordinator archive | |||
| cp target/release/dora-daemon archive | |||
| cp target/release/dora-cli archive/dora | |||
| cp target/release/dora archive/dora | |||
| cd archive | |||
| zip -r ../archive.zip . | |||
| cd .. | |||
| @@ -151,7 +144,6 @@ jobs: | |||
| strategy: | |||
| matrix: | |||
| platform: [ubuntu-20.04] | |||
| python-version: ["3.7"] | |||
| fail-fast: false | |||
| runs-on: ${{ matrix.platform }} | |||
| @@ -166,15 +158,13 @@ jobs: | |||
| with: | |||
| use-cross: true | |||
| command: build | |||
| args: --release --target aarch64-unknown-linux-gnu -p dora-coordinator -p dora-cli -p dora-daemon | |||
| args: --release --target aarch64-unknown-linux-gnu -p dora-cli | |||
| - name: "Archive Linux ARM64" | |||
| if: runner.os == 'Linux' | |||
| run: | | |||
| mkdir archive_aarch64 | |||
| cp target/aarch64-unknown-linux-gnu/release/dora-coordinator archive_aarch64 | |||
| cp target/aarch64-unknown-linux-gnu/release/dora-daemon archive_aarch64 | |||
| cp target/aarch64-unknown-linux-gnu/release/dora-cli archive_aarch64/dora | |||
| cp target/aarch64-unknown-linux-gnu/release/dora archive_aarch64/dora | |||
| cd archive_aarch64 | |||
| zip -r ../archive_aarch64.zip . | |||
| cd .. | |||
| @@ -196,7 +186,6 @@ jobs: | |||
| strategy: | |||
| matrix: | |||
| platform: [macos-12] | |||
| python-version: ["3.7"] | |||
| fail-fast: false | |||
| runs-on: ${{ matrix.platform }} | |||
| @@ -219,9 +208,7 @@ jobs: | |||
| if: runner.os == 'macOS' | |||
| run: | | |||
| mkdir archive_aarch64 | |||
| cp target/aarch64-apple-darwin/release/dora-coordinator archive_aarch64 | |||
| cp target/aarch64-apple-darwin/release/dora-daemon archive_aarch64 | |||
| cp target/aarch64-apple-darwin/release/dora-cli archive_aarch64/dora | |||
| cp target/aarch64-apple-darwin/release/dora archive_aarch64/dora | |||
| cd archive_aarch64 | |||
| zip -r ../archive_aarch64.zip . | |||
| cd .. | |||
| @@ -7,6 +7,7 @@ | |||
| # Remove arrow file from dora-record | |||
| **/*.arrow | |||
| *.pt | |||
| # Removing images. | |||
| *.jpg | |||
| @@ -32,31 +32,31 @@ members = [ | |||
| [workspace.package] | |||
| # Make sure to also bump `apis/node/python/__init__.py` version. | |||
| version = "0.3.0" | |||
| version = "0.3.2" | |||
| description = "`dora` goal is to be a low latency, composable, and distributed data flow." | |||
| documentation = "https://dora.carsmos.ai" | |||
| license = "Apache-2.0" | |||
| [workspace.dependencies] | |||
| dora-node-api = { version = "0.3.0", path = "apis/rust/node", default-features = false } | |||
| dora-node-api-python = { version = "0.3.0", path = "apis/python/node", default-features = false } | |||
| dora-operator-api = { version = "0.3.0", path = "apis/rust/operator", default-features = false } | |||
| dora-operator-api-macros = { version = "0.3.0", path = "apis/rust/operator/macros" } | |||
| dora-operator-api-types = { version = "0.3.0", path = "apis/rust/operator/types" } | |||
| dora-operator-api-python = { version = "0.3.0", path = "apis/python/operator" } | |||
| dora-operator-api-c = { version = "0.3.0", path = "apis/c/operator" } | |||
| dora-node-api-c = { version = "0.3.0", path = "apis/c/node" } | |||
| dora-core = { version = "0.3.0", path = "libraries/core" } | |||
| dora-arrow-convert = { version = "0.3.0", path = "libraries/arrow-convert" } | |||
| dora-tracing = { version = "0.3.0", path = "libraries/extensions/telemetry/tracing" } | |||
| dora-metrics = { version = "0.3.0", path = "libraries/extensions/telemetry/metrics" } | |||
| dora-download = { version = "0.3.0", path = "libraries/extensions/download" } | |||
| shared-memory-server = { version = "0.3.0", path = "libraries/shared-memory-server" } | |||
| communication-layer-request-reply = { version = "0.3.0", path = "libraries/communication-layer/request-reply" } | |||
| dora-message = { version = "0.3.0", path = "libraries/message" } | |||
| dora-runtime = { version = "0.3.0", path = "binaries/runtime" } | |||
| dora-daemon = { version = "0.3.0", path = "binaries/daemon" } | |||
| dora-coordinator = { version = "0.3.0", path = "binaries/coordinator" } | |||
| dora-node-api = { version = "0.3.2", path = "apis/rust/node", default-features = false } | |||
| dora-node-api-python = { version = "0.3.2", path = "apis/python/node", default-features = false } | |||
| dora-operator-api = { version = "0.3.2", path = "apis/rust/operator", default-features = false } | |||
| dora-operator-api-macros = { version = "0.3.2", path = "apis/rust/operator/macros" } | |||
| dora-operator-api-types = { version = "0.3.2", path = "apis/rust/operator/types" } | |||
| dora-operator-api-python = { version = "0.3.2", path = "apis/python/operator" } | |||
| dora-operator-api-c = { version = "0.3.2", path = "apis/c/operator" } | |||
| dora-node-api-c = { version = "0.3.2", path = "apis/c/node" } | |||
| dora-core = { version = "0.3.2", path = "libraries/core" } | |||
| dora-arrow-convert = { version = "0.3.2", path = "libraries/arrow-convert" } | |||
| dora-tracing = { version = "0.3.2", path = "libraries/extensions/telemetry/tracing" } | |||
| dora-metrics = { version = "0.3.2", path = "libraries/extensions/telemetry/metrics" } | |||
| dora-download = { version = "0.3.2", path = "libraries/extensions/download" } | |||
| shared-memory-server = { version = "0.3.2", path = "libraries/shared-memory-server" } | |||
| communication-layer-request-reply = { version = "0.3.2", path = "libraries/communication-layer/request-reply" } | |||
| dora-message = { version = "0.3.2", path = "libraries/message" } | |||
| dora-runtime = { version = "0.3.2", path = "binaries/runtime" } | |||
| dora-daemon = { version = "0.3.2", path = "binaries/daemon" } | |||
| dora-coordinator = { version = "0.3.2", path = "binaries/coordinator" } | |||
| dora-ros2-bridge = { path = "libraries/extensions/ros2-bridge" } | |||
| dora-ros2-bridge-msg-gen = { path = "libraries/extensions/ros2-bridge/msg-gen" } | |||
| dora-ros2-bridge-python = { path = "libraries/extensions/ros2-bridge/python" } | |||
| @@ -81,17 +81,16 @@ ros2-examples = [] | |||
| [dev-dependencies] | |||
| eyre = "0.6.8" | |||
| tokio = "1.24.2" | |||
| dora-daemon = { workspace = true } | |||
| dora-coordinator = { workspace = true } | |||
| dora-core = { workspace = true } | |||
| dora-tracing = { workspace = true } | |||
| dora-download = { workspace = true } | |||
| dunce = "1.0.2" | |||
| serde_yaml = "0.8.23" | |||
| uuid = { version = "1.2.1", features = ["v4", "serde"] } | |||
| tracing = "0.1.36" | |||
| tracing-subscriber = "0.3.15" | |||
| futures = "0.3.25" | |||
| tokio-stream = "0.1.11" | |||
| clap = { version = "4.0.3", features = ["derive"] } | |||
| [[example]] | |||
| name = "c-dataflow" | |||
| @@ -1,5 +1,46 @@ | |||
| # Changelog | |||
| ## v0.3.2 (2024-01-26) | |||
| ## Features | |||
| - Wait until `DestroyResult` is sent before exiting dora-daemon by @phil-opp in https://github.com/dora-rs/dora/pull/413 | |||
| - Reduce dora-rs to a single binary by @haixuanTao in https://github.com/dora-rs/dora/pull/410 | |||
| - Rework python ROS2 (de)serialization using parsed ROS2 messages directly by @phil-opp in https://github.com/dora-rs/dora/pull/415 | |||
| - Fix ros2 array bug by @haixuanTao in https://github.com/dora-rs/dora/pull/412 | |||
| - Test ros2 type info by @haixuanTao in https://github.com/dora-rs/dora/pull/418 | |||
| - Use forward slash as it is default way of defining ros2 topic by @haixuanTao in https://github.com/dora-rs/dora/pull/419 | |||
| ## Minor | |||
| - Bump h2 from 0.3.21 to 0.3.24 by @dependabot in https://github.com/dora-rs/dora/pull/414 | |||
| ## v0.3.1 (2024-01-09) | |||
| ## Features | |||
| - Support legacy python by @haixuanTao in https://github.com/dora-rs/dora/pull/382 | |||
| - Add an error catch in python `on_event` when using hot-reloading by @haixuanTao in https://github.com/dora-rs/dora/pull/372 | |||
| - add cmake example by @XxChang in https://github.com/dora-rs/dora/pull/381 | |||
| - Bump opentelemetry metrics to 0.21 by @haixuanTao in https://github.com/dora-rs/dora/pull/383 | |||
| - Trace send_output as it can be a big source of overhead for large messages by @haixuanTao in https://github.com/dora-rs/dora/pull/384 | |||
| - Adding a timeout method to not block indefinitely next event by @haixuanTao in https://github.com/dora-rs/dora/pull/386 | |||
| - Adding `Vec<u8>` conversion by @haixuanTao in https://github.com/dora-rs/dora/pull/387 | |||
| - Dora cli renaming by @haixuanTao in https://github.com/dora-rs/dora/pull/399 | |||
| - Update `ros2-client` and `rustdds` dependencies to latest fork version by @phil-opp in https://github.com/dora-rs/dora/pull/397 | |||
| ## Fix | |||
| - Fix window path error by @haixuanTao in https://github.com/dora-rs/dora/pull/398 | |||
| - Fix read error in C++ node input by @haixuanTao in https://github.com/dora-rs/dora/pull/406 | |||
| - Bump unsafe-libyaml from 0.2.9 to 0.2.10 by @dependabot in https://github.com/dora-rs/dora/pull/400 | |||
| ## New Contributors | |||
| - @XxChang made their first contribution in https://github.com/dora-rs/dora/pull/381 | |||
| **Full Changelog**: https://github.com/dora-rs/dora/compare/v0.3.0...v0.3.1 | |||
| ## v0.3.0 (2023-11-01) | |||
| ## Features | |||
| @@ -57,9 +57,6 @@ Quickest way: | |||
| ```bash | |||
| cargo install dora-cli | |||
| alias dora='dora-cli' | |||
| cargo install dora-coordinator | |||
| cargo install dora-daemon | |||
| pip install dora-rs ## For Python API | |||
| dora --help | |||
| @@ -72,17 +69,17 @@ For more installation guideline, check out our installation guide here: https:// | |||
| 1. Install the example python dependencies: | |||
| ```bash | |||
| pip install -r https://raw.githubusercontent.com/dora-rs/dora/v0.3.0/examples/python-operator-dataflow/requirements.txt | |||
| pip install -r https://raw.githubusercontent.com/dora-rs/dora/v0.3.2/examples/python-operator-dataflow/requirements.txt | |||
| ``` | |||
| 2. Get some example operators: | |||
| ```bash | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.0/examples/python-operator-dataflow/webcam.py | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.0/examples/python-operator-dataflow/plot.py | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.0/examples/python-operator-dataflow/utils.py | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.0/examples/python-operator-dataflow/object_detection.py | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.0/examples/python-operator-dataflow/dataflow.yml | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.2/examples/python-operator-dataflow/webcam.py | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.2/examples/python-operator-dataflow/plot.py | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.2/examples/python-operator-dataflow/utils.py | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.2/examples/python-operator-dataflow/object_detection.py | |||
| wget https://raw.githubusercontent.com/dora-rs/dora/v0.3.2/examples/python-operator-dataflow/dataflow.yml | |||
| ``` | |||
| 3. Start the dataflow | |||
| @@ -19,7 +19,6 @@ tracing = ["dora-node-api/tracing"] | |||
| [dependencies] | |||
| eyre = "0.6.8" | |||
| flume = "0.10.14" | |||
| tracing = "0.1.33" | |||
| arrow-array = { workspace = true } | |||
| @@ -1,6 +1,6 @@ | |||
| #![deny(unsafe_op_in_unsafe_fn)] | |||
| use arrow_array::BinaryArray; | |||
| use arrow_array::UInt8Array; | |||
| use dora_node_api::{arrow::array::AsArray, DoraNode, Event, EventStream}; | |||
| use eyre::Context; | |||
| use std::{ffi::c_void, ptr, slice}; | |||
| @@ -170,22 +170,24 @@ pub unsafe extern "C" fn read_dora_input_data( | |||
| ) { | |||
| let event: &Event = unsafe { &*event.cast() }; | |||
| match event { | |||
| Event::Input { data, .. } => { | |||
| let data: Option<&BinaryArray> = data.as_binary_opt(); | |||
| if let Some(data) = data { | |||
| let ptr = data.value(0).as_ptr(); | |||
| let len = data.value(0).len(); | |||
| Event::Input { data, metadata, .. } => match metadata.type_info.data_type { | |||
| dora_node_api::arrow::datatypes::DataType::UInt8 => { | |||
| let array: &UInt8Array = data.as_primitive(); | |||
| let ptr = array.values().as_ptr(); | |||
| unsafe { | |||
| *out_ptr = ptr; | |||
| *out_len = len; | |||
| } | |||
| } else { | |||
| unsafe { | |||
| *out_ptr = ptr::null(); | |||
| *out_len = 0; | |||
| *out_len = metadata.type_info.len; | |||
| } | |||
| } | |||
| } | |||
| dora_node_api::arrow::datatypes::DataType::Null => unsafe { | |||
| *out_ptr = ptr::null(); | |||
| *out_len = 0; | |||
| }, | |||
| _ => { | |||
| todo!("dora C++ Node does not yet support higher level type of arrow. Only UInt8. | |||
| The ultimate solution should be based on arrow FFI interface. Feel free to contribute :)") | |||
| } | |||
| }, | |||
| _ => unsafe { | |||
| *out_ptr = ptr::null(); | |||
| *out_len = 0; | |||
| @@ -2,10 +2,10 @@ use std::path::Path; | |||
| fn main() { | |||
| dora_operator_api_types::generate_headers(Path::new("operator_types.h")) | |||
| .expect("failed to create operator_api.h"); | |||
| .expect("failed to create operator_types.h"); | |||
| // don't rebuild on changes (otherwise we rebuild on every run as we're | |||
| // writing the `operator_types.h` file; cargo will still rerun this script | |||
| // when the `dora_operator_api_types` crate changes) | |||
| println!("cargo:rerun-if-changed="); | |||
| println!("cargo:rerun-if-changed=build.rs"); | |||
| } | |||
| @@ -20,7 +20,7 @@ pyo3 = { workspace = true, features = ["eyre", "abi3-py37"] } | |||
| eyre = "0.6" | |||
| serde_yaml = "0.8.23" | |||
| flume = "0.10.14" | |||
| dora-runtime = { workspace = true, features = ["tracing", "python"] } | |||
| dora-runtime = { workspace = true, features = ["tracing", "metrics", "python"] } | |||
| arrow = { workspace = true, features = ["pyarrow"] } | |||
| pythonize = { workspace = true } | |||
| futures = "0.3.28" | |||
| @@ -28,4 +28,4 @@ dora-ros2-bridge-python = { workspace = true } | |||
| [lib] | |||
| name = "dora" | |||
| crate-type = ["lib", "cdylib"] | |||
| crate-type = ["cdylib"] | |||
| @@ -4,9 +4,9 @@ This crate corresponds to the Node API for Dora. | |||
| To build the Python module for development: | |||
| ````bash | |||
| python3 -m venv .env | |||
| ```bash | |||
| python -m venv .env | |||
| source .env/bin/activate | |||
| pip install maturin | |||
| maturin develop | |||
| ```` | |||
| ``` | |||
| @@ -15,7 +15,7 @@ from enum import Enum | |||
| from .dora import * | |||
| __author__ = "Dora-rs Authors" | |||
| __version__ = "0.3.0" | |||
| __version__ = "0.3.2" | |||
| class DoraStatus(Enum): | |||
| @@ -14,20 +14,13 @@ tracing = ["dep:dora-tracing"] | |||
| dora-core = { workspace = true } | |||
| shared-memory-server = { workspace = true } | |||
| eyre = "0.6.7" | |||
| once_cell = "1.13.0" | |||
| serde = { version = "1.0.136", features = ["derive"] } | |||
| serde_yaml = "0.8.23" | |||
| serde_json = "1.0.89" | |||
| thiserror = "1.0.30" | |||
| tracing = "0.1.33" | |||
| flume = "0.10.14" | |||
| uuid = { version = "1.1.2", features = ["v4"] } | |||
| capnp = "0.14.11" | |||
| bincode = "1.3.3" | |||
| shared_memory_extended = "0.13.0" | |||
| dora-tracing = { workspace = true, optional = true } | |||
| arrow = { workspace = true } | |||
| arrow-schema = { workspace = true } | |||
| futures = "0.3.28" | |||
| futures-concurrency = "7.3.0" | |||
| futures-timer = "3.0.2" | |||
| @@ -13,5 +13,5 @@ arrow = { workspace = true, features = ["ffi"] } | |||
| dora-arrow-convert = { workspace = true } | |||
| [dependencies.safer-ffi] | |||
| version = "0.1.3" | |||
| version = "0.1.4" | |||
| features = ["headers", "inventory-0-3-1"] | |||
| @@ -1,4 +1,5 @@ | |||
| #![deny(elided_lifetimes_in_paths)] // required for safer-ffi | |||
| #![allow(improper_ctypes_definitions)] | |||
| pub use arrow; | |||
| use dora_arrow_convert::{ArrowData, IntoArrow}; | |||
| @@ -9,7 +9,7 @@ license.workspace = true | |||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |||
| [[bin]] | |||
| name = "dora-cli" | |||
| name = "dora" | |||
| path = "src/main.rs" | |||
| [features] | |||
| @@ -34,4 +34,10 @@ notify = "5.1.0" | |||
| ctrlc = "3.2.5" | |||
| tracing = "0.1.36" | |||
| dora-tracing = { workspace = true, optional = true } | |||
| bat = "0.23.0" | |||
| bat = "0.24.0" | |||
| dora-daemon = { workspace = true } | |||
| dora-coordinator = { workspace = true } | |||
| dora-runtime = { workspace = true } | |||
| tokio = { version = "1.20.1", features = ["full"] } | |||
| tokio-stream = { version = "0.1.8", features = ["io-util", "net"] } | |||
| futures = "0.3.21" | |||
| @@ -1,15 +1,22 @@ | |||
| use std::path::PathBuf; | |||
| use std::{net::Ipv4Addr, path::PathBuf}; | |||
| use attach::attach_dataflow; | |||
| use clap::Parser; | |||
| use communication_layer_request_reply::{RequestReplyLayer, TcpLayer, TcpRequestReplyConnection}; | |||
| use dora_coordinator::Event; | |||
| use dora_core::{ | |||
| descriptor::Descriptor, | |||
| topics::{control_socket_addr, ControlRequest, ControlRequestReply, DataflowId}, | |||
| topics::{ | |||
| control_socket_addr, ControlRequest, ControlRequestReply, DataflowId, | |||
| DORA_COORDINATOR_PORT_DEFAULT, | |||
| }, | |||
| }; | |||
| use dora_daemon::Daemon; | |||
| #[cfg(feature = "tracing")] | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use std::net::SocketAddr; | |||
| use tokio::runtime::Builder; | |||
| use uuid::Uuid; | |||
| mod attach; | |||
| @@ -56,10 +63,6 @@ enum Command { | |||
| Up { | |||
| #[clap(long)] | |||
| config: Option<PathBuf>, | |||
| #[clap(long)] | |||
| coordinator_path: Option<PathBuf>, | |||
| #[clap(long)] | |||
| daemon_path: Option<PathBuf>, | |||
| }, | |||
| /// Destroy running coordinator and daemon. If some dataflows are still running, they will be stopped first. | |||
| Destroy { | |||
| @@ -87,11 +90,29 @@ enum Command { | |||
| // Planned for future releases: | |||
| // Dashboard, | |||
| /// Show logs of a given dataflow and node. | |||
| Logs { dataflow: String, node: String }, | |||
| #[command(allow_missing_positional = true)] | |||
| Logs { | |||
| dataflow: Option<String>, | |||
| node: String, | |||
| }, | |||
| // Metrics, | |||
| // Stats, | |||
| // Get, | |||
| // Upgrade, | |||
| /// Run daemon | |||
| Daemon { | |||
| #[clap(long)] | |||
| machine_id: Option<String>, | |||
| #[clap(long)] | |||
| coordinator_addr: Option<SocketAddr>, | |||
| #[clap(long)] | |||
| run_dataflow: Option<PathBuf>, | |||
| }, | |||
| /// Run runtime | |||
| Runtime, | |||
| /// Run coordinator | |||
| Coordinator { port: Option<u16> }, | |||
| } | |||
| #[derive(Debug, clap::Args)] | |||
| @@ -127,10 +148,24 @@ fn main() { | |||
| } | |||
| fn run() -> eyre::Result<()> { | |||
| #[cfg(feature = "tracing")] | |||
| set_up_tracing("dora-cli").context("failed to set up tracing subscriber")?; | |||
| let args = Args::parse(); | |||
| #[cfg(feature = "tracing")] | |||
| match args.command { | |||
| Command::Daemon { .. } => { | |||
| set_up_tracing("dora-daemon").context("failed to set up tracing subscriber")?; | |||
| } | |||
| Command::Runtime => { | |||
| // Do not set the runtime in the cli. | |||
| } | |||
| Command::Coordinator { .. } => { | |||
| set_up_tracing("dora-coordinator").context("failed to set up tracing subscriber")?; | |||
| } | |||
| _ => { | |||
| set_up_tracing("dora-cli").context("failed to set up tracing subscriber")?; | |||
| } | |||
| }; | |||
| match args.command { | |||
| Command::Check { dataflow } => match dataflow { | |||
| Some(dataflow) => { | |||
| @@ -159,21 +194,25 @@ fn run() -> eyre::Result<()> { | |||
| args, | |||
| internal_create_with_path_dependencies, | |||
| } => template::create(args, internal_create_with_path_dependencies)?, | |||
| Command::Up { | |||
| config, | |||
| coordinator_path, | |||
| daemon_path, | |||
| } => up::up( | |||
| config.as_deref(), | |||
| coordinator_path.as_deref(), | |||
| daemon_path.as_deref(), | |||
| )?, | |||
| Command::Up { config } => up::up(config.as_deref())?, | |||
| Command::Logs { dataflow, node } => { | |||
| let uuid = Uuid::parse_str(&dataflow).ok(); | |||
| let name = if uuid.is_some() { None } else { Some(dataflow) }; | |||
| let mut session = | |||
| connect_to_coordinator().wrap_err("failed to connect to dora coordinator")?; | |||
| logs::logs(&mut *session, uuid, name, node)? | |||
| let uuids = query_running_dataflows(&mut *session) | |||
| .wrap_err("failed to query running dataflows")?; | |||
| if let Some(dataflow) = dataflow { | |||
| let uuid = Uuid::parse_str(&dataflow).ok(); | |||
| let name = if uuid.is_some() { None } else { Some(dataflow) }; | |||
| logs::logs(&mut *session, uuid, name, node)? | |||
| } else { | |||
| let uuid = match &uuids[..] { | |||
| [] => bail!("No dataflows are running"), | |||
| [uuid] => uuid.clone(), | |||
| _ => inquire::Select::new("Choose dataflow to show logs:", uuids).prompt()?, | |||
| }; | |||
| logs::logs(&mut *session, Some(uuid.uuid), None, node)? | |||
| } | |||
| } | |||
| Command::Start { | |||
| dataflow, | |||
| @@ -227,7 +266,54 @@ fn run() -> eyre::Result<()> { | |||
| } | |||
| } | |||
| Command::Destroy { config } => up::destroy(config.as_deref())?, | |||
| } | |||
| Command::Coordinator { port } => { | |||
| let rt = Builder::new_multi_thread() | |||
| .enable_all() | |||
| .build() | |||
| .context("tokio runtime failed")?; | |||
| rt.block_on(async { | |||
| let (_port, task) = | |||
| dora_coordinator::start(port, futures::stream::empty::<Event>()).await?; | |||
| task.await | |||
| }) | |||
| .context("failed to run dora-coordinator")? | |||
| } | |||
| Command::Daemon { | |||
| coordinator_addr, | |||
| machine_id, | |||
| run_dataflow, | |||
| } => { | |||
| let rt = Builder::new_multi_thread() | |||
| .enable_all() | |||
| .build() | |||
| .context("tokio runtime failed")?; | |||
| rt.block_on(async { | |||
| match run_dataflow { | |||
| Some(dataflow_path) => { | |||
| tracing::info!("Starting dataflow `{}`", dataflow_path.display()); | |||
| if let Some(coordinator_addr) = coordinator_addr { | |||
| tracing::info!( | |||
| "Not using coordinator addr {} as `run_dataflow` is for local dataflow only. Please use the `start` command for remote coordinator", | |||
| coordinator_addr | |||
| ); | |||
| } | |||
| Daemon::run_dataflow(&dataflow_path).await | |||
| } | |||
| None => { | |||
| let addr = coordinator_addr.unwrap_or_else(|| { | |||
| tracing::info!("Starting in local mode"); | |||
| let localhost = Ipv4Addr::new(127, 0, 0, 1); | |||
| (localhost, DORA_COORDINATOR_PORT_DEFAULT).into() | |||
| }); | |||
| Daemon::run(addr, machine_id.unwrap_or_default()).await | |||
| } | |||
| } | |||
| }) | |||
| .context("failed to run dora-daemon")? | |||
| } | |||
| Command::Runtime => dora_runtime::main().context("Failed to run dora-runtime")?, | |||
| }; | |||
| Ok(()) | |||
| } | |||
| @@ -16,8 +16,7 @@ nodes: | |||
| - id: custom-node_1 | |||
| custom: | |||
| source: python3 | |||
| args: ./node_1/node_1.py | |||
| source: ./node_1/node_1.py | |||
| inputs: | |||
| tick: dora/timer/secs/1 | |||
| input-1: op_1/some-output | |||
| @@ -6,17 +6,13 @@ use std::{fs, path::Path, process::Command, time::Duration}; | |||
| #[derive(Debug, Default, serde::Serialize, serde::Deserialize)] | |||
| struct UpConfig {} | |||
| pub(crate) fn up( | |||
| config_path: Option<&Path>, | |||
| coordinator: Option<&Path>, | |||
| daemon: Option<&Path>, | |||
| ) -> eyre::Result<()> { | |||
| pub(crate) fn up(config_path: Option<&Path>) -> eyre::Result<()> { | |||
| let UpConfig {} = parse_dora_config(config_path)?; | |||
| let mut session = match connect_to_coordinator() { | |||
| Ok(session) => session, | |||
| Err(_) => { | |||
| start_coordinator(coordinator).wrap_err("failed to start dora-coordinator")?; | |||
| start_coordinator().wrap_err("failed to start dora-coordinator")?; | |||
| loop { | |||
| match connect_to_coordinator() { | |||
| @@ -31,7 +27,7 @@ pub(crate) fn up( | |||
| }; | |||
| if !daemon_running(&mut *session)? { | |||
| start_daemon(daemon).wrap_err("failed to start dora-daemon")?; | |||
| start_daemon().wrap_err("failed to start dora-daemon")?; | |||
| } | |||
| Ok(()) | |||
| @@ -70,24 +66,24 @@ fn parse_dora_config(config_path: Option<&Path>) -> Result<UpConfig, eyre::ErrRe | |||
| Ok(config) | |||
| } | |||
| fn start_coordinator(coordinator: Option<&Path>) -> eyre::Result<()> { | |||
| let coordinator = coordinator.unwrap_or_else(|| Path::new("dora-coordinator")); | |||
| let mut cmd = Command::new(coordinator); | |||
| fn start_coordinator() -> eyre::Result<()> { | |||
| let mut cmd = | |||
| Command::new(std::env::current_exe().wrap_err("failed to get current executable path")?); | |||
| cmd.arg("coordinator"); | |||
| cmd.spawn() | |||
| .wrap_err_with(|| format!("failed to run {}", coordinator.display()))?; | |||
| .wrap_err_with(|| format!("failed to run `dora coordinator`"))?; | |||
| println!("started dora coordinator"); | |||
| Ok(()) | |||
| } | |||
| fn start_daemon(daemon: Option<&Path>) -> eyre::Result<()> { | |||
| let daemon = daemon.unwrap_or_else(|| Path::new("dora-daemon")); | |||
| let mut cmd = Command::new(daemon); | |||
| fn start_daemon() -> eyre::Result<()> { | |||
| let mut cmd = | |||
| Command::new(std::env::current_exe().wrap_err("failed to get current executable path")?); | |||
| cmd.arg("daemon"); | |||
| cmd.spawn() | |||
| .wrap_err_with(|| format!("failed to run {}", daemon.display()))?; | |||
| .wrap_err_with(|| format!("failed to run `dora daemon`"))?; | |||
| println!("started dora daemon"); | |||
| @@ -15,20 +15,13 @@ tracing = ["dep:dora-tracing"] | |||
| [dependencies] | |||
| eyre = "0.6.7" | |||
| futures = "0.3.21" | |||
| serde = { version = "1.0.136", features = ["derive"] } | |||
| serde_yaml = "0.8.23" | |||
| tokio = { version = "1.24.2", features = ["full"] } | |||
| tokio-stream = { version = "0.1.8", features = ["io-util", "net"] } | |||
| uuid = { version = "1.2.1" } | |||
| rand = "0.8.5" | |||
| dora-core = { workspace = true } | |||
| tracing = "0.1.36" | |||
| dora-tracing = { workspace = true, optional = true } | |||
| futures-concurrency = "7.1.0" | |||
| zenoh = "0.7.0-rc" | |||
| serde_json = "1.0.86" | |||
| which = "4.3.0" | |||
| thiserror = "1.0.37" | |||
| ctrlc = "3.2.5" | |||
| clap = { version = "4.0.3", features = ["derive"] } | |||
| names = "0.14.0" | |||
| ctrlc = "3.2.5" | |||
| @@ -38,36 +38,23 @@ mod listener; | |||
| mod run; | |||
| mod tcp_utils; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| #[clap(about = "Dora coordinator")] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub port: Option<u16>, | |||
| } | |||
| pub async fn run(args: Args) -> eyre::Result<()> { | |||
| let ctrlc_events = set_up_ctrlc_handler()?; | |||
| let (_, task) = start(args, ctrlc_events).await?; | |||
| task.await?; | |||
| Ok(()) | |||
| } | |||
| pub async fn start( | |||
| args: Args, | |||
| port: Option<u16>, | |||
| external_events: impl Stream<Item = Event> + Unpin, | |||
| ) -> Result<(u16, impl Future<Output = eyre::Result<()>>), eyre::ErrReport> { | |||
| let port = args.port.unwrap_or(DORA_COORDINATOR_PORT_DEFAULT); | |||
| let port = port.unwrap_or(DORA_COORDINATOR_PORT_DEFAULT); | |||
| let listener = listener::create_listener(port).await?; | |||
| let port = listener | |||
| .local_addr() | |||
| .wrap_err("failed to get local addr of listener")? | |||
| .port(); | |||
| let mut tasks = FuturesUnordered::new(); | |||
| // Setup ctrl-c handler | |||
| let ctrlc_events = set_up_ctrlc_handler()?; | |||
| let future = async move { | |||
| start_inner(listener, &tasks, external_events).await?; | |||
| start_inner(listener, &tasks, (ctrlc_events, external_events).merge()).await?; | |||
| tracing::debug!("coordinator main loop finished, waiting on spawned tasks"); | |||
| while let Some(join_result) = tasks.next().await { | |||
| @@ -251,8 +238,11 @@ async fn start_inner( | |||
| // notify all machines that run parts of the dataflow | |||
| for machine_id in &dataflow.machines { | |||
| let Some(connection) = daemon_connections.get_mut(machine_id) else { | |||
| tracing::warn!("no daemon connection found for machine `{machine_id}`"); | |||
| let Some(connection) = daemon_connections.get_mut(machine_id) | |||
| else { | |||
| tracing::warn!( | |||
| "no daemon connection found for machine `{machine_id}`" | |||
| ); | |||
| continue; | |||
| }; | |||
| tcp_send(&mut connection.stream, &message) | |||
| @@ -601,28 +591,6 @@ struct DaemonConnection { | |||
| last_heartbeat: Instant, | |||
| } | |||
| fn set_up_ctrlc_handler() -> Result<impl Stream<Item = Event>, eyre::ErrReport> { | |||
| let (ctrlc_tx, ctrlc_rx) = mpsc::channel(1); | |||
| let mut ctrlc_sent = false; | |||
| ctrlc::set_handler(move || { | |||
| if ctrlc_sent { | |||
| tracing::warn!("received second ctrlc signal -> aborting immediately"); | |||
| std::process::abort(); | |||
| } else { | |||
| tracing::info!("received ctrlc signal"); | |||
| if ctrlc_tx.blocking_send(Event::CtrlC).is_err() { | |||
| tracing::error!("failed to report ctrl-c event to dora-coordinator"); | |||
| } | |||
| ctrlc_sent = true; | |||
| } | |||
| }) | |||
| .wrap_err("failed to set ctrl-c handler")?; | |||
| Ok(ReceiverStream::new(ctrlc_rx)) | |||
| } | |||
| async fn handle_destroy( | |||
| running_dataflows: &HashMap<Uuid, RunningDataflow>, | |||
| daemon_connections: &mut HashMap<String, DaemonConnection>, | |||
| @@ -884,7 +852,7 @@ async fn destroy_daemons( | |||
| match serde_json::from_slice(&reply_raw) | |||
| .wrap_err("failed to deserialize destroy reply from daemon")? | |||
| { | |||
| DaemonCoordinatorReply::DestroyResult(result) => result | |||
| DaemonCoordinatorReply::DestroyResult { result, .. } => result | |||
| .map_err(|e| eyre!(e)) | |||
| .wrap_err("failed to destroy dataflow")?, | |||
| other => bail!("unexpected reply after sending `destroy`: {other:?}"), | |||
| @@ -940,3 +908,25 @@ pub enum DaemonEvent { | |||
| listen_socket: SocketAddr, | |||
| }, | |||
| } | |||
| fn set_up_ctrlc_handler() -> Result<impl Stream<Item = Event>, eyre::ErrReport> { | |||
| let (ctrlc_tx, ctrlc_rx) = mpsc::channel(1); | |||
| let mut ctrlc_sent = false; | |||
| ctrlc::set_handler(move || { | |||
| if ctrlc_sent { | |||
| tracing::warn!("received second ctrlc signal -> aborting immediately"); | |||
| std::process::abort(); | |||
| } else { | |||
| tracing::info!("received ctrlc signal"); | |||
| if ctrlc_tx.blocking_send(Event::CtrlC).is_err() { | |||
| tracing::error!("failed to report ctrl-c event to dora-coordinator"); | |||
| } | |||
| ctrlc_sent = true; | |||
| } | |||
| }) | |||
| .wrap_err("failed to set ctrl-c handler")?; | |||
| Ok(ReceiverStream::new(ctrlc_rx)) | |||
| } | |||
| @@ -1,12 +0,0 @@ | |||
| #[cfg(feature = "tracing")] | |||
| use dora_tracing::set_up_tracing; | |||
| #[cfg(feature = "tracing")] | |||
| use eyre::Context; | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| #[cfg(feature = "tracing")] | |||
| set_up_tracing("dora-coordinator").context("failed to set up tracing subscriber")?; | |||
| dora_coordinator::run(clap::Parser::parse()).await | |||
| } | |||
| @@ -22,20 +22,16 @@ tokio-stream = { version = "0.1.11", features = ["net"] } | |||
| tracing = "0.1.36" | |||
| tracing-opentelemetry = { version = "0.18.0", optional = true } | |||
| futures-concurrency = "7.1.0" | |||
| serde = { version = "1.0.136", features = ["derive"] } | |||
| serde_json = "1.0.86" | |||
| dora-core = { workspace = true } | |||
| dora-runtime = { workspace = true } | |||
| flume = "0.10.14" | |||
| dora-download = { workspace = true } | |||
| dora-tracing = { workspace = true, optional = true } | |||
| serde_yaml = "0.8.23" | |||
| uuid = { version = "1.1.2", features = ["v4"] } | |||
| futures = "0.3.25" | |||
| clap = { version = "4.0.3", features = ["derive"] } | |||
| shared-memory-server = { workspace = true } | |||
| ctrlc = "3.2.5" | |||
| bincode = "1.3.3" | |||
| async-trait = "0.1.64" | |||
| arrow-schema = { workspace = true } | |||
| aligned-vec = "0.5.0" | |||
| ctrlc = "3.2.5" | |||
| @@ -113,6 +113,12 @@ pub async fn register( | |||
| tracing::warn!("failed to send reply to coordinator: {err}"); | |||
| continue; | |||
| }; | |||
| if let DaemonCoordinatorReply::DestroyResult { notify, .. } = reply { | |||
| if let Some(notify) = notify { | |||
| let _ = notify.send(()); | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| @@ -76,13 +76,11 @@ pub struct Daemon { | |||
| } | |||
| impl Daemon { | |||
| pub async fn run( | |||
| coordinator_addr: SocketAddr, | |||
| machine_id: String, | |||
| external_events: impl Stream<Item = Timestamped<Event>> + Unpin, | |||
| ) -> eyre::Result<()> { | |||
| pub async fn run(coordinator_addr: SocketAddr, machine_id: String) -> eyre::Result<()> { | |||
| let clock = Arc::new(HLC::default()); | |||
| let ctrlc_events = set_up_ctrlc_handler(clock.clone())?; | |||
| // spawn listen loop | |||
| let (events_tx, events_rx) = flume::bounded(10); | |||
| let listen_socket = | |||
| @@ -108,7 +106,7 @@ impl Daemon { | |||
| ); | |||
| Self::run_general( | |||
| (coordinator_events, external_events, daemon_events).merge(), | |||
| (coordinator_events, ctrlc_events, daemon_events).merge(), | |||
| Some(coordinator_addr), | |||
| machine_id, | |||
| None, | |||
| @@ -159,7 +157,7 @@ impl Daemon { | |||
| let run_result = Self::run_general( | |||
| Box::pin(coordinator_events), | |||
| None, | |||
| "".into(), | |||
| "".to_string(), | |||
| Some(exit_when_done), | |||
| clock, | |||
| ); | |||
| @@ -429,10 +427,18 @@ impl Daemon { | |||
| } | |||
| DaemonCoordinatorEvent::Destroy => { | |||
| tracing::info!("received destroy command -> exiting"); | |||
| let reply = DaemonCoordinatorReply::DestroyResult(Ok(())); | |||
| let (notify_tx, notify_rx) = oneshot::channel(); | |||
| let reply = DaemonCoordinatorReply::DestroyResult { | |||
| result: Ok(()), | |||
| notify: Some(notify_tx), | |||
| }; | |||
| let _ = reply_tx | |||
| .send(Some(reply)) | |||
| .map_err(|_| error!("could not send destroy reply from daemon to coordinator")); | |||
| // wait until the reply is sent out | |||
| if notify_rx.await.is_err() { | |||
| tracing::warn!("no confirmation received for DestroyReply"); | |||
| } | |||
| RunStatus::Exit | |||
| } | |||
| DaemonCoordinatorEvent::Heartbeat => { | |||
| @@ -1062,10 +1068,6 @@ impl Daemon { | |||
| } | |||
| } | |||
| pub fn run_dora_runtime() -> eyre::Result<()> { | |||
| dora_runtime::main() | |||
| } | |||
| async fn send_output_to_local_receivers( | |||
| node_id: NodeId, | |||
| output_id: DataId, | |||
| @@ -1519,3 +1521,33 @@ fn send_with_timestamp<T>( | |||
| timestamp: clock.new_timestamp(), | |||
| }) | |||
| } | |||
| fn set_up_ctrlc_handler( | |||
| clock: Arc<HLC>, | |||
| ) -> Result<impl Stream<Item = Timestamped<Event>>, eyre::ErrReport> { | |||
| let (ctrlc_tx, ctrlc_rx) = mpsc::channel(1); | |||
| let mut ctrlc_sent = false; | |||
| ctrlc::set_handler(move || { | |||
| if ctrlc_sent { | |||
| tracing::warn!("received second ctrlc signal -> aborting immediately"); | |||
| std::process::abort(); | |||
| } else { | |||
| tracing::info!("received ctrlc signal"); | |||
| if ctrlc_tx | |||
| .blocking_send(Timestamped { | |||
| inner: Event::CtrlC, | |||
| timestamp: clock.new_timestamp(), | |||
| }) | |||
| .is_err() | |||
| { | |||
| tracing::error!("failed to report ctrl-c event to dora-coordinator"); | |||
| } | |||
| ctrlc_sent = true; | |||
| } | |||
| }) | |||
| .wrap_err("failed to set ctrl-c handler")?; | |||
| Ok(ReceiverStream::new(ctrlc_rx)) | |||
| } | |||
| @@ -1,98 +0,0 @@ | |||
| use dora_core::{ | |||
| daemon_messages::Timestamped, message::uhlc::HLC, topics::DORA_COORDINATOR_PORT_DEFAULT, | |||
| }; | |||
| use dora_daemon::{Daemon, Event}; | |||
| #[cfg(feature = "tracing")] | |||
| use dora_tracing::set_up_tracing; | |||
| #[cfg(feature = "tracing")] | |||
| use eyre::Context; | |||
| use tokio::sync::mpsc; | |||
| use tokio_stream::wrappers::ReceiverStream; | |||
| use std::{ | |||
| net::{Ipv4Addr, SocketAddr}, | |||
| path::PathBuf, | |||
| }; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| #[clap(about = "Dora daemon")] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub machine_id: Option<String>, | |||
| #[clap(long)] | |||
| pub coordinator_addr: Option<SocketAddr>, | |||
| #[clap(long)] | |||
| pub run_dataflow: Option<PathBuf>, | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| // the tokio::main proc macro confuses some tools such as rust-analyzer, so | |||
| // directly invoke a "normal" async function | |||
| run().await | |||
| } | |||
| async fn run() -> eyre::Result<()> { | |||
| let Args { | |||
| run_dataflow, | |||
| machine_id, | |||
| coordinator_addr, | |||
| run_dora_runtime, | |||
| } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| #[cfg(feature = "tracing")] | |||
| set_up_tracing("dora-daemon").wrap_err("failed to set up tracing subscriber")?; | |||
| let ctrl_c_events = { | |||
| let (ctrl_c_tx, ctrl_c_rx) = mpsc::channel(1); | |||
| let mut ctrlc_sent = false; | |||
| ctrlc::set_handler(move || { | |||
| let clock = HLC::default(); | |||
| if ctrlc_sent { | |||
| tracing::warn!("received second ctrlc signal -> aborting immediately"); | |||
| std::process::abort(); | |||
| } else { | |||
| tracing::info!("received ctrlc signal"); | |||
| let event = Timestamped { | |||
| inner: Event::CtrlC, | |||
| timestamp: clock.new_timestamp(), | |||
| }; | |||
| if ctrl_c_tx.blocking_send(event).is_err() { | |||
| tracing::error!("failed to report ctrl-c event to dora-daemon"); | |||
| } | |||
| ctrlc_sent = true; | |||
| } | |||
| }) | |||
| .wrap_err("failed to set ctrl-c handler")?; | |||
| ReceiverStream::new(ctrl_c_rx) | |||
| }; | |||
| match run_dataflow { | |||
| Some(dataflow_path) => { | |||
| tracing::info!("Starting dataflow `{}`", dataflow_path.display()); | |||
| Daemon::run_dataflow(&dataflow_path).await | |||
| } | |||
| None => { | |||
| Daemon::run( | |||
| coordinator_addr.unwrap_or_else(|| { | |||
| tracing::info!("Starting in local mode"); | |||
| let localhost = Ipv4Addr::new(127, 0, 0, 1); | |||
| (localhost, DORA_COORDINATOR_PORT_DEFAULT).into() | |||
| }), | |||
| machine_id.unwrap_or_default(), | |||
| ctrl_c_events, | |||
| ) | |||
| .await | |||
| } | |||
| } | |||
| } | |||
| @@ -299,7 +299,11 @@ impl Listener { | |||
| // iterate over queued events, newest first | |||
| for event in self.queue.iter_mut().rev() { | |||
| let Some(Timestamped { inner: NodeEvent::Input { id, data, .. }, ..}) = event.as_mut() else { | |||
| let Some(Timestamped { | |||
| inner: NodeEvent::Input { id, data, .. }, | |||
| .. | |||
| }) = event.as_mut() | |||
| else { | |||
| continue; | |||
| }; | |||
| match queue_size_remaining.get_mut(id) { | |||
| @@ -128,11 +128,20 @@ impl PendingNodes { | |||
| None => Ok(()), | |||
| } | |||
| } else { | |||
| let node_id_message = if self.exited_before_subscribe.len() == 1 { | |||
| self.exited_before_subscribe | |||
| .iter() | |||
| .next() | |||
| .map(|node_id| node_id.to_string()) | |||
| .unwrap_or("<node_id>".to_string()) | |||
| } else { | |||
| "<node_id>".to_string() | |||
| }; | |||
| Err(format!( | |||
| "Some nodes exited before subscribing to dora: {:?}\n\n\ | |||
| This is typically happens when an initialization error occurs | |||
| in the node or operator. To check the output of the failed | |||
| nodes, run `dora logs {} <node_id>`.", | |||
| nodes, run `dora logs {} {node_id_message}`.", | |||
| self.exited_before_subscribe, self.dataflow_id | |||
| )) | |||
| }; | |||
| @@ -8,6 +8,7 @@ use dora_core::{ | |||
| descriptor::{ | |||
| resolve_path, source_is_url, Descriptor, OperatorSource, ResolvedNode, SHELL_SOURCE, | |||
| }, | |||
| get_python_path, | |||
| message::uhlc::HLC, | |||
| }; | |||
| use dora_download::download_file; | |||
| @@ -81,8 +82,21 @@ pub async fn spawn_node( | |||
| })? | |||
| }; | |||
| tracing::info!("spawning {}", resolved_path.display()); | |||
| let mut cmd = tokio::process::Command::new(&resolved_path); | |||
| // If extension is .py, use python to run the script | |||
| let mut cmd = match resolved_path.extension().map(|ext| ext.to_str()) { | |||
| Some(Some("py")) => { | |||
| let python = get_python_path().context("Could not get python path")?; | |||
| tracing::info!("spawning: {:?} {}", &python, resolved_path.display()); | |||
| let mut cmd = tokio::process::Command::new(&python); | |||
| cmd.arg(&resolved_path); | |||
| cmd | |||
| } | |||
| _ => { | |||
| tracing::info!("spawning: {}", resolved_path.display()); | |||
| tokio::process::Command::new(&resolved_path) | |||
| } | |||
| }; | |||
| if let Some(args) = &n.args { | |||
| cmd.args(args.split_ascii_whitespace()); | |||
| } | |||
| @@ -120,7 +134,7 @@ pub async fn spawn_node( | |||
| format!( | |||
| "failed to run `{}` with args `{}`", | |||
| n.source, | |||
| n.args.as_deref().unwrap_or_default() | |||
| n.args.as_deref().unwrap_or_default(), | |||
| ) | |||
| })? | |||
| } | |||
| @@ -137,7 +151,8 @@ pub async fn spawn_node( | |||
| let mut command = if has_python_operator && !has_other_operator { | |||
| // Use python to spawn runtime if there is a python operator | |||
| let mut command = tokio::process::Command::new("python3"); | |||
| let python = get_python_path().context("Could not find python in daemon")?; | |||
| let mut command = tokio::process::Command::new(python); | |||
| command.args([ | |||
| "-c", | |||
| format!("import dora; dora.start_runtime() # {}", node.id).as_str(), | |||
| @@ -147,7 +162,7 @@ pub async fn spawn_node( | |||
| let mut cmd = tokio::process::Command::new( | |||
| std::env::current_exe().wrap_err("failed to get current executable path")?, | |||
| ); | |||
| cmd.arg("--run-dora-runtime"); | |||
| cmd.arg("runtime"); | |||
| cmd | |||
| } else { | |||
| eyre::bail!("Runtime can not mix Python Operator with other type of operator."); | |||
| @@ -15,8 +15,6 @@ dora-operator-api-types = { workspace = true } | |||
| dora-core = { workspace = true } | |||
| dora-tracing = { workspace = true, optional = true } | |||
| dora-metrics = { workspace = true, optional = true } | |||
| opentelemetry = { version = "0.21.0", features = ["metrics"], optional = true } | |||
| opentelemetry-system-metrics = { version = "0.1.6", optional = true } | |||
| eyre = "0.6.8" | |||
| futures = "0.3.21" | |||
| futures-concurrency = "7.1.0" | |||
| @@ -27,19 +25,16 @@ tokio-stream = "0.1.8" | |||
| # pyo3-abi3 flag allow simpler linking. See: https://pyo3.rs/v0.13.2/building_and_distribution.html | |||
| pyo3 = { workspace = true, features = ["eyre", "abi3-py37"], optional = true } | |||
| tracing = "0.1.36" | |||
| tracing-subscriber = "0.3.15" | |||
| dora-download = { workspace = true } | |||
| flume = "0.10.14" | |||
| clap = { version = "4.0.3", features = ["derive"] } | |||
| tracing-opentelemetry = { version = "0.18.0", optional = true } | |||
| pythonize = { workspace = true, optional = true } | |||
| arrow-schema = { workspace = true } | |||
| arrow = { workspace = true, features = ["ffi"] } | |||
| aligned-vec = "0.5.0" | |||
| [features] | |||
| default = ["tracing", "metrics"] | |||
| tracing = ["dora-tracing"] | |||
| telemetry = ["tracing", "opentelemetry", "tracing-opentelemetry"] | |||
| metrics = ["opentelemetry", "opentelemetry-system-metrics", "dora-metrics"] | |||
| telemetry = ["tracing", "tracing-opentelemetry"] | |||
| metrics = ["dora-metrics"] | |||
| python = ["pyo3", "dora-operator-api-python", "pythonize", "arrow/pyarrow"] | |||
| @@ -5,6 +5,7 @@ use dora_core::{ | |||
| daemon_messages::{NodeConfig, RuntimeConfig}, | |||
| descriptor::OperatorConfig, | |||
| }; | |||
| use dora_metrics::init_meter_provider; | |||
| use dora_node_api::{DoraNode, Event}; | |||
| use eyre::{bail, Context, Result}; | |||
| use futures::{Stream, StreamExt}; | |||
| @@ -14,7 +15,6 @@ use operator::{run_operator, OperatorEvent, StopReason}; | |||
| #[cfg(feature = "tracing")] | |||
| use dora_tracing::set_up_tracing; | |||
| use std::{ | |||
| borrow::Cow, | |||
| collections::{BTreeMap, BTreeSet, HashMap}, | |||
| mem, | |||
| }; | |||
| @@ -123,19 +123,7 @@ async fn run( | |||
| init_done: oneshot::Receiver<Result<()>>, | |||
| ) -> eyre::Result<()> { | |||
| #[cfg(feature = "metrics")] | |||
| let _started = { | |||
| use dora_metrics::init_metrics; | |||
| use opentelemetry::global; | |||
| use opentelemetry_system_metrics::init_process_observer; | |||
| let _started = init_metrics().context("Could not create opentelemetry meter")?; | |||
| let meter = global::meter(Cow::Borrowed(Box::leak( | |||
| config.node_id.to_string().into_boxed_str(), | |||
| ))); | |||
| init_process_observer(meter).context("could not initiale system metrics observer")?; | |||
| _started | |||
| }; | |||
| let _meter_provider = init_meter_provider(config.node_id.to_string()); | |||
| init_done | |||
| .await | |||
| .wrap_err("the `init_done` channel was closed unexpectedly")? | |||
| @@ -190,7 +178,9 @@ async fn run( | |||
| } | |||
| let Some(config) = operators.get(&operator_id) else { | |||
| tracing::warn!("received Finished event for unknown operator `{operator_id}`"); | |||
| tracing::warn!( | |||
| "received Finished event for unknown operator `{operator_id}`" | |||
| ); | |||
| continue; | |||
| }; | |||
| let outputs = config | |||
| @@ -13,6 +13,7 @@ pub mod channel; | |||
| mod python; | |||
| mod shared_lib; | |||
| #[allow(unused_variables)] | |||
| pub fn run_operator( | |||
| node_id: &NodeId, | |||
| operator_definition: OperatorDefinition, | |||
| @@ -123,7 +123,9 @@ pub fn run( | |||
| let mut reload = false; | |||
| let reason = loop { | |||
| #[allow(unused_mut)] | |||
| let Ok(mut event) = incoming_events.recv() else { break StopReason::InputsClosed }; | |||
| let Ok(mut event) = incoming_events.recv() else { | |||
| break StopReason::InputsClosed; | |||
| }; | |||
| if let Event::Reload { .. } = event { | |||
| reload = true; | |||
| @@ -1,23 +1,10 @@ | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use std::path::Path; | |||
| use tracing::metadata::LevelFilter; | |||
| use tracing_subscriber::Layer; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| let Args { run_dora_runtime } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| set_up_tracing().wrap_err("failed to set up tracing subscriber")?; | |||
| set_up_tracing("benchmark-runner").wrap_err("failed to set up tracing subscriber")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| @@ -26,7 +13,7 @@ async fn main() -> eyre::Result<()> { | |||
| let dataflow = Path::new("dataflow.yml"); | |||
| build_dataflow(dataflow).await?; | |||
| dora_daemon::Daemon::run_dataflow(dataflow).await?; | |||
| run_dataflow(dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| @@ -43,13 +30,17 @@ async fn build_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| fn set_up_tracing() -> eyre::Result<()> { | |||
| use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(LevelFilter::DEBUG); | |||
| let subscriber = tracing_subscriber::Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(subscriber) | |||
| .context("failed to set tracing global subscriber") | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| @@ -10,7 +10,7 @@ nodes: | |||
| custom: | |||
| source: build/node_c_api | |||
| inputs: | |||
| tick: dora/timer/millis/300 | |||
| tick: cxx-node-rust-api/counter | |||
| outputs: | |||
| - counter | |||
| @@ -1,25 +1,13 @@ | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use std::{ | |||
| env::consts::{DLL_PREFIX, DLL_SUFFIX, EXE_SUFFIX}, | |||
| path::Path, | |||
| }; | |||
| use tracing::metadata::LevelFilter; | |||
| use tracing_subscriber::Layer; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| let Args { run_dora_runtime } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| set_up_tracing().wrap_err("failed to set up tracing")?; | |||
| set_up_tracing("c++-dataflow-runner").wrap_err("failed to set up tracing")?; | |||
| if cfg!(windows) { | |||
| tracing::error!( | |||
| @@ -124,7 +112,7 @@ async fn main() -> eyre::Result<()> { | |||
| let dataflow = Path::new("dataflow.yml").to_owned(); | |||
| build_package("dora-runtime").await?; | |||
| dora_daemon::Daemon::run_dataflow(&dataflow).await?; | |||
| run_dataflow(&dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| @@ -140,6 +128,21 @@ async fn build_package(package: &str) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| async fn build_cxx_node( | |||
| root: &Path, | |||
| paths: &[&Path], | |||
| @@ -294,14 +297,3 @@ async fn build_cxx_operator( | |||
| Ok(()) | |||
| } | |||
| fn set_up_tracing() -> eyre::Result<()> { | |||
| use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(LevelFilter::DEBUG); | |||
| let subscriber = tracing_subscriber::Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(subscriber) | |||
| .context("failed to set tracing global subscriber") | |||
| } | |||
| @@ -1,26 +1,13 @@ | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use std::{ | |||
| env::consts::{DLL_PREFIX, DLL_SUFFIX, EXE_SUFFIX}, | |||
| path::Path, | |||
| }; | |||
| use tracing::metadata::LevelFilter; | |||
| use tracing_subscriber::Layer; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| let Args { run_dora_runtime } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| set_up_tracing().wrap_err("failed to set up tracing")?; | |||
| set_up_tracing("c-dataflow-runner").wrap_err("failed to set up tracing")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| @@ -36,7 +23,7 @@ async fn main() -> eyre::Result<()> { | |||
| build_c_operator(root).await?; | |||
| let dataflow = Path::new("dataflow.yml").to_owned(); | |||
| dora_daemon::Daemon::run_dataflow(&dataflow).await?; | |||
| run_dataflow(&dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| @@ -52,6 +39,21 @@ async fn build_package(package: &str) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| async fn build_c_node(root: &Path, name: &str, out_name: &str) -> eyre::Result<()> { | |||
| let mut clang = tokio::process::Command::new("clang"); | |||
| clang.arg(name); | |||
| @@ -182,14 +184,3 @@ async fn build_c_operator(root: &Path) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| fn set_up_tracing() -> eyre::Result<()> { | |||
| use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(LevelFilter::DEBUG); | |||
| let subscriber = tracing_subscriber::Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(subscriber) | |||
| .context("failed to set tracing global subscriber") | |||
| } | |||
| @@ -15,7 +15,7 @@ if(DORA_ROOT_DIR) | |||
| ) | |||
| FetchContent_MakeAvailable(Corrosion) | |||
| list(PREPEND CMAKE_MODULE_PATH ${Corrosion_SOURCE_DIR}/cmake) | |||
| find_package(Rust 1.70 REQUIRED MODULE) | |||
| find_package(Rust 1.72 REQUIRED MODULE) | |||
| corrosion_import_crate(MANIFEST_PATH "${DORA_ROOT_DIR}/Cargo.toml" | |||
| CRATES | |||
| dora-node-api-c | |||
| @@ -1,22 +1,10 @@ | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use std::path::Path; | |||
| use tracing::metadata::LevelFilter; | |||
| use tracing_subscriber::Layer; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| let Args { run_dora_runtime } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| set_up_tracing().wrap_err("failed to set up tracing")?; | |||
| set_up_tracing("cmake-dataflow-runner").wrap_err("failed to set up tracing")?; | |||
| if cfg!(windows) { | |||
| tracing::error!( | |||
| @@ -52,7 +40,7 @@ async fn main() -> eyre::Result<()> { | |||
| let dataflow = Path::new("dataflow.yml").to_owned(); | |||
| build_package("dora-runtime").await?; | |||
| dora_daemon::Daemon::run_dataflow(&dataflow).await?; | |||
| run_dataflow(&dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| @@ -68,13 +56,17 @@ async fn build_package(package: &str) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| fn set_up_tracing() -> eyre::Result<()> { | |||
| use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(LevelFilter::DEBUG); | |||
| let subscriber = tracing_subscriber::Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(subscriber) | |||
| .context("failed to set tracing global subscriber") | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| @@ -3,8 +3,9 @@ use dora_core::{ | |||
| descriptor::Descriptor, | |||
| topics::{ControlRequest, ControlRequestReply, DataflowId}, | |||
| }; | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use futures::stream; | |||
| use std::{ | |||
| collections::BTreeSet, | |||
| net::{Ipv4Addr, SocketAddr}, | |||
| @@ -19,24 +20,11 @@ use tokio::{ | |||
| task::JoinSet, | |||
| }; | |||
| use tokio_stream::wrappers::ReceiverStream; | |||
| use tracing::metadata::LevelFilter; | |||
| use tracing_subscriber::Layer; | |||
| use uuid::Uuid; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| let Args { run_dora_runtime } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| set_up_tracing().wrap_err("failed to set up tracing subscriber")?; | |||
| set_up_tracing("multiple-daemon-runner").wrap_err("failed to set up tracing subscriber")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| @@ -46,14 +34,11 @@ async fn main() -> eyre::Result<()> { | |||
| build_dataflow(dataflow).await?; | |||
| let (coordinator_events_tx, coordinator_events_rx) = mpsc::channel(1); | |||
| let (coordinator_port, coordinator) = dora_coordinator::start( | |||
| dora_coordinator::Args { port: Some(0) }, | |||
| ReceiverStream::new(coordinator_events_rx), | |||
| ) | |||
| .await?; | |||
| let (coordinator_port, coordinator) = | |||
| dora_coordinator::start(None, ReceiverStream::new(coordinator_events_rx)).await?; | |||
| let coordinator_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), coordinator_port); | |||
| let daemon_a = dora_daemon::Daemon::run(coordinator_addr, "A".into(), stream::empty()); | |||
| let daemon_b = dora_daemon::Daemon::run(coordinator_addr, "B".into(), stream::empty()); | |||
| let daemon_a = run_daemon(coordinator_addr.to_string(), "A".into()); | |||
| let daemon_b = run_daemon(coordinator_addr.to_string(), "B".into()); | |||
| tracing::info!("Spawning coordinator and daemons"); | |||
| let mut tasks = JoinSet::new(); | |||
| @@ -61,7 +46,6 @@ async fn main() -> eyre::Result<()> { | |||
| tasks.spawn(daemon_a); | |||
| tasks.spawn(daemon_b); | |||
| // wait until both daemons are connected | |||
| tracing::info!("waiting until daemons are connected to coordinator"); | |||
| let mut retries = 0; | |||
| loop { | |||
| @@ -212,13 +196,19 @@ async fn build_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| fn set_up_tracing() -> eyre::Result<()> { | |||
| use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(LevelFilter::TRACE); | |||
| let subscriber = tracing_subscriber::Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(subscriber) | |||
| .context("failed to set tracing global subscriber") | |||
| async fn run_daemon(coordinator: String, machine_id: &str) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--machine-id") | |||
| .arg(machine_id) | |||
| .arg("--coordinator-addr") | |||
| .arg(coordinator); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| @@ -1,14 +1,14 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| from typing import Callable | |||
| from dora import Node | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from ultralytics import YOLO | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| model = torch.hub.load("ultralytics/yolov5", "yolov5n") | |||
| model = YOLO("yolov8n.pt") | |||
| node = Node() | |||
| @@ -22,9 +22,14 @@ for event in node: | |||
| frame = cv2.imdecode(frame, -1) | |||
| frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB) | |||
| results = model(frame) # includes NMS | |||
| arrays = np.array(results.xyxy[0].cpu()).tobytes() | |||
| # Process results | |||
| boxes = np.array(results[0].boxes.xyxy.cpu()) | |||
| conf = np.array(results[0].boxes.conf.cpu()) | |||
| label = np.array(results[0].boxes.cls.cpu()) | |||
| # concatenate them together | |||
| arrays = np.concatenate((boxes, conf[:, None], label[:, None]), axis=1) | |||
| node.send_output("bbox", arrays, event["metadata"]) | |||
| node.send_output("bbox", pa.array(arrays.ravel()), event["metadata"]) | |||
| else: | |||
| print("[object detection] ignoring unexpected input:", event_id) | |||
| elif event_type == "STOP": | |||
| @@ -2,7 +2,6 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| from typing import Callable | |||
| from dora import Node | |||
| from dora import DoraStatus | |||
| @@ -14,10 +14,8 @@ PyYAML>=5.3.1 | |||
| requests>=2.23.0 | |||
| scipy>=1.4.1 | |||
| thop>=0.1.1 # FLOPs computation | |||
| --extra-index-url https://download.pytorch.org/whl/cpu | |||
| torch>=1.7.0 # see https://pytorch.org/get-started/locally (recommended) | |||
| --extra-index-url https://download.pytorch.org/whl/cpu | |||
| torchvision>=0.8.1 | |||
| torch # see https://pytorch.org/get-started/locally (recommended) | |||
| torchvision | |||
| tqdm>=4.64.0 | |||
| # Logging ------------------------------------- | |||
| @@ -1,79 +1,102 @@ | |||
| use eyre::{ContextCompat, WrapErr}; | |||
| use dora_core::{get_pip_path, get_python_path, run}; | |||
| use dora_download::download_file; | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, ContextCompat, WrapErr}; | |||
| use std::path::Path; | |||
| use tracing_subscriber::{ | |||
| filter::{FilterExt, LevelFilter}, | |||
| prelude::*, | |||
| EnvFilter, Registry, | |||
| }; | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| set_up_tracing()?; | |||
| set_up_tracing("python-dataflow-runner")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| .wrap_err("failed to set working dir")?; | |||
| run(&["python3", "-m", "venv", "../.env"], None) | |||
| .await | |||
| .context("failed to create venv")?; | |||
| run( | |||
| get_python_path().context("Could not get python binary")?, | |||
| &["-m", "venv", "../.env"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("failed to create venv")?; | |||
| let venv = &root.join("examples").join(".env"); | |||
| std::env::set_var( | |||
| "VIRTUAL_ENV", | |||
| venv.to_str().context("venv path not valid unicode")?, | |||
| ); | |||
| let orig_path = std::env::var("PATH")?; | |||
| let venv_bin = venv.join("bin"); | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{}:{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| // bin folder is named Scripts on windows. | |||
| // 🤦♂️ See: https://github.com/pypa/virtualenv/commit/993ba1316a83b760370f5a3872b3f5ef4dd904c1 | |||
| let venv_bin = if cfg!(windows) { | |||
| venv.join("Scripts") | |||
| } else { | |||
| venv.join("bin") | |||
| }; | |||
| run(&["pip", "install", "--upgrade", "pip"], None) | |||
| .await | |||
| .context("failed to install pip")?; | |||
| run(&["pip", "install", "-r", "requirements.txt"], None) | |||
| .await | |||
| .context("pip install failed")?; | |||
| if cfg!(windows) { | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{};{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| } else { | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{}:{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| } | |||
| run( | |||
| get_python_path().context("Could not get pip binary")?, | |||
| &["-m", "pip", "install", "--upgrade", "pip"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("failed to install pip")?; | |||
| run( | |||
| get_pip_path().context("Could not get pip binary")?, | |||
| &["install", "-r", "requirements.txt"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("pip install failed")?; | |||
| run( | |||
| &["maturin", "develop"], | |||
| "maturin", | |||
| &["develop"], | |||
| Some(&root.join("apis").join("python").join("node")), | |||
| ) | |||
| .await | |||
| .context("maturin develop failed")?; | |||
| download_file( | |||
| "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt", | |||
| Path::new("yolov8n.pt"), | |||
| ) | |||
| .await | |||
| .context("Could not download weights.")?; | |||
| let dataflow = Path::new("dataflow.yml"); | |||
| dora_daemon::Daemon::run_dataflow(dataflow).await?; | |||
| run_dataflow(dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| async fn run(cmd: &[&str], pwd: Option<&Path>) -> eyre::Result<()> { | |||
| let mut run = tokio::process::Command::new(cmd[0]); | |||
| run.args(&cmd[1..]); | |||
| if let Some(pwd) = pwd { | |||
| run.current_dir(pwd); | |||
| } | |||
| if !run.status().await?.success() { | |||
| eyre::bail!("failed to run {cmd:?}"); | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| pub fn set_up_tracing() -> eyre::Result<()> { | |||
| // Filter log using `RUST_LOG`. More useful for CLI. | |||
| let filter = EnvFilter::from_default_env().or(LevelFilter::DEBUG); | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(filter); | |||
| let registry = Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(registry) | |||
| .context("failed to set tracing global subscriber") | |||
| } | |||
| @@ -21,10 +21,3 @@ nodes: | |||
| inputs: | |||
| image: webcam/image | |||
| bbox: object_detection/bbox | |||
| - id: dora-record | |||
| custom: | |||
| build: cargo build -p dora-record | |||
| source: ../../target/debug/dora-record | |||
| inputs: | |||
| image: webcam/image | |||
| bbox: object_detection/bbox | |||
| @@ -2,12 +2,11 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| import torch | |||
| from dora import DoraStatus | |||
| from ultralytics import YOLO | |||
| pa.array([]) | |||
| @@ -21,7 +20,7 @@ class Operator: | |||
| """ | |||
| def __init__(self): | |||
| self.model = torch.hub.load("ultralytics/yolov5", "yolov5n") | |||
| self.model = YOLO("yolov8n.pt") | |||
| def on_event( | |||
| self, | |||
| @@ -51,6 +50,12 @@ class Operator: | |||
| frame = dora_input["value"].to_numpy().reshape((CAMERA_HEIGHT, CAMERA_WIDTH, 3)) | |||
| frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB) | |||
| results = self.model(frame) # includes NMS | |||
| arrays = pa.array(np.array(results.xyxy[0].cpu()).ravel()) | |||
| send_output("bbox", arrays, dora_input["metadata"]) | |||
| # Process results | |||
| boxes = np.array(results[0].boxes.xyxy.cpu()) | |||
| conf = np.array(results[0].boxes.conf.cpu()) | |||
| label = np.array(results[0].boxes.cls.cpu()) | |||
| # concatenate them together | |||
| arrays = np.concatenate((boxes, conf[:, None], label[:, None]), axis=1) | |||
| send_output("bbox", pa.array(arrays.ravel()), dora_input["metadata"]) | |||
| return DoraStatus.CONTINUE | |||
| @@ -3,19 +3,20 @@ | |||
| # Base ---------------------------------------- | |||
| ultralytics | |||
| gitpython | |||
| ipython # interactive notebook | |||
| matplotlib>=3.2.2 | |||
| numpy>=1.18.5 | |||
| opencv-python>=4.1.1 | |||
| Pillow>=7.1.2 | |||
| psutil # system resources | |||
| PyYAML>=5.3.1 | |||
| requests>=2.23.0 | |||
| scipy>=1.4.1 | |||
| --extra-index-url https://download.pytorch.org/whl/cpu | |||
| torch>=1.7.0 | |||
| --extra-index-url https://download.pytorch.org/whl/cpu | |||
| torchvision>=0.8.1 | |||
| thop>=0.1.1 # FLOPs computation | |||
| torch # see https://pytorch.org/get-started/locally (recommended) | |||
| torchvision | |||
| tqdm>=4.64.0 | |||
| protobuf<=3.20.1 # https://github.com/ultralytics/yolov5/issues/8012 | |||
| # Logging ------------------------------------- | |||
| tensorboard>=2.4.1 | |||
| @@ -38,13 +39,9 @@ seaborn>=0.11.0 | |||
| # openvino-dev # OpenVINO export | |||
| # Extras -------------------------------------- | |||
| ipython # interactive notebook | |||
| psutil # system utilization | |||
| thop>=0.1.1 # FLOPs computation | |||
| # albumentations>=1.0.3 | |||
| # pycocotools>=2.0 # COCO mAP | |||
| # roboflow | |||
| opencv-python>=4.1.1 | |||
| pyarrow | |||
| maturin | |||
| @@ -1,79 +1,102 @@ | |||
| use eyre::{ContextCompat, WrapErr}; | |||
| use dora_core::{get_pip_path, get_python_path, run}; | |||
| use dora_download::download_file; | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, ContextCompat, WrapErr}; | |||
| use std::path::Path; | |||
| use tracing_subscriber::{ | |||
| filter::{FilterExt, LevelFilter}, | |||
| prelude::*, | |||
| EnvFilter, Registry, | |||
| }; | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| set_up_tracing()?; | |||
| set_up_tracing("python-operator-dataflow-runner")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| .wrap_err("failed to set working dir")?; | |||
| run(&["python3", "-m", "venv", "../.env"], None) | |||
| .await | |||
| .context("failed to create venv")?; | |||
| run( | |||
| get_python_path().context("Could not get python binary")?, | |||
| &["-m", "venv", "../.env"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("failed to create venv")?; | |||
| let venv = &root.join("examples").join(".env"); | |||
| std::env::set_var( | |||
| "VIRTUAL_ENV", | |||
| venv.to_str().context("venv path not valid unicode")?, | |||
| ); | |||
| let orig_path = std::env::var("PATH")?; | |||
| let venv_bin = venv.join("bin"); | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{}:{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| // bin folder is named Scripts on windows. | |||
| // 🤦♂️ See: https://github.com/pypa/virtualenv/commit/993ba1316a83b760370f5a3872b3f5ef4dd904c1 | |||
| let venv_bin = if cfg!(windows) { | |||
| venv.join("Scripts") | |||
| } else { | |||
| venv.join("bin") | |||
| }; | |||
| run(&["pip", "install", "--upgrade", "pip"], None) | |||
| .await | |||
| .context("failed to install pip")?; | |||
| run(&["pip", "install", "-r", "requirements.txt"], None) | |||
| .await | |||
| .context("pip install failed")?; | |||
| if cfg!(windows) { | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{};{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| } else { | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{}:{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| } | |||
| run( | |||
| get_python_path().context("Could not get pip binary")?, | |||
| &["-m", "pip", "install", "--upgrade", "pip"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("failed to install pip")?; | |||
| run( | |||
| get_pip_path().context("Could not get pip binary")?, | |||
| &["install", "-r", "requirements.txt"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("pip install failed")?; | |||
| run( | |||
| &["maturin", "develop"], | |||
| "maturin", | |||
| &["develop"], | |||
| Some(&root.join("apis").join("python").join("node")), | |||
| ) | |||
| .await | |||
| .context("maturin develop failed")?; | |||
| download_file( | |||
| "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt", | |||
| Path::new("yolov8n.pt"), | |||
| ) | |||
| .await | |||
| .context("Could not download weights.")?; | |||
| let dataflow = Path::new("dataflow.yml"); | |||
| dora_daemon::Daemon::run_dataflow(dataflow).await?; | |||
| run_dataflow(dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| async fn run(cmd: &[&str], pwd: Option<&Path>) -> eyre::Result<()> { | |||
| let mut run = tokio::process::Command::new(cmd[0]); | |||
| run.args(&cmd[1..]); | |||
| if let Some(pwd) = pwd { | |||
| run.current_dir(pwd); | |||
| } | |||
| if !run.status().await?.success() { | |||
| eyre::bail!("failed to run {cmd:?}"); | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| pub fn set_up_tracing() -> eyre::Result<()> { | |||
| // Filter log using `RUST_LOG`. More useful for CLI. | |||
| let filter = EnvFilter::from_default_env().or(LevelFilter::DEBUG); | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(filter); | |||
| let registry = Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(registry) | |||
| .context("failed to set tracing global subscriber") | |||
| } | |||
| @@ -21,14 +21,12 @@ topic_qos = dora.experimental.ros2_bridge.Ros2QosPolicies( | |||
| # Create a publisher to cmd_vel topic | |||
| turtle_twist_topic = ros2_node.create_topic( | |||
| "/turtle1/cmd_vel", "geometry_msgs::Twist", topic_qos | |||
| "/turtle1/cmd_vel", "geometry_msgs/Twist", topic_qos | |||
| ) | |||
| twist_writer = ros2_node.create_publisher(turtle_twist_topic) | |||
| # Create a listener to pose topic | |||
| turtle_pose_topic = ros2_node.create_topic( | |||
| "/turtle1/pose", "turtlesim::Pose", topic_qos | |||
| ) | |||
| turtle_pose_topic = ros2_node.create_topic("/turtle1/pose", "turtlesim/Pose", topic_qos) | |||
| pose_reader = ros2_node.create_subscription(turtle_pose_topic) | |||
| # Create a dora node | |||
| @@ -39,6 +37,12 @@ dora_node.merge_external_events(pose_reader) | |||
| print("looping", flush=True) | |||
| # take track of minimum and maximum coordinates of turtle | |||
| min_x = 1000 | |||
| max_x = 0 | |||
| min_y = 1000 | |||
| max_y = 0 | |||
| for i in range(500): | |||
| event = dora_node.next() | |||
| if event is None: | |||
| @@ -55,8 +59,10 @@ for i in range(500): | |||
| # ROS2 Event | |||
| elif event_kind == "external": | |||
| pose = event.inner()[0].as_py() | |||
| if i == CHECK_TICK: | |||
| assert ( | |||
| pose["x"] != 5.544444561004639 | |||
| ), "turtle should not be at initial x axis" | |||
| min_x = min([min_x, pose["x"]]) | |||
| max_x = max([max_x, pose["x"]]) | |||
| min_y = min([min_y, pose["y"]]) | |||
| max_y = max([max_y, pose["y"]]) | |||
| dora_node.send_output("turtle_pose", event.inner()) | |||
| assert max_x - min_x > 1 or max_y - min_y > 1, "no turtle movement" | |||
| @@ -1,81 +1,95 @@ | |||
| use eyre::{ContextCompat, WrapErr}; | |||
| use dora_core::{get_pip_path, get_python_path, run}; | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, ContextCompat, WrapErr}; | |||
| use std::path::Path; | |||
| use tracing_subscriber::{ | |||
| filter::{FilterExt, LevelFilter}, | |||
| prelude::*, | |||
| EnvFilter, Registry, | |||
| }; | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| set_up_tracing()?; | |||
| set_up_tracing("python-ros2-dataflow-runner")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| .wrap_err("failed to set working dir")?; | |||
| run(&["python3", "-m", "venv", "../.env"], None) | |||
| .await | |||
| .context("failed to create venv")?; | |||
| run( | |||
| get_python_path().context("Could not get python binary")?, | |||
| &["-m", "venv", "../.env"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("failed to create venv")?; | |||
| let venv = &root.join("examples").join(".env"); | |||
| std::env::set_var( | |||
| "VIRTUAL_ENV", | |||
| venv.to_str() | |||
| .context("venv path not valid unicode")? | |||
| .to_owned(), | |||
| venv.to_str().context("venv path not valid unicode")?, | |||
| ); | |||
| let orig_path = std::env::var("PATH")?; | |||
| let venv_bin = venv.join("bin"); | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{}:{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| // bin folder is named Scripts on windows. | |||
| // 🤦♂️ See: https://github.com/pypa/virtualenv/commit/993ba1316a83b760370f5a3872b3f5ef4dd904c1 | |||
| let venv_bin = if cfg!(windows) { | |||
| venv.join("Scripts") | |||
| } else { | |||
| venv.join("bin") | |||
| }; | |||
| if cfg!(windows) { | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{};{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| } else { | |||
| std::env::set_var( | |||
| "PATH", | |||
| format!( | |||
| "{}:{orig_path}", | |||
| venv_bin.to_str().context("venv path not valid unicode")? | |||
| ), | |||
| ); | |||
| } | |||
| run(&["pip", "install", "--upgrade", "pip"], None) | |||
| .await | |||
| .context("failed to install pip")?; | |||
| run(&["pip", "install", "-r", "requirements.txt"], None) | |||
| .await | |||
| .context("pip install failed")?; | |||
| run( | |||
| get_python_path().context("Could not get pip binary")?, | |||
| &["-m", "pip", "install", "--upgrade", "pip"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("failed to install pip")?; | |||
| run( | |||
| get_pip_path().context("Could not get pip binary")?, | |||
| &["install", "-r", "requirements.txt"], | |||
| None, | |||
| ) | |||
| .await | |||
| .context("pip install failed")?; | |||
| run( | |||
| &["maturin", "develop"], | |||
| "maturin", | |||
| &["develop"], | |||
| Some(&root.join("apis").join("python").join("node")), | |||
| ) | |||
| .await | |||
| .context("maturin develop failed")?; | |||
| let dataflow = Path::new("dataflow.yml"); | |||
| dora_daemon::Daemon::run_dataflow(dataflow).await?; | |||
| run_dataflow(dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| async fn run(cmd: &[&str], pwd: Option<&Path>) -> eyre::Result<()> { | |||
| let mut run = tokio::process::Command::new(cmd[0]); | |||
| run.args(&cmd[1..]); | |||
| if let Some(pwd) = pwd { | |||
| run.current_dir(pwd); | |||
| } | |||
| if !run.status().await?.success() { | |||
| eyre::bail!("failed to run {cmd:?}"); | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| pub fn set_up_tracing() -> eyre::Result<()> { | |||
| // Filter log using `RUST_LOG`. More useful for CLI. | |||
| let filter = EnvFilter::from_default_env().or(LevelFilter::DEBUG); | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(filter); | |||
| let registry = Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(registry) | |||
| .context("failed to set tracing global subscriber") | |||
| } | |||
| @@ -1,22 +1,10 @@ | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use std::path::Path; | |||
| use tracing::metadata::LevelFilter; | |||
| use tracing_subscriber::Layer; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| let Args { run_dora_runtime } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| set_up_tracing().wrap_err("failed to set up tracing")?; | |||
| set_up_tracing("rust-dataflow-url-runner").wrap_err("failed to set up tracing")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| @@ -25,7 +13,7 @@ async fn main() -> eyre::Result<()> { | |||
| let dataflow = Path::new("dataflow.yml"); | |||
| build_dataflow(dataflow).await?; | |||
| dora_daemon::Daemon::run_dataflow(dataflow).await?; | |||
| run_dataflow(dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| @@ -42,13 +30,17 @@ async fn build_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| fn set_up_tracing() -> eyre::Result<()> { | |||
| use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(LevelFilter::DEBUG); | |||
| let subscriber = tracing_subscriber::Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(subscriber) | |||
| .context("failed to set tracing global subscriber") | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| @@ -1,23 +1,10 @@ | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use std::path::Path; | |||
| use tracing::metadata::LevelFilter; | |||
| use tracing_subscriber::Layer; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| let Args { run_dora_runtime } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| set_up_tracing().wrap_err("failed to set up tracing subscriber")?; | |||
| set_up_tracing("rust-dataflow-runner").wrap_err("failed to set up tracing subscriber")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| @@ -26,7 +13,7 @@ async fn main() -> eyre::Result<()> { | |||
| let dataflow = Path::new("dataflow.yml"); | |||
| build_dataflow(dataflow).await?; | |||
| dora_daemon::Daemon::run_dataflow(dataflow).await?; | |||
| run_dataflow(dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| @@ -43,13 +30,17 @@ async fn build_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| fn set_up_tracing() -> eyre::Result<()> { | |||
| use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(LevelFilter::DEBUG); | |||
| let subscriber = tracing_subscriber::Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(subscriber) | |||
| .context("failed to set tracing global subscriber") | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| @@ -1,23 +1,10 @@ | |||
| use dora_tracing::set_up_tracing; | |||
| use eyre::{bail, Context}; | |||
| use std::path::Path; | |||
| use tracing::metadata::LevelFilter; | |||
| use tracing_subscriber::Layer; | |||
| #[derive(Debug, Clone, clap::Parser)] | |||
| pub struct Args { | |||
| #[clap(long)] | |||
| pub run_dora_runtime: bool, | |||
| } | |||
| #[tokio::main] | |||
| async fn main() -> eyre::Result<()> { | |||
| let Args { run_dora_runtime } = clap::Parser::parse(); | |||
| if run_dora_runtime { | |||
| return tokio::task::block_in_place(dora_daemon::run_dora_runtime); | |||
| } | |||
| set_up_tracing().wrap_err("failed to set up tracing subscriber")?; | |||
| set_up_tracing("rust-ros2-dataflow-runner").wrap_err("failed to set up tracing subscriber")?; | |||
| let root = Path::new(env!("CARGO_MANIFEST_DIR")); | |||
| std::env::set_current_dir(root.join(file!()).parent().unwrap()) | |||
| @@ -26,7 +13,7 @@ async fn main() -> eyre::Result<()> { | |||
| let dataflow = Path::new("dataflow.yml"); | |||
| build_dataflow(dataflow).await?; | |||
| dora_daemon::Daemon::run_dataflow(dataflow).await?; | |||
| run_dataflow(dataflow).await?; | |||
| Ok(()) | |||
| } | |||
| @@ -43,13 +30,17 @@ async fn build_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| Ok(()) | |||
| } | |||
| fn set_up_tracing() -> eyre::Result<()> { | |||
| use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; | |||
| let stdout_log = tracing_subscriber::fmt::layer() | |||
| .pretty() | |||
| .with_filter(LevelFilter::DEBUG); | |||
| let subscriber = tracing_subscriber::Registry::default().with(stdout_log); | |||
| tracing::subscriber::set_global_default(subscriber) | |||
| .context("failed to set tracing global subscriber") | |||
| async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> { | |||
| let cargo = std::env::var("CARGO").unwrap(); | |||
| let mut cmd = tokio::process::Command::new(&cargo); | |||
| cmd.arg("run"); | |||
| cmd.arg("--package").arg("dora-cli"); | |||
| cmd.arg("--") | |||
| .arg("daemon") | |||
| .arg("--run-dataflow") | |||
| .arg(dataflow); | |||
| if !cmd.status().await?.success() { | |||
| bail!("failed to run dataflow"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| @@ -39,7 +39,7 @@ impl TryFrom<&ArrowData> for u8 { | |||
| fn try_from(value: &ArrowData) -> Result<Self, Self::Error> { | |||
| let array = value | |||
| .as_primitive_opt::<arrow::datatypes::UInt8Type>() | |||
| .context("not a primitive array")?; | |||
| .context("not a primitive UInt8Type array")?; | |||
| extract_single_primitive(array) | |||
| } | |||
| } | |||
| @@ -48,7 +48,7 @@ impl TryFrom<&ArrowData> for u16 { | |||
| fn try_from(value: &ArrowData) -> Result<Self, Self::Error> { | |||
| let array = value | |||
| .as_primitive_opt::<arrow::datatypes::UInt16Type>() | |||
| .context("not a primitive array")?; | |||
| .context("not a primitive UInt16Type array")?; | |||
| extract_single_primitive(array) | |||
| } | |||
| } | |||
| @@ -57,7 +57,7 @@ impl TryFrom<&ArrowData> for u32 { | |||
| fn try_from(value: &ArrowData) -> Result<Self, Self::Error> { | |||
| let array = value | |||
| .as_primitive_opt::<arrow::datatypes::UInt32Type>() | |||
| .context("not a primitive array")?; | |||
| .context("not a primitive UInt32Type array")?; | |||
| extract_single_primitive(array) | |||
| } | |||
| } | |||
| @@ -66,7 +66,7 @@ impl TryFrom<&ArrowData> for u64 { | |||
| fn try_from(value: &ArrowData) -> Result<Self, Self::Error> { | |||
| let array = value | |||
| .as_primitive_opt::<arrow::datatypes::UInt64Type>() | |||
| .context("not a primitive array")?; | |||
| .context("not a primitive UInt64Type array")?; | |||
| extract_single_primitive(array) | |||
| } | |||
| } | |||
| @@ -75,7 +75,7 @@ impl TryFrom<&ArrowData> for i8 { | |||
| fn try_from(value: &ArrowData) -> Result<Self, Self::Error> { | |||
| let array = value | |||
| .as_primitive_opt::<arrow::datatypes::Int8Type>() | |||
| .context("not a primitive array")?; | |||
| .context("not a primitive Int8Type array")?; | |||
| extract_single_primitive(array) | |||
| } | |||
| } | |||
| @@ -84,7 +84,7 @@ impl TryFrom<&ArrowData> for i16 { | |||
| fn try_from(value: &ArrowData) -> Result<Self, Self::Error> { | |||
| let array = value | |||
| .as_primitive_opt::<arrow::datatypes::Int16Type>() | |||
| .context("not a primitive array")?; | |||
| .context("not a primitive Int16Type array")?; | |||
| extract_single_primitive(array) | |||
| } | |||
| } | |||
| @@ -93,7 +93,7 @@ impl TryFrom<&ArrowData> for i32 { | |||
| fn try_from(value: &ArrowData) -> Result<Self, Self::Error> { | |||
| let array = value | |||
| .as_primitive_opt::<arrow::datatypes::Int32Type>() | |||
| .context("not a primitive array")?; | |||
| .context("not a primitive Int32Type array")?; | |||
| extract_single_primitive(array) | |||
| } | |||
| } | |||
| @@ -102,7 +102,7 @@ impl TryFrom<&ArrowData> for i64 { | |||
| fn try_from(value: &ArrowData) -> Result<Self, Self::Error> { | |||
| let array = value | |||
| .as_primitive_opt::<arrow::datatypes::Int64Type>() | |||
| .context("not a primitive array")?; | |||
| .context("not a primitive Int64Type array")?; | |||
| extract_single_primitive(array) | |||
| } | |||
| } | |||
| @@ -127,8 +127,9 @@ impl<'a> TryFrom<&'a ArrowData> for &'a str { | |||
| impl<'a> TryFrom<&'a ArrowData> for &'a [u8] { | |||
| type Error = eyre::Report; | |||
| fn try_from(value: &'a ArrowData) -> Result<Self, Self::Error> { | |||
| let array: &PrimitiveArray<arrow::datatypes::UInt8Type> = | |||
| value.as_primitive_opt().wrap_err("not a primitive array")?; | |||
| let array: &PrimitiveArray<arrow::datatypes::UInt8Type> = value | |||
| .as_primitive_opt() | |||
| .wrap_err("not a primitive UInt8Type array")?; | |||
| if array.null_count() != 0 { | |||
| eyre::bail!("array has nulls"); | |||
| } | |||
| @@ -12,12 +12,11 @@ license.workspace = true | |||
| eyre = "0.6.8" | |||
| serde = { version = "1.0.136", features = ["derive"] } | |||
| serde_yaml = "0.9.11" | |||
| serde_bytes = "0.11.12" | |||
| once_cell = "1.13.0" | |||
| which = "4.3.0" | |||
| which = "5.0.0" | |||
| uuid = { version = "1.2.1", features = ["serde"] } | |||
| dora-message = { workspace = true } | |||
| tracing = "0.1" | |||
| serde-with-expand-env = "1.1.0" | |||
| tokio = { version = "1.24.1", features = ["fs"] } | |||
| tokio = { version = "1.24.1", features = ["fs", "process", "sync"] } | |||
| aligned-vec = { version = "0.5.0", features = ["serde"] } | |||
| @@ -247,7 +247,11 @@ pub enum DaemonCoordinatorReply { | |||
| SpawnResult(Result<(), String>), | |||
| ReloadResult(Result<(), String>), | |||
| StopResult(Result<(), String>), | |||
| DestroyResult(Result<(), String>), | |||
| DestroyResult { | |||
| result: Result<(), String>, | |||
| #[serde(skip)] | |||
| notify: Option<tokio::sync::oneshot::Sender<()>>, | |||
| }, | |||
| Logs(Result<Vec<u8>, String>), | |||
| } | |||
| @@ -2,6 +2,7 @@ use crate::{ | |||
| adjust_shared_library_path, | |||
| config::{DataId, Input, InputMapping, OperatorId, UserInputMapping}, | |||
| descriptor::{self, source_is_url, CoreNodeKind, OperatorSource}, | |||
| get_python_path, | |||
| }; | |||
| use eyre::{bail, eyre, Context}; | |||
| @@ -152,7 +153,7 @@ fn check_python_runtime() -> eyre::Result<()> { | |||
| // Check if python dora-rs is installed and match cli version | |||
| let reinstall_command = | |||
| format!("Please reinstall it with: `pip install dora-rs=={VERSION} --force`"); | |||
| let mut command = Command::new("python3"); | |||
| let mut command = Command::new(get_python_path().context("Could not get python binary")?); | |||
| command.args([ | |||
| "-c", | |||
| &format!( | |||
| @@ -1,6 +1,7 @@ | |||
| use eyre::{bail, eyre}; | |||
| use eyre::{bail, eyre, Context}; | |||
| use std::{ | |||
| env::consts::{DLL_PREFIX, DLL_SUFFIX}, | |||
| ffi::OsStr, | |||
| path::Path, | |||
| }; | |||
| @@ -30,3 +31,45 @@ pub fn adjust_shared_library_path(path: &Path) -> Result<std::path::PathBuf, eyr | |||
| let path = path.with_file_name(library_filename); | |||
| Ok(path) | |||
| } | |||
| // Search for python binary. | |||
| // Match `python` for windows and `python3` for other platforms. | |||
| pub fn get_python_path() -> Result<std::path::PathBuf, eyre::ErrReport> { | |||
| let python = if cfg!(windows) { | |||
| which::which("python") | |||
| .context("failed to find `python` or `python3`. Make sure that python is available.")? | |||
| } else { | |||
| which::which("python3") | |||
| .context("failed to find `python` or `python3`. Make sure that python is available.")? | |||
| }; | |||
| Ok(python) | |||
| } | |||
| // Search for pip binary. | |||
| // First search for `pip3` as for ubuntu <20, `pip` can resolves to `python2,7 -m pip` | |||
| // Then search for `pip`, this will resolve for windows to python3 -m pip. | |||
| pub fn get_pip_path() -> Result<std::path::PathBuf, eyre::ErrReport> { | |||
| let python = match which::which("pip3") { | |||
| Ok(python) => python, | |||
| Err(_) => which::which("pip") | |||
| .context("failed to find `pip3` or `pip`. Make sure that python is available.")?, | |||
| }; | |||
| Ok(python) | |||
| } | |||
| // Helper function to run a program | |||
| pub async fn run<S>(program: S, args: &[&str], pwd: Option<&Path>) -> eyre::Result<()> | |||
| where | |||
| S: AsRef<OsStr>, | |||
| { | |||
| let mut run = tokio::process::Command::new(program); | |||
| run.args(args); | |||
| if let Some(pwd) = pwd { | |||
| run.current_dir(pwd); | |||
| } | |||
| if !run.status().await?.success() { | |||
| eyre::bail!("failed to run {args:?}"); | |||
| }; | |||
| Ok(()) | |||
| } | |||
| @@ -11,7 +11,7 @@ license.workspace = true | |||
| [dependencies] | |||
| eyre = "0.6.8" | |||
| reqwest = { version = "0.11.12", default-features = false, features = [ | |||
| "rustls", | |||
| "rustls-tls", | |||
| ] } | |||
| tokio = { version = "1.24.2", features = ["fs"] } | |||
| tracing = "0.1.36" | |||
| @@ -10,6 +10,8 @@ dora-ros2-bridge-msg-gen = { path = "../msg-gen" } | |||
| pyo3 = { workspace = true, features = ["eyre", "abi3-py37", "serde"] } | |||
| eyre = "0.6" | |||
| serde = "1.0.166" | |||
| flume = "0.10.14" | |||
| arrow = { workspace = true, features = ["pyarrow"] } | |||
| futures = "0.3.28" | |||
| [dev-dependencies] | |||
| serde_assert = "0.7.1" | |||
| @@ -1,4 +1,5 @@ | |||
| use std::{ | |||
| borrow::Cow, | |||
| collections::HashMap, | |||
| path::{Path, PathBuf}, | |||
| sync::Arc, | |||
| @@ -6,7 +7,7 @@ use std::{ | |||
| use ::dora_ros2_bridge::{ros2_client, rustdds}; | |||
| use arrow::{ | |||
| array::ArrayData, | |||
| array::{make_array, ArrayData}, | |||
| pyarrow::{FromPyArrow, ToPyArrow}, | |||
| }; | |||
| use dora_ros2_bridge_msg_gen::types::Message; | |||
| @@ -17,7 +18,7 @@ use pyo3::{ | |||
| types::{PyDict, PyList, PyModule}, | |||
| PyAny, PyObject, PyResult, Python, | |||
| }; | |||
| use typed::{deserialize::TypedDeserializer, for_message, TypeInfo, TypedValue}; | |||
| use typed::{deserialize::StructDeserializer, TypeInfo, TypedValue}; | |||
| pub mod qos; | |||
| pub mod typed; | |||
| @@ -52,6 +53,7 @@ impl Ros2Context { | |||
| ament_prefix_path_parsed.split(':').map(Path::new).collect() | |||
| } | |||
| }; | |||
| let packages = dora_ros2_bridge_msg_gen::get_packages(&paths) | |||
| .map_err(|err| eyre!(err)) | |||
| .context("failed to parse ROS2 message types")?; | |||
| @@ -99,22 +101,24 @@ impl Ros2Node { | |||
| message_type: String, | |||
| qos: qos::Ros2QosPolicies, | |||
| ) -> eyre::Result<Ros2Topic> { | |||
| let (namespace_name, message_name) = message_type.split_once("::").with_context(|| { | |||
| format!( | |||
| "message type must be of form `package::type`, is `{}`", | |||
| message_type | |||
| ) | |||
| })?; | |||
| let (namespace_name, message_name) = | |||
| match (message_type.split_once("/"), message_type.split_once("::")) { | |||
| (Some(msg), None) => msg, | |||
| (None, Some(msg)) => msg, | |||
| _ => eyre::bail!("Expected message type in the format `namespace/message` or `namespace::message`, such as `std_msgs/UInt8` but got: {}", message_type), | |||
| }; | |||
| let message_type_name = ros2_client::MessageTypeName::new(namespace_name, message_name); | |||
| let topic_name = ros2_client::Name::parse(name) | |||
| .map_err(|err| eyre!("failed to parse ROS2 topic name: {err}"))?; | |||
| let topic = self | |||
| .node | |||
| .create_topic(&topic_name, message_type_name, &qos.into())?; | |||
| let type_info = | |||
| for_message(&self.messages, namespace_name, message_name).with_context(|| { | |||
| format!("failed to determine type info for message {namespace_name}/{message_name}") | |||
| })?; | |||
| let type_info = TypeInfo { | |||
| package_name: namespace_name.to_owned().into(), | |||
| message_name: message_name.to_owned().into(), | |||
| messages: self.messages.clone(), | |||
| }; | |||
| Ok(Ros2Topic { topic, type_info }) | |||
| } | |||
| @@ -143,7 +147,7 @@ impl Ros2Node { | |||
| .create_subscription(&topic.topic, qos.map(Into::into))?; | |||
| Ok(Ros2Subscription { | |||
| subscription: Some(subscription), | |||
| deserializer: TypedDeserializer::new(topic.type_info.clone()), | |||
| deserializer: StructDeserializer::new(Cow::Owned(topic.type_info.clone())), | |||
| }) | |||
| } | |||
| } | |||
| @@ -175,14 +179,14 @@ impl From<Ros2NodeOptions> for ros2_client::NodeOptions { | |||
| #[non_exhaustive] | |||
| pub struct Ros2Topic { | |||
| topic: rustdds::Topic, | |||
| type_info: TypeInfo, | |||
| type_info: TypeInfo<'static>, | |||
| } | |||
| #[pyclass] | |||
| #[non_exhaustive] | |||
| pub struct Ros2Publisher { | |||
| publisher: ros2_client::Publisher<TypedValue<'static>>, | |||
| type_info: TypeInfo, | |||
| type_info: TypeInfo<'static>, | |||
| } | |||
| #[pymethods] | |||
| @@ -209,7 +213,7 @@ impl Ros2Publisher { | |||
| //// add type info to ensure correct serialization (e.g. struct types | |||
| //// and map types need to be serialized differently) | |||
| let typed_value = TypedValue { | |||
| value: &value, | |||
| value: &make_array(value), | |||
| type_info: &self.type_info, | |||
| }; | |||
| @@ -224,7 +228,7 @@ impl Ros2Publisher { | |||
| #[pyclass] | |||
| #[non_exhaustive] | |||
| pub struct Ros2Subscription { | |||
| deserializer: TypedDeserializer, | |||
| deserializer: StructDeserializer<'static>, | |||
| subscription: Option<ros2_client::Subscription<ArrayData>>, | |||
| } | |||
| @@ -238,7 +242,7 @@ impl Ros2Subscription { | |||
| .take_seed(self.deserializer.clone()) | |||
| .context("failed to take next message from subscription")?; | |||
| let Some((value, _info)) = message else { | |||
| return Ok(None) | |||
| return Ok(None); | |||
| }; | |||
| let message = value.to_pyarrow(py)?; | |||
| @@ -263,7 +267,7 @@ impl Ros2Subscription { | |||
| } | |||
| pub struct Ros2SubscriptionStream { | |||
| deserializer: TypedDeserializer, | |||
| deserializer: StructDeserializer<'static>, | |||
| subscription: ros2_client::Subscription<ArrayData>, | |||
| } | |||
| @@ -1,397 +0,0 @@ | |||
| use super::TypeInfo; | |||
| use arrow::{ | |||
| array::{ | |||
| make_array, Array, ArrayData, BooleanBuilder, Float32Builder, Float64Builder, Int16Builder, | |||
| Int32Builder, Int64Builder, Int8Builder, ListArray, NullArray, StringBuilder, StructArray, | |||
| UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, | |||
| }, | |||
| buffer::OffsetBuffer, | |||
| compute::concat, | |||
| datatypes::{DataType, Field, Fields}, | |||
| }; | |||
| use core::fmt; | |||
| use std::sync::Arc; | |||
| #[derive(Debug, Clone, PartialEq)] | |||
| pub struct TypedDeserializer { | |||
| type_info: TypeInfo, | |||
| } | |||
| impl TypedDeserializer { | |||
| pub fn new(type_info: TypeInfo) -> Self { | |||
| Self { type_info } | |||
| } | |||
| } | |||
| impl<'de> serde::de::DeserializeSeed<'de> for TypedDeserializer { | |||
| type Value = ArrayData; | |||
| fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> | |||
| where | |||
| D: serde::Deserializer<'de>, | |||
| { | |||
| let data_type = self.type_info.data_type; | |||
| let value = match data_type.clone() { | |||
| DataType::Struct(fields) => { | |||
| /// Serde requires that struct and field names are known at | |||
| /// compile time with a `'static` lifetime, which is not | |||
| /// possible in this case. Thus, we need to use dummy names | |||
| /// instead. | |||
| /// | |||
| /// The actual names do not really matter because | |||
| /// the CDR format of ROS2 does not encode struct or field | |||
| /// names. | |||
| const DUMMY_STRUCT_NAME: &str = "struct"; | |||
| const DUMMY_FIELDS: &[&str] = &[""; 100]; | |||
| deserializer.deserialize_struct( | |||
| DUMMY_STRUCT_NAME, | |||
| &DUMMY_FIELDS[..fields.len()], | |||
| StructVisitor { | |||
| fields, | |||
| defaults: self.type_info.defaults, | |||
| }, | |||
| ) | |||
| } | |||
| DataType::List(field) => deserializer.deserialize_seq(ListVisitor { | |||
| field, | |||
| defaults: self.type_info.defaults, | |||
| }), | |||
| DataType::UInt8 => deserializer.deserialize_u8(PrimitiveValueVisitor), | |||
| DataType::UInt16 => deserializer.deserialize_u16(PrimitiveValueVisitor), | |||
| DataType::UInt32 => deserializer.deserialize_u32(PrimitiveValueVisitor), | |||
| DataType::UInt64 => deserializer.deserialize_u64(PrimitiveValueVisitor), | |||
| DataType::Int8 => deserializer.deserialize_i8(PrimitiveValueVisitor), | |||
| DataType::Int16 => deserializer.deserialize_i16(PrimitiveValueVisitor), | |||
| DataType::Int32 => deserializer.deserialize_i32(PrimitiveValueVisitor), | |||
| DataType::Int64 => deserializer.deserialize_i64(PrimitiveValueVisitor), | |||
| DataType::Float32 => deserializer.deserialize_f32(PrimitiveValueVisitor), | |||
| DataType::Float64 => deserializer.deserialize_f64(PrimitiveValueVisitor), | |||
| DataType::Utf8 => deserializer.deserialize_str(PrimitiveValueVisitor), | |||
| _ => todo!(), | |||
| }?; | |||
| debug_assert!( | |||
| value.data_type() == &data_type, | |||
| "Datatype does not correspond to default data type.\n Expected: {:#?} \n but got: {:#?}, with value: {:#?}", data_type, value.data_type(), value | |||
| ); | |||
| Ok(value) | |||
| } | |||
| } | |||
| /// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 | |||
| struct PrimitiveValueVisitor; | |||
| impl<'de> serde::de::Visitor<'de> for PrimitiveValueVisitor { | |||
| type Value = ArrayData; | |||
| fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { | |||
| formatter.write_str("a primitive value") | |||
| } | |||
| fn visit_bool<E>(self, b: bool) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = BooleanBuilder::new(); | |||
| array.append_value(b); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_i8<E>(self, u: i8) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Int8Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_i16<E>(self, u: i16) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Int16Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_i32<E>(self, u: i32) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Int32Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_i64<E>(self, i: i64) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Int64Builder::new(); | |||
| array.append_value(i); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_u8<E>(self, u: u8) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = UInt8Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_u16<E>(self, u: u16) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = UInt16Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_u32<E>(self, u: u32) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = UInt32Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_u64<E>(self, u: u64) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = UInt64Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_f32<E>(self, f: f32) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Float32Builder::new(); | |||
| array.append_value(f); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_f64<E>(self, f: f64) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Float64Builder::new(); | |||
| array.append_value(f); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_str<E>(self, s: &str) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = StringBuilder::new(); | |||
| array.append_value(s); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_string<E>(self, s: String) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = StringBuilder::new(); | |||
| array.append_value(s); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_unit<E>(self) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let array = NullArray::new(0); | |||
| Ok(array.into()) | |||
| } | |||
| fn visit_none<E>(self) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let array = NullArray::new(0); | |||
| Ok(array.into()) | |||
| } | |||
| } | |||
| struct StructVisitor { | |||
| fields: Fields, | |||
| defaults: ArrayData, | |||
| } | |||
| impl<'de> serde::de::Visitor<'de> for StructVisitor { | |||
| type Value = ArrayData; | |||
| fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { | |||
| formatter.write_str("a struct encoded as sequence") | |||
| } | |||
| fn visit_seq<A>(self, mut data: A) -> Result<Self::Value, A::Error> | |||
| where | |||
| A: serde::de::SeqAccess<'de>, | |||
| { | |||
| let mut fields = vec![]; | |||
| let defaults: StructArray = self.defaults.clone().into(); | |||
| for field in self.fields.iter() { | |||
| let default = match defaults.column_by_name(field.name()) { | |||
| Some(value) => value.clone(), | |||
| None => { | |||
| return Err(serde::de::Error::custom(format!( | |||
| "missing field {} for deserialization", | |||
| &field.name() | |||
| ))) | |||
| } | |||
| }; | |||
| let value = match data.next_element_seed(TypedDeserializer { | |||
| type_info: TypeInfo { | |||
| data_type: field.data_type().clone(), | |||
| defaults: default.to_data(), | |||
| }, | |||
| })? { | |||
| Some(value) => make_array(value), | |||
| None => default, | |||
| }; | |||
| fields.push(( | |||
| // Recreate a new field as List(UInt8) can be converted to UInt8 | |||
| Arc::new(Field::new(field.name(), value.data_type().clone(), true)), | |||
| value, | |||
| )); | |||
| } | |||
| let struct_array: StructArray = fields.into(); | |||
| Ok(struct_array.into()) | |||
| } | |||
| } | |||
| struct ListVisitor { | |||
| field: Arc<Field>, | |||
| defaults: ArrayData, | |||
| } | |||
| impl<'de> serde::de::Visitor<'de> for ListVisitor { | |||
| type Value = ArrayData; | |||
| fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { | |||
| formatter.write_str("an array encoded as sequence") | |||
| } | |||
| fn visit_seq<A>(self, mut data: A) -> Result<Self::Value, A::Error> | |||
| where | |||
| A: serde::de::SeqAccess<'de>, | |||
| { | |||
| let data = match self.field.data_type().clone() { | |||
| DataType::UInt8 => { | |||
| let mut array = UInt8Builder::new(); | |||
| while let Some(value) = data.next_element::<u8>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::UInt16 => { | |||
| let mut array = UInt16Builder::new(); | |||
| while let Some(value) = data.next_element::<u16>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::UInt32 => { | |||
| let mut array = UInt32Builder::new(); | |||
| while let Some(value) = data.next_element::<u32>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::UInt64 => { | |||
| let mut array = UInt64Builder::new(); | |||
| while let Some(value) = data.next_element::<u64>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::Int8 => { | |||
| let mut array = Int8Builder::new(); | |||
| while let Some(value) = data.next_element::<i8>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::Int16 => { | |||
| let mut array = Int16Builder::new(); | |||
| while let Some(value) = data.next_element::<i16>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::Int32 => { | |||
| let mut array = Int32Builder::new(); | |||
| while let Some(value) = data.next_element::<i32>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::Int64 => { | |||
| let mut array = Int64Builder::new(); | |||
| while let Some(value) = data.next_element::<i64>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::Float32 => { | |||
| let mut array = Float32Builder::new(); | |||
| while let Some(value) = data.next_element::<f32>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::Float64 => { | |||
| let mut array = Float64Builder::new(); | |||
| while let Some(value) = data.next_element::<f64>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| DataType::Utf8 => { | |||
| let mut array = StringBuilder::new(); | |||
| while let Some(value) = data.next_element::<String>()? { | |||
| array.append_value(value); | |||
| } | |||
| Ok(array.finish().into()) | |||
| } | |||
| _ => { | |||
| let mut buffer = vec![]; | |||
| while let Some(value) = data.next_element_seed(TypedDeserializer { | |||
| type_info: TypeInfo { | |||
| data_type: self.field.data_type().clone(), | |||
| defaults: self.defaults.clone(), | |||
| }, | |||
| })? { | |||
| let element = make_array(value); | |||
| buffer.push(element); | |||
| } | |||
| concat( | |||
| buffer | |||
| .iter() | |||
| .map(|data| data.as_ref()) | |||
| .collect::<Vec<_>>() | |||
| .as_slice(), | |||
| ) | |||
| .map(|op| op.to_data()) | |||
| } | |||
| }; | |||
| if let Ok(values) = data { | |||
| let offsets = OffsetBuffer::new(vec![0, values.len() as i32].into()); | |||
| let array = ListArray::new(self.field, offsets, make_array(values), None).to_data(); | |||
| Ok(array) | |||
| } else { | |||
| Ok(self.defaults) // TODO: Better handle deserialization error | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,28 @@ | |||
| use arrow::array::ArrayData; | |||
| use dora_ros2_bridge_msg_gen::types::sequences; | |||
| use crate::typed::TypeInfo; | |||
| use super::sequence::SequenceVisitor; | |||
| pub struct ArrayDeserializer<'a> { | |||
| pub array_type: &'a sequences::Array, | |||
| pub type_info: &'a TypeInfo<'a>, | |||
| } | |||
| impl<'de> serde::de::DeserializeSeed<'de> for ArrayDeserializer<'_> { | |||
| type Value = ArrayData; | |||
| fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> | |||
| where | |||
| D: serde::Deserializer<'de>, | |||
| { | |||
| deserializer.deserialize_tuple( | |||
| self.array_type.size, | |||
| SequenceVisitor { | |||
| item_type: &self.array_type.value_type, | |||
| type_info: self.type_info, | |||
| }, | |||
| ) | |||
| } | |||
| } | |||
| @@ -0,0 +1,163 @@ | |||
| use super::{TypeInfo, DUMMY_STRUCT_NAME}; | |||
| use arrow::{ | |||
| array::{make_array, ArrayData, StructArray}, | |||
| datatypes::Field, | |||
| }; | |||
| use core::fmt; | |||
| use std::{borrow::Cow, collections::HashMap, fmt::Display, sync::Arc}; | |||
| mod array; | |||
| mod primitive; | |||
| mod sequence; | |||
| mod string; | |||
| #[derive(Debug, Clone)] | |||
| pub struct StructDeserializer<'a> { | |||
| type_info: Cow<'a, TypeInfo<'a>>, | |||
| } | |||
| impl<'a> StructDeserializer<'a> { | |||
| pub fn new(type_info: Cow<'a, TypeInfo<'a>>) -> Self { | |||
| Self { type_info } | |||
| } | |||
| } | |||
| impl<'de> serde::de::DeserializeSeed<'de> for StructDeserializer<'_> { | |||
| type Value = ArrayData; | |||
| fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> | |||
| where | |||
| D: serde::Deserializer<'de>, | |||
| { | |||
| let empty = HashMap::new(); | |||
| let package_messages = self | |||
| .type_info | |||
| .messages | |||
| .get(self.type_info.package_name.as_ref()) | |||
| .unwrap_or(&empty); | |||
| let message = package_messages | |||
| .get(self.type_info.message_name.as_ref()) | |||
| .ok_or_else(|| { | |||
| error(format!( | |||
| "could not find message type {}::{}", | |||
| self.type_info.package_name, self.type_info.message_name | |||
| )) | |||
| })?; | |||
| let visitor = StructVisitor { | |||
| type_info: self.type_info.as_ref(), | |||
| }; | |||
| deserializer.deserialize_tuple_struct(DUMMY_STRUCT_NAME, message.members.len(), visitor) | |||
| } | |||
| } | |||
| struct StructVisitor<'a> { | |||
| type_info: &'a TypeInfo<'a>, | |||
| } | |||
| impl<'a, 'de> serde::de::Visitor<'de> for StructVisitor<'a> { | |||
| type Value = ArrayData; | |||
| fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { | |||
| formatter.write_str("a struct encoded as TupleStruct") | |||
| } | |||
| fn visit_seq<A>(self, mut data: A) -> Result<Self::Value, A::Error> | |||
| where | |||
| A: serde::de::SeqAccess<'de>, | |||
| { | |||
| let empty = HashMap::new(); | |||
| let package_messages = self | |||
| .type_info | |||
| .messages | |||
| .get(self.type_info.package_name.as_ref()) | |||
| .unwrap_or(&empty); | |||
| let message = package_messages | |||
| .get(self.type_info.message_name.as_ref()) | |||
| .ok_or_else(|| { | |||
| error(format!( | |||
| "could not find message type {}::{}", | |||
| self.type_info.package_name, self.type_info.message_name | |||
| )) | |||
| })?; | |||
| let mut fields = vec![]; | |||
| for member in &message.members { | |||
| let value = match &member.r#type { | |||
| dora_ros2_bridge_msg_gen::types::MemberType::NestableType(t) => match t { | |||
| dora_ros2_bridge_msg_gen::types::primitives::NestableType::BasicType(t) => { | |||
| data.next_element_seed(primitive::PrimitiveDeserializer(t))? | |||
| } | |||
| dora_ros2_bridge_msg_gen::types::primitives::NestableType::NamedType(name) => { | |||
| data.next_element_seed(StructDeserializer { | |||
| type_info: Cow::Owned(TypeInfo { | |||
| package_name: Cow::Borrowed(&self.type_info.package_name), | |||
| message_name: Cow::Borrowed(&name.0), | |||
| messages: self.type_info.messages.clone(), | |||
| }), | |||
| })? | |||
| } | |||
| dora_ros2_bridge_msg_gen::types::primitives::NestableType::NamespacedType( | |||
| reference, | |||
| ) => { | |||
| if reference.namespace != "msg" { | |||
| return Err(error(format!( | |||
| "struct field {} references non-message type {reference:?}", | |||
| member.name | |||
| ))); | |||
| } | |||
| data.next_element_seed(StructDeserializer { | |||
| type_info: Cow::Owned(TypeInfo { | |||
| package_name: Cow::Borrowed(&reference.package), | |||
| message_name: Cow::Borrowed(&reference.name), | |||
| messages: self.type_info.messages.clone(), | |||
| }), | |||
| })? | |||
| } | |||
| dora_ros2_bridge_msg_gen::types::primitives::NestableType::GenericString(t) => { | |||
| match t { | |||
| dora_ros2_bridge_msg_gen::types::primitives::GenericString::String | dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedString(_)=> { | |||
| data.next_element_seed(string::StringDeserializer)? | |||
| }, | |||
| dora_ros2_bridge_msg_gen::types::primitives::GenericString::WString => todo!("deserialize WString"), | |||
| dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedWString(_) => todo!("deserialize BoundedWString"), | |||
| } | |||
| } | |||
| }, | |||
| dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { | |||
| data.next_element_seed(array::ArrayDeserializer{ array_type : a, type_info: self.type_info})? | |||
| }, | |||
| dora_ros2_bridge_msg_gen::types::MemberType::Sequence(s) => { | |||
| data.next_element_seed(sequence::SequenceDeserializer{item_type: &s.value_type, type_info: self.type_info})? | |||
| }, | |||
| dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(s) => { | |||
| data.next_element_seed(sequence::SequenceDeserializer{ item_type: &s.value_type, type_info: self.type_info})? | |||
| }, | |||
| }; | |||
| let value = value.ok_or_else(|| { | |||
| error(format!( | |||
| "struct member {} not present in message", | |||
| member.name | |||
| )) | |||
| })?; | |||
| fields.push(( | |||
| Arc::new(Field::new(&member.name, value.data_type().clone(), true)), | |||
| make_array(value), | |||
| )); | |||
| } | |||
| let struct_array: StructArray = fields.into(); | |||
| Ok(struct_array.into()) | |||
| } | |||
| } | |||
| fn error<E, T>(e: T) -> E | |||
| where | |||
| T: Display, | |||
| E: serde::de::Error, | |||
| { | |||
| serde::de::Error::custom(e) | |||
| } | |||
| @@ -0,0 +1,155 @@ | |||
| use arrow::array::{ | |||
| ArrayData, BooleanBuilder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, | |||
| Int64Builder, Int8Builder, NullArray, UInt16Builder, UInt32Builder, UInt64Builder, | |||
| UInt8Builder, | |||
| }; | |||
| use core::fmt; | |||
| use dora_ros2_bridge_msg_gen::types::primitives::BasicType; | |||
| pub struct PrimitiveDeserializer<'a>(pub &'a BasicType); | |||
| impl<'de> serde::de::DeserializeSeed<'de> for PrimitiveDeserializer<'_> { | |||
| type Value = ArrayData; | |||
| fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> | |||
| where | |||
| D: serde::Deserializer<'de>, | |||
| { | |||
| match self.0 { | |||
| BasicType::I8 => deserializer.deserialize_i8(PrimitiveValueVisitor), | |||
| BasicType::I16 => deserializer.deserialize_i16(PrimitiveValueVisitor), | |||
| BasicType::I32 => deserializer.deserialize_i32(PrimitiveValueVisitor), | |||
| BasicType::I64 => deserializer.deserialize_i64(PrimitiveValueVisitor), | |||
| BasicType::U8 | BasicType::Char | BasicType::Byte => { | |||
| deserializer.deserialize_u8(PrimitiveValueVisitor) | |||
| } | |||
| BasicType::U16 => deserializer.deserialize_u16(PrimitiveValueVisitor), | |||
| BasicType::U32 => deserializer.deserialize_u32(PrimitiveValueVisitor), | |||
| BasicType::U64 => deserializer.deserialize_u64(PrimitiveValueVisitor), | |||
| BasicType::F32 => deserializer.deserialize_f32(PrimitiveValueVisitor), | |||
| BasicType::F64 => deserializer.deserialize_f64(PrimitiveValueVisitor), | |||
| BasicType::Bool => deserializer.deserialize_bool(PrimitiveValueVisitor), | |||
| } | |||
| } | |||
| } | |||
| /// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 | |||
| struct PrimitiveValueVisitor; | |||
| impl<'de> serde::de::Visitor<'de> for PrimitiveValueVisitor { | |||
| type Value = ArrayData; | |||
| fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { | |||
| formatter.write_str("a primitive value") | |||
| } | |||
| fn visit_bool<E>(self, b: bool) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = BooleanBuilder::new(); | |||
| array.append_value(b); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_i8<E>(self, u: i8) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Int8Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_i16<E>(self, u: i16) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Int16Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_i32<E>(self, u: i32) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Int32Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_i64<E>(self, i: i64) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Int64Builder::new(); | |||
| array.append_value(i); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_u8<E>(self, u: u8) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = UInt8Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_u16<E>(self, u: u16) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = UInt16Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_u32<E>(self, u: u32) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = UInt32Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_u64<E>(self, u: u64) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = UInt64Builder::new(); | |||
| array.append_value(u); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_f32<E>(self, f: f32) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Float32Builder::new(); | |||
| array.append_value(f); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_f64<E>(self, f: f64) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = Float64Builder::new(); | |||
| array.append_value(f); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_unit<E>(self) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let array = NullArray::new(0); | |||
| Ok(array.into()) | |||
| } | |||
| fn visit_none<E>(self) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let array = NullArray::new(0); | |||
| Ok(array.into()) | |||
| } | |||
| } | |||
| @@ -0,0 +1,163 @@ | |||
| use arrow::{ | |||
| array::{ | |||
| Array, ArrayData, BooleanBuilder, ListArray, ListBuilder, PrimitiveBuilder, StringBuilder, | |||
| }, | |||
| buffer::OffsetBuffer, | |||
| datatypes::{self, ArrowPrimitiveType, Field}, | |||
| }; | |||
| use core::fmt; | |||
| use dora_ros2_bridge_msg_gen::types::primitives::{self, BasicType, NestableType}; | |||
| use serde::Deserialize; | |||
| use std::{borrow::Cow, ops::Deref, sync::Arc}; | |||
| use crate::typed::TypeInfo; | |||
| use super::{error, StructDeserializer}; | |||
| pub struct SequenceDeserializer<'a> { | |||
| pub item_type: &'a NestableType, | |||
| pub type_info: &'a TypeInfo<'a>, | |||
| } | |||
| impl<'de> serde::de::DeserializeSeed<'de> for SequenceDeserializer<'_> { | |||
| type Value = ArrayData; | |||
| fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> | |||
| where | |||
| D: serde::Deserializer<'de>, | |||
| { | |||
| deserializer.deserialize_seq(SequenceVisitor { | |||
| item_type: self.item_type, | |||
| type_info: self.type_info, | |||
| }) | |||
| } | |||
| } | |||
| pub struct SequenceVisitor<'a> { | |||
| pub item_type: &'a NestableType, | |||
| pub type_info: &'a TypeInfo<'a>, | |||
| } | |||
| impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { | |||
| type Value = ArrayData; | |||
| fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { | |||
| write!(formatter, "a sequence") | |||
| } | |||
| fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> | |||
| where | |||
| A: serde::de::SeqAccess<'de>, | |||
| { | |||
| match &self.item_type { | |||
| NestableType::BasicType(t) => match t { | |||
| BasicType::I8 => deserialize_primitive_seq::<_, datatypes::Int8Type>(seq), | |||
| BasicType::I16 => deserialize_primitive_seq::<_, datatypes::Int16Type>(seq), | |||
| BasicType::I32 => deserialize_primitive_seq::<_, datatypes::Int32Type>(seq), | |||
| BasicType::I64 => deserialize_primitive_seq::<_, datatypes::Int64Type>(seq), | |||
| BasicType::U8 | BasicType::Char | BasicType::Byte => { | |||
| deserialize_primitive_seq::<_, datatypes::UInt8Type>(seq) | |||
| } | |||
| BasicType::U16 => deserialize_primitive_seq::<_, datatypes::UInt16Type>(seq), | |||
| BasicType::U32 => deserialize_primitive_seq::<_, datatypes::UInt32Type>(seq), | |||
| BasicType::U64 => deserialize_primitive_seq::<_, datatypes::UInt64Type>(seq), | |||
| BasicType::F32 => deserialize_primitive_seq::<_, datatypes::Float32Type>(seq), | |||
| BasicType::F64 => deserialize_primitive_seq::<_, datatypes::Float64Type>(seq), | |||
| BasicType::Bool => { | |||
| let mut array = BooleanBuilder::new(); | |||
| while let Some(value) = seq.next_element()? { | |||
| array.append_value(value); | |||
| } | |||
| // wrap array into list of length 1 | |||
| let mut list = ListBuilder::new(array); | |||
| list.append(true); | |||
| Ok(list.finish().into()) | |||
| } | |||
| }, | |||
| NestableType::NamedType(name) => { | |||
| let deserializer = StructDeserializer { | |||
| type_info: Cow::Owned(TypeInfo { | |||
| package_name: Cow::Borrowed(&self.type_info.package_name), | |||
| message_name: Cow::Borrowed(&name.0), | |||
| messages: self.type_info.messages.clone(), | |||
| }), | |||
| }; | |||
| deserialize_struct_seq(&mut seq, deserializer) | |||
| } | |||
| NestableType::NamespacedType(reference) => { | |||
| if reference.namespace != "msg" { | |||
| return Err(error(format!( | |||
| "sequence item references non-message type {reference:?}", | |||
| ))); | |||
| } | |||
| let deserializer = StructDeserializer { | |||
| type_info: Cow::Owned(TypeInfo { | |||
| package_name: Cow::Borrowed(&reference.package), | |||
| message_name: Cow::Borrowed(&reference.name), | |||
| messages: self.type_info.messages.clone(), | |||
| }), | |||
| }; | |||
| deserialize_struct_seq(&mut seq, deserializer) | |||
| } | |||
| NestableType::GenericString(t) => match t { | |||
| primitives::GenericString::String | primitives::GenericString::BoundedString(_) => { | |||
| let mut array = StringBuilder::new(); | |||
| while let Some(value) = seq.next_element::<String>()? { | |||
| array.append_value(value); | |||
| } | |||
| // wrap array into list of length 1 | |||
| let mut list = ListBuilder::new(array); | |||
| list.append(true); | |||
| Ok(list.finish().into()) | |||
| } | |||
| primitives::GenericString::WString => todo!("deserialize sequence of WString"), | |||
| primitives::GenericString::BoundedWString(_) => { | |||
| todo!("deserialize sequence of BoundedWString") | |||
| } | |||
| }, | |||
| } | |||
| } | |||
| } | |||
| fn deserialize_struct_seq<'de, A>( | |||
| seq: &mut A, | |||
| deserializer: StructDeserializer<'_>, | |||
| ) -> Result<ArrayData, <A as serde::de::SeqAccess<'de>>::Error> | |||
| where | |||
| A: serde::de::SeqAccess<'de>, | |||
| { | |||
| let mut values = Vec::new(); | |||
| while let Some(value) = seq.next_element_seed(deserializer.clone())? { | |||
| values.push(arrow::array::make_array(value)); | |||
| } | |||
| let refs: Vec<_> = values.iter().map(|a| a.deref()).collect(); | |||
| let concatenated = arrow::compute::concat(&refs).map_err(super::error)?; | |||
| let list = ListArray::try_new( | |||
| Arc::new(Field::new("item", concatenated.data_type().clone(), true)), | |||
| OffsetBuffer::from_lengths([concatenated.len()]), | |||
| Arc::new(concatenated), | |||
| None, | |||
| ) | |||
| .map_err(error)?; | |||
| Ok(list.to_data()) | |||
| } | |||
| fn deserialize_primitive_seq<'de, S, T>( | |||
| mut seq: S, | |||
| ) -> Result<ArrayData, <S as serde::de::SeqAccess<'de>>::Error> | |||
| where | |||
| S: serde::de::SeqAccess<'de>, | |||
| T: ArrowPrimitiveType, | |||
| T::Native: Deserialize<'de>, | |||
| { | |||
| let mut array = PrimitiveBuilder::<T>::new(); | |||
| while let Some(value) = seq.next_element::<T::Native>()? { | |||
| array.append_value(value); | |||
| } | |||
| // wrap array into list of length 1 | |||
| let mut list = ListBuilder::new(array); | |||
| list.append(true); | |||
| Ok(list.finish().into()) | |||
| } | |||
| @@ -0,0 +1,44 @@ | |||
| use arrow::array::{ArrayData, StringBuilder}; | |||
| use core::fmt; | |||
| pub struct StringDeserializer; | |||
| impl<'de> serde::de::DeserializeSeed<'de> for StringDeserializer { | |||
| type Value = ArrayData; | |||
| fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> | |||
| where | |||
| D: serde::Deserializer<'de>, | |||
| { | |||
| deserializer.deserialize_str(StringVisitor) | |||
| } | |||
| } | |||
| /// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 | |||
| struct StringVisitor; | |||
| impl<'de> serde::de::Visitor<'de> for StringVisitor { | |||
| type Value = ArrayData; | |||
| fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { | |||
| formatter.write_str("a string value") | |||
| } | |||
| fn visit_str<E>(self, s: &str) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = StringBuilder::new(); | |||
| array.append_value(s); | |||
| Ok(array.finish().into()) | |||
| } | |||
| fn visit_string<E>(self, s: String) -> Result<Self::Value, E> | |||
| where | |||
| E: serde::de::Error, | |||
| { | |||
| let mut array = StringBuilder::new(); | |||
| array.append_value(s); | |||
| Ok(array.finish().into()) | |||
| } | |||
| } | |||
| @@ -1,277 +1,120 @@ | |||
| use arrow::{ | |||
| array::{ | |||
| make_array, Array, ArrayData, BooleanArray, Float32Array, Float64Array, Int16Array, | |||
| Int32Array, Int64Array, Int8Array, StringArray, StructArray, UInt16Array, UInt32Array, | |||
| UInt64Array, UInt8Array, | |||
| }, | |||
| buffer::Buffer, | |||
| compute::concat, | |||
| datatypes::{DataType, Field}, | |||
| }; | |||
| use dora_ros2_bridge_msg_gen::types::{ | |||
| primitives::{BasicType, NestableType}, | |||
| MemberType, Message, | |||
| }; | |||
| use eyre::{Context, ContextCompat, Result}; | |||
| use std::{collections::HashMap, sync::Arc}; | |||
| use dora_ros2_bridge_msg_gen::types::Message; | |||
| use std::{borrow::Cow, collections::HashMap, sync::Arc}; | |||
| pub use serialize::TypedValue; | |||
| pub mod deserialize; | |||
| pub mod serialize; | |||
| #[derive(Debug, Clone, PartialEq)] | |||
| pub struct TypeInfo { | |||
| data_type: DataType, | |||
| defaults: ArrayData, | |||
| #[derive(Debug, Clone)] | |||
| pub struct TypeInfo<'a> { | |||
| pub package_name: Cow<'a, str>, | |||
| pub message_name: Cow<'a, str>, | |||
| pub messages: Arc<HashMap<String, HashMap<String, Message>>>, | |||
| } | |||
| pub fn for_message( | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| package_name: &str, | |||
| message_name: &str, | |||
| ) -> eyre::Result<TypeInfo> { | |||
| let empty = HashMap::new(); | |||
| let package_messages = messages.get(package_name).unwrap_or(&empty); | |||
| let message = package_messages | |||
| .get(message_name) | |||
| .context("unknown type name")?; | |||
| let default_struct_vec: Vec<(Arc<Field>, Arc<dyn Array>)> = message | |||
| .members | |||
| .iter() | |||
| .map(|m| { | |||
| let default = make_array(default_for_member(m, package_name, messages)?); | |||
| Result::<_, eyre::Report>::Ok(( | |||
| Arc::new(Field::new( | |||
| m.name.clone(), | |||
| default.data_type().clone(), | |||
| true, | |||
| )), | |||
| default, | |||
| )) | |||
| }) | |||
| .collect::<Result<_, _>>()?; | |||
| let default_struct: StructArray = default_struct_vec.into(); | |||
| Ok(TypeInfo { | |||
| data_type: default_struct.data_type().clone(), | |||
| defaults: default_struct.into(), | |||
| }) | |||
| } | |||
| pub fn default_for_member( | |||
| m: &dora_ros2_bridge_msg_gen::types::Member, | |||
| package_name: &str, | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| ) -> eyre::Result<ArrayData> { | |||
| let value = match &m.r#type { | |||
| MemberType::NestableType(t) => match t { | |||
| NestableType::BasicType(_) | NestableType::GenericString(_) => match &m | |||
| .default | |||
| .as_deref() | |||
| { | |||
| Some([]) => eyre::bail!("empty default value not supported"), | |||
| Some([default]) => preset_default_for_basic_type(t, default) | |||
| .with_context(|| format!("failed to parse default value for `{}`", m.name))?, | |||
| Some(_) => eyre::bail!( | |||
| "there should be only a single default value for non-sequence types" | |||
| ), | |||
| None => default_for_nestable_type(t, package_name, messages)?, | |||
| }, | |||
| NestableType::NamedType(_) => { | |||
| if m.default.is_some() { | |||
| eyre::bail!("default values for nested types are not supported") | |||
| } else { | |||
| default_for_nestable_type(t, package_name, messages)? | |||
| } | |||
| /// Serde requires that struct and field names are known at | |||
| /// compile time with a `'static` lifetime, which is not | |||
| /// possible in this case. Thus, we need to use dummy names | |||
| /// instead. | |||
| /// | |||
| /// The actual names do not really matter because | |||
| /// the CDR format of ROS2 does not encode struct or field | |||
| /// names. | |||
| const DUMMY_STRUCT_NAME: &str = "struct"; | |||
| #[cfg(test)] | |||
| mod tests { | |||
| use std::path::PathBuf; | |||
| use crate::typed::deserialize::StructDeserializer; | |||
| use crate::typed::serialize; | |||
| use crate::typed::TypeInfo; | |||
| use crate::Ros2Context; | |||
| use arrow::array::make_array; | |||
| use arrow::pyarrow::FromPyArrow; | |||
| use arrow::pyarrow::ToPyArrow; | |||
| use pyo3::types::IntoPyDict; | |||
| use pyo3::types::PyDict; | |||
| use pyo3::types::PyList; | |||
| use pyo3::types::PyModule; | |||
| use pyo3::types::PyTuple; | |||
| use pyo3::Python; | |||
| use serde::de::DeserializeSeed; | |||
| use serde::Serialize; | |||
| use serde_assert::Serializer; | |||
| use serialize::TypedValue; | |||
| use eyre::{Context, Result}; | |||
| use serde_assert::Deserializer; | |||
| #[test] | |||
| fn test_python_array_code() -> Result<()> { | |||
| pyo3::prepare_freethreaded_python(); | |||
| let context = Ros2Context::new(None).context("Could not create a context")?; | |||
| let messages = context.messages.clone(); | |||
| let serializer = Serializer::builder().build(); | |||
| Python::with_gil(|py| -> Result<()> { | |||
| let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); //.join("test_utils.py"); // Adjust this path as needed | |||
| // Add the Python module's directory to sys.path | |||
| py.run( | |||
| "import sys; sys.path.append(str(path))", | |||
| Some([("path", path)].into_py_dict(py)), | |||
| None, | |||
| )?; | |||
| let my_module = PyModule::import(py, "test_utils")?; | |||
| let arrays: &PyList = my_module.getattr("TEST_ARRAYS")?.extract()?; | |||
| for array_wrapper in arrays.iter() { | |||
| let arrays: &PyTuple = array_wrapper.extract()?; | |||
| let package_name: String = arrays.get_item(0)?.extract()?; | |||
| let message_name: String = arrays.get_item(1)?.extract()?; | |||
| println!("Checking {}::{}", package_name, message_name); | |||
| let in_pyarrow = arrays.get_item(2)?; | |||
| let array = arrow::array::ArrayData::from_pyarrow(in_pyarrow)?; | |||
| let type_info = TypeInfo { | |||
| package_name: package_name.into(), | |||
| message_name: message_name.clone().into(), | |||
| messages: messages.clone(), | |||
| }; | |||
| let typed_value = TypedValue { | |||
| value: &make_array(array.clone()), | |||
| type_info: &type_info.clone(), | |||
| }; | |||
| let typed_deserializer = | |||
| StructDeserializer::new(std::borrow::Cow::Owned(type_info)); | |||
| let tokens = typed_value.serialize(&serializer)?; | |||
| let mut deserializer = Deserializer::builder(tokens).build(); | |||
| let out_value = typed_deserializer | |||
| .deserialize(&mut deserializer) | |||
| .context("could not deserialize array")?; | |||
| let out_pyarrow = out_value.to_pyarrow(py)?; | |||
| let test_utils = PyModule::import(py, "test_utils")?; | |||
| let context = PyDict::new(py); | |||
| context.set_item("test_utils", test_utils)?; | |||
| context.set_item("in_pyarrow", in_pyarrow)?; | |||
| context.set_item("out_pyarrow", out_pyarrow)?; | |||
| let _ = py | |||
| .eval( | |||
| "test_utils.is_subset(in_pyarrow, out_pyarrow)", | |||
| Some(context), | |||
| None, | |||
| ) | |||
| .context("could not check if it is a subset")?; | |||
| } | |||
| NestableType::NamespacedType(_) => { | |||
| default_for_nestable_type(t, package_name, messages)? | |||
| } | |||
| }, | |||
| MemberType::Array(array) => { | |||
| list_default_values(m, &array.value_type, package_name, messages)? | |||
| } | |||
| MemberType::Sequence(seq) => { | |||
| list_default_values(m, &seq.value_type, package_name, messages)? | |||
| } | |||
| MemberType::BoundedSequence(seq) => { | |||
| list_default_values(m, &seq.value_type, package_name, messages)? | |||
| } | |||
| }; | |||
| Ok(value) | |||
| } | |||
| fn default_for_nestable_type( | |||
| t: &NestableType, | |||
| package_name: &str, | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| ) -> Result<ArrayData> { | |||
| let empty = HashMap::new(); | |||
| let package_messages = messages.get(package_name).unwrap_or(&empty); | |||
| let array = match t { | |||
| NestableType::BasicType(t) => match t { | |||
| BasicType::I8 => Int8Array::from(vec![0]).into(), | |||
| BasicType::I16 => Int16Array::from(vec![0]).into(), | |||
| BasicType::I32 => Int32Array::from(vec![0]).into(), | |||
| BasicType::I64 => Int64Array::from(vec![0]).into(), | |||
| BasicType::U8 => UInt8Array::from(vec![0]).into(), | |||
| BasicType::U16 => UInt16Array::from(vec![0]).into(), | |||
| BasicType::U32 => UInt32Array::from(vec![0]).into(), | |||
| BasicType::U64 => UInt64Array::from(vec![0]).into(), | |||
| BasicType::F32 => Float32Array::from(vec![0.]).into(), | |||
| BasicType::F64 => Float64Array::from(vec![0.]).into(), | |||
| BasicType::Char => StringArray::from(vec![""]).into(), | |||
| BasicType::Byte => UInt8Array::from(vec![0u8] as Vec<u8>).into(), | |||
| BasicType::Bool => BooleanArray::from(vec![false]).into(), | |||
| }, | |||
| NestableType::GenericString(_) => StringArray::from(vec![""]).into(), | |||
| NestableType::NamedType(name) => { | |||
| let referenced_message = package_messages | |||
| .get(&name.0) | |||
| .context("unknown referenced message")?; | |||
| default_for_referenced_message(referenced_message, package_name, messages)? | |||
| } | |||
| NestableType::NamespacedType(t) => { | |||
| let referenced_package_messages = messages.get(&t.package).unwrap_or(&empty); | |||
| let referenced_message = referenced_package_messages | |||
| .get(&t.name) | |||
| .context("unknown referenced message")?; | |||
| default_for_referenced_message(referenced_message, &t.package, messages)? | |||
| } | |||
| }; | |||
| Ok(array) | |||
| } | |||
| fn preset_default_for_basic_type(t: &NestableType, preset: &str) -> Result<ArrayData> { | |||
| Ok(match t { | |||
| NestableType::BasicType(t) => match t { | |||
| BasicType::I8 => Int8Array::from(vec![preset | |||
| .parse::<i8>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::I16 => Int16Array::from(vec![preset | |||
| .parse::<i16>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::I32 => Int32Array::from(vec![preset | |||
| .parse::<i32>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::I64 => Int64Array::from(vec![preset | |||
| .parse::<i64>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::U8 => UInt8Array::from(vec![preset | |||
| .parse::<u8>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::U16 => UInt16Array::from(vec![preset | |||
| .parse::<u16>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::U32 => UInt32Array::from(vec![preset | |||
| .parse::<u32>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::U64 => UInt64Array::from(vec![preset | |||
| .parse::<u64>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::F32 => Float32Array::from(vec![preset | |||
| .parse::<f32>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::F64 => Float64Array::from(vec![preset | |||
| .parse::<f64>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::Char => StringArray::from(vec![preset]).into(), | |||
| BasicType::Byte => UInt8Array::from(preset.as_bytes().to_owned()).into(), | |||
| BasicType::Bool => BooleanArray::from(vec![preset | |||
| .parse::<bool>() | |||
| .context("could not parse preset default value")?]) | |||
| .into(), | |||
| }, | |||
| NestableType::GenericString(_) => StringArray::from(vec![preset]).into(), | |||
| _ => todo!(), | |||
| }) | |||
| } | |||
| fn default_for_referenced_message( | |||
| referenced_message: &Message, | |||
| package_name: &str, | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| ) -> eyre::Result<ArrayData> { | |||
| let fields: Vec<(Arc<Field>, Arc<dyn Array>)> = referenced_message | |||
| .members | |||
| .iter() | |||
| .map(|m| { | |||
| let default = default_for_member(m, package_name, messages)?; | |||
| Result::<_, eyre::Report>::Ok(( | |||
| Arc::new(Field::new( | |||
| m.name.clone(), | |||
| default.data_type().clone(), | |||
| true, | |||
| )), | |||
| make_array(default), | |||
| )) | |||
| Ok(()) | |||
| }) | |||
| .collect::<Result<_, _>>()?; | |||
| let struct_array: StructArray = fields.into(); | |||
| Ok(struct_array.into()) | |||
| } | |||
| fn list_default_values( | |||
| m: &dora_ros2_bridge_msg_gen::types::Member, | |||
| value_type: &NestableType, | |||
| package_name: &str, | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| ) -> Result<ArrayData> { | |||
| let defaults = match &m.default.as_deref() { | |||
| Some([]) => eyre::bail!("empty default value not supported"), | |||
| Some(defaults) => { | |||
| let raw_array: Vec<Arc<dyn Array>> = defaults | |||
| .iter() | |||
| .map(|default| { | |||
| preset_default_for_basic_type(value_type, default) | |||
| .with_context(|| format!("failed to parse default value for `{}`", m.name)) | |||
| .map(make_array) | |||
| }) | |||
| .collect::<Result<_, _>>()?; | |||
| let default_values = concat( | |||
| raw_array | |||
| .iter() | |||
| .map(|data| data.as_ref()) | |||
| .collect::<Vec<_>>() | |||
| .as_slice(), | |||
| ) | |||
| .context("Failed to concatenate default list value")?; | |||
| default_values.to_data() | |||
| } | |||
| None => { | |||
| let default_nested_type = | |||
| default_for_nestable_type(value_type, package_name, messages)?; | |||
| let value_offsets = Buffer::from_slice_ref([0i64, 1]); | |||
| let list_data_type = DataType::List(Arc::new(Field::new( | |||
| &m.name, | |||
| default_nested_type.data_type().clone(), | |||
| true, | |||
| ))); | |||
| // Construct a list array from the above two | |||
| ArrayData::builder(list_data_type) | |||
| .len(1) | |||
| .add_buffer(value_offsets) | |||
| .add_child_data(default_nested_type) | |||
| .build() | |||
| .context("Failed to build default list value")? | |||
| } | |||
| }; | |||
| Ok(defaults) | |||
| } | |||
| } | |||
| @@ -1,158 +0,0 @@ | |||
| use arrow::array::ArrayData; | |||
| use arrow::array::Float32Array; | |||
| use arrow::array::Float64Array; | |||
| use arrow::array::Int16Array; | |||
| use arrow::array::Int32Array; | |||
| use arrow::array::Int64Array; | |||
| use arrow::array::Int8Array; | |||
| use arrow::array::ListArray; | |||
| use arrow::array::StringArray; | |||
| use arrow::array::StructArray; | |||
| use arrow::array::UInt16Array; | |||
| use arrow::array::UInt32Array; | |||
| use arrow::array::UInt64Array; | |||
| use arrow::array::UInt8Array; | |||
| use arrow::datatypes::DataType; | |||
| use serde::ser::SerializeSeq; | |||
| use serde::ser::SerializeStruct; | |||
| use super::TypeInfo; | |||
| #[derive(Debug, Clone, PartialEq)] | |||
| pub struct TypedValue<'a> { | |||
| pub value: &'a ArrayData, | |||
| pub type_info: &'a TypeInfo, | |||
| } | |||
| impl serde::Serialize for TypedValue<'_> { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| match &self.type_info.data_type { | |||
| DataType::UInt8 => { | |||
| let uint_array: UInt8Array = self.value.clone().into(); | |||
| let number = uint_array.value(0); | |||
| serializer.serialize_u8(number) | |||
| } | |||
| DataType::UInt16 => { | |||
| let uint_array: UInt16Array = self.value.clone().into(); | |||
| let number = uint_array.value(0); | |||
| serializer.serialize_u16(number) | |||
| } | |||
| DataType::UInt32 => { | |||
| let uint_array: UInt32Array = self.value.clone().into(); | |||
| let number = uint_array.value(0); | |||
| serializer.serialize_u32(number) | |||
| } | |||
| DataType::UInt64 => { | |||
| let uint_array: UInt64Array = self.value.clone().into(); | |||
| let number = uint_array.value(0); | |||
| serializer.serialize_u64(number) | |||
| } | |||
| DataType::Int8 => { | |||
| let int_array: Int8Array = self.value.clone().into(); | |||
| let number = int_array.value(0); | |||
| serializer.serialize_i8(number) | |||
| } | |||
| DataType::Int16 => { | |||
| let int_array: Int16Array = self.value.clone().into(); | |||
| let number = int_array.value(0); | |||
| serializer.serialize_i16(number) | |||
| } | |||
| DataType::Int32 => { | |||
| let int_array: Int32Array = self.value.clone().into(); | |||
| let number = int_array.value(0); | |||
| serializer.serialize_i32(number) | |||
| } | |||
| DataType::Int64 => { | |||
| let int_array: Int64Array = self.value.clone().into(); | |||
| let number = int_array.value(0); | |||
| serializer.serialize_i64(number) | |||
| } | |||
| DataType::Float32 => { | |||
| let int_array: Float32Array = self.value.clone().into(); | |||
| let number = int_array.value(0); | |||
| serializer.serialize_f32(number) | |||
| } | |||
| DataType::Float64 => { | |||
| let int_array: Float64Array = self.value.clone().into(); | |||
| let number = int_array.value(0); | |||
| serializer.serialize_f64(number) | |||
| } | |||
| DataType::Utf8 => { | |||
| let int_array: StringArray = self.value.clone().into(); | |||
| let string = int_array.value(0); | |||
| serializer.serialize_str(string) | |||
| } | |||
| DataType::List(field) => { | |||
| let list_array: ListArray = self.value.clone().into(); | |||
| let values = list_array.values(); | |||
| let mut s = serializer.serialize_seq(Some(values.len()))?; | |||
| for value in list_array.iter() { | |||
| let value = match value { | |||
| Some(value) => value.to_data(), | |||
| None => { | |||
| return Err(serde::ser::Error::custom( | |||
| "Value in ListArray is null and not yet supported".to_string(), | |||
| )) | |||
| } | |||
| }; | |||
| s.serialize_element(&TypedValue { | |||
| value: &value, | |||
| type_info: &TypeInfo { | |||
| data_type: field.data_type().clone(), | |||
| defaults: self.type_info.defaults.clone(), | |||
| }, | |||
| })?; | |||
| } | |||
| s.end() | |||
| } | |||
| DataType::Struct(fields) => { | |||
| /// Serde requires that struct and field names are known at | |||
| /// compile time with a `'static` lifetime, which is not | |||
| /// possible in this case. Thus, we need to use dummy names | |||
| /// instead. | |||
| /// | |||
| /// The actual names do not really matter because | |||
| /// the CDR format of ROS2 does not encode struct or field | |||
| /// names. | |||
| const DUMMY_STRUCT_NAME: &str = "struct"; | |||
| const DUMMY_FIELD_NAME: &str = "field"; | |||
| let struct_array: StructArray = self.value.clone().into(); | |||
| let mut s = serializer.serialize_struct(DUMMY_STRUCT_NAME, fields.len())?; | |||
| let defaults: StructArray = self.type_info.defaults.clone().into(); | |||
| for field in fields.iter() { | |||
| let default = match defaults.column_by_name(field.name()) { | |||
| Some(value) => value.to_data(), | |||
| None => { | |||
| return Err(serde::ser::Error::custom(format!( | |||
| "missing field {} for serialization", | |||
| &field.name() | |||
| ))) | |||
| } | |||
| }; | |||
| let field_value = match struct_array.column_by_name(field.name()) { | |||
| Some(value) => value.to_data(), | |||
| None => default.clone(), | |||
| }; | |||
| s.serialize_field( | |||
| DUMMY_FIELD_NAME, | |||
| &TypedValue { | |||
| value: &field_value, | |||
| type_info: &TypeInfo { | |||
| data_type: field.data_type().clone(), | |||
| defaults: default, | |||
| }, | |||
| }, | |||
| )?; | |||
| } | |||
| s.end() | |||
| } | |||
| _ => todo!(), | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,259 @@ | |||
| use std::{any::type_name, borrow::Cow, marker::PhantomData, sync::Arc}; | |||
| use arrow::{ | |||
| array::{Array, ArrayRef, AsArray, OffsetSizeTrait, PrimitiveArray}, | |||
| datatypes::{self, ArrowPrimitiveType}, | |||
| }; | |||
| use dora_ros2_bridge_msg_gen::types::{ | |||
| primitives::{BasicType, GenericString, NestableType}, | |||
| sequences, | |||
| }; | |||
| use serde::ser::SerializeTuple; | |||
| use crate::typed::TypeInfo; | |||
| use super::{error, TypedValue}; | |||
| /// Serialize an array with known size as tuple. | |||
| pub struct ArraySerializeWrapper<'a> { | |||
| pub array_info: &'a sequences::Array, | |||
| pub column: &'a ArrayRef, | |||
| pub type_info: &'a TypeInfo<'a>, | |||
| } | |||
| impl serde::Serialize for ArraySerializeWrapper<'_> { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| let entry = if let Some(list) = self.column.as_list_opt::<i32>() { | |||
| // should match the length of the outer struct | |||
| assert_eq!(list.len(), 1); | |||
| list.value(0) | |||
| } else { | |||
| // try as large list | |||
| let list = self | |||
| .column | |||
| .as_list_opt::<i64>() | |||
| .ok_or_else(|| error("value is not compatible with expected array type"))?; | |||
| // should match the length of the outer struct | |||
| assert_eq!(list.len(), 1); | |||
| list.value(0) | |||
| }; | |||
| match &self.array_info.value_type { | |||
| NestableType::BasicType(t) => match t { | |||
| BasicType::I8 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Int8Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::I16 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Int16Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::I32 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Int32Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::I64 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Int64Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::U8 | BasicType::Char | BasicType::Byte => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::UInt8Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::U16 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::UInt16Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::U32 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::UInt32Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::U64 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::UInt64Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::F32 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Float32Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::F64 => BasicArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Float64Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::Bool => BoolArrayAsTuple { | |||
| len: self.array_info.size, | |||
| value: &entry, | |||
| } | |||
| .serialize(serializer), | |||
| }, | |||
| NestableType::NamedType(name) => { | |||
| let array = entry | |||
| .as_struct_opt() | |||
| .ok_or_else(|| error("not a struct array"))?; | |||
| let mut seq = serializer.serialize_tuple(self.array_info.size)?; | |||
| for i in 0..array.len() { | |||
| let row = array.slice(i, 1); | |||
| seq.serialize_element(&TypedValue { | |||
| value: &(Arc::new(row) as ArrayRef), | |||
| type_info: &crate::typed::TypeInfo { | |||
| package_name: Cow::Borrowed(&self.type_info.package_name), | |||
| message_name: Cow::Borrowed(&name.0), | |||
| messages: self.type_info.messages.clone(), | |||
| }, | |||
| })?; | |||
| } | |||
| seq.end() | |||
| } | |||
| NestableType::NamespacedType(reference) => { | |||
| if reference.namespace != "msg" { | |||
| return Err(error(format!( | |||
| "sequence references non-message type {reference:?}" | |||
| ))); | |||
| } | |||
| let array = entry | |||
| .as_struct_opt() | |||
| .ok_or_else(|| error("not a struct array"))?; | |||
| let mut seq = serializer.serialize_tuple(self.array_info.size)?; | |||
| for i in 0..array.len() { | |||
| let row = array.slice(i, 1); | |||
| seq.serialize_element(&TypedValue { | |||
| value: &(Arc::new(row) as ArrayRef), | |||
| type_info: &crate::typed::TypeInfo { | |||
| package_name: Cow::Borrowed(&reference.package), | |||
| message_name: Cow::Borrowed(&reference.name), | |||
| messages: self.type_info.messages.clone(), | |||
| }, | |||
| })?; | |||
| } | |||
| seq.end() | |||
| } | |||
| NestableType::GenericString(s) => match s { | |||
| GenericString::String | GenericString::BoundedString(_) => { | |||
| match entry.as_string_opt::<i32>() { | |||
| Some(array) => { | |||
| serialize_arrow_string(serializer, array, self.array_info.size) | |||
| } | |||
| None => { | |||
| let array = entry | |||
| .as_string_opt::<i64>() | |||
| .ok_or_else(|| error("expected string array"))?; | |||
| serialize_arrow_string(serializer, array, self.array_info.size) | |||
| } | |||
| } | |||
| } | |||
| GenericString::WString => { | |||
| todo!("serializing WString sequences") | |||
| } | |||
| GenericString::BoundedWString(_) => todo!("serializing BoundedWString sequences"), | |||
| }, | |||
| } | |||
| } | |||
| } | |||
| /// Serializes a primitive array with known size as tuple. | |||
| struct BasicArrayAsTuple<'a, T> { | |||
| len: usize, | |||
| value: &'a ArrayRef, | |||
| ty: PhantomData<T>, | |||
| } | |||
| impl<T> serde::Serialize for BasicArrayAsTuple<'_, T> | |||
| where | |||
| T: ArrowPrimitiveType, | |||
| T::Native: serde::Serialize, | |||
| { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| let mut seq = serializer.serialize_tuple(self.len)?; | |||
| let array: &PrimitiveArray<T> = self | |||
| .value | |||
| .as_primitive_opt() | |||
| .ok_or_else(|| error(format!("not a primitive {} array", type_name::<T>())))?; | |||
| if array.len() != self.len { | |||
| return Err(error(format!( | |||
| "expected array with length {}, got length {}", | |||
| self.len, | |||
| array.len() | |||
| ))); | |||
| } | |||
| for value in array.values() { | |||
| seq.serialize_element(value)?; | |||
| } | |||
| seq.end() | |||
| } | |||
| } | |||
| struct BoolArrayAsTuple<'a> { | |||
| len: usize, | |||
| value: &'a ArrayRef, | |||
| } | |||
| impl serde::Serialize for BoolArrayAsTuple<'_> { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| let mut seq = serializer.serialize_tuple(self.len)?; | |||
| let array = self | |||
| .value | |||
| .as_boolean_opt() | |||
| .ok_or_else(|| error("not a boolean array"))?; | |||
| if array.len() != self.len { | |||
| return Err(error(format!( | |||
| "expected array with length {}, got length {}", | |||
| self.len, | |||
| array.len() | |||
| ))); | |||
| } | |||
| for value in array.values() { | |||
| seq.serialize_element(&value)?; | |||
| } | |||
| seq.end() | |||
| } | |||
| } | |||
| fn serialize_arrow_string<S, O>( | |||
| serializer: S, | |||
| array: &arrow::array::GenericByteArray<datatypes::GenericStringType<O>>, | |||
| array_len: usize, | |||
| ) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| O: OffsetSizeTrait, | |||
| { | |||
| let mut seq = serializer.serialize_tuple(array_len)?; | |||
| for s in array.iter() { | |||
| seq.serialize_element(s.unwrap_or_default())?; | |||
| } | |||
| seq.end() | |||
| } | |||
| @@ -0,0 +1,237 @@ | |||
| use arrow::{ | |||
| array::{ | |||
| make_array, Array, ArrayData, BooleanArray, Float32Array, Float64Array, Int16Array, | |||
| Int32Array, Int64Array, Int8Array, ListArray, StringArray, StructArray, UInt16Array, | |||
| UInt32Array, UInt64Array, UInt8Array, | |||
| }, | |||
| buffer::{OffsetBuffer, ScalarBuffer}, | |||
| compute::concat, | |||
| datatypes::Field, | |||
| }; | |||
| use dora_ros2_bridge_msg_gen::types::{ | |||
| primitives::{BasicType, NestableType}, | |||
| MemberType, Message, | |||
| }; | |||
| use eyre::{Context, ContextCompat, Result}; | |||
| use std::{collections::HashMap, sync::Arc, vec}; | |||
| pub fn default_for_member( | |||
| m: &dora_ros2_bridge_msg_gen::types::Member, | |||
| package_name: &str, | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| ) -> eyre::Result<ArrayData> { | |||
| let value = match &m.r#type { | |||
| MemberType::NestableType(t) => match t { | |||
| NestableType::BasicType(_) | NestableType::GenericString(_) => match &m | |||
| .default | |||
| .as_deref() | |||
| { | |||
| Some([]) => eyre::bail!("empty default value not supported"), | |||
| Some([default]) => preset_default_for_basic_type(t, default) | |||
| .with_context(|| format!("failed to parse default value for `{}`", m.name))?, | |||
| Some(_) => eyre::bail!( | |||
| "there should be only a single default value for non-sequence types" | |||
| ), | |||
| None => default_for_nestable_type(t, package_name, messages, 1)?, | |||
| }, | |||
| NestableType::NamedType(_) => { | |||
| if m.default.is_some() { | |||
| eyre::bail!("default values for nested types are not supported") | |||
| } else { | |||
| default_for_nestable_type(t, package_name, messages, 1)? | |||
| } | |||
| } | |||
| NestableType::NamespacedType(_) => { | |||
| default_for_nestable_type(t, package_name, messages, 1)? | |||
| } | |||
| }, | |||
| MemberType::Array(array) => list_default_values( | |||
| m, | |||
| &array.value_type, | |||
| package_name, | |||
| messages, | |||
| Some(array.size), | |||
| )?, | |||
| MemberType::Sequence(seq) => { | |||
| list_default_values(m, &seq.value_type, package_name, messages, None)? | |||
| } | |||
| MemberType::BoundedSequence(seq) => list_default_values( | |||
| m, | |||
| &seq.value_type, | |||
| package_name, | |||
| messages, | |||
| Some(seq.max_size), | |||
| )?, | |||
| }; | |||
| Ok(value) | |||
| } | |||
| fn default_for_nestable_type( | |||
| t: &NestableType, | |||
| package_name: &str, | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| size: usize, | |||
| ) -> Result<ArrayData> { | |||
| let empty = HashMap::new(); | |||
| let package_messages = messages.get(package_name).unwrap_or(&empty); | |||
| let array = match t { | |||
| NestableType::BasicType(t) => match t { | |||
| BasicType::I8 => Int8Array::from(vec![0; size]).into(), | |||
| BasicType::I16 => Int16Array::from(vec![0; size]).into(), | |||
| BasicType::I32 => Int32Array::from(vec![0; size]).into(), | |||
| BasicType::I64 => Int64Array::from(vec![0; size]).into(), | |||
| BasicType::U8 => UInt8Array::from(vec![0; size]).into(), | |||
| BasicType::U16 => UInt16Array::from(vec![0; size]).into(), | |||
| BasicType::U32 => UInt32Array::from(vec![0; size]).into(), | |||
| BasicType::U64 => UInt64Array::from(vec![0; size]).into(), | |||
| BasicType::F32 => Float32Array::from(vec![0.; size]).into(), | |||
| BasicType::F64 => Float64Array::from(vec![0.; size]).into(), | |||
| BasicType::Char => StringArray::from(vec![""]).into(), | |||
| BasicType::Byte => UInt8Array::from(vec![0u8; size]).into(), | |||
| BasicType::Bool => BooleanArray::from(vec![false; size]).into(), | |||
| }, | |||
| NestableType::GenericString(_) => StringArray::from(vec![""]).into(), | |||
| NestableType::NamedType(name) => { | |||
| let referenced_message = package_messages | |||
| .get(&name.0) | |||
| .context("unknown referenced message")?; | |||
| default_for_referenced_message(referenced_message, package_name, messages)? | |||
| } | |||
| NestableType::NamespacedType(t) => { | |||
| let referenced_package_messages = messages.get(&t.package).unwrap_or(&empty); | |||
| let referenced_message = referenced_package_messages | |||
| .get(&t.name) | |||
| .context("unknown referenced message")?; | |||
| default_for_referenced_message(referenced_message, &t.package, messages)? | |||
| } | |||
| }; | |||
| Ok(array) | |||
| } | |||
| fn preset_default_for_basic_type(t: &NestableType, preset: &str) -> Result<ArrayData> { | |||
| Ok(match t { | |||
| NestableType::BasicType(t) => match t { | |||
| BasicType::I8 => Int8Array::from(vec![preset | |||
| .parse::<i8>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::I16 => Int16Array::from(vec![preset | |||
| .parse::<i16>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::I32 => Int32Array::from(vec![preset | |||
| .parse::<i32>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::I64 => Int64Array::from(vec![preset | |||
| .parse::<i64>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::U8 => UInt8Array::from(vec![preset | |||
| .parse::<u8>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::U16 => UInt16Array::from(vec![preset | |||
| .parse::<u16>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::U32 => UInt32Array::from(vec![preset | |||
| .parse::<u32>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::U64 => UInt64Array::from(vec![preset | |||
| .parse::<u64>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::F32 => Float32Array::from(vec![preset | |||
| .parse::<f32>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::F64 => Float64Array::from(vec![preset | |||
| .parse::<f64>() | |||
| .context("Could not parse preset default value")?]) | |||
| .into(), | |||
| BasicType::Char => StringArray::from(vec![preset]).into(), | |||
| BasicType::Byte => UInt8Array::from(preset.as_bytes().to_owned()).into(), | |||
| BasicType::Bool => BooleanArray::from(vec![preset | |||
| .parse::<bool>() | |||
| .context("could not parse preset default value")?]) | |||
| .into(), | |||
| }, | |||
| NestableType::GenericString(_) => StringArray::from(vec![preset]).into(), | |||
| _ => todo!("preset_default_for_basic_type (other)"), | |||
| }) | |||
| } | |||
| fn default_for_referenced_message( | |||
| referenced_message: &Message, | |||
| package_name: &str, | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| ) -> eyre::Result<ArrayData> { | |||
| let fields: Vec<(Arc<Field>, Arc<dyn Array>)> = referenced_message | |||
| .members | |||
| .iter() | |||
| .map(|m| { | |||
| let default = default_for_member(m, package_name, messages)?; | |||
| Result::<_, eyre::Report>::Ok(( | |||
| Arc::new(Field::new( | |||
| m.name.clone(), | |||
| default.data_type().clone(), | |||
| true, | |||
| )), | |||
| make_array(default), | |||
| )) | |||
| }) | |||
| .collect::<Result<_, _>>()?; | |||
| let struct_array: StructArray = fields.into(); | |||
| Ok(struct_array.into()) | |||
| } | |||
| fn list_default_values( | |||
| m: &dora_ros2_bridge_msg_gen::types::Member, | |||
| value_type: &NestableType, | |||
| package_name: &str, | |||
| messages: &HashMap<String, HashMap<String, Message>>, | |||
| size: Option<usize>, | |||
| ) -> Result<ArrayData> { | |||
| let defaults = match &m.default.as_deref() { | |||
| Some([]) => eyre::bail!("empty default value not supported"), | |||
| Some(defaults) => { | |||
| let raw_array: Vec<Arc<dyn Array>> = defaults | |||
| .iter() | |||
| .map(|default| { | |||
| preset_default_for_basic_type(value_type, default) | |||
| .with_context(|| format!("failed to parse default value for `{}`", m.name)) | |||
| .map(make_array) | |||
| }) | |||
| .collect::<Result<_, _>>()?; | |||
| let default_values = concat( | |||
| raw_array | |||
| .iter() | |||
| .map(|data| data.as_ref()) | |||
| .collect::<Vec<_>>() | |||
| .as_slice(), | |||
| ) | |||
| .context("Failed to concatenate default list value")?; | |||
| default_values.to_data() | |||
| } | |||
| None => { | |||
| let size = size.unwrap_or(1); | |||
| let default_nested_type = | |||
| default_for_nestable_type(value_type, package_name, messages, size)?; | |||
| let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, size as i32])); | |||
| let field = Arc::new(Field::new( | |||
| "item", | |||
| default_nested_type.data_type().clone(), | |||
| true, | |||
| )); | |||
| let list = ListArray::new(field, offsets, make_array(default_nested_type), None); | |||
| list.to_data() | |||
| } | |||
| }; | |||
| Ok(defaults) | |||
| } | |||
| @@ -0,0 +1,205 @@ | |||
| use std::{borrow::Cow, collections::HashMap, fmt::Display}; | |||
| use arrow::{ | |||
| array::{Array, ArrayRef, AsArray}, | |||
| error, | |||
| }; | |||
| use dora_ros2_bridge_msg_gen::types::{ | |||
| primitives::{GenericString, NestableType}, | |||
| MemberType, | |||
| }; | |||
| use eyre::Context; | |||
| use serde::ser::SerializeTupleStruct; | |||
| use super::{TypeInfo, DUMMY_STRUCT_NAME}; | |||
| mod array; | |||
| mod defaults; | |||
| mod primitive; | |||
| mod sequence; | |||
| #[derive(Debug, Clone)] | |||
| pub struct TypedValue<'a> { | |||
| pub value: &'a ArrayRef, | |||
| pub type_info: &'a TypeInfo<'a>, | |||
| } | |||
| impl serde::Serialize for TypedValue<'_> { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| let empty = HashMap::new(); | |||
| let package_messages = self | |||
| .type_info | |||
| .messages | |||
| .get(self.type_info.package_name.as_ref()) | |||
| .unwrap_or(&empty); | |||
| let message = package_messages | |||
| .get(self.type_info.message_name.as_ref()) | |||
| .ok_or_else(|| { | |||
| error(format!( | |||
| "could not find message type {}::{}", | |||
| self.type_info.package_name, self.type_info.message_name | |||
| )) | |||
| })?; | |||
| let input = self.value.as_struct_opt().ok_or_else(|| { | |||
| error(format!( | |||
| "expected struct array for message: {}, with following format: {:#?} \n But, got type: {:#?}", | |||
| self.type_info.message_name, message, self.value.data_type() | |||
| )) | |||
| })?; | |||
| for column_name in input.column_names() { | |||
| if !message.members.iter().any(|m| m.name == column_name) { | |||
| return Err(error(format!( | |||
| "given struct has unknown field {column_name}" | |||
| )))?; | |||
| } | |||
| } | |||
| if input.is_empty() { | |||
| // TODO: publish default value | |||
| return Err(error("given struct is empty"))?; | |||
| } | |||
| if input.len() > 1 { | |||
| return Err(error(format!( | |||
| "expected single struct instance, got struct array with {} entries", | |||
| input.len() | |||
| )))?; | |||
| } | |||
| let mut s = serializer.serialize_tuple_struct(DUMMY_STRUCT_NAME, message.members.len())?; | |||
| for field in message.members.iter() { | |||
| let column: Cow<_> = match input.column_by_name(&field.name) { | |||
| Some(input) => Cow::Borrowed(input), | |||
| None => { | |||
| let default = defaults::default_for_member( | |||
| field, | |||
| &self.type_info.package_name, | |||
| &self.type_info.messages, | |||
| ) | |||
| .with_context(|| { | |||
| format!( | |||
| "failed to calculate default value for field {}.{}", | |||
| message.name, field.name | |||
| ) | |||
| }) | |||
| .map_err(|e| error(format!("{e:?}")))?; | |||
| Cow::Owned(arrow::array::make_array(default)) | |||
| } | |||
| }; | |||
| self.serialize_field::<S>(field, column, &mut s) | |||
| .map_err(|e| { | |||
| error(format!( | |||
| "failed to serialize field {}.{}: {e}", | |||
| message.name, field.name | |||
| )) | |||
| })?; | |||
| } | |||
| s.end() | |||
| } | |||
| } | |||
| impl<'a> TypedValue<'a> { | |||
| fn serialize_field<S>( | |||
| &self, | |||
| field: &dora_ros2_bridge_msg_gen::types::Member, | |||
| column: Cow<'_, std::sync::Arc<dyn Array>>, | |||
| s: &mut S::SerializeTupleStruct, | |||
| ) -> Result<(), S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| match &field.r#type { | |||
| MemberType::NestableType(t) => match t { | |||
| NestableType::BasicType(t) => { | |||
| s.serialize_field(&primitive::SerializeWrapper { | |||
| t, | |||
| column: column.as_ref(), | |||
| })?; | |||
| } | |||
| NestableType::NamedType(name) => { | |||
| let referenced_value = &TypedValue { | |||
| value: column.as_ref(), | |||
| type_info: &TypeInfo { | |||
| package_name: Cow::Borrowed(&self.type_info.package_name), | |||
| message_name: Cow::Borrowed(&name.0), | |||
| messages: self.type_info.messages.clone(), | |||
| }, | |||
| }; | |||
| s.serialize_field(&referenced_value)?; | |||
| } | |||
| NestableType::NamespacedType(reference) => { | |||
| if reference.namespace != "msg" { | |||
| return Err(error(format!( | |||
| "struct field {} references non-message type {reference:?}", | |||
| field.name | |||
| ))); | |||
| } | |||
| let referenced_value: &TypedValue<'_> = &TypedValue { | |||
| value: column.as_ref(), | |||
| type_info: &TypeInfo { | |||
| package_name: Cow::Borrowed(&reference.package), | |||
| message_name: Cow::Borrowed(&reference.name), | |||
| messages: self.type_info.messages.clone(), | |||
| }, | |||
| }; | |||
| s.serialize_field(&referenced_value)?; | |||
| } | |||
| NestableType::GenericString(t) => match t { | |||
| GenericString::String | GenericString::BoundedString(_) => { | |||
| let string = if let Some(string_array) = column.as_string_opt::<i32>() { | |||
| // should match the length of the outer struct array | |||
| assert_eq!(string_array.len(), 1); | |||
| string_array.value(0) | |||
| } else { | |||
| // try again with large offset type | |||
| let string_array = column | |||
| .as_string_opt::<i64>() | |||
| .ok_or_else(|| error("expected string array"))?; | |||
| // should match the length of the outer struct array | |||
| assert_eq!(string_array.len(), 1); | |||
| string_array.value(0) | |||
| }; | |||
| s.serialize_field(string)?; | |||
| } | |||
| GenericString::WString => todo!("serializing WString types"), | |||
| GenericString::BoundedWString(_) => { | |||
| todo!("serializing BoundedWString types") | |||
| } | |||
| }, | |||
| }, | |||
| dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { | |||
| s.serialize_field(&array::ArraySerializeWrapper { | |||
| array_info: a, | |||
| column: column.as_ref(), | |||
| type_info: self.type_info, | |||
| })?; | |||
| } | |||
| dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => { | |||
| s.serialize_field(&sequence::SequenceSerializeWrapper { | |||
| item_type: &v.value_type, | |||
| column: column.as_ref(), | |||
| type_info: self.type_info, | |||
| })?; | |||
| } | |||
| dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => { | |||
| s.serialize_field(&sequence::SequenceSerializeWrapper { | |||
| item_type: &v.value_type, | |||
| column: column.as_ref(), | |||
| type_info: self.type_info, | |||
| })?; | |||
| } | |||
| } | |||
| Ok(()) | |||
| } | |||
| } | |||
| fn error<E, T>(e: T) -> E | |||
| where | |||
| T: Display, | |||
| E: serde::ser::Error, | |||
| { | |||
| serde::ser::Error::custom(e) | |||
| } | |||
| @@ -0,0 +1,79 @@ | |||
| use arrow::{ | |||
| array::{ArrayRef, AsArray}, | |||
| datatypes::{self, ArrowPrimitiveType}, | |||
| }; | |||
| use dora_ros2_bridge_msg_gen::types::primitives::BasicType; | |||
| pub struct SerializeWrapper<'a> { | |||
| pub t: &'a BasicType, | |||
| pub column: &'a ArrayRef, | |||
| } | |||
| impl serde::Serialize for SerializeWrapper<'_> { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| match self.t { | |||
| BasicType::I8 => { | |||
| serializer.serialize_i8(as_single_primitive::<datatypes::Int8Type, _>(self.column)?) | |||
| } | |||
| BasicType::I16 => serializer | |||
| .serialize_i16(as_single_primitive::<datatypes::Int16Type, _>(self.column)?), | |||
| BasicType::I32 => serializer | |||
| .serialize_i32(as_single_primitive::<datatypes::Int32Type, _>(self.column)?), | |||
| BasicType::I64 => serializer | |||
| .serialize_i64(as_single_primitive::<datatypes::Int64Type, _>(self.column)?), | |||
| BasicType::U8 | BasicType::Char | BasicType::Byte => serializer | |||
| .serialize_u8(as_single_primitive::<datatypes::UInt8Type, _>(self.column)?), | |||
| BasicType::U16 => serializer | |||
| .serialize_u16(as_single_primitive::<datatypes::UInt16Type, _>( | |||
| self.column, | |||
| )?), | |||
| BasicType::U32 => serializer | |||
| .serialize_u32(as_single_primitive::<datatypes::UInt32Type, _>( | |||
| self.column, | |||
| )?), | |||
| BasicType::U64 => serializer | |||
| .serialize_u64(as_single_primitive::<datatypes::UInt64Type, _>( | |||
| self.column, | |||
| )?), | |||
| BasicType::F32 => serializer | |||
| .serialize_f32(as_single_primitive::<datatypes::Float32Type, _>( | |||
| self.column, | |||
| )?), | |||
| BasicType::F64 => serializer | |||
| .serialize_f64(as_single_primitive::<datatypes::Float64Type, _>( | |||
| self.column, | |||
| )?), | |||
| BasicType::Bool => { | |||
| let array = self.column.as_boolean_opt().ok_or_else(|| { | |||
| serde::ser::Error::custom( | |||
| "value is not compatible with expected `BooleanArray` type", | |||
| ) | |||
| })?; | |||
| // should match the length of the outer struct | |||
| assert_eq!(array.len(), 1); | |||
| let field_value = array.value(0); | |||
| serializer.serialize_bool(field_value) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| fn as_single_primitive<T, E>(column: &ArrayRef) -> Result<T::Native, E> | |||
| where | |||
| T: ArrowPrimitiveType, | |||
| E: serde::ser::Error, | |||
| { | |||
| let array: &arrow::array::PrimitiveArray<T> = column.as_primitive_opt().ok_or_else(|| { | |||
| serde::ser::Error::custom(format!( | |||
| "value is not compatible with expected `{}` type", | |||
| std::any::type_name::<T::Native>() | |||
| )) | |||
| })?; | |||
| // should match the length of the outer struct | |||
| assert_eq!(array.len(), 1); | |||
| let number = array.value(0); | |||
| Ok(number) | |||
| } | |||
| @@ -0,0 +1,268 @@ | |||
| use std::{any::type_name, borrow::Cow, marker::PhantomData, sync::Arc}; | |||
| use arrow::{ | |||
| array::{Array, ArrayRef, AsArray, OffsetSizeTrait, PrimitiveArray}, | |||
| datatypes::{self, ArrowPrimitiveType, UInt8Type}, | |||
| }; | |||
| use dora_ros2_bridge_msg_gen::types::primitives::{BasicType, GenericString, NestableType}; | |||
| use serde::ser::{SerializeSeq, SerializeTuple}; | |||
| use crate::typed::TypeInfo; | |||
| use super::{error, TypedValue}; | |||
| /// Serialize a variable-sized sequence. | |||
| pub struct SequenceSerializeWrapper<'a> { | |||
| pub item_type: &'a NestableType, | |||
| pub column: &'a ArrayRef, | |||
| pub type_info: &'a TypeInfo<'a>, | |||
| } | |||
| impl serde::Serialize for SequenceSerializeWrapper<'_> { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| let entry = if let Some(list) = self.column.as_list_opt::<i32>() { | |||
| // should match the length of the outer struct | |||
| assert_eq!(list.len(), 1); | |||
| list.value(0) | |||
| } else if let Some(list) = self.column.as_list_opt::<i64>() { | |||
| // should match the length of the outer struct | |||
| assert_eq!(list.len(), 1); | |||
| list.value(0) | |||
| } else if let Some(list) = self.column.as_binary_opt::<i32>() { | |||
| // should match the length of the outer struct | |||
| assert_eq!(list.len(), 1); | |||
| Arc::new(list.slice(0, 1)) as ArrayRef | |||
| } else if let Some(list) = self.column.as_binary_opt::<i64>() { | |||
| // should match the length of the outer struct | |||
| assert_eq!(list.len(), 1); | |||
| Arc::new(list.slice(0, 1)) as ArrayRef | |||
| } else { | |||
| return Err(error(format!( | |||
| "value is not compatible with expected sequence type: {:?}", | |||
| self.column | |||
| ))); | |||
| }; | |||
| match &self.item_type { | |||
| NestableType::BasicType(t) => match t { | |||
| BasicType::I8 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Int8Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::I16 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Int16Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::I32 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Int32Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::I64 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Int64Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::U8 | BasicType::Char | BasicType::Byte => { | |||
| ByteSequence { value: &entry }.serialize(serializer) | |||
| } | |||
| BasicType::U16 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::UInt16Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::U32 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::UInt32Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::U64 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::UInt64Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::F32 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Float32Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::F64 => BasicSequence { | |||
| value: &entry, | |||
| ty: PhantomData::<datatypes::Float64Type>, | |||
| } | |||
| .serialize(serializer), | |||
| BasicType::Bool => BoolArray { value: &entry }.serialize(serializer), | |||
| }, | |||
| NestableType::NamedType(name) => { | |||
| let array = entry | |||
| .as_struct_opt() | |||
| .ok_or_else(|| error("not a struct array"))?; | |||
| let mut seq = serializer.serialize_seq(Some(array.len()))?; | |||
| for i in 0..array.len() { | |||
| let row = array.slice(i, 1); | |||
| seq.serialize_element(&TypedValue { | |||
| value: &(Arc::new(row) as ArrayRef), | |||
| type_info: &crate::typed::TypeInfo { | |||
| package_name: Cow::Borrowed(&self.type_info.package_name), | |||
| message_name: Cow::Borrowed(&name.0), | |||
| messages: self.type_info.messages.clone(), | |||
| }, | |||
| })?; | |||
| } | |||
| seq.end() | |||
| } | |||
| NestableType::NamespacedType(reference) => { | |||
| if reference.namespace != "msg" { | |||
| return Err(error(format!( | |||
| "sequence references non-message type {reference:?}" | |||
| ))); | |||
| } | |||
| let array = entry | |||
| .as_struct_opt() | |||
| .ok_or_else(|| error("not a struct array"))?; | |||
| let mut seq = serializer.serialize_seq(Some(array.len()))?; | |||
| for i in 0..array.len() { | |||
| let row = array.slice(i, 1); | |||
| seq.serialize_element(&TypedValue { | |||
| value: &(Arc::new(row) as ArrayRef), | |||
| type_info: &crate::typed::TypeInfo { | |||
| package_name: Cow::Borrowed(&reference.package), | |||
| message_name: Cow::Borrowed(&reference.name), | |||
| messages: self.type_info.messages.clone(), | |||
| }, | |||
| })?; | |||
| } | |||
| seq.end() | |||
| } | |||
| NestableType::GenericString(s) => match s { | |||
| GenericString::String | GenericString::BoundedString(_) => { | |||
| match entry.as_string_opt::<i32>() { | |||
| Some(array) => serialize_arrow_string(serializer, array), | |||
| None => { | |||
| let array = entry | |||
| .as_string_opt::<i64>() | |||
| .ok_or_else(|| error("expected string array"))?; | |||
| serialize_arrow_string(serializer, array) | |||
| } | |||
| } | |||
| } | |||
| GenericString::WString => { | |||
| todo!("serializing WString sequences") | |||
| } | |||
| GenericString::BoundedWString(_) => todo!("serializing BoundedWString sequences"), | |||
| }, | |||
| } | |||
| } | |||
| } | |||
| fn serialize_arrow_string<S, O>( | |||
| serializer: S, | |||
| array: &arrow::array::GenericByteArray<datatypes::GenericStringType<O>>, | |||
| ) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| O: OffsetSizeTrait, | |||
| { | |||
| let mut seq = serializer.serialize_seq(Some(array.len()))?; | |||
| for s in array.iter() { | |||
| seq.serialize_element(s.unwrap_or_default())?; | |||
| } | |||
| seq.end() | |||
| } | |||
| struct BasicSequence<'a, T> { | |||
| value: &'a ArrayRef, | |||
| ty: PhantomData<T>, | |||
| } | |||
| impl<T> serde::Serialize for BasicSequence<'_, T> | |||
| where | |||
| T: ArrowPrimitiveType, | |||
| T::Native: serde::Serialize, | |||
| { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| let array: &PrimitiveArray<T> = self | |||
| .value | |||
| .as_primitive_opt() | |||
| .ok_or_else(|| error(format!("not a primitive {} array", type_name::<T>())))?; | |||
| let mut seq = serializer.serialize_seq(Some(array.len()))?; | |||
| for value in array.values() { | |||
| seq.serialize_element(value)?; | |||
| } | |||
| seq.end() | |||
| } | |||
| } | |||
| struct ByteSequence<'a> { | |||
| value: &'a ArrayRef, | |||
| } | |||
| impl serde::Serialize for ByteSequence<'_> { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| if let Some(binary) = self.value.as_binary_opt::<i32>() { | |||
| serialize_binary(serializer, binary) | |||
| } else if let Some(binary) = self.value.as_binary_opt::<i64>() { | |||
| serialize_binary(serializer, binary) | |||
| } else { | |||
| BasicSequence { | |||
| value: self.value, | |||
| ty: PhantomData::<UInt8Type>, | |||
| } | |||
| .serialize(serializer) | |||
| } | |||
| } | |||
| } | |||
| fn serialize_binary<S, O>( | |||
| serializer: S, | |||
| binary: &arrow::array::GenericByteArray<datatypes::GenericBinaryType<O>>, | |||
| ) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| O: OffsetSizeTrait, | |||
| { | |||
| let mut seq = serializer.serialize_seq(Some(binary.len()))?; | |||
| for value in binary.iter() { | |||
| seq.serialize_element(value.unwrap_or_default())?; | |||
| } | |||
| seq.end() | |||
| } | |||
| struct BoolArray<'a> { | |||
| value: &'a ArrayRef, | |||
| } | |||
| impl serde::Serialize for BoolArray<'_> { | |||
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |||
| where | |||
| S: serde::Serializer, | |||
| { | |||
| let array = self | |||
| .value | |||
| .as_boolean_opt() | |||
| .ok_or_else(|| error("not a boolean array"))?; | |||
| let mut seq = serializer.serialize_tuple(array.len())?; | |||
| for value in array.values() { | |||
| seq.serialize_element(&value)?; | |||
| } | |||
| seq.end() | |||
| } | |||
| } | |||
| @@ -0,0 +1,284 @@ | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| # Marker Message Example | |||
| TEST_ARRAYS = [ | |||
| ("std_msgs", "UInt8", pa.array([{"data": np.uint8(2)}])), | |||
| ( | |||
| "std_msgs", | |||
| "String", | |||
| pa.array([{"data": "hello"}]), | |||
| ), | |||
| ( | |||
| "std_msgs", | |||
| "UInt8MultiArray", | |||
| pa.array( | |||
| [ | |||
| { | |||
| "data": np.array([1, 2, 3, 4], np.uint8), | |||
| "layout": { | |||
| "dim": [ | |||
| { | |||
| "label": "a", | |||
| "size": np.uint32(10), | |||
| "stride": np.uint32(20), | |||
| } | |||
| ], | |||
| "data_offset": np.uint32(30), | |||
| }, | |||
| } | |||
| ] | |||
| ), | |||
| ), | |||
| ( | |||
| "std_msgs", | |||
| "Float32MultiArray", | |||
| pa.array( | |||
| [ | |||
| { | |||
| "data": np.array([1, 2, 3, 4], np.float32), | |||
| "layout": { | |||
| "dim": [ | |||
| { | |||
| "label": "a", | |||
| "size": np.uint32(10), | |||
| "stride": np.uint32(20), | |||
| } | |||
| ], | |||
| "data_offset": np.uint32(30), | |||
| }, | |||
| } | |||
| ] | |||
| ), | |||
| ), | |||
| ( | |||
| "visualization_msgs", | |||
| "Marker", | |||
| pa.array( | |||
| [ | |||
| { | |||
| "header": { | |||
| "frame_id": "world", # Placeholder value (String type, no numpy equivalent) | |||
| }, | |||
| "ns": "my_namespace", # Placeholder value (String type, no numpy equivalent) | |||
| "id": np.int32(1), # Numpy type | |||
| "type": np.int32(0), # Numpy type (ARROW) | |||
| "action": np.int32(0), # Numpy type (ADD) | |||
| "lifetime": { | |||
| "sec": np.int32(1), | |||
| "nanosec": np.uint32(2), | |||
| }, # Numpy type | |||
| "pose": { | |||
| "position": { | |||
| "x": np.float64(1.0), # Numpy type | |||
| "y": np.float64(2.0), # Numpy type | |||
| "z": np.float64(3.0), # Numpy type | |||
| }, | |||
| "orientation": { | |||
| "x": np.float64(0.0), # Numpy type | |||
| "y": np.float64(0.0), # Numpy type | |||
| "z": np.float64(0.0), # Numpy type | |||
| "w": np.float64(1.0), # Numpy type | |||
| }, | |||
| }, | |||
| "scale": { | |||
| "x": np.float64(1.0), # Numpy type | |||
| "y": np.float64(1.0), # Numpy type | |||
| "z": np.float64(1.0), # Numpy type | |||
| }, | |||
| "color": { | |||
| "r": np.float32(1.0), # Numpy type | |||
| "g": np.float32(0.0), # Numpy type | |||
| "b": np.float32(0.0), # Numpy type | |||
| "a": np.float32(1.0), # Numpy type (alpha) | |||
| }, | |||
| "frame_locked": False, # Boolean type, no numpy equivalent | |||
| "points": [ # Numpy array for points | |||
| { | |||
| "x": np.float64(1.0), # Numpy type | |||
| "y": np.float64(1.0), # Numpy type | |||
| "z": np.float64(1.0), # Numpy type | |||
| } | |||
| ], | |||
| "colors": [ | |||
| { | |||
| "r": np.float32(1.0), # Numpy type | |||
| "g": np.float32(1.0), # Numpy type | |||
| "b": np.float32(1.0), # Numpy type | |||
| "a": np.float32(1.0), # Numpy type (alpha) | |||
| } # Numpy array for colors | |||
| ], | |||
| "texture_resource": "", | |||
| "uv_coordinates": [{}], | |||
| "text": "", | |||
| "mesh_resource": "", | |||
| "mesh_use_embedded_materials": False, # Boolean type, no numpy equivalent | |||
| } | |||
| ] | |||
| ), | |||
| ), | |||
| ( | |||
| "visualization_msgs", | |||
| "MarkerArray", | |||
| pa.array( | |||
| [ | |||
| { | |||
| "markers": [ | |||
| { | |||
| "header": { | |||
| "frame_id": "world", # Placeholder value (String type, no numpy equivalent) | |||
| }, | |||
| "ns": "my_namespace", # Placeholder value (String type, no numpy equivalent) | |||
| "id": np.int32(1), # Numpy type | |||
| "type": np.int32(0), # Numpy type (ARROW) | |||
| "action": np.int32(0), # Numpy type (ADD) | |||
| "lifetime": { | |||
| "sec": np.int32(1), | |||
| "nanosec": np.uint32(2), | |||
| }, # Numpy type | |||
| "pose": { | |||
| "position": { | |||
| "x": np.float64(1.0), # Numpy type | |||
| "y": np.float64(2.0), # Numpy type | |||
| "z": np.float64(3.0), # Numpy type | |||
| }, | |||
| "orientation": { | |||
| "x": np.float64(0.0), # Numpy type | |||
| "y": np.float64(0.0), # Numpy type | |||
| "z": np.float64(0.0), # Numpy type | |||
| "w": np.float64(1.0), # Numpy type | |||
| }, | |||
| }, | |||
| "scale": { | |||
| "x": np.float64(1.0), # Numpy type | |||
| "y": np.float64(1.0), # Numpy type | |||
| "z": np.float64(1.0), # Numpy type | |||
| }, | |||
| "color": { | |||
| "r": np.float32(1.0), # Numpy type | |||
| "g": np.float32(0.0), # Numpy type | |||
| "b": np.float32(0.0), # Numpy type | |||
| "a": np.float32(1.0), # Numpy type (alpha) | |||
| }, | |||
| "frame_locked": False, # Boolean type, no numpy equivalent | |||
| "points": [ # Numpy array for points | |||
| { | |||
| "x": np.float64(1.0), # Numpy type | |||
| "y": np.float64(1.0), # Numpy type | |||
| "z": np.float64(1.0), # Numpy type | |||
| } | |||
| ], | |||
| "colors": [ | |||
| { | |||
| "r": np.float32(1.0), # Numpy type | |||
| "g": np.float32(1.0), # Numpy type | |||
| "b": np.float32(1.0), # Numpy type | |||
| "a": np.float32(1.0), # Numpy type (alpha) | |||
| } # Numpy array for colors | |||
| ], | |||
| "texture_resource": "", | |||
| "uv_coordinates": [{}], | |||
| "text": "", | |||
| "mesh_resource": "", | |||
| "mesh_use_embedded_materials": False, # Boolean type, no numpy equivalent | |||
| } | |||
| ] | |||
| } | |||
| ] | |||
| ), | |||
| ), | |||
| ( | |||
| "visualization_msgs", | |||
| "ImageMarker", | |||
| pa.array( | |||
| [ | |||
| { | |||
| "header": { | |||
| "stamp": { | |||
| "sec": np.int32(123456), # 32-bit integer | |||
| "nanosec": np.uint32(789), # 32-bit unsigned integer | |||
| }, | |||
| "frame_id": "frame_example", | |||
| }, | |||
| "ns": "namespace", | |||
| "id": np.int32(1), # 32-bit integer | |||
| "type": np.int32(0), # 32-bit integer, e.g., CIRCLE = 0 | |||
| "action": np.int32(0), # 32-bit integer, e.g., ADD = 0 | |||
| "position": { | |||
| "x": np.float64(1.0), # 32-bit float | |||
| "y": np.float64(2.0), # 32-bit float | |||
| "z": np.float64(3.0), # 32-bit float | |||
| }, | |||
| "scale": np.float32(1.0), # 32-bit float | |||
| "outline_color": { | |||
| "r": np.float32(255.0), # 32-bit float | |||
| "g": np.float32(0.0), # 32-bit float | |||
| "b": np.float32(0.0), # 32-bit float | |||
| "a": np.float32(1.0), # 32-bit float | |||
| }, | |||
| "filled": np.uint8(1), # 8-bit unsigned integer | |||
| "fill_color": { | |||
| "r": np.float32(0.0), # 32-bit float | |||
| "g": np.float32(255.0), # 32-bit float | |||
| "b": np.float32(0.0), # 32-bit float | |||
| "a": np.float32(1.0), # 32-bit float | |||
| }, | |||
| "lifetime": { | |||
| "sec": np.int32(300), # 32-bit integer | |||
| "nanosec": np.uint32(0), # 32-bit unsigned integer | |||
| }, | |||
| "points": [ | |||
| { | |||
| "x": np.float64(1.0), # 32-bit float | |||
| "y": np.float64(2.0), # 32-bit float | |||
| "z": np.float64(3.0), # 32-bit float | |||
| }, | |||
| { | |||
| "x": np.float64(4.0), # 32-bit float | |||
| "y": np.float64(5.0), # 32-bit float | |||
| "z": np.float64(6.0), # 32-bit float | |||
| }, | |||
| ], | |||
| "outline_colors": [ | |||
| { | |||
| "r": np.float32(255.0), # 32-bit float | |||
| "g": np.float32(0.0), # 32-bit float | |||
| "b": np.float32(0.0), # 32-bit float | |||
| "a": np.float32(1.0), # 32-bit float | |||
| }, | |||
| { | |||
| "r": np.float32(0.0), # 32-bit float | |||
| "g": np.float32(255.0), # 32-bit float | |||
| "b": np.float32(0.0), # 32-bit float | |||
| "a": np.float32(1.0), # 32-bit float | |||
| }, | |||
| ], | |||
| } | |||
| ] | |||
| ), | |||
| ), | |||
| ] | |||
| def is_subset(subset, superset): | |||
| """ | |||
| Check if subset is a subset of superset, to avoid false negatives linked to default values. | |||
| """ | |||
| if isinstance(subset, pa.Array): | |||
| return is_subset(subset.to_pylist(), superset.to_pylist()) | |||
| match subset: | |||
| case dict(_): | |||
| return all( | |||
| key in superset and is_subset(val, superset[key]) | |||
| for key, val in subset.items() | |||
| ) | |||
| case list(_) | set(_): | |||
| return all( | |||
| any(is_subset(subitem, superitem) for superitem in superset) | |||
| for subitem in subset | |||
| ) | |||
| # assume that subset is a plain value if none of the above match | |||
| case _: | |||
| return subset == superset | |||
| @@ -9,8 +9,8 @@ license.workspace = true | |||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |||
| [dependencies] | |||
| futures = "0.3.21" | |||
| opentelemetry = { version = "0.21", features = ["metrics"] } | |||
| opentelemetry-otlp = { version = "0.14.0", features = ["tonic", "metrics"] } | |||
| tokio = { version = "1.24.2", features = ["full"] } | |||
| opentelemetry_sdk = { version = "0.21.0", features = ["rt-tokio", "metrics"] } | |||
| opentelemetry_sdk = { version = "0.21", features = ["rt-tokio", "metrics"] } | |||
| eyre = "0.6.12" | |||
| opentelemetry-system-metrics = { version = "0.1.6" } | |||
| @@ -10,10 +10,13 @@ | |||
| //! [`sysinfo`]: https://github.com/GuillaumeGomez/sysinfo | |||
| //! [`opentelemetry-rust`]: https://github.com/open-telemetry/opentelemetry-rust | |||
| use opentelemetry::metrics::{self}; | |||
| use opentelemetry_sdk::{metrics::MeterProvider, runtime}; | |||
| use std::time::Duration; | |||
| use eyre::{Context, Result}; | |||
| use opentelemetry::metrics::{self, MeterProvider as _}; | |||
| use opentelemetry_otlp::{ExportConfig, WithExportConfig}; | |||
| use opentelemetry_sdk::{metrics::MeterProvider, runtime}; | |||
| use opentelemetry_system_metrics::init_process_observer; | |||
| /// Init opentelemetry meter | |||
| /// | |||
| /// Use the default Opentelemetry exporter with default config | |||
| @@ -34,5 +37,13 @@ pub fn init_metrics() -> metrics::Result<MeterProvider> { | |||
| .tonic() | |||
| .with_export_config(export_config), | |||
| ) | |||
| .with_period(Duration::from_secs(10)) | |||
| .build() | |||
| } | |||
| pub fn init_meter_provider(meter_id: String) -> Result<MeterProvider> { | |||
| let meter_provider = init_metrics().context("Could not create opentelemetry meter")?; | |||
| let meter = meter_provider.meter(meter_id); | |||
| let _ = init_process_observer(meter).context("could not initiale system metrics observer")?; | |||
| Ok(meter_provider) | |||
| } | |||
| @@ -11,7 +11,6 @@ license.workspace = true | |||
| [features] | |||
| [dependencies] | |||
| tokio = { version = "1.24.2", features = ["full"] } | |||
| tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } | |||
| tracing-opentelemetry = { version = "0.18.0" } | |||
| eyre = "0.6.8" | |||