# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The parser for parsing hccl files.""" import csv import json import os import stat from enum import Enum import numpy as np from mindspore.profiler.common.exceptions.exceptions import \ ProfilerPathErrorException, ProfilerFileNotFoundException, \ ProfilerDirNotFoundException, ProfilerRawFileException from mindspore import log as logger from mindspore.profiler.common.validator.validate_path import \ validate_and_normalize_path class CommunicationInfo(Enum): """ Communication related enumeration types. Enum: RDMA: Communication link between servers in cluster training. SDMA: Communication link inside server in cluster training. LOCAL: The operation of this card has no transmission process. RDMASEND:Communication operator of RDMA link. REDUCE_INLINE:Communication operator of SDMA link. MEMCPY:Communication operator of SDMA link. NOTIFY_RECORD: Communication operator of SDMA link. NOTIFY_WAIT: operator of LOCAL. """ RDMA = 'RDMA' SDMA = 'SDMA' LOCAL = 'LOCAL' RDMASEND = 'RDMASend' REDUCE_INLINE = 'Reduce Inline' MEMCPY = 'Memcpy' NOTIFY_RECORD = 'Notify Record' NOTIFY_WAIT = 'Notify Wait' class HcclParser: """ The parser for parsing hccl file. Args: source_dir (str): The hccl source dir. device_id (str): The device ID. rank_id (str): The rank ID. output_path (str): The directory of the parsed file. Default: `./`. Raises: ProfilerPathErrorException: If the hccl file path or the output path is invalid. ProfilerFileNotFoundException: If the hccl file or the output dir does not exist. """ _parsed_hccl_file_name = 'hccl_raw_{}.csv' _col_names = ['step_num', 'communication_cost', 'wait_cost', 'link_info', 'communication_operator_cost'] def __init__(self, source_dir, device_id, rank_id, output_path): self._dev_id = device_id self._rank_id = rank_id self._source_dir = source_dir self._save_path = self._get_save_path(output_path) self._step_trace_info = self._get_step_trace_info(output_path) self._communication_operator_name_mapping_info = self._get_communication_operator_name_mapping_info() def parse(self): """Parse communication info.""" self._parse_and_save(self._source_dir) def _parse_communication_cost(self, operators_cost_info, info, operators_dict): """Parse communication cost.""" for key, value in operators_cost_info.items(): for item in value: # index0:step_num if info[0] == item[0]: operators_dict[key] = item def _parse_and_save(self, dir_path): """Parse and save communication info.""" communication_info_cache = list() operators_cost_info = self._get_communication_operators_cost_info(dir_path) for key, value in operators_cost_info.items(): for item in value: communication_info_cache.append(item) communication_info_cache = self._merge_communication_info_by_step_num(communication_info_cache) for info in communication_info_cache: operators_dict = dict() self._parse_communication_cost(operators_cost_info, info, operators_dict) info.append(operators_dict) # Calculate device communication average. device_communication_average_value = self._calculate_communication_average_value(communication_info_cache) # Calculate operator communication average. operators_average_value = dict() for key, value in operators_cost_info.items(): average_value = self._calculate_communication_average_value(value) # The symbol '-' is used to indicate that the line is average information. average_value.insert(0, '-') operators_average_value[key] = average_value device_communication_average_value.append(operators_average_value) # The symbol '-' is used to indicate that the line is average information. device_communication_average_value.insert(0, '-') with open(self._save_path, 'w', newline='') as save_file: csv_writer = csv.writer(save_file) csv_writer.writerow(self._col_names) for item in communication_info_cache: # item[3]:link_info which is a dictionary that needs to be encoded before it is written to a CSV file. # item[4]:it is a dictionary that needs to be encoded before it is written to a CSV file. item[3] = json.dumps(item[3]) item[4] = json.dumps(item[4]) csv_writer.writerow(item) # device_communication_average_value[3]: average value for link info # device_communication_average_value[4]: average value for operator info device_communication_average_value[3] = json.dumps(device_communication_average_value[3]) device_communication_average_value[4] = json.dumps(device_communication_average_value[4]) csv_writer.writerow(device_communication_average_value) os.chmod(self._save_path, stat.S_IREAD | stat.S_IWRITE) def _get_save_path(self, output_path): """ Get the save path. Args: output_path (str): The output dir. Returns: str, the save path. """ output_path = self._validate_dir_path(output_path) return os.path.join( output_path, self._parsed_hccl_file_name.format(self._rank_id) ) def _get_step_trace_info(self, source_dir): """Get the start and end timestamps in a step and communication operators names.""" file_path = os.path.join( source_dir, f'step_trace_raw_{self._rank_id}_detail_time.csv' ) try: file_path = validate_and_normalize_path(file_path) except RuntimeError: logger.warning('file path is invalid.') raise ProfilerPathErrorException('file path is invalid.') if not os.path.isfile(file_path): logger.warning('The step trace file <%s> not found.', file_path) raise ProfilerFileNotFoundException(file_path) with open(file_path, 'r') as src_file: csv_reader = csv.reader(src_file) # The first row of step trace file is like: step_num, start_point,...,communication_operator_name. # The position number of the first communication operator name is 9. communication_operators_names = next(csv_reader)[9:] # index_0:step_num, index_1:start_point, index_2:end_point # The unit of time stamp is 10ns. To convert it to μs, you need to divide it by 100. step_timestamps_info = [[info[0], float(info[1]) / 100, float(info[2]) / 100] for info in csv_reader if info[0].isdigit()] return [communication_operators_names, step_timestamps_info] def _get_communication_operator_name_mapping_info(self): """Get the name of communication operators mapping between hccl and step trace.""" dir_path = self._validate_dir_path(self._source_dir) # The name of the operator in hccl is like:operatorName_{Ordered_number}_xx_xx. operators_names_in_hccl = [entry.name for entry in os.scandir(dir_path) if entry.is_dir()] operators_names_in_hccl_set = set({i.split('_')[0] for i in operators_names_in_hccl}) op_names_in_hccl_dic = dict() for item in operators_names_in_hccl_set: op_names_in_hccl_dic[item] = sorted([i for i in operators_names_in_hccl if i.split('_')[0] == item], key=lambda x: int(x.split('_')[1])) # The op_info in step trace is like: [op_name,op_name_start_point,op_name_end_point] # The name of the operator in step trace can be obtained every three. # The name of the operator in step trace is like: stream_xx_xx_operatorName-opxx. operators_names_in_step_trace = [self._step_trace_info[0][i] for i in range(0, len(self._step_trace_info[0]), 3)] op_names_in_step_trace_set = set({i.split('_')[3].split('-')[0] for i in operators_names_in_step_trace}) op_names_in_step_trace_dic = dict() for item in op_names_in_step_trace_set: op_names_in_step_trace_dic[item] = [i for i in operators_names_in_step_trace if i.split('_')[3].split('-')[0] == item] communication_operator_mapping_info = dict() for hccl_key, hccl_value in op_names_in_hccl_dic.items(): for step_trace_key, step_trace_value in op_names_in_step_trace_dic.items(): if hccl_key.lower() == step_trace_key.lower(): communication_operator_mapping_info[hccl_key] = list(zip(hccl_value, step_trace_value)) logger.info("Communication operator name mapping info is %s", communication_operator_mapping_info) return communication_operator_mapping_info def _calculate_the_step_by_timestamp(self, timestamp): """Calculate the step according to the timestamp.""" # index0:communication_operator_name, index1:step_timestamps_info step_timestamps_info = self._step_trace_info[1] step_timestamps_len = len(step_timestamps_info) # index_0:step_num, index_1:start_point, index_2:end_point if timestamp < step_timestamps_info[0][1]: step_num = "1" elif step_timestamps_info[step_timestamps_len - 1][2] < timestamp: step_num = step_timestamps_info[step_timestamps_len - 1][0] else: for item in step_timestamps_info: if item[1] <= timestamp < item[2]: step_num = item[0] return step_num def _get_communication_operators_cost_info(self, dir_path): """Obtain time-consuming information of all communication operators.""" operators_cost_info = dict() dir_path = self._validate_dir_path(dir_path) operators_dir = [entry.name for entry in os.scandir(dir_path) if entry.is_dir()] operator_dir_path = [os.path.join(dir_path, operator_dir) for operator_dir in operators_dir] for operator_dir in operator_dir_path: operator_cost = self._calculate_communication_operator_cost(operator_dir) operator_name = os.path.basename(operator_dir) op_mapping_info = self._communication_operator_name_mapping_info.get(operator_name.split('_')[0], []) # index1: operator name in step trace. op_mapping_name = [item[1] for item in op_mapping_info if item[0] == operator_name] if not op_mapping_name: logger.warning("The mapping relationship between op name in hccl and op name in step trace " "cannot be found. Use op name in hccl to show the name of the communication operator.") else: operator_name = op_mapping_name[0] operators_cost_info[operator_name] = operator_cost return operators_cost_info def _calculate_communication_operator_cost(self, dir_path): """Calculate communication operator cost. Such as allReduce_1,allReduce_2.""" dir_path = self._validate_dir_path(dir_path) files = [entry.name for entry in os.scandir(dir_path) if entry.is_file()] files_path = [os.path.join(dir_path, file) for file in files] operator_cost = list(map(self._calculate_communication_operator_iter_cost, files_path)) # Add the same step_num merge. steps_operator_cost = self._merge_communication_info_by_step_num(operator_cost) return steps_operator_cost def _merge_communication_info_by_step_num(self, communication_info: list): """According to step num to merge communication info.""" steps_communication_info = list() info_set = set() for item in communication_info: # index0:step_num,index1:communication_cost,index2:communication_wait_cost,index3:link_info if item[0].isdigit(): info_set.add(int(item[0])) info_set = sorted(info_set) for item in info_set: item = str(item) step_communication_info = [info for info in communication_info if info[0] == item] step_communication_cost = sum([i[1] for i in step_communication_info]) step_communication_wait_cost = sum([i[2] for i in step_communication_info]) step_communication_link = self._calculate_link_value([i[3] for i in step_communication_info], "total") steps_communication_info.append([item, step_communication_cost, step_communication_wait_cost, step_communication_link]) return steps_communication_info def _calculate_communication_operator_iter_cost(self, file_path): """Calculate the time-consuming of communication operator in one execution round.""" def _inner_calculate_communication_operator_iter_cost(events): total_notify_wait = HcclParser._calculate_notify_wait_time(events) # Divide information by src dst rank_id. src_dst_dict = self._divide_communication_info_by_src_dst_rank(events) src_dst_link_info = self._calculate_src_dst_link_info(src_dst_dict) communication_cost, communication_wait = self._calculate_device_communication_cost(src_dst_link_info) total_notify_wait -= communication_wait return [communication_cost, total_notify_wait, src_dst_link_info] file_path = self._validate_file_path(file_path) with open(file_path, 'r') as src_file: try: operator_info = json.load(src_file) except (json.JSONDecodeError, TypeError) as err: logger.warning(err) raise ProfilerRawFileException('Fail to parse operator file.') trace_events = operator_info.get("traceEvents") operator_timestamp = trace_events[0].get("ts", 0) step_id = self._calculate_the_step_by_timestamp(operator_timestamp) # Statistics of communication operators in all streams. total_communication_operator_iter_cost = \ _inner_calculate_communication_operator_iter_cost(trace_events) # Statistics of communication operators in mainstream. threads_dict = self._divide_communication_info_by_thread(trace_events) # The largest value is mainstream. major_thread = sorted(threads_dict, reverse=True)[0] major_thread_trace_events = threads_dict.get(major_thread) mainstream_communication_operator_iter_cost = \ _inner_calculate_communication_operator_iter_cost(major_thread_trace_events) # index0:communication_cost,index1:communication_wait_cost,index2:link_info return [step_id, mainstream_communication_operator_iter_cost[0], mainstream_communication_operator_iter_cost[1], total_communication_operator_iter_cost[2]] @staticmethod def _divide_communication_info_by_thread(trace_events: list): """Divide information by thread.""" threads_dict = dict() for item in trace_events: thread_id = item.get("tid") if thread_id not in threads_dict.keys(): threads_dict[thread_id] = [item] else: threads_dict[thread_id].append(item) return threads_dict def _divide_communication_info_by_src_dst_rank(self, trace_event: list): """Divide information by src rank id and dst rank id""" src_dst_dict = dict() for item in trace_event: src_rank = item.get("args").get("src rank") dst_rank = item.get("args").get("dst rank") if src_rank is None or dst_rank is None: continue # When the SDMA operation is in the card, # the source card or destination card is 0xffffffff, and it needs to be converted to localrank. if int(src_rank) == int('0xffffffff', 16): src_rank = dst_rank if int(dst_rank) == int('0xffffffff', 16): dst_rank = src_rank if item.get("args").get("transport type") == CommunicationInfo.LOCAL.value: item["args"]["src rank"] = dst_rank item["args"]["dst rank"] = src_rank src_dst_key = str(dst_rank) + '-' + str(src_rank) else: src_dst_key = str(src_rank) + '-' + str(dst_rank) if src_dst_key not in src_dst_dict.keys(): src_dst_dict[src_dst_key] = [item] else: src_dst_dict[src_dst_key].append(item) return src_dst_dict def _divide_communication_info_by_link_type(self, trace_event: list): """Divide information by link type.""" link_type_dict = dict() for item in trace_event: link_type_key = item.get("args").get("transport type") if link_type_key is None: continue if link_type_key in (CommunicationInfo.RDMA.value, CommunicationInfo.SDMA.value): task_type = item.get("args").get("task type") # Filter out the Notify Record operator in SDMA, because it does not transmit the actual amount of data. if task_type == CommunicationInfo.NOTIFY_RECORD.value: continue if link_type_dict.get(link_type_key): link_type_dict[link_type_key].append(item) else: link_type_dict[link_type_key] = [item] if link_type_key == CommunicationInfo.LOCAL.value: if link_type_dict.get(CommunicationInfo.RDMA.value): link_type_dict[CommunicationInfo.RDMA.value].append(item) return link_type_dict def _calculate_device_communication_cost(self, src_dst_link_info: dict): """Calculate notify wait time.""" total_communication_time = 0 total_wait_time = 0 for src_dst_value in src_dst_link_info.values(): for link_type_value in src_dst_value.values(): # time_cost:0,size_cost:1,brand_width:2,wait_time:3 total_communication_time += link_type_value[0] if len(link_type_value) > 3: total_wait_time += link_type_value[3] return total_communication_time, total_wait_time def _parse_link_cost(self, result_dict, key, link_type_dict): """Parse link cost.""" for link_type_key, link_type_value in link_type_dict.items(): if link_type_key == CommunicationInfo.RDMA.value: # Divide information by thread. rdma_infos = [] threads_dict = self._divide_communication_info_by_thread(link_type_value) for thread_value in threads_dict.values(): rdma_info = self._calculate_adma_link_info(thread_value) rdma_infos.append(rdma_info) rdma_total_cost = np.sum(rdma_infos, axis=0).tolist() result_dict[key][link_type_key] = rdma_total_cost if link_type_key == CommunicationInfo.SDMA.value: sdma_total_cost = self._calculate_sdma_link_info(link_type_value) result_dict[key][link_type_key] = sdma_total_cost def _calculate_src_dst_link_info(self, src_dst_dict: dict): """Calculate src dst link info.""" result_dict = dict() for key, value in src_dst_dict.items(): # Divide information by link type. link_type_dict = self._divide_communication_info_by_link_type(value) if not link_type_dict: continue result_dict[key] = dict() self._parse_link_cost(result_dict, key, link_type_dict) return result_dict @staticmethod def _calculate_adma_link_info(trace_event: list): """ Calculate RDMA link info. When the link is RDMA,it is necessary to match three consecutive operators RDMASend, RDMASend \ and Notify Wait,and take the sum of the time of the three operators as one communication time. """ rdma_communication_time = 0 rdma_communication_size = 0 rdma_communication_wait_time = 0 start_index = 0 end_index = len(trace_event) - 1 while start_index < end_index: first_task_type = trace_event[start_index].get("args").get("task type") if first_task_type == CommunicationInfo.RDMASEND.value and start_index < end_index - 1: second_task_type = trace_event[start_index + 1].get("args").get("task type") third_task_type = trace_event[start_index + 2].get("args").get("task type") if second_task_type == CommunicationInfo.RDMASEND.value and \ third_task_type == CommunicationInfo.NOTIFY_WAIT.value: rdma_send_cost = trace_event[start_index].get("dur", 0) notify_record_cost = trace_event[start_index + 1].get("dur", 0) notify_wait_cost = trace_event[start_index + 2].get("dur", 0) rdma_communication_time += rdma_send_cost + notify_record_cost + notify_wait_cost rdma_communication_wait_time += notify_wait_cost rdma_size = trace_event[start_index].get("args").get("size") if rdma_size: rdma_size = rdma_size if isinstance(rdma_size, int) else int(rdma_size, 16) else: rdma_size = 0 notify_record_size = trace_event[start_index + 1].get("args").get("size") if notify_record_size: notify_record_size = notify_record_size if isinstance(notify_record_size, int) \ else int(notify_record_size, 16) else: notify_record_size = 0 rdma_communication_size += rdma_size + notify_record_size start_index += 2 start_index += 1 # The unit of rdma_communication_wait_time is ms. # The unit of rdma_bandwidth is KB/s. # The unit of rdma_communication_size is k_byte and The unit of rdma_communication_time is ms. rdma_communication_wait_time = rdma_communication_wait_time / 1e3 rdma_communication_size = rdma_communication_size / 1e3 rdma_communication_time = rdma_communication_time / 1e3 rdma_bandwidth = rdma_communication_size / (rdma_communication_time / 1e3) \ if rdma_communication_size else 0 return [rdma_communication_time, rdma_communication_size, rdma_bandwidth, rdma_communication_wait_time] def _calculate_sdma_link_info(self, trace_event: list): """ Calculate SDMA link info. When the link is SDMA, the communication time of the primary link is the sum of the execution time\ of Reduce inline and Memcpy operators. """ sdma_communication_time = 0 sdma_communication_size = 0 for item in trace_event: task_type = item.get("args").get("task type") if task_type in (CommunicationInfo.REDUCE_INLINE.value, CommunicationInfo.MEMCPY.value): sdma_communication_time += item.get("dur", 0) sdma_size = item.get("args").get("size") if sdma_size: sdma_size = sdma_size if isinstance(sdma_size, int) else int(sdma_size, 16) else: sdma_size = 0 sdma_communication_size += sdma_size # The unit of sdma_bandwidth is KB/s. # The unit of sdma_communication_size is k_byte and The unit of sdma_communication_time is ms. sdma_communication_time = sdma_communication_time / 1e3 sdma_communication_size = sdma_communication_size / 1e3 sdma_bandwidth = sdma_communication_size / (sdma_communication_time / 1e3) \ if sdma_communication_size else 0 return [sdma_communication_time, sdma_communication_size, sdma_bandwidth] @staticmethod def _calculate_notify_wait_time(trace_event: list): """Calculate notify wait time.""" total_notify_wait_time = 0 for item in trace_event: task_type = item.get("args").get("task type") if task_type == CommunicationInfo.NOTIFY_WAIT.value: total_notify_wait_time += item.get("dur", 0) # The unit of total_notify_wait_time is ms. total_notify_wait_time = total_notify_wait_time / 1e3 return total_notify_wait_time def _calculate_communication_average_value(self, communication_info: list): """Calculate communication average value.""" communication_info_size = len(communication_info) if communication_info_size == 0: return [] # index1: communication_cost,index2:wait_cost,index3:link_info communication_cost_average = sum([i[1] for i in communication_info]) / communication_info_size wait_cost_average = sum([i[2] for i in communication_info]) / communication_info_size link_info = [i[3] for i in communication_info] calculate_type = 'average' link_average_info = HcclParser._calculate_link_value(link_info, calculate_type) return [communication_cost_average, wait_cost_average, link_average_info] @staticmethod def _parser_link_dict(result_dict, src_dst_key, src_dst_value): """Parser link info to dict.""" if src_dst_key not in result_dict.keys(): result_dict[src_dst_key] = dict() for link_key, link_value in src_dst_value.items(): if link_key not in result_dict[src_dst_key].keys(): result_dict[src_dst_key][link_key] = list() result_dict[src_dst_key][link_key].append(link_value) @staticmethod def _calculate_link_value(link_info: list, calculate_type): """Calculate link average or total value.""" result_dict = dict() for item in link_info: for src_dst_key, src_dst_value in item.items(): HcclParser._parser_link_dict(result_dict, src_dst_key, src_dst_value) for src_dst_key, src_dst_value in result_dict.items(): for link_key, _ in src_dst_value.items(): if calculate_type == 'average': result_dict[src_dst_key][link_key] = np.mean(result_dict[src_dst_key][link_key], axis=0).tolist() if calculate_type == 'total': result_dict[src_dst_key][link_key] = np.sum(result_dict[src_dst_key][link_key], axis=0).tolist() return result_dict def _validate_file_path(self, file_path): """Validate file path.""" try: file_path = validate_and_normalize_path(file_path) except RuntimeError: logger.warning('file path is invalid.') raise ProfilerPathErrorException('file path is invalid.') if not os.path.isfile(file_path): logger.warning('The file <%s> not found.', file_path) raise ProfilerFileNotFoundException(file_path) return file_path def _validate_dir_path(self, dir_path): """Validate dir path.""" try: dir_path = validate_and_normalize_path(dir_path) except RuntimeError: logger.warning('dir path is invalid.') raise ProfilerPathErrorException('dir path is invalid.') if not os.path.isdir(dir_path): logger.warning('The dir <%s> not found.', dir_path) raise ProfilerDirNotFoundException(dir_path) return dir_path