You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework_parser.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Thr parser for parsing framework files."""
  16. import csv
  17. import enum
  18. import json
  19. import os
  20. import re
  21. import stat
  22. from mindspore.profiler.common.exceptions.exceptions import \
  23. ProfilerPathErrorException, ProfilerDirNotFoundException, \
  24. ProfilerFileNotFoundException, ProfilerDeviceIdMismatchException, \
  25. ProfilerRawFileException, ProfilerParamValueErrorException
  26. from mindspore.profiler.common.validator.validate_path import \
  27. validate_and_normalize_path
  28. class VmDataType(enum.IntEnum):
  29. """Definition of vm data type."""
  30. NUMBER_TYPE_BEGIN = 30
  31. NUMBER_TYPE_BOOL = 31
  32. NUMBER_TYPE_INT = 32
  33. NUMBER_TYPE_INT8 = 33
  34. NUMBER_TYPE_INT16 = 34
  35. NUMBER_TYPE_INT32 = 35
  36. NUMBER_TYPE_INT64 = 36
  37. NUMBER_TYPE_UINT = 37
  38. NUMBER_TYPE_UINT8 = 38
  39. NUMBER_TYPE_UINT16 = 39
  40. NUMBER_TYPE_UINT32 = 40
  41. NUMBER_TYPE_UINT64 = 41
  42. NUMBER_TYPE_FLOAT = 42
  43. NUMBER_TYPE_FLOAT16 = 43
  44. NUMBER_TYPE_FLOAT32 = 44
  45. NUMBER_TYPE_FLOAT64 = 45
  46. NUMBER_TYPE_COMPLEX = 46
  47. NUMBER_TYPE_END = 47
  48. @classmethod
  49. def get_data_type_name(cls, num):
  50. """
  51. Get the name of data type by enum number.
  52. Args:
  53. num (int): Enum number.
  54. Returns:
  55. str, the name of data type.
  56. """
  57. data_type = cls._value2member_map_.get(num)
  58. return 'UNKNOWN' if data_type is None else data_type.name
  59. class GeDataType(enum.IntEnum):
  60. """Definition of ge data type."""
  61. DT_FLOAT = 0
  62. DT_FLOAT16 = 1
  63. DT_INT8 = 2
  64. DT_INT16 = 6
  65. DT_UINT16 = 7
  66. DT_UINT8 = 4
  67. DT_INT32 = 3
  68. DT_INT64 = 9
  69. DT_UINT32 = 8
  70. DT_UINT64 = 10
  71. DT_BOOL = 12
  72. DT_DOUBLE = 11
  73. DT_STRING = 13
  74. DT_DUAL_SUB_INT8 = 14
  75. DT_DUAL_SUB_UINT8 = 15
  76. DT_COMPLEX64 = 16
  77. DT_COMPLEX128 = 17
  78. DT_QINT8 = 18
  79. DT_QINT16 = 19
  80. DT_QINT32 = 20
  81. DT_QUINT8 = 21
  82. DT_QUINT16 = 22
  83. DT_RESOURCE = 23
  84. DT_STRING_REF = 24
  85. DT_DUAL = 25
  86. DT_UNDEFINED = 26
  87. @classmethod
  88. def get_data_type_name(cls, num):
  89. """
  90. Get the name of data type by enum number.
  91. Args:
  92. num (int): Enum number.
  93. Returns:
  94. str, the name of data type.
  95. """
  96. data_type = cls._value2member_map_.get(num)
  97. return 'UNKNOWN' if data_type is None else data_type.name
  98. class GeFormat(enum.IntEnum):
  99. """Definition of ge format type."""
  100. FORMAT_NCHW = 0
  101. FORMAT_NHWC = 1
  102. FORMAT_ND = 2
  103. FORMAT_NC1HWC0 = 3
  104. FORMAT_FRACTAL_Z = 4
  105. FORMAT_NC1C0HWPAD = 5
  106. FORMAT_NHWC1C0 = 6
  107. FORMAT_FSR_NCHW = 7
  108. FORMAT_FRACTAL_DECONV = 8
  109. FORMAT_C1HWNC0 = 9
  110. FORMAT_FRACTAL_DECONV_TRANSPOSE = 10
  111. FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11
  112. FORMAT_NC1HWC0_C04 = 12
  113. FORMAT_FRACTAL_Z_C04 = 13
  114. FORMAT_CHWN = 14
  115. FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15
  116. FORMAT_HWCN = 16
  117. FORMAT_NC1KHKWHWC0 = 17
  118. FORMAT_BN_WEIGHT = 18
  119. FORMAT_FILTER_HWCK = 19
  120. FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20
  121. FORMAT_HASHTABLE_LOOKUP_KEYS = 21
  122. FORMAT_HASHTABLE_LOOKUP_VALUE = 22
  123. FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23
  124. FORMAT_HASHTABLE_LOOKUP_HITS = 24
  125. FORMAT_C1HWNCOC0 = 25
  126. FORMAT_MD = 26
  127. FORMAT_NDHWC = 27
  128. FORMAT_FRACTAL_ZZ = 28
  129. FORMAT_FRACTAL_NZ = 29
  130. FORMAT_NCDHW = 30
  131. FORMAT_DHWCN = 31
  132. FORMAT_NDC1HWC0 = 32
  133. FORMAT_FRACTAL_Z_3D = 33
  134. FORMAT_CN = 34
  135. FORMAT_NC = 35
  136. FORMAT_DHWNC = 36
  137. FORMAT_FRACTAL_Z_3D_TRANSPOSE = 37
  138. FORMAT_RESERVED = 38
  139. FORMAT_ALL = 39
  140. @classmethod
  141. def get_format_name(cls, num):
  142. """
  143. Get the name of format type by enum number.
  144. Args:
  145. num (int): Enum number.
  146. Returns:
  147. str, the name of format type.
  148. """
  149. format_type = cls._value2member_map_.get(num)
  150. return 'UNKNOWN' if format_type is None else format_type.name
  151. class FrameworkParser:
  152. """
  153. Thr parser for parsing framework files.
  154. Args:
  155. profiling_id (str): The profiling ID.
  156. device_id (str): The device ID.
  157. rank_id (str): The rank ID.
  158. output_path (str): The directory of the parsed file. Default: `./`.
  159. """
  160. _regex_framework = r'Framework\.(?P<data_type>.+)\.(?P<device_id>\d).+'
  161. _regex_framework_in_data = r'Framework\.(?P<data_type>.+)\.' \
  162. r'(?P<device_id>\d)\.(?P<profiling_id>[a-zA-Z0-9]+).+'
  163. _col_names = [
  164. 'task_id', 'stream_id', 'block_dim', 'full_op_name', 'op_name',
  165. 'op_type', 'subgraph', 'op_info'
  166. ]
  167. _graph_attr_name = [
  168. 'input_format', 'input_data_type', 'input_shape', 'output_format',
  169. 'output_data_type', 'output_shape'
  170. ]
  171. # if the task id is less than the task id threshold, The combination of
  172. # task id and Stream id represents one operator, else the task id represents
  173. # one operator
  174. _task_id_threshold = 25000
  175. def __init__(self, profiling_id, device_id, rank_id, output_path='./'):
  176. self._raw_data_dir = output_path
  177. self._profiling_path = self._get_raw_profiling_path(profiling_id)
  178. self._backend_type = None
  179. self._framework_path = {'graph': [], 'task': [], 'point': []}
  180. self._search_file(profiling_id, device_id)
  181. self._device_id = device_id
  182. self._save_path = self._get_save_path(rank_id, output_path)
  183. self._task_id_full_op_name_dict = {}
  184. self._task_cache = {}
  185. self._point_info = {}
  186. self._parse_task_files()
  187. self._parse_point_files()
  188. @property
  189. def save_path(self):
  190. """
  191. The property of save path.
  192. Returns:
  193. str, the save path.
  194. """
  195. return self._save_path
  196. @property
  197. def point_info(self):
  198. """
  199. The property of the framework point information.
  200. Returns:
  201. dict, the framework point information.
  202. """
  203. return self._point_info
  204. def to_task_id_full_op_name_dict(self):
  205. """
  206. Get the task id and full operator name dict.
  207. Returns:
  208. dict, the task id and full operator name dict.
  209. """
  210. return self._task_id_full_op_name_dict
  211. def parse(self):
  212. """Parse the framework files."""
  213. self._parse_graph_files_and_save(self._task_cache)
  214. del self._task_cache
  215. def check_op_name(self, op_name, is_prefix=True):
  216. """
  217. Check whether the operator name exists.
  218. Args:
  219. op_name (str): The operator name or operator name prefix.
  220. is_prefix (bool): `True` if the op_name is prefix, else `False`.
  221. Default: True.
  222. Returns:
  223. bool, `True` if the operator name does exist in framework file, else
  224. `False`.
  225. """
  226. if not op_name:
  227. raise ProfilerParamValueErrorException('The op_name should exist.')
  228. for full_op_name in self._task_id_full_op_name_dict.values():
  229. if full_op_name:
  230. if is_prefix and full_op_name.startswith(op_name):
  231. return True
  232. if not is_prefix and op_name == full_op_name:
  233. return True
  234. return False
  235. def _get_raw_profiling_path(self, profiling_id):
  236. """
  237. Get raw profiling path.
  238. Args:
  239. profiling_id (str): The profiling ID.
  240. Returns:
  241. str, the raw profiling path.
  242. Raises:
  243. ProfilerPathErrorException: If the profiling path is invalid.
  244. ProfilerDirNotFoundException: If the profiling dir is not found.
  245. """
  246. profiling_path = os.path.join(self._raw_data_dir, profiling_id)
  247. try:
  248. profiling_path = validate_and_normalize_path(profiling_path)
  249. except RuntimeError:
  250. raise ProfilerPathErrorException('Profiling path is invalid.')
  251. if not os.path.isdir(profiling_path):
  252. raise ProfilerDirNotFoundException(profiling_path)
  253. return profiling_path
  254. def _search_file(self, profiling_id, device_id):
  255. """
  256. Search all framework files in raw profiling path.
  257. Args:
  258. profiling_id (str): The profiling ID.
  259. device_id (str): The device ID.
  260. Raises:
  261. ProfilerFileNotFoundException: If the framework files are not found.
  262. """
  263. # first search in the JOB dir, and if not, search in the sub directory
  264. # in the JOB
  265. self._search_file_from_job_path(device_id, search_in_sub_path=False)
  266. if self._backend_type is None:
  267. self._search_file_from_job_path(device_id, search_in_sub_path=True)
  268. self._search_file_from_data_path(profiling_id, device_id)
  269. if self._backend_type is None:
  270. raise ProfilerFileNotFoundException('Framework')
  271. self._framework_path['graph'].sort()
  272. self._framework_path['task'].sort()
  273. def _search_file_from_job_path(self, device_id, search_in_sub_path=False):
  274. """
  275. Search framework files from job path.
  276. Args:
  277. device_id (str): The device ID.
  278. search_in_sub_path (bool): `True` if search file in profiling dir,
  279. else search in profiling sub dir. Default: False.
  280. Raises:
  281. ProfilerRawFileException: If the framework file type is inconsistent.
  282. ProfilerDeviceIdMismatchException: If the device id is mismatch
  283. with framework in the raw dir.
  284. """
  285. profiling_dir = os.path.join(self._profiling_path, 'data') \
  286. if search_in_sub_path else self._profiling_path
  287. if not os.path.isdir(profiling_dir):
  288. return
  289. files = os.listdir(profiling_dir)
  290. for file in files:
  291. pattern = re.search(self._regex_framework, file)
  292. if not pattern or file.endswith('.done'):
  293. continue
  294. attrs = pattern.groupdict()
  295. device_id_in_path = attrs.get('device_id')
  296. if device_id_in_path != device_id:
  297. raise ProfilerDeviceIdMismatchException()
  298. data_type = attrs.get('data_type')
  299. data_type = data_type.replace("host.", "")
  300. if data_type.startswith('vm_'):
  301. if self._backend_type and self._backend_type != 'vm':
  302. raise ProfilerRawFileException('Backend type is inconsistent.')
  303. self._backend_type = 'vm'
  304. _, data_type = data_type.split('_', 1)
  305. else:
  306. if self._backend_type and self._backend_type != 'ge':
  307. raise ProfilerRawFileException('Backend type is inconsistent.')
  308. self._backend_type = 'ge'
  309. if data_type.startswith('graph_desc_info'):
  310. self._framework_path['graph'].append(
  311. os.path.join(profiling_dir, file)
  312. )
  313. elif data_type.startswith('task_desc_info'):
  314. self._framework_path['task'].append(
  315. os.path.join(profiling_dir, file)
  316. )
  317. elif data_type.startswith('point'):
  318. self._framework_path['point'].append(
  319. os.path.join(profiling_dir, file)
  320. )
  321. def _search_file_from_data_path(self, profiling_id, device_id):
  322. """
  323. Search framework files from data path.
  324. Args:
  325. profiling_id (str): The profiling ID.
  326. device_id (str): The device ID.
  327. Raises:
  328. ProfilerRawFileException: If the framework file type is inconsistent.
  329. ProfilerDeviceIdMismatchException: If the device id is mismatch
  330. with framework in the raw dir.
  331. """
  332. profiling_data_path = os.path.join(
  333. self._raw_data_dir, 'container', device_id, 'data'
  334. )
  335. if not os.path.isdir(profiling_data_path):
  336. return
  337. files = os.listdir(profiling_data_path)
  338. for file in files:
  339. pattern = re.search(self._regex_framework_in_data, file)
  340. if not pattern or file.endswith('.done') or file.endswith('.zip'):
  341. continue
  342. attrs = pattern.groupdict()
  343. profiling_id_in_path = attrs.get('profiling_id')
  344. if profiling_id_in_path != profiling_id:
  345. continue
  346. device_id_in_path = attrs.get('device_id')
  347. if device_id_in_path != device_id:
  348. raise ProfilerDeviceIdMismatchException()
  349. data_type = attrs.get('data_type')
  350. data_type = data_type.replace("host.", "")
  351. if data_type.startswith('vm_'):
  352. if self._backend_type and self._backend_type != 'vm':
  353. raise ProfilerRawFileException('Backend type is inconsistent.')
  354. self._backend_type = 'vm'
  355. _, data_type = data_type.split('_', 1)
  356. else:
  357. if self._backend_type and self._backend_type != 'ge':
  358. raise ProfilerRawFileException('Backend type is inconsistent.')
  359. self._backend_type = 'ge'
  360. if data_type.startswith('graph_desc_info'):
  361. self._framework_path['graph'].append(
  362. os.path.join(profiling_data_path, file)
  363. )
  364. elif data_type.startswith('task_desc_info'):
  365. self._framework_path['task'].append(
  366. os.path.join(profiling_data_path, file)
  367. )
  368. elif data_type.startswith('point'):
  369. self._framework_path['point'].append(
  370. os.path.join(profiling_data_path, file)
  371. )
  372. def _get_save_path(self, rank_id, output_path):
  373. """
  374. Get the save path.
  375. Args:
  376. rank_id (str): The rank ID.
  377. output_path (str): The output dir.
  378. Returns:
  379. str, the save path.
  380. Raises:
  381. ProfilerPathErrorException: If the output path is invalid.
  382. ProfilerDirNotFoundException: If the output dir is not found.
  383. """
  384. try:
  385. output_dir = validate_and_normalize_path(output_path)
  386. except RuntimeError:
  387. raise ProfilerPathErrorException('Output path is invalid.')
  388. if not os.path.isdir(output_dir):
  389. raise ProfilerDirNotFoundException(output_dir)
  390. return os.path.join(
  391. output_dir, '_'.join(['framework', 'raw', rank_id]) + '.csv'
  392. )
  393. def _parse_task_files(self):
  394. """Parse the framework task files."""
  395. for path in self._framework_path['task']:
  396. path = validate_and_normalize_path(path)
  397. with open(path, 'r') as file:
  398. for task_info in file:
  399. infos = task_info.strip('\n').split(' ')
  400. infos = infos[1:] if len(infos) == 5 else infos
  401. # key is op name, values is task id, stream id, block_dim
  402. self._task_cache[infos[0]] = [infos[2], infos[3], infos[1]]
  403. # if the task id is less than the task id threshold, the
  404. # stream id and task id correspond to an operator
  405. task_id = infos[2]
  406. if int(task_id) < self._task_id_threshold:
  407. task_id = '_'.join([infos[3], task_id])
  408. self._task_id_full_op_name_dict[task_id] = infos[0]
  409. def _parse_graph_files_and_save(self, task_cache):
  410. """
  411. Parse the framework graph files and save the framework information.
  412. Args:
  413. task_cache (dict): The task information cache.
  414. """
  415. with open(self._save_path, 'w') as save_file:
  416. csv_writer = csv.writer(save_file)
  417. csv_writer.writerow(self._col_names)
  418. pre_graph_info = None
  419. for path in self._framework_path['graph']:
  420. first_row = True
  421. with open(path, 'r') as graph_file:
  422. for graph_info in graph_file:
  423. if first_row is True:
  424. first_row = False
  425. # The last row of the previous file and the first row of the current file may need
  426. # to be combined to one row
  427. if graph_info.startswith("op_name:") is False:
  428. pre_graph_info = pre_graph_info + graph_info
  429. continue
  430. if pre_graph_info is not None:
  431. self._parse_graph_row_and_save(task_cache, csv_writer, pre_graph_info)
  432. pre_graph_info = graph_info
  433. if pre_graph_info is not None:
  434. self._parse_graph_row_and_save(task_cache, csv_writer, pre_graph_info)
  435. none_list = [None, None, None, None]
  436. for key, value in task_cache.items():
  437. value.append(key)
  438. value.extend(none_list)
  439. csv_writer.writerow(value)
  440. os.chmod(self._save_path, stat.S_IREAD | stat.S_IWRITE)
  441. def _parse_graph_row_and_save(self, task_cache, csv_writer, graph_info):
  442. """
  443. Parse the framework graph row and save the framework information.
  444. Args:
  445. task_cache (dict): The task information cache.
  446. csv_writer (csv): Csv writer.
  447. graph_info (str): Row info of graph.
  448. """
  449. result = self._parse_one_row_graph_info(graph_info)
  450. task_info = task_cache.get(result[0])
  451. if task_info:
  452. task_info.extend(result)
  453. csv_writer.writerow(task_info)
  454. del task_cache[result[0]]
  455. else:
  456. save_info = [None, None, None]
  457. save_info.extend(result)
  458. csv_writer.writerow(save_info)
  459. def _parse_one_row_graph_info(self, row_info):
  460. """
  461. Parse the graph information in one row.
  462. Args:
  463. row_info (str): One row graph information.
  464. Returns:
  465. list[str], the parsed graph information.
  466. """
  467. full_op_name = None
  468. op_name = None
  469. subgraph_name = None
  470. op_type = None
  471. op_info = dict()
  472. cur_op_info_key = None
  473. infos = row_info.strip('\n').split(' ')
  474. for info in infos:
  475. attr_name, attr_value = info.split(':', 1)
  476. if attr_name == 'op_name':
  477. full_op_name = attr_value
  478. subgraph_name = self._get_subgraph_name(full_op_name)
  479. op_name = self._get_op_name(full_op_name, subgraph_name)
  480. elif attr_name == 'op_type':
  481. op_type = attr_value
  482. elif attr_name in ['input_id', 'output_id']:
  483. cur_op_info_key = '{}_{}'.format(
  484. attr_name.split('_')[0], attr_value
  485. )
  486. op_info[cur_op_info_key] = dict()
  487. elif attr_name in self._graph_attr_name:
  488. op_attr = attr_name.split('_', 1)[1]
  489. if op_attr == 'shape':
  490. attr_value = attr_value.strip('"')
  491. if self._backend_type == 'vm':
  492. if op_attr == 'data_type':
  493. attr_value = VmDataType.get_data_type_name(
  494. int(attr_value)
  495. )
  496. else:
  497. if op_attr == 'data_type':
  498. attr_value = GeDataType.get_data_type_name(
  499. int(attr_value)
  500. )
  501. elif op_attr == 'format':
  502. attr_value = GeFormat.get_format_name(int(attr_value))
  503. op_info[cur_op_info_key][op_attr] = attr_value
  504. # the list info are full_op_name, op_name, op_type, subgraph, op_info
  505. return [full_op_name, op_name, op_type, subgraph_name,
  506. json.dumps(op_info)]
  507. def _get_subgraph_name(self, full_op_name):
  508. """
  509. Get subgraph name.
  510. Args:
  511. full_op_name (str): The full operator name.
  512. Returns:
  513. str, the subgraph name.
  514. """
  515. subgraph_name = full_op_name.split('/', 1)[0]
  516. if subgraph_name in ['Default', 'Gradients']:
  517. return subgraph_name
  518. return None
  519. def _get_op_name(self, full_op_name, subgraph_name):
  520. """
  521. Get operator name.
  522. Args:
  523. full_op_name (str): The full operator name.
  524. subgraph_name (str): The subgraph name.
  525. Returns:
  526. str, the operator name.
  527. """
  528. if subgraph_name is None:
  529. return full_op_name
  530. if self._backend_type == 'vm':
  531. return full_op_name.split('/')[-1]
  532. strs = full_op_name.split(subgraph_name + '/')
  533. op_name = None
  534. for name_str in strs:
  535. if not name_str:
  536. continue
  537. if op_name is None:
  538. op_name = name_str.split('/')[-1]
  539. else:
  540. op_name = '+'.join([op_name, name_str.split('/')[-1]])
  541. return op_name
  542. def _parse_point_files(self):
  543. """Parse the framework point files."""
  544. for path in self._framework_path['point']:
  545. path = validate_and_normalize_path(path)
  546. with open(path, 'r') as file:
  547. for point_info in file:
  548. infos = point_info.strip('\n').split(' ')
  549. self._point_info[int(infos[0])] = infos[1]