|
|
|
@@ -21,7 +21,7 @@ import json |
|
|
|
from enum import Enum |
|
|
|
|
|
|
|
from mindspore import log as logger, context |
|
|
|
from mindspore.communication.management import release, get_rank |
|
|
|
from mindspore.communication.management import GlobalComm, release, get_rank |
|
|
|
import mindspore._c_expression as c_expression |
|
|
|
from mindspore.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \ |
|
|
|
ProfilerIOException, ProfilerException, ProfilerRawFileException |
|
|
|
@@ -96,7 +96,7 @@ class Profiler: |
|
|
|
self._gpu_profiler = GPUProfiler.get_instance() |
|
|
|
self._gpu_profiler.init(self._output_path) |
|
|
|
self._gpu_profiler.step_profiling_enable(True) |
|
|
|
if context.get_auto_parallel_context('device_num') > 1: |
|
|
|
if GlobalComm.WORLD_COMM_GROUP == "nccl_world_group": |
|
|
|
self._dev_id = str(get_rank()) |
|
|
|
os.environ['DEVICE_ID'] = self._dev_id |
|
|
|
|
|
|
|
@@ -254,7 +254,7 @@ class Profiler: |
|
|
|
|
|
|
|
def _gpu_analyse(self): |
|
|
|
"""Collect and analyse gpu performance data""" |
|
|
|
if context.get_auto_parallel_context('device_num') > 1 and self._dev_id != str(get_rank()): |
|
|
|
if GlobalComm.WORLD_COMM_GROUP == "nccl_world_group" and self._dev_id != str(get_rank()): |
|
|
|
self._dev_id = str(get_rank()) |
|
|
|
logger.error('Please check the Profiler object initialized after mindspore.context.set_auto_parallel_' |
|
|
|
'context() and mindspore.communication.management.init(). Profiler should be initialized' |
|
|
|
|