|
|
|
@@ -31,6 +31,7 @@ config_path = ( |
|
|
|
with open(config_path, "r", encoding="utf-8") as fp: |
|
|
|
config = yaml.safe_load(fp) |
|
|
|
|
|
|
|
|
|
|
|
def get_policy(): |
|
|
|
from dora_rdt_1b.RoboticsDiffusionTransformer.models.rdt_runner import RDTRunner |
|
|
|
|
|
|
|
@@ -92,7 +93,6 @@ def process_image(rgbs_lst, image_processor, vision_encoder): |
|
|
|
device = torch.device("cuda:0") |
|
|
|
dtype = torch.bfloat16 # recommanded |
|
|
|
|
|
|
|
|
|
|
|
# previous_image_path = "/mnt/hpfs/1ms.ai/dora/node-hub/dora-rdt-1b/dora_rdt_1b/RoboticsDiffusionTransformer/img.jpeg" |
|
|
|
# # previous_image = None # if t = 0 |
|
|
|
# previous_image = Image.fromarray(previous_image_path).convert("RGB") # if t > 0 |
|
|
|
@@ -157,7 +157,6 @@ def get_states(proprio): |
|
|
|
STATE_VEC_IDX_MAPPING["right_gripper_open"], |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
B, N = 1, 1 # batch size and state history size |
|
|
|
states = torch.zeros( |
|
|
|
(B, N, config["model"]["state_token_dim"]), device=device, dtype=dtype |
|
|
|
|