You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework_parser.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Thr parser for parsing framework files."""
  16. import csv
  17. import enum
  18. import json
  19. import os
  20. import re
  21. from marshmallow import ValidationError
  22. from mindinsight.profiler.common.exceptions.exceptions import \
  23. ProfilerPathErrorException, ProfilerDirNotFoundException, \
  24. ProfilerFileNotFoundException, ProfilerDeviceIdMismatchException, \
  25. ProfilerRawFileException
  26. from mindinsight.profiler.common.validator.validate_path import \
  27. validate_and_normalize_path
  28. class VmDataType(enum.IntEnum):
  29. """Definition of vm data type."""
  30. NUMBER_TYPE_BEGIN = 26
  31. NUMBER_TYPE_BOOL = 27
  32. NUMBER_TYPE_INT = 28
  33. NUMBER_TYPE_INT8 = 29
  34. NUMBER_TYPE_INT16 = 30
  35. NUMBER_TYPE_INT32 = 31
  36. NUMBER_TYPE_INT64 = 32
  37. NUMBER_TYPE_UINT = 33
  38. NUMBER_TYPE_UINT8 = 34
  39. NUMBER_TYPE_UINT16 = 35
  40. NUMBER_TYPE_UINT32 = 36
  41. NUMBER_TYPE_UINT64 = 37
  42. NUMBER_TYPE_FLOAT = 38
  43. NUMBER_TYPE_FLOAT16 = 39
  44. NUMBER_TYPE_FLOAT32 = 40
  45. NUMBER_TYPE_FLOAT64 = 41
  46. NUMBER_TYPE_END = 42
  47. @classmethod
  48. def get_data_type_name(cls, num):
  49. """
  50. Get the name of data type by enum number.
  51. Args:
  52. num (int): Enum number.
  53. Returns:
  54. str, the name of data type.
  55. """
  56. data_type = cls._value2member_map_.get(num)
  57. return 'UNKNOWN' if data_type is None else data_type.name
  58. class GeDataType(enum.IntEnum):
  59. """Definition of ge data type."""
  60. DT_FLOAT = 0
  61. DT_FLOAT16 = 1
  62. DT_INT8 = 2
  63. DT_INT16 = 6
  64. DT_UINT16 = 7
  65. DT_UINT8 = 4
  66. DT_INT32 = 3
  67. DT_INT64 = 9
  68. DT_UINT32 = 8
  69. DT_UINT64 = 10
  70. DT_BOOL = 12
  71. DT_DOUBLE = 11
  72. DT_STRING = 13
  73. DT_DUAL_SUB_INT8 = 14
  74. DT_DUAL_SUB_UINT8 = 15
  75. DT_COMPLEX64 = 16
  76. DT_COMPLEX128 = 17
  77. DT_QINT8 = 18
  78. DT_QINT16 = 19
  79. DT_QINT32 = 20
  80. DT_QUINT8 = 21
  81. DT_QUINT16 = 22
  82. DT_RESOURCE = 23
  83. DT_STRING_REF = 24
  84. DT_DUAL = 25
  85. DT_UNDEFINED = 26
  86. @classmethod
  87. def get_data_type_name(cls, num):
  88. """
  89. Get the name of data type by enum number.
  90. Args:
  91. num (int): Enum number.
  92. Returns:
  93. str, the name of data type.
  94. """
  95. data_type = cls._value2member_map_.get(num)
  96. return 'UNKNOWN' if data_type is None else data_type.name
  97. class GeFormat(enum.IntEnum):
  98. """Definition of ge format type."""
  99. FORMAT_NCHW = 0
  100. FORMAT_NHWC = 1
  101. FORMAT_ND = 2
  102. FORMAT_NC1HWC0 = 3
  103. FORMAT_FRACTAL_Z = 4
  104. FORMAT_NC1C0HWPAD = 5
  105. FORMAT_NHWC1C0 = 6
  106. FORMAT_FSR_NCHW = 7
  107. FORMAT_FRACTAL_DECONV = 8
  108. FORMAT_C1HWNC0 = 9
  109. FORMAT_FRACTAL_DECONV_TRANSPOSE = 10
  110. FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11
  111. FORMAT_NC1HWC0_C04 = 12
  112. FORMAT_FRACTAL_Z_C04 = 13
  113. FORMAT_CHWN = 14
  114. FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15
  115. FORMAT_HWCN = 16
  116. FORMAT_NC1KHKWHWC0 = 17
  117. FORMAT_BN_WEIGHT = 18
  118. FORMAT_FILTER_HWCK = 19
  119. FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20
  120. FORMAT_HASHTABLE_LOOKUP_KEYS = 21
  121. FORMAT_HASHTABLE_LOOKUP_VALUE = 22
  122. FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23
  123. FORMAT_HASHTABLE_LOOKUP_HITS = 24
  124. FORMAT_C1HWNCOC0 = 25
  125. FORMAT_MD = 26
  126. FORMAT_NDHWC = 27
  127. FORMAT_FRACTAL_ZZ = 28
  128. FORMAT_FRACTAL_NZ = 29
  129. FORMAT_NCDHW = 30
  130. FORMAT_DHWCN = 31
  131. FORMAT_NDC1HWC0 = 32
  132. FORMAT_FRACTAL_Z_3D = 33
  133. FORMAT_CN = 34
  134. FORMAT_NC = 35
  135. FORMAT_DHWNC = 36
  136. FORMAT_FRACTAL_Z_3D_TRANSPOSE = 37
  137. FORMAT_RESERVED = 38
  138. FORMAT_ALL = 39
  139. @classmethod
  140. def get_format_name(cls, num):
  141. """
  142. Get the name of format type by enum number.
  143. Args:
  144. num (int): Enum number.
  145. Returns:
  146. str, the name of format type.
  147. """
  148. format_type = cls._value2member_map_.get(num)
  149. return 'UNKNOWN' if format_type is None else format_type.name
  150. class FrameworkParser:
  151. """
  152. Thr parser for parsing framework files.
  153. Args:
  154. profiling_id (str): The profiling ID.
  155. device_id (str): The device ID.
  156. output_path (str): The directory of the parsed file. Default: `./`.
  157. """
  158. _raw_data_dir = '/var/log/npu/profiling'
  159. _regex_framework = r'Framework\.host\.(?P<data_type>.+)\.(?P<device_id>\d).+'
  160. _regex_framework_in_data = r'Framework\.host\.(?P<data_type>.+)\.' \
  161. r'(?P<device_id>\d)\.(?P<profiling_id>[a-zA-Z0-9]+).+'
  162. _col_names = [
  163. 'task_id', 'stream_id', 'block_dim', 'full_op_name', 'op_name',
  164. 'op_type', 'subgraph', 'op_info'
  165. ]
  166. _graph_attr_name = [
  167. 'input_format', 'input_data_type', 'input_shape', 'output_format',
  168. 'output_data_type', 'output_shape'
  169. ]
  170. def __init__(self, profiling_id, device_id, output_path='./'):
  171. self._profiling_path = self._get_raw_profiling_path(profiling_id)
  172. self._backend_type = None
  173. self._framework_path = {'graph': [], 'task': []}
  174. self._search_file(profiling_id, device_id)
  175. self._device_id = device_id
  176. self._save_path = self._get_save_path(device_id, output_path)
  177. self._task_id_full_op_name_dict = {}
  178. self._task_cache = {}
  179. self._parse_task_files()
  180. @property
  181. def save_path(self):
  182. """
  183. The property of save path.
  184. Returns:
  185. str, the save path.
  186. """
  187. return self._save_path
  188. def to_task_id_full_op_name_dict(self):
  189. """
  190. Get the task id and full operator name dict.
  191. Returns:
  192. dict, the task id and full operator name dict.
  193. """
  194. return self._task_id_full_op_name_dict
  195. def parse(self):
  196. """Parse the framework files."""
  197. self._parse_graph_files_and_save(self._task_cache)
  198. del self._task_cache
  199. def _get_raw_profiling_path(self, profiling_id):
  200. """
  201. Get raw profiling path.
  202. Args:
  203. profiling_id (str): The profiling ID.
  204. Returns:
  205. str, the raw profiling path.
  206. Raises:
  207. ProfilerPathErrorException: If the profiling path is invalid.
  208. ProfilerDirNotFoundException: If the profiling dir is not found.
  209. """
  210. profiling_path = os.path.join(self._raw_data_dir, profiling_id)
  211. try:
  212. profiling_path = validate_and_normalize_path(
  213. profiling_path, 'profiler'
  214. )
  215. except ValidationError:
  216. raise ProfilerPathErrorException('Profiling path is invalid.')
  217. if not os.path.isdir(profiling_path):
  218. raise ProfilerDirNotFoundException(profiling_path)
  219. return profiling_path
  220. def _search_file(self, profiling_id, device_id):
  221. """
  222. Search all framework files in raw profiling path.
  223. Args:
  224. profiling_id (str): The profiling ID.
  225. device_id (str): The device ID.
  226. Raises:
  227. ProfilerFileNotFoundException: If the framework files are not found.
  228. """
  229. self._search_file_from_job_path(device_id)
  230. self._search_file_from_data_path(profiling_id, device_id)
  231. if self._backend_type is None:
  232. raise ProfilerFileNotFoundException('Framework')
  233. self._framework_path['graph'].sort()
  234. self._framework_path['task'].sort()
  235. def _search_file_from_job_path(self, device_id):
  236. """
  237. Search framework files from job path.
  238. Args:
  239. device_id (str): The device ID.
  240. Raises:
  241. ProfilerRawFileException: If the framework file type is inconsistent.
  242. ProfilerDeviceIdMismatchException: If the device id is mismatch
  243. with framework in the raw dir.
  244. """
  245. files = os.listdir(self._profiling_path)
  246. for file in files:
  247. pattern = re.search(self._regex_framework, file)
  248. if not pattern or file.endswith('.done'):
  249. continue
  250. attrs = pattern.groupdict()
  251. device_id_in_path = attrs.get('device_id')
  252. if device_id_in_path != device_id:
  253. raise ProfilerDeviceIdMismatchException()
  254. data_type = attrs.get('data_type')
  255. if data_type.startswith('vm.'):
  256. if self._backend_type and self._backend_type != 'vm':
  257. raise ProfilerRawFileException('Backend type is inconsistent.')
  258. self._backend_type = 'vm'
  259. data_type = data_type.split('.')[1]
  260. else:
  261. if self._backend_type and self._backend_type != 'ge':
  262. raise ProfilerRawFileException('Backend type is inconsistent.')
  263. self._backend_type = 'ge'
  264. if data_type.startswith('graph_desc_info'):
  265. self._framework_path['graph'].append(
  266. os.path.join(self._profiling_path, file)
  267. )
  268. elif data_type.startswith('task_desc_info'):
  269. self._framework_path['task'].append(
  270. os.path.join(self._profiling_path, file)
  271. )
  272. def _search_file_from_data_path(self, profiling_id, device_id):
  273. """
  274. Search framework files from data path.
  275. Args:
  276. profiling_id (str): The profiling ID.
  277. device_id (str): The device ID.
  278. Raises:
  279. ProfilerRawFileException: If the framework file type is inconsistent.
  280. ProfilerDeviceIdMismatchException: If the device id is mismatch
  281. with framework in the raw dir.
  282. """
  283. profiling_data_path = os.path.join(
  284. self._raw_data_dir, 'container', device_id, 'data'
  285. )
  286. if not os.path.isdir(profiling_data_path):
  287. return
  288. files = os.listdir(profiling_data_path)
  289. for file in files:
  290. pattern = re.search(self._regex_framework_in_data, file)
  291. if not pattern or file.endswith('.done') or file.endswith('.zip'):
  292. continue
  293. attrs = pattern.groupdict()
  294. profiling_id_in_path = attrs.get('profiling_id')
  295. if profiling_id_in_path != profiling_id:
  296. continue
  297. device_id_in_path = attrs.get('device_id')
  298. if device_id_in_path != device_id:
  299. raise ProfilerDeviceIdMismatchException()
  300. data_type = attrs.get('data_type')
  301. if data_type.startswith('vm.'):
  302. if self._backend_type and self._backend_type != 'vm':
  303. raise ProfilerRawFileException('Backend type is inconsistent.')
  304. self._backend_type = 'vm'
  305. data_type = data_type.split('.')[1]
  306. else:
  307. if self._backend_type and self._backend_type != 'ge':
  308. raise ProfilerRawFileException('Backend type is inconsistent.')
  309. self._backend_type = 'ge'
  310. if data_type.startswith('graph_desc_info'):
  311. self._framework_path['graph'].append(
  312. os.path.join(profiling_data_path, file)
  313. )
  314. elif data_type.startswith('task_desc_info'):
  315. self._framework_path['task'].append(
  316. os.path.join(profiling_data_path, file)
  317. )
  318. def _get_save_path(self, device_id, output_path):
  319. """
  320. Get the save path.
  321. Args:
  322. device_id (str): The device ID.
  323. output_path (str): The output dir.
  324. Returns:
  325. str, the save path.
  326. Raises:
  327. ProfilerPathErrorException: If the output path is invalid.
  328. ProfilerDirNotFoundException: If the output dir is not found.
  329. """
  330. try:
  331. output_dir = validate_and_normalize_path(output_path, 'profiler')
  332. except ValidationError:
  333. raise ProfilerPathErrorException('Output path is invalid.')
  334. if not os.path.isdir(output_dir):
  335. raise ProfilerDirNotFoundException(output_dir)
  336. return os.path.join(
  337. output_dir, '_'.join(['framework', 'raw', device_id]) + '.csv'
  338. )
  339. def _parse_task_files(self):
  340. """Parse the framework task files."""
  341. for path in self._framework_path['task']:
  342. with open(path, 'r') as file:
  343. for task_info in file:
  344. infos = task_info.strip('\n').split(' ')
  345. # key is op name, values is task id, stream id, block_dim
  346. self._task_cache[infos[0]] = [infos[2], infos[3], infos[1]]
  347. self._task_id_full_op_name_dict[infos[2]] = infos[0]
  348. def _parse_graph_files_and_save(self, task_cache):
  349. """
  350. Parse the framework graph files and save the framework information.
  351. Args:
  352. task_cache (dict): The task information cache.
  353. """
  354. with open(self._save_path, 'w') as save_file:
  355. csv_writer = csv.writer(save_file)
  356. csv_writer.writerow(self._col_names)
  357. for path in self._framework_path['graph']:
  358. with open(path, 'r') as graph_file:
  359. for graph_info in graph_file:
  360. result = self._parse_one_row_graph_info(graph_info)
  361. task_info = task_cache.get(result[0])
  362. task_info.extend(result)
  363. csv_writer.writerow(task_info)
  364. del task_cache[result[0]]
  365. none_list = [None, None, None, None]
  366. for key, value in task_cache.items():
  367. value.append(key)
  368. value.extend(none_list)
  369. csv_writer.writerow(value)
  370. def _parse_one_row_graph_info(self, row_info):
  371. """
  372. Parse the graph information in one row.
  373. Args:
  374. row_info (str): One row graph information.
  375. Returns:
  376. list[str], the parsed graph information.
  377. """
  378. full_op_name = None
  379. op_name = None
  380. subgraph_name = None
  381. op_type = None
  382. op_info = dict()
  383. cur_op_info_key = None
  384. infos = row_info.strip('\n').split(' ')
  385. for info in infos:
  386. attr_name, attr_value = info.split(':', 1)
  387. if attr_name == 'op_name':
  388. full_op_name = attr_value
  389. subgraph_name = self._get_subgraph_name(full_op_name)
  390. op_name = self._get_op_name(full_op_name, subgraph_name)
  391. elif attr_name == 'op_type':
  392. op_type = attr_value
  393. elif attr_name in ['input_id', 'output_id']:
  394. cur_op_info_key = '{}_{}'.format(
  395. attr_name.split('_')[0], attr_value
  396. )
  397. op_info[cur_op_info_key] = dict()
  398. elif attr_name in self._graph_attr_name:
  399. op_attr = attr_name.split('_', 1)[1]
  400. if op_attr == 'shape':
  401. attr_value = attr_value.strip('"')
  402. if self._backend_type == 'vm':
  403. if op_attr == 'data_type':
  404. attr_value = VmDataType.get_data_type_name(
  405. int(attr_value)
  406. )
  407. else:
  408. if op_attr == 'data_type':
  409. attr_value = GeDataType.get_data_type_name(
  410. int(attr_value)
  411. )
  412. elif op_attr == 'format':
  413. attr_value = GeFormat.get_format_name(int(attr_value))
  414. op_info[cur_op_info_key][op_attr] = attr_value
  415. # the list info are full_op_name, op_name, op_type, subgraph, op_info
  416. return [full_op_name, op_name, op_type, subgraph_name,
  417. json.dumps(op_info)]
  418. def _get_subgraph_name(self, full_op_name):
  419. """
  420. Get subgraph name.
  421. Args:
  422. full_op_name (str): The full operator name.
  423. Returns:
  424. str, the subgraph name.
  425. """
  426. subgraph_name = full_op_name.split('/', 1)[0]
  427. if subgraph_name in ['Default', 'Gradients']:
  428. return subgraph_name
  429. return None
  430. def _get_op_name(self, full_op_name, subgraph_name):
  431. """
  432. Get operator name.
  433. Args:
  434. full_op_name (str): The full operator name.
  435. subgraph_name (str): The subgraph name.
  436. Returns:
  437. str, the operator name.
  438. """
  439. if subgraph_name is None:
  440. return full_op_name
  441. if self._backend_type == 'vm':
  442. return full_op_name.split('/')[-1]
  443. strs = full_op_name.split(subgraph_name + '/')
  444. op_name = None
  445. for name_str in strs:
  446. if not name_str:
  447. continue
  448. if op_name is None:
  449. op_name = name_str.split('/')[-1]
  450. else:
  451. op_name = '+'.join([op_name, name_str.split('/')[-1]])
  452. return op_name