You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

memory_usage_parser.py 16 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Memory Usage Parser."""
  16. from collections import OrderedDict
  17. import json
  18. import os
  19. import stat
  20. from google.protobuf.text_format import ParseError
  21. from mindspore import log as logger
  22. from mindspore.profiler.common.exceptions.exceptions import ProfilerIOException, \
  23. ProfilerFileNotFoundException, ProfilerRawFileException
  24. from mindspore.profiler.common.validator.validate_path import validate_and_normalize_path
  25. from mindspore.profiler.parser.container import MemoryGraph as Graph
  26. from mindspore.profiler.parser.container import MemoryNode as Node
  27. from mindspore.profiler.parser.container import MemoryTensor as Tensor
  28. import mindspore._c_expression as c_expression
  29. if not c_expression.security.enable_security():
  30. from mindspore.train.memory_profiling_pb2 import MemoryProto
  31. else:
  32. MemoryProto = None
  33. GIGABYTES = 1024 * 1024 * 1024
  34. class MemoryUsageParser:
  35. """MemoryUsageParser to parse memory raw data."""
  36. def __init__(self, profiling_dir, device_id):
  37. self._profiling_dir = profiling_dir
  38. self._device_id = device_id
  39. self._proto_file_path = 'memory_usage_{}.pb'
  40. self._summary_filename = 'memory_usage_summary_{}.json'
  41. self._details_filename = 'memory_usage_details_{}.json'
  42. self._graphs_dict = {}
  43. self._peak_mem = 0
  44. self._mem_summary = {
  45. 'capacity': 0,
  46. 'allocations': 0,
  47. 'deallocations': 0,
  48. 'peak_mem': 0,
  49. 'static_mem': 0
  50. }
  51. self._framework = {}
  52. self._points = {}
  53. def _get_file_path(self):
  54. """Get the proto file path."""
  55. file_path = os.path.join(
  56. self._profiling_dir,
  57. self._proto_file_path.format(self._device_id)
  58. )
  59. file_path = validate_and_normalize_path(file_path)
  60. if not os.path.exists(file_path):
  61. logger.warning('The memory file does not exist! Please ignore the warning '
  62. 'if you are running heterogeneous training.')
  63. raise ProfilerFileNotFoundException(msg=file_path)
  64. return file_path
  65. def init_memory_usage_info(self, aicore_detail_data, points):
  66. """Init memory usage information."""
  67. logger.info("Start to load memory usage data from pb file")
  68. file_path = self._get_file_path()
  69. self._framework = self._process_framework_info(aicore_detail_data)
  70. self._points = points
  71. # Open memory protobuf file.
  72. try:
  73. with open(file_path, 'rb') as f:
  74. content = f.read()
  75. except (IOError, OSError) as err:
  76. logger.critical('Failed to read memory file: %s', err)
  77. raise ProfilerIOException
  78. # Parse memory raw data from file.
  79. if not c_expression.security.enable_security():
  80. if not MemoryProto:
  81. raise ProfilerRawFileException("Can not find memory profiling pb file.")
  82. memory_proto = MemoryProto()
  83. try:
  84. memory_proto.ParseFromString(content)
  85. except ParseError as err:
  86. msg = "Fail to parse memory proto file."
  87. logger.critical("Cannot parse the memory file. Please check the file schema.\n%s", err)
  88. raise ProfilerRawFileException(msg)
  89. # Parse memory details based on graphs in the network.
  90. graphs = memory_proto.graph_mem
  91. self._parse_graph_memory(graphs)
  92. # Update memory summary information.
  93. self._mem_summary['capacity'] = memory_proto.total_mem / GIGABYTES
  94. self._mem_summary['peak_mem'] = self._peak_mem
  95. logger.info('Finished processing memory usage data.')
  96. def _parse_graph_memory(self, graphs):
  97. """Parse memory usage based on subgraphs."""
  98. for graph_proto in graphs:
  99. graph_id = graph_proto.graph_id
  100. if graph_id is None:
  101. logger.info('Graph id is missing, skipped the graph.')
  102. continue
  103. graph_parser = GraphMemoryParser(graph_proto, self._points, self._framework)
  104. graph = graph_parser.parse_graph()
  105. if graph:
  106. self._graphs_dict[graph_id] = graph
  107. # update global memory usage data
  108. self._peak_mem = max(self._peak_mem, graph_parser.peak_mem)
  109. self._mem_summary['static_mem'] += graph_parser.static_mem
  110. self._mem_summary['allocations'] += graph_parser.allocations
  111. self._mem_summary['deallocations'] += graph_parser.deallocations
  112. def _write_memory_files(self, filename, content):
  113. """Write the summary and top breakdowns of memory usage."""
  114. file_path = os.path.join(self._profiling_dir, filename)
  115. file_path = validate_and_normalize_path(file_path)
  116. try:
  117. with open(file_path, 'w') as json_file:
  118. json.dump(content, json_file)
  119. os.chmod(file_path, stat.S_IREAD | stat.S_IWRITE)
  120. except (IOError, OSError) as err:
  121. logger.critical('Fail to write memory file.\n%s', err)
  122. raise ProfilerIOException
  123. def write_memory_files(self):
  124. """Write memory files."""
  125. logger.info('Start recording memory data into files...')
  126. # write memory summary to json file
  127. summary_filename = self._summary_filename.format(self._device_id)
  128. self._write_memory_files(summary_filename, self._mem_summary)
  129. # write memory details to json file
  130. details_filename = self._details_filename.format(self._device_id)
  131. self._write_memory_files(details_filename, self._graphs_dict)
  132. logger.info('Successfully write memory data into files.')
  133. @staticmethod
  134. def _process_framework_info(aicore_detail_data):
  135. """Process framework info."""
  136. framework_info_dict = {}
  137. for framework_obj in aicore_detail_data:
  138. op_name = framework_obj[0]
  139. op_full_name = framework_obj[4]
  140. op_info = framework_obj[5]
  141. framework_info_dict[op_name] = {
  142. 'fullname': op_full_name,
  143. 'name': op_name,
  144. 'args': op_info
  145. }
  146. return framework_info_dict
  147. class GraphMemoryParser:
  148. """Parse memory usage data for each graph."""
  149. def __init__(self, graph_proto, points, framework):
  150. self.graph = None
  151. self.nodes = OrderedDict()
  152. self.tensors = OrderedDict()
  153. self._framework = framework
  154. self._points = points
  155. self._graph_proto = graph_proto
  156. self.peak_mem = 0
  157. self.static_mem = 0
  158. self.allocations = 0
  159. self.deallocations = 0
  160. self._mem_change = []
  161. self.breakdowns = []
  162. self._lifetime = []
  163. def parse_graph(self):
  164. """Parse memory usage data for subgraphs."""
  165. graph_dict = {}
  166. self.graph = Graph(self._graph_proto)
  167. # process tensors in the graph
  168. tensors_proto = self._graph_proto.tensor_mems
  169. if not tensors_proto:
  170. logger.info('No tensor in graph %s, skipped.', self.graph.graph_id)
  171. return graph_dict
  172. self._parse_tensors(tensors_proto)
  173. # calculate memory usage of the graph by number of nodes and details of tensors
  174. nodes_proto = self._graph_proto.node_mems
  175. # init memory usage list with static memory
  176. self._mem_change = [self.graph.static_mem for _ in range(len(nodes_proto))]
  177. self._lifetime = [[] for _ in range(len(nodes_proto))]
  178. self._calc_mem_change() # update self._mem_change and self._lifetime
  179. self.graph.lines = self._mem_change
  180. # process nodes in graph
  181. self.graph.nodes = self._parse_nodes(nodes_proto)
  182. self._process_memory_breakdowns()
  183. self.graph.breakdowns = self.breakdowns
  184. # update fp_start and bp_end
  185. point_id = self._locate_fp_bp_id()
  186. self.graph.fp_start = point_id.get('fp_start')
  187. self.graph.bp_end = point_id.get('bp_end')
  188. graph_dict = self.graph.to_dict()
  189. self.static_mem = self.graph.static_mem
  190. self.allocations = len(self.tensors)
  191. self.deallocations = len(self.tensors)
  192. self.peak_mem = max(max(self._mem_change), self.peak_mem)
  193. return graph_dict
  194. def _parse_tensors(self, tensors_proto):
  195. """Parse tensors."""
  196. for tensor_proto in tensors_proto:
  197. tensor = Tensor(tensor_proto)
  198. self.tensors.update({tensor.tensor_id: tensor})
  199. def _parse_nodes(self, nodes_proto):
  200. """Parse nodes."""
  201. nodes_list = []
  202. for index, node_proto in enumerate(nodes_proto):
  203. node = Node(node_proto)
  204. # Calculate memory size allocated for this node
  205. tensor_ids = set(node.output_ids + node.workspace_ids)
  206. node.size = self._calc_node_memory(tensor_ids)
  207. node.allocations = len(tensor_ids)
  208. node.deallocations = len(tensor_ids)
  209. # calculate the allocated/deallocated memory size on the node
  210. if index == 0:
  211. node.mem_change = self._mem_change[index] - self.graph.static_mem
  212. else:
  213. node.mem_change = self._mem_change[index] - self._mem_change[index - 1]
  214. self._update_nodes(node)
  215. self._update_tensor_source(node)
  216. self.nodes[node.name] = node
  217. nodes_list.append(node.to_dict())
  218. return nodes_list
  219. def _update_nodes(self, node):
  220. """Update nodes."""
  221. # Remove duplicate tensors
  222. self._remove_duplicate_tensors(node)
  223. name = node.name
  224. if self._framework and name in self._framework:
  225. node_frame = self._framework[name]
  226. node.fullname = node_frame.get('fullname')
  227. info = node_frame.get('args')
  228. for key, value in info.items():
  229. if 'input' in key:
  230. node.inputs.append(value)
  231. else:
  232. node.outputs.append(value)
  233. def _update_tensor_source(self, node):
  234. """Update source node for tensors."""
  235. for t_id in node.output_ids:
  236. tensor = self.tensors.get(t_id)
  237. tensor.source_node = node.name
  238. @staticmethod
  239. def _remove_duplicate_tensors(node):
  240. """Find conflict tensors in node."""
  241. if node.workspace_ids:
  242. i = 0
  243. while i < len(node.workspace_ids):
  244. t_id = node.workspace_ids[i]
  245. if t_id in node.output_ids:
  246. del node.workspace_ids[i] # remove duplicate tensor
  247. continue
  248. i += 1
  249. def _calc_node_memory(self, tensor_ids):
  250. """Calculate the allocated memory for the node."""
  251. node_mem = 0
  252. for t_id in tensor_ids:
  253. tensor = self.tensors[t_id]
  254. size = tensor.size
  255. node_mem += size
  256. return node_mem
  257. def _calc_mem_change(self):
  258. """Calculate the memory change for the subgraph."""
  259. node_num = len(self._mem_change)
  260. for tensor_id, tensor in self.tensors.items():
  261. life_long = tensor.life_long
  262. life_start = tensor.life_start
  263. life_end = tensor.life_end
  264. size = tensor.size
  265. # Update memory change for the entire graph.
  266. # If a tensor's lifetime cannot be fully located, it will be ignored as 0 change.
  267. if life_long == 'LifeLongGraphAll': # lifetime is from graph start to graph end
  268. tensor.life_start = 0
  269. tensor.life_end = node_num
  270. self._update_mem_change(size, 0, node_num, tensor_id)
  271. elif life_long == 'LifeLongGraphStart': # lifetime is from graph start to tensor end
  272. if life_end is not None and life_end >= 0:
  273. tensor.life_start = 0
  274. self._update_mem_change(size, 0, life_end + 1, tensor_id)
  275. else:
  276. logger.info('Cannot locate lifetime end for tensor: %s', tensor_id)
  277. elif life_long == 'LifeLongGraphEnd': # lifetime is from tensor start to graph end
  278. if life_start is not None and life_start <= node_num:
  279. tensor.life_end = node_num
  280. self._update_mem_change(size, life_start, node_num, tensor_id)
  281. else:
  282. logger.info('Cannot locate lifetime start for tensor: %s', tensor_id)
  283. elif life_long == 'LifeLongNone': # lifetime is from tensor start to tensor end
  284. if life_start is not None and life_end is not None and life_start <= life_end:
  285. self._update_mem_change(size, life_start, life_end + 1, tensor_id)
  286. else:
  287. logger.info('Cannot locate lifetime start or end for tensor: %s', tensor_id)
  288. def _update_mem_change(self, size, start, end, tensor_id):
  289. """Update memory change for the subgraph."""
  290. for i in range(start, end):
  291. self._mem_change[i] += size
  292. # Update tensor lifetime list.
  293. self._lifetime[i].append(tensor_id)
  294. def _locate_fp_bp_id(self):
  295. """Locate the node id of fp_start and bp_end in graph."""
  296. point_id = {
  297. 'fp_start': None,
  298. 'bp_end': None
  299. }
  300. fp_start = self._points.get('fp_start') if self._points else None
  301. bp_end = self._points.get('bp_end') if self._points else None
  302. fp_name = fp_start.split('/')[-1] if fp_start else ""
  303. bp_name = bp_end.split('/')[-1] if bp_end else ""
  304. if fp_name in self.nodes:
  305. point_id['fp_start'] = self.nodes[fp_name].node_id
  306. if bp_name in self.nodes:
  307. point_id['bp_end'] = self.nodes[bp_name].node_id
  308. return point_id
  309. def _process_memory_breakdowns(self):
  310. """Process memory breakdowns for each node."""
  311. self.breakdowns = [[] for _ in range(len(self.nodes))]
  312. for index, breakdown in enumerate(self._lifetime):
  313. for t_id in breakdown:
  314. tensor = self.tensors.get(t_id)
  315. source_node = tensor.source_node
  316. if not source_node:
  317. continue
  318. node = self.nodes.get(source_node)
  319. tensor_dict = self._get_tensor_dict(node, tensor, t_id)
  320. self.breakdowns[index].append(tensor_dict)
  321. @staticmethod
  322. def _get_tensor_dict(node, tensor, t_id):
  323. """Update node outputs to assemble memory breakdowns."""
  324. for i, output_id in enumerate(node.output_ids):
  325. if t_id == output_id:
  326. output = node.outputs[i] if i < len(node.outputs) else {}
  327. tensor.name = node.name + ':' + str(i)
  328. tensor.shape = output.get('shape')
  329. tensor.dtype = output.get('data_type')
  330. tensor.format = output.get('format')
  331. tensor.type = 'output'
  332. return tensor.to_dict()