| @@ -14,11 +14,12 @@ from vggt.utils.load_fn import load_and_preprocess_images | |||
| from vggt.utils.pose_enc import pose_encoding_to_extri_intri | |||
| # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) | |||
| dtype = torch.float16 | |||
| dtype = torch.bfloat16 | |||
| # Initialize the model and load the pretrained weights. | |||
| # This will automatically download the model weights the first time it's run, which may take a while. | |||
| model = VGGT.from_pretrained("facebook/VGGT-1B") | |||
| model = VGGT.from_pretrained("facebook/VGGT-1B").to("cuda") | |||
| model.eval() | |||
| # Import vecdeque | |||
| @@ -27,17 +28,10 @@ model.eval() | |||
| def main(): | |||
| """TODO: Add docstring.""" | |||
| node = Node() | |||
| raw_images = Deque(maxlen=5) | |||
| raw_images = Deque(maxlen=2) | |||
| for event in node: | |||
| if event["type"] == "INPUT": | |||
| if event["id"] == "TICK": | |||
| print( | |||
| f"""Node received: | |||
| id: {event["id"]}, | |||
| value: {event["value"]}, | |||
| metadata: {event["metadata"]}""", | |||
| ) | |||
| if "image" in event["id"]: | |||
| storage = event["value"] | |||
| @@ -86,7 +80,7 @@ def main(): | |||
| raw_images.append(buffer) | |||
| with torch.no_grad(): | |||
| images = load_and_preprocess_images(raw_images) | |||
| images = load_and_preprocess_images(raw_images).to("cuda") | |||
| images = images[None] # add batch dimension | |||
| aggregated_tokens_list, ps_idx = model.aggregator(images) | |||
| @@ -96,6 +90,11 @@ def main(): | |||
| extrinsic, intrinsic = pose_encoding_to_extri_intri( | |||
| pose_enc, images.shape[-2:] | |||
| ) | |||
| intrinsic = intrinsic[-1][-1] | |||
| f_0 = intrinsic[0, 0] | |||
| f_1 = intrinsic[1, 1] | |||
| r_0 = intrinsic[0, 2] | |||
| r_1 = intrinsic[1, 2] | |||
| # Predict Depth Maps | |||
| depth_map, depth_conf = model.depth_head( | |||
| @@ -114,6 +113,14 @@ def main(): | |||
| metadata={ | |||
| "width": depth_map.shape[1], | |||
| "height": depth_map.shape[0], | |||
| "focal": [ | |||
| int(f_0), | |||
| int(f_1), | |||
| ], | |||
| "resolution": [ | |||
| int(r_0), | |||
| int(r_1), | |||
| ], | |||
| }, | |||
| ) | |||