|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386 |
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Memory Usage Parser."""
- from collections import OrderedDict
- import json
- import os
- import stat
-
- from google.protobuf.text_format import ParseError
-
- from mindspore import log as logger
- from mindspore.profiler.common.exceptions.exceptions import ProfilerIOException, \
- ProfilerFileNotFoundException, ProfilerRawFileException
- from mindspore.profiler.common.validator.validate_path import validate_and_normalize_path
- from mindspore.profiler.parser.container import MemoryGraph as Graph
- from mindspore.profiler.parser.container import MemoryNode as Node
- from mindspore.profiler.parser.container import MemoryTensor as Tensor
- import mindspore._c_expression as c_expression
-
- if not c_expression.security.enable_security():
- from mindspore.train.memory_profiling_pb2 import MemoryProto
- else:
- MemoryProto = None
-
- GIGABYTES = 1024 * 1024 * 1024
-
-
- class MemoryUsageParser:
- """MemoryUsageParser to parse memory raw data."""
-
- def __init__(self, profiling_dir, device_id):
- self._profiling_dir = profiling_dir
- self._device_id = device_id
- self._proto_file_path = 'memory_usage_{}.pb'
- self._summary_filename = 'memory_usage_summary_{}.json'
- self._details_filename = 'memory_usage_details_{}.json'
- self._graphs_dict = {}
- self._peak_mem = 0
- self._mem_summary = {
- 'capacity': 0,
- 'allocations': 0,
- 'deallocations': 0,
- 'peak_mem': 0,
- 'static_mem': 0
- }
- self._framework = {}
- self._points = {}
-
- def _get_file_path(self):
- """Get the proto file path."""
- file_path = os.path.join(
- self._profiling_dir,
- self._proto_file_path.format(self._device_id)
- )
- file_path = validate_and_normalize_path(file_path)
-
- if not os.path.exists(file_path):
- logger.warning('The memory file does not exist! Please ignore the warning '
- 'if you are running heterogeneous training.')
- raise ProfilerFileNotFoundException(msg=file_path)
-
- return file_path
-
- def init_memory_usage_info(self, aicore_detail_data, points):
- """Init memory usage information."""
- logger.info("Start to load memory usage data from pb file")
- file_path = self._get_file_path()
- self._framework = self._process_framework_info(aicore_detail_data)
- self._points = points
-
- # Open memory protobuf file.
- try:
- with open(file_path, 'rb') as f:
- content = f.read()
- except (IOError, OSError) as err:
- logger.critical('Failed to read memory file: %s', err)
- raise ProfilerIOException
-
- # Parse memory raw data from file.
- if not c_expression.security.enable_security():
- if not MemoryProto:
- raise ProfilerRawFileException("Can not find memory profiling pb file.")
-
- memory_proto = MemoryProto()
- try:
- memory_proto.ParseFromString(content)
- except ParseError as err:
- msg = "Fail to parse memory proto file."
- logger.critical("Cannot parse the memory file. Please check the file schema.\n%s", err)
- raise ProfilerRawFileException(msg)
-
- # Parse memory details based on graphs in the network.
- graphs = memory_proto.graph_mem
- self._parse_graph_memory(graphs)
- # Update memory summary information.
- self._mem_summary['capacity'] = memory_proto.total_mem / GIGABYTES
- self._mem_summary['peak_mem'] = self._peak_mem
-
- logger.info('Finished processing memory usage data.')
-
- def _parse_graph_memory(self, graphs):
- """Parse memory usage based on subgraphs."""
- for graph_proto in graphs:
- graph_id = graph_proto.graph_id
- if graph_id is None:
- logger.info('Graph id is missing, skipped the graph.')
- continue
-
- graph_parser = GraphMemoryParser(graph_proto, self._points, self._framework)
- graph = graph_parser.parse_graph()
- if graph:
- self._graphs_dict[graph_id] = graph
-
- # update global memory usage data
- self._peak_mem = max(self._peak_mem, graph_parser.peak_mem)
- self._mem_summary['static_mem'] += graph_parser.static_mem
- self._mem_summary['allocations'] += graph_parser.allocations
- self._mem_summary['deallocations'] += graph_parser.deallocations
-
- def _write_memory_files(self, filename, content):
- """Write the summary and top breakdowns of memory usage."""
- file_path = os.path.join(self._profiling_dir, filename)
- file_path = validate_and_normalize_path(file_path)
-
- try:
- with open(file_path, 'w') as json_file:
- json.dump(content, json_file)
- os.chmod(file_path, stat.S_IREAD | stat.S_IWRITE)
- except (IOError, OSError) as err:
- logger.critical('Fail to write memory file.\n%s', err)
- raise ProfilerIOException
-
- def write_memory_files(self):
- """Write memory files."""
- logger.info('Start recording memory data into files...')
- # write memory summary to json file
- summary_filename = self._summary_filename.format(self._device_id)
- self._write_memory_files(summary_filename, self._mem_summary)
-
- # write memory details to json file
- details_filename = self._details_filename.format(self._device_id)
- self._write_memory_files(details_filename, self._graphs_dict)
- logger.info('Successfully write memory data into files.')
-
- @staticmethod
- def _process_framework_info(aicore_detail_data):
- """Process framework info."""
- framework_info_dict = {}
- for framework_obj in aicore_detail_data:
- op_name = framework_obj[0]
- op_full_name = framework_obj[4]
- op_info = framework_obj[5]
- framework_info_dict[op_name] = {
- 'fullname': op_full_name,
- 'name': op_name,
- 'args': op_info
- }
-
- return framework_info_dict
-
-
- class GraphMemoryParser:
- """Parse memory usage data for each graph."""
-
- def __init__(self, graph_proto, points, framework):
- self.graph = None
- self.nodes = OrderedDict()
- self.tensors = OrderedDict()
- self._framework = framework
- self._points = points
- self._graph_proto = graph_proto
- self.peak_mem = 0
- self.static_mem = 0
- self.allocations = 0
- self.deallocations = 0
- self._mem_change = []
- self.breakdowns = []
- self._lifetime = []
-
- def parse_graph(self):
- """Parse memory usage data for subgraphs."""
- graph_dict = {}
- self.graph = Graph(self._graph_proto)
- # process tensors in the graph
- tensors_proto = self._graph_proto.tensor_mems
- if not tensors_proto:
- logger.info('No tensor in graph %s, skipped.', self.graph.graph_id)
- return graph_dict
- self._parse_tensors(tensors_proto)
-
- # calculate memory usage of the graph by number of nodes and details of tensors
- nodes_proto = self._graph_proto.node_mems
- # init memory usage list with static memory
- self._mem_change = [self.graph.static_mem for _ in range(len(nodes_proto))]
- self._lifetime = [[] for _ in range(len(nodes_proto))]
- self._calc_mem_change() # update self._mem_change and self._lifetime
- self.graph.lines = self._mem_change
-
- # process nodes in graph
- self.graph.nodes = self._parse_nodes(nodes_proto)
-
- self._process_memory_breakdowns()
- self.graph.breakdowns = self.breakdowns
-
- # update fp_start and bp_end
- point_id = self._locate_fp_bp_id()
- self.graph.fp_start = point_id.get('fp_start')
- self.graph.bp_end = point_id.get('bp_end')
-
- graph_dict = self.graph.to_dict()
-
- self.static_mem = self.graph.static_mem
- self.allocations = len(self.tensors)
- self.deallocations = len(self.tensors)
- self.peak_mem = max(max(self._mem_change), self.peak_mem)
-
- return graph_dict
-
- def _parse_tensors(self, tensors_proto):
- """Parse tensors."""
- for tensor_proto in tensors_proto:
- tensor = Tensor(tensor_proto)
- self.tensors.update({tensor.tensor_id: tensor})
-
- def _parse_nodes(self, nodes_proto):
- """Parse nodes."""
- nodes_list = []
- for index, node_proto in enumerate(nodes_proto):
- node = Node(node_proto)
- # Calculate memory size allocated for this node
- tensor_ids = set(node.output_ids + node.workspace_ids)
- node.size = self._calc_node_memory(tensor_ids)
- node.allocations = len(tensor_ids)
- node.deallocations = len(tensor_ids)
-
- # calculate the allocated/deallocated memory size on the node
- if index == 0:
- node.mem_change = self._mem_change[index] - self.graph.static_mem
- else:
- node.mem_change = self._mem_change[index] - self._mem_change[index - 1]
-
- self._update_nodes(node)
- self._update_tensor_source(node)
- self.nodes[node.name] = node
- nodes_list.append(node.to_dict())
-
- return nodes_list
-
- def _update_nodes(self, node):
- """Update nodes."""
- # Remove duplicate tensors
- self._remove_duplicate_tensors(node)
- name = node.name
- if self._framework and name in self._framework:
- node_frame = self._framework[name]
- node.fullname = node_frame.get('fullname')
- info = node_frame.get('args')
- for key, value in info.items():
- if 'input' in key:
- node.inputs.append(value)
- else:
- node.outputs.append(value)
-
- def _update_tensor_source(self, node):
- """Update source node for tensors."""
- for t_id in node.output_ids:
- tensor = self.tensors.get(t_id)
- tensor.source_node = node.name
-
- @staticmethod
- def _remove_duplicate_tensors(node):
- """Find conflict tensors in node."""
- if node.workspace_ids:
- i = 0
- while i < len(node.workspace_ids):
- t_id = node.workspace_ids[i]
- if t_id in node.output_ids:
- del node.workspace_ids[i] # remove duplicate tensor
- continue
- i += 1
-
- def _calc_node_memory(self, tensor_ids):
- """Calculate the allocated memory for the node."""
- node_mem = 0
- for t_id in tensor_ids:
- tensor = self.tensors[t_id]
- size = tensor.size
- node_mem += size
-
- return node_mem
-
- def _calc_mem_change(self):
- """Calculate the memory change for the subgraph."""
- node_num = len(self._mem_change)
- for tensor_id, tensor in self.tensors.items():
- life_long = tensor.life_long
- life_start = tensor.life_start
- life_end = tensor.life_end
- size = tensor.size
-
- # Update memory change for the entire graph.
- # If a tensor's lifetime cannot be fully located, it will be ignored as 0 change.
- if life_long == 'LifeLongGraphAll': # lifetime is from graph start to graph end
- tensor.life_start = 0
- tensor.life_end = node_num
- self._update_mem_change(size, 0, node_num, tensor_id)
- elif life_long == 'LifeLongGraphStart': # lifetime is from graph start to tensor end
- if life_end is not None and life_end >= 0:
- tensor.life_start = 0
- self._update_mem_change(size, 0, life_end + 1, tensor_id)
- else:
- logger.info('Cannot locate lifetime end for tensor: %s', tensor_id)
- elif life_long == 'LifeLongGraphEnd': # lifetime is from tensor start to graph end
- if life_start is not None and life_start <= node_num:
- tensor.life_end = node_num
- self._update_mem_change(size, life_start, node_num, tensor_id)
- else:
- logger.info('Cannot locate lifetime start for tensor: %s', tensor_id)
- elif life_long == 'LifeLongNone': # lifetime is from tensor start to tensor end
- if life_start is not None and life_end is not None and life_start <= life_end:
- self._update_mem_change(size, life_start, life_end + 1, tensor_id)
- else:
- logger.info('Cannot locate lifetime start or end for tensor: %s', tensor_id)
-
- def _update_mem_change(self, size, start, end, tensor_id):
- """Update memory change for the subgraph."""
- for i in range(start, end):
- self._mem_change[i] += size
- # Update tensor lifetime list.
- self._lifetime[i].append(tensor_id)
-
- def _locate_fp_bp_id(self):
- """Locate the node id of fp_start and bp_end in graph."""
- point_id = {
- 'fp_start': None,
- 'bp_end': None
- }
- fp_start = self._points.get('fp_start') if self._points else None
- bp_end = self._points.get('bp_end') if self._points else None
- fp_name = fp_start.split('/')[-1] if fp_start else ""
- bp_name = bp_end.split('/')[-1] if bp_end else ""
- if fp_name in self.nodes:
- point_id['fp_start'] = self.nodes[fp_name].node_id
- if bp_name in self.nodes:
- point_id['bp_end'] = self.nodes[bp_name].node_id
-
- return point_id
-
- def _process_memory_breakdowns(self):
- """Process memory breakdowns for each node."""
- self.breakdowns = [[] for _ in range(len(self.nodes))]
- for index, breakdown in enumerate(self._lifetime):
- for t_id in breakdown:
- tensor = self.tensors.get(t_id)
- source_node = tensor.source_node
- if not source_node:
- continue
- node = self.nodes.get(source_node)
- tensor_dict = self._get_tensor_dict(node, tensor, t_id)
- self.breakdowns[index].append(tensor_dict)
-
- @staticmethod
- def _get_tensor_dict(node, tensor, t_id):
- """Update node outputs to assemble memory breakdowns."""
- for i, output_id in enumerate(node.output_ids):
- if t_id == output_id:
- output = node.outputs[i] if i < len(node.outputs) else {}
- tensor.name = node.name + ':' + str(i)
- tensor.shape = output.get('shape')
- tensor.dtype = output.get('data_type')
- tensor.format = output.get('format')
- tensor.type = 'output'
-
- return tensor.to_dict()
|