"""TODO: Add docstring.""" import argparse import io import os import cv2 import numpy as np import pyarrow as pa from dora import Node from PIL import ( Image, ) if True: import pillow_avif # noqa # noqa RUNNER_CI = True if os.getenv("CI") == "true" else False class Plot: """TODO: Add docstring.""" frame: np.array = np.array([]) bboxes: dict = { "bbox": np.array([]), "conf": np.array([]), "labels": np.array([]), } text: str = "" width: np.uint32 = None height: np.uint32 = None def plot_frame(plot): """TODO: Add docstring.""" for bbox in zip(plot.bboxes["bbox"], plot.bboxes["conf"], plot.bboxes["labels"]): [ [min_x, min_y, max_x, max_y], confidence, label, ] = bbox cv2.rectangle( plot.frame, (int(min_x), int(min_y)), (int(max_x), int(max_y)), (0, 255, 0), 2, ) cv2.putText( plot.frame, f"{label}, {confidence:0.2f}", (int(max_x) - 120, int(max_y) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, 1, ) cv2.putText( plot.frame, plot.text, (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, 1, ) if plot.width is not None and plot.height is not None: plot.frame = cv2.resize(plot.frame, (plot.width, plot.height)) if not RUNNER_CI: if len(plot.frame.shape) >= 3: cv2.imshow("Dora Node: opencv-plot", plot.frame) def yuv420p_to_bgr_opencv(yuv_array, width, height): """TODO: Add docstring.""" yuv_array = yuv_array[: width * height * 3 // 2] yuv = yuv_array.reshape((height * 3 // 2, width)) return cv2.cvtColor(yuv, cv2.COLOR_YUV420p2RGB) def main(): # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables. """TODO: Add docstring.""" parser = argparse.ArgumentParser( description="OpenCV Plotter: This node is used to plot text and bounding boxes on an image.", ) parser.add_argument( "--name", type=str, required=False, help="The name of the node in the dataflow.", default="opencv-plot", ) parser.add_argument( "--plot-width", type=int, required=False, help="The width of the plot.", default=None, ) parser.add_argument( "--plot-height", type=int, required=False, help="The height of the plot.", default=None, ) args = parser.parse_args() plot_width = os.getenv("PLOT_WIDTH", args.plot_width) plot_height = os.getenv("PLOT_HEIGHT", args.plot_height) if plot_width is not None: if isinstance(plot_width, str) and plot_width.isnumeric(): plot_width = int(plot_width) if plot_height is not None: if isinstance(plot_height, str) and plot_height.isnumeric(): plot_height = int(plot_height) node = Node( args.name, ) # provide the name to connect to the dataflow if dynamic node plot = Plot() plot.width = plot_width plot.height = plot_height pa.array([]) # initialize pyarrow array for event in node: event_type = event["type"] if event_type == "INPUT": event_id = event["id"] if event_id == "image": storage = event["value"] metadata = event["metadata"] encoding = metadata["encoding"].lower() width = metadata["width"] height = metadata["height"] if encoding == "bgr8": channels = 3 storage_type = np.uint8 plot.frame = ( storage.to_numpy() .astype(storage_type) .reshape((height, width, channels)) .copy() # Copy So that we can add annotation on the image ) elif encoding == "rgb8": channels = 3 storage_type = np.uint8 frame = ( storage.to_numpy() .astype(storage_type) .reshape((height, width, channels)) ) plot.frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]: channels = 3 storage_type = np.uint8 storage = storage.to_numpy() plot.frame = cv2.imdecode(storage, cv2.IMREAD_COLOR) elif encoding == "yuv420": storage = storage.to_numpy() # Convert back to BGR results in more saturated image. channels = 3 storage_type = np.uint8 img_bgr_restored = yuv420p_to_bgr_opencv(storage, width, height) plot.frame = img_bgr_restored elif encoding == "avif": # Convert AVIF to RGB array = storage.to_numpy() bytes = array.tobytes() img = Image.open(io.BytesIO(bytes)) img = img.convert("RGB") plot.frame = np.array(img) plot.frame = cv2.cvtColor(plot.frame, cv2.COLOR_RGB2BGR) else: raise RuntimeError(f"Unsupported image encoding: {encoding}") plot_frame(plot) if not RUNNER_CI: if cv2.waitKey(1) & 0xFF == ord("q"): break elif event_id == "bbox": arrow_bbox = event["value"][0] bbox_format = event["metadata"]["format"].lower() if bbox_format == "xyxy": bbox = arrow_bbox["bbox"].values.to_numpy().reshape(-1, 4) elif bbox_format == "xywh": original_bbox = arrow_bbox["bbox"].values.to_numpy().reshape(-1, 4) bbox = np.array( [ ( x - w / 2, y - h / 2, x + w / 2, y + h / 2, ) for [x, y, w, h] in original_bbox ], ) else: raise RuntimeError(f"Unsupported bbox format: {bbox_format}") plot.bboxes = { "bbox": bbox, "conf": arrow_bbox["conf"].values.to_numpy(), "labels": arrow_bbox["labels"].values.to_numpy( zero_copy_only=False, ), } elif event_id == "text": plot.text = event["value"][0].as_py() elif event_type == "ERROR": raise RuntimeError(event["error"]) if __name__ == "__main__": main()