Browse Source

!31306 Add recovery context

Merge pull request !31306 from zyli2020/master
r1.7
i-robot Gitee 4 years ago
parent
commit
ff2a023f8a
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 738 additions and 3 deletions
  1. +1
    -0
      mindspore/ccsrc/CMakeLists.txt
  2. +29
    -0
      mindspore/ccsrc/distributed/persistent/storage/file_io_utils.cc
  3. +3
    -0
      mindspore/ccsrc/distributed/persistent/storage/file_io_utils.h
  4. +8
    -3
      mindspore/ccsrc/distributed/persistent/storage/json_utils.cc
  5. +3
    -0
      mindspore/ccsrc/distributed/persistent/storage/json_utils.h
  6. +17
    -0
      mindspore/ccsrc/pipeline/jit/init.cc
  7. +9
    -0
      mindspore/ccsrc/runtime/recovery/CMakeLists.txt
  8. +387
    -0
      mindspore/ccsrc/runtime/recovery/recovery_context.cc
  9. +172
    -0
      mindspore/ccsrc/runtime/recovery/recovery_context.h
  10. +109
    -0
      mindspore/python/mindspore/parallel/_recovery_context.py

+ 1
- 0
mindspore/ccsrc/CMakeLists.txt View File

@@ -258,6 +258,7 @@ set(SUB_COMP
runtime/device
runtime/graph_scheduler
runtime/hardware
runtime/recovery
runtime/pynative
plugin/device/ascend/hal/device
plugin/device/ascend/hal/hardware


+ 29
- 0
mindspore/ccsrc/distributed/persistent/storage/file_io_utils.cc View File

@@ -153,6 +153,35 @@ void FileIOUtils::CreateDir(const std::string &dir_path, mode_t mode) {
MS_LOG(EXCEPTION) << "Failed to create directory " << dir_path << ". Errno = " << errno;
}
}

void FileIOUtils::CreateDirRecursive(const std::string &dir_path, mode_t mode) {
size_t dir_path_len = dir_path.length();
if (dir_path_len > PATH_MAX) {
MS_LOG(EXCEPTION) << "Directory path is too long: " << dir_path;
}

char tmp_dir_path[PATH_MAX] = {0};
for (size_t i = 0; i < dir_path_len; ++i) {
tmp_dir_path[i] = dir_path[i];
if (tmp_dir_path[i] == '/' || dir_path == tmp_dir_path) {
if (access(tmp_dir_path, F_OK) == 0) {
continue;
}

#if defined(_WIN32) || defined(_WIN64)
int32_t ret = mkdir(tmp_dir_path);
#else
int32_t ret = mkdir(tmp_dir_path, mode);
if (ret == 0) {
ChangeFileMode(tmp_dir_path, mode);
}
#endif
if (ret != 0) {
MS_LOG(EXCEPTION) << "Failed to create directory recursion: " << dir_path << ". Errno = " << errno;
}
}
}
}
} // namespace storage
} // namespace distributed
} // namespace mindspore

+ 3
- 0
mindspore/ccsrc/distributed/persistent/storage/file_io_utils.h View File

@@ -42,6 +42,9 @@ class FileIOUtils {

// Create directory.
static void CreateDir(const std::string &dir_path, mode_t mode = S_IRWXU | S_IRWXG | S_IRWXO);

// Create directory recursively.
static void CreateDirRecursive(const std::string &dir_path, mode_t mode = S_IRWXU | S_IRWXG | S_IRWXO);
};
} // namespace storage
} // namespace distributed


+ 8
- 3
mindspore/ccsrc/distributed/persistent/storage/json_utils.cc View File

@@ -23,9 +23,7 @@ namespace distributed {
namespace storage {
bool JsonUtils::Initialize() {
if (!FileIOUtils::IsFileOrDirExist(file_name_)) {
std::ofstream output_file(file_name_);
output_file.close();
ChangeFileMode(file_name_, S_IRWXU | S_IRWXG | S_IRWXO);
FileIOUtils::CreateFile(file_name_);
return true;
}

@@ -41,6 +39,13 @@ bool JsonUtils::Initialize() {
}
return true;
}

bool JsonUtils::Exists(const std::string &key) const {
if (!js_.contains(key)) {
return false;
}
return true;
}
} // namespace storage
} // namespace distributed
} // namespace mindspore

+ 3
- 0
mindspore/ccsrc/distributed/persistent/storage/json_utils.h View File

@@ -44,6 +44,9 @@ class JsonUtils {
template <typename T>
void Insert(const std::string &key, const T &value);

// Check whether key exists in json or not.
bool Exists(const std::string &key) const;

private:
// Json object.
nlohmann::json js_;


+ 17
- 0
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -40,6 +40,7 @@
#include "ps/util.h"
#endif
#include "ps/ps_context.h"
#include "runtime/recovery/recovery_context.h"

#include "pybind_api/gil_scoped_long_running.h"

@@ -57,6 +58,7 @@ using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam;
using PSContext = mindspore::ps::PSContext;
using RecoveryContext = mindspore::runtime::recovery::RecoveryContext;

// Interface with python
PYBIND11_MODULE(_c_expression, m) {
@@ -511,6 +513,21 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");

(void)py::class_<RecoveryContext, std::shared_ptr<RecoveryContext>>(m, "RecoveryContext")
.def_static("get_instance", &RecoveryContext::GetInstance, "Get recovery context instance.")
.def("enable_recovery", &RecoveryContext::enable_recovery, "Get whether enable recovery.")
.def("latest_ckpt_file", &RecoveryContext::latest_ckpt_file, "Get latest checkpoint file path.")
.def("latest_ckpt_epoch", &RecoveryContext::latest_ckpt_epoch, "Get the epoch of latest checkpoint.")
.def("latest_ckpt_step", &RecoveryContext::latest_ckpt_step, "Get the step of latest checkpoint.")
.def("set_need_reset", &RecoveryContext::set_need_reset,
"Set whether should call reset minddata and load ckpt for disaster recovery.")
.def("need_reset", &RecoveryContext::need_reset,
"Get whether should call reset minddata and load ckpt for disaster recovery.")
.def("recovery_path", &RecoveryContext::recovery_path,
"Get the recovery path used to save that need to be persisted.")
.def("ckpt_path", &RecoveryContext::GetCkptPath, "Get the recovery path used to save checkpoint.")
.def("set_ckpt_path", &RecoveryContext::SetCkptPath, "Set the recovery path used to save checkpoint.");

#ifndef _WIN32
(void)m.def("_export_bprop_mindir", &mindspore::ad::KPrim::ExportBpropMindir,
"Export the backpropagation function to mindir file.");


+ 9
- 0
mindspore/ccsrc/runtime/recovery/CMakeLists.txt View File

@@ -0,0 +1,9 @@
file(GLOB_RECURSE RECOVERY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"recovery_context.cc")

if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-abstract-non-virtual-dtor")
endif()

set_property(SOURCE ${RECOVERY_SRC_LIST} PROPERTY SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(_mindspore_runtime_recovery_obj OBJECT ${RECOVERY_SRC_LIST})

+ 387
- 0
mindspore/ccsrc/runtime/recovery/recovery_context.cc View File

@@ -0,0 +1,387 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/recovery/recovery_context.h"

#include <dirent.h>
#include <algorithm>
#include <utility>

#include "nlohmann/json.hpp"
#include "ps/ps_context.h"
#include "ps/constants.h"
#include "utils/file_utils.h"
#include "distributed/constants.h"
#include "distributed/cluster/topology/common.h"
#include "distributed/init.h"
#include "runtime/hardware/device_context.h"
#include "utils/convert_utils_base.h"

namespace mindspore {
namespace runtime {
namespace recovery {
constexpr char kEnvEnableRecovery[] = "MS_ENABLE_RECOVERY";
constexpr char kEnvRecoveryPath[] = "MS_RECOVERY_PATH";
constexpr char kEnvRecoveryInterval[] = "MS_RECOVERY_INTERVAL";

constexpr char kCkptSuffix[] = ".ckpt";
constexpr char kCkptPath[] = "ckpt_path";
constexpr char kJsonSuffix[] = ".json";
constexpr char kConfigJson[] = "/config.json";

const uint32_t kSendBufferLen = 2;

namespace {
std::pair<int, int> ParseCkptEpochStep(const std::string &checkpoint) {
size_t suffix_pos = checkpoint.rfind('.');
if (suffix_pos == std::string::npos || checkpoint.substr(suffix_pos) != kCkptSuffix) {
MS_LOG(WARNING) << "The file : " << checkpoint << "is not a checkpoint";
return {};
}

size_t epoch_begin_pos = checkpoint.rfind('-');
size_t step_begin_pos = checkpoint.rfind('_');
if (epoch_begin_pos == std::string::npos || step_begin_pos == std::string::npos) {
MS_LOG(EXCEPTION) << "The checkpoint file name is not valid: " << checkpoint;
}

return std::make_pair(std::stoi(checkpoint.substr(epoch_begin_pos + 1, (step_begin_pos - epoch_begin_pos) - 1)),
std::stoi(checkpoint.substr(step_begin_pos + 1, (suffix_pos - step_begin_pos) - 1)));
}

void RemoveAllCkptFiles(const std::string &directory, const std::vector<std::string> &files_list) {
for (size_t i = 0; i < files_list.size(); i++) {
const auto &ckpt_name = files_list[i];
const auto &ckpt_file = directory + "/" + ckpt_name;
(void)remove(ckpt_file.c_str());
}
}
} // namespace

void RecoveryContext::Initialize() {
if (initialized_) {
return;
}

// 1. Read environment variable.
enable_recovery_ = (common::GetEnv(kEnvEnableRecovery) == std::string("1"));
if (!enable_recovery_) {
return;
}

recovery_path_ = common::GetEnv(kEnvRecoveryPath);
if (recovery_path_.empty()) {
MS_LOG(EXCEPTION) << "The recovery path is empty, please export MS_RECOVERY_PATH correctly.";
}

auto env_recovery_interval = common::GetEnv(kEnvRecoveryInterval);
if (!env_recovery_interval.empty()) {
recovery_interval_ = std::stoi(env_recovery_interval);
}

node_role_ = common::GetEnv(distributed::kEnvRole);
if (distributed::kValidRoleName.count(node_role_) == 0) {
MS_LOG(EXCEPTION) << "Role name '" << node_role_ << "' is invalid. ";
}

// 2. Create config json file.
if (node_role_ == distributed::kEnvRoleOfScheduler) {
if (!FileIOUtils::IsFileOrDirExist(recovery_path_)) {
FileIOUtils::CreateDirRecursive(recovery_path_);
}

auto ret = FileUtils::GetRealPath(recovery_path_.c_str());
if (!ret.has_value()) {
MS_LOG(EXCEPTION) << "Cannot get real path of persistent storage path: " << recovery_path_;
}
recovery_path_ = ret.value();
if (!FileIOUtils::IsFileOrDirExist(recovery_path_ + kConfigJson)) {
nlohmann::json config_js;
config_js[std::string(ps::kStoreType)] = 1;
config_js[std::string(ps::kStoreFilePath)] = recovery_path_ + "/" + ps::kStoreFilePath + kJsonSuffix;
config_js[std::string(ps::kSchedulerStoreFilePath)] =
recovery_path_ + "/" + ps::kSchedulerStoreFilePath + kJsonSuffix;

nlohmann::json recovery_js;
recovery_js[std::string(ps::kKeyRecovery)] = config_js;
std::ofstream config_file(recovery_path_ + kConfigJson);
config_file << recovery_js.dump();
config_file.close();
}
}

// 3. Worker or Server need to wait the recovery config json file to be created.
while (!FileIOUtils::IsFileOrDirExist(recovery_path_ + kConfigJson)) {
// Wait duration: 200ms.
const int kWaitDuration = 200;
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitDuration));
}

// 4. Set config content to PSContext.
ps::PSContext::instance()->set_config_file_path(recovery_path_ + kConfigJson);
ps::PSContext::instance()->set_node_id(common::GetEnv(distributed::cluster::topology::kEnvNodeId));

initialized_ = true;
}

bool RecoveryContext::ReInitializeCollective() {
auto ret = distributed::Initialize();
if (ret) {
recovery_status_ = RecoveryStatus::kUnKnownError;
set_need_reset(true);
set_need_sync_weight_to_device(true);
return true;
}

if (recovery_status_ == RecoveryStatus::kBroadcastUniqueIDFailed ||
recovery_status_ == RecoveryStatus::kAllGatherHostNameFailed) {
MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying.";
// Retry duration: 30s.
const int kRetryDuration = 30;
std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration));
return ReInitializeCollective();
} else if (recovery_status_ == RecoveryStatus::kInitNcclFailed) {
MS_LOG(EXCEPTION) << "Initialize NCCL failed.";
}

MS_LOG(EXCEPTION) << "ReInitialize collective failed.";
return false;
}

void RecoveryContext::ObtainGlobalLatestCkptInfo() {
// 1. Obtain the step corresponding to the local latest checkpoint.
ObtainLocalLatestCkptInfo();

// For standalone training.
if (global_rank_size_ == 0) {
return;
}

// 2. AllGather the latest checkpoint info of all nodes.
device::DeviceContextKey host_key = {"CPU", 0};
device::DeviceContext *host_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
MS_EXCEPTION_IF_NULL(host_context);
device::CollectiveCommunicationLib *host_comm_lib_instance = host_context->collective_comm_lib();
MS_EXCEPTION_IF_NULL(host_comm_lib_instance);

if (global_rank_id_ >= global_rank_size_) {
MS_LOG(EXCEPTION) << "The global rank id " << global_rank_id_ << " should be less than global rank size "
<< global_rank_size_;
}

const uint32_t kRecvBufferLen = kSendBufferLen * global_rank_size_;

int send_buffer[kSendBufferLen] = {latest_ckpt_epoch_, latest_ckpt_step_};
int recv_buffer[kRecvBufferLen] = {0};
recv_buffer[kSendBufferLen * global_rank_id_] = latest_ckpt_epoch_;
recv_buffer[kSendBufferLen * global_rank_id_ + 1] = latest_ckpt_step_;

const std::string &host_global_group_name = host_comm_lib_instance->global_group_name();
if (!host_comm_lib_instance->AllGather(send_buffer, recv_buffer, kSendBufferLen, TypeId::kNumberTypeInt,
host_global_group_name)) {
MS_LOG(EXCEPTION) << "AllGather latest ckpt step failed";
}

// 3. Check whether save checkpoint successfully on every workers.
uint32_t save_ckpt_success_num = 0;
uint32_t save_ckpt_failed_num = 0;
for (uint32_t i = 0; i < kRecvBufferLen; i += kSendBufferLen) {
if (recv_buffer[i] < 0) {
save_ckpt_failed_num++;
} else {
save_ckpt_success_num++;
}
}

if (save_ckpt_success_num > 0 && save_ckpt_failed_num > 0) {
RemoveAllCkptFiles(GetCkptPath(), ckpt_files_);
MS_LOG(EXCEPTION) << "Can not find checkpoint for same step, the workers quits and training should start over.";
}
if (save_ckpt_success_num == 0 && save_ckpt_failed_num == global_rank_size_) {
return;
}

// 4. Parse latest epoch and step info.
ParseLatestCkptInfo(recv_buffer, kRecvBufferLen);

// 5. Remove useless ckpt
for (int i = SizeToInt(ckpt_files_.size()) - 1; i >= 0; i--) {
const auto &last_ckpt_name = ckpt_files_[i];
const auto &last_ckpt_file = GetCkptPath() + "/" + last_ckpt_name;
if (last_ckpt_file != latest_ckpt_file_) {
(void)remove(last_ckpt_file.c_str());
} else {
break;
}
}
}

void RecoveryContext::ObtainLocalLatestCkptInfo() {
std::string ckpt_save_dir = GetCkptPath();
if (ckpt_save_dir.empty()) {
MS_LOG(INFO) << "The ckpt file path is empty";
return;
}

DIR *dir = opendir(ckpt_save_dir.c_str());
if (dir == nullptr) {
MS_LOG(EXCEPTION) << "The file path [" << ckpt_save_dir << "] is not exist";
return;
}

if (!ckpt_files_.empty()) {
ckpt_files_.clear();
}

struct dirent *entry;
while ((entry = readdir(dir)) != nullptr) {
std::string file_name = entry->d_name;
size_t suffix_pos = file_name.rfind('.');
if (suffix_pos == std::string::npos || file_name.substr(suffix_pos) != kCkptSuffix) {
continue;
}

ckpt_files_.push_back(file_name);
}
(void)closedir(dir);

if (ckpt_files_.empty()) {
MS_LOG(INFO) << "There is no checkpoint file in dir: " << ckpt_save_dir;
return;
}

sort(ckpt_files_.begin(), ckpt_files_.end(), [](const std::string &a, const std::string &b) {
auto ckpt_epoch_step_a = ParseCkptEpochStep(a);
auto ckpt_epoch_step_b = ParseCkptEpochStep(b);
if (ckpt_epoch_step_a.first < ckpt_epoch_step_b.first) {
return true;
} else if (ckpt_epoch_step_a.first == ckpt_epoch_step_b.first) {
return ckpt_epoch_step_a.second < ckpt_epoch_step_b.second;
} else {
return false;
}
});

const auto &latest_ckpt_name = ckpt_files_.back();
latest_ckpt_file_ = ckpt_save_dir + "/" + latest_ckpt_name;

auto ckpt_epoch_step = ParseCkptEpochStep(latest_ckpt_name);
latest_ckpt_epoch_ = ckpt_epoch_step.first;
latest_ckpt_step_ = ckpt_epoch_step.second;
}

void RecoveryContext::ParseLatestCkptInfo(const int *recv_buffer, const uint32_t buffer_len) {
std::vector<std::pair<int, int>> ckpts_epoch_step;
for (uint32_t i = 0; i < buffer_len; i += kSendBufferLen) {
ckpts_epoch_step.emplace_back(recv_buffer[i], recv_buffer[i + 1]);
}
sort(ckpts_epoch_step.begin(), ckpts_epoch_step.end(),
[](const std::pair<int, int> &a, const std::pair<int, int> &b) {
if (a.first < b.first) {
return true;
} else if (a.first == b.first) {
return a.second < b.second;
} else {
return false;
}
});

const std::pair<int, int> &latest_epoch_step = ckpts_epoch_step.front();
latest_ckpt_epoch_ = latest_epoch_step.first;
latest_ckpt_step_ = latest_epoch_step.second;

const std::string latest_epoch_step_suffix =
std::to_string(latest_epoch_step.first) + "_" + std::to_string(latest_epoch_step.second) + kCkptSuffix;
auto iter = std::find_if(ckpt_files_.rbegin(), ckpt_files_.rend(), [&](const std::string &file_name) {
if (file_name.size() <= latest_epoch_step_suffix.size()) {
return false;
}
return file_name.rfind(latest_epoch_step_suffix) == (file_name.size() - latest_epoch_step_suffix.size());
});
if (iter == ckpt_files_.rend()) {
RemoveAllCkptFiles(GetCkptPath(), ckpt_files_);
MS_LOG(EXCEPTION) << "Can not find checkpoint for same step, the workers quits and training should start over.";
}

latest_ckpt_file_ = GetCkptPath() + "/" + *iter;
}

void RecoveryContext::CreatePersistentFile() {
if (node_role_ == distributed::kEnvRoleOfScheduler) {
return;
}

if (persistent_json_ != nullptr) {
return;
}

// Need to get real path of recovry path for worker or server.
auto ret = FileUtils::GetRealPath(recovery_path_.c_str());
if (!ret.has_value()) {
MS_LOG(EXCEPTION) << "Cannot get real path of persistent storage path: " << recovery_path_;
}
recovery_path_ = ret.value();

// The directory used to save ckpt is persisted to json file.
std::string persistent_file_path =
recovery_path_ + "/" + node_role_ + "_" + std::to_string(global_rank_id_) + "_persistent.json";
persistent_json_ = std::make_unique<JsonUtils>(persistent_file_path);
if (!persistent_json_->Initialize()) {
MS_LOG(EXCEPTION) << "Initialize json failed, file path: " << persistent_file_path;
}
}

void RecoveryContext::SetCkptPath(const std::string &path) {
if (node_role_ == distributed::kEnvRoleOfScheduler) {
return;
}

if (!FileIOUtils::IsFileOrDirExist(path)) {
FileIOUtils::CreateDirRecursive(path);
}

auto ret = FileUtils::GetRealPath(path.c_str());
if (!ret.has_value()) {
MS_LOG(EXCEPTION) << "Cannot get real path for save checkpoint, path: " << path;
}

if (persistent_json_ == nullptr) {
CreatePersistentFile();
}

MS_EXCEPTION_IF_NULL(persistent_json_);
persistent_json_->Insert(kCkptPath, ret.value());
}

std::string RecoveryContext::GetCkptPath() {
if (node_role_ == distributed::kEnvRoleOfScheduler) {
return std::string();
}

if (persistent_json_ == nullptr) {
CreatePersistentFile();
}

MS_EXCEPTION_IF_NULL(persistent_json_);
if (!persistent_json_->Exists(kCkptPath)) {
return std::string();
}

return persistent_json_->Get<std::string>(kCkptPath);
}
} // namespace recovery
} // namespace runtime
} // namespace mindspore

+ 172
- 0
mindspore/ccsrc/runtime/recovery/recovery_context.h View File

@@ -0,0 +1,172 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
#define MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_

#include <vector>
#include <string>
#include <memory>
#include "utils/ms_utils.h"
#include "distributed/persistent/storage/file_io_utils.h"
#include "distributed/persistent/storage/json_utils.h"
#include "runtime/collective/collective_communication_lib.h"

namespace mindspore {
namespace runtime {
namespace recovery {
using distributed::storage::FileIOUtils;
using distributed::storage::JsonUtils;

enum class RecoveryStatus { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };

// Used to save disaster recovery-related state quantities and provide disaster recovery-related
// functions, such as reinitializing collective communication, etc.
class RecoveryContext {
public:
static std::shared_ptr<RecoveryContext> &GetInstance() {
static std::shared_ptr<RecoveryContext> instance = nullptr;
if (instance == nullptr) {
instance.reset(new (std::nothrow) RecoveryContext());
instance->Initialize();
}
return instance;
}
~RecoveryContext() = default;

// Reinitializing collective communication.
bool ReInitializeCollective();

// Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be
// some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the
// successful checkpoint in each training process, and then take the minimum value as the final reset position.
void ObtainGlobalLatestCkptInfo();

// Get whether enable recovery or not.
bool enable_recovery() const { return enable_recovery_; }

// Get the persistent directory.
const std::string &recovery_path() const { return recovery_path_; }

// Get interval to persist model.
int recovery_interval() const { return recovery_interval_; }

// Get the error status of recovery.
RecoveryStatus recovery_status() const { return recovery_status_; }
// Set the error status of recovery.
void set_recovery_status(RecoveryStatus recovery_status) { recovery_status_ = recovery_status; }

// Set the path used to save checkpoint.
void SetCkptPath(const std::string &path);
// Get the path used to save checkpoint.
std::string GetCkptPath();

// Get the latest checkpoint in this node.
std::string latest_ckpt_file() const { return latest_ckpt_file_; }
// Get the epoch of latest checkpoint in this node.
int latest_ckpt_epoch() const { return latest_ckpt_epoch_; }
// Get the step of latest checkpoint in this node.
int latest_ckpt_step() const { return latest_ckpt_step_; }

// Set whether need to reset training process or not, if true, all training process need to rollback the same step of
// latest checkpoint, including loading checkpoint and reset the minddata.
void set_need_reset(bool need_reset) { need_reset_ = need_reset; }
// Get whether need to reset training process or not.
bool need_reset() const { return need_reset_; }

// Set whether need to sync the weight of model to device.
void set_need_sync_weight_to_device(bool need_sync_weight_to_device) {
need_sync_weight_to_device_ = need_sync_weight_to_device;
}
// Get whether need to sync the weight of model to device or not.
bool need_sync_weight_to_device() const { return need_sync_weight_to_device_; }

// Set global rank id.
void set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }
// Set global rank size.
void set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }

// Set whether need reinitialize collective communication.
void set_need_reinit_collective(bool need_reinit_collective) { need_reinit_collective_ = need_reinit_collective; }
// Get whether need reinitialize collective communication.
bool need_reinit_collective() const { return need_reinit_collective_.load(); }

private:
RecoveryContext() = default;
DISABLE_COPY_AND_ASSIGN(RecoveryContext);

// Initialize recovery context.
void Initialize();

// Create persitent json file, used to persist recovery config.
void CreatePersistentFile();

// Obtain the step corresponding to the local latest checkpoint in each training process.
void ObtainLocalLatestCkptInfo();

// Parse latest epoch and step info from all latest checkpoints info allgather from other workers.
void ParseLatestCkptInfo(const int *recv_buffer, const uint32_t buffer_len);

// Whether enable recovery or not, set by environment variable 'MS_ENABLE_RECOVERY'.
bool enable_recovery_{false};

// The persistent directory, set by environment variable 'MS_RECOVERY_PATH'.
std::string recovery_path_;

// The interval to persist model, default value: 30 second. set by environment variable 'MS_RECOVERY_INTERVAL'.
int recovery_interval_{30};

// Local checkpoint file list.
std::vector<std::string> ckpt_files_;
// The file name of latest checkpoint.
std::string latest_ckpt_file_;
// The epoch of latest checkpoint.
int latest_ckpt_epoch_{-1};
// The step of latest checkpoint.
int latest_ckpt_step_{-1};

// Node role in cluster, could be 'MS_WORKER', 'MS_SERVER' or 'MS_SCHED'.
std::string node_role_;

// The global rank id of this process. Normally this range is 0 to `global_rank_size_ - 1`.
uint32_t global_rank_id_{0};
// The global rank size.
uint32_t global_rank_size_{0};

// Whether need to reset training process or not.
bool need_reset_{false};

// Whether need to sync the weight of model to device, this value needs to be set to true when python layer
// performs load checkpoint.
bool need_sync_weight_to_device_{false};

// Whether need reinitialize collective communication, this value should be set to true once a training process
// exits unexpectedly is detected.
std::atomic_bool need_reinit_collective_{false};

// Whether the recovery context is already initialized.
bool initialized_{false};

// The error status of recovery.
RecoveryStatus recovery_status_{RecoveryStatus::kUnKnownError};

// The persitent json file util, used to persist recovery config.
std::unique_ptr<JsonUtils> persistent_json_;
};
} // namespace recovery
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_

+ 109
- 0
mindspore/python/mindspore/parallel/_recovery_context.py View File

@@ -0,0 +1,109 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Context for recovery"""

from mindspore._checkparam import Validator
from mindspore._c_expression import RecoveryContext

_recovery_context = None


def recovery_context():
"""
Get the global _recovery_context, if it is not created, create a new one.

Returns:
_recovery_context, the global recovery context.
"""

global _recovery_context
if _recovery_context is None:
_recovery_context = RecoveryContext.get_instance()
return _recovery_context

_set_recovery_context_func_map = {
"ckpt_path": recovery_context().set_ckpt_path,
"need_reset": recovery_context().set_need_reset
}

_get_recovery_context_func_map = {
"enable_recovery": recovery_context().enable_recovery,
"latest_ckpt_file": recovery_context().latest_ckpt_file,
"latest_ckpt_epoch": recovery_context().latest_ckpt_epoch,
"latest_ckpt_step": recovery_context().latest_ckpt_step,
"need_reset": recovery_context().need_reset,
"recovery_path": recovery_context().recovery_path,
"ckpt_path": recovery_context().ckpt_path
}

_check_bool_keys = ["need_reset"]


def _set_recovery_context(**kwargs):
"""
Set recovery context value.

Note:
Some other environment variables should also be set for recovery.
These environment variables are listed below:

MS_ENABLE_RECOVERY # Enable recovery
MS_RECOVERY_PATH # The persistent path for recovery
MS_RECOVERY_INTERVAL # The persistent interval for recovery

Args:
ckpt_path (string): Set the recovery path used to save checkpoint. Default: ''.
need_reset (bool): Set whether should call reset minddata and load ckpt for disaster recovery. Default: False.

Raises:
ValueError: If input key is not the attribute in recovery context.
"""

for key, value in kwargs.items():
if key not in _set_recovery_context_func_map:
raise ValueError("Set recovery context keyword %s is not recognized!" % key)
_check_value(key, value)
set_func = _set_recovery_context_func_map[key]
set_func(value)


def _get_recovery_context(attr_key):
"""
Get recovery context attribute value according to the key.

Args:
attr_key (str): The key of the attribute.

Returns:
Returns attribute value according to the key.

Raises:
ValueError: If input key is not attribute in revovery context.
"""

if attr_key not in _get_recovery_context_func_map:
raise ValueError("Get recovery context keyword %s is not recognized!" % attr_key)
get_func = _get_recovery_context_func_map[attr_key]
value = get_func()
return value


def _check_value(key, value):
"""
Validate the value for recovery context keys.
"""

if key in _check_bool_keys:
Validator.check_bool(value, key)

Loading…
Cancel
Save