Browse Source

Fix rdt 1b black error linting to exclude rdt git submodule

tags/0.3.8-rc
haixuantao 1 year ago
parent
commit
61dafea5b3
3 changed files with 6 additions and 2 deletions
  1. +1
    -0
      .github/workflows/node-hub-ci-cd.yml
  2. +1
    -2
      node-hub/dora-rdt-1b/dora_rdt_1b/main.py
  3. +4
    -0
      node-hub/dora-rdt-1b/pyproject.toml

+ 1
- 0
.github/workflows/node-hub-ci-cd.yml View File

@@ -44,6 +44,7 @@ jobs:
submodules: true # Make sure to check out the sub-module

- name: Update submodule
if: runner.os == 'Linux'
run: |
git submodule update --init --recursive
git submodule update --remote --recursive


+ 1
- 2
node-hub/dora-rdt-1b/dora_rdt_1b/main.py View File

@@ -31,6 +31,7 @@ config_path = (
with open(config_path, "r", encoding="utf-8") as fp:
config = yaml.safe_load(fp)


def get_policy():
from dora_rdt_1b.RoboticsDiffusionTransformer.models.rdt_runner import RDTRunner

@@ -92,7 +93,6 @@ def process_image(rgbs_lst, image_processor, vision_encoder):
device = torch.device("cuda:0")
dtype = torch.bfloat16 # recommanded


# previous_image_path = "/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/img.jpeg"
# # previous_image = None # if t = 0
# previous_image = Image.fromarray(previous_image_path).convert("RGB") # if t > 0
@@ -157,7 +157,6 @@ def get_states(proprio):
STATE_VEC_IDX_MAPPING["right_gripper_open"],
]


B, N = 1, 1 # batch size and state history size
states = torch.zeros(
(B, N, config["model"]["state_token_dim"]), device=device, dtype=dtype


+ 4
- 0
node-hub/dora-rdt-1b/pyproject.toml View File

@@ -35,6 +35,10 @@ ignore-paths = '^dora_rdt_1b/RoboticsDiffusionTransformer.*$'
pytest = "^8.3.4"
pylint = "^3.3.2"

[tool.black]
extend-exclude = 'dora_rdt_1b/RoboticsDiffusionTransformer'


[tool.poetry.scripts]
dora-rdt-1b = "dora_rdt_1b.main:main"



Loading…
Cancel
Save