#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """parsing_profiling_data""" import os import subprocess import struct import re from akg import tvm OUTPUT_FORMAT_DATA = "./output_format_data_hwts.txt" max_time_consume = 9999999999 def get_log_slice_id(file_name): pattern = re.compile(r'(?<=slice_)\d+') slice_ = pattern.findall(file_name) index = re.findall(r'\d+', slice_[0]) return int(index[0]) def get_file_join_name(input_path=None, file_name=None): """Function for getting join name from input path.""" name_list = [] file_join_name = '' if os.path.exists(input_path): files = os.listdir(input_path) for f in files: if file_name in f and not f.endswith('.done') and not f.endswith('.join'): name_list.append(f) # resort name_list name_list.sort(key=get_log_slice_id) if len(name_list) == 1: file_join_name = input_path + os.sep + name_list[0] elif len(name_list) > 1: file_join_name = input_path + os.sep + '%s.join' % file_name if os.path.exists(file_join_name): os.remove(file_join_name) with open(file_join_name, 'ab') as bin_data: for i in name_list: file = input_path + os.sep + i with open(file, 'rb') as txt: bin_data.write(txt.read()) return file_join_name def fwrite_format(output_data_path=OUTPUT_FORMAT_DATA, data_source=None, is_start=False): if is_start and os.path.exists(OUTPUT_FORMAT_DATA): os.remove(OUTPUT_FORMAT_DATA) with open(output_data_path, 'a+') as f: if isinstance(data_source, (list, tuple)): for raw_data in data_source: if isinstance(raw_data, (list, tuple)): raw_data = map(str, raw_data) raw_data = " ".join(raw_data) f.write(raw_data) f.write("\n") else: f.write(data_source) f.write("\n") def validate_and_normalize_path( path, check_absolute_path=False, allow_parent_dir=True, ): """ Validates path and returns its normalized form. If path has a valid scheme, treat path as url, otherwise consider path a unix local path. Note: File scheme (rfc8089) is currently not supported. Args: path (str): Path to be normalized. check_absolute_path (bool): Whether check path scheme is supported. allow_parent_dir (bool): Whether allow parent dir in path. Returns: str, normalized path. """ if not path: raise RuntimeError("The path is invalid!") path_str = str(path) if not allow_parent_dir: path_components = path_str.split("/") if ".." in path_components: raise RuntimeError("The parent path is not allowed!") # path does not have valid schema, treat it as unix local path. if check_absolute_path: if not path_str.startswith("/"): raise RuntimeError("The path is invalid!") try: # most unix systems allow normalized_path = os.path.realpath(path) except ValueError: raise RuntimeError("The path is invalid!") return normalized_path class HWTSLogParser: """ The Parser for hwts log files. Args: input_path (str): The profiling job path. Such as: '/var/log/npu/profiling/JOBAIFGJEJFEDCBAEADIFJAAAAAAAAAA". output_filename (str): The output data path and name. Such as: './output_format_data_hwts_0.txt'. """ _source_file_target_old = 'hwts.log.data.45.dev.profiler_default_tag' _source_file_target = 'hwts.data' _dst_file_title = 'title:45 HWTS data' _dst_file_column_title = 'Type cnt Core_ID Block_ID Task_ID Cycle_counter Stream_ID' def __init__(self, input_path, output_filename=None, is_print=False): self._input_path = input_path self._output_filename = output_filename self._source_flie_name = self._get_source_file() self._is_print = is_print def _get_source_file(self): """Get hwts log file name, which was created by ada service.""" file_name = get_file_join_name(self._input_path, self._source_file_target) if not file_name: file_name = get_file_join_name(self._input_path, self._source_file_target_old) if not file_name: data_path = os.path.join(self._input_path, "data") file_name = get_file_join_name(data_path, self._source_file_target) if not file_name: file_name = get_file_join_name(data_path, self._source_file_target_old) if not file_name: msg = "Fail to find hwts log file, under profiling directory" raise RuntimeError(msg) return file_name def execute(self): """ Execute the parser, get result data, and write it to the output file. Returns: bool, whether succeed to analyse hwts log. """ content_format = ['QIIIIIIIIIIII', 'QIIQIIIIIIII', 'IIIIQIIIIIIII'] log_type = ['Start of task', 'End of task', 'Start of block', 'End of block', 'Block PMU'] result_data = "" self._source_flie_name = validate_and_normalize_path(self._source_flie_name) last_syscnt = 0 cycles = 0 kernel_label = tvm.get_global_func("ascend_get_kernel_label")() with open(self._source_flie_name, 'rb') as hwts_data: while True: # read 64 bit data line = hwts_data.read(64) if line: if not line.strip(): continue else: break byte_first_four = struct.unpack('BBHHH', line[0:8]) # byte_first[0:4] refers to count. byte_first[4] refers to is_warn_res0_0v. # byte_first[5:8] refers to the type of ms. byte_first = bin(byte_first_four[0]).replace('0b', '').zfill(8) ms_type = byte_first[-3:] is_warn_res0_ov = byte_first[4] cnt = int(byte_first[0:4], 2) core_id = byte_first_four[1] blk_id, task_id = byte_first_four[3], byte_first_four[4] if ms_type in ['000', '001', '010']: # log type 0,1,2 result = struct.unpack(content_format[0], line[8:]) syscnt = result[0] stream_id = result[1] elif ms_type == '011': # log type 3 result = struct.unpack(content_format[1], line[8:]) syscnt = result[0] stream_id = result[1] elif ms_type == '100': # log type 4 result = struct.unpack(content_format[2], line[8:]) stream_id = result[2] if is_warn_res0_ov == '0': syscnt = result[4] else: syscnt = None else: logger.info("Profiling: invalid hwts log record type %s", ms_type) continue if int(task_id) < 25000: task_id = str(task_id) if kernel_label == (str(stream_id) + '_' +str(task_id)): if log_type[int(ms_type, 2)] == "Start of task": last_syscnt = syscnt elif log_type[int(ms_type, 2)] == "End of task": cycles += syscnt - last_syscnt if self._is_print: result_data += ("%-14s %-4s %-8s %-9s %-8s %-15s %s\n" %(log_type[int(ms_type, 2)], cnt, core_id, blk_id, task_id, syscnt, stream_id)) if self._is_print: fwrite_format(self._output_filename, data_source=self._dst_file_title, is_start=True) fwrite_format(self._output_filename, data_source=self._dst_file_column_title) fwrite_format(self._output_filename, data_source=result_data) return cycles if cycles != 0 else max_time_consume