diff --git a/mindspore/profiler/parser/integrator.py b/mindspore/profiler/parser/integrator.py index 3de7d7e65b..3c019a2b22 100644 --- a/mindspore/profiler/parser/integrator.py +++ b/mindspore/profiler/parser/integrator.py @@ -498,14 +498,15 @@ class BaseTimelineGenerator: """ __col_names__ = ['op_name', 'stream_id', 'start_time', 'duration'] _output_timeline_data_file_path = 'output_timeline_data_{}.txt' - _min_cycle_counter_file_path = 'min_cycle_counter_{}.txt' _timeline_meta = [] _timeline_summary = { 'total_time': 0, 'num_of_streams': 0, 'num_of_ops': 0, - 'op_exe_times': 0 + 'op_exe_times': 0, + 'max_scope_name_num': 0, } + _op_name_idx, _tid_idx, _start_time_idx, _duration_idx = 0, 1, 2, 3 def _load_timeline_data(self): """Load timeline data from file.""" @@ -576,67 +577,117 @@ class BaseTimelineGenerator: else: stream_count_dict[stream_id] += 1 - def get_min_cycle_counter(self): - """ - Get minimum cycle counter. - - Returns: - float, the minimum value of the cycle counter. - """ - file_path = os.path.join( - self._profiling_dir, - self._min_cycle_counter_file_path.format(self._device_id) - ) - - file_path = validate_and_normalize_path(file_path) - - if os.path.exists(file_path): - try: - with open(file_path, 'r') as f_obj: - min_cycle_counter = f_obj.read() - min_cycle_counter = float(min_cycle_counter) \ - if not min_cycle_counter == 'inf' else 0 - except (IOError, OSError) as err: - logger.error('Error occurred when read minimum cycle counter: %s', err) - raise ProfilerIOException - else: - min_cycle_counter = 0 - logger.info("No min cycle counter recorded.") - - return min_cycle_counter - - def _add_framework_info(self, framework_obj_list): - """ - Add framework info into timeline metadata. - - Args: - framework_obj_list (list): The framework metadata. - """ - logger.debug('Start adding framework info into timeline...') - # Get the framework info that will be written into timeline. - framework_info_dict = {} - for framework_obj in framework_obj_list: - op_name = framework_obj[0] - op_type = framework_obj[1] - op_full_name = framework_obj[4] - op_info = framework_obj[5] - framework_info_dict[op_full_name] = { - 'name': op_name, - 'args': { - 'type': op_type, - 'fullname': op_full_name - } - } - framework_info_dict[op_full_name]['args'].update(op_info) - - # Insert framework info into timeline. - for timeline_item in self._timeline_meta: - op_full_name = timeline_item.get('name') - framework_item = framework_info_dict.get(op_full_name) - if framework_item: - timeline_item['name'] = framework_item.get('name') - timeline_item['args'] = framework_item.get('args') - logger.debug('Finished adding framework info into timeline...') + def _get_max_scope_name_num(self, timeline_list): + """Get the max number of scope level from all operator.""" + max_scope_name_num = 0 + for time_item in timeline_list: + cur_scope_name_num = len(time_item[self._op_name_idx].split('/')) - 1 + max_scope_name_num = max(cur_scope_name_num, max_scope_name_num) + + return max_scope_name_num + + def _get_scope_name_time_list(self, timeline_list, subgraph, factor_start_time_to_duration=1): + """Produce the timeline of hierarchical scope name.""" + # the key of scope_name_start_duration_dict is scope name, the value is a dict which store the + # start and end index of time_item in timeline_list. + scope_name_start_duration_dict = {} + scope_name_time_list = [] + op_full_name_idx, scope_name_idx, invalid_idx = 0, 0, -1 + tid = "Name Scope" + for idx, time_item in enumerate(timeline_list): + scope_name_list = time_item[op_full_name_idx].split('/')[:-1] + # skip Default/InitDataSetQueue operator. + if time_item[op_full_name_idx].startswith("Default/InitDataSetQueue"): + scope_name_list = [] + # process scope name of subgraph(Default/Gradients/recompute_Default) only. + if scope_name_list and scope_name_list[0] != subgraph: + scope_name_list = [] + # add the level of scope name, used to distinguish the same name at different scope level. + scope_name_list = [f"{scope_level}-{scope_name}" + for scope_level, scope_name in enumerate(scope_name_list)] + + # update the start and end index of time_item according to current scope_name + for scope_name in scope_name_list: + init_start_end_idx_dict = {'start_item_idx': idx, 'end_item_idx': idx} + if scope_name not in scope_name_start_duration_dict: + scope_name_start_duration_dict[scope_name] = init_start_end_idx_dict + if scope_name_start_duration_dict[scope_name]['start_item_idx'] == invalid_idx: + scope_name_start_duration_dict[scope_name] = init_start_end_idx_dict + else: + scope_name_start_duration_dict[scope_name]['end_item_idx'] = idx + # if the key(scope name) in scope_name_start_duration_dict does not appear in scope_name_list, + # it means this key(scope name) is end and it is append to scope_name_time_list. + for key, val in scope_name_start_duration_dict.items(): + if val['start_item_idx'] == invalid_idx: + continue + if (key not in scope_name_list) \ + or idx == (len(timeline_list) - 1) \ + or time_item[op_full_name_idx] == self._step_end_op_name: + start_time = timeline_list[val['start_item_idx']][self._start_time_idx] + duration = (float(timeline_list[val['end_item_idx']][self._start_time_idx]) - float(start_time)) * \ + factor_start_time_to_duration + float(timeline_list[val['end_item_idx']][self._duration_idx]) + scope_name_time_item = [key, tid, start_time, duration] + scope_name_time_list.append(scope_name_time_item) + scope_name_start_duration_dict[key]['start_item_idx'] = invalid_idx + + # x[scope_name_idx] is a scope name like "0-Default". + # if two element in scope_name_time_list have the same start time, + # the previous element in list will displayed at the higher line in UI page. + scope_name_time_list.sort(key=lambda x: (float(x[self._start_time_idx]), + x[scope_name_idx].split('-')[0])) + + return scope_name_time_list + + def _set_step_start_and_end_op_name(self, timeline_list): + """Set the start and end operator full name of each step.""" + if not timeline_list: + return + start_op_idx = 0 + if timeline_list[0][self._op_name_idx].startswith("Default/InitDataSetQueue"): + start_op_idx = 1 + self._step_start_op_name = timeline_list[start_op_idx][self._op_name_idx] + self._step_end_op_name = self._step_start_op_name + if len(timeline_list) > (start_op_idx + 1): + for time_item in timeline_list[start_op_idx + 1:]: + if time_item[self._op_name_idx] != self._step_start_op_name: + self._step_end_op_name = time_item[self._op_name_idx] + else: + break + + def _get_step_time_list(self, timeline_list, factor_start_time_to_duration=1): + """Produce the time of each step.""" + # Record the time of each step. + step_time_list = [] + step_num = 1 + tid = "Steps" + cur_step_start_time, cur_step_duration_time = 0, 0 + for time_item in timeline_list: + if time_item[self._op_name_idx] == self._step_start_op_name: + cur_step_start_time = time_item[self._start_time_idx] + if time_item[self._op_name_idx] == self._step_end_op_name: + cur_step_duration_time = (float(time_item[self._start_time_idx]) - float(cur_step_start_time)) * \ + factor_start_time_to_duration + float(time_item[self._duration_idx]) + step_time_item = [str(step_num), tid, float(cur_step_start_time), cur_step_duration_time] + step_time_list.append(step_time_item) + step_num += 1 + + return step_time_list + + def _adjust_timeline_arrange(self, timeline_list): + """Place the step time and scope name in the front of timeline list.""" + first_step_time_idx = 0 + first_scope_time_idx = 0 + for idx, time_item in enumerate(timeline_list): + if time_item[self._tid_idx] == "Steps" and first_step_time_idx == 0: + first_step_time_idx = idx + if time_item[self._tid_idx] == "Name Scope" and first_step_time_idx == 0: + first_scope_time_idx = idx + if first_scope_time_idx and first_step_time_idx: + break + first_scope_time_item = timeline_list.pop(first_scope_time_idx) + timeline_list.insert(0, first_scope_time_item) + first_step_time_item = timeline_list.pop(first_step_time_idx) + timeline_list.insert(0, first_step_time_item) class GpuTimelineGenerator(BaseTimelineGenerator): """Generate gpu Timeline data from file.""" @@ -651,12 +702,7 @@ class GpuTimelineGenerator(BaseTimelineGenerator): self._profiling_dir = profiling_dir self._device_id = device_id self._timeline_meta = [] - self._timeline_summary = { - 'total_time': 0, - 'num_of_streams': 0, - 'num_of_ops': 0, - 'op_exe_times': 0 - } + self._max_scope_name_num = 0 def _get_and_validate_path(self, file_name): """Generate op or activity file path from file name, and validate this path.""" @@ -683,10 +729,12 @@ class GpuTimelineGenerator(BaseTimelineGenerator): timeline_dict['ts'] = (op_meta.start_time - min_cycle_counter) / factor dur = op_meta.duration timeline_dict['dur'] = dur - if op_meta.pid is None: - timeline_dict['pid'] = int(self._device_id) - else: # AllReduce and AI CPU pid - timeline_dict['pid'] = op_meta.pid + timeline_dict['pid'] = int(self._device_id) + if op_meta.stream_id == "Name Scope": + # remove the level of scope name which has a format like "0-conv2-Conv2d". + timeline_dict['name'] = "-".join(op_meta.op_name.split('-')[1:]) + timeline_dict['scope_level'] = op_meta.op_name.split('-')[0] + if len(timeline) > 4: # len(timeline) > 4 refers to activity data, else op data. # Add args for activity data @@ -694,6 +742,7 @@ class GpuTimelineGenerator(BaseTimelineGenerator): for ix, value in enumerate(timeline[4:]): args_dict[self._activity_keys_list[ix]] = value timeline_dict['args'] = args_dict + timeline_dict['tid'] = f"Stream #{timeline_dict['tid']}" else: # Update total time of operator execution. self._timeline_summary['total_time'] += dur / factor @@ -710,14 +759,43 @@ class GpuTimelineGenerator(BaseTimelineGenerator): activity_args_file_path = self._get_and_validate_path( self._output_gpu_activity_info_file_path) - timeline_list = self._load_op_data(op_file_path) + \ - self._load_activity_data(activity_file_path, activity_args_file_path) + timeline_list = self._load_op_data(op_file_path) + + # Add host cpu op timeline. cpu_timeline_generator = CpuTimelineGenerator(self._profiling_dir, self._device_id) cpu_timeline_list = cpu_timeline_generator.load_cpu_op_data() if cpu_timeline_list: self._clock_synchronize_to_gpu(cpu_timeline_list) timeline_list.extend(cpu_timeline_list) timeline_list.sort(key=lambda x: float(x[2])) + self._max_scope_name_num = self._get_max_scope_name_num(timeline_list) + self._timeline_summary['max_scope_name_num'] = self._max_scope_name_num + + # Generate step time. + factor_start_time_uint_to_duration = 1e-3 + self._set_step_start_and_end_op_name(timeline_list) + step_time_list = self._get_step_time_list(timeline_list, factor_start_time_uint_to_duration) + # Add Scope Name. + default_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "Default", + factor_start_time_uint_to_duration) + gradient_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "Gradients", + factor_start_time_uint_to_duration) + recompute_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "recompute_Default", + factor_start_time_uint_to_duration) + timeline_list.extend(default_scope_name_time_list) + timeline_list.extend(gradient_scope_name_time_list) + timeline_list.extend(recompute_scope_name_time_list) + timeline_list.extend(step_time_list) + + timeline_list.sort(key=lambda x: (float(x[self._start_time_idx]), x[self._tid_idx])) + + # Add cuda activity timeline. + activity_timeline_list = self._load_activity_data(activity_file_path, activity_args_file_path) + timeline_list.extend(activity_timeline_list) + timeline_list.sort(key=lambda x: float(x[2])) + + # In order to show the steps at top of timeline, place the step time in front of the list. + self._adjust_timeline_arrange(timeline_list) return timeline_list @@ -726,18 +804,19 @@ class GpuTimelineGenerator(BaseTimelineGenerator): start_time_file_path = os.path.join(self._profiling_dir, f"start_time_{self._device_id}.txt") try: - with open(start_time_file_path) as f: - lines = f.readlines() + with open(start_time_file_path) as f_obj: + lines = f_obj.readlines() + # lines[0] stores the host monotonic time of start training. host_monotonic_start_time = int(lines[0].strip().split(':')[-1]) + # lines[1] stores the gpu time of start training. gpu_start_time = int(lines[1].strip().split(':')[-1]) except (IOError, OSError) as err: logger.error(f'Error occurred when read {start_time_file_path}: {err}') raise ProfilerIOException time_diff = gpu_start_time - host_monotonic_start_time - start_time = 2 for idx, time_item in enumerate(timeline_list): - timeline_list[idx][start_time] = int(time_item[start_time]) + time_diff + timeline_list[idx][self._start_time_idx] = int(time_item[self._start_time_idx]) + time_diff def _load_op_data(self, op_file_path): """Load operator data from file""" @@ -825,6 +904,7 @@ class AscendTimelineGenerator(BaseTimelineGenerator): def __init__(self, profiling_dir, device_id): self._profiling_dir = profiling_dir self._device_id = device_id + self._max_scope_name_num = 0 def _load_timeline_data(self): """Load timeline data from file.""" @@ -843,6 +923,7 @@ class AscendTimelineGenerator(BaseTimelineGenerator): for line in f_obj: if not line.startswith('op_name'): line_list = line.strip('\n').split(',') + line_list[self._tid_idx] = f"Stream #{line_list[self._tid_idx]}" timeline_list.append(line_list) except (IOError, OSError) as err: logger.error('Error occurred when read timeline intermediate file: %s', err) @@ -862,6 +943,11 @@ class AscendTimelineGenerator(BaseTimelineGenerator): timeline_dict['ts'] = (op_meta.start_time - min_cycle_counter) * factor dur = op_meta.duration * factor timeline_dict['dur'] = dur + if op_meta.stream_id == "Name Scope": + # remove the level of scope name which has a format like "0-conv2-Conv2d". + timeline_dict['name'] = "-".join(op_meta.op_name.split('-')[1:]) + timeline_dict['scope_level'] = op_meta.op_name.split('-')[0] + if op_meta.pid is None: timeline_dict['pid'] = int(self._device_id) # Update total time of operator execution. @@ -888,16 +974,33 @@ class AscendTimelineGenerator(BaseTimelineGenerator): cpu_timeline_generator = CpuTimelineGenerator(self._profiling_dir, self._device_id) cpu_timeline_list = cpu_timeline_generator.get_timeline_data() if cpu_timeline_list: - self._clock_synchronize_to_host(timeline_list, source_path) + self._clock_synchronize_to_device(cpu_timeline_list, source_path) timeline_list.extend(cpu_timeline_list) - timeline_list.sort(key=lambda x: float(x[2])) + timeline_list.sort(key=lambda x: float(x[self._start_time_idx])) + self._max_scope_name_num = self._get_max_scope_name_num(timeline_list) self._timeline_summary['op_exe_times'] = len(timeline_list) + self._timeline_summary['max_scope_name_num'] = self._max_scope_name_num + + # Generate step time. + self._set_step_start_and_end_op_name(timeline_list) + step_time_list = self._get_step_time_list(timeline_list) + + # Add Scope Name. + default_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "Default") + gradient_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "Gradients") + recompute_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "recompute_Default") + timeline_list.extend(step_time_list) + timeline_list.extend(default_scope_name_time_list) + timeline_list.extend(recompute_scope_name_time_list) + timeline_list.extend(gradient_scope_name_time_list) + + timeline_list.sort(key=lambda x: (float(x[self._start_time_idx]), x[self._tid_idx])) # Add AllReduce info to timeline temp list and sort by start time. if all_reduce_info: logger.debug('AllReduce info found. Start adding info into timeline...') timeline_list.extend(all_reduce_info) - timeline_list.sort(key=lambda x: float(x[2])) + timeline_list.sort(key=lambda x: float(x[self._start_time_idx])) # Add AI CPU data into timeline temp list and sort by start time. aicpu_data = aicpu_info.get('info') @@ -909,6 +1012,9 @@ class AscendTimelineGenerator(BaseTimelineGenerator): self._timeline_summary['num_of_ops'] += aicpu_info.get('num_of_ops', 0) self._timeline_summary['total_time'] += aicpu_info.get('total_time', 0) + # In order to show the steps at top of timeline, place the step time in the front of list. + self._adjust_timeline_arrange(timeline_list) + # Init a dict for counting the num of streams. stream_count_dict = {} for timeline in timeline_list: @@ -927,37 +1033,71 @@ class AscendTimelineGenerator(BaseTimelineGenerator): # Update timeline summary info self._timeline_summary['num_of_streams'] += len(stream_count_dict.keys()) - def _clock_synchronize_to_host(self, timeline_list, source_path): - """Synchronize the timestamp from device to host.""" + def _clock_synchronize_to_device(self, timeline_list, source_path): + """Synchronize the timestamp from host to device.""" host_start_file_path = os.path.join(source_path, f"host_start.log.{self._device_id}") dev_start_file_path = os.path.join(source_path, f"dev_start.log.{self._device_id}") try: - with open(host_start_file_path) as f: - lines = f.readlines() + with open(host_start_file_path) as f_obj: + lines = f_obj.readlines() + # lines[2] stores host monotonic_raw time of start training. host_monotonic = int(lines[2].strip().split(':')[1]) except (IOError, OSError) as err: logger.error('Error occurred when read host_start.log: %s', err) raise ProfilerIOException try: - with open(dev_start_file_path) as f: - lines = f.readlines() + with open(dev_start_file_path) as f_obj: + lines = f_obj.readlines() + # lines[2] stores device cycle counter of start training. dev_cntvct = int(lines[2].strip().split(':')[1]) except (IOError, OSError) as err: logger.error('Error occurred when read dev_start.log: %s', err) raise ProfilerIOException - factor_ns_to_ms = 1e6 - factor_ms_to_ten_ns = 1e5 + factor_ns_to_ms = 1e-6 factor_ten_ns_to_ns = 10 - start_time = 2 + factor_ms_to_ns = 1e6 for idx, time_item in enumerate(timeline_list): - cycle_counter = int(float(time_item[start_time]) * factor_ms_to_ten_ns) - host_monotonic_time = host_monotonic + (cycle_counter - dev_cntvct) * factor_ten_ns_to_ns - timeline_list[idx][start_time] = host_monotonic_time / factor_ns_to_ms + host_time = int(float(time_item[self._start_time_idx]) * factor_ms_to_ns) + device_time = dev_cntvct * factor_ten_ns_to_ns + (host_time - host_monotonic) + timeline_list[idx][self._start_time_idx] = device_time * factor_ns_to_ms + + def _add_framework_info(self, framework_obj_list): + """ + Add framework info into timeline metadata. + + Args: + framework_obj_list (list): The framework metadata. + """ + logger.debug('Start adding framework info into timeline...') + # Get the framework info that will be written into timeline. + framework_info_dict = {} + for framework_obj in framework_obj_list: + op_name = framework_obj[0] + op_type = framework_obj[1] + op_full_name = framework_obj[4] + op_info = framework_obj[5] + framework_info_dict[op_full_name] = { + 'name': op_name, + 'args': { + 'type': op_type, + 'fullname': op_full_name + } + } + framework_info_dict[op_full_name]['args'].update(op_info) + + # Insert framework info into timeline. + for timeline_item in self._timeline_meta: + op_full_name = timeline_item.get('name') + framework_item = framework_info_dict.get(op_full_name) + if framework_item: + timeline_item['name'] = framework_item.get('name') + timeline_item['args'] = framework_item.get('args') + logger.debug('Finished adding framework info into timeline...') class CpuTimelineGenerator(GpuTimelineGenerator): - """Generate gpu Timeline data from file.""" + """Generate cpu Timeline data from file.""" _output_op_execute_time_file_path = "cpu_op_execute_timestamp_{}.txt" def _get_and_validate_path(self, file_name): @@ -979,6 +1119,9 @@ class CpuTimelineGenerator(GpuTimelineGenerator): logger.info("No cpu operator info.") return timeline_list timeline_list = self._load_op_data(op_file_path) + factor_ms_to_us = 1e-3 + for time_item in timeline_list: + time_item[self._duration_idx] = float(time_item[self._duration_idx]) / factor_ms_to_us return timeline_list @@ -986,11 +1129,9 @@ class CpuTimelineGenerator(GpuTimelineGenerator): """Get timeline data from file.""" timeline_list = self.load_cpu_op_data() factor_ns_to_ms = 1e6 - start_time = 2 - duration = 3 - for idx, time_item in enumerate(timeline_list): - time_item[start_time] = float(time_item[start_time]) / factor_ns_to_ms - time_item[duration] = float(time_item[duration]) - timeline_list[idx] = time_item + factor_us_to_ms = 1e3 + for time_item in timeline_list: + time_item[self._start_time_idx] = float(time_item[self._start_time_idx]) / factor_ns_to_ms + time_item[self._duration_idx] = float(time_item[self._duration_idx]) / factor_us_to_ms return timeline_list diff --git a/mindspore/profiler/profiling.py b/mindspore/profiler/profiling.py index 06c3b351dd..968aa0d57a 100644 --- a/mindspore/profiler/profiling.py +++ b/mindspore/profiler/profiling.py @@ -454,7 +454,8 @@ class Profiler: path, training_device_id, self._dev_id) if not job_id: - msg = "Fail to get profiling job, please check whether job dir was generated" + msg = "Fail to get profiling job, please check whether job dir was generated, " \ + "or may be the device id from job dir dismatch the device_id in current process." raise RuntimeError(msg) return job_id