You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark_script.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """TODO: Add docstring."""
  2. import argparse
  3. import ast
  4. # Create an empty csv file with header in the current directory if file does not exist
  5. import csv
  6. import os
  7. import time
  8. from io import BytesIO
  9. import cv2
  10. import librosa
  11. import numpy as np
  12. import pyarrow as pa
  13. import requests
  14. from dora import Node
  15. from PIL import Image
  16. CAT_URL = "https://i.ytimg.com/vi/fzzjgBAaWZw/hqdefault.jpg"
  17. def get_cat_image():
  18. """
  19. Get a cat image as a numpy array.
  20. :return: Cat image as a numpy array.
  21. """
  22. # Fetch the image from the URL
  23. response = requests.get(CAT_URL)
  24. response.raise_for_status()
  25. # Open the image using PIL
  26. image = Image.open(BytesIO(response.content))
  27. # Convert the image to a numpy array
  28. image_array = np.array(image)
  29. cv2.resize(image_array, (640, 480))
  30. # Convert RGB to BGR for
  31. return image_array
  32. AUDIO_URL = "https://github.com/dora-rs/dora-rs.github.io/raw/refs/heads/main/static/Voicy_C3PO%20-Don't%20follow%20me.mp3"
  33. def get_c3po_audio():
  34. """
  35. Download the C-3PO audio and load it into a NumPy array using librosa.
  36. """
  37. # Download the audio file
  38. response = requests.get(AUDIO_URL)
  39. if response.status_code != 200:
  40. raise Exception(
  41. f"Failed to download audio file. Status code: {response.status_code}"
  42. )
  43. # Save the audio file temporarily
  44. temp_audio_file = "temp_audio.mp3"
  45. with open(temp_audio_file, "wb") as f:
  46. f.write(response.content)
  47. # Load the audio file into a NumPy array using librosa
  48. audio_data, sample_rate = librosa.load(temp_audio_file, sr=None)
  49. # Optionally, you can remove the temporary file after loading
  50. os.remove(temp_audio_file)
  51. return audio_data, sample_rate
  52. def write_to_csv(filename, header, row):
  53. """
  54. Create a CSV file with a header if it does not exist, and write a row to it.
  55. If the file exists, append the row to the file.
  56. :param filename: Name of the CSV file.
  57. :param header: List of column names to use as the header.
  58. :param row: List of data to write as a row in the CSV file.
  59. """
  60. file_exists = os.path.exists(filename)
  61. with open(
  62. filename, mode="a" if file_exists else "w", newline="", encoding="utf8"
  63. ) as file:
  64. writer = csv.writer(file)
  65. # Write the header if the file is being created
  66. if not file_exists:
  67. writer.writerow(header)
  68. print(f"File '{filename}' created with header: {header}")
  69. # Write the row
  70. writer.writerow(row)
  71. print(f"Row written to '{filename}': {row}")
  72. def main():
  73. # Handle dynamic nodes, ask for the name of the node in the dataflow, and the same values as the ENV variables.
  74. """TODO: Add docstring."""
  75. parser = argparse.ArgumentParser(description="Simple arrow sender")
  76. parser.add_argument(
  77. "--name",
  78. type=str,
  79. required=False,
  80. help="The name of the node in the dataflow.",
  81. default="pyarrow-sender",
  82. )
  83. parser.add_argument(
  84. "--text",
  85. type=str,
  86. required=False,
  87. help="Arrow Data as string.",
  88. default=None,
  89. )
  90. args = parser.parse_args()
  91. text = os.getenv("TEXT", args.text)
  92. text_truth = os.getenv("TEXT_TRUTH", args.text)
  93. cat = get_cat_image()
  94. audio, sample_rate = get_c3po_audio()
  95. if text is None:
  96. raise ValueError(
  97. "No data provided. Please specify `TEXT` environment argument or as `--text` argument",
  98. )
  99. try:
  100. text = ast.literal_eval(text)
  101. except Exception: # noqa
  102. print("Passing input as string")
  103. if isinstance(text, (str, int, float)):
  104. text = pa.array([text])
  105. else:
  106. text = pa.array(text) # initialize pyarrow array
  107. node = Node(
  108. args.name,
  109. ) # provide the name to connect to the dataflow if dynamic node
  110. name = node.dataflow_descriptor()["nodes"][1]["path"]
  111. durations = []
  112. speed = []
  113. for _ in range(10):
  114. node.send_output(
  115. "image",
  116. pa.array(cat.ravel()),
  117. {"encoding": "rgb8", "width": cat.shape[1], "height": cat.shape[0]},
  118. )
  119. node.send_output(
  120. "audio",
  121. pa.array(audio.ravel()),
  122. {"sample_rate": sample_rate},
  123. )
  124. time.sleep(0.1)
  125. start_time = time.time()
  126. node.send_output("text", text)
  127. event = node.next()
  128. duration = time.time() - start_time
  129. if event is not None and event["type"] == "INPUT":
  130. received_text = event["value"][0].as_py()
  131. tokens = event["metadata"].get("tokens", 6)
  132. assert text_truth in received_text, (
  133. f"Expected '{text_truth}', got {received_text}"
  134. )
  135. durations.append(duration)
  136. speed.append(tokens / duration)
  137. time.sleep(0.1)
  138. durations = np.array(durations)
  139. speed = np.array(speed)
  140. print(
  141. f"\nAverage duration: {sum(durations) / len(durations)}"
  142. + f"\nMax duration: {max(durations)}"
  143. + f"\nMin duration: {min(durations)}"
  144. + f"\nMedian duration: {np.median(durations)}"
  145. + f"\nMedian frequency: {1 / np.median(durations)}"
  146. + f"\nAverage speed: {sum(speed) / len(speed)}"
  147. + f"\nMax speed: {max(speed)}"
  148. + f"\nMin speed: {min(speed)}"
  149. + f"\nMedian speed: {np.median(speed)}"
  150. + f"\nTotal tokens: {tokens}"
  151. )
  152. write_to_csv(
  153. "benchmark.csv",
  154. [
  155. "path",
  156. "date",
  157. "average_duration(s)",
  158. "max_duration(s)",
  159. "min_duration(s)",
  160. "median_duration(s)",
  161. "median_frequency(Hz)",
  162. "average_speed(tok/s)",
  163. "max_speed(tok/s)",
  164. "min_speed(tok/s)",
  165. "median_speed(tok/s)",
  166. "total_tokens",
  167. ],
  168. [
  169. name,
  170. time.strftime("%Y-%m-%d %H:%M:%S"),
  171. sum(durations) / len(durations),
  172. max(durations),
  173. min(durations),
  174. np.median(durations),
  175. 1 / np.median(durations),
  176. sum(speed) / len(speed),
  177. max(speed),
  178. min(speed),
  179. np.median(speed),
  180. tokens,
  181. ],
  182. )
  183. if __name__ == "__main__":
  184. main()