| @@ -258,6 +258,7 @@ set(SUB_COMP | |||
| runtime/device | |||
| runtime/graph_scheduler | |||
| runtime/hardware | |||
| runtime/recovery | |||
| runtime/pynative | |||
| plugin/device/ascend/hal/device | |||
| plugin/device/ascend/hal/hardware | |||
| @@ -153,6 +153,35 @@ void FileIOUtils::CreateDir(const std::string &dir_path, mode_t mode) { | |||
| MS_LOG(EXCEPTION) << "Failed to create directory " << dir_path << ". Errno = " << errno; | |||
| } | |||
| } | |||
| void FileIOUtils::CreateDirRecursive(const std::string &dir_path, mode_t mode) { | |||
| size_t dir_path_len = dir_path.length(); | |||
| if (dir_path_len > PATH_MAX) { | |||
| MS_LOG(EXCEPTION) << "Directory path is too long: " << dir_path; | |||
| } | |||
| char tmp_dir_path[PATH_MAX] = {0}; | |||
| for (size_t i = 0; i < dir_path_len; ++i) { | |||
| tmp_dir_path[i] = dir_path[i]; | |||
| if (tmp_dir_path[i] == '/' || dir_path == tmp_dir_path) { | |||
| if (access(tmp_dir_path, F_OK) == 0) { | |||
| continue; | |||
| } | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| int32_t ret = mkdir(tmp_dir_path); | |||
| #else | |||
| int32_t ret = mkdir(tmp_dir_path, mode); | |||
| if (ret == 0) { | |||
| ChangeFileMode(tmp_dir_path, mode); | |||
| } | |||
| #endif | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "Failed to create directory recursion: " << dir_path << ". Errno = " << errno; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace storage | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||
| @@ -42,6 +42,9 @@ class FileIOUtils { | |||
| // Create directory. | |||
| static void CreateDir(const std::string &dir_path, mode_t mode = S_IRWXU | S_IRWXG | S_IRWXO); | |||
| // Create directory recursively. | |||
| static void CreateDirRecursive(const std::string &dir_path, mode_t mode = S_IRWXU | S_IRWXG | S_IRWXO); | |||
| }; | |||
| } // namespace storage | |||
| } // namespace distributed | |||
| @@ -23,9 +23,7 @@ namespace distributed { | |||
| namespace storage { | |||
| bool JsonUtils::Initialize() { | |||
| if (!FileIOUtils::IsFileOrDirExist(file_name_)) { | |||
| std::ofstream output_file(file_name_); | |||
| output_file.close(); | |||
| ChangeFileMode(file_name_, S_IRWXU | S_IRWXG | S_IRWXO); | |||
| FileIOUtils::CreateFile(file_name_); | |||
| return true; | |||
| } | |||
| @@ -41,6 +39,13 @@ bool JsonUtils::Initialize() { | |||
| } | |||
| return true; | |||
| } | |||
| bool JsonUtils::Exists(const std::string &key) const { | |||
| if (!js_.contains(key)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace storage | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||
| @@ -44,6 +44,9 @@ class JsonUtils { | |||
| template <typename T> | |||
| void Insert(const std::string &key, const T &value); | |||
| // Check whether key exists in json or not. | |||
| bool Exists(const std::string &key) const; | |||
| private: | |||
| // Json object. | |||
| nlohmann::json js_; | |||
| @@ -40,6 +40,7 @@ | |||
| #include "ps/util.h" | |||
| #endif | |||
| #include "ps/ps_context.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "pybind_api/gil_scoped_long_running.h" | |||
| @@ -57,6 +58,7 @@ using ParallelContext = mindspore::parallel::ParallelContext; | |||
| using CostModelContext = mindspore::parallel::CostModelContext; | |||
| using mindspore::MsCtxParam; | |||
| using PSContext = mindspore::ps::PSContext; | |||
| using RecoveryContext = mindspore::runtime::recovery::RecoveryContext; | |||
| // Interface with python | |||
| PYBIND11_MODULE(_c_expression, m) { | |||
| @@ -511,6 +513,21 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data."); | |||
| (void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted"); | |||
| (void)py::class_<RecoveryContext, std::shared_ptr<RecoveryContext>>(m, "RecoveryContext") | |||
| .def_static("get_instance", &RecoveryContext::GetInstance, "Get recovery context instance.") | |||
| .def("enable_recovery", &RecoveryContext::enable_recovery, "Get whether enable recovery.") | |||
| .def("latest_ckpt_file", &RecoveryContext::latest_ckpt_file, "Get latest checkpoint file path.") | |||
| .def("latest_ckpt_epoch", &RecoveryContext::latest_ckpt_epoch, "Get the epoch of latest checkpoint.") | |||
| .def("latest_ckpt_step", &RecoveryContext::latest_ckpt_step, "Get the step of latest checkpoint.") | |||
| .def("set_need_reset", &RecoveryContext::set_need_reset, | |||
| "Set whether should call reset minddata and load ckpt for disaster recovery.") | |||
| .def("need_reset", &RecoveryContext::need_reset, | |||
| "Get whether should call reset minddata and load ckpt for disaster recovery.") | |||
| .def("recovery_path", &RecoveryContext::recovery_path, | |||
| "Get the recovery path used to save that need to be persisted.") | |||
| .def("ckpt_path", &RecoveryContext::GetCkptPath, "Get the recovery path used to save checkpoint.") | |||
| .def("set_ckpt_path", &RecoveryContext::SetCkptPath, "Set the recovery path used to save checkpoint."); | |||
| #ifndef _WIN32 | |||
| (void)m.def("_export_bprop_mindir", &mindspore::ad::KPrim::ExportBpropMindir, | |||
| "Export the backpropagation function to mindir file."); | |||
| @@ -0,0 +1,9 @@ | |||
| file(GLOB_RECURSE RECOVERY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "recovery_context.cc") | |||
| if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-abstract-non-virtual-dtor") | |||
| endif() | |||
| set_property(SOURCE ${RECOVERY_SRC_LIST} PROPERTY SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||
| add_library(_mindspore_runtime_recovery_obj OBJECT ${RECOVERY_SRC_LIST}) | |||
| @@ -0,0 +1,387 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include <dirent.h> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include "nlohmann/json.hpp" | |||
| #include "ps/ps_context.h" | |||
| #include "ps/constants.h" | |||
| #include "utils/file_utils.h" | |||
| #include "distributed/constants.h" | |||
| #include "distributed/cluster/topology/common.h" | |||
| #include "distributed/init.h" | |||
| #include "runtime/hardware/device_context.h" | |||
| #include "utils/convert_utils_base.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| namespace recovery { | |||
| constexpr char kEnvEnableRecovery[] = "MS_ENABLE_RECOVERY"; | |||
| constexpr char kEnvRecoveryPath[] = "MS_RECOVERY_PATH"; | |||
| constexpr char kEnvRecoveryInterval[] = "MS_RECOVERY_INTERVAL"; | |||
| constexpr char kCkptSuffix[] = ".ckpt"; | |||
| constexpr char kCkptPath[] = "ckpt_path"; | |||
| constexpr char kJsonSuffix[] = ".json"; | |||
| constexpr char kConfigJson[] = "/config.json"; | |||
| const uint32_t kSendBufferLen = 2; | |||
| namespace { | |||
| std::pair<int, int> ParseCkptEpochStep(const std::string &checkpoint) { | |||
| size_t suffix_pos = checkpoint.rfind('.'); | |||
| if (suffix_pos == std::string::npos || checkpoint.substr(suffix_pos) != kCkptSuffix) { | |||
| MS_LOG(WARNING) << "The file : " << checkpoint << "is not a checkpoint"; | |||
| return {}; | |||
| } | |||
| size_t epoch_begin_pos = checkpoint.rfind('-'); | |||
| size_t step_begin_pos = checkpoint.rfind('_'); | |||
| if (epoch_begin_pos == std::string::npos || step_begin_pos == std::string::npos) { | |||
| MS_LOG(EXCEPTION) << "The checkpoint file name is not valid: " << checkpoint; | |||
| } | |||
| return std::make_pair(std::stoi(checkpoint.substr(epoch_begin_pos + 1, (step_begin_pos - epoch_begin_pos) - 1)), | |||
| std::stoi(checkpoint.substr(step_begin_pos + 1, (suffix_pos - step_begin_pos) - 1))); | |||
| } | |||
| void RemoveAllCkptFiles(const std::string &directory, const std::vector<std::string> &files_list) { | |||
| for (size_t i = 0; i < files_list.size(); i++) { | |||
| const auto &ckpt_name = files_list[i]; | |||
| const auto &ckpt_file = directory + "/" + ckpt_name; | |||
| (void)remove(ckpt_file.c_str()); | |||
| } | |||
| } | |||
| } // namespace | |||
| void RecoveryContext::Initialize() { | |||
| if (initialized_) { | |||
| return; | |||
| } | |||
| // 1. Read environment variable. | |||
| enable_recovery_ = (common::GetEnv(kEnvEnableRecovery) == std::string("1")); | |||
| if (!enable_recovery_) { | |||
| return; | |||
| } | |||
| recovery_path_ = common::GetEnv(kEnvRecoveryPath); | |||
| if (recovery_path_.empty()) { | |||
| MS_LOG(EXCEPTION) << "The recovery path is empty, please export MS_RECOVERY_PATH correctly."; | |||
| } | |||
| auto env_recovery_interval = common::GetEnv(kEnvRecoveryInterval); | |||
| if (!env_recovery_interval.empty()) { | |||
| recovery_interval_ = std::stoi(env_recovery_interval); | |||
| } | |||
| node_role_ = common::GetEnv(distributed::kEnvRole); | |||
| if (distributed::kValidRoleName.count(node_role_) == 0) { | |||
| MS_LOG(EXCEPTION) << "Role name '" << node_role_ << "' is invalid. "; | |||
| } | |||
| // 2. Create config json file. | |||
| if (node_role_ == distributed::kEnvRoleOfScheduler) { | |||
| if (!FileIOUtils::IsFileOrDirExist(recovery_path_)) { | |||
| FileIOUtils::CreateDirRecursive(recovery_path_); | |||
| } | |||
| auto ret = FileUtils::GetRealPath(recovery_path_.c_str()); | |||
| if (!ret.has_value()) { | |||
| MS_LOG(EXCEPTION) << "Cannot get real path of persistent storage path: " << recovery_path_; | |||
| } | |||
| recovery_path_ = ret.value(); | |||
| if (!FileIOUtils::IsFileOrDirExist(recovery_path_ + kConfigJson)) { | |||
| nlohmann::json config_js; | |||
| config_js[std::string(ps::kStoreType)] = 1; | |||
| config_js[std::string(ps::kStoreFilePath)] = recovery_path_ + "/" + ps::kStoreFilePath + kJsonSuffix; | |||
| config_js[std::string(ps::kSchedulerStoreFilePath)] = | |||
| recovery_path_ + "/" + ps::kSchedulerStoreFilePath + kJsonSuffix; | |||
| nlohmann::json recovery_js; | |||
| recovery_js[std::string(ps::kKeyRecovery)] = config_js; | |||
| std::ofstream config_file(recovery_path_ + kConfigJson); | |||
| config_file << recovery_js.dump(); | |||
| config_file.close(); | |||
| } | |||
| } | |||
| // 3. Worker or Server need to wait the recovery config json file to be created. | |||
| while (!FileIOUtils::IsFileOrDirExist(recovery_path_ + kConfigJson)) { | |||
| // Wait duration: 200ms. | |||
| const int kWaitDuration = 200; | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kWaitDuration)); | |||
| } | |||
| // 4. Set config content to PSContext. | |||
| ps::PSContext::instance()->set_config_file_path(recovery_path_ + kConfigJson); | |||
| ps::PSContext::instance()->set_node_id(common::GetEnv(distributed::cluster::topology::kEnvNodeId)); | |||
| initialized_ = true; | |||
| } | |||
| bool RecoveryContext::ReInitializeCollective() { | |||
| auto ret = distributed::Initialize(); | |||
| if (ret) { | |||
| recovery_status_ = RecoveryStatus::kUnKnownError; | |||
| set_need_reset(true); | |||
| set_need_sync_weight_to_device(true); | |||
| return true; | |||
| } | |||
| if (recovery_status_ == RecoveryStatus::kBroadcastUniqueIDFailed || | |||
| recovery_status_ == RecoveryStatus::kAllGatherHostNameFailed) { | |||
| MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying."; | |||
| // Retry duration: 30s. | |||
| const int kRetryDuration = 30; | |||
| std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration)); | |||
| return ReInitializeCollective(); | |||
| } else if (recovery_status_ == RecoveryStatus::kInitNcclFailed) { | |||
| MS_LOG(EXCEPTION) << "Initialize NCCL failed."; | |||
| } | |||
| MS_LOG(EXCEPTION) << "ReInitialize collective failed."; | |||
| return false; | |||
| } | |||
| void RecoveryContext::ObtainGlobalLatestCkptInfo() { | |||
| // 1. Obtain the step corresponding to the local latest checkpoint. | |||
| ObtainLocalLatestCkptInfo(); | |||
| // For standalone training. | |||
| if (global_rank_size_ == 0) { | |||
| return; | |||
| } | |||
| // 2. AllGather the latest checkpoint info of all nodes. | |||
| device::DeviceContextKey host_key = {"CPU", 0}; | |||
| device::DeviceContext *host_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key); | |||
| MS_EXCEPTION_IF_NULL(host_context); | |||
| device::CollectiveCommunicationLib *host_comm_lib_instance = host_context->collective_comm_lib(); | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance); | |||
| if (global_rank_id_ >= global_rank_size_) { | |||
| MS_LOG(EXCEPTION) << "The global rank id " << global_rank_id_ << " should be less than global rank size " | |||
| << global_rank_size_; | |||
| } | |||
| const uint32_t kRecvBufferLen = kSendBufferLen * global_rank_size_; | |||
| int send_buffer[kSendBufferLen] = {latest_ckpt_epoch_, latest_ckpt_step_}; | |||
| int recv_buffer[kRecvBufferLen] = {0}; | |||
| recv_buffer[kSendBufferLen * global_rank_id_] = latest_ckpt_epoch_; | |||
| recv_buffer[kSendBufferLen * global_rank_id_ + 1] = latest_ckpt_step_; | |||
| const std::string &host_global_group_name = host_comm_lib_instance->global_group_name(); | |||
| if (!host_comm_lib_instance->AllGather(send_buffer, recv_buffer, kSendBufferLen, TypeId::kNumberTypeInt, | |||
| host_global_group_name)) { | |||
| MS_LOG(EXCEPTION) << "AllGather latest ckpt step failed"; | |||
| } | |||
| // 3. Check whether save checkpoint successfully on every workers. | |||
| uint32_t save_ckpt_success_num = 0; | |||
| uint32_t save_ckpt_failed_num = 0; | |||
| for (uint32_t i = 0; i < kRecvBufferLen; i += kSendBufferLen) { | |||
| if (recv_buffer[i] < 0) { | |||
| save_ckpt_failed_num++; | |||
| } else { | |||
| save_ckpt_success_num++; | |||
| } | |||
| } | |||
| if (save_ckpt_success_num > 0 && save_ckpt_failed_num > 0) { | |||
| RemoveAllCkptFiles(GetCkptPath(), ckpt_files_); | |||
| MS_LOG(EXCEPTION) << "Can not find checkpoint for same step, the workers quits and training should start over."; | |||
| } | |||
| if (save_ckpt_success_num == 0 && save_ckpt_failed_num == global_rank_size_) { | |||
| return; | |||
| } | |||
| // 4. Parse latest epoch and step info. | |||
| ParseLatestCkptInfo(recv_buffer, kRecvBufferLen); | |||
| // 5. Remove useless ckpt | |||
| for (int i = SizeToInt(ckpt_files_.size()) - 1; i >= 0; i--) { | |||
| const auto &last_ckpt_name = ckpt_files_[i]; | |||
| const auto &last_ckpt_file = GetCkptPath() + "/" + last_ckpt_name; | |||
| if (last_ckpt_file != latest_ckpt_file_) { | |||
| (void)remove(last_ckpt_file.c_str()); | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| void RecoveryContext::ObtainLocalLatestCkptInfo() { | |||
| std::string ckpt_save_dir = GetCkptPath(); | |||
| if (ckpt_save_dir.empty()) { | |||
| MS_LOG(INFO) << "The ckpt file path is empty"; | |||
| return; | |||
| } | |||
| DIR *dir = opendir(ckpt_save_dir.c_str()); | |||
| if (dir == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The file path [" << ckpt_save_dir << "] is not exist"; | |||
| return; | |||
| } | |||
| if (!ckpt_files_.empty()) { | |||
| ckpt_files_.clear(); | |||
| } | |||
| struct dirent *entry; | |||
| while ((entry = readdir(dir)) != nullptr) { | |||
| std::string file_name = entry->d_name; | |||
| size_t suffix_pos = file_name.rfind('.'); | |||
| if (suffix_pos == std::string::npos || file_name.substr(suffix_pos) != kCkptSuffix) { | |||
| continue; | |||
| } | |||
| ckpt_files_.push_back(file_name); | |||
| } | |||
| (void)closedir(dir); | |||
| if (ckpt_files_.empty()) { | |||
| MS_LOG(INFO) << "There is no checkpoint file in dir: " << ckpt_save_dir; | |||
| return; | |||
| } | |||
| sort(ckpt_files_.begin(), ckpt_files_.end(), [](const std::string &a, const std::string &b) { | |||
| auto ckpt_epoch_step_a = ParseCkptEpochStep(a); | |||
| auto ckpt_epoch_step_b = ParseCkptEpochStep(b); | |||
| if (ckpt_epoch_step_a.first < ckpt_epoch_step_b.first) { | |||
| return true; | |||
| } else if (ckpt_epoch_step_a.first == ckpt_epoch_step_b.first) { | |||
| return ckpt_epoch_step_a.second < ckpt_epoch_step_b.second; | |||
| } else { | |||
| return false; | |||
| } | |||
| }); | |||
| const auto &latest_ckpt_name = ckpt_files_.back(); | |||
| latest_ckpt_file_ = ckpt_save_dir + "/" + latest_ckpt_name; | |||
| auto ckpt_epoch_step = ParseCkptEpochStep(latest_ckpt_name); | |||
| latest_ckpt_epoch_ = ckpt_epoch_step.first; | |||
| latest_ckpt_step_ = ckpt_epoch_step.second; | |||
| } | |||
| void RecoveryContext::ParseLatestCkptInfo(const int *recv_buffer, const uint32_t buffer_len) { | |||
| std::vector<std::pair<int, int>> ckpts_epoch_step; | |||
| for (uint32_t i = 0; i < buffer_len; i += kSendBufferLen) { | |||
| ckpts_epoch_step.emplace_back(recv_buffer[i], recv_buffer[i + 1]); | |||
| } | |||
| sort(ckpts_epoch_step.begin(), ckpts_epoch_step.end(), | |||
| [](const std::pair<int, int> &a, const std::pair<int, int> &b) { | |||
| if (a.first < b.first) { | |||
| return true; | |||
| } else if (a.first == b.first) { | |||
| return a.second < b.second; | |||
| } else { | |||
| return false; | |||
| } | |||
| }); | |||
| const std::pair<int, int> &latest_epoch_step = ckpts_epoch_step.front(); | |||
| latest_ckpt_epoch_ = latest_epoch_step.first; | |||
| latest_ckpt_step_ = latest_epoch_step.second; | |||
| const std::string latest_epoch_step_suffix = | |||
| std::to_string(latest_epoch_step.first) + "_" + std::to_string(latest_epoch_step.second) + kCkptSuffix; | |||
| auto iter = std::find_if(ckpt_files_.rbegin(), ckpt_files_.rend(), [&](const std::string &file_name) { | |||
| if (file_name.size() <= latest_epoch_step_suffix.size()) { | |||
| return false; | |||
| } | |||
| return file_name.rfind(latest_epoch_step_suffix) == (file_name.size() - latest_epoch_step_suffix.size()); | |||
| }); | |||
| if (iter == ckpt_files_.rend()) { | |||
| RemoveAllCkptFiles(GetCkptPath(), ckpt_files_); | |||
| MS_LOG(EXCEPTION) << "Can not find checkpoint for same step, the workers quits and training should start over."; | |||
| } | |||
| latest_ckpt_file_ = GetCkptPath() + "/" + *iter; | |||
| } | |||
| void RecoveryContext::CreatePersistentFile() { | |||
| if (node_role_ == distributed::kEnvRoleOfScheduler) { | |||
| return; | |||
| } | |||
| if (persistent_json_ != nullptr) { | |||
| return; | |||
| } | |||
| // Need to get real path of recovry path for worker or server. | |||
| auto ret = FileUtils::GetRealPath(recovery_path_.c_str()); | |||
| if (!ret.has_value()) { | |||
| MS_LOG(EXCEPTION) << "Cannot get real path of persistent storage path: " << recovery_path_; | |||
| } | |||
| recovery_path_ = ret.value(); | |||
| // The directory used to save ckpt is persisted to json file. | |||
| std::string persistent_file_path = | |||
| recovery_path_ + "/" + node_role_ + "_" + std::to_string(global_rank_id_) + "_persistent.json"; | |||
| persistent_json_ = std::make_unique<JsonUtils>(persistent_file_path); | |||
| if (!persistent_json_->Initialize()) { | |||
| MS_LOG(EXCEPTION) << "Initialize json failed, file path: " << persistent_file_path; | |||
| } | |||
| } | |||
| void RecoveryContext::SetCkptPath(const std::string &path) { | |||
| if (node_role_ == distributed::kEnvRoleOfScheduler) { | |||
| return; | |||
| } | |||
| if (!FileIOUtils::IsFileOrDirExist(path)) { | |||
| FileIOUtils::CreateDirRecursive(path); | |||
| } | |||
| auto ret = FileUtils::GetRealPath(path.c_str()); | |||
| if (!ret.has_value()) { | |||
| MS_LOG(EXCEPTION) << "Cannot get real path for save checkpoint, path: " << path; | |||
| } | |||
| if (persistent_json_ == nullptr) { | |||
| CreatePersistentFile(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(persistent_json_); | |||
| persistent_json_->Insert(kCkptPath, ret.value()); | |||
| } | |||
| std::string RecoveryContext::GetCkptPath() { | |||
| if (node_role_ == distributed::kEnvRoleOfScheduler) { | |||
| return std::string(); | |||
| } | |||
| if (persistent_json_ == nullptr) { | |||
| CreatePersistentFile(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(persistent_json_); | |||
| if (!persistent_json_->Exists(kCkptPath)) { | |||
| return std::string(); | |||
| } | |||
| return persistent_json_->Get<std::string>(kCkptPath); | |||
| } | |||
| } // namespace recovery | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,172 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "utils/ms_utils.h" | |||
| #include "distributed/persistent/storage/file_io_utils.h" | |||
| #include "distributed/persistent/storage/json_utils.h" | |||
| #include "runtime/collective/collective_communication_lib.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| namespace recovery { | |||
| using distributed::storage::FileIOUtils; | |||
| using distributed::storage::JsonUtils; | |||
| enum class RecoveryStatus { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed }; | |||
| // Used to save disaster recovery-related state quantities and provide disaster recovery-related | |||
| // functions, such as reinitializing collective communication, etc. | |||
| class RecoveryContext { | |||
| public: | |||
| static std::shared_ptr<RecoveryContext> &GetInstance() { | |||
| static std::shared_ptr<RecoveryContext> instance = nullptr; | |||
| if (instance == nullptr) { | |||
| instance.reset(new (std::nothrow) RecoveryContext()); | |||
| instance->Initialize(); | |||
| } | |||
| return instance; | |||
| } | |||
| ~RecoveryContext() = default; | |||
| // Reinitializing collective communication. | |||
| bool ReInitializeCollective(); | |||
| // Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be | |||
| // some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the | |||
| // successful checkpoint in each training process, and then take the minimum value as the final reset position. | |||
| void ObtainGlobalLatestCkptInfo(); | |||
| // Get whether enable recovery or not. | |||
| bool enable_recovery() const { return enable_recovery_; } | |||
| // Get the persistent directory. | |||
| const std::string &recovery_path() const { return recovery_path_; } | |||
| // Get interval to persist model. | |||
| int recovery_interval() const { return recovery_interval_; } | |||
| // Get the error status of recovery. | |||
| RecoveryStatus recovery_status() const { return recovery_status_; } | |||
| // Set the error status of recovery. | |||
| void set_recovery_status(RecoveryStatus recovery_status) { recovery_status_ = recovery_status; } | |||
| // Set the path used to save checkpoint. | |||
| void SetCkptPath(const std::string &path); | |||
| // Get the path used to save checkpoint. | |||
| std::string GetCkptPath(); | |||
| // Get the latest checkpoint in this node. | |||
| std::string latest_ckpt_file() const { return latest_ckpt_file_; } | |||
| // Get the epoch of latest checkpoint in this node. | |||
| int latest_ckpt_epoch() const { return latest_ckpt_epoch_; } | |||
| // Get the step of latest checkpoint in this node. | |||
| int latest_ckpt_step() const { return latest_ckpt_step_; } | |||
| // Set whether need to reset training process or not, if true, all training process need to rollback the same step of | |||
| // latest checkpoint, including loading checkpoint and reset the minddata. | |||
| void set_need_reset(bool need_reset) { need_reset_ = need_reset; } | |||
| // Get whether need to reset training process or not. | |||
| bool need_reset() const { return need_reset_; } | |||
| // Set whether need to sync the weight of model to device. | |||
| void set_need_sync_weight_to_device(bool need_sync_weight_to_device) { | |||
| need_sync_weight_to_device_ = need_sync_weight_to_device; | |||
| } | |||
| // Get whether need to sync the weight of model to device or not. | |||
| bool need_sync_weight_to_device() const { return need_sync_weight_to_device_; } | |||
| // Set global rank id. | |||
| void set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; } | |||
| // Set global rank size. | |||
| void set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; } | |||
| // Set whether need reinitialize collective communication. | |||
| void set_need_reinit_collective(bool need_reinit_collective) { need_reinit_collective_ = need_reinit_collective; } | |||
| // Get whether need reinitialize collective communication. | |||
| bool need_reinit_collective() const { return need_reinit_collective_.load(); } | |||
| private: | |||
| RecoveryContext() = default; | |||
| DISABLE_COPY_AND_ASSIGN(RecoveryContext); | |||
| // Initialize recovery context. | |||
| void Initialize(); | |||
| // Create persitent json file, used to persist recovery config. | |||
| void CreatePersistentFile(); | |||
| // Obtain the step corresponding to the local latest checkpoint in each training process. | |||
| void ObtainLocalLatestCkptInfo(); | |||
| // Parse latest epoch and step info from all latest checkpoints info allgather from other workers. | |||
| void ParseLatestCkptInfo(const int *recv_buffer, const uint32_t buffer_len); | |||
| // Whether enable recovery or not, set by environment variable 'MS_ENABLE_RECOVERY'. | |||
| bool enable_recovery_{false}; | |||
| // The persistent directory, set by environment variable 'MS_RECOVERY_PATH'. | |||
| std::string recovery_path_; | |||
| // The interval to persist model, default value: 30 second. set by environment variable 'MS_RECOVERY_INTERVAL'. | |||
| int recovery_interval_{30}; | |||
| // Local checkpoint file list. | |||
| std::vector<std::string> ckpt_files_; | |||
| // The file name of latest checkpoint. | |||
| std::string latest_ckpt_file_; | |||
| // The epoch of latest checkpoint. | |||
| int latest_ckpt_epoch_{-1}; | |||
| // The step of latest checkpoint. | |||
| int latest_ckpt_step_{-1}; | |||
| // Node role in cluster, could be 'MS_WORKER', 'MS_SERVER' or 'MS_SCHED'. | |||
| std::string node_role_; | |||
| // The global rank id of this process. Normally this range is 0 to `global_rank_size_ - 1`. | |||
| uint32_t global_rank_id_{0}; | |||
| // The global rank size. | |||
| uint32_t global_rank_size_{0}; | |||
| // Whether need to reset training process or not. | |||
| bool need_reset_{false}; | |||
| // Whether need to sync the weight of model to device, this value needs to be set to true when python layer | |||
| // performs load checkpoint. | |||
| bool need_sync_weight_to_device_{false}; | |||
| // Whether need reinitialize collective communication, this value should be set to true once a training process | |||
| // exits unexpectedly is detected. | |||
| std::atomic_bool need_reinit_collective_{false}; | |||
| // Whether the recovery context is already initialized. | |||
| bool initialized_{false}; | |||
| // The error status of recovery. | |||
| RecoveryStatus recovery_status_{RecoveryStatus::kUnKnownError}; | |||
| // The persitent json file util, used to persist recovery config. | |||
| std::unique_ptr<JsonUtils> persistent_json_; | |||
| }; | |||
| } // namespace recovery | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_ | |||
| @@ -0,0 +1,109 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Context for recovery""" | |||
| from mindspore._checkparam import Validator | |||
| from mindspore._c_expression import RecoveryContext | |||
| _recovery_context = None | |||
| def recovery_context(): | |||
| """ | |||
| Get the global _recovery_context, if it is not created, create a new one. | |||
| Returns: | |||
| _recovery_context, the global recovery context. | |||
| """ | |||
| global _recovery_context | |||
| if _recovery_context is None: | |||
| _recovery_context = RecoveryContext.get_instance() | |||
| return _recovery_context | |||
| _set_recovery_context_func_map = { | |||
| "ckpt_path": recovery_context().set_ckpt_path, | |||
| "need_reset": recovery_context().set_need_reset | |||
| } | |||
| _get_recovery_context_func_map = { | |||
| "enable_recovery": recovery_context().enable_recovery, | |||
| "latest_ckpt_file": recovery_context().latest_ckpt_file, | |||
| "latest_ckpt_epoch": recovery_context().latest_ckpt_epoch, | |||
| "latest_ckpt_step": recovery_context().latest_ckpt_step, | |||
| "need_reset": recovery_context().need_reset, | |||
| "recovery_path": recovery_context().recovery_path, | |||
| "ckpt_path": recovery_context().ckpt_path | |||
| } | |||
| _check_bool_keys = ["need_reset"] | |||
| def _set_recovery_context(**kwargs): | |||
| """ | |||
| Set recovery context value. | |||
| Note: | |||
| Some other environment variables should also be set for recovery. | |||
| These environment variables are listed below: | |||
| MS_ENABLE_RECOVERY # Enable recovery | |||
| MS_RECOVERY_PATH # The persistent path for recovery | |||
| MS_RECOVERY_INTERVAL # The persistent interval for recovery | |||
| Args: | |||
| ckpt_path (string): Set the recovery path used to save checkpoint. Default: ''. | |||
| need_reset (bool): Set whether should call reset minddata and load ckpt for disaster recovery. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not the attribute in recovery context. | |||
| """ | |||
| for key, value in kwargs.items(): | |||
| if key not in _set_recovery_context_func_map: | |||
| raise ValueError("Set recovery context keyword %s is not recognized!" % key) | |||
| _check_value(key, value) | |||
| set_func = _set_recovery_context_func_map[key] | |||
| set_func(value) | |||
| def _get_recovery_context(attr_key): | |||
| """ | |||
| Get recovery context attribute value according to the key. | |||
| Args: | |||
| attr_key (str): The key of the attribute. | |||
| Returns: | |||
| Returns attribute value according to the key. | |||
| Raises: | |||
| ValueError: If input key is not attribute in revovery context. | |||
| """ | |||
| if attr_key not in _get_recovery_context_func_map: | |||
| raise ValueError("Get recovery context keyword %s is not recognized!" % attr_key) | |||
| get_func = _get_recovery_context_func_map[attr_key] | |||
| value = get_func() | |||
| return value | |||
| def _check_value(key, value): | |||
| """ | |||
| Validate the value for recovery context keys. | |||
| """ | |||
| if key in _check_bool_keys: | |||
| Validator.check_bool(value, key) | |||