| @@ -16,20 +16,23 @@ CAPTURE_PATH: "0" # Laptop camera is defaults to 0 | |||
| # External camera | |||
| CAPTURE_PATH: "1" # Change this to match your external camera device | |||
| ``` | |||
| ### SO101 Arms | |||
| Identify and set the correct USB ports for both the leader and follower SO101 arms | |||
| ```yaml | |||
| PORT: "/dev/ttyACM0" # change this | |||
| ``` | |||
| ### Dataset Recorder Configuration | |||
| >Edit the following fields: | |||
| > Edit the following fields: | |||
| ```yaml | |||
| REPO_ID: "your_username/so101_dataset" # HuggingFace dataset path | |||
| SINGLE_TASK: "Pick up the red cube and place it in the blue box" # Your task description | |||
| CAMERA_NAMES: "laptop, front" # Name your camera sources depending on your setup | |||
| REPO_ID: "your_username/so101_dataset" # HuggingFace dataset path | |||
| SINGLE_TASK: "Pick up the red cube and place it in the blue box" # Your task description | |||
| CAMERA_NAMES: "laptop, front" # Name your camera sources depending on your setup | |||
| CAMERA_LAPTOP_RESOLUTION: "480,640,3" | |||
| CAMERA_FRONT_RESOLUTION: "480,640,3" | |||
| ``` | |||
| @@ -37,24 +40,26 @@ CAMERA_FRONT_RESOLUTION: "480,640,3" | |||
| You can adjust the following parameters as needed: | |||
| ```yaml | |||
| TOTAL_EPISODES: "2" # Number of episodes | |||
| TOTAL_EPISODES: "2" # Number of episodes | |||
| #you may want to try with 2-3 episodes to test, then atleast 50 episodes for training is recommended | |||
| EPISODE_DURATION_S: "60" # Duration of each episode (in seconds) - depends on complexity of task | |||
| RESET_DURATION_S: "15" # Time to reset the environment between episodes | |||
| FPS: "30" # Should match camera fps | |||
| PUSH_TO_HUB: "false" # Set to "true" to auto-upload dataset to HuggingFace | |||
| ROOT_PATH: "full path where you want to save the dataset" | |||
| EPISODE_DURATION_S: "60" # Duration of each episode (in seconds) - depends on complexity of task | |||
| RESET_DURATION_S: "15" # Time to reset the environment between episodes | |||
| FPS: "30" # Should match camera fps | |||
| PUSH_TO_HUB: "false" # Set to "true" to auto-upload dataset to HuggingFace | |||
| ROOT_PATH: "full path where you want to save the dataset" | |||
| # if not defined then it will be stored at ~/.cache/huggingface/lerobot/repo_id | |||
| ``` | |||
| Once everything is updated in `dataflow.yml`, you are ready to record your dataset. | |||
| ## Start Recording | |||
| Build and Run | |||
| ```bash | |||
| dora build dataflow.yml | |||
| dora run dataflow.yml | |||
| uv venv | |||
| dora build dataflow.yml --uv | |||
| dora run dataflow.yml --uv | |||
| ``` | |||
| ## Recording Process | |||
| @@ -64,11 +69,13 @@ In the rerun window you can see the the info regarding Start of episodes, start | |||
| #### During Recording | |||
| 1. **Episode Active**: | |||
| - Use the **leader robot** to demonstrate/perform the task | |||
| - Move smoothly and naturally | |||
| - Complete the full task within the time limit | |||
| 2. **Reset Phase**: | |||
| - move the leader arm to initial position | |||
| - Reset objects to starting positions | |||
| - Prepare workspace for next demonstration | |||
| @@ -86,8 +93,8 @@ In the rerun window you can see the the info regarding Start of episodes, start | |||
| ## After Recording | |||
| Your dataset will be saved locally. Check the recording was successful: | |||
| >It will be stored in ~/.cache/huggingface/lerobot/repo_id | |||
| > It will be stored in ~/.cache/huggingface/lerobot/repo_id | |||
| # Training and Testing Policies using the recorded dataset | |||
| @@ -95,35 +102,24 @@ After successfully recording your dataset, we will be training imitation learnin | |||
| ## Training Your Policy | |||
| #### Install LeRobot Training Dependencies | |||
| Easiest way to train your policy is to use lerobots training scripts | |||
| ```bash | |||
| # Install training requirements and lerobot repo | |||
| git clone https://github.com/huggingface/lerobot.git | |||
| pip install lerobot[training] | |||
| pip install tensorboard wandb # For monitoring (Optional) | |||
| ``` | |||
| ### Choose Your Policy | |||
| **ACT (Recommended for SO101)** | |||
| - Good for manipulation tasks | |||
| - Handles multi-modal data well | |||
| - Faster training (for me it took 7hrs 🥺 to train on 50 episodes for pick and place, 3050 laptop gpu) | |||
| **Diffusion Policy** | |||
| - Better for complex behaviors | |||
| - More robust to distribution shift | |||
| - Longer training | |||
| ### Start Training | |||
| ```bash | |||
| cd lerobot | |||
| python lerobot/scripts/train.py \ | |||
| uv run dora-dataset-lerobot-train | |||
| --dataset.repo_id=${HF_USER}/your_repo_id \ # provide full path of your dataset | |||
| --policy.type=act \ | |||
| --output_dir=outputs/train/act_so101_test \ | |||
| @@ -134,6 +130,7 @@ python lerobot/scripts/train.py \ | |||
| ``` | |||
| You can monitor your training progress on wandb | |||
| > For more details regarding training check [lerobot](https://huggingface.co/docs/lerobot/en/il_robots#train-a-policy) guide on Imitation learning for SO101 | |||
| ## Policy Inference and Testing | |||
| @@ -150,7 +147,7 @@ Update the camera device IDs to match your setup (same as recording): | |||
| # Laptop camera | |||
| CAPTURE_PATH: "0" # Usually 0 for built-in laptop camera | |||
| # External camera | |||
| # External camera | |||
| CAPTURE_PATH: "1" # Change this to your external camera device ID | |||
| ``` | |||
| @@ -159,7 +156,7 @@ CAPTURE_PATH: "1" # Change this to your external camera device ID | |||
| Set the correct USB port for your follower SO101 arm: | |||
| ```yaml | |||
| PORT: "/dev/ttyACM1" # Update this to match your follower robot port | |||
| PORT: "/dev/ttyACM1" # Update this to match your follower robot port | |||
| ``` | |||
| #### 3. Model Configuration | |||
| @@ -167,7 +164,7 @@ PORT: "/dev/ttyACM1" # Update this to match your follower robot port | |||
| Update the path to your trained model and task description: | |||
| ```yaml | |||
| MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model" # Path to your trained model | |||
| MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model" # Path to your trained model | |||
| TASK_DESCRIPTION: "Pick up the red cube and place it in the blue box" | |||
| ``` | |||
| @@ -176,9 +173,9 @@ TASK_DESCRIPTION: "Pick up the red cube and place it in the blue box" | |||
| Ensure camera settings match your recording configuration: | |||
| ```yaml | |||
| CAMERA_NAMES: "laptop, front" # Must match training setup | |||
| CAMERA_LAPTOP_RESOLUTION: "480,640,3" # Must match training | |||
| CAMERA_FRONT_RESOLUTION: "480,640,3" # Must match training | |||
| CAMERA_NAMES: "laptop, front" # Must match training setup | |||
| CAMERA_LAPTOP_RESOLUTION: "480,640,3" # Must match training | |||
| CAMERA_FRONT_RESOLUTION: "480,640,3" # Must match training | |||
| ``` | |||
| ### Start Policy Inference | |||
| @@ -188,4 +185,4 @@ Once you've updated the configuration: | |||
| ```bash | |||
| dora build policy_inference.yml | |||
| dora run policy_inference.yml | |||
| ``` | |||
| ``` | |||
| @@ -21,7 +21,7 @@ nodes: | |||
| outputs: | |||
| - image | |||
| env: | |||
| CAPTURE_PATH: "1" # ← UPDATE: Your external camera device | |||
| CAPTURE_PATH: "1" # ← UPDATE: Your external camera device | |||
| ENCODING: "rgb8" | |||
| IMAGE_WIDTH: "640" | |||
| IMAGE_HEIGHT: "480" | |||
| @@ -35,21 +35,21 @@ nodes: | |||
| outputs: | |||
| - pose | |||
| env: | |||
| PORT: "/dev/ttyACM1" # UPDATE your robot port | |||
| PORT: "/dev/ttyACM1" # UPDATE your robot port | |||
| IDS: "1,2,3,4,5,6" | |||
| - id: policy_inference | |||
| build: pip install -e ../../node-hub/dora-policy-inference | |||
| path: dora-policy-inference | |||
| inputs: | |||
| laptop: laptop_cam/image # the camera inputs should match the training setup | |||
| laptop: laptop_cam/image # the camera inputs should match the training setup | |||
| front: front_cam/image | |||
| robot_state: so101_follower/pose | |||
| outputs: | |||
| - robot_action | |||
| - status | |||
| env: | |||
| MODEL_PATH: "./outputs/train/your_policy/checkpoints/last/pretrained_model" # Path to your trained model | |||
| MODEL_PATH: "./outputs/train/act_so101_test/checkpoints/last/pretrained_model/" # Path to your trained model | |||
| TASK_DESCRIPTION: "Your task" | |||
| ROBOT_TYPE: "so101_follower" | |||
| CAMERA_NAMES: "laptop, front" | |||
| @@ -64,4 +64,4 @@ nodes: | |||
| inputs: | |||
| image_laptop: laptop_cam/image | |||
| image_front: front_cam/image | |||
| status: policy_inference/status | |||
| status: policy_inference/status | |||
| @@ -1,15 +1,17 @@ | |||
| """TODO: Add docstring.""" | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| import os | |||
| import time | |||
| import numpy as np | |||
| import threading | |||
| import queue | |||
| import cv2 | |||
| import threading | |||
| import time | |||
| from typing import Any | |||
| from lerobot.common.datasets.lerobot_dataset import LeRobotDataset | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| from lerobot.datasets.lerobot_dataset import LeRobotDataset | |||
| class DoraLeRobotRecorder: | |||
| """Recorder class for LeRobot dataset.""" | |||
| @@ -28,10 +30,16 @@ class DoraLeRobotRecorder: | |||
| self.episode_index = 0 | |||
| self.start_time = None | |||
| self.cameras = self._get_cameras() | |||
| self.total_episodes = int(os.getenv("TOTAL_EPISODES", "10")) # Default to 10 episodes | |||
| self.episode_duration = int(os.getenv("EPISODE_DURATION_S", "60")) # Default to 60 seconds | |||
| self.reset_duration = int(os.getenv("RESET_DURATION_S", "15")) # Default to 15 seconds | |||
| self.fps = int(os.getenv("FPS", "30")) # Default to 30 FPS | |||
| self.total_episodes = int( | |||
| os.getenv("TOTAL_EPISODES", "10") | |||
| ) # Default to 10 episodes | |||
| self.episode_duration = int( | |||
| os.getenv("EPISODE_DURATION_S", "60") | |||
| ) # Default to 60 seconds | |||
| self.reset_duration = int( | |||
| os.getenv("RESET_DURATION_S", "15") | |||
| ) # Default to 15 seconds | |||
| self.fps = int(os.getenv("FPS", "30")) # Default to 30 FPS | |||
| self.recording_started = False | |||
| self.in_reset_phase = False | |||
| @@ -49,7 +57,7 @@ class DoraLeRobotRecorder: | |||
| """Get Camera config.""" | |||
| camera_names_str = os.getenv("CAMERA_NAMES") | |||
| if camera_names_str: | |||
| camera_names = [name.strip() for name in camera_names_str.split(',')] | |||
| camera_names = [name.strip() for name in camera_names_str.split(",")] | |||
| else: | |||
| return {} | |||
| @@ -57,10 +65,12 @@ class DoraLeRobotRecorder: | |||
| for camera_name in camera_names: | |||
| resolution = os.getenv(f"CAMERA_{camera_name.upper()}_RESOLUTION") | |||
| if resolution: | |||
| dims = [int(d.strip()) for d in resolution.split(',')] | |||
| dims = [int(d.strip()) for d in resolution.split(",")] | |||
| cameras[camera_name] = dims | |||
| else: | |||
| print(f"Warning: Set CAMERA_{camera_name.upper()}_RESOLUTION: \"height,width,channels\"") | |||
| print( | |||
| f'Warning: Set CAMERA_{camera_name.upper()}_RESOLUTION: "height,width,channels"' | |||
| ) | |||
| return cameras | |||
| @@ -68,7 +78,7 @@ class DoraLeRobotRecorder: | |||
| """Get robot joints.""" | |||
| joints_str = os.getenv("ROBOT_JOINTS") | |||
| if joints_str: | |||
| return [joint.strip() for joint in joints_str.split(',')] | |||
| return [joint.strip() for joint in joints_str.split(",")] | |||
| else: | |||
| raise ValueError("ROBOT_JOINTS are not set.") | |||
| @@ -76,7 +86,7 @@ class DoraLeRobotRecorder: | |||
| """Get tags for dataset.""" | |||
| tags_str = os.getenv("TAGS") | |||
| if tags_str: | |||
| return [tag.strip() for tag in tags_str.split(',')] | |||
| return [tag.strip() for tag in tags_str.split(",")] | |||
| return [] | |||
| def _setup_dataset(self): | |||
| @@ -85,34 +95,41 @@ class DoraLeRobotRecorder: | |||
| joint_names = self._get_robot_joints() | |||
| features["action"] = { | |||
| "dtype": "float32", | |||
| "dtype": "float32", | |||
| "shape": (len(joint_names),), | |||
| "names": joint_names} | |||
| "names": joint_names, | |||
| } | |||
| features["observation.state"] = { | |||
| "dtype": "float32", | |||
| "shape": (len(joint_names),), | |||
| "names": joint_names} | |||
| "names": joint_names, | |||
| } | |||
| self.use_videos = os.getenv("USE_VIDEOS", "true").lower() == "true" | |||
| for camera_name in self.cameras: | |||
| features[f"observation.images.{camera_name}"] = { | |||
| "dtype": "video" if self.use_videos else "image", | |||
| "shape": self.cameras[camera_name], | |||
| "names": ["height", "width", "channels"]} | |||
| "names": ["height", "width", "channels"], | |||
| } | |||
| self.required_features = set(features.keys()) | |||
| features.update({ | |||
| "timestamp": {"dtype": "float32", "shape": [1]}, | |||
| "frame_index": {"dtype": "int64", "shape": [1]}, | |||
| "episode_index": {"dtype": "int64", "shape": [1]}, | |||
| "index": {"dtype": "int64", "shape": [1]}, | |||
| "task_index": {"dtype": "int64", "shape": [1]}, | |||
| }) | |||
| features.update( | |||
| { | |||
| "timestamp": {"dtype": "float32", "shape": [1]}, | |||
| "frame_index": {"dtype": "int64", "shape": [1]}, | |||
| "episode_index": {"dtype": "int64", "shape": [1]}, | |||
| "index": {"dtype": "int64", "shape": [1]}, | |||
| "task_index": {"dtype": "int64", "shape": [1]}, | |||
| } | |||
| ) | |||
| repo_id = os.getenv("REPO_ID", None) | |||
| if repo_id is None: | |||
| raise ValueError("REPO_ID environment variable must be set to create dataset") | |||
| raise ValueError( | |||
| "REPO_ID environment variable must be set to create dataset" | |||
| ) | |||
| self.dataset = LeRobotDataset.create( | |||
| repo_id=repo_id, | |||
| @@ -122,7 +139,8 @@ class DoraLeRobotRecorder: | |||
| robot_type=os.getenv("ROBOT_TYPE", "your_robot_type"), | |||
| use_videos=self.use_videos, | |||
| image_writer_processes=int(os.getenv("IMAGE_WRITER_PROCESSES", "0")), | |||
| image_writer_threads=int(os.getenv("IMAGE_WRITER_THREADS", "4")) * len(self.cameras), | |||
| image_writer_threads=int(os.getenv("IMAGE_WRITER_THREADS", "4")) | |||
| * len(self.cameras), | |||
| ) | |||
| def _check_episode_timing(self): | |||
| @@ -167,7 +185,9 @@ class DoraLeRobotRecorder: | |||
| """End current episode and save to dataset.""" | |||
| self.episode_active = False | |||
| if self.frame_count > 0: | |||
| self._output(f"Saving episode index {self.episode_index} with {self.frame_count} frames") | |||
| self._output( | |||
| f"Saving episode index {self.episode_index} with {self.frame_count} frames" | |||
| ) | |||
| self.dataset.save_episode() | |||
| self.episode_index += 1 | |||
| else: | |||
| @@ -177,12 +197,16 @@ class DoraLeRobotRecorder: | |||
| """Start the reset phase between episodes.""" | |||
| self.in_reset_phase = True | |||
| self.reset_start_time = time.time() | |||
| self._output(f"Reset phase started - {self.reset_duration}s break before next episode...") | |||
| self._output( | |||
| f"Reset phase started - {self.reset_duration}s break before next episode..." | |||
| ) | |||
| def _start_frame_timer(self): | |||
| """Start the frame timer thread.""" | |||
| self.stop_timer = False | |||
| self.frame_timer_thread = threading.Thread(target=self._frame_timer_loop, daemon=True) | |||
| self.frame_timer_thread = threading.Thread( | |||
| target=self._frame_timer_loop, daemon=True | |||
| ) | |||
| self.frame_timer_thread.start() | |||
| def _frame_timer_loop(self): | |||
| @@ -190,9 +214,13 @@ class DoraLeRobotRecorder: | |||
| while not self.stop_timer and not self.shutdown: | |||
| current_time = time.time() | |||
| if self.episode_active and not self.in_reset_phase and ( | |||
| self.last_frame_time is None or | |||
| current_time - self.last_frame_time >= self.frame_interval | |||
| if ( | |||
| self.episode_active | |||
| and not self.in_reset_phase | |||
| and ( | |||
| self.last_frame_time is None | |||
| or current_time - self.last_frame_time >= self.frame_interval | |||
| ) | |||
| ): | |||
| self._add_frame() | |||
| self.last_frame_time = current_time | |||
| @@ -212,7 +240,7 @@ class DoraLeRobotRecorder: | |||
| self.data_buffer[input_id] = { | |||
| "data": data, | |||
| "timestamp": time.time(), | |||
| "metadata": metadata | |||
| "metadata": metadata, | |||
| } | |||
| should_stop = self._check_episode_timing() | |||
| @@ -242,7 +270,7 @@ class DoraLeRobotRecorder: | |||
| """Convert camera data from 1D pyarrow array to numpy format.""" | |||
| height, width = metadata.get("height"), metadata.get("width") | |||
| encoding = metadata.get("encoding") | |||
| image = dora_data.to_numpy().reshape(height, width, 3) | |||
| image = dora_data.to_numpy().reshape(height, width, 3) | |||
| if encoding == "bgr8": | |||
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |||
| @@ -266,18 +294,24 @@ class DoraLeRobotRecorder: | |||
| for key, value in self.data_buffer.items(): | |||
| if key == "robot_action": | |||
| frame_data["action"] = self.convert_robot_data(self.data_buffer["robot_action"]["data"]) | |||
| frame_data["action"] = self.convert_robot_data( | |||
| self.data_buffer["robot_action"]["data"] | |||
| ) | |||
| if key == "robot_state": | |||
| frame_data["observation.state"] = self.convert_robot_data(self.data_buffer["robot_state"]["data"]) | |||
| if {'height', 'width'} <= value.get('metadata', {}).keys(): | |||
| frame_data["observation.state"] = self.convert_robot_data( | |||
| self.data_buffer["robot_state"]["data"] | |||
| ) | |||
| if {"height", "width"} <= value.get("metadata", {}).keys(): | |||
| camera_name = key | |||
| image = self._convert_camera_data( | |||
| self.data_buffer[camera_name]["data"], | |||
| self.data_buffer[camera_name]["metadata"] | |||
| self.data_buffer[camera_name]["metadata"], | |||
| ) | |||
| frame_data[f"observation.images.{camera_name}"] = image | |||
| missing_keys = self.required_features - set(frame_data.keys()) # Ensure all required features are present | |||
| missing_keys = self.required_features - set( | |||
| frame_data.keys() | |||
| ) # Ensure all required features are present | |||
| if missing_keys: | |||
| print(f"Missing required data in frame: {missing_keys}") | |||
| return | |||
| @@ -285,7 +319,7 @@ class DoraLeRobotRecorder: | |||
| self.dataset.add_frame( | |||
| frame=frame_data, | |||
| task=os.getenv("SINGLE_TASK", "Your task"), | |||
| timestamp=ideal_timestamp | |||
| timestamp=ideal_timestamp, | |||
| ) | |||
| self.frame_count += 1 | |||
| @@ -302,10 +336,12 @@ class DoraLeRobotRecorder: | |||
| self._output("Pushing dataset to hub...") | |||
| self.dataset.push_to_hub( | |||
| tags=self._get_tags(), | |||
| private=os.getenv("PRIVATE", "false").lower() == "true" | |||
| private=os.getenv("PRIVATE", "false").lower() == "true", | |||
| ) | |||
| self._output(f"Dataset recording completed. Total episodes: {self.episode_index}") | |||
| self._output( | |||
| f"Dataset recording completed. Total episodes: {self.episode_index}" | |||
| ) | |||
| def _output(self, message: str): | |||
| """Output message.""" | |||
| @@ -320,6 +356,7 @@ class DoraLeRobotRecorder: | |||
| messages.append(self.message_queue.get_nowait()) | |||
| return messages | |||
| def main(): | |||
| node = Node() | |||
| recorder = DoraLeRobotRecorder() | |||
| @@ -332,18 +369,28 @@ def main(): | |||
| for event in node: | |||
| pending_messages = recorder.get_pending_messages() | |||
| for message in pending_messages: | |||
| node.send_output( | |||
| output_id="text", | |||
| data=pa.array([message]), | |||
| metadata={}) | |||
| node.send_output(output_id="text", data=pa.array([message]), metadata={}) | |||
| if event["type"] == "INPUT": | |||
| should_stop = recorder.handle_input(event["id"], event["value"], event.get("metadata", {})) | |||
| should_stop = recorder.handle_input( | |||
| event["id"], event["value"], event.get("metadata", {}) | |||
| ) | |||
| if should_stop: | |||
| print("All episodes completed, stopping recording...") | |||
| break | |||
| recorder._shutdown() | |||
| def train_main(): | |||
| from lerobot.scripts.train import train | |||
| from lerobot.utils.utils import ( | |||
| init_logging, | |||
| ) | |||
| init_logging() | |||
| train() | |||
| if __name__ == "__main__": | |||
| main() | |||
| main() | |||
| @@ -7,30 +7,25 @@ license = { text = "MIT" } | |||
| readme = "README.md" | |||
| requires-python = ">=3.8" | |||
| dependencies = ["dora-rs >= 0.3.9","pyarrow","lerobot"] | |||
| dependencies = ["dora-rs >= 0.3.9", "pyarrow", "lerobot[train]"] | |||
| [dependency-groups] | |||
| dev = ["pytest >=8.1.1", "ruff >=0.9.1"] | |||
| [project.scripts] | |||
| dora-dataset-record = "dora_dataset_record.main:main" | |||
| dora-dataset-lerobot-train = "dora_dataset_record.main:train_main" | |||
| [tool.ruff.lint] | |||
| extend-select = [ | |||
| "D", # pydocstyle | |||
| "UP" | |||
| "D", # pydocstyle | |||
| "UP", | |||
| ] | |||
| ignore = [ | |||
| "D100", # Missing docstring in public module | |||
| "D103", # Missing docstring in public function | |||
| "D104", # Missing docstring in public package | |||
| "D100", # Missing docstring in public module | |||
| "D103", # Missing docstring in public function | |||
| "D104", # Missing docstring in public package | |||
| ] | |||
| [tool.uv.sources] | |||
| lerobot = {git = "https://github.com/huggingface/lerobot.git", rev = "main"} | |||
| [tool.uv.pip] | |||
| git-config = [ | |||
| "filter.lfs.smudge=git-lfs smudge --skip -- %f", | |||
| "filter.lfs.process=git-lfs filter-process --skip" | |||
| ] | |||
| lerobot = { git = "https://github.com/huggingface/lerobot.git", rev = "main" } | |||
| @@ -1,20 +1,21 @@ | |||
| """TODO.""" | |||
| from dora import Node | |||
| import pyarrow as pa | |||
| import os | |||
| import time | |||
| import numpy as np | |||
| import threading | |||
| import queue | |||
| import cv2 | |||
| from typing import Any | |||
| import threading | |||
| import time | |||
| from pathlib import Path | |||
| from typing import Any | |||
| import cv2 | |||
| import numpy as np | |||
| import pyarrow as pa | |||
| from dora import Node | |||
| from lerobot.configs.policies import PreTrainedConfig | |||
| from lerobot.common.policies.factory import get_policy_class | |||
| from lerobot.common.utils.control_utils import predict_action | |||
| from lerobot.common.utils.utils import get_safe_torch_device | |||
| from lerobot.policies.factory import get_policy_class | |||
| from lerobot.utils.control_utils import predict_action | |||
| from lerobot.utils.utils import get_safe_torch_device | |||
| class DoraPolicyInference: | |||
| """Policy inference node for LeRobot policies.""" | |||
| @@ -25,10 +26,10 @@ class DoraPolicyInference: | |||
| self.data_buffer = {} | |||
| self.buffer_lock = threading.Lock() | |||
| self.cameras = self._parse_cameras() | |||
| self.cameras = self._parse_cameras() | |||
| self.task_description = os.getenv("TASK_DESCRIPTION", "") | |||
| self.policy = None | |||
| self.policy = None | |||
| self.inference_fps = int(os.getenv("INFERENCE_FPS", "30")) | |||
| self.last_inference_time = None | |||
| self.inference_interval = 1.0 / self.inference_fps | |||
| @@ -40,12 +41,14 @@ class DoraPolicyInference: | |||
| def _parse_cameras(self) -> dict: | |||
| """Parse camera configuration from environment variables.""" | |||
| camera_names_str = os.getenv("CAMERA_NAMES", "laptop,front") | |||
| camera_names = [name.strip() for name in camera_names_str.split(',')] | |||
| camera_names = [name.strip() for name in camera_names_str.split(",")] | |||
| cameras = {} | |||
| for camera_name in camera_names: | |||
| resolution = os.getenv(f"CAMERA_{camera_name.upper()}_RESOLUTION", "480,640,3") | |||
| dims = [int(d.strip()) for d in resolution.split(',')] | |||
| resolution = os.getenv( | |||
| f"CAMERA_{camera_name.upper()}_RESOLUTION", "480,640,3" | |||
| ) | |||
| dims = [int(d.strip()) for d in resolution.split(",")] | |||
| cameras[camera_name] = dims | |||
| return cameras | |||
| @@ -57,14 +60,13 @@ class DoraPolicyInference: | |||
| raise ValueError("MODEL_PATH environment variable must be set correctly.") | |||
| config = PreTrainedConfig.from_pretrained(model_path) | |||
| self.device = get_safe_torch_device(config.device) # get device automatically | |||
| self.device = get_safe_torch_device(config.device) # get device automatically | |||
| config.device = self.device.type | |||
| # Get the policy class and load pretrained weights | |||
| policy_cls = get_policy_class(config.type) | |||
| self.policy = policy_cls.from_pretrained( | |||
| pretrained_name_or_path=model_path, | |||
| config=config | |||
| pretrained_name_or_path=model_path, config=config | |||
| ) | |||
| self.policy.eval() # Set to evaluation mode | |||
| @@ -76,15 +78,19 @@ class DoraPolicyInference: | |||
| def _start_timer(self): | |||
| """Start the inference timing thread.""" | |||
| self.stop_timer = False | |||
| self.inference_thread = threading.Thread(target=self._inference_loop, daemon=True) | |||
| self.inference_thread = threading.Thread( | |||
| target=self._inference_loop, daemon=True | |||
| ) | |||
| self.inference_thread.start() | |||
| def _inference_loop(self): | |||
| """Inference loop.""" | |||
| while not self.stop_timer and not self.shutdown: | |||
| current_time = time.time() | |||
| if (self.last_inference_time is None or | |||
| current_time - self.last_inference_time >= self.inference_interval): | |||
| if ( | |||
| self.last_inference_time is None | |||
| or current_time - self.last_inference_time >= self.inference_interval | |||
| ): | |||
| self._run_inference() | |||
| self.last_inference_time = current_time | |||
| @@ -104,18 +110,24 @@ class DoraPolicyInference: | |||
| if camera_name in self.data_buffer: | |||
| image = self._convert_camera_data( | |||
| self.data_buffer[camera_name]["data"], | |||
| self.data_buffer[camera_name]["metadata"]) | |||
| self.data_buffer[camera_name]["metadata"], | |||
| ) | |||
| observation[f"observation.images.{camera_name}"] = image | |||
| state = self._convert_robot_data(self.data_buffer["robot_state"]["data"]) | |||
| observation["observation.state"] = state | |||
| action = predict_action( | |||
| observation=observation, | |||
| policy=self.policy, | |||
| device=self.device, | |||
| use_amp=self.policy.config.use_amp, | |||
| task=self.task_description).cpu().numpy() | |||
| action = ( | |||
| predict_action( | |||
| observation=observation, | |||
| policy=self.policy, | |||
| device=self.device, | |||
| use_amp=self.policy.config.use_amp, | |||
| task=self.task_description, | |||
| ) | |||
| .cpu() | |||
| .numpy() | |||
| ) | |||
| # Convert from degrees to radians | |||
| action = np.deg2rad(action) | |||
| @@ -153,7 +165,7 @@ class DoraPolicyInference: | |||
| self.data_buffer[input_id] = { | |||
| "data": data, | |||
| "timestamp": time.time(), | |||
| "metadata": metadata | |||
| "metadata": metadata, | |||
| } | |||
| def get_pending_messages(self): | |||
| @@ -193,24 +205,20 @@ def main(): | |||
| node.send_output( | |||
| output_id="robot_action", | |||
| data=pa.array(msg_data["action"]), | |||
| metadata={} | |||
| metadata={}, | |||
| ) | |||
| elif msg_type == "status": | |||
| node.send_output( | |||
| output_id="status", | |||
| data=pa.array([msg_data]), | |||
| metadata={} | |||
| output_id="status", data=pa.array([msg_data]), metadata={} | |||
| ) | |||
| if event["type"] == "INPUT": | |||
| inference.handle_input( | |||
| event["id"], | |||
| event["value"], | |||
| event.get("metadata", {}) | |||
| event["id"], event["value"], event.get("metadata", {}) | |||
| ) | |||
| # inference.shutdown() # Not sure when to shutdown the node, and how to identify task completion | |||
| if __name__ == "__main__": | |||
| main() | |||
| main() | |||