Browse Source

!13129 Fix the bug of gpu profiler get device_id incorrectly

From: @gzhcv
Reviewed-by: @yelihua,@lilongfei15
Signed-off-by: @lilongfei15
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
80fb285c8e
2 changed files with 8 additions and 5 deletions
  1. +5
    -2
      mindspore/ccsrc/profiler/device/gpu/gpu_profiling_utils.cc
  2. +3
    -3
      mindspore/profiler/profiling.py

+ 5
- 2
mindspore/ccsrc/profiler/device/gpu/gpu_profiling_utils.cc View File

@@ -40,6 +40,8 @@ ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNull<const sessio
return profiling_trace;
}

ProfilingTraceInfo empty_info;
profiling_trace = empty_info;
SetTraceIterEnd(cnode_exec_order);
SetTraceFpStart(cnode_exec_order);
SetTraceBpEnd(cnode_exec_order);
@@ -172,8 +174,9 @@ std::string ProfilingUtils::GetGraphSecondLastKernelName(const std::vector<CNode

bool ProfilingUtils::IsFirstStep(const uint32_t graph_id) {
auto iter = is_first_step_map_.find(graph_id);
if (iter != is_first_step_map_.end()) {
is_first_step_map_[graph_id] = true;
if (iter == is_first_step_map_.end()) {
is_first_step_map_[graph_id] = false;
return true;
}
return is_first_step_map_[graph_id];
}


+ 3
- 3
mindspore/profiler/profiling.py View File

@@ -21,7 +21,7 @@ import json
from enum import Enum

from mindspore import log as logger, context
from mindspore.communication.management import release, get_rank
from mindspore.communication.management import GlobalComm, release, get_rank
import mindspore._c_expression as c_expression
from mindspore.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \
ProfilerIOException, ProfilerException, ProfilerRawFileException
@@ -96,7 +96,7 @@ class Profiler:
self._gpu_profiler = GPUProfiler.get_instance()
self._gpu_profiler.init(self._output_path)
self._gpu_profiler.step_profiling_enable(True)
if context.get_auto_parallel_context('device_num') > 1:
if GlobalComm.WORLD_COMM_GROUP == "nccl_world_group":
self._dev_id = str(get_rank())
os.environ['DEVICE_ID'] = self._dev_id

@@ -254,7 +254,7 @@ class Profiler:

def _gpu_analyse(self):
"""Collect and analyse gpu performance data"""
if context.get_auto_parallel_context('device_num') > 1 and self._dev_id != str(get_rank()):
if GlobalComm.WORLD_COMM_GROUP == "nccl_world_group" and self._dev_id != str(get_rank()):
self._dev_id = str(get_rank())
logger.error('Please check the Profiler object initialized after mindspore.context.set_auto_parallel_'
'context() and mindspore.communication.management.init(). Profiler should be initialized'


Loading…
Cancel
Save