Browse Source

Adding `rdt-1b` node

improve pytest of rdt-1b

Add main into rdt-1b

add small cloud fix for rdt 1b

Small rdt-1b main fix

Small improvement on rdt 1b

Small fixes to dora-rdt-1b main

Add piper example

Add environment variable for configuration vision and language parameter

add python feature flag to dora-rerun

Fix play inference

fixing replay issue

make data dir dependant on the date
tags/0.3.8-rc
haixuanTao haixuantao 1 year ago
parent
commit
5d87bd1beb
17 changed files with 1459 additions and 5 deletions
  1. +3
    -0
      .gitmodules
  2. +61
    -0
      examples/piper/README.md
  3. +74
    -0
      examples/piper/arms_camera.yml
  4. +33
    -0
      examples/piper/arms_only.yml
  5. +126
    -0
      examples/piper/dummy_inference_2.py
  6. +38
    -5
      examples/piper/play_dummy_inference.yml
  7. +24
    -0
      examples/piper/post_process_action.py
  8. +132
    -0
      examples/piper/rdt_1b.yml
  9. +231
    -0
      examples/piper/record.py
  10. +115
    -0
      examples/piper/record.yml
  11. +3
    -0
      node-hub/dora-rdt-1b/README.md
  12. +1
    -0
      node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer
  13. +19
    -0
      node-hub/dora-rdt-1b/dora_rdt_1b/__init__.py
  14. +324
    -0
      node-hub/dora-rdt-1b/dora_rdt_1b/main.py
  15. +36
    -0
      node-hub/dora-rdt-1b/pyproject.toml
  16. +12
    -0
      node-hub/dora-rdt-1b/tests/conftest.py
  17. +227
    -0
      node-hub/dora-rdt-1b/tests/test_dora_rdt_1b.py

+ 3
- 0
.gitmodules View File

@@ -0,0 +1,3 @@
[submodule "node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer"]
path = node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer
url = https://github.com/thu-ml/RoboticsDiffusionTransformer

+ 61
- 0
examples/piper/README.md View File

@@ -0,0 +1,61 @@
# Getting Started with Tracer + Piper

## Installation (To do once)

Make sure to:

```bash
dora build rdt_1b.yaml

# Make sure to install from source pyorbbecksdk

git clone https://github.com/orbbec/pyorbbecsdk
cd pyorbbecsdk
pip3 install -r requirements.txt
mkdir build
cd build
cmake -Dpybind11_DIR=`pybind11-config --cmakedir` ..
make -j4
make install
cd ..
pip3 install wheel
python3 setup.py bdist_wheel
pip3 install dist/*.whl

export PYTHONPATH=$PYTHONPATH:$(pwd)/install/lib/ # Make sure to save this in your .bashrc


# Install ugv_sdk_py from source
git clone https://github.com/westonrobot/ugv_sdk
cd ugv_sdk
python setup.py build_ext --inplace

export PYTHONPATH=$PYTHONPATH:$(pwd) # Make sure to save this in your .bashrc
```

### Your bashrc should contain something like this

```bash
export PYTHONPATH=$PYTHONPATH:/home/agilex/pyorbbecsdk/install/lib/:/home/agilex/ugv_sdk
```

## Setup ( Every boot of the computer )

```bash
# Run on Agilex provided computer
source /home/agilex/cobot_magic/Piper_ros_private-ros-noetic/can_config.sh
```

## Run

### For recording episode

```bash
dora run record.yml
```

## For inference

```bash
dora run rdt_1b.yml
```

+ 74
- 0
examples/piper/arms_camera.yml View File

@@ -0,0 +1,74 @@
nodes:
- id: piper_left
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/20
outputs:
- jointstate
env:
CAN_BUS: can_left

- id: camera_left
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/50
outputs:
- image
env:
DEVICE_INDEX: 0
ENCODING: jpeg

- id: camera_center
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/50
outputs:
- image
env:
DEVICE_INDEX: 1
ENCODING: jpeg

- id: camera_right
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/50
outputs:
- image
env:
DEVICE_INDEX: 2
ENCODING: jpeg
# import opencv as cv
# [cv2.VideoCapture(i) for i in range(12)]

- id: piper_right
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/20
outputs:
- jointstate
env:
CAN_BUS: can_right

- id: rerun
path: dora-rerun1
inputs:
jointstate_piper_left: piper_left/jointstate
jointstate_piper_right: piper_right/jointstate
image_camera_left: camera_left/image
image_camera_center: camera_center/image
image_camera_right: camera_right/image
env:
piper_left_urdf: assets/piper_left.urdf
piper_right_urdf: assets/piper_right.urdf
piper_left_transform: 0 0.2 0
piper_right_transform: 0 -0.2 0

+ 33
- 0
examples/piper/arms_only.yml View File

@@ -0,0 +1,33 @@
nodes:
- id: piper_left
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/20
outputs:
- jointstate
env:
CAN_BUS: can_left

- id: piper_right
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/20
outputs:
- jointstate
env:
CAN_BUS: can_right

- id: rerun
path: dora-rerun
inputs:
jointstate_piper_left: piper_left/jointstate
jointstate_piper_right: piper_right/jointstate
env:
piper_left_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_left.urdf
piper_right_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_right.urdf
piper_left_transform: 0 0.3 0
piper_right_transform: 0 -0.3 0

+ 126
- 0
examples/piper/dummy_inference_2.py View File

@@ -0,0 +1,126 @@
from dora import Node


import numpy as np
import h5py

f = h5py.File("data/episode_0.hdf5", "r")

data = f["action"][:]


STATE_VEC_IDX_MAPPING = {
# [0, 10): right arm joint positions
**{"arm_joint_{}_pos".format(i): i for i in range(10)},
**{"right_arm_joint_{}_pos".format(i): i for i in range(10)},
# [10, 15): right gripper joint positions
**{"gripper_joint_{}_pos".format(i): i + 10 for i in range(5)},
**{"right_gripper_joint_{}_pos".format(i): i + 10 for i in range(5)},
"gripper_open": 10, # alias of right_gripper_joint_0_pos
"right_gripper_open": 10,
# [15, 25): right arm joint velocities
**{"arm_joint_{}_vel".format(i): i + 15 for i in range(10)},
**{"right_arm_joint_{}_vel".format(i): i + 15 for i in range(10)},
# [25, 30): right gripper joint velocities
**{"gripper_joint_{}_vel".format(i): i + 25 for i in range(5)},
**{"right_gripper_joint_{}_vel".format(i): i + 25 for i in range(5)},
"gripper_open_vel": 25, # alias of right_gripper_joint_0_vel
"right_gripper_open_vel": 25,
# [30, 33): right end effector positions
"eef_pos_x": 30,
"right_eef_pos_x": 30,
"eef_pos_y": 31,
"right_eef_pos_y": 31,
"eef_pos_z": 32,
"right_eef_pos_z": 32,
# [33, 39): right end effector 6D pose
"eef_angle_0": 33,
"right_eef_angle_0": 33,
"eef_angle_1": 34,
"right_eef_angle_1": 34,
"eef_angle_2": 35,
"right_eef_angle_2": 35,
"eef_angle_3": 36,
"right_eef_angle_3": 36,
"eef_angle_4": 37,
"right_eef_angle_4": 37,
"eef_angle_5": 38,
"right_eef_angle_5": 38,
# [39, 42): right end effector velocities
"eef_vel_x": 39,
"right_eef_vel_x": 39,
"eef_vel_y": 40,
"right_eef_vel_y": 40,
"eef_vel_z": 41,
"right_eef_vel_z": 41,
# [42, 45): right end effector angular velocities
"eef_angular_vel_roll": 42,
"right_eef_angular_vel_roll": 42,
"eef_angular_vel_pitch": 43,
"right_eef_angular_vel_pitch": 43,
"eef_angular_vel_yaw": 44,
"right_eef_angular_vel_yaw": 44,
# [45, 50): reserved
# [50, 60): left arm joint positions
**{"left_arm_joint_{}_pos".format(i): i + 50 for i in range(10)},
# [60, 65): left gripper joint positions
**{"left_gripper_joint_{}_pos".format(i): i + 60 for i in range(5)},
"left_gripper_open": 60, # alias of left_gripper_joint_0_pos
# [65, 75): left arm joint velocities
**{"left_arm_joint_{}_vel".format(i): i + 65 for i in range(10)},
# [75, 80): left gripper joint velocities
**{"left_gripper_joint_{}_vel".format(i): i + 75 for i in range(5)},
"left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel
# [80, 83): left end effector positions
"left_eef_pos_x": 80,
"left_eef_pos_y": 81,
"left_eef_pos_z": 82,
# [83, 89): left end effector 6D pose
"left_eef_angle_0": 83,
"left_eef_angle_1": 84,
"left_eef_angle_2": 85,
"left_eef_angle_3": 86,
"left_eef_angle_4": 87,
"left_eef_angle_5": 88,
# [89, 92): left end effector velocities
"left_eef_vel_x": 89,
"left_eef_vel_y": 90,
"left_eef_vel_z": 91,
# [92, 95): left end effector angular velocities
"left_eef_angular_vel_roll": 92,
"left_eef_angular_vel_pitch": 93,
"left_eef_angular_vel_yaw": 94,
# [95, 100): reserved
# [100, 102): base linear velocities
"base_vel_x": 100,
"base_vel_y": 101,
# [102, 103): base angular velocities
"base_angular_vel": 102,
# [103, 128): reserved
}

import time
import pyarrow as pa

node = Node()
LEFT_UNI_STATE_INDICES = [
STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(6)
] + [STATE_VEC_IDX_MAPPING["left_gripper_open"]]
RIGHT_UNI_STATE_INDICES = [
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
] + [STATE_VEC_IDX_MAPPING["right_gripper_open"]]
MOBILE_BASE_UNI_STATE_INDICES = [STATE_VEC_IDX_MAPPING["base_vel_x"]] + [
STATE_VEC_IDX_MAPPING["base_angular_vel"]
]

for joint in data:
node.send_output(
"jointstate_left", pa.array(joint[LEFT_UNI_STATE_INDICES], type=pa.float32())
)
node.send_output(
"jointstate_right", pa.array(joint[RIGHT_UNI_STATE_INDICES], type=pa.float32())
)
node.send_output(
"mobile_base", pa.array(joint[MOBILE_BASE_UNI_STATE_INDICES], type=pa.float32())
)
time.sleep(0.05)

+ 38
- 5
examples/piper/play_dummy_inference.yml View File

@@ -1,13 +1,12 @@
nodes:
- id: piper
path: dummy_inference.py
path: dummy_inference_2.py
inputs:
tick: dora/timer/millis/20
outputs:
- jointstate_left
- jointstate_right
env:
CAN_BUS: can_left
- mobile_base

- id: rerun
build: |
@@ -24,10 +23,44 @@ nodes:
pip install git+https://github.com/rerun-io/rerun-loader-python-example-urdf.git
path: dora-rerun
inputs:
jointstate_piper_left: piper/jointstate_left
jointstate_piper_right: piper/jointstate_right
jointstate_piper_left: piper_left/jointstate
jointstate_piper_right: piper_right/jointstate
jointstate_piper_left_pred: piper/jointstate_left
jointstate_piper_right_pred: piper/jointstate_right
series_piper_left: piper_left/jointstate
series_piper_right: piper_right/jointstate
series_piper_left_pred: piper/jointstate_left
series_piper_right_pred: piper/jointstate_right
env:
piper_left_urdf: piper_left.urdf # Make sure to download meshes from https://github.com/agilexrobotics/Piper_ros/tree/4f22c61f96b8fb3ef3f937b99b63edb697caadf0/src/piper_description/meshes and put them in the assets folder
piper_right_urdf: piper_right.urdf # Make sure to download meshes from https://github.com/agilexrobotics/Piper_ros/tree/4f22c61f96b8fb3ef3f937b99b63edb697caadf0/src/piper_description/meshes and put them in the assets folder
piper_left_transform: 0 0.2 0
piper_right_transform: 0 -0.2 0
piper_left_pred_urdf: assets/piper_left_pred.urdf
piper_right_pred_urdf: assets/piper_right_pred.urdf
piper_left_pred_transform: 0 0.2 0
piper_right_pred_transform: 0 -0.2 0

- id: piper_left
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/500
action: piper/jointstate_left
outputs:
- jointstate
env:
CAN_BUS: can_left

- id: piper_right
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/1000
action: piper/jointstate_right
outputs:
- jointstate
env:
CAN_BUS: can_right

+ 24
- 0
examples/piper/post_process_action.py View File

@@ -0,0 +1,24 @@
from dora import Node

node = Node()

import numpy as np


import time
import pyarrow as pa

for event in node:
if event["type"] == "INPUT":
actions = event["value"].to_numpy().reshape((64, 14))

# Skip action to only keep 8 spread action
actions = actions[[0, 8, 16, 24, 32, 40, 48, 56], :]

for action in actions:
node.send_output("jointstate_left", pa.array(action[:7], type=pa.float32()))
node.send_output(
"jointstate_right", pa.array(action[7:], type=pa.float32())
)
time.sleep(0.005)
print(actions)

+ 132
- 0
examples/piper/rdt_1b.yml View File

@@ -0,0 +1,132 @@
nodes:
- id: piper_left
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/500
action: post_process_rdt_1b/jointstate_left
outputs:
- jointstate
env:
CAN_BUS: can_left

- id: piper_right
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/1000
action: post_process_rdt_1b/jointstate_right
outputs:
- jointstate
env:
CAN_BUS: can_right

- id: camera_left
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/1000
outputs:
- image
env:
DEVICE_INDEX: 0
ENCODING: jpeg

- id: camera_center
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/1000
outputs:
- image
env:
DEVICE_INDEX: 1
ENCODING: jpeg

- id: camera_right
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/1000
outputs:
- image
env:
DEVICE_INDEX: 2
ENCODING: jpeg
# import opencv as cv
# [cv2.VideoCapture(i) for i in range(12)]

- id: rerun
path: dora-rerun
inputs:
jointstate_piper_left: piper_left/jointstate
jointstate_piper_right: piper_right/jointstate
jointstate_piper_left_pred: post_process_rdt_1b/jointstate_left
jointstate_piper_right_pred: post_process_rdt_1b/jointstate_right
series_piper_left: piper_left/jointstate
series_piper_right: piper_right/jointstate
series_piper_left_pred: post_process_rdt_1b/jointstate_left
series_piper_right_pred: post_process_rdt_1b/jointstate_right
image_left: camera_left/image
image_center: camera_center/image
image_right: camera_right/image
env:
piper_left_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_left.urdf
piper_right_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_right.urdf
piper_left_transform: 0 0.2 0
piper_right_transform: 0 -0.2 0
piper_left_pred_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_left_pred.urdf
piper_right_pred_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_right_pred.urdf
piper_left_pred_transform: 0 0.2 0
piper_right_pred_transform: 0 -0.2 0

- id: rdt_1b
path: dora-rdt_1b
inputs:
jointstate_left:
source: piper_left/jointstate
queue_size: 1
jointstate_right:
source: piper_right/jointstate
queue_size: 1
image_left:
source: camera_left/image
queue_size: 1
image_center:
source: camera_center/image
queue_size: 1
image_right:
source: camera_right/image
queue_size: 1
tick:
source: dora/timer/secs/1
queue_size: 1
outputs:
- action
env:
ROBOTIC_MODEL_NAME_OR_PATH: /home/peter/Documents/work/dora/examples/piper/checkpoints/checkpoint-450
VISION_MODEL_NAME_OR_PATH: /home/peter/.cache/huggingface/hub/models--google--siglip-so400m-patch14-384/snapshots/9fdffc58afc957d1a03a25b10dba0329ab15c2a3
LANGUAGE_EMBEDDING_PATH: lang_embed.pt

- id: post_process_rdt_1b
path: post_process_action.py
inputs:
action: rdt_1b/action
outputs:
- jointstate_left
- jointstate_right

- id: mobile_base
path: /home/agilex/1ms.ai/ugv_sdk/tracer_node.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/40
# action: dummy/mobile_base
outputs:
- velocity

+ 231
- 0
examples/piper/record.py View File

@@ -0,0 +1,231 @@
import h5py

import os
import datetime

from dora import Node
import numpy as np

STATE_VEC_IDX_MAPPING = {
# [0, 10): right arm joint positions
**{"arm_joint_{}_pos".format(i): i for i in range(10)},
**{"right_arm_joint_{}_pos".format(i): i for i in range(10)},
# [10, 15): right gripper joint positions
**{"gripper_joint_{}_pos".format(i): i + 10 for i in range(5)},
**{"right_gripper_joint_{}_pos".format(i): i + 10 for i in range(5)},
"gripper_open": 10, # alias of right_gripper_joint_0_pos
"right_gripper_open": 10,
# [15, 25): right arm joint velocities
**{"arm_joint_{}_vel".format(i): i + 15 for i in range(10)},
**{"right_arm_joint_{}_vel".format(i): i + 15 for i in range(10)},
# [25, 30): right gripper joint velocities
**{"gripper_joint_{}_vel".format(i): i + 25 for i in range(5)},
**{"right_gripper_joint_{}_vel".format(i): i + 25 for i in range(5)},
"gripper_open_vel": 25, # alias of right_gripper_joint_0_vel
"right_gripper_open_vel": 25,
# [30, 33): right end effector positions
"eef_pos_x": 30,
"right_eef_pos_x": 30,
"eef_pos_y": 31,
"right_eef_pos_y": 31,
"eef_pos_z": 32,
"right_eef_pos_z": 32,
# [33, 39): right end effector 6D pose
"eef_angle_0": 33,
"right_eef_angle_0": 33,
"eef_angle_1": 34,
"right_eef_angle_1": 34,
"eef_angle_2": 35,
"right_eef_angle_2": 35,
"eef_angle_3": 36,
"right_eef_angle_3": 36,
"eef_angle_4": 37,
"right_eef_angle_4": 37,
"eef_angle_5": 38,
"right_eef_angle_5": 38,
# [39, 42): right end effector velocities
"eef_vel_x": 39,
"right_eef_vel_x": 39,
"eef_vel_y": 40,
"right_eef_vel_y": 40,
"eef_vel_z": 41,
"right_eef_vel_z": 41,
# [42, 45): right end effector angular velocities
"eef_angular_vel_roll": 42,
"right_eef_angular_vel_roll": 42,
"eef_angular_vel_pitch": 43,
"right_eef_angular_vel_pitch": 43,
"eef_angular_vel_yaw": 44,
"right_eef_angular_vel_yaw": 44,
# [45, 50): reserved
# [50, 60): left arm joint positions
**{"left_arm_joint_{}_pos".format(i): i + 50 for i in range(10)},
# [60, 65): left gripper joint positions
**{"left_gripper_joint_{}_pos".format(i): i + 60 for i in range(5)},
"left_gripper_open": 60, # alias of left_gripper_joint_0_pos
# [65, 75): left arm joint velocities
**{"left_arm_joint_{}_vel".format(i): i + 65 for i in range(10)},
# [75, 80): left gripper joint velocities
**{"left_gripper_joint_{}_vel".format(i): i + 75 for i in range(5)},
"left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel
# [80, 83): left end effector positions
"left_eef_pos_x": 80,
"left_eef_pos_y": 81,
"left_eef_pos_z": 82,
# [83, 89): left end effector 6D pose
"left_eef_angle_0": 83,
"left_eef_angle_1": 84,
"left_eef_angle_2": 85,
"left_eef_angle_3": 86,
"left_eef_angle_4": 87,
"left_eef_angle_5": 88,
# [89, 92): left end effector velocities
"left_eef_vel_x": 89,
"left_eef_vel_y": 90,
"left_eef_vel_z": 91,
# [92, 95): left end effector angular velocities
"left_eef_angular_vel_roll": 92,
"left_eef_angular_vel_pitch": 93,
"left_eef_angular_vel_yaw": 94,
# [95, 100): reserved
# [100, 102): base linear velocities
"base_vel_x": 100,
"base_vel_y": 101,
# [102, 103): base angular velocities
"base_angular_vel": 102,
# [103, 128): reserved
}
STATE_VEC_LEN = 128


now = datetime.datetime.now()

DATA_DIR = now.strftime("%Y.%m.%d")
os.makedirs(DATA_DIR, exist_ok=True)

## Make data dir if it does not exist
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)


def save_data(data_dict, dataset_path, data_size):
with h5py.File(dataset_path + ".hdf5", "w", rdcc_nbytes=1024**2 * 2) as root:
root.attrs["sim"] = False
root.attrs["compress"] = False

obs = root.create_group("observations")
variable_length = h5py.vlen_dtype(np.dtype("uint8"))
image = obs.create_group("images")
_ = image.create_dataset(
"cam_high",
(data_size,),
dtype=variable_length,
)
_ = image.create_dataset(
"cam_left_wrist",
(data_size,),
dtype=variable_length,
)
_ = image.create_dataset(
"cam_right_wrist",
(data_size,),
dtype=variable_length,
)

_ = obs.create_dataset("qpos", (data_size, 128))
_ = root.create_dataset("action", (data_size, 128))

# data_dict write into h5py.File
for name, array in data_dict.items():
print(name)
if "images" in name:
image[name][...] = array
else:
root[name][...] = array


data_dict = {
"/observations/qpos": [],
"/observations/images/cam_high": [],
"/observations/images/cam_left_wrist": [],
"/observations/images/cam_right_wrist": [],
"/action": [],
}


node = Node()

LEAD_CAMERA = "/observations/images/cam_high"

tmp_dict = {}

i = 0

start = False
for event in node:
if event["type"] == "INPUT":
if "save" in event["id"]:
char = event["value"][0].as_py()
if char == "p":
if start == False:
continue

save_data(
data_dict,
f"{DATA_DIR}/episode_{i}",
len(data_dict["/observations/qpos"]),
)

# Reset dict
data_dict = {
"/observations/qpos": [],
"/observations/images/cam_high": [],
"/observations/images/cam_left_wrist": [],
"/observations/images/cam_right_wrist": [],
"/action": [],
}
i += 1
start = False
elif char == "s":
start = True

elif "image" in event["id"]:
tmp_dict[event["id"]] = event["value"].to_numpy()
elif "qpos" in event["id"]:
tmp_dict[event["id"]] = event["value"].to_numpy()
elif "base_vel" in event["id"]:
tmp_dict[event["id"]] = event["value"].to_numpy()

# Check if tmp dict is full
if len(tmp_dict) != 6:
continue
elif event["id"] == LEAD_CAMERA and start == True:
values = np.concatenate(
[
tmp_dict["/observations/qpos_left"],
tmp_dict["/observations/qpos_right"],
tmp_dict["/observations/base_vel"],
]
)
UNI_STATE_INDICES = (
[STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(6)]
+ [STATE_VEC_IDX_MAPPING["left_gripper_open"]]
+ [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)]
+ [STATE_VEC_IDX_MAPPING["right_gripper_open"]]
+ [STATE_VEC_IDX_MAPPING["base_vel_x"]]
+ [STATE_VEC_IDX_MAPPING["base_angular_vel"]]
)
universal_vec = np.zeros(STATE_VEC_LEN)
universal_vec[UNI_STATE_INDICES] = values
data_dict["/observations/qpos"].append(universal_vec)
# We reproduce obs and action
data_dict["/action"].append(universal_vec)
data_dict["/observations/images/cam_high"].append(
tmp_dict["/observations/images/cam_high"]
)
data_dict["/observations/images/cam_left_wrist"].append(
tmp_dict["/observations/images/cam_left_wrist"]
)
data_dict["/observations/images/cam_right_wrist"].append(
tmp_dict["/observations/images/cam_right_wrist"]
)

+ 115
- 0
examples/piper/record.yml View File

@@ -0,0 +1,115 @@
nodes:
- id: piper_left
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/40
outputs:
- jointstate
env:
CAN_BUS: can_left
TEACH_MODE: True

- id: piper_right
path: /home/agilex/1ms.ai/piper_sdk/dora_piper.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/40
outputs:
- jointstate
env:
CAN_BUS: can_right
TEACH_MODE: True

- id: mobile_base
path: /home/agilex/1ms.ai/ugv_sdk/tracer_node.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/40
outputs:
- velocity

- id: camera_left
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/40
outputs:
- image
env:
DEVICE_INDEX: 0
ENCODING: jpeg

- id: camera_center
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/40
outputs:
- image
env:
DEVICE_INDEX: 1
ENCODING: jpeg

- id: camera_right
path: /home/agilex/1ms.ai/pyorbbecsdk/examples/color_viewer.py
_unstable_deploy:
machine: piper
inputs:
tick: dora/timer/millis/40
outputs:
- image
env:
DEVICE_INDEX: 2
ENCODING: jpeg
# import opencv as cv
# [cv2.VideoCapture(i) for i in range(12)]

- id: rerun
path: dora-rerun
inputs:
jointstate_piper_left: piper_left/jointstate
jointstate_piper_right: piper_right/jointstate
series_base_vel: mobile_base/velocity
image_left: camera_left/image
image_center: camera_center/image
image_right: camera_right/image
env:
piper_left_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_left.urdf
piper_right_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_right.urdf
piper_left_transform: 0 0.2 0
piper_right_transform: 0 -0.2 0
piper_left_pred_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_left_pred.urdf
piper_right_pred_urdf: /home/peter/Documents/work/dora/examples/piper/assets/piper_right_pred.urdf
piper_left_pred_transform: 0 0.2 0
piper_right_pred_transform: 0 -0.2 0

- id: keyboard
build: pip install dora-keyboard
path: dora-keyboard
inputs:
tick: dora/timer/millis/1000
outputs:
- char

- id: recorder
path: record.py
inputs:
/observations/qpos_left:
source: piper_left/jointstate
/observations/qpos_right:
source: piper_right/jointstate
/observations/base_vel:
source: mobile_base/velocity
/observations/images/cam_left_wrist:
source: camera_left/image
/observations/images/cam_high:
source: camera_center/image
/observations/images/cam_right_wrist:
source: camera_right/image
save: keyboard/char

+ 3
- 0
node-hub/dora-rdt-1b/README.md View File

@@ -0,0 +1,3 @@
# Dora RDT-1B node

Experimental node for using a RDT-1B VLA model.

+ 1
- 0
node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer

@@ -0,0 +1 @@
Subproject commit b2889e65cfe62571ced3ce88f00e7d80b41fee69

+ 19
- 0
node-hub/dora-rdt-1b/dora_rdt_1b/__init__.py View File

@@ -0,0 +1,19 @@
import os
import sys
from pathlib import Path

# Define the path to the README file relative to the package directory
readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.md")

# Read the content of the README file
try:
with open(readme_path, "r", encoding="utf-8") as f:
__doc__ = f.read()
except FileNotFoundError:
__doc__ = "README file not found."


# Set up the import hook

submodule_path = Path(__file__).resolve().parent / "RoboticsDiffusionTransformer"
sys.path.insert(0, str(submodule_path))

+ 324
- 0
node-hub/dora-rdt-1b/dora_rdt_1b/main.py View File

@@ -0,0 +1,324 @@
# install dependencies as shown in the README here https://github.com/alik-git/RoboticsDiffusionTransformer?tab=readme-ov-file#installation
import yaml
import torch
import numpy as np
from PIL import Image
from torchvision import transforms

from dora_rdt_1b.RoboticsDiffusionTransformer.configs.state_vec import (
STATE_VEC_IDX_MAPPING,
)
from dora_rdt_1b.RoboticsDiffusionTransformer.models.multimodal_encoder.siglip_encoder import (
SiglipVisionTower,
)
from dora_rdt_1b.RoboticsDiffusionTransformer.models.rdt_runner import RDTRunner
from dora_rdt_1b.RoboticsDiffusionTransformer.configs.state_vec import (
STATE_VEC_IDX_MAPPING,
)
from dora import Node
import cv2
import pyarrow as pa
import os
from pathlib import Path

VISION_DEFAULT_PATH = "robotics-diffusion-transformer/rdt-1b"
ROBOTIC_MODEL_NAME_OR_PATH = os.getenv(
"ROBOTIC_MODEL_NAME_OR_PATH", VISION_DEFAULT_PATH
)
LANGUAGE_EMBEDDING_PATH = os.getenv("LANGUAGE_EMBEDDING", "lang_embed.pt")

VISION_DEFAULT_PATH = "google/siglip-so400m-patch14-384"
VISION_MODEL_NAME_OR_PATH = os.getenv("VISION_MODEL_NAME_OR_PATH", VISION_DEFAULT_PATH)


def get_policy():
from dora_rdt_1b.RoboticsDiffusionTransformer.models.rdt_runner import RDTRunner

pretrained_model_name_or_path = ROBOTIC_MODEL_NAME_OR_PATH
rdt = RDTRunner.from_pretrained(pretrained_model_name_or_path)
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded
rdt.to(device, dtype=dtype)
rdt.eval()
return rdt


def get_vision_model():
from dora_rdt_1b.RoboticsDiffusionTransformer.models.multimodal_encoder.siglip_encoder import (
SiglipVisionTower,
)

# Load vision encoder
vision_encoder = SiglipVisionTower(
vision_tower=VISION_MODEL_NAME_OR_PATH,
args=None,
)
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded
vision_encoder.to(device, dtype=dtype)
vision_encoder.eval()
image_processor = vision_encoder.image_processor
return vision_encoder, image_processor


def get_language_embeddings():
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded

lang_embeddings = torch.load(
LANGUAGE_EMBEDDING_PATH,
map_location=device,
)

return lang_embeddings.unsqueeze(
0
) # Size: (B, L_lang, D) or None, language condition tokens (variable length), dimension D is assumed to be the same as the hidden size.


def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result


def process_image(rgbs_lst, image_processor, vision_encoder):
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded

file_path = Path(__file__).parent

config_path = (
file_path / "RoboticsDiffusionTransformer/configs/base.yaml"
) # default config

with open(config_path, "r") as fp:
config = yaml.safe_load(fp)

# previous_image_path = "/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/img.jpeg"
# # previous_image = None # if t = 0
# previous_image = Image.fromarray(previous_image_path).convert("RGB") # if t > 0

# current_image_path = "/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/img.jpeg"
# current_image = Image.fromarray(current_image_path).convert("RGB")

# here I suppose you only have an image from exterior (e.g., 3rd person view) and you don't have any state information
# the images shoud arrange in sequence [exterior_image, right_wrist_image, left_wrist_image] * image_history_size (e.g., 2)
# rgbs_lst = [[previous_image, None, None], [current_image, None, None]]
# if your have an right_wrist_image, then it should be
# rgbs_lst = [
# [previous_image, previous_right_wrist_image, None],
# [current_image, current_right_wrist_image, None]
# ]

# image pre-processing
# The background image used for padding

image_tensor_list = []
for step in range(config["common"]["img_history_size"]):
rgbs = rgbs_lst[step]
for rgb in rgbs:
assert rgb, "You should not have None image"
image = rgb

if config["dataset"].get("image_aspect_ratio", "pad") == "pad":
background_color = tuple(
int(x * 255) for x in image_processor.image_mean
)
image = expand2square(image, background_color)
image = image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
image_tensor_list.append(image)

image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
# encode images
image_embeds = vision_encoder(image_tensor).detach()
return image_embeds.reshape(-1, vision_encoder.hidden_size).unsqueeze(0)


def get_states(proprio):
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded

# suppose you control in 7DOF joint position
STATE_INDICES = [
STATE_VEC_IDX_MAPPING["left_arm_joint_0_pos"],
STATE_VEC_IDX_MAPPING["left_arm_joint_1_pos"],
STATE_VEC_IDX_MAPPING["left_arm_joint_2_pos"],
STATE_VEC_IDX_MAPPING["left_arm_joint_3_pos"],
STATE_VEC_IDX_MAPPING["left_arm_joint_4_pos"],
STATE_VEC_IDX_MAPPING["left_arm_joint_5_pos"],
STATE_VEC_IDX_MAPPING["left_gripper_open"],
STATE_VEC_IDX_MAPPING["right_arm_joint_0_pos"],
STATE_VEC_IDX_MAPPING["right_arm_joint_1_pos"],
STATE_VEC_IDX_MAPPING["right_arm_joint_2_pos"],
STATE_VEC_IDX_MAPPING["right_arm_joint_3_pos"],
STATE_VEC_IDX_MAPPING["right_arm_joint_4_pos"],
STATE_VEC_IDX_MAPPING["right_arm_joint_5_pos"],
STATE_VEC_IDX_MAPPING["right_gripper_open"],
]

file_path = Path(__file__).parent

config_path = (
file_path / "RoboticsDiffusionTransformer/configs/base.yaml"
) # default config
with open(config_path, "r") as fp:
config = yaml.safe_load(fp)

B, N = 1, 1 # batch size and state history size
states = torch.zeros(
(B, N, config["model"]["state_token_dim"]), device=device, dtype=dtype
)
# suppose you do not have proprio
# it's kind of tricky, I strongly suggest adding proprio as input and futher fine-tuning
proprio = torch.tensor(proprio, device=device, dtype=dtype).reshape(
(1, 1, -1)
) # B, N = 1, 1 # batch size and state history size

# if you have proprio, you can do like this
# format like this: [arm_joint_0_pos, arm_joint_1_pos, arm_joint_2_pos, arm_joint_3_pos, arm_joint_4_pos, arm_joint_5_pos, arm_joint_6_pos, gripper_open]
# proprio = torch.tensor([0, 1, 2, 3, 4, 5, 6, 0.5]).reshape((1, 1, -1))
states[:, :, STATE_INDICES] = proprio

state_elem_mask = torch.zeros(
(1, config["model"]["state_token_dim"]), device=device, dtype=torch.bool
)

state_elem_mask[:, STATE_INDICES] = True
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(
device, dtype=dtype
)
states = states[:, -1:, :] # only use the last state
return states, state_elem_mask, STATE_INDICES


def main():

device = torch.device("cuda:0")
rdt = get_policy()
lang_embeddings = get_language_embeddings()
vision_encoder, image_processor = get_vision_model()

## for image
# image_embeds = process_image(rgb_lst, image_processor, vision_encoder)
## for states
# states, state_elem_mask, STATE_INDICES = get_states(states)
node = Node()
frames = {}
joints = {}
with torch.no_grad():

for event in node:
event_type = event["type"]
if event_type == "INPUT":

event_id = event["id"]

if "image" in event_id:
storage = event["value"]
metadata = event["metadata"]
encoding = metadata["encoding"]

if encoding == "bgr8":
channels = 3
storage_type = np.uint8
elif encoding == "rgb8":
channels = 3
storage_type = np.uint8
elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
channels = 3
storage_type = np.uint8
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")

if encoding == "bgr8":
width = metadata["width"]
height = metadata["height"]
frame = (
storage.to_numpy()
.astype(storage_type)
.reshape((height, width, channels))
)
frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
elif encoding == "rgb8":
width = metadata["width"]
height = metadata["height"]
frame = (
storage.to_numpy()
.astype(storage_type)
.reshape((height, width, channels))
)
elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
storage = storage.to_numpy()
frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")
frames[f"last_{event_id}"] = frames.get(
event_id, Image.fromarray(frame)
)
frames[event_id] = Image.fromarray(frame)
elif "jointstate" in event_id:
joints[event_id] = event["value"].to_numpy()

elif "tick" == event_id:
## Wait for all images
if len(frames.keys()) < 6:
continue
if len(joints.keys()) < 2:
continue

## Embed images
rgbs_lst = [
[
frames["last_image_center"],
frames["last_image_right"],
frames["last_image_left"],
],
[
frames["image_center"],
frames["image_right"],
frames["image_left"],
],
]
image_embeds = process_image(
rgbs_lst, image_processor, vision_encoder
)

## Embed states
proprio = np.concatenate(
[
joints["jointstate_left"],
joints["jointstate_right"],
]
)
states, state_elem_mask, state_indices = get_states(proprio=proprio)

actions = rdt.predict_action(
lang_tokens=lang_embeddings,
lang_attn_mask=torch.ones(
lang_embeddings.shape[:2], dtype=torch.bool, device=device
),
img_tokens=image_embeds,
state_tokens=states, # how can I get this?
action_mask=state_elem_mask.unsqueeze(1), # how can I get this?
ctrl_freqs=torch.tensor(
[25.0], device=device
), # would this default work?
) # (1, chunk_size, 128)

# select the meaning action via STATE_INDICES
action = actions[
:, :, state_indices
] # (1, chunk_size, len(STATE_INDICES)) = (1, chunk_size, 7+ 1)
action = action.detach().float().to("cpu").numpy()
node.send_output("action", pa.array(action.ravel()))

+ 36
- 0
node-hub/dora-rdt-1b/pyproject.toml View File

@@ -0,0 +1,36 @@
[tool.poetry]
name = "dora-rdt-1b"
version = "0.3.6-rc0"
authors = ["Haixuan Xavier Tao <tao.xavier@outlook.com>"]
description = "Dora Node for VLM"
readme = "README.md"

packages = [{ include = "dora_rdt_1b" }]

[tool.poetry.dependencies]
python = "^3.7"
dora-rs = "^0.3.6"
numpy = "< 2.0.0"
torch = "^2.4.0"
torchvision = "^0.19"
transformers = "^4.45"
qwen-vl-utils = "^0.0.2"
accelerate = "^0.33"
opencv-python = ">= 4.1.1"
modelscope = "^1.18.1"
packaging = "24.0"
wandb = "0.17.0"
diffusers = "0.27.2"
timm = "1.0.3"
sentencepiece = "0.2.0"
h5py = "3.11.0"
imgaug = "0.4.0"
# flash_attn = "^2.6.1" # Install using: pip install -U flash-attn --no-build-isolation


[tool.poetry.scripts]
dora-rdt-1b = "dora_rdt_1b.main:main"

[build-system]
requires = ["poetry-core>=1.8.0"]
build-backend = "poetry.core.masonry.api"

+ 12
- 0
node-hub/dora-rdt-1b/tests/conftest.py View File

@@ -0,0 +1,12 @@
import pytest


def pytest_configure():
pytest.rdt = None
pytest.lang_embeddings = None
pytest.image_processor = None
pytest.vision_encoder = None
pytest.image_embeds = None
pytest.state_elem_mask = None
pytest.states = None
pytest.STATE_INDICES = None

+ 227
- 0
node-hub/dora-rdt-1b/tests/test_dora_rdt_1b.py View File

@@ -0,0 +1,227 @@
import pytest
import torch
import yaml
import numpy as np
from PIL import Image
from torchvision import transforms


def test_import_main():
# from dora_rdt_1b.main import main

# Check that everything is working, and catch dora Runtime Exception as we're not running in a dora dataflow.
# with pytest.raises(RuntimeError):
pass
# main()
import dora_rdt_1b
import dora_rdt_1b.RoboticsDiffusionTransformer


def test_download_policy():
from dora_rdt_1b.RoboticsDiffusionTransformer.models.rdt_runner import RDTRunner

pretrained_model_name_or_path = "robotics-diffusion-transformer/rdt-1b"
rdt = RDTRunner.from_pretrained(pretrained_model_name_or_path)
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded
rdt.to(device, dtype=dtype)
rdt.eval()
pytest.rdt = rdt


def test_download_vision_model():
from dora_rdt_1b.RoboticsDiffusionTransformer.models.multimodal_encoder.siglip_encoder import (
SiglipVisionTower,
)

# Load vision encoder
vision_encoder = SiglipVisionTower(
vision_tower="google/siglip-so400m-patch14-384", args=None
)
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded
vision_encoder.to(device, dtype=dtype)
vision_encoder.eval()
image_processor = vision_encoder.image_processor
pytest.vision_encoder = vision_encoder
pytest.image_processor = image_processor


def test_download_language_embeddings():
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded
lang_embeddings = torch.load(
"/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/outs/handover_pan.pt",
map_location=device,
)
pytest.lang_embeddings = lang_embeddings["embeddings"]


def test_load_dummy_image():
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded
config_path = "/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/configs/base.yaml" # default config
with open(config_path, "r") as fp:
config = yaml.safe_load(fp)

# Load pretrained model (in HF style)
image_processor = pytest.image_processor
vision_encoder = pytest.vision_encoder

previous_image_path = "/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/img.jpeg"
# previous_image = None # if t = 0
previous_image = Image.open(previous_image_path).convert("RGB") # if t > 0

current_image_path = "/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/img.jpeg"
current_image = Image.open(current_image_path).convert("RGB")

# here I suppose you only have an image from exterior (e.g., 3rd person view) and you don't have any state information
# the images shoud arrange in sequence [exterior_image, right_wrist_image, left_wrist_image] * image_history_size (e.g., 2)
rgbs_lst = [[previous_image, None, None], [current_image, None, None]]
# if your have an right_wrist_image, then it should be
# rgbs_lst = [
# [previous_image, previous_right_wrist_image, None],
# [current_image, current_right_wrist_image, None]
# ]

# image pre-processing
# The background image used for padding
background_color = np.array(
[int(x * 255) for x in image_processor.image_mean], dtype=np.uint8
).reshape(1, 1, 3)
background_image = (
np.ones(
(image_processor.size["height"], image_processor.size["width"], 3),
dtype=np.uint8,
)
* background_color
)

image_tensor_list = []
for step in range(config["common"]["img_history_size"]):
rgbs = rgbs_lst[step % len(rgbs_lst)]
for rgb in rgbs:
if rgb is None:
# Replace it with the background image
image = Image.fromarray(background_image)
else:
image = rgb

if config["dataset"].get("auto_adjust_image_brightness", False):
pixel_values = list(image.getdata())
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (
len(pixel_values) * 255.0 * 3
)
if average_brightness <= 0.15:
image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)

if config["dataset"].get("image_aspect_ratio", "pad") == "pad":

def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(
pil_img.mode, (width, width), background_color
)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(
pil_img.mode, (height, height), background_color
)
result.paste(pil_img, ((height - width) // 2, 0))
return result

image = expand2square(
image, tuple(int(x * 255) for x in image_processor.image_mean)
)
image = image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
image_tensor_list.append(image)

image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
# encode images
image_embeds = vision_encoder(image_tensor).detach()
pytest.image_embeds = image_embeds.reshape(
-1, vision_encoder.hidden_size
).unsqueeze(0)


def test_dummy_states():
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded
config_path = "/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/configs/base.yaml" # default config
with open(config_path, "r") as fp:
config = yaml.safe_load(fp)

# suppose you do not have proprio
# it's kind of tricky, I strongly suggest adding proprio as input and futher fine-tuning
B, N = 1, 1 # batch size and state history size
states = torch.zeros(
(B, N, config["model"]["state_token_dim"]), device=device, dtype=dtype
)

# if you have proprio, you can do like this
# format like this: [arm_joint_0_pos, arm_joint_1_pos, arm_joint_2_pos, arm_joint_3_pos, arm_joint_4_pos, arm_joint_5_pos, arm_joint_6_pos, gripper_open]
# proprio = torch.tensor([0, 1, 2, 3, 4, 5, 6, 0.5]).reshape((1, 1, -1))
# states[:, :, STATE_INDICES] = proprio

state_elem_mask = torch.zeros(
(B, config["model"]["state_token_dim"]), device=device, dtype=torch.bool
)
from dora_rdt_1b.RoboticsDiffusionTransformer.configs.state_vec import (
STATE_VEC_IDX_MAPPING,
)

# suppose you control in 7DOF joint position
STATE_INDICES = [
STATE_VEC_IDX_MAPPING["arm_joint_0_pos"],
STATE_VEC_IDX_MAPPING["arm_joint_1_pos"],
STATE_VEC_IDX_MAPPING["arm_joint_2_pos"],
STATE_VEC_IDX_MAPPING["arm_joint_3_pos"],
STATE_VEC_IDX_MAPPING["arm_joint_4_pos"],
STATE_VEC_IDX_MAPPING["arm_joint_5_pos"],
STATE_VEC_IDX_MAPPING["arm_joint_6_pos"],
STATE_VEC_IDX_MAPPING["gripper_open"],
]

state_elem_mask[:, STATE_INDICES] = True
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(
device, dtype=dtype
)
states = states[:, -1:, :] # only use the last state
pytest.states = states
pytest.state_elem_mask = state_elem_mask
pytest.STATE_INDICES = STATE_INDICES


def test_dummy_input(request):

rdt = pytest.rdt
lang_embeddings = pytest.lang_embeddings
image_embeds = pytest.image_embeds
state_elem_mask = pytest.state_elem_mask
states = pytest.states
STATE_INDICES = pytest.STATE_INDICES

device = torch.device("cuda:0")

actions = rdt.predict_action(
lang_tokens=lang_embeddings,
lang_attn_mask=torch.ones(
lang_embeddings.shape[:2], dtype=torch.bool, device=device
),
img_tokens=image_embeds,
state_tokens=states, # how can I get this?
action_mask=state_elem_mask.unsqueeze(1), # how can I get this?
ctrl_freqs=torch.tensor([25.0], device=device), # would this default work?
) # (1, chunk_size, 128)

# select the meaning action via STATE_INDICES
action = actions[
:, :, STATE_INDICES
] # (1, chunk_size, len(STATE_INDICES)) = (1, chunk_size, 7+ 1)
print(action)

Loading…
Cancel
Save