You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import os
  2. import argparse
  3. import numpy as np
  4. import pyarrow as pa
  5. from dora import Node
  6. from ultralytics import YOLO
  7. def main():
  8. # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables.
  9. parser = argparse.ArgumentParser(
  10. description="UltraLytics YOLO: This node is used to perform object detection using the UltraLytics YOLO model."
  11. )
  12. parser.add_argument(
  13. "--name",
  14. type=str,
  15. required=False,
  16. help="The name of the node in the dataflow.",
  17. default="ultralytics-yolo",
  18. )
  19. parser.add_argument(
  20. "--model",
  21. type=str,
  22. required=False,
  23. help="The name of the model file (e.g. yolov8n.pt).",
  24. default="yolov8n.pt",
  25. )
  26. args = parser.parse_args()
  27. model_path = os.getenv("MODEL", args.model)
  28. model = YOLO(model_path)
  29. node = Node(args.name)
  30. pa.array([]) # initialize pyarrow array
  31. for event in node:
  32. event_type = event["type"]
  33. if event_type == "INPUT":
  34. event_id = event["id"]
  35. if event_id == "image":
  36. arrow_image = event["value"][0]
  37. encoding = arrow_image["encoding"].as_py()
  38. if encoding == "bgr8":
  39. channels = 3
  40. storage_type = np.uint8
  41. else:
  42. raise Exception(f"Unsupported image encoding: {encoding}")
  43. image = {
  44. "width": np.uint32(arrow_image["width"].as_py()),
  45. "height": np.uint32(arrow_image["height"].as_py()),
  46. "encoding": encoding,
  47. "channels": channels,
  48. "data": arrow_image["data"].values.to_numpy().astype(storage_type),
  49. }
  50. frame = image["data"].reshape(
  51. (image["height"], image["width"], image["channels"])
  52. )
  53. if encoding == "bgr8":
  54. frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
  55. results = model(frame, verbose=False) # includes NMS
  56. bboxes = np.array(results[0].boxes.xyxy.cpu())
  57. conf = np.array(results[0].boxes.conf.cpu())
  58. labels = np.array(results[0].boxes.cls.cpu())
  59. names = [model.names.get(label) for label in labels]
  60. bbox = {
  61. "bbox": bboxes.ravel(),
  62. "conf": conf,
  63. "names": names,
  64. }
  65. node.send_output(
  66. "bbox",
  67. pa.array([bbox]),
  68. event["metadata"],
  69. )
  70. elif event_type == "ERROR":
  71. raise Exception(event["error"])
  72. if __name__ == "__main__":
  73. main()