You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

minddata_pipeline_parser.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Thr parser for parsing minddata pipeline files."""
  16. import csv
  17. import json
  18. import os
  19. from queue import Queue
  20. from mindspore.profiler.common.exceptions.exceptions import \
  21. ProfilerPathErrorException, ProfilerFileNotFoundException, \
  22. ProfilerDirNotFoundException, ProfilerRawFileException
  23. from mindspore import log as logger
  24. from mindspore.profiler.common.validator.validate_path import \
  25. validate_and_normalize_path
  26. class MinddataPipelineParser:
  27. """
  28. Thr parser for parsing minddata pipeline files.
  29. Args:
  30. source_dir (str): The minddata pipeline source dir.
  31. device_id (str): The device ID.
  32. output_path (str): The directory of the parsed file. Default: `./`.
  33. Raises:
  34. ProfilerPathErrorException: If the minddata pipeline file path or
  35. the output path is invalid.
  36. ProfilerFileNotFoundException: If the minddata pipeline file or
  37. the output dir does not exist.
  38. """
  39. _raw_pipeline_file_name = 'pipeline_profiling_{}.json'
  40. _parsed_pipeline_file_name = 'minddata_pipeline_raw_{}.csv'
  41. _col_names = [
  42. 'op_id', 'op_type', 'num_workers', 'output_queue_size',
  43. 'output_queue_average_size', 'output_queue_length',
  44. 'output_queue_usage_rate', 'sample_interval', 'parent_id', 'children_id'
  45. ]
  46. def __init__(self, source_dir, device_id, output_path='./'):
  47. self._device_id = device_id
  48. self._pipeline_path = self._get_pipeline_path(source_dir)
  49. self._save_path = self._get_save_path(output_path)
  50. @property
  51. def save_path(self):
  52. """
  53. The property of save path.
  54. Returns:
  55. str, the save path.
  56. """
  57. return self._save_path
  58. def parse(self):
  59. """
  60. Parse the minddata pipeline files.
  61. Raises:
  62. ProfilerRawFileException: If fails to parse the raw file of
  63. minddata pipeline or the file is empty.
  64. """
  65. with open(self._pipeline_path, 'r') as file:
  66. try:
  67. pipeline_info = json.load(file)
  68. except (json.JSONDecodeError, TypeError) as err:
  69. logger.warning(err)
  70. raise ProfilerRawFileException(
  71. 'Fail to parse minddata pipeline file.'
  72. )
  73. if not pipeline_info:
  74. logger.warning('The minddata pipeline file is empty.')
  75. raise ProfilerRawFileException(
  76. 'The minddata pipeline file is empty.'
  77. )
  78. self._parse_and_save(pipeline_info)
  79. def _get_pipeline_path(self, source_dir):
  80. """
  81. Get the minddata pipeline file path.
  82. Args:
  83. source_dir (str): The minddata pipeline source dir.
  84. Returns:
  85. str, the minddata pipeline file path.
  86. """
  87. pipeline_path = os.path.join(
  88. source_dir,
  89. self._raw_pipeline_file_name.format(self._device_id)
  90. )
  91. try:
  92. pipeline_path = validate_and_normalize_path(pipeline_path)
  93. except RuntimeError:
  94. logger.warning('Minddata pipeline file is invalid.')
  95. raise ProfilerPathErrorException('Minddata pipeline file is invalid.')
  96. if not os.path.isfile(pipeline_path):
  97. logger.warning(
  98. 'The minddata pipeline file <%s> not found.', pipeline_path
  99. )
  100. raise ProfilerFileNotFoundException(pipeline_path)
  101. return pipeline_path
  102. def _get_save_path(self, output_path):
  103. """
  104. Get the save path.
  105. Args:
  106. output_path (str): The output dir.
  107. Returns:
  108. str, the save path.
  109. """
  110. try:
  111. output_dir = validate_and_normalize_path(output_path)
  112. except ValidationError:
  113. logger.warning('Output path is invalid.')
  114. raise ProfilerPathErrorException('Output path is invalid.')
  115. if not os.path.isdir(output_dir):
  116. logger.warning('The output dir <%s> not found.', output_dir)
  117. raise ProfilerDirNotFoundException(output_dir)
  118. return os.path.join(
  119. output_dir, self._parsed_pipeline_file_name.format(self._device_id)
  120. )
  121. def _parse_and_save(self, pipeline_info):
  122. """
  123. Parse and save the parsed minddata pipeline file.
  124. Args:
  125. pipeline_info (dict): The pipeline info reads from the raw file of
  126. the minddata pipeline.
  127. Raises:
  128. ProfilerRawFileException: If the format of minddata pipeline raw
  129. file is wrong.
  130. """
  131. sample_interval = pipeline_info.get('sampling_interval')
  132. op_info = pipeline_info.get('op_info')
  133. if sample_interval is None or not op_info:
  134. raise ProfilerRawFileException(
  135. 'The format of minddata pipeline raw file is wrong.'
  136. )
  137. op_id_info_cache = {}
  138. for item in op_info:
  139. op_id_info_cache[item.get('op_id')] = item
  140. with open(self._save_path, 'w') as save_file:
  141. csv_writer = csv.writer(save_file)
  142. csv_writer.writerow(self._col_names)
  143. self._parse_and_save_op_info(
  144. csv_writer, op_id_info_cache, sample_interval
  145. )
  146. def _parse_and_save_op_info(self, csv_writer, op_id_info_cache,
  147. sample_interval):
  148. """
  149. Parse and save the minddata pipeline operator information.
  150. Args:
  151. csv_writer (csv.writer): The csv writer.
  152. op_id_info_cache (dict): The operator id and information cache.
  153. sample_interval (int): The sample interval.
  154. Raises:
  155. ProfilerRawFileException: If the operator that id is 0 does not exist.
  156. """
  157. queue = Queue()
  158. root_node = op_id_info_cache.get(0)
  159. if not root_node:
  160. raise ProfilerRawFileException(
  161. 'The format of minddata pipeline raw file is wrong, '
  162. 'the operator that id is 0 does not exist.'
  163. )
  164. root_node['parent_id'] = None
  165. queue.put_nowait(root_node)
  166. while not queue.empty():
  167. node = queue.get_nowait()
  168. self._update_child_node(node, op_id_info_cache)
  169. csv_writer.writerow(self._get_op_info(node, sample_interval))
  170. op_id = node.get('op_id')
  171. children_ids = node.get('children')
  172. if not children_ids:
  173. continue
  174. for child_op_id in children_ids:
  175. sub_node = op_id_info_cache.get(child_op_id)
  176. sub_node['parent_id'] = op_id
  177. queue.put_nowait(sub_node)
  178. def _update_child_node(self, node, op_id_info_cache):
  179. """
  180. Updates the child node information of the operator.
  181. Args:
  182. node (dict): The node represents an operator.
  183. op_id_info_cache (dict): The operator id and information cache.
  184. """
  185. child_op_ids = node.get('children')
  186. if not child_op_ids:
  187. return
  188. queue = Queue()
  189. self._cp_list_item_to_queue(child_op_ids, queue)
  190. new_child_op_ids = []
  191. while not queue.empty():
  192. child_op_id = queue.get_nowait()
  193. child_node = op_id_info_cache.get(child_op_id)
  194. if child_node is None:
  195. continue
  196. metrics = child_node.get('metrics')
  197. if not metrics or not metrics.get('output_queue'):
  198. op_ids = child_node.get('children')
  199. if op_ids:
  200. self._cp_list_item_to_queue(op_ids, queue)
  201. else:
  202. new_child_op_ids.append(child_op_id)
  203. node['children'] = new_child_op_ids
  204. def _get_op_info(self, op_node, sample_interval):
  205. """
  206. Get the operator information.
  207. Args:
  208. op_node (dict): The node represents an operator.
  209. sample_interval (int): The sample interval.
  210. Returns:
  211. list[str, int, float], the operator information.
  212. """
  213. queue_size = None
  214. queue_average_size = None
  215. queue_length = None
  216. queue_usage_rate = None
  217. metrics = op_node.get('metrics')
  218. if metrics:
  219. output_queue = metrics.get('output_queue')
  220. if output_queue:
  221. queue_size = output_queue.get('size')
  222. queue_average_size = sum(queue_size) / len(queue_size)
  223. queue_length = output_queue.get('length')
  224. queue_usage_rate = queue_average_size / queue_length
  225. children_id = op_node.get('children')
  226. op_info = [
  227. op_node.get('op_id'),
  228. op_node.get('op_type'),
  229. op_node.get('num_workers'),
  230. queue_size,
  231. queue_average_size,
  232. queue_length,
  233. queue_usage_rate,
  234. sample_interval,
  235. op_node.get('parent_id'),
  236. children_id if children_id else None
  237. ]
  238. return op_info
  239. def _cp_list_item_to_queue(self, inner_list, queue):
  240. """
  241. Copy the contents of a list to a queue.
  242. Args:
  243. inner_list (list): The list.
  244. queue (Queue): The target queue.
  245. """
  246. for item in inner_list:
  247. queue.put_nowait(item)