|
|
|
@@ -10,12 +10,21 @@ import pyarrow as pa |
|
|
|
|
|
|
|
for event in node: |
|
|
|
if event["type"] == "INPUT": |
|
|
|
actions = event["value"].to_numpy().reshape((64, 14)) |
|
|
|
|
|
|
|
# Skip action to only keep 8 spread action |
|
|
|
actions = actions[[0, 8, 16, 24, 32, 40, 48, 56], :] |
|
|
|
actions = event["value"].to_numpy().copy().reshape((64, 14)) |
|
|
|
|
|
|
|
for action in actions: |
|
|
|
gripper_left = action[6] |
|
|
|
gripper_right = action[13] |
|
|
|
if gripper_right < 0.45: |
|
|
|
action[13] = 0.3 |
|
|
|
else: |
|
|
|
action[13] = 0.6 |
|
|
|
|
|
|
|
if gripper_left < 0.45: |
|
|
|
action[6] = 0.3 |
|
|
|
else: |
|
|
|
action[6] = 0.6 |
|
|
|
|
|
|
|
node.send_output("jointstate_left", pa.array(action[:7], type=pa.float32())) |
|
|
|
node.send_output( |
|
|
|
"jointstate_right", pa.array(action[7:], type=pa.float32()) |
|
|
|
|