Browse Source

Fix naming and make it possible to train model without cloning lerobot repo

pull/1041/head
haixuantao 6 months ago
parent
commit
baaca28a51
5 changed files with 189 additions and 142 deletions
  1. +31
    -34
      examples/lerobot-dataset/Readme.md
  2. +5
    -5
      examples/lerobot-dataset/policy_inference.yml
  3. +98
    -51
      node-hub/dora-dataset-record/dora_dataset_record/main.py
  4. +8
    -13
      node-hub/dora-dataset-record/pyproject.toml
  5. +47
    -39
      node-hub/dora-policy-inference/dora_policy_inference/main.py

+ 31
- 34
examples/lerobot-dataset/Readme.md View File

@@ -16,20 +16,23 @@ CAPTURE_PATH: "0" # Laptop camera is defaults to 0
# External camera
CAPTURE_PATH: "1" # Change this to match your external camera device
```

### SO101 Arms

Identify and set the correct USB ports for both the leader and follower SO101 arms

```yaml
PORT: "/dev/ttyACM0" # change this
```

### Dataset Recorder Configuration

>Edit the following fields:
> Edit the following fields:

```yaml
REPO_ID: "your_username/so101_dataset" # HuggingFace dataset path
SINGLE_TASK: "Pick up the red cube and place it in the blue box" # Your task description
CAMERA_NAMES: "laptop, front" # Name your camera sources depending on your setup
REPO_ID: "your_username/so101_dataset" # HuggingFace dataset path
SINGLE_TASK: "Pick up the red cube and place it in the blue box" # Your task description
CAMERA_NAMES: "laptop, front" # Name your camera sources depending on your setup
CAMERA_LAPTOP_RESOLUTION: "480,640,3"
CAMERA_FRONT_RESOLUTION: "480,640,3"
```
@@ -37,24 +40,26 @@ CAMERA_FRONT_RESOLUTION: "480,640,3"
You can adjust the following parameters as needed:

```yaml
TOTAL_EPISODES: "2" # Number of episodes
TOTAL_EPISODES: "2" # Number of episodes
#you may want to try with 2-3 episodes to test, then atleast 50 episodes for training is recommended
EPISODE_DURATION_S: "60" # Duration of each episode (in seconds) - depends on complexity of task
RESET_DURATION_S: "15" # Time to reset the environment between episodes
FPS: "30" # Should match camera fps
PUSH_TO_HUB: "false" # Set to "true" to auto-upload dataset to HuggingFace
ROOT_PATH: "full path where you want to save the dataset"
EPISODE_DURATION_S: "60" # Duration of each episode (in seconds) - depends on complexity of task
RESET_DURATION_S: "15" # Time to reset the environment between episodes
FPS: "30" # Should match camera fps
PUSH_TO_HUB: "false" # Set to "true" to auto-upload dataset to HuggingFace
ROOT_PATH: "full path where you want to save the dataset"
# if not defined then it will be stored at ~/.cache/huggingface/lerobot/repo_id
```

Once everything is updated in `dataflow.yml`, you are ready to record your dataset.

## Start Recording

Build and Run

```bash
dora build dataflow.yml
dora run dataflow.yml
uv venv
dora build dataflow.yml --uv
dora run dataflow.yml --uv
```

## Recording Process
@@ -64,11 +69,13 @@ In the rerun window you can see the the info regarding Start of episodes, start
#### During Recording

1. **Episode Active**:

- Use the **leader robot** to demonstrate/perform the task
- Move smoothly and naturally
- Complete the full task within the time limit

2. **Reset Phase**:

- move the leader arm to initial position
- Reset objects to starting positions
- Prepare workspace for next demonstration
@@ -86,8 +93,8 @@ In the rerun window you can see the the info regarding Start of episodes, start
## After Recording

Your dataset will be saved locally. Check the recording was successful:
>It will be stored in ~/.cache/huggingface/lerobot/repo_id

> It will be stored in ~/.cache/huggingface/lerobot/repo_id

# Training and Testing Policies using the recorded dataset

@@ -95,35 +102,24 @@ After successfully recording your dataset, we will be training imitation learnin

## Training Your Policy

#### Install LeRobot Training Dependencies

Easiest way to train your policy is to use lerobots training scripts
```bash
# Install training requirements and lerobot repo
git clone https://github.com/huggingface/lerobot.git
pip install lerobot[training]
pip install tensorboard wandb # For monitoring (Optional)
```

### Choose Your Policy

**ACT (Recommended for SO101)**

- Good for manipulation tasks
- Handles multi-modal data well
- Faster training (for me it took 7hrs 🥺 to train on 50 episodes for pick and place, 3050 laptop gpu)

**Diffusion Policy**

- Better for complex behaviors
- More robust to distribution shift
- Longer training


### Start Training

```bash
cd lerobot

python lerobot/scripts/train.py \
uv run dora-dataset-lerobot-train
--dataset.repo_id=${HF_USER}/your_repo_id \ # provide full path of your dataset
--policy.type=act \
--output_dir=outputs/train/act_so101_test \
@@ -134,6 +130,7 @@ python lerobot/scripts/train.py \
```

You can monitor your training progress on wandb

> For more details regarding training check [lerobot](https://huggingface.co/docs/lerobot/en/il_robots#train-a-policy) guide on Imitation learning for SO101

## Policy Inference and Testing
@@ -150,7 +147,7 @@ Update the camera device IDs to match your setup (same as recording):
# Laptop camera
CAPTURE_PATH: "0" # Usually 0 for built-in laptop camera

# External camera
# External camera
CAPTURE_PATH: "1" # Change this to your external camera device ID
```

@@ -159,7 +156,7 @@ CAPTURE_PATH: "1" # Change this to your external camera device ID
Set the correct USB port for your follower SO101 arm:

```yaml
PORT: "/dev/ttyACM1" # Update this to match your follower robot port
PORT: "/dev/ttyACM1" # Update this to match your follower robot port
```

#### 3. Model Configuration
@@ -167,7 +164,7 @@ PORT: "/dev/ttyACM1" # Update this to match your follower robot port
Update the path to your trained model and task description:

```yaml
MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model" # Path to your trained model
MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model" # Path to your trained model
TASK_DESCRIPTION: "Pick up the red cube and place it in the blue box"
```

@@ -176,9 +173,9 @@ TASK_DESCRIPTION: "Pick up the red cube and place it in the blue box"
Ensure camera settings match your recording configuration:

```yaml
CAMERA_NAMES: "laptop, front" # Must match training setup
CAMERA_LAPTOP_RESOLUTION: "480,640,3" # Must match training
CAMERA_FRONT_RESOLUTION: "480,640,3" # Must match training
CAMERA_NAMES: "laptop, front" # Must match training setup
CAMERA_LAPTOP_RESOLUTION: "480,640,3" # Must match training
CAMERA_FRONT_RESOLUTION: "480,640,3" # Must match training
```

### Start Policy Inference
@@ -188,4 +185,4 @@ Once you've updated the configuration:
```bash
dora build policy_inference.yml
dora run policy_inference.yml
```
```

+ 5
- 5
examples/lerobot-dataset/policy_inference.yml View File

@@ -21,7 +21,7 @@ nodes:
outputs:
- image
env:
CAPTURE_PATH: "1" # ← UPDATE: Your external camera device
CAPTURE_PATH: "1" # ← UPDATE: Your external camera device
ENCODING: "rgb8"
IMAGE_WIDTH: "640"
IMAGE_HEIGHT: "480"
@@ -35,21 +35,21 @@ nodes:
outputs:
- pose
env:
PORT: "/dev/ttyACM1" # UPDATE your robot port
PORT: "/dev/ttyACM1" # UPDATE your robot port
IDS: "1,2,3,4,5,6"

- id: policy_inference
build: pip install -e ../../node-hub/dora-policy-inference
path: dora-policy-inference
inputs:
laptop: laptop_cam/image # the camera inputs should match the training setup
laptop: laptop_cam/image # the camera inputs should match the training setup
front: front_cam/image
robot_state: so101_follower/pose
outputs:
- robot_action
- status
env:
MODEL_PATH: "./outputs/train/your_policy/checkpoints/last/pretrained_model" # Path to your trained model
MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model/" # Path to your trained model
TASK_DESCRIPTION: "Your task"
ROBOT_TYPE: "so101_follower"
CAMERA_NAMES: "laptop, front"
@@ -64,4 +64,4 @@ nodes:
inputs:
image_laptop: laptop_cam/image
image_front: front_cam/image
status: policy_inference/status
status: policy_inference/status

+ 98
- 51
node-hub/dora-dataset-record/dora_dataset_record/main.py View File

@@ -1,15 +1,17 @@
"""TODO: Add docstring."""

from dora import Node
import pyarrow as pa
import os
import time
import numpy as np
import threading
import queue
import cv2
import threading
import time
from typing import Any
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

import cv2
import numpy as np
import pyarrow as pa
from dora import Node
from lerobot.datasets.lerobot_dataset import LeRobotDataset


class DoraLeRobotRecorder:
"""Recorder class for LeRobot dataset."""
@@ -28,10 +30,16 @@ class DoraLeRobotRecorder:
self.episode_index = 0
self.start_time = None
self.cameras = self._get_cameras()
self.total_episodes = int(os.getenv("TOTAL_EPISODES", "10")) # Default to 10 episodes
self.episode_duration = int(os.getenv("EPISODE_DURATION_S", "60")) # Default to 60 seconds
self.reset_duration = int(os.getenv("RESET_DURATION_S", "15")) # Default to 15 seconds
self.fps = int(os.getenv("FPS", "30")) # Default to 30 FPS
self.total_episodes = int(
os.getenv("TOTAL_EPISODES", "10")
) # Default to 10 episodes
self.episode_duration = int(
os.getenv("EPISODE_DURATION_S", "60")
) # Default to 60 seconds
self.reset_duration = int(
os.getenv("RESET_DURATION_S", "15")
) # Default to 15 seconds
self.fps = int(os.getenv("FPS", "30")) # Default to 30 FPS

self.recording_started = False
self.in_reset_phase = False
@@ -49,7 +57,7 @@ class DoraLeRobotRecorder:
"""Get Camera config."""
camera_names_str = os.getenv("CAMERA_NAMES")
if camera_names_str:
camera_names = [name.strip() for name in camera_names_str.split(',')]
camera_names = [name.strip() for name in camera_names_str.split(",")]
else:
return {}

@@ -57,10 +65,12 @@ class DoraLeRobotRecorder:
for camera_name in camera_names:
resolution = os.getenv(f"CAMERA_{camera_name.upper()}_RESOLUTION")
if resolution:
dims = [int(d.strip()) for d in resolution.split(',')]
dims = [int(d.strip()) for d in resolution.split(",")]
cameras[camera_name] = dims
else:
print(f"Warning: Set CAMERA_{camera_name.upper()}_RESOLUTION: \"height,width,channels\"")
print(
f'Warning: Set CAMERA_{camera_name.upper()}_RESOLUTION: "height,width,channels"'
)

return cameras

@@ -68,7 +78,7 @@ class DoraLeRobotRecorder:
"""Get robot joints."""
joints_str = os.getenv("ROBOT_JOINTS")
if joints_str:
return [joint.strip() for joint in joints_str.split(',')]
return [joint.strip() for joint in joints_str.split(",")]
else:
raise ValueError("ROBOT_JOINTS are not set.")

@@ -76,7 +86,7 @@ class DoraLeRobotRecorder:
"""Get tags for dataset."""
tags_str = os.getenv("TAGS")
if tags_str:
return [tag.strip() for tag in tags_str.split(',')]
return [tag.strip() for tag in tags_str.split(",")]
return []

def _setup_dataset(self):
@@ -85,34 +95,41 @@ class DoraLeRobotRecorder:

joint_names = self._get_robot_joints()
features["action"] = {
"dtype": "float32",
"dtype": "float32",
"shape": (len(joint_names),),
"names": joint_names}
"names": joint_names,
}
features["observation.state"] = {
"dtype": "float32",
"shape": (len(joint_names),),
"names": joint_names}
"names": joint_names,
}

self.use_videos = os.getenv("USE_VIDEOS", "true").lower() == "true"
for camera_name in self.cameras:
features[f"observation.images.{camera_name}"] = {
"dtype": "video" if self.use_videos else "image",
"shape": self.cameras[camera_name],
"names": ["height", "width", "channels"]}
"names": ["height", "width", "channels"],
}

self.required_features = set(features.keys())

features.update({
"timestamp": {"dtype": "float32", "shape": [1]},
"frame_index": {"dtype": "int64", "shape": [1]},
"episode_index": {"dtype": "int64", "shape": [1]},
"index": {"dtype": "int64", "shape": [1]},
"task_index": {"dtype": "int64", "shape": [1]},
})
features.update(
{
"timestamp": {"dtype": "float32", "shape": [1]},
"frame_index": {"dtype": "int64", "shape": [1]},
"episode_index": {"dtype": "int64", "shape": [1]},
"index": {"dtype": "int64", "shape": [1]},
"task_index": {"dtype": "int64", "shape": [1]},
}
)

repo_id = os.getenv("REPO_ID", None)
if repo_id is None:
raise ValueError("REPO_ID environment variable must be set to create dataset")
raise ValueError(
"REPO_ID environment variable must be set to create dataset"
)

self.dataset = LeRobotDataset.create(
repo_id=repo_id,
@@ -122,7 +139,8 @@ class DoraLeRobotRecorder:
robot_type=os.getenv("ROBOT_TYPE", "your_robot_type"),
use_videos=self.use_videos,
image_writer_processes=int(os.getenv("IMAGE_WRITER_PROCESSES", "0")),
image_writer_threads=int(os.getenv("IMAGE_WRITER_THREADS", "4")) * len(self.cameras),
image_writer_threads=int(os.getenv("IMAGE_WRITER_THREADS", "4"))
* len(self.cameras),
)

def _check_episode_timing(self):
@@ -167,7 +185,9 @@ class DoraLeRobotRecorder:
"""End current episode and save to dataset."""
self.episode_active = False
if self.frame_count > 0:
self._output(f"Saving episode index {self.episode_index} with {self.frame_count} frames")
self._output(
f"Saving episode index {self.episode_index} with {self.frame_count} frames"
)
self.dataset.save_episode()
self.episode_index += 1
else:
@@ -177,12 +197,16 @@ class DoraLeRobotRecorder:
"""Start the reset phase between episodes."""
self.in_reset_phase = True
self.reset_start_time = time.time()
self._output(f"Reset phase started - {self.reset_duration}s break before next episode...")
self._output(
f"Reset phase started - {self.reset_duration}s break before next episode..."
)

def _start_frame_timer(self):
"""Start the frame timer thread."""
self.stop_timer = False
self.frame_timer_thread = threading.Thread(target=self._frame_timer_loop, daemon=True)
self.frame_timer_thread = threading.Thread(
target=self._frame_timer_loop, daemon=True
)
self.frame_timer_thread.start()

def _frame_timer_loop(self):
@@ -190,9 +214,13 @@ class DoraLeRobotRecorder:
while not self.stop_timer and not self.shutdown:
current_time = time.time()

if self.episode_active and not self.in_reset_phase and (
self.last_frame_time is None or
current_time - self.last_frame_time >= self.frame_interval
if (
self.episode_active
and not self.in_reset_phase
and (
self.last_frame_time is None
or current_time - self.last_frame_time >= self.frame_interval
)
):
self._add_frame()
self.last_frame_time = current_time
@@ -212,7 +240,7 @@ class DoraLeRobotRecorder:
self.data_buffer[input_id] = {
"data": data,
"timestamp": time.time(),
"metadata": metadata
"metadata": metadata,
}

should_stop = self._check_episode_timing()
@@ -242,7 +270,7 @@ class DoraLeRobotRecorder:
"""Convert camera data from 1D pyarrow array to numpy format."""
height, width = metadata.get("height"), metadata.get("width")
encoding = metadata.get("encoding")
image = dora_data.to_numpy().reshape(height, width, 3)
image = dora_data.to_numpy().reshape(height, width, 3)

if encoding == "bgr8":
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -266,18 +294,24 @@ class DoraLeRobotRecorder:

for key, value in self.data_buffer.items():
if key == "robot_action":
frame_data["action"] = self.convert_robot_data(self.data_buffer["robot_action"]["data"])
frame_data["action"] = self.convert_robot_data(
self.data_buffer["robot_action"]["data"]
)
if key == "robot_state":
frame_data["observation.state"] = self.convert_robot_data(self.data_buffer["robot_state"]["data"])
if {'height', 'width'} <= value.get('metadata', {}).keys():
frame_data["observation.state"] = self.convert_robot_data(
self.data_buffer["robot_state"]["data"]
)
if {"height", "width"} <= value.get("metadata", {}).keys():
camera_name = key
image = self._convert_camera_data(
self.data_buffer[camera_name]["data"],
self.data_buffer[camera_name]["metadata"]
self.data_buffer[camera_name]["metadata"],
)
frame_data[f"observation.images.{camera_name}"] = image

missing_keys = self.required_features - set(frame_data.keys()) # Ensure all required features are present
missing_keys = self.required_features - set(
frame_data.keys()
) # Ensure all required features are present
if missing_keys:
print(f"Missing required data in frame: {missing_keys}")
return
@@ -285,7 +319,7 @@ class DoraLeRobotRecorder:
self.dataset.add_frame(
frame=frame_data,
task=os.getenv("SINGLE_TASK", "Your task"),
timestamp=ideal_timestamp
timestamp=ideal_timestamp,
)
self.frame_count += 1

@@ -302,10 +336,12 @@ class DoraLeRobotRecorder:
self._output("Pushing dataset to hub...")
self.dataset.push_to_hub(
tags=self._get_tags(),
private=os.getenv("PRIVATE", "false").lower() == "true"
private=os.getenv("PRIVATE", "false").lower() == "true",
)

self._output(f"Dataset recording completed. Total episodes: {self.episode_index}")
self._output(
f"Dataset recording completed. Total episodes: {self.episode_index}"
)

def _output(self, message: str):
"""Output message."""
@@ -320,6 +356,7 @@ class DoraLeRobotRecorder:
messages.append(self.message_queue.get_nowait())
return messages


def main():
node = Node()
recorder = DoraLeRobotRecorder()
@@ -332,18 +369,28 @@ def main():
for event in node:
pending_messages = recorder.get_pending_messages()
for message in pending_messages:
node.send_output(
output_id="text",
data=pa.array([message]),
metadata={})
node.send_output(output_id="text", data=pa.array([message]), metadata={})

if event["type"] == "INPUT":
should_stop = recorder.handle_input(event["id"], event["value"], event.get("metadata", {}))
should_stop = recorder.handle_input(
event["id"], event["value"], event.get("metadata", {})
)
if should_stop:
print("All episodes completed, stopping recording...")
break

recorder._shutdown()


def train_main():
from lerobot.scripts.train import train
from lerobot.utils.utils import (
init_logging,
)

init_logging()
train()


if __name__ == "__main__":
main()
main()

+ 8
- 13
node-hub/dora-dataset-record/pyproject.toml View File

@@ -7,30 +7,25 @@ license = { text = "MIT" }
readme = "README.md"
requires-python = ">=3.8"

dependencies = ["dora-rs >= 0.3.9","pyarrow","lerobot"]
dependencies = ["dora-rs >= 0.3.9", "pyarrow", "lerobot[train]"]

[dependency-groups]
dev = ["pytest >=8.1.1", "ruff >=0.9.1"]

[project.scripts]
dora-dataset-record = "dora_dataset_record.main:main"
dora-dataset-lerobot-train = "dora_dataset_record.main:train_main"

[tool.ruff.lint]
extend-select = [
"D", # pydocstyle
"UP"
"D", # pydocstyle
"UP",
]
ignore = [
"D100", # Missing docstring in public module
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D100", # Missing docstring in public module
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
]

[tool.uv.sources]
lerobot = {git = "https://github.com/huggingface/lerobot.git", rev = "main"}

[tool.uv.pip]
git-config = [
"filter.lfs.smudge=git-lfs smudge --skip -- %f",
"filter.lfs.process=git-lfs filter-process --skip"
]
lerobot = { git = "https://github.com/huggingface/lerobot.git", rev = "main" }

+ 47
- 39
node-hub/dora-policy-inference/dora_policy_inference/main.py View File

@@ -1,20 +1,21 @@
"""TODO."""

from dora import Node
import pyarrow as pa
import os
import time
import numpy as np
import threading
import queue
import cv2
from typing import Any
import threading
import time
from pathlib import Path
from typing import Any

import cv2
import numpy as np
import pyarrow as pa
from dora import Node
from lerobot.configs.policies import PreTrainedConfig
from lerobot.common.policies.factory import get_policy_class
from lerobot.common.utils.control_utils import predict_action
from lerobot.common.utils.utils import get_safe_torch_device
from lerobot.policies.factory import get_policy_class
from lerobot.utils.control_utils import predict_action
from lerobot.utils.utils import get_safe_torch_device


class DoraPolicyInference:
"""Policy inference node for LeRobot policies."""
@@ -25,10 +26,10 @@ class DoraPolicyInference:

self.data_buffer = {}
self.buffer_lock = threading.Lock()
self.cameras = self._parse_cameras()
self.cameras = self._parse_cameras()
self.task_description = os.getenv("TASK_DESCRIPTION", "")

self.policy = None
self.policy = None
self.inference_fps = int(os.getenv("INFERENCE_FPS", "30"))
self.last_inference_time = None
self.inference_interval = 1.0 / self.inference_fps
@@ -40,12 +41,14 @@ class DoraPolicyInference:
def _parse_cameras(self) -> dict:
"""Parse camera configuration from environment variables."""
camera_names_str = os.getenv("CAMERA_NAMES", "laptop,front")
camera_names = [name.strip() for name in camera_names_str.split(',')]
camera_names = [name.strip() for name in camera_names_str.split(",")]

cameras = {}
for camera_name in camera_names:
resolution = os.getenv(f"CAMERA_{camera_name.upper()}_RESOLUTION", "480,640,3")
dims = [int(d.strip()) for d in resolution.split(',')]
resolution = os.getenv(
f"CAMERA_{camera_name.upper()}_RESOLUTION", "480,640,3"
)
dims = [int(d.strip()) for d in resolution.split(",")]
cameras[camera_name] = dims

return cameras
@@ -57,14 +60,13 @@ class DoraPolicyInference:
raise ValueError("MODEL_PATH environment variable must be set correctly.")

config = PreTrainedConfig.from_pretrained(model_path)
self.device = get_safe_torch_device(config.device) # get device automatically
self.device = get_safe_torch_device(config.device) # get device automatically
config.device = self.device.type

# Get the policy class and load pretrained weights
policy_cls = get_policy_class(config.type)
self.policy = policy_cls.from_pretrained(
pretrained_name_or_path=model_path,
config=config
pretrained_name_or_path=model_path, config=config
)

self.policy.eval() # Set to evaluation mode
@@ -76,15 +78,19 @@ class DoraPolicyInference:
def _start_timer(self):
"""Start the inference timing thread."""
self.stop_timer = False
self.inference_thread = threading.Thread(target=self._inference_loop, daemon=True)
self.inference_thread = threading.Thread(
target=self._inference_loop, daemon=True
)
self.inference_thread.start()
def _inference_loop(self):
"""Inference loop."""
while not self.stop_timer and not self.shutdown:
current_time = time.time()
if (self.last_inference_time is None or
current_time - self.last_inference_time >= self.inference_interval):
if (
self.last_inference_time is None
or current_time - self.last_inference_time >= self.inference_interval
):
self._run_inference()
self.last_inference_time = current_time

@@ -104,18 +110,24 @@ class DoraPolicyInference:
if camera_name in self.data_buffer:
image = self._convert_camera_data(
self.data_buffer[camera_name]["data"],
self.data_buffer[camera_name]["metadata"])
self.data_buffer[camera_name]["metadata"],
)
observation[f"observation.images.{camera_name}"] = image

state = self._convert_robot_data(self.data_buffer["robot_state"]["data"])
observation["observation.state"] = state

action = predict_action(
observation=observation,
policy=self.policy,
device=self.device,
use_amp=self.policy.config.use_amp,
task=self.task_description).cpu().numpy()
action = (
predict_action(
observation=observation,
policy=self.policy,
device=self.device,
use_amp=self.policy.config.use_amp,
task=self.task_description,
)
.cpu()
.numpy()
)

# Convert from degrees to radians
action = np.deg2rad(action)
@@ -153,7 +165,7 @@ class DoraPolicyInference:
self.data_buffer[input_id] = {
"data": data,
"timestamp": time.time(),
"metadata": metadata
"metadata": metadata,
}

def get_pending_messages(self):
@@ -193,24 +205,20 @@ def main():
node.send_output(
output_id="robot_action",
data=pa.array(msg_data["action"]),
metadata={}
metadata={},
)
elif msg_type == "status":
node.send_output(
output_id="status",
data=pa.array([msg_data]),
metadata={}
output_id="status", data=pa.array([msg_data]), metadata={}
)

if event["type"] == "INPUT":
inference.handle_input(
event["id"],
event["value"],
event.get("metadata", {})
event["id"], event["value"], event.get("metadata", {})
)

# inference.shutdown() # Not sure when to shutdown the node, and how to identify task completion


if __name__ == "__main__":
main()
main()

Loading…
Cancel
Save