|
|
|
@@ -1,10 +1,10 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
import torch |
|
|
|
import numpy as np |
|
|
|
from PIL import Image |
|
|
|
from torchvision import transforms |
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
CI = os.environ.get("CI") |
|
|
|
|
|
|
|
@@ -20,8 +20,8 @@ def test_import_main(): |
|
|
|
# Check that everything is working, and catch dora Runtime Exception as we're not running in a dora dataflow. |
|
|
|
# with pytest.raises(RuntimeError): |
|
|
|
# main() |
|
|
|
import dora_rdt_1b.RoboticsDiffusionTransformer as _ |
|
|
|
import dora_rdt_1b as _ |
|
|
|
import dora_rdt_1b.RoboticsDiffusionTransformer as _ # noqa |
|
|
|
import dora_rdt_1b as _ # noqa |
|
|
|
|
|
|
|
|
|
|
|
def test_download_policy(): |
|
|
|
@@ -44,7 +44,6 @@ def test_download_vision_model(): |
|
|
|
|
|
|
|
|
|
|
|
def test_download_language_embeddings(): |
|
|
|
|
|
|
|
## in the future we should add this test within CI |
|
|
|
if CI: |
|
|
|
return |
|
|
|
@@ -55,7 +54,6 @@ def test_download_language_embeddings(): |
|
|
|
|
|
|
|
|
|
|
|
def test_load_dummy_image(): |
|
|
|
|
|
|
|
from dora_rdt_1b.main import config |
|
|
|
|
|
|
|
# Load pretrained model (in HF style) |
|
|
|
@@ -187,8 +185,9 @@ def test_dummy_states(): |
|
|
|
] |
|
|
|
|
|
|
|
state_elem_mask[:, STATE_INDICES] = True |
|
|
|
states, state_elem_mask = states.to(DEVICE, dtype=DTYPE), state_elem_mask.to( |
|
|
|
DEVICE, dtype=DTYPE |
|
|
|
states, state_elem_mask = ( |
|
|
|
states.to(DEVICE, dtype=DTYPE), |
|
|
|
state_elem_mask.to(DEVICE, dtype=DTYPE), |
|
|
|
) |
|
|
|
states = states[:, -1:, :] # only use the last state |
|
|
|
pytest.states = states |
|
|
|
|