|
|
@@ -1,9 +1,14 @@ |
|
|
|
|
|
import os |
|
|
|
|
|
from collections import deque |
|
|
|
|
|
|
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import pyarrow as pa |
|
|
import pyarrow as pa |
|
|
from dora import Node |
|
|
|
|
|
import cv2 |
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from collections import deque |
|
|
|
|
|
|
|
|
from dora import Node |
|
|
|
|
|
|
|
|
|
|
|
INTERACTIVE_MODE = os.getenv("INTERACTIVE_MODE", "false").lower() == "true" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VideoTrackingNode: |
|
|
class VideoTrackingNode: |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
@@ -12,10 +17,12 @@ class VideoTrackingNode: |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online") |
|
|
self.model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online") |
|
|
self.model = self.model.to(self.device) |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
self.model.step = 8 |
|
|
self.model.step = 8 |
|
|
self.buffer_size = self.model.step * 2 |
|
|
|
|
|
|
|
|
self.buffer_size = self.model.step * 2 |
|
|
self.window_frames = deque(maxlen=self.buffer_size) |
|
|
self.window_frames = deque(maxlen=self.buffer_size) |
|
|
self.is_first_step = True |
|
|
self.is_first_step = True |
|
|
|
|
|
self.accept_new_points = True |
|
|
self.clicked_points = [] |
|
|
self.clicked_points = [] |
|
|
self.input_points = [] |
|
|
self.input_points = [] |
|
|
|
|
|
|
|
|
@@ -29,14 +36,12 @@ class VideoTrackingNode: |
|
|
"""Process frame for tracking""" |
|
|
"""Process frame for tracking""" |
|
|
if len(self.window_frames) == self.buffer_size: |
|
|
if len(self.window_frames) == self.buffer_size: |
|
|
all_points = self.input_points + self.clicked_points |
|
|
all_points = self.input_points + self.clicked_points |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not all_points: |
|
|
if not all_points: |
|
|
print("No points to track") |
|
|
print("No points to track") |
|
|
return None, None |
|
|
return None, None |
|
|
|
|
|
|
|
|
video_chunk = torch.tensor( |
|
|
video_chunk = torch.tensor( |
|
|
np.stack(list(self.window_frames)), |
|
|
|
|
|
device=self.device |
|
|
|
|
|
|
|
|
np.stack(list(self.window_frames)), device=self.device |
|
|
).float() |
|
|
).float() |
|
|
video_chunk = video_chunk / 255.0 |
|
|
video_chunk = video_chunk / 255.0 |
|
|
# Reshape to [B,T,C,H,W] |
|
|
# Reshape to [B,T,C,H,W] |
|
|
@@ -50,11 +55,12 @@ class VideoTrackingNode: |
|
|
is_first_step=self.is_first_step, |
|
|
is_first_step=self.is_first_step, |
|
|
grid_size=0, |
|
|
grid_size=0, |
|
|
queries=queries, |
|
|
queries=queries, |
|
|
add_support_grid=False |
|
|
|
|
|
|
|
|
add_support_grid=False, |
|
|
) |
|
|
) |
|
|
self.is_first_step = False |
|
|
self.is_first_step = False |
|
|
|
|
|
|
|
|
if pred_tracks is not None and pred_visibility is not None: |
|
|
if pred_tracks is not None and pred_visibility is not None: |
|
|
|
|
|
self.accept_new_points = True |
|
|
tracks = pred_tracks[0, -1].cpu().numpy() |
|
|
tracks = pred_tracks[0, -1].cpu().numpy() |
|
|
visibility = pred_visibility[0, -1].cpu().numpy() |
|
|
visibility = pred_visibility[0, -1].cpu().numpy() |
|
|
visible_tracks = [] |
|
|
visible_tracks = [] |
|
|
@@ -66,84 +72,131 @@ class VideoTrackingNode: |
|
|
frame_viz = frame.copy() |
|
|
frame_viz = frame.copy() |
|
|
num_input_stream = len(self.input_points) |
|
|
num_input_stream = len(self.input_points) |
|
|
# Draw input points in red |
|
|
# Draw input points in red |
|
|
for i, (pt, vis) in enumerate(zip(tracks[:num_input_stream], visibility[:num_input_stream])): |
|
|
|
|
|
|
|
|
for i, (pt, vis) in enumerate( |
|
|
|
|
|
zip(tracks[:num_input_stream], visibility[:num_input_stream]) |
|
|
|
|
|
): |
|
|
if vis > 0.5: |
|
|
if vis > 0.5: |
|
|
x, y = int(pt[0]), int(pt[1]) |
|
|
x, y = int(pt[0]), int(pt[1]) |
|
|
cv2.circle(frame_viz, (x, y), radius=3, |
|
|
|
|
|
color=(0, 255, 0), thickness=-1) |
|
|
|
|
|
cv2.putText(frame_viz, f"I{i}", (x + 5, y - 5), |
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cv2.circle( |
|
|
|
|
|
frame_viz, (x, y), radius=3, color=(0, 255, 0), thickness=-1 |
|
|
|
|
|
) |
|
|
|
|
|
cv2.putText( |
|
|
|
|
|
frame_viz, |
|
|
|
|
|
f"I{i}", |
|
|
|
|
|
(x + 5, y - 5), |
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
|
|
|
0.5, |
|
|
|
|
|
(0, 255, 0), |
|
|
|
|
|
1, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
# Draw clicked points in red |
|
|
# Draw clicked points in red |
|
|
for i, (pt, vis) in enumerate(zip(tracks[num_input_stream:], visibility[num_input_stream:])): |
|
|
|
|
|
|
|
|
for i, (pt, vis) in enumerate( |
|
|
|
|
|
zip(tracks[num_input_stream:], visibility[num_input_stream:]) |
|
|
|
|
|
): |
|
|
if vis > 0.5: |
|
|
if vis > 0.5: |
|
|
x, y = int(pt[0]), int(pt[1]) |
|
|
x, y = int(pt[0]), int(pt[1]) |
|
|
cv2.circle(frame_viz, (x, y), radius=3, |
|
|
|
|
|
color=(0, 0, 255), thickness=-1) |
|
|
|
|
|
cv2.putText(frame_viz, f"C{i}", (x + 5, y - 5), |
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cv2.circle( |
|
|
|
|
|
frame_viz, (x, y), radius=3, color=(0, 0, 255), thickness=-1 |
|
|
|
|
|
) |
|
|
|
|
|
cv2.putText( |
|
|
|
|
|
frame_viz, |
|
|
|
|
|
f"C{i}", |
|
|
|
|
|
(x + 5, y - 5), |
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
|
|
|
0.5, |
|
|
|
|
|
(0, 0, 255), |
|
|
|
|
|
1, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
# Send tracked points |
|
|
# Send tracked points |
|
|
if len(visible_tracks) > 0: |
|
|
if len(visible_tracks) > 0: |
|
|
self.node.send_output( |
|
|
self.node.send_output( |
|
|
"tracked_points", |
|
|
|
|
|
|
|
|
"points", |
|
|
pa.array(visible_tracks.ravel()), |
|
|
pa.array(visible_tracks.ravel()), |
|
|
{ |
|
|
{ |
|
|
"num_points": len(visible_tracks), |
|
|
"num_points": len(visible_tracks), |
|
|
"dtype": "float32", |
|
|
"dtype": "float32", |
|
|
"shape": (len(visible_tracks), 2) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
"shape": (len(visible_tracks), 2), |
|
|
|
|
|
}, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return frame, frame_viz |
|
|
return frame, frame_viz |
|
|
|
|
|
|
|
|
return None, None |
|
|
return None, None |
|
|
|
|
|
|
|
|
def run(self): |
|
|
def run(self): |
|
|
"""Main run loop""" |
|
|
"""Main run loop""" |
|
|
cv2.namedWindow("Raw Feed", cv2.WINDOW_NORMAL) |
|
|
|
|
|
cv2.setMouseCallback("Raw Feed", self.mouse_callback) |
|
|
|
|
|
|
|
|
if INTERACTIVE_MODE: |
|
|
|
|
|
cv2.namedWindow("Interactive Feed to track point", cv2.WINDOW_NORMAL) |
|
|
|
|
|
cv2.setMouseCallback("Interactive Feed to track point", self.mouse_callback) |
|
|
|
|
|
|
|
|
for event in self.node: |
|
|
for event in self.node: |
|
|
if event["type"] == "INPUT": |
|
|
if event["type"] == "INPUT": |
|
|
if event["id"] == "image": |
|
|
if event["id"] == "image": |
|
|
metadata = event["metadata"] |
|
|
metadata = event["metadata"] |
|
|
frame = event["value"].to_numpy().reshape(( |
|
|
|
|
|
metadata["height"], |
|
|
|
|
|
metadata["width"], |
|
|
|
|
|
3 |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
frame = ( |
|
|
|
|
|
event["value"] |
|
|
|
|
|
.to_numpy() |
|
|
|
|
|
.reshape((metadata["height"], metadata["width"], 3)) |
|
|
|
|
|
) |
|
|
# Add frame to tracking window |
|
|
# Add frame to tracking window |
|
|
self.window_frames.append(frame) |
|
|
self.window_frames.append(frame) |
|
|
original_frame, tracked_frame = self.process_tracking(frame) |
|
|
original_frame, tracked_frame = self.process_tracking(frame) |
|
|
if original_frame is not None and tracked_frame is not None: |
|
|
if original_frame is not None and tracked_frame is not None: |
|
|
self.node.send_output("image", |
|
|
|
|
|
pa.array(original_frame.ravel()), |
|
|
|
|
|
metadata |
|
|
|
|
|
) |
|
|
|
|
|
self.node.send_output("tracked_image", |
|
|
|
|
|
pa.array(tracked_frame.ravel()), |
|
|
|
|
|
metadata |
|
|
|
|
|
|
|
|
self.node.send_output( |
|
|
|
|
|
"tracked_image", pa.array(tracked_frame.ravel()), metadata |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
display_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
cv2.imshow("Raw Feed", display_frame) |
|
|
|
|
|
cv2.waitKey(1) |
|
|
|
|
|
|
|
|
|
|
|
if event["id"] == "points_to_track": |
|
|
|
|
|
|
|
|
if INTERACTIVE_MODE: |
|
|
|
|
|
display_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
cv2.imshow("Interactive Feed to track point", display_frame) |
|
|
|
|
|
cv2.waitKey(1) |
|
|
|
|
|
|
|
|
|
|
|
if event["id"] == "points": |
|
|
|
|
|
if not self.accept_new_points: |
|
|
|
|
|
continue |
|
|
# Handle points from input_stream node |
|
|
# Handle points from input_stream node |
|
|
metadata = event["metadata"] |
|
|
metadata = event["metadata"] |
|
|
points_array = event["value"].to_numpy() |
|
|
points_array = event["value"].to_numpy() |
|
|
num_points = metadata["num_points"] |
|
|
|
|
|
self.input_points = points_array.reshape((num_points, 2)).tolist() |
|
|
|
|
|
|
|
|
self.input_points = points_array.reshape((-1, 2)).tolist() |
|
|
|
|
|
self.accept_new_points = False |
|
|
self.is_first_step = True |
|
|
self.is_first_step = True |
|
|
print(f"Received {num_points} points from input_stream") |
|
|
|
|
|
|
|
|
if event["id"] == "boxes2d": |
|
|
|
|
|
if not self.accept_new_points: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
# Handle points from input_stream node |
|
|
|
|
|
metadata = event["metadata"] |
|
|
|
|
|
if isinstance(event["value"], pa.StructArray): |
|
|
|
|
|
boxes2d = ( |
|
|
|
|
|
event["value"] |
|
|
|
|
|
.get("bbox") |
|
|
|
|
|
.values.to_numpy() |
|
|
|
|
|
.reshape((-1, 4)) |
|
|
|
|
|
) |
|
|
|
|
|
_labels = ( |
|
|
|
|
|
event["value"] |
|
|
|
|
|
.get("labels") |
|
|
|
|
|
.values.to_numpy(zero_copy_only=False) |
|
|
|
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
boxes2d = event["value"].to_numpy().reshape((-1, 4)) |
|
|
|
|
|
_labels = None |
|
|
|
|
|
|
|
|
|
|
|
self.input_points = [ |
|
|
|
|
|
[int((x_min + x_max) / 2), int((y_min + y_max) / 2)] |
|
|
|
|
|
for x_min, y_min, x_max, y_max in boxes2d |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
self.is_first_step = True |
|
|
|
|
|
self.accept_new_points = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
def main(): |
|
|
tracker = VideoTrackingNode() |
|
|
tracker = VideoTrackingNode() |
|
|
tracker.run() |
|
|
tracker.run() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
main() |