Browse Source

wait until dumping finished

r1.7
TinaMengtingZhang 4 years ago
parent
commit
438e261bbc
4 changed files with 37 additions and 2 deletions
  1. +21
    -0
      mindspore/ccsrc/debug/debugger/debugger.cc
  2. +2
    -0
      mindspore/ccsrc/debug/debugger/debugger.h
  3. +6
    -0
      mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc
  4. +8
    -2
      tests/st/dump/test_data_dump.py

+ 21
- 0
mindspore/ccsrc/debug/debugger/debugger.cc View File

@@ -1783,6 +1783,27 @@ std::shared_ptr<DumpDataBuilder> Debugger::LoadDumpDataBuilder(const std::string
}

void Debugger::ClearDumpDataBuilder(const std::string &node_name) { (void)dump_data_construct_map_.erase(node_name); }

/*
* Feature group: Dump.
* Target device group: Ascend.
* Runtime category: MindRT.
* Description: This function is used for A+M dump to make sure training processing ends after tensor data have been
* dumped to disk completely. Check if dump_data_construct_map_ is empty to see if no dump task is alive. If not, sleep
* for 500ms and check again.
*/
void Debugger::WaitForWriteFileFinished() {
const int kRetryTimeInMilliseconds = 500;
const int kMaxRecheckCount = 10;
int recheck_cnt = 0;
while (recheck_cnt < kMaxRecheckCount && !dump_data_construct_map_.empty()) {
MS_LOG(INFO) << "Sleep for " << std::to_string(kRetryTimeInMilliseconds)
<< " ms to wait for dumping files to finish. Retry count: " << std::to_string(recheck_cnt + 1) << "/"
<< std::to_string(kMaxRecheckCount);
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryTimeInMilliseconds));
recheck_cnt++;
}
}
#endif

} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/debug/debugger/debugger.h View File

@@ -199,6 +199,8 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
std::shared_ptr<DumpDataBuilder> LoadDumpDataBuilder(const std::string &node_name);

void ClearDumpDataBuilder(const std::string &node_name);

void WaitForWriteFileFinished();
#endif

private:


+ 6
- 0
mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc View File

@@ -249,6 +249,12 @@ bool AscendKernelRuntime::NeedDestroyHccl() {
#ifndef ENABLE_SECURITY
void AsyncDataDumpUninit() {
if (DumpJsonParser::GetInstance().async_dump_enabled()) {
#if ENABLE_D
// When it is A+M dump mode, wait until file save is finished.
if (DumpJsonParser::GetInstance().FileFormatIsNpy()) {
Debugger::GetInstance()->WaitForWriteFileFinished();
}
#endif
if (AdxDataDumpServerUnInit() != 0) {
MS_LOG(ERROR) << "Adx data dump server uninit failed";
}


+ 8
- 2
tests/st/dump/test_data_dump.py View File

@@ -433,6 +433,12 @@ def check_data_dump(dump_file_path):
expect = np.array([[8, 10, 12], [14, 16, 18]], np.float32)
assert np.array_equal(output, expect)


def run_train():
add = Net()
add(Tensor(x), Tensor(y))


def run_saved_data_dump_test(scenario, saved_data):
"""Run e2e dump on scenario, testing statistic dump"""
if sys.platform != 'linux':
@@ -445,8 +451,8 @@ def run_saved_data_dump_test(scenario, saved_data):
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
if os.path.isdir(dump_path):
shutil.rmtree(dump_path)
add = Net()
add(Tensor(x), Tensor(y))
exec_network_cmd = 'cd {0}; python -c "from test_data_dump import run_train; run_train()"'.format(os.getcwd())
_ = os.system(exec_network_cmd)
for _ in range(3):
if not os.path.exists(dump_file_path):
time.sleep(2)


Loading…
Cancel
Save