You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 4.7 kB

9 months ago
9 months ago
9 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. """TODO: Add docstring."""
  2. import cv2
  3. import mediapipe as mp
  4. import numpy as np
  5. import pyarrow as pa
  6. from dora import Node
  7. # Initialiser MediaPipe Pose
  8. mp_pose = mp.solutions.pose
  9. pose = mp_pose.Pose()
  10. mp_draw = mp.solutions.drawing_utils
  11. def get_3d_coordinates(landmark, depth_frame, w, h, resolution, focal_length):
  12. """Convert 2D landmark coordinates to 3D coordinates."""
  13. cx, cy = int(landmark.x * w), int(landmark.y * h)
  14. if 0 < cx < w and 0 < cy < h:
  15. depth = depth_frame[cy, cx] / 1_000.0
  16. if depth > 0:
  17. fx, fy = focal_length
  18. ppx, ppy = resolution
  19. x = (cy - ppy) * depth / fy
  20. y = (cx - ppx) * depth / fx
  21. # Convert to right-handed coordinate system
  22. return [x, -y, depth]
  23. return [0, 0, 0]
  24. def get_image(event: dict) -> np.ndarray:
  25. """Convert the image from the event to a numpy array.
  26. Args:
  27. event (dict): The event containing the image data.
  28. """
  29. storage = event["value"]
  30. metadata = event["metadata"]
  31. encoding = metadata["encoding"]
  32. width = metadata["width"]
  33. height = metadata["height"]
  34. if (
  35. encoding == "bgr8"
  36. or encoding == "rgb8"
  37. or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]
  38. ):
  39. channels = 3
  40. storage_type = np.uint8
  41. else:
  42. raise RuntimeError(f"Unsupported image encoding: {encoding}")
  43. if encoding == "bgr8":
  44. frame = (
  45. storage.to_numpy().astype(storage_type).reshape((height, width, channels))
  46. )
  47. frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
  48. elif encoding == "rgb8":
  49. frame = (
  50. storage.to_numpy().astype(storage_type).reshape((height, width, channels))
  51. )
  52. elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
  53. storage = storage.to_numpy()
  54. frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
  55. frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
  56. else:
  57. raise RuntimeError(f"Unsupported image encoding: {encoding}")
  58. return frame
  59. def main():
  60. """TODO: Add docstring."""
  61. node = Node()
  62. depth = None
  63. focal_length = None
  64. resolution = None
  65. for event in node:
  66. if event["type"] == "INPUT":
  67. event_id = event["id"]
  68. if "image" in event_id:
  69. rgb_image = get_image(event)
  70. width = rgb_image.shape[1]
  71. height = rgb_image.shape[0]
  72. pose_results = pose.process(rgb_image)
  73. if pose_results.pose_landmarks:
  74. values = pose_results.pose_landmarks.landmark
  75. values = np.array(
  76. [
  77. [landmark.x * width, landmark.y * height]
  78. for landmark in pose_results.pose_landmarks.landmark
  79. ]
  80. )
  81. # Warning: Make sure to add my_output_id and my_input_id within the dataflow.
  82. node.send_output(
  83. output_id="points2d",
  84. data=pa.array(values.ravel()),
  85. metadata={},
  86. )
  87. if depth is not None:
  88. values = np.array(
  89. [
  90. get_3d_coordinates(
  91. landmark,
  92. depth,
  93. width,
  94. height,
  95. resolution,
  96. focal_length,
  97. )
  98. for landmark in pose_results.pose_landmarks.landmark
  99. ]
  100. )
  101. # Warning: Make sure to add my_output_id and my_input_id within the dataflow.
  102. node.send_output(
  103. output_id="points3d",
  104. data=pa.array(values.ravel()),
  105. metadata={},
  106. )
  107. else:
  108. print("No pose landmarks detected.")
  109. elif "depth" in event_id:
  110. metadata = event["metadata"]
  111. _encoding = metadata["encoding"]
  112. width = metadata["width"]
  113. height = metadata["height"]
  114. focal_length = metadata["focal_length"]
  115. resolution = metadata["resolution"]
  116. depth = event["value"].to_numpy().reshape((height, width))
  117. if __name__ == "__main__":
  118. main()