|
- """TODO: Add docstring."""
-
- import io
- import os
- from collections import deque as Deque
-
- import cv2
- import numpy as np
- import pyarrow as pa
- import torch
- from dora import Node
- from PIL import Image
- from vggt.models.vggt import VGGT
- from vggt.utils.geometry import unproject_depth_map_to_point_map
- from vggt.utils.load_fn import load_and_preprocess_images
- from vggt.utils.pose_enc import pose_encoding_to_extri_intri
-
- CAMERA_HEIGHT_Y = os.getenv("CAMERA_HEIGHT_Y", "0.115")
-
- # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
- dtype = torch.bfloat16
-
- # Check if cuda is available and set the device accordingly
- device = "cuda" if torch.cuda.is_available() else "cpu"
-
- # Initialize the model and load the pretrained weights.
- # This will automatically download the model weights the first time it's run, which may take a while.
- model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
- model.eval()
-
- DEPTH_ENCODING = os.environ.get("DEPTH_ENCODING", "float64")
-
-
- def main():
- """TODO: Add docstring."""
- node = Node()
- raw_images = Deque(maxlen=2)
-
- for event in node:
- if event["type"] == "INPUT":
- if "image" in event["id"]:
- storage = event["value"]
- metadata = event["metadata"]
- encoding = metadata["encoding"]
- width = metadata["width"]
- height = metadata["height"]
-
- if (
- encoding == "bgr8"
- or encoding == "rgb8"
- or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]
- ):
- channels = 3
- storage_type = np.uint8
- else:
- raise RuntimeError(f"Unsupported image encoding: {encoding}")
-
- if encoding == "bgr8":
- frame = (
- storage.to_numpy()
- .astype(storage_type)
- .reshape((height, width, channels))
- )
- frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
- elif encoding == "rgb8":
- frame = (
- storage.to_numpy()
- .astype(storage_type)
- .reshape((height, width, channels))
- )
- elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
- storage = storage.to_numpy()
- frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
- frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
- else:
- raise RuntimeError(f"Unsupported image encoding: {encoding}")
- image = Image.fromarray(frame)
-
- # Save the image to a bytes buffer
- buffer = io.BytesIO()
- image.save(buffer, format="JPEG") # or JPEG, etc.
-
- # Rewind the buffer's file pointer to the beginning
- buffer.seek(0)
- raw_images.append(buffer)
-
- with torch.no_grad():
- images = load_and_preprocess_images(raw_images).to(device)
-
- images = images[None] # add batch dimension
- aggregated_tokens_list, ps_idx = model.aggregator(images)
- # Predict Cameras
- pose_enc = model.camera_head(aggregated_tokens_list)[-1]
- # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
- extrinsic, intrinsic = pose_encoding_to_extri_intri(
- pose_enc, images.shape[-2:]
- )
- print(f"Extrinsic: {extrinsic}")
- print(f"Intrinsic: {intrinsic}")
-
- # Predict Depth Maps
- depth_map, depth_conf = model.depth_head(
- aggregated_tokens_list, images, ps_idx
- )
- depth_map[depth_conf < 0.6] = 0.0 # Set low confidence pixels to 0
-
- # Construct 3D Points from Depth Maps and Cameras
- # which usually leads to more accurate 3D points than point map branch
- point_map_by_unprojection = unproject_depth_map_to_point_map(
- depth_map.squeeze(0), extrinsic.squeeze(0), intrinsic.squeeze(0)
- )
-
- # Get the last quartile of the 2nd axis
- z_value = point_map_by_unprojection[0, :, :, 2] # S, H, W, 3
- scale_factor = 0.51
-
- print(
- f"Event Id: {event['id']} Scale factor: {scale_factor}, with height: {CAMERA_HEIGHT_Y} and max depth: {point_map_by_unprojection[0, :, :, 1].max()}"
- )
- print(
- f" 0. all min and max depth values: {point_map_by_unprojection[0, :, :, 0].min()} / {point_map_by_unprojection[0, :, :, 0].max()}"
- )
- print(
- f" 1. all min and max depth values: {point_map_by_unprojection[0, :, :, 1].min()} / {point_map_by_unprojection[0, :, :, 1].max()}"
- )
- print(
- f" 2. all min and max depth values: {point_map_by_unprojection[0, :, :, 2].min()} / {point_map_by_unprojection[0, :, :, 2].max()}"
- )
-
- print(
- f"Depth map before scaling: min and max: {depth_map.min()} / {depth_map.max()}"
- )
-
- depth_map = (
- depth_map * scale_factor
- ) # Scale depth map to the desired depth
- print(
- f"Depth map after scaling min and max in meters: {depth_map.min()} / {depth_map.max()}. Depth map shape: {depth_map.shape}"
- )
- depth_map = depth_map.to(torch.float64)
-
- intrinsic = intrinsic[-1][-1]
- f_0 = intrinsic[0, 0]
- f_1 = intrinsic[1, 1]
- r_0 = intrinsic[0, 2]
- r_1 = intrinsic[1, 2]
- depth_map = depth_map[-1][-1].cpu().numpy()
- # Warning: Make sure to add my_output_id and my_input_id within the dataflow.
- if DEPTH_ENCODING == "mono16":
- depth_map = (depth_map * 1000).astype(np.uint16)
-
- node.send_output(
- output_id=event["id"].replace("image", "depth"),
- data=pa.array(depth_map.ravel()),
- metadata={
- "width": depth_map.shape[1],
- "height": depth_map.shape[0],
- "encoding": DEPTH_ENCODING,
- "focal": [
- int(f_0),
- int(f_1),
- ],
- "resolution": [
- int(r_0),
- int(r_1),
- ],
- },
- )
-
- image = images[-1][-1].cpu().numpy() * 255
- image = image.astype(np.uint8)
- # reorder pixels to be in last dimension
- image = image.transpose(1, 2, 0)
-
- # Warning: Make sure to add my_output_id and my_input_id within the dataflow.
- node.send_output(
- output_id=event["id"],
- data=pa.array(image.ravel()),
- metadata={
- "encoding": "rgb8",
- "width": image.shape[1],
- "height": image.shape[0],
- },
- )
-
-
- if __name__ == "__main__":
- main()
|