You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_trace_analyser.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """The StepTraceAnalyser analyser class."""
  16. import csv
  17. import json
  18. import os
  19. from mindinsight.datavisual.utils.tools import to_int
  20. from mindinsight.profiler.analyser.base_analyser import BaseAnalyser
  21. from mindinsight.profiler.common.exceptions.exceptions import ProfilerParamValueErrorException, \
  22. ProfilerFileNotFoundException, StepNumNotSupportedException, ProfilerRawFileException
  23. from mindinsight.profiler.common.log import logger as log
  24. from mindinsight.profiler.common.util import query_latest_trace_time_file, get_field_value, \
  25. get_summary_for_step_trace, to_millisecond
  26. from mindinsight.profiler.common.validator.validate_path import validate_and_normalize_path
  27. class StepTraceAnalyser(BaseAnalyser):
  28. """The analyser for analyzing training steps."""
  29. _col_names = []
  30. _attr_ui_name = 'name'
  31. _attr_ui_start = 'start'
  32. _attr_ui_duration = 'duration'
  33. _point_info = {}
  34. @property
  35. def summary(self):
  36. """The property of summary info."""
  37. summary = get_summary_for_step_trace(self._data[-1], self.__column__)
  38. summary['total_steps'] = self._size
  39. return summary
  40. @property
  41. def point_info(self):
  42. """The property of point info."""
  43. return self._point_info
  44. def query(self, condition=None):
  45. """
  46. Query data according to the condition.
  47. Args:
  48. condition (dict): The search condition, only contains `filter_condition` parameter.
  49. Default: None.
  50. Returns:
  51. dict, the result after filtered, sorted and grouped.
  52. """
  53. if condition is None:
  54. condition = {}
  55. filter_condition = condition.get('filter_condition', {})
  56. log.info("Receive query request. %s", filter_condition)
  57. self._validate_filter_condition(filter_condition)
  58. self._result = {'size': self._size}
  59. self._filter(filter_condition)
  60. return self._result
  61. def query_for_all_reduce(self):
  62. """
  63. Query for all reduce info.
  64. Returns:
  65. list[dict], reduce information. Each item is the reduce info for one step.
  66. The reduce info is format like:
  67. {stream_id: List[Tuple(start_point, end_point, duration, field_name)]}.
  68. """
  69. reduce_infos = []
  70. for row_info in self._data[:-1]:
  71. row_info_dict = self._get_info_dict_from_row_data(row_info, 'systime')
  72. reduce_info = self._sort_reduce_by_time(row_info_dict)
  73. if reduce_info:
  74. reduce_infos.extend(reduce_info)
  75. return reduce_infos
  76. def _load(self):
  77. """Load data according to the parsed AICORE operator types file."""
  78. file_path = query_latest_trace_time_file(self._profiling_dir, self._device_id)
  79. if not file_path:
  80. log.error("Failed to find parsed trace time file.")
  81. raise ProfilerFileNotFoundException('parsed step trace time file')
  82. file_path = validate_and_normalize_path(
  83. file_path, raise_key="Invalid latest_trace_trace_time file path.")
  84. with open(file_path, 'r') as handle:
  85. csv_reader = csv.reader(handle)
  86. self.__column__ = next(csv_reader)
  87. self._data = list(csv_reader)
  88. self._size = len(self._data) - 1
  89. self._display_col_names = self._col_names[:]
  90. self._load_point_info()
  91. def _load_point_info(self):
  92. """Load point info."""
  93. file_path = os.path.join(self._profiling_dir, 'step_trace_point_info.json')
  94. file_path = validate_and_normalize_path(
  95. file_path, raise_key="Invalid step_trace_point_info file path.")
  96. if os.path.isfile(file_path):
  97. with open(file_path, 'r', encoding='utf-8') as file:
  98. try:
  99. self._point_info = json.load(file)
  100. except (json.JSONDecodeError, TypeError) as err:
  101. log.exception(err)
  102. raise ProfilerRawFileException('Fail to parse point info file.')
  103. def _filter(self, filter_condition):
  104. """
  105. Filter the profiling data according to the filter condition.
  106. Args:
  107. filter_condition (dict): The filter condition.
  108. - mode (str): The kind of information. `step` return the info about specific
  109. step. `proc` return the info about specific field in parsed trace file.
  110. - step_id (int): The selected step_id. If not given, it means all steps is required.
  111. If the value is 0, it means average info for all steps except the first is
  112. required.
  113. - proc_name (str): The selected field name.
  114. - time_type (str): The value type. `systime` keeps the original value.
  115. `realtime` transforms the value in millisecond. Default: `realtime`.
  116. """
  117. mode = filter_condition.get('mode', 'step')
  118. if mode == 'step':
  119. self._get_step_details(step_id=filter_condition.get('step_id'),
  120. time_type=filter_condition.get('time_type', 'realtime'))
  121. else:
  122. self._get_proc_details(step_id=filter_condition.get('step_id'),
  123. proc_name=filter_condition.get('proc_name'),
  124. time_type=filter_condition.get('time_type', 'realtime'))
  125. def _construct_time_point(self, name, start, duration):
  126. """Construct time point."""
  127. point = {}
  128. if start >= 0 and duration >= 0:
  129. point = {
  130. self._attr_ui_name: name,
  131. self._attr_ui_start: round(start, 4),
  132. self._attr_ui_duration: round(duration, 4)
  133. }
  134. else:
  135. log.warning("Not invalid point info: "
  136. "name: %s, start: %s, duration: %s", name, start, duration)
  137. return point
  138. def _get_step_details(self, step_id, time_type='realtime'):
  139. """
  140. Get step trace info for selected step and save the result.
  141. Args:
  142. step_id (int): The selected step_id. If the value is 0, it means average info
  143. for all steps except the first is required.
  144. time_type (str): The value type. `systime` keeps the original value.
  145. `realtime` transforms the value in millisecond. Default: `realtime`.
  146. """
  147. if step_id is None:
  148. step_id = 0
  149. row_info = self._data[step_id - 1]
  150. row_info_dict = self._get_info_dict_from_row_data(row_info, time_type)
  151. # first line only contains total time
  152. first_line = [self._construct_time_point('', 0, row_info_dict.get('total', 0))]
  153. # second line contains iteration_interval, fp_and_bp and tail
  154. second_line = self._get_main_proc_points(row_info_dict)
  155. # construct reduces lines
  156. reduce_lines = self._construct_reduce_lines(row_info_dict)
  157. graph = [first_line, second_line]
  158. graph.extend(reduce_lines)
  159. self._result['training_trace_graph'] = graph
  160. def _get_info_dict_from_row_data(self, row_info, time_type):
  161. """
  162. Get step info in dict format.
  163. Args:
  164. row_info (list[str]): Step info, the value is corresponding to `__column__`.
  165. time_type (str): The value type. `systime` keeps the original value.
  166. `realtime` transforms the value in millisecond. Default: `realtime`.
  167. Returns:
  168. dict, step trace information. The key is in `__column__`.
  169. """
  170. row_info_dict = {}
  171. for key, value in zip(self.__column__, row_info):
  172. if key == 'step_num':
  173. continue
  174. value = to_int(value, key)
  175. row_info_dict[key] = to_millisecond(value) if time_type == 'realtime' else value
  176. return row_info_dict
  177. def _get_main_proc_points(self, row_info_dict):
  178. """
  179. Get iteration_interval, fp_and_bp and tail points.
  180. Args:
  181. row_info_dict (dict): Step trace information.
  182. Returns:
  183. list[dict], the list of time points.
  184. """
  185. start_point = row_info_dict.get('start_point', 0)
  186. fp_point = row_info_dict.get('fp_point', 0)
  187. bp_point = row_info_dict.get('bp_point', 0)
  188. points_part = [
  189. self._construct_time_point(
  190. 'iteration_interval', 0, row_info_dict.get('iteration_interval', 0)),
  191. ]
  192. # if fp key exist, inference scene
  193. if 'fp' in row_info_dict.keys():
  194. points = [
  195. self._construct_time_point(
  196. 'fp', fp_point - start_point, row_info_dict.get('fp', 0)),
  197. ]
  198. # training scene
  199. else:
  200. points = [
  201. self._construct_time_point(
  202. 'fp_and_bp', fp_point - start_point, row_info_dict.get('fp_and_bp', 0)),
  203. self._construct_time_point('tail', bp_point - start_point, row_info_dict.get('tail', 0))
  204. ]
  205. points = points_part + points
  206. return points
  207. def _get_reduce_time_in_order(self, row_info_dict):
  208. """
  209. Get reduce time in order.
  210. Args:
  211. row_info_dict (dict): Step trace information.
  212. Returns:
  213. dict, sorted reduce information. The reduce info is format like:
  214. {stream_id: List[Tuple(start_point, end_point, duration, field_name)]}
  215. """
  216. reduce_info = {}
  217. reduce_fields = [field_name for field_name in self.__column__
  218. if field_name.startswith('stream_') and not field_name.endswith('point')]
  219. for reduce_field in reduce_fields:
  220. reduce_start = row_info_dict.get(reduce_field + '_start_point', 0)
  221. reduce_end = row_info_dict.get(reduce_field + '_end_point', 0)
  222. reduce_duration = row_info_dict.get(reduce_field, 0)
  223. if not (reduce_start and reduce_end and reduce_duration):
  224. log.info("Reduce event missing value.")
  225. continue
  226. cur_stream_id = reduce_field.split('_', 2)[1]
  227. cur_stream = reduce_info.get(cur_stream_id)
  228. if not cur_stream:
  229. cur_stream = []
  230. reduce_info[cur_stream_id] = cur_stream
  231. cur_stream.append((reduce_start, reduce_end, reduce_duration, reduce_field))
  232. for _, reduce_events in reduce_info.items():
  233. reduce_events.sort(key=lambda elem: elem[1])
  234. return reduce_info
  235. def _sort_reduce_by_time(self, row_info_dict):
  236. """
  237. Sort reduce info by time.
  238. Args:
  239. row_info_dict (dict): Step trace information.
  240. Returns:
  241. list, including the all reduce info sorted by start time only.
  242. [
  243. [reduce_field, stream_id, reduce_start, reduce_duration],
  244. [...],
  245. [...]
  246. ]
  247. """
  248. factor = 1e5 # convert time unit from 10ns to 1ms
  249. reduce_pid = 10000
  250. reduce_info = []
  251. reduce_fields = [field_name for field_name in self.__column__
  252. if field_name.startswith('stream_') and not field_name.endswith('point')]
  253. for reduce_field in reduce_fields:
  254. reduce_start = row_info_dict.get(reduce_field + '_start_point')
  255. reduce_start = reduce_start / factor \
  256. if reduce_start else 0
  257. reduce_duration = row_info_dict.get(reduce_field)
  258. reduce_duration = reduce_duration / factor if reduce_duration else 0
  259. if not (reduce_start and reduce_duration):
  260. log.info("Reduce event missing value.")
  261. continue
  262. cur_stream_id = reduce_field.split('_', 2)[1]
  263. reduce_meta = [reduce_field, int(cur_stream_id), reduce_start,
  264. reduce_duration, reduce_pid]
  265. reduce_info.append(reduce_meta)
  266. return reduce_info
  267. def _construct_reduce_lines(self, row_info_dict):
  268. """
  269. Construct first line in detailed graph.
  270. Args:
  271. row_info_dict (dict): Step trace information.
  272. Returns:
  273. list, list of reduce information of each stream. Each item is a list of time points.
  274. """
  275. reduce_lines = []
  276. start_point = row_info_dict.get('start_point', 0)
  277. fp_point = row_info_dict.get('fp_point', 0)
  278. end_point = row_info_dict.get('end_point', 0)
  279. reduce_info = self._get_reduce_time_in_order(row_info_dict)
  280. # construct time point for each line
  281. for _, reduce_events in reduce_info.items():
  282. current_line = self._construct_reduce_line(
  283. start_point, end_point, fp_point, reduce_events)
  284. reduce_lines.append(current_line)
  285. return reduce_lines
  286. def _construct_reduce_line(self, start_point, end_point, fp_point, reduce_events):
  287. """
  288. Construct list of time points for reduce line.
  289. Args:
  290. start_point (int): The start point of current step.
  291. end_point (int): The end point of current step.
  292. fp_point (int): The fp point of current step.
  293. reduce_events (list[Tuple]): The reduce information of current step. Each item
  294. contains the start, end duration and name of one reduce event.
  295. Returns:
  296. list[dict], list of time points.
  297. """
  298. current_line = []
  299. previous_start = fp_point
  300. for start, end, duration, field_name in reduce_events:
  301. current_line.extend([
  302. self._construct_time_point(
  303. '', previous_start - start_point, start - previous_start),
  304. self._construct_time_point(
  305. field_name, start - start_point, duration)
  306. ])
  307. previous_start = end
  308. current_line.append(self._construct_time_point(
  309. '', previous_start - start_point, end_point - previous_start))
  310. return current_line
  311. def _get_proc_details(self, proc_name, step_id=None, time_type='realtime'):
  312. """
  313. Get step trace info for selected step and save the result.
  314. Args:
  315. proc_name (str): The selected field name.
  316. step_id (int): The selected step_id. If not given, it means all steps is required.
  317. If the value is 0, it means average info for all steps except the first is
  318. required. Default: None.
  319. time_type (str): The value type. `systime` keeps the original value.
  320. `realtime` transforms the value in millisecond. Default: `realtime`.
  321. """
  322. if proc_name is None:
  323. log.error('`proc_name` is required for query.')
  324. raise ProfilerParamValueErrorException('`proc_name` is required for query.')
  325. if step_id is None:
  326. rows_info = self._data[:-1]
  327. else:
  328. rows_info = [self._data[step_id - 1]]
  329. proc_info = [get_field_value(row_info, proc_name, self.__column__, time_type)
  330. for row_info in rows_info]
  331. self._result['info'] = {proc_name: proc_info}
  332. def _validate_filter_condition(self, filter_condition):
  333. """Validate step trace filter_condition."""
  334. mode = filter_condition.get('mode', 'step')
  335. self._validate_str_param(mode, ['step', 'proc'], 'mode')
  336. step_id = filter_condition.get('step_id')
  337. self._validate_step_id(step_id)
  338. proc_name = filter_condition.get('proc_name')
  339. self._validate_str_param(proc_name, self.__column__, 'proc_name')
  340. time_type = filter_condition.get('time_type', 'realtime')
  341. self._validate_str_param(time_type, ['realtime', 'systime'], 'time_type')
  342. def _validate_step_id(self, step_id):
  343. """Validate step_id."""
  344. if step_id is None or isinstance(step_id, int) and 0 <= step_id <= self._size:
  345. return
  346. log.error("Invalid step_id in request. step_id should be in [0, %d].", self._size)
  347. raise StepNumNotSupportedException([0, self._size])
  348. @staticmethod
  349. def _validate_str_param(proc_name, accept_param, error_name=''):
  350. """Validate proc_name."""
  351. if proc_name is None or isinstance(proc_name, str) and proc_name in accept_param:
  352. return
  353. log.error("Invalid param %s in request. Acceptable value is %s.", error_name, accept_param)
  354. raise ProfilerParamValueErrorException(f"Invalid {error_name}.")