| @@ -0,0 +1,8 @@ | |||||
| build_id: 2b402c1e-e52e-45e9-86e5-236b33a77369 | |||||
| session_id: 275de19c-e605-4865-bc5f-2f15916bade9 | |||||
| git_sources: {} | |||||
| local_build: | |||||
| node_working_dirs: | |||||
| camera: /Users/xaviertao/Documents/work/dora/examples/vggt | |||||
| dora-vggt: /Users/xaviertao/Documents/work/dora/examples/vggt | |||||
| plot: /Users/xaviertao/Documents/work/dora/examples/vggt | |||||
| @@ -0,0 +1,26 @@ | |||||
| nodes: | |||||
| - id: camera | |||||
| build: pip install opencv-video-capture | |||||
| path: opencv-video-capture | |||||
| inputs: | |||||
| tick: dora/timer/millis/100 | |||||
| outputs: | |||||
| - image | |||||
| env: | |||||
| CAPTURE_PATH: 1 | |||||
| - id: dora-vggt | |||||
| build: pip install -e ../../node-hub/dora-vggt | |||||
| path: dora-vggt | |||||
| inputs: | |||||
| image: camera/image | |||||
| outputs: | |||||
| - depth | |||||
| - image | |||||
| - id: plot | |||||
| build: pip install dora-rerun | |||||
| path: dora-rerun | |||||
| inputs: | |||||
| camera/image: dora-vggt/image | |||||
| camera/depth: dora-vggt/depth | |||||
| @@ -1,4 +1,7 @@ | |||||
| use arrow::array::{Array, Float32Array, Float64Array, Int32Array, Int64Array, UInt32Array}; | |||||
| use arrow::array::{ | |||||
| Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, | |||||
| UInt32Array, UInt8Array, | |||||
| }; | |||||
| use arrow::datatypes::DataType; | use arrow::datatypes::DataType; | ||||
| use eyre::{eyre, ContextCompat, Result}; | use eyre::{eyre, ContextCompat, Result}; | ||||
| use num::NumCast; | use num::NumCast; | ||||
| @@ -63,7 +66,11 @@ macro_rules! register_array_handlers { | |||||
| register_array_handlers! { | register_array_handlers! { | ||||
| (DataType::Float32, Float32Array, "float32"), | (DataType::Float32, Float32Array, "float32"), | ||||
| (DataType::Float64, Float64Array, "float64"), | (DataType::Float64, Float64Array, "float64"), | ||||
| (DataType::Int8, Int8Array, "int8"), | |||||
| (DataType::Int16, Int16Array, "int16"), | |||||
| (DataType::Int32, Int32Array, "int32"), | (DataType::Int32, Int32Array, "int32"), | ||||
| (DataType::Int64, Int64Array, "int64"), | (DataType::Int64, Int64Array, "int64"), | ||||
| (DataType::UInt8, UInt8Array, "uint8"), | |||||
| (DataType::UInt16, UInt16Array, "uint16"), | |||||
| (DataType::UInt32, UInt32Array, "uint32"), | (DataType::UInt32, UInt32Array, "uint32"), | ||||
| } | } | ||||
| @@ -0,0 +1,40 @@ | |||||
| # dora-vggt | |||||
| ## Getting started | |||||
| - Install it with uv: | |||||
| ```bash | |||||
| uv venv -p 3.11 --seed | |||||
| uv pip install -e . | |||||
| ``` | |||||
| ## Contribution Guide | |||||
| - Format with [ruff](https://docs.astral.sh/ruff/): | |||||
| ```bash | |||||
| uv pip install ruff | |||||
| uv run ruff check . --fix | |||||
| ``` | |||||
| - Lint with ruff: | |||||
| ```bash | |||||
| uv run ruff check . | |||||
| ``` | |||||
| - Test with [pytest](https://github.com/pytest-dev/pytest) | |||||
| ```bash | |||||
| uv pip install pytest | |||||
| uv run pytest . # Test | |||||
| ``` | |||||
| ## YAML Specification | |||||
| ## Examples | |||||
| ## License | |||||
| dora-vggt's code are released under the MIT License | |||||
| @@ -0,0 +1,13 @@ | |||||
| """TODO: Add docstring.""" | |||||
| import os | |||||
| # Define the path to the README file relative to the package directory | |||||
| readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.md") | |||||
| # Read the content of the README file | |||||
| try: | |||||
| with open(readme_path, encoding="utf-8") as f: | |||||
| __doc__ = f.read() | |||||
| except FileNotFoundError: | |||||
| __doc__ = "README file not found." | |||||
| @@ -0,0 +1,6 @@ | |||||
| """TODO: Add docstring.""" | |||||
| from .main import main | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,142 @@ | |||||
| """TODO: Add docstring.""" | |||||
| import io | |||||
| from collections import deque as Deque | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import pyarrow as pa | |||||
| import torch | |||||
| from dora import Node | |||||
| from PIL import Image | |||||
| from vggt.models.vggt import VGGT | |||||
| from vggt.utils.load_fn import load_and_preprocess_images | |||||
| from vggt.utils.pose_enc import pose_encoding_to_extri_intri | |||||
| # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) | |||||
| dtype = torch.float16 | |||||
| # Initialize the model and load the pretrained weights. | |||||
| # This will automatically download the model weights the first time it's run, which may take a while. | |||||
| model = VGGT.from_pretrained("facebook/VGGT-1B") | |||||
| model.eval() | |||||
| # Import vecdeque | |||||
| def main(): | |||||
| """TODO: Add docstring.""" | |||||
| node = Node() | |||||
| raw_images = Deque(maxlen=5) | |||||
| for event in node: | |||||
| if event["type"] == "INPUT": | |||||
| if event["id"] == "TICK": | |||||
| print( | |||||
| f"""Node received: | |||||
| id: {event["id"]}, | |||||
| value: {event["value"]}, | |||||
| metadata: {event["metadata"]}""", | |||||
| ) | |||||
| if "image" in event["id"]: | |||||
| storage = event["value"] | |||||
| metadata = event["metadata"] | |||||
| encoding = metadata["encoding"] | |||||
| width = metadata["width"] | |||||
| height = metadata["height"] | |||||
| if ( | |||||
| encoding == "bgr8" | |||||
| or encoding == "rgb8" | |||||
| or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"] | |||||
| ): | |||||
| channels = 3 | |||||
| storage_type = np.uint8 | |||||
| else: | |||||
| raise RuntimeError(f"Unsupported image encoding: {encoding}") | |||||
| if encoding == "bgr8": | |||||
| frame = ( | |||||
| storage.to_numpy() | |||||
| .astype(storage_type) | |||||
| .reshape((height, width, channels)) | |||||
| ) | |||||
| frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB) | |||||
| elif encoding == "rgb8": | |||||
| frame = ( | |||||
| storage.to_numpy() | |||||
| .astype(storage_type) | |||||
| .reshape((height, width, channels)) | |||||
| ) | |||||
| elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]: | |||||
| storage = storage.to_numpy() | |||||
| frame = cv2.imdecode(storage, cv2.IMREAD_COLOR) | |||||
| frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB) | |||||
| else: | |||||
| raise RuntimeError(f"Unsupported image encoding: {encoding}") | |||||
| image = Image.fromarray(frame) | |||||
| # Save the image to a bytes buffer | |||||
| buffer = io.BytesIO() | |||||
| image.save(buffer, format="JPEG") # or JPEG, etc. | |||||
| # Rewind the buffer's file pointer to the beginning | |||||
| buffer.seek(0) | |||||
| raw_images.append(buffer) | |||||
| with torch.no_grad(): | |||||
| images = load_and_preprocess_images(raw_images) | |||||
| images = images[None] # add batch dimension | |||||
| aggregated_tokens_list, ps_idx = model.aggregator(images) | |||||
| # Predict Cameras | |||||
| pose_enc = model.camera_head(aggregated_tokens_list)[-1] | |||||
| # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world) | |||||
| extrinsic, intrinsic = pose_encoding_to_extri_intri( | |||||
| pose_enc, images.shape[-2:] | |||||
| ) | |||||
| # Predict Depth Maps | |||||
| depth_map, depth_conf = model.depth_head( | |||||
| aggregated_tokens_list, images, ps_idx | |||||
| ) | |||||
| print(depth_conf.max()) | |||||
| depth_map[depth_conf < 1.0] = 0.0 # Set low confidence pixels to 0 | |||||
| depth_map = depth_map.to(torch.float64) | |||||
| depth_map = depth_map[-1][-1].cpu().numpy() | |||||
| # Warning: Make sure to add my_output_id and my_input_id within the dataflow. | |||||
| node.send_output( | |||||
| output_id="depth", | |||||
| data=pa.array(depth_map.ravel()), | |||||
| metadata={ | |||||
| "width": depth_map.shape[1], | |||||
| "height": depth_map.shape[0], | |||||
| }, | |||||
| ) | |||||
| image = images[-1][-1].cpu().numpy() * 255 | |||||
| image = image.astype(np.uint8) | |||||
| # reorder pixels to be in last dimension | |||||
| image = image.transpose(1, 2, 0) | |||||
| print( | |||||
| f"Image shape: {image.shape}, dtype: {image.dtype} and depth map shape: {depth_map.shape}, dtype: {depth_map.dtype}" | |||||
| ) | |||||
| # Warning: Make sure to add my_output_id and my_input_id within the dataflow. | |||||
| node.send_output( | |||||
| output_id="image", | |||||
| data=pa.array(image.ravel()), | |||||
| metadata={ | |||||
| "encoding": "rgb8", | |||||
| "width": image.shape[1], | |||||
| "height": image.shape[0], | |||||
| }, | |||||
| ) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,30 @@ | |||||
| [project] | |||||
| name = "dora-vggt" | |||||
| version = "0.0.0" | |||||
| authors = [{ name = "Your Name", email = "email@email.com" }] | |||||
| description = "dora-vggt" | |||||
| license = { text = "MIT" } | |||||
| readme = "README.md" | |||||
| requires-python = ">=3.10" | |||||
| dependencies = [ | |||||
| "dora-rs >= 0.3.9", | |||||
| "torch>=2.7.0", | |||||
| "torchvision>=0.22.0", | |||||
| "vggt", | |||||
| ] | |||||
| [dependency-groups] | |||||
| dev = ["pytest >=8.1.1", "ruff >=0.9.1"] | |||||
| [project.scripts] | |||||
| dora-vggt = "dora_vggt.main:main" | |||||
| [tool.ruff.lint] | |||||
| extend-select = [ | |||||
| "D", # pydocstyle | |||||
| "UP", | |||||
| ] | |||||
| [tool.uv.sources] | |||||
| vggt = { git = "https://github.com/facebookresearch/vggt" } | |||||
| @@ -0,0 +1,13 @@ | |||||
| """Test module for dora_vggt package.""" | |||||
| import pytest | |||||
| def test_import_main(): | |||||
| """Test importing and running the main function.""" | |||||
| from dora_vggt.main import main | |||||
| # Check that everything is working, and catch Dora RuntimeError | |||||
| # as we're not running in a Dora dataflow. | |||||
| with pytest.raises(RuntimeError): | |||||
| main() | |||||