From baaca28a51bb674541dfb1b908de90ab8ecc4478 Mon Sep 17 00:00:00 2001 From: haixuantao Date: Mon, 7 Jul 2025 15:12:54 +0200 Subject: [PATCH] Fix naming and make it possible to train model without cloning lerobot repo --- examples/lerobot-dataset/Readme.md | 65 ++++---- examples/lerobot-dataset/policy_inference.yml | 10 +- .../dora_dataset_record/main.py | 149 ++++++++++++------ node-hub/dora-dataset-record/pyproject.toml | 21 +-- .../dora_policy_inference/main.py | 86 +++++----- 5 files changed, 189 insertions(+), 142 deletions(-) diff --git a/examples/lerobot-dataset/Readme.md b/examples/lerobot-dataset/Readme.md index c3f95245..a35217fd 100644 --- a/examples/lerobot-dataset/Readme.md +++ b/examples/lerobot-dataset/Readme.md @@ -16,20 +16,23 @@ CAPTURE_PATH: "0" # Laptop camera is defaults to 0 # External camera CAPTURE_PATH: "1" # Change this to match your external camera device ``` + ### SO101 Arms Identify and set the correct USB ports for both the leader and follower SO101 arms + ```yaml PORT: "/dev/ttyACM0" # change this ``` + ### Dataset Recorder Configuration ->Edit the following fields: +> Edit the following fields: ```yaml -REPO_ID: "your_username/so101_dataset" # HuggingFace dataset path -SINGLE_TASK: "Pick up the red cube and place it in the blue box" # Your task description -CAMERA_NAMES: "laptop, front" # Name your camera sources depending on your setup +REPO_ID: "your_username/so101_dataset" # HuggingFace dataset path +SINGLE_TASK: "Pick up the red cube and place it in the blue box" # Your task description +CAMERA_NAMES: "laptop, front" # Name your camera sources depending on your setup CAMERA_LAPTOP_RESOLUTION: "480,640,3" CAMERA_FRONT_RESOLUTION: "480,640,3" ``` @@ -37,24 +40,26 @@ CAMERA_FRONT_RESOLUTION: "480,640,3" You can adjust the following parameters as needed: ```yaml -TOTAL_EPISODES: "2" # Number of episodes +TOTAL_EPISODES: "2" # Number of episodes #you may want to try with 2-3 episodes to test, then atleast 50 episodes for training is recommended -EPISODE_DURATION_S: "60" # Duration of each episode (in seconds) - depends on complexity of task -RESET_DURATION_S: "15" # Time to reset the environment between episodes -FPS: "30" # Should match camera fps -PUSH_TO_HUB: "false" # Set to "true" to auto-upload dataset to HuggingFace -ROOT_PATH: "full path where you want to save the dataset" +EPISODE_DURATION_S: "60" # Duration of each episode (in seconds) - depends on complexity of task +RESET_DURATION_S: "15" # Time to reset the environment between episodes +FPS: "30" # Should match camera fps +PUSH_TO_HUB: "false" # Set to "true" to auto-upload dataset to HuggingFace +ROOT_PATH: "full path where you want to save the dataset" # if not defined then it will be stored at ~/.cache/huggingface/lerobot/repo_id ``` Once everything is updated in `dataflow.yml`, you are ready to record your dataset. ## Start Recording + Build and Run ```bash -dora build dataflow.yml -dora run dataflow.yml +uv venv +dora build dataflow.yml --uv +dora run dataflow.yml --uv ``` ## Recording Process @@ -64,11 +69,13 @@ In the rerun window you can see the the info regarding Start of episodes, start #### During Recording 1. **Episode Active**: + - Use the **leader robot** to demonstrate/perform the task - Move smoothly and naturally - Complete the full task within the time limit 2. **Reset Phase**: + - move the leader arm to initial position - Reset objects to starting positions - Prepare workspace for next demonstration @@ -86,8 +93,8 @@ In the rerun window you can see the the info regarding Start of episodes, start ## After Recording Your dataset will be saved locally. Check the recording was successful: ->It will be stored in ~/.cache/huggingface/lerobot/repo_id +> It will be stored in ~/.cache/huggingface/lerobot/repo_id # Training and Testing Policies using the recorded dataset @@ -95,35 +102,24 @@ After successfully recording your dataset, we will be training imitation learnin ## Training Your Policy -#### Install LeRobot Training Dependencies - -Easiest way to train your policy is to use lerobots training scripts -```bash -# Install training requirements and lerobot repo -git clone https://github.com/huggingface/lerobot.git -pip install lerobot[training] -pip install tensorboard wandb # For monitoring (Optional) -``` - ### Choose Your Policy **ACT (Recommended for SO101)** + - Good for manipulation tasks - Handles multi-modal data well - Faster training (for me it took 7hrs 🥺 to train on 50 episodes for pick and place, 3050 laptop gpu) **Diffusion Policy** + - Better for complex behaviors - More robust to distribution shift - Longer training - ### Start Training ```bash -cd lerobot - -python lerobot/scripts/train.py \ +uv run dora-dataset-lerobot-train --dataset.repo_id=${HF_USER}/your_repo_id \ # provide full path of your dataset --policy.type=act \ --output_dir=outputs/train/act_so101_test \ @@ -134,6 +130,7 @@ python lerobot/scripts/train.py \ ``` You can monitor your training progress on wandb + > For more details regarding training check [lerobot](https://huggingface.co/docs/lerobot/en/il_robots#train-a-policy) guide on Imitation learning for SO101 ## Policy Inference and Testing @@ -150,7 +147,7 @@ Update the camera device IDs to match your setup (same as recording): # Laptop camera CAPTURE_PATH: "0" # Usually 0 for built-in laptop camera -# External camera +# External camera CAPTURE_PATH: "1" # Change this to your external camera device ID ``` @@ -159,7 +156,7 @@ CAPTURE_PATH: "1" # Change this to your external camera device ID Set the correct USB port for your follower SO101 arm: ```yaml -PORT: "/dev/ttyACM1" # Update this to match your follower robot port +PORT: "/dev/ttyACM1" # Update this to match your follower robot port ``` #### 3. Model Configuration @@ -167,7 +164,7 @@ PORT: "/dev/ttyACM1" # Update this to match your follower robot port Update the path to your trained model and task description: ```yaml -MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model" # Path to your trained model +MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model" # Path to your trained model TASK_DESCRIPTION: "Pick up the red cube and place it in the blue box" ``` @@ -176,9 +173,9 @@ TASK_DESCRIPTION: "Pick up the red cube and place it in the blue box" Ensure camera settings match your recording configuration: ```yaml -CAMERA_NAMES: "laptop, front" # Must match training setup -CAMERA_LAPTOP_RESOLUTION: "480,640,3" # Must match training -CAMERA_FRONT_RESOLUTION: "480,640,3" # Must match training +CAMERA_NAMES: "laptop, front" # Must match training setup +CAMERA_LAPTOP_RESOLUTION: "480,640,3" # Must match training +CAMERA_FRONT_RESOLUTION: "480,640,3" # Must match training ``` ### Start Policy Inference @@ -188,4 +185,4 @@ Once you've updated the configuration: ```bash dora build policy_inference.yml dora run policy_inference.yml -``` \ No newline at end of file +``` diff --git a/examples/lerobot-dataset/policy_inference.yml b/examples/lerobot-dataset/policy_inference.yml index 15cc7fb1..a5d6396b 100644 --- a/examples/lerobot-dataset/policy_inference.yml +++ b/examples/lerobot-dataset/policy_inference.yml @@ -21,7 +21,7 @@ nodes: outputs: - image env: - CAPTURE_PATH: "1" # ← UPDATE: Your external camera device + CAPTURE_PATH: "1" # ← UPDATE: Your external camera device ENCODING: "rgb8" IMAGE_WIDTH: "640" IMAGE_HEIGHT: "480" @@ -35,21 +35,21 @@ nodes: outputs: - pose env: - PORT: "/dev/ttyACM1" # UPDATE your robot port + PORT: "/dev/ttyACM1" # UPDATE your robot port IDS: "1,2,3,4,5,6" - id: policy_inference build: pip install -e ../../node-hub/dora-policy-inference path: dora-policy-inference inputs: - laptop: laptop_cam/image # the camera inputs should match the training setup + laptop: laptop_cam/image # the camera inputs should match the training setup front: front_cam/image robot_state: so101_follower/pose outputs: - robot_action - status env: - MODEL_PATH: "./outputs/train/your_policy/checkpoints/last/pretrained_model" # Path to your trained model + MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model/" # Path to your trained model TASK_DESCRIPTION: "Your task" ROBOT_TYPE: "so101_follower" CAMERA_NAMES: "laptop, front" @@ -64,4 +64,4 @@ nodes: inputs: image_laptop: laptop_cam/image image_front: front_cam/image - status: policy_inference/status \ No newline at end of file + status: policy_inference/status diff --git a/node-hub/dora-dataset-record/dora_dataset_record/main.py b/node-hub/dora-dataset-record/dora_dataset_record/main.py index e03f39ce..9f41ce03 100644 --- a/node-hub/dora-dataset-record/dora_dataset_record/main.py +++ b/node-hub/dora-dataset-record/dora_dataset_record/main.py @@ -1,15 +1,17 @@ """TODO: Add docstring.""" -from dora import Node -import pyarrow as pa import os -import time -import numpy as np -import threading import queue -import cv2 +import threading +import time from typing import Any -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +import cv2 +import numpy as np +import pyarrow as pa +from dora import Node +from lerobot.datasets.lerobot_dataset import LeRobotDataset + class DoraLeRobotRecorder: """Recorder class for LeRobot dataset.""" @@ -28,10 +30,16 @@ class DoraLeRobotRecorder: self.episode_index = 0 self.start_time = None self.cameras = self._get_cameras() - self.total_episodes = int(os.getenv("TOTAL_EPISODES", "10")) # Default to 10 episodes - self.episode_duration = int(os.getenv("EPISODE_DURATION_S", "60")) # Default to 60 seconds - self.reset_duration = int(os.getenv("RESET_DURATION_S", "15")) # Default to 15 seconds - self.fps = int(os.getenv("FPS", "30")) # Default to 30 FPS + self.total_episodes = int( + os.getenv("TOTAL_EPISODES", "10") + ) # Default to 10 episodes + self.episode_duration = int( + os.getenv("EPISODE_DURATION_S", "60") + ) # Default to 60 seconds + self.reset_duration = int( + os.getenv("RESET_DURATION_S", "15") + ) # Default to 15 seconds + self.fps = int(os.getenv("FPS", "30")) # Default to 30 FPS self.recording_started = False self.in_reset_phase = False @@ -49,7 +57,7 @@ class DoraLeRobotRecorder: """Get Camera config.""" camera_names_str = os.getenv("CAMERA_NAMES") if camera_names_str: - camera_names = [name.strip() for name in camera_names_str.split(',')] + camera_names = [name.strip() for name in camera_names_str.split(",")] else: return {} @@ -57,10 +65,12 @@ class DoraLeRobotRecorder: for camera_name in camera_names: resolution = os.getenv(f"CAMERA_{camera_name.upper()}_RESOLUTION") if resolution: - dims = [int(d.strip()) for d in resolution.split(',')] + dims = [int(d.strip()) for d in resolution.split(",")] cameras[camera_name] = dims else: - print(f"Warning: Set CAMERA_{camera_name.upper()}_RESOLUTION: \"height,width,channels\"") + print( + f'Warning: Set CAMERA_{camera_name.upper()}_RESOLUTION: "height,width,channels"' + ) return cameras @@ -68,7 +78,7 @@ class DoraLeRobotRecorder: """Get robot joints.""" joints_str = os.getenv("ROBOT_JOINTS") if joints_str: - return [joint.strip() for joint in joints_str.split(',')] + return [joint.strip() for joint in joints_str.split(",")] else: raise ValueError("ROBOT_JOINTS are not set.") @@ -76,7 +86,7 @@ class DoraLeRobotRecorder: """Get tags for dataset.""" tags_str = os.getenv("TAGS") if tags_str: - return [tag.strip() for tag in tags_str.split(',')] + return [tag.strip() for tag in tags_str.split(",")] return [] def _setup_dataset(self): @@ -85,34 +95,41 @@ class DoraLeRobotRecorder: joint_names = self._get_robot_joints() features["action"] = { - "dtype": "float32", + "dtype": "float32", "shape": (len(joint_names),), - "names": joint_names} + "names": joint_names, + } features["observation.state"] = { "dtype": "float32", "shape": (len(joint_names),), - "names": joint_names} + "names": joint_names, + } self.use_videos = os.getenv("USE_VIDEOS", "true").lower() == "true" for camera_name in self.cameras: features[f"observation.images.{camera_name}"] = { "dtype": "video" if self.use_videos else "image", "shape": self.cameras[camera_name], - "names": ["height", "width", "channels"]} + "names": ["height", "width", "channels"], + } self.required_features = set(features.keys()) - features.update({ - "timestamp": {"dtype": "float32", "shape": [1]}, - "frame_index": {"dtype": "int64", "shape": [1]}, - "episode_index": {"dtype": "int64", "shape": [1]}, - "index": {"dtype": "int64", "shape": [1]}, - "task_index": {"dtype": "int64", "shape": [1]}, - }) + features.update( + { + "timestamp": {"dtype": "float32", "shape": [1]}, + "frame_index": {"dtype": "int64", "shape": [1]}, + "episode_index": {"dtype": "int64", "shape": [1]}, + "index": {"dtype": "int64", "shape": [1]}, + "task_index": {"dtype": "int64", "shape": [1]}, + } + ) repo_id = os.getenv("REPO_ID", None) if repo_id is None: - raise ValueError("REPO_ID environment variable must be set to create dataset") + raise ValueError( + "REPO_ID environment variable must be set to create dataset" + ) self.dataset = LeRobotDataset.create( repo_id=repo_id, @@ -122,7 +139,8 @@ class DoraLeRobotRecorder: robot_type=os.getenv("ROBOT_TYPE", "your_robot_type"), use_videos=self.use_videos, image_writer_processes=int(os.getenv("IMAGE_WRITER_PROCESSES", "0")), - image_writer_threads=int(os.getenv("IMAGE_WRITER_THREADS", "4")) * len(self.cameras), + image_writer_threads=int(os.getenv("IMAGE_WRITER_THREADS", "4")) + * len(self.cameras), ) def _check_episode_timing(self): @@ -167,7 +185,9 @@ class DoraLeRobotRecorder: """End current episode and save to dataset.""" self.episode_active = False if self.frame_count > 0: - self._output(f"Saving episode index {self.episode_index} with {self.frame_count} frames") + self._output( + f"Saving episode index {self.episode_index} with {self.frame_count} frames" + ) self.dataset.save_episode() self.episode_index += 1 else: @@ -177,12 +197,16 @@ class DoraLeRobotRecorder: """Start the reset phase between episodes.""" self.in_reset_phase = True self.reset_start_time = time.time() - self._output(f"Reset phase started - {self.reset_duration}s break before next episode...") + self._output( + f"Reset phase started - {self.reset_duration}s break before next episode..." + ) def _start_frame_timer(self): """Start the frame timer thread.""" self.stop_timer = False - self.frame_timer_thread = threading.Thread(target=self._frame_timer_loop, daemon=True) + self.frame_timer_thread = threading.Thread( + target=self._frame_timer_loop, daemon=True + ) self.frame_timer_thread.start() def _frame_timer_loop(self): @@ -190,9 +214,13 @@ class DoraLeRobotRecorder: while not self.stop_timer and not self.shutdown: current_time = time.time() - if self.episode_active and not self.in_reset_phase and ( - self.last_frame_time is None or - current_time - self.last_frame_time >= self.frame_interval + if ( + self.episode_active + and not self.in_reset_phase + and ( + self.last_frame_time is None + or current_time - self.last_frame_time >= self.frame_interval + ) ): self._add_frame() self.last_frame_time = current_time @@ -212,7 +240,7 @@ class DoraLeRobotRecorder: self.data_buffer[input_id] = { "data": data, "timestamp": time.time(), - "metadata": metadata + "metadata": metadata, } should_stop = self._check_episode_timing() @@ -242,7 +270,7 @@ class DoraLeRobotRecorder: """Convert camera data from 1D pyarrow array to numpy format.""" height, width = metadata.get("height"), metadata.get("width") encoding = metadata.get("encoding") - image = dora_data.to_numpy().reshape(height, width, 3) + image = dora_data.to_numpy().reshape(height, width, 3) if encoding == "bgr8": image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) @@ -266,18 +294,24 @@ class DoraLeRobotRecorder: for key, value in self.data_buffer.items(): if key == "robot_action": - frame_data["action"] = self.convert_robot_data(self.data_buffer["robot_action"]["data"]) + frame_data["action"] = self.convert_robot_data( + self.data_buffer["robot_action"]["data"] + ) if key == "robot_state": - frame_data["observation.state"] = self.convert_robot_data(self.data_buffer["robot_state"]["data"]) - if {'height', 'width'} <= value.get('metadata', {}).keys(): + frame_data["observation.state"] = self.convert_robot_data( + self.data_buffer["robot_state"]["data"] + ) + if {"height", "width"} <= value.get("metadata", {}).keys(): camera_name = key image = self._convert_camera_data( self.data_buffer[camera_name]["data"], - self.data_buffer[camera_name]["metadata"] + self.data_buffer[camera_name]["metadata"], ) frame_data[f"observation.images.{camera_name}"] = image - missing_keys = self.required_features - set(frame_data.keys()) # Ensure all required features are present + missing_keys = self.required_features - set( + frame_data.keys() + ) # Ensure all required features are present if missing_keys: print(f"Missing required data in frame: {missing_keys}") return @@ -285,7 +319,7 @@ class DoraLeRobotRecorder: self.dataset.add_frame( frame=frame_data, task=os.getenv("SINGLE_TASK", "Your task"), - timestamp=ideal_timestamp + timestamp=ideal_timestamp, ) self.frame_count += 1 @@ -302,10 +336,12 @@ class DoraLeRobotRecorder: self._output("Pushing dataset to hub...") self.dataset.push_to_hub( tags=self._get_tags(), - private=os.getenv("PRIVATE", "false").lower() == "true" + private=os.getenv("PRIVATE", "false").lower() == "true", ) - self._output(f"Dataset recording completed. Total episodes: {self.episode_index}") + self._output( + f"Dataset recording completed. Total episodes: {self.episode_index}" + ) def _output(self, message: str): """Output message.""" @@ -320,6 +356,7 @@ class DoraLeRobotRecorder: messages.append(self.message_queue.get_nowait()) return messages + def main(): node = Node() recorder = DoraLeRobotRecorder() @@ -332,18 +369,28 @@ def main(): for event in node: pending_messages = recorder.get_pending_messages() for message in pending_messages: - node.send_output( - output_id="text", - data=pa.array([message]), - metadata={}) + node.send_output(output_id="text", data=pa.array([message]), metadata={}) if event["type"] == "INPUT": - should_stop = recorder.handle_input(event["id"], event["value"], event.get("metadata", {})) + should_stop = recorder.handle_input( + event["id"], event["value"], event.get("metadata", {}) + ) if should_stop: print("All episodes completed, stopping recording...") break recorder._shutdown() + +def train_main(): + from lerobot.scripts.train import train + from lerobot.utils.utils import ( + init_logging, + ) + + init_logging() + train() + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/node-hub/dora-dataset-record/pyproject.toml b/node-hub/dora-dataset-record/pyproject.toml index 122895c4..487e8539 100644 --- a/node-hub/dora-dataset-record/pyproject.toml +++ b/node-hub/dora-dataset-record/pyproject.toml @@ -7,30 +7,25 @@ license = { text = "MIT" } readme = "README.md" requires-python = ">=3.8" -dependencies = ["dora-rs >= 0.3.9","pyarrow","lerobot"] +dependencies = ["dora-rs >= 0.3.9", "pyarrow", "lerobot[train]"] [dependency-groups] dev = ["pytest >=8.1.1", "ruff >=0.9.1"] [project.scripts] dora-dataset-record = "dora_dataset_record.main:main" +dora-dataset-lerobot-train = "dora_dataset_record.main:train_main" [tool.ruff.lint] extend-select = [ - "D", # pydocstyle - "UP" + "D", # pydocstyle + "UP", ] ignore = [ - "D100", # Missing docstring in public module - "D103", # Missing docstring in public function - "D104", # Missing docstring in public package + "D100", # Missing docstring in public module + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package ] [tool.uv.sources] -lerobot = {git = "https://github.com/huggingface/lerobot.git", rev = "main"} - -[tool.uv.pip] -git-config = [ - "filter.lfs.smudge=git-lfs smudge --skip -- %f", - "filter.lfs.process=git-lfs filter-process --skip" -] \ No newline at end of file +lerobot = { git = "https://github.com/huggingface/lerobot.git", rev = "main" } diff --git a/node-hub/dora-policy-inference/dora_policy_inference/main.py b/node-hub/dora-policy-inference/dora_policy_inference/main.py index 69c06467..00445c0c 100644 --- a/node-hub/dora-policy-inference/dora_policy_inference/main.py +++ b/node-hub/dora-policy-inference/dora_policy_inference/main.py @@ -1,20 +1,21 @@ """TODO.""" -from dora import Node -import pyarrow as pa import os -import time -import numpy as np -import threading import queue -import cv2 -from typing import Any +import threading +import time from pathlib import Path +from typing import Any +import cv2 +import numpy as np +import pyarrow as pa +from dora import Node from lerobot.configs.policies import PreTrainedConfig -from lerobot.common.policies.factory import get_policy_class -from lerobot.common.utils.control_utils import predict_action -from lerobot.common.utils.utils import get_safe_torch_device +from lerobot.policies.factory import get_policy_class +from lerobot.utils.control_utils import predict_action +from lerobot.utils.utils import get_safe_torch_device + class DoraPolicyInference: """Policy inference node for LeRobot policies.""" @@ -25,10 +26,10 @@ class DoraPolicyInference: self.data_buffer = {} self.buffer_lock = threading.Lock() - self.cameras = self._parse_cameras() + self.cameras = self._parse_cameras() self.task_description = os.getenv("TASK_DESCRIPTION", "") - self.policy = None + self.policy = None self.inference_fps = int(os.getenv("INFERENCE_FPS", "30")) self.last_inference_time = None self.inference_interval = 1.0 / self.inference_fps @@ -40,12 +41,14 @@ class DoraPolicyInference: def _parse_cameras(self) -> dict: """Parse camera configuration from environment variables.""" camera_names_str = os.getenv("CAMERA_NAMES", "laptop,front") - camera_names = [name.strip() for name in camera_names_str.split(',')] + camera_names = [name.strip() for name in camera_names_str.split(",")] cameras = {} for camera_name in camera_names: - resolution = os.getenv(f"CAMERA_{camera_name.upper()}_RESOLUTION", "480,640,3") - dims = [int(d.strip()) for d in resolution.split(',')] + resolution = os.getenv( + f"CAMERA_{camera_name.upper()}_RESOLUTION", "480,640,3" + ) + dims = [int(d.strip()) for d in resolution.split(",")] cameras[camera_name] = dims return cameras @@ -57,14 +60,13 @@ class DoraPolicyInference: raise ValueError("MODEL_PATH environment variable must be set correctly.") config = PreTrainedConfig.from_pretrained(model_path) - self.device = get_safe_torch_device(config.device) # get device automatically + self.device = get_safe_torch_device(config.device) # get device automatically config.device = self.device.type # Get the policy class and load pretrained weights policy_cls = get_policy_class(config.type) self.policy = policy_cls.from_pretrained( - pretrained_name_or_path=model_path, - config=config + pretrained_name_or_path=model_path, config=config ) self.policy.eval() # Set to evaluation mode @@ -76,15 +78,19 @@ class DoraPolicyInference: def _start_timer(self): """Start the inference timing thread.""" self.stop_timer = False - self.inference_thread = threading.Thread(target=self._inference_loop, daemon=True) + self.inference_thread = threading.Thread( + target=self._inference_loop, daemon=True + ) self.inference_thread.start() - + def _inference_loop(self): """Inference loop.""" while not self.stop_timer and not self.shutdown: current_time = time.time() - if (self.last_inference_time is None or - current_time - self.last_inference_time >= self.inference_interval): + if ( + self.last_inference_time is None + or current_time - self.last_inference_time >= self.inference_interval + ): self._run_inference() self.last_inference_time = current_time @@ -104,18 +110,24 @@ class DoraPolicyInference: if camera_name in self.data_buffer: image = self._convert_camera_data( self.data_buffer[camera_name]["data"], - self.data_buffer[camera_name]["metadata"]) + self.data_buffer[camera_name]["metadata"], + ) observation[f"observation.images.{camera_name}"] = image state = self._convert_robot_data(self.data_buffer["robot_state"]["data"]) observation["observation.state"] = state - action = predict_action( - observation=observation, - policy=self.policy, - device=self.device, - use_amp=self.policy.config.use_amp, - task=self.task_description).cpu().numpy() + action = ( + predict_action( + observation=observation, + policy=self.policy, + device=self.device, + use_amp=self.policy.config.use_amp, + task=self.task_description, + ) + .cpu() + .numpy() + ) # Convert from degrees to radians action = np.deg2rad(action) @@ -153,7 +165,7 @@ class DoraPolicyInference: self.data_buffer[input_id] = { "data": data, "timestamp": time.time(), - "metadata": metadata + "metadata": metadata, } def get_pending_messages(self): @@ -193,24 +205,20 @@ def main(): node.send_output( output_id="robot_action", data=pa.array(msg_data["action"]), - metadata={} + metadata={}, ) elif msg_type == "status": node.send_output( - output_id="status", - data=pa.array([msg_data]), - metadata={} + output_id="status", data=pa.array([msg_data]), metadata={} ) - + if event["type"] == "INPUT": inference.handle_input( - event["id"], - event["value"], - event.get("metadata", {}) + event["id"], event["value"], event.get("metadata", {}) ) # inference.shutdown() # Not sure when to shutdown the node, and how to identify task completion if __name__ == "__main__": - main() \ No newline at end of file + main()