Browse Source

added inference node, try to fix CI issue

pull/1041/head
ShashwatPatil 6 months ago
parent
commit
8bff965b81
10 changed files with 398 additions and 1 deletions
  1. +0
    -0
      examples/lerobot-dataset/Readme.md
  2. +0
    -0
      examples/lerobot-dataset/dataset_record.yml
  3. +67
    -0
      examples/lerobot-dataset/policy_inference.yml
  4. +7
    -1
      node-hub/dora-dataset-record/pyproject.toml
  5. +40
    -0
      node-hub/dora-policy-inference/README.md
  6. +13
    -0
      node-hub/dora-policy-inference/dora_policy_inference/__init__.py
  7. +6
    -0
      node-hub/dora-policy-inference/dora_policy_inference/__main__.py
  8. +216
    -0
      node-hub/dora-policy-inference/dora_policy_inference/main.py
  9. +36
    -0
      node-hub/dora-policy-inference/pyproject.toml
  10. +13
    -0
      node-hub/dora-policy-inference/tests/test_dora_policy_inference.py

examples/lerobot-dataset-record/Readme.md → examples/lerobot-dataset/Readme.md View File


examples/lerobot-dataset-record/dataflow.yml → examples/lerobot-dataset/dataset_record.yml View File


+ 67
- 0
examples/lerobot-dataset/policy_inference.yml View File

@@ -0,0 +1,67 @@
nodes:
# NOTE: The camera inputs should match the training setup
- id: laptop_cam
build: pip install opencv-video-capture
path: opencv-video-capture
inputs:
tick: dora/timer/millis/33
outputs:
- image
env:
CAPTURE_PATH: "0"
ENCODING: "rgb8"
IMAGE_WIDTH: "640"
IMAGE_HEIGHT: "480"

- id: front_cam
build: pip install opencv-video-capture
path: opencv-video-capture
inputs:
tick: dora/timer/millis/33
outputs:
- image
env:
CAPTURE_PATH: "1" # ← UPDATE: Your external camera device
ENCODING: "rgb8"
IMAGE_WIDTH: "640"
IMAGE_HEIGHT: "480"

- id: so101_follower
build: pip install -e ../../node-hub/dora-rustypot
path: dora-rustypot
inputs:
tick: dora/timer/millis/10
pose: policy_inference/robot_action
outputs:
- pose
env:
PORT: "/dev/ttyACM1" # UPDATE your robot port
IDS: "1,2,3,4,5,6"

- id: policy_inference
build: pip install -e ../../node-hub/dora-policy-inference
path: dora-policy-inference
inputs:
laptop: laptop_cam/image # the camera inputs should match the training setup
front: front_cam/image
robot_state: so101_follower/pose
outputs:
- robot_action
- status
env:
MODEL_PATH: "./outputs/train/your_policy/checkpoints/last/pretrained_model" # Path to your trained model
TASK_DESCRIPTION: "Your task"
ROBOT_TYPE: "so101_follower"
CAMERA_NAMES: "laptop, front"
CAMERA_FRONT_RESOLUTION: "480,640,3"
CAMERA_LAPTOP_RESOLUTION: "480,640,3"
INFERENCE_FPS: "30"

# Visualization in rerun (optional)
- id: plot
build: pip install dora-rerun
path: dora-rerun
inputs:
image_laptop: laptop_cam/image
image_front: front_cam/image
status: policy_inference/status

+ 7
- 1
node-hub/dora-dataset-record/pyproject.toml View File

@@ -27,4 +27,10 @@ ignore = [
]

[tool.uv.sources]
lerobot = {git = "https://github.com/huggingface/lerobot.git", rev = "main"}
lerobot = {git = "https://github.com/huggingface/lerobot.git", rev = "main"}

[tool.uv.pip]
git-config = [
"filter.lfs.smudge=git-lfs smudge --skip -- %f",
"filter.lfs.process=git-lfs filter-process --skip"
]

+ 40
- 0
node-hub/dora-policy-inference/README.md View File

@@ -0,0 +1,40 @@
# dora-policy-inference

## Getting started

- Install it with uv:

```bash
uv venv -p 3.11 --seed
uv pip install -e .
```

## Contribution Guide

- Format with [ruff](https://docs.astral.sh/ruff/):

```bash
uv pip install ruff
uv run ruff check . --fix
```

- Lint with ruff:

```bash
uv run ruff check .
```

- Test with [pytest](https://github.com/pytest-dev/pytest)

```bash
uv pip install pytest
uv run pytest . # Test
```

## YAML Specification

## Examples

## License

dora-policy-inference's code are released under the MIT License

+ 13
- 0
node-hub/dora-policy-inference/dora_policy_inference/__init__.py View File

@@ -0,0 +1,13 @@
"""TODO: Add docstring."""

import os

# Define the path to the README file relative to the package directory
readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.md")

# Read the content of the README file
try:
with open(readme_path, encoding="utf-8") as f:
__doc__ = f.read()
except FileNotFoundError:
__doc__ = "README file not found."

+ 6
- 0
node-hub/dora-policy-inference/dora_policy_inference/__main__.py View File

@@ -0,0 +1,6 @@
"""TODO: Add docstring."""

from .main import main

if __name__ == "__main__":
main()

+ 216
- 0
node-hub/dora-policy-inference/dora_policy_inference/main.py View File

@@ -0,0 +1,216 @@
"""TODO."""

from dora import Node
import pyarrow as pa
import os
import time
import numpy as np
import threading
import queue
import cv2
from typing import Any
from pathlib import Path

from lerobot.configs.policies import PreTrainedConfig
from lerobot.common.policies.factory import get_policy_class
from lerobot.common.utils.control_utils import predict_action
from lerobot.common.utils.utils import get_safe_torch_device

class DoraPolicyInference:
"""Policy inference node for LeRobot policies."""

def __init__(self):
"""Initialize the policy inference node."""
self.message_queue = queue.Queue()

self.data_buffer = {}
self.buffer_lock = threading.Lock()
self.cameras = self._parse_cameras()
self.task_description = os.getenv("TASK_DESCRIPTION", "")

self.policy = None
self.inference_fps = int(os.getenv("INFERENCE_FPS", "30"))
self.last_inference_time = None
self.inference_interval = 1.0 / self.inference_fps

self.shutdown = False
self._load_policy()
self._start_timer()

def _parse_cameras(self) -> dict:
"""Parse camera configuration from environment variables."""
camera_names_str = os.getenv("CAMERA_NAMES", "laptop,front")
camera_names = [name.strip() for name in camera_names_str.split(',')]

cameras = {}
for camera_name in camera_names:
resolution = os.getenv(f"CAMERA_{camera_name.upper()}_RESOLUTION", "480,640,3")
dims = [int(d.strip()) for d in resolution.split(',')]
cameras[camera_name] = dims

return cameras

def _load_policy(self):
"""Load the trained LeRobot policy."""
model_path = Path(os.getenv("MODEL_PATH"))
if not model_path:
raise ValueError("MODEL_PATH environment variable must be set correctly.")

config = PreTrainedConfig.from_pretrained(model_path)
self.device = get_safe_torch_device(config.device) # get device automatically
config.device = self.device.type

# Get the policy class and load pretrained weights
policy_cls = get_policy_class(config.type)
self.policy = policy_cls.from_pretrained(
pretrained_name_or_path=model_path,
config=config
)

self.policy.eval() # Set to evaluation mode
self.policy.to(self.device)

self._output(f"Policy loaded successfully on device: {self.device}")
self._output(f"Policy type: {config.type}")

def _start_timer(self):
"""Start the inference timing thread."""
self.stop_timer = False
self.inference_thread = threading.Thread(target=self._inference_loop, daemon=True)
self.inference_thread.start()
def _inference_loop(self):
"""Inference loop."""
while not self.stop_timer and not self.shutdown:
current_time = time.time()
if (self.last_inference_time is None or
current_time - self.last_inference_time >= self.inference_interval):
self._run_inference()
self.last_inference_time = current_time

time.sleep(0.001) # Small sleep to prevent busy waiting

def _run_inference(self):
"""Run policy inference on current observations."""
with self.buffer_lock:
required_inputs = set(self.cameras.keys()) | {"robot_state"}
available_inputs = set(self.data_buffer.keys())

if not required_inputs.issubset(available_inputs):
return

observation = {}
for camera_name in self.cameras:
if camera_name in self.data_buffer:
image = self._convert_camera_data(
self.data_buffer[camera_name]["data"],
self.data_buffer[camera_name]["metadata"])
observation[f"observation.images.{camera_name}"] = image

state = self._convert_robot_data(self.data_buffer["robot_state"]["data"])
observation["observation.state"] = state

action = predict_action(
observation=observation,
policy=self.policy,
device=self.device,
use_amp=self.policy.config.use_amp,
task=self.task_description).cpu().numpy()

# Convert from degrees to radians
action = np.deg2rad(action)

self._send_action(action)

def _convert_camera_data(self, dora_data, metadata) -> np.ndarray:
"""Convert camera data from Dora format to numpy."""
height, width = metadata.get("height"), metadata.get("width")
encoding = metadata.get("encoding", "rgb8")
image = dora_data.to_numpy().reshape(height, width, 3)

if encoding == "bgr8":
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif encoding == "yuv420":
image = cv2.cvtColor(image, cv2.COLOR_YUV2RGB_I420)

return image.astype(np.uint8)

def _convert_robot_data(self, dora_data, convert_to_degrees=True) -> np.ndarray:
"""Convert robot joint data, LeRobot expects angles in degrees and float32."""
joint_array = dora_data.to_numpy()
if convert_to_degrees:
joint_array = np.rad2deg(joint_array)
return joint_array.astype(np.float32)

def _send_action(self, action: np.ndarray):
"""Send action through message queue to be output by main thread."""
action_message = {"action": action}
self.message_queue.put(("action", action_message))

def handle_input(self, input_id: str, data: Any, metadata: Any):
"""Handle incoming data."""
with self.buffer_lock:
self.data_buffer[input_id] = {
"data": data,
"timestamp": time.time(),
"metadata": metadata
}

def get_pending_messages(self):
"""Get all pending messages from the queue."""
messages = []
while not self.message_queue.empty():
messages.append(self.message_queue.get_nowait())

return messages

def _output(self, message: str):
"""Output status message."""
self.message_queue.put(("status", message))
print(message)

def shutdown(self):
"""Shutdown the inference node."""
self._output("Shutting down policy inference...")
self.shutdown = True
self.stop_timer = True

if self.inference_thread.is_alive():
self.inference_thread.join(timeout=5.0)

self._output("Policy inference shutdown complete")


def main():
node = Node()
inference = DoraPolicyInference()

for event in node:
messages = inference.get_pending_messages()
for msg_type, msg_data in messages:
if msg_type == "action":
# Send robot action
node.send_output(
output_id="robot_action",
data=pa.array(msg_data["action"]),
metadata={}
)
elif msg_type == "status":
node.send_output(
output_id="status",
data=pa.array([msg_data]),
metadata={}
)
if event["type"] == "INPUT":
inference.handle_input(
event["id"],
event["value"],
event.get("metadata", {})
)

# inference.shutdown() # Not sure when to shutdown the node, and how to identify task completion


if __name__ == "__main__":
main()

+ 36
- 0
node-hub/dora-policy-inference/pyproject.toml View File

@@ -0,0 +1,36 @@
[project]
name = "dora-policy-inference"
version = "0.0.0"
authors = [{ name = "Shashwat Patil", email = "email@email.com" }]
description = "dora-policy-inference"
license = { text = "MIT" }
readme = "README.md"
requires-python = ">=3.8"

dependencies = ["dora-rs >= 0.3.9","pyarrow","lerobot"]

[dependency-groups]
dev = ["pytest >=8.1.1", "ruff >=0.9.1"]

[project.scripts]
dora-policy-inference = "dora_policy_inference.main:main"

[tool.ruff.lint]
extend-select = [
"D", # pydocstyle
"UP"
]
ignore = [
"D100", # Missing docstring in public module
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
]

[tool.uv.sources]
lerobot = {git = "https://github.com/huggingface/lerobot.git", rev = "main"}

[tool.uv.pip]
git-config = [
"filter.lfs.smudge=git-lfs smudge --skip -- %f",
"filter.lfs.process=git-lfs filter-process --skip"
]

+ 13
- 0
node-hub/dora-policy-inference/tests/test_dora_policy_inference.py View File

@@ -0,0 +1,13 @@
"""Test module for dora_policy_inference package."""

import pytest


def test_import_main():
"""Test importing and running the main function."""
from dora_policy_inference.main import main

# Check that everything is working, and catch Dora RuntimeError
# as we're not running in a Dora dataflow.
with pytest.raises(RuntimeError):
main()

Loading…
Cancel
Save