You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 6.8 kB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """TODO: Add docstring."""
  2. import argparse
  3. import os
  4. import cv2
  5. import numpy as np
  6. import pyarrow as pa
  7. from dora import Node
  8. RUNNER_CI = True if os.getenv("CI") == "true" else False
  9. class Plot:
  10. """TODO: Add docstring."""
  11. frame: np.array = np.array([])
  12. bboxes: dict = {
  13. "bbox": np.array([]),
  14. "conf": np.array([]),
  15. "labels": np.array([]),
  16. }
  17. text: str = ""
  18. width: np.uint32 = None
  19. height: np.uint32 = None
  20. def plot_frame(plot):
  21. """TODO: Add docstring."""
  22. for bbox in zip(plot.bboxes["bbox"], plot.bboxes["conf"], plot.bboxes["labels"]):
  23. [
  24. [min_x, min_y, max_x, max_y],
  25. confidence,
  26. label,
  27. ] = bbox
  28. cv2.rectangle(
  29. plot.frame,
  30. (int(min_x), int(min_y)),
  31. (int(max_x), int(max_y)),
  32. (0, 255, 0),
  33. 2,
  34. )
  35. cv2.putText(
  36. plot.frame,
  37. f"{label}, {confidence:0.2f}",
  38. (int(max_x) - 120, int(max_y) - 10),
  39. cv2.FONT_HERSHEY_SIMPLEX,
  40. 0.5,
  41. (0, 255, 0),
  42. 1,
  43. 1,
  44. )
  45. cv2.putText(
  46. plot.frame,
  47. plot.text,
  48. (20, 20),
  49. cv2.FONT_HERSHEY_SIMPLEX,
  50. 0.5,
  51. (255, 255, 255),
  52. 1,
  53. 1,
  54. )
  55. if plot.width is not None and plot.height is not None:
  56. plot.frame = cv2.resize(plot.frame, (plot.width, plot.height))
  57. if not RUNNER_CI:
  58. if len(plot.frame.shape) >= 3:
  59. cv2.imshow("Dora Node: opencv-plot", plot.frame)
  60. def yuv420p_to_bgr_opencv(yuv_array, width, height):
  61. """TODO: Add docstring."""
  62. yuv = yuv_array.reshape((height * 3 // 2, width))
  63. return cv2.cvtColor(yuv, cv2.COLOR_YUV420p2RGB)
  64. def main():
  65. # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables.
  66. """TODO: Add docstring."""
  67. parser = argparse.ArgumentParser(
  68. description="OpenCV Plotter: This node is used to plot text and bounding boxes on an image.",
  69. )
  70. parser.add_argument(
  71. "--name",
  72. type=str,
  73. required=False,
  74. help="The name of the node in the dataflow.",
  75. default="opencv-plot",
  76. )
  77. parser.add_argument(
  78. "--plot-width",
  79. type=int,
  80. required=False,
  81. help="The width of the plot.",
  82. default=None,
  83. )
  84. parser.add_argument(
  85. "--plot-height",
  86. type=int,
  87. required=False,
  88. help="The height of the plot.",
  89. default=None,
  90. )
  91. args = parser.parse_args()
  92. plot_width = os.getenv("PLOT_WIDTH", args.plot_width)
  93. plot_height = os.getenv("PLOT_HEIGHT", args.plot_height)
  94. if plot_width is not None:
  95. if isinstance(plot_width, str) and plot_width.isnumeric():
  96. plot_width = int(plot_width)
  97. if plot_height is not None:
  98. if isinstance(plot_height, str) and plot_height.isnumeric():
  99. plot_height = int(plot_height)
  100. node = Node(
  101. args.name,
  102. ) # provide the name to connect to the dataflow if dynamic node
  103. plot = Plot()
  104. plot.width = plot_width
  105. plot.height = plot_height
  106. pa.array([]) # initialize pyarrow array
  107. for event in node:
  108. event_type = event["type"]
  109. if event_type == "INPUT":
  110. event_id = event["id"]
  111. if event_id == "image":
  112. storage = event["value"]
  113. metadata = event["metadata"]
  114. encoding = metadata["encoding"]
  115. width = metadata["width"]
  116. height = metadata["height"]
  117. if encoding == "bgr8":
  118. channels = 3
  119. storage_type = np.uint8
  120. plot.frame = (
  121. storage.to_numpy()
  122. .astype(storage_type)
  123. .reshape((height, width, channels))
  124. .copy() # Copy So that we can add annotation on the image
  125. )
  126. elif encoding == "rgb8":
  127. channels = 3
  128. storage_type = np.uint8
  129. frame = (
  130. storage.to_numpy()
  131. .astype(storage_type)
  132. .reshape((height, width, channels))
  133. )
  134. plot.frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  135. elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
  136. channels = 3
  137. storage_type = np.uint8
  138. storage = storage.to_numpy()
  139. plot.frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
  140. elif encoding == "yuv420":
  141. storage = storage.to_numpy()
  142. # Convert back to BGR results in more saturated image.
  143. channels = 3
  144. storage_type = np.uint8
  145. img_bgr_restored = yuv420p_to_bgr_opencv(storage, width, height)
  146. plot.frame = img_bgr_restored
  147. else:
  148. raise RuntimeError(f"Unsupported image encoding: {encoding}")
  149. plot_frame(plot)
  150. if not RUNNER_CI:
  151. if cv2.waitKey(1) & 0xFF == ord("q"):
  152. break
  153. elif event_id == "bbox":
  154. arrow_bbox = event["value"][0]
  155. bbox_format = event["metadata"]["format"]
  156. if bbox_format == "xyxy":
  157. bbox = arrow_bbox["bbox"].values.to_numpy().reshape(-1, 4)
  158. elif bbox_format == "xywh":
  159. original_bbox = arrow_bbox["bbox"].values.to_numpy().reshape(-1, 4)
  160. bbox = np.array(
  161. [
  162. (
  163. x - w / 2,
  164. y - h / 2,
  165. x + w / 2,
  166. y + h / 2,
  167. )
  168. for [x, y, w, h] in original_bbox
  169. ],
  170. )
  171. else:
  172. raise RuntimeError(f"Unsupported bbox format: {bbox_format}")
  173. plot.bboxes = {
  174. "bbox": bbox,
  175. "conf": arrow_bbox["conf"].values.to_numpy(),
  176. "labels": arrow_bbox["labels"].values.to_numpy(
  177. zero_copy_only=False,
  178. ),
  179. }
  180. elif event_id == "text":
  181. plot.text = event["value"][0].as_py()
  182. elif event_type == "ERROR":
  183. raise RuntimeError(event["error"])
  184. if __name__ == "__main__":
  185. main()