From ac8f679850ffa011feff034a6ef5514c451e8ee2 Mon Sep 17 00:00:00 2001 From: haixuantao Date: Fri, 13 Jun 2025 18:06:32 +0200 Subject: [PATCH] Add intrinsic parameter into the demo --- node-hub/dora-vggt/dora_vggt/main.py | 29 +++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/node-hub/dora-vggt/dora_vggt/main.py b/node-hub/dora-vggt/dora_vggt/main.py index 32aad2ae..7c0e24c7 100644 --- a/node-hub/dora-vggt/dora_vggt/main.py +++ b/node-hub/dora-vggt/dora_vggt/main.py @@ -14,11 +14,12 @@ from vggt.utils.load_fn import load_and_preprocess_images from vggt.utils.pose_enc import pose_encoding_to_extri_intri # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) -dtype = torch.float16 + +dtype = torch.bfloat16 # Initialize the model and load the pretrained weights. # This will automatically download the model weights the first time it's run, which may take a while. -model = VGGT.from_pretrained("facebook/VGGT-1B") +model = VGGT.from_pretrained("facebook/VGGT-1B").to("cuda") model.eval() # Import vecdeque @@ -27,17 +28,10 @@ model.eval() def main(): """TODO: Add docstring.""" node = Node() - raw_images = Deque(maxlen=5) + raw_images = Deque(maxlen=2) for event in node: if event["type"] == "INPUT": - if event["id"] == "TICK": - print( - f"""Node received: - id: {event["id"]}, - value: {event["value"]}, - metadata: {event["metadata"]}""", - ) if "image" in event["id"]: storage = event["value"] @@ -86,7 +80,7 @@ def main(): raw_images.append(buffer) with torch.no_grad(): - images = load_and_preprocess_images(raw_images) + images = load_and_preprocess_images(raw_images).to("cuda") images = images[None] # add batch dimension aggregated_tokens_list, ps_idx = model.aggregator(images) @@ -96,6 +90,11 @@ def main(): extrinsic, intrinsic = pose_encoding_to_extri_intri( pose_enc, images.shape[-2:] ) + intrinsic = intrinsic[-1][-1] + f_0 = intrinsic[0, 0] + f_1 = intrinsic[1, 1] + r_0 = intrinsic[0, 2] + r_1 = intrinsic[1, 2] # Predict Depth Maps depth_map, depth_conf = model.depth_head( @@ -114,6 +113,14 @@ def main(): metadata={ "width": depth_map.shape[1], "height": depth_map.shape[0], + "focal": [ + int(f_0), + int(f_1), + ], + "resolution": [ + int(r_0), + int(r_1), + ], }, )