You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import argparse
  2. import os
  3. import cv2
  4. import numpy as np
  5. import pyarrow as pa
  6. from dora import Node
  7. RUNNER_CI = True if os.getenv("CI") == "true" else False
  8. class Plot:
  9. frame: np.array = np.array([])
  10. bboxes: dict = {
  11. "bbox": np.array([]),
  12. "conf": np.array([]),
  13. "labels": np.array([]),
  14. }
  15. text: str = ""
  16. width: np.uint32 = None
  17. height: np.uint32 = None
  18. def plot_frame(plot):
  19. for bbox in zip(plot.bboxes["bbox"], plot.bboxes["conf"], plot.bboxes["labels"]):
  20. [
  21. [min_x, min_y, max_x, max_y],
  22. confidence,
  23. label,
  24. ] = bbox
  25. cv2.rectangle(
  26. plot.frame,
  27. (int(min_x), int(min_y)),
  28. (int(max_x), int(max_y)),
  29. (0, 255, 0),
  30. 2,
  31. )
  32. cv2.putText(
  33. plot.frame,
  34. f"{label}, {confidence:0.2f}",
  35. (int(max_x) - 120, int(max_y) - 10),
  36. cv2.FONT_HERSHEY_SIMPLEX,
  37. 0.5,
  38. (0, 255, 0),
  39. 1,
  40. 1,
  41. )
  42. cv2.putText(
  43. plot.frame,
  44. plot.text,
  45. (20, 20),
  46. cv2.FONT_HERSHEY_SIMPLEX,
  47. 0.5,
  48. (255, 255, 255),
  49. 1,
  50. 1,
  51. )
  52. if plot.width is not None and plot.height is not None:
  53. plot.frame = cv2.resize(plot.frame, (plot.width, plot.height))
  54. if not RUNNER_CI:
  55. if len(plot.frame.shape) >= 3:
  56. cv2.imshow("Dora Node: opencv-plot", plot.frame)
  57. def main():
  58. # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables.
  59. parser = argparse.ArgumentParser(
  60. description="OpenCV Plotter: This node is used to plot text and bounding boxes on an image."
  61. )
  62. parser.add_argument(
  63. "--name",
  64. type=str,
  65. required=False,
  66. help="The name of the node in the dataflow.",
  67. default="opencv-plot",
  68. )
  69. parser.add_argument(
  70. "--plot-width",
  71. type=int,
  72. required=False,
  73. help="The width of the plot.",
  74. default=None,
  75. )
  76. parser.add_argument(
  77. "--plot-height",
  78. type=int,
  79. required=False,
  80. help="The height of the plot.",
  81. default=None,
  82. )
  83. args = parser.parse_args()
  84. plot_width = os.getenv("PLOT_WIDTH", args.plot_width)
  85. plot_height = os.getenv("PLOT_HEIGHT", args.plot_height)
  86. if plot_width is not None:
  87. if isinstance(plot_width, str) and plot_width.isnumeric():
  88. plot_width = int(plot_width)
  89. if plot_height is not None:
  90. if isinstance(plot_height, str) and plot_height.isnumeric():
  91. plot_height = int(plot_height)
  92. node = Node(
  93. args.name
  94. ) # provide the name to connect to the dataflow if dynamic node
  95. plot = Plot()
  96. plot.width = plot_width
  97. plot.height = plot_height
  98. pa.array([]) # initialize pyarrow array
  99. for event in node:
  100. event_type = event["type"]
  101. if event_type == "INPUT":
  102. event_id = event["id"]
  103. if event_id == "image":
  104. storage = event["value"]
  105. metadata = event["metadata"]
  106. encoding = metadata["encoding"]
  107. width = metadata["width"]
  108. height = metadata["height"]
  109. if encoding == "bgr8":
  110. channels = 3
  111. storage_type = np.uint8
  112. plot.frame = (
  113. storage.to_numpy()
  114. .astype(storage_type)
  115. .reshape((height, width, channels))
  116. .copy() # Copy So that we can add annotation on the image
  117. )
  118. elif encoding == "rgb8":
  119. channels = 3
  120. storage_type = np.uint8
  121. frame = (
  122. storage.to_numpy()
  123. .astype(storage_type)
  124. .reshape((height, width, channels))
  125. )
  126. plot.frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  127. else:
  128. raise RuntimeError(f"Unsupported image encoding: {encoding}")
  129. plot_frame(plot)
  130. if not RUNNER_CI:
  131. if cv2.waitKey(1) & 0xFF == ord("q"):
  132. break
  133. elif event_id == "bbox":
  134. arrow_bbox = event["value"][0]
  135. bbox_format = event["metadata"]["format"]
  136. if bbox_format == "xyxy":
  137. bbox = arrow_bbox["bbox"].values.to_numpy().reshape(-1, 4)
  138. elif bbox_format == "xywh":
  139. original_bbox = arrow_bbox["bbox"].values.to_numpy().reshape(-1, 4)
  140. bbox = np.array(
  141. [
  142. (
  143. x - w / 2,
  144. y - h / 2,
  145. x + w / 2,
  146. y + h / 2,
  147. )
  148. for [x, y, w, h] in original_bbox
  149. ]
  150. )
  151. else:
  152. raise RuntimeError(f"Unsupported bbox format: {bbox_format}")
  153. plot.bboxes = {
  154. "bbox": bbox,
  155. "conf": arrow_bbox["conf"].values.to_numpy(),
  156. "labels": arrow_bbox["labels"].values.to_numpy(
  157. zero_copy_only=False
  158. ),
  159. }
  160. elif event_id == "text":
  161. plot.text = event["value"][0].as_py()
  162. elif event_type == "ERROR":
  163. raise RuntimeError(event["error"])
  164. if __name__ == "__main__":
  165. main()