You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 7.3 kB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """TODO: Add docstring."""
  2. import argparse
  3. import io
  4. import os
  5. import cv2
  6. import numpy as np
  7. import pyarrow as pa
  8. from dora import Node
  9. from PIL import (
  10. Image,
  11. )
  12. if True:
  13. import pillow_avif # noqa # noqa
  14. RUNNER_CI = True if os.getenv("CI") == "true" else False
  15. class Plot:
  16. """TODO: Add docstring."""
  17. frame: np.array = np.array([])
  18. bboxes: dict = {
  19. "bbox": np.array([]),
  20. "conf": np.array([]),
  21. "labels": np.array([]),
  22. }
  23. text: str = ""
  24. width: np.uint32 = None
  25. height: np.uint32 = None
  26. def plot_frame(plot):
  27. """TODO: Add docstring."""
  28. for bbox in zip(plot.bboxes["bbox"], plot.bboxes["conf"], plot.bboxes["labels"]):
  29. [
  30. [min_x, min_y, max_x, max_y],
  31. confidence,
  32. label,
  33. ] = bbox
  34. cv2.rectangle(
  35. plot.frame,
  36. (int(min_x), int(min_y)),
  37. (int(max_x), int(max_y)),
  38. (0, 255, 0),
  39. 2,
  40. )
  41. cv2.putText(
  42. plot.frame,
  43. f"{label}, {confidence:0.2f}",
  44. (int(max_x) - 120, int(max_y) - 10),
  45. cv2.FONT_HERSHEY_SIMPLEX,
  46. 0.5,
  47. (0, 255, 0),
  48. 1,
  49. 1,
  50. )
  51. cv2.putText(
  52. plot.frame,
  53. plot.text,
  54. (20, 20),
  55. cv2.FONT_HERSHEY_SIMPLEX,
  56. 0.5,
  57. (255, 255, 255),
  58. 1,
  59. 1,
  60. )
  61. if plot.width is not None and plot.height is not None:
  62. plot.frame = cv2.resize(plot.frame, (plot.width, plot.height))
  63. if not RUNNER_CI:
  64. if len(plot.frame.shape) >= 3:
  65. cv2.imshow("Dora Node: opencv-plot", plot.frame)
  66. def yuv420p_to_bgr_opencv(yuv_array, width, height):
  67. """TODO: Add docstring."""
  68. yuv_array = yuv_array[: width * height * 3 // 2]
  69. yuv = yuv_array.reshape((height * 3 // 2, width))
  70. return cv2.cvtColor(yuv, cv2.COLOR_YUV420p2RGB)
  71. def main():
  72. # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables.
  73. """TODO: Add docstring."""
  74. parser = argparse.ArgumentParser(
  75. description="OpenCV Plotter: This node is used to plot text and bounding boxes on an image.",
  76. )
  77. parser.add_argument(
  78. "--name",
  79. type=str,
  80. required=False,
  81. help="The name of the node in the dataflow.",
  82. default="opencv-plot",
  83. )
  84. parser.add_argument(
  85. "--plot-width",
  86. type=int,
  87. required=False,
  88. help="The width of the plot.",
  89. default=None,
  90. )
  91. parser.add_argument(
  92. "--plot-height",
  93. type=int,
  94. required=False,
  95. help="The height of the plot.",
  96. default=None,
  97. )
  98. args = parser.parse_args()
  99. plot_width = os.getenv("PLOT_WIDTH", args.plot_width)
  100. plot_height = os.getenv("PLOT_HEIGHT", args.plot_height)
  101. if plot_width is not None:
  102. if isinstance(plot_width, str) and plot_width.isnumeric():
  103. plot_width = int(plot_width)
  104. if plot_height is not None:
  105. if isinstance(plot_height, str) and plot_height.isnumeric():
  106. plot_height = int(plot_height)
  107. node = Node(
  108. args.name,
  109. ) # provide the name to connect to the dataflow if dynamic node
  110. plot = Plot()
  111. plot.width = plot_width
  112. plot.height = plot_height
  113. pa.array([]) # initialize pyarrow array
  114. for event in node:
  115. event_type = event["type"]
  116. if event_type == "INPUT":
  117. event_id = event["id"]
  118. if event_id == "image":
  119. storage = event["value"]
  120. metadata = event["metadata"]
  121. encoding = metadata["encoding"]
  122. width = metadata["width"]
  123. height = metadata["height"]
  124. if encoding == "bgr8":
  125. channels = 3
  126. storage_type = np.uint8
  127. plot.frame = (
  128. storage.to_numpy()
  129. .astype(storage_type)
  130. .reshape((height, width, channels))
  131. .copy() # Copy So that we can add annotation on the image
  132. )
  133. elif encoding == "rgb8":
  134. channels = 3
  135. storage_type = np.uint8
  136. frame = (
  137. storage.to_numpy()
  138. .astype(storage_type)
  139. .reshape((height, width, channels))
  140. )
  141. plot.frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  142. elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
  143. channels = 3
  144. storage_type = np.uint8
  145. storage = storage.to_numpy()
  146. plot.frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
  147. elif encoding == "yuv420":
  148. storage = storage.to_numpy()
  149. # Convert back to BGR results in more saturated image.
  150. channels = 3
  151. storage_type = np.uint8
  152. img_bgr_restored = yuv420p_to_bgr_opencv(storage, width, height)
  153. plot.frame = img_bgr_restored
  154. elif encoding == "avif":
  155. # Convert AVIF to RGB
  156. array = storage.to_numpy()
  157. bytes = array.tobytes()
  158. img = Image.open(io.BytesIO(bytes))
  159. img = img.convert("RGB")
  160. plot.frame = np.array(img)
  161. plot.frame = cv2.cvtColor(plot.frame, cv2.COLOR_RGB2BGR)
  162. else:
  163. raise RuntimeError(f"Unsupported image encoding: {encoding}")
  164. plot_frame(plot)
  165. if not RUNNER_CI:
  166. if cv2.waitKey(1) & 0xFF == ord("q"):
  167. break
  168. elif event_id == "bbox":
  169. arrow_bbox = event["value"][0]
  170. bbox_format = event["metadata"]["format"]
  171. if bbox_format == "xyxy":
  172. bbox = arrow_bbox["bbox"].values.to_numpy().reshape(-1, 4)
  173. elif bbox_format == "xywh":
  174. original_bbox = arrow_bbox["bbox"].values.to_numpy().reshape(-1, 4)
  175. bbox = np.array(
  176. [
  177. (
  178. x - w / 2,
  179. y - h / 2,
  180. x + w / 2,
  181. y + h / 2,
  182. )
  183. for [x, y, w, h] in original_bbox
  184. ],
  185. )
  186. else:
  187. raise RuntimeError(f"Unsupported bbox format: {bbox_format}")
  188. plot.bboxes = {
  189. "bbox": bbox,
  190. "conf": arrow_bbox["conf"].values.to_numpy(),
  191. "labels": arrow_bbox["labels"].values.to_numpy(
  192. zero_copy_only=False,
  193. ),
  194. }
  195. elif event_id == "text":
  196. plot.text = event["value"][0].as_py()
  197. elif event_type == "ERROR":
  198. raise RuntimeError(event["error"])
  199. if __name__ == "__main__":
  200. main()