|
|
|
@@ -18,7 +18,7 @@ import time |
|
|
|
from enum import Enum |
|
|
|
|
|
|
|
from mindspore import log as logger, context |
|
|
|
from mindspore.communication.management import release, init, get_rank |
|
|
|
from mindspore.communication.management import release, get_rank |
|
|
|
from mindspore.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \ |
|
|
|
ProfilerIOException, ProfilerException |
|
|
|
from mindspore.profiler.common.util import get_file_names, fwrite_format |
|
|
|
@@ -96,8 +96,8 @@ class Profiler: |
|
|
|
self._gpu_profiler = GPUProfiler.get_instance() |
|
|
|
self._gpu_profiler.init(self._output_path) |
|
|
|
self._gpu_profiler.step_profiling_enable(True) |
|
|
|
init() |
|
|
|
self._dev_id = get_rank() |
|
|
|
if context.get_auto_parallel_context('device_num') > 1: |
|
|
|
self._dev_id = get_rank() |
|
|
|
os.environ['DEVICE_ID'] = str(self._dev_id) |
|
|
|
|
|
|
|
if kwargs: |
|
|
|
|