Merge pull request !31879 from zyli2020/worker_failover_bpr1.7
| @@ -310,7 +310,6 @@ set(BACKEND_SUB_COMP | |||
| runtime/graph_scheduler | |||
| runtime/hardware | |||
| runtime/pynative | |||
| runtime/recovery | |||
| plugin/device/ascend/hal/device | |||
| plugin/device/ascend/hal/hardware | |||
| plugin/device/ascend/hal/hccl_adapter | |||
| @@ -37,7 +37,7 @@ | |||
| #include "runtime/hardware/device_context_manager.h" | |||
| #include "runtime/graph_scheduler/graph_compiler.h" | |||
| #include "runtime/pynative/run_op_helper.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| #include "include/common/utils/scoped_long_running.h" | |||
| #ifdef ENABLE_D | |||
| #include "include/common/utils/callbacks_ge.h" | |||
| @@ -953,8 +953,8 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, | |||
| MS_EXCEPTION_IF_NULL(graph_compiler_); | |||
| graph_compiler_->Summary(graph_compiler_info.graphs_); | |||
| bool need_contruct_output = !(runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() && | |||
| runtime::recovery::RecoveryContext::GetInstance()->need_reset()); | |||
| bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() && | |||
| distributed::recovery::RecoveryContext::GetInstance()->need_reset()); | |||
| if (need_contruct_output) { | |||
| // Update device address for output node of graph. | |||
| // Summary processing will use the output device address, so must be after the summary processing. | |||
| @@ -15,18 +15,24 @@ | |||
| */ | |||
| #include "distributed/collective/collective_manager.h" | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <functional> | |||
| #include <csignal> | |||
| #include <memory> | |||
| #include "utils/ms_context.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| namespace mindspore { | |||
| namespace distributed { | |||
| namespace collective { | |||
| using recovery::RecoveryContext; | |||
| CollectiveManager::CollectiveManager() | |||
| : inited_(false), | |||
| finalized_(true), | |||
| need_reinit_(false), | |||
| host_ctx_(nullptr), | |||
| device_ctx_(nullptr), | |||
| host_comm_lib_instance_(nullptr), | |||
| @@ -60,8 +66,75 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() { | |||
| return instance; | |||
| } | |||
| namespace { | |||
| // The wrapper to provide a timeout mechanism for executing functions. | |||
| bool ExecuteFuncInThread(const std::function<bool()> &func, const int64_t timeout) { | |||
| bool execute_success = false; | |||
| bool execute_fail = false; | |||
| std::mutex exec_ret_mutex; | |||
| std::condition_variable thread_blocker; | |||
| std::unique_ptr<std::thread> executive_thread = std::make_unique<std::thread>([&] { | |||
| if (!func()) { | |||
| MS_LOG(ERROR) << "Failed to execute function asynchronously"; | |||
| std::unique_lock<std::mutex> lock(exec_ret_mutex); | |||
| execute_fail = true; | |||
| thread_blocker.notify_one(); | |||
| return; | |||
| } | |||
| { | |||
| std::unique_lock<std::mutex> lock(exec_ret_mutex); | |||
| execute_success = true; | |||
| thread_blocker.notify_one(); | |||
| } | |||
| }); | |||
| executive_thread->detach(); | |||
| std::unique_lock<std::mutex> locker(exec_ret_mutex); | |||
| (void)thread_blocker.wait_for(locker, std::chrono::seconds(timeout), [&] { return execute_success || execute_fail; }); | |||
| if (!execute_success && !execute_fail) { | |||
| std::string node_id = common::GetEnv("MS_NODE_ID"); | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| MS_LOG(ERROR) << "Execute function asynchronously timeout, node id: " << node_id << " exit process"; | |||
| (void)kill(getpid(), SIGTERM); | |||
| #endif | |||
| } | |||
| return execute_success; | |||
| } | |||
| // In a disaster recovery scenario, the comparison between the current unique id and the last generated unique id | |||
| // ensures that the acquired unique id is newly generated, and the latest unique id will be persisted. | |||
| bool CheckUniqueIDLatest(const std::string &group_name, size_t root_info_size, const void *root_info) { | |||
| MS_EXCEPTION_IF_NULL(root_info); | |||
| auto persistent_json = RecoveryContext::GetInstance()->persistent_json(); | |||
| MS_EXCEPTION_IF_NULL(persistent_json); | |||
| std::string new_unique_id(static_cast<const char *>(root_info), root_info_size); | |||
| std::vector<int> new_unique_id_integer_seq; | |||
| (void)std::transform(new_unique_id.begin(), new_unique_id.end(), std::back_inserter(new_unique_id_integer_seq), | |||
| [](char c) { return static_cast<int>(c); }); | |||
| const char unique_id_str[] = "_unique_id"; | |||
| std::string unique_id_key = group_name + unique_id_str; | |||
| if (!persistent_json->Exists(unique_id_key)) { | |||
| persistent_json->Insert(unique_id_key, new_unique_id_integer_seq); | |||
| return true; | |||
| } | |||
| std::vector<int> old_unique_id_integer_seq = persistent_json->Get<std::vector<int>>(unique_id_key); | |||
| if (new_unique_id_integer_seq == old_unique_id_integer_seq) { | |||
| return false; | |||
| } | |||
| persistent_json->Insert(unique_id_key, new_unique_id_integer_seq); | |||
| return true; | |||
| } | |||
| } // namespace | |||
| bool CollectiveManager::Initialize() { | |||
| if (inited_ && !runtime::recovery::RecoveryContext::GetInstance()->need_reinit_collective()) { | |||
| if (inited_ && !need_reinit_) { | |||
| return true; | |||
| } | |||
| @@ -98,6 +171,8 @@ bool CollectiveManager::Initialize() { | |||
| MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_; | |||
| inited_ = true; | |||
| finalized_ = false; | |||
| need_reinit_ = false; | |||
| return true; | |||
| } | |||
| @@ -125,50 +200,36 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name, | |||
| void *root_info = group->GenerateRootInfo(&root_info_size); | |||
| MS_EXCEPTION_IF_NULL(root_info); | |||
| bool ret = false; | |||
| // Step 4: Broadcast the device root information to all nodes on host side. | |||
| if (!host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info)) { | |||
| MS_LOG(ERROR) << "Broadcast for device root info failed on the host side."; | |||
| return false; | |||
| while (!ret) { | |||
| ret = host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Broadcast for device root info failed on the host side."; | |||
| return false; | |||
| } | |||
| // In disaster recovery scenarios, it is necessary to ensure that the unique id obtained from the Scheduler is a | |||
| // newly generated one. | |||
| if (RecoveryContext::GetInstance()->enable_recovery()) { | |||
| ret = CheckUniqueIDLatest(group_name, root_info_size, root_info); | |||
| } | |||
| } | |||
| // Step 5: Initialize communication group on the device side. | |||
| return InitDeviceCommGroup(group, root_info); | |||
| } | |||
| bool CollectiveManager::InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info) { | |||
| bool init_group_success = false; | |||
| bool init_group_fail = false; | |||
| std::condition_variable thread_blocker; | |||
| init_group_thread_ = std::make_unique<std::thread>([&, this] { | |||
| std::function<bool()> init_device_comm_group_func = [&, this]() { | |||
| device_ctx_->Initialize(); | |||
| if (!group->Initialize(root_info)) { | |||
| MS_LOG(ERROR) << "Initialize group on the device side failed."; | |||
| std::unique_lock<std::mutex> lock(init_group_mutex_); | |||
| init_group_fail = true; | |||
| thread_blocker.notify_one(); | |||
| return; | |||
| } | |||
| { | |||
| std::unique_lock<std::mutex> lock(init_group_mutex_); | |||
| init_group_success = true; | |||
| thread_blocker.notify_one(); | |||
| } | |||
| }); | |||
| init_group_thread_->detach(); | |||
| return group->Initialize(root_info); | |||
| }; | |||
| MS_LOG(INFO) << "Begin initialize communication group on the device side."; | |||
| // Timeout limit 180 seconds to wait finishing init device communication group. | |||
| // Timeout limit 180 seconds to wait finish initializing device communication group. | |||
| const int64_t kTimeToWait = 180; | |||
| std::unique_lock<std::mutex> locker(init_group_mutex_); | |||
| (void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait), | |||
| [&] { return init_group_success || init_group_fail; }); | |||
| if (!init_group_success && runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) { | |||
| runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status( | |||
| runtime::recovery::RecoveryErrCode::kInitNcclFailed); | |||
| MS_LOG(ERROR) << "Initialize group on the device side failed."; | |||
| } | |||
| return init_group_success; | |||
| // Initialize communication group on the device side in thread with timeout limit. | |||
| ret = ExecuteFuncInThread(init_device_comm_group_func, kTimeToWait); | |||
| MS_LOG(INFO) << "End initialize communication group on the device side."; | |||
| return ret; | |||
| } | |||
| bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) { | |||
| @@ -197,22 +258,34 @@ uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) { | |||
| } | |||
| bool CollectiveManager::Finalize() { | |||
| if (finalized_) { | |||
| if (!inited_.load() || finalized_.load()) { | |||
| return true; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| if (!host_comm_lib_instance_->Finalize()) { | |||
| MS_LOG(WARNING) << "Failed to finalize host communication library."; | |||
| } | |||
| std::function<bool()> finalize_func = [&, this]() { | |||
| MS_EXCEPTION_IF_NULL(host_comm_lib_instance_); | |||
| if (!host_comm_lib_instance_->Finalize()) { | |||
| MS_LOG(WARNING) << "Failed to finalize host communication library."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| if (!device_comm_lib_instance_->Finalize()) { | |||
| MS_LOG(WARNING) << "Failed to finalize device communication library."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_comm_lib_instance_); | |||
| if (!device_comm_lib_instance_->Finalize()) { | |||
| MS_LOG(WARNING) << "Failed to finalize device communication library."; | |||
| } | |||
| finalized_ = true; | |||
| return true; | |||
| finalized_ = true; | |||
| return true; | |||
| }; | |||
| MS_LOG(INFO) << "Begin finalize collective manager."; | |||
| // Timeout limit 5 seconds to wait to finish finalizing device communication group. | |||
| const int64_t kTimeToWait = 5; | |||
| // Finalize collective manager in thread with timeout limit. | |||
| bool ret = ExecuteFuncInThread(finalize_func, kTimeToWait); | |||
| MS_LOG(INFO) << "End finalize collective manager."; | |||
| return ret; | |||
| } | |||
| void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; } | |||
| @@ -69,6 +69,11 @@ class BACKEND_EXPORT CollectiveManager { | |||
| uint32_t local_rank_id() const; | |||
| // Set whether need reinitialize collective communication. | |||
| void set_need_reinit(bool need_reinit) { need_reinit_ = need_reinit; } | |||
| // Get whether need reinitialize collective communication. | |||
| bool need_reinit() const { return need_reinit_.load(); } | |||
| private: | |||
| CollectiveManager(); | |||
| @@ -81,16 +86,13 @@ class BACKEND_EXPORT CollectiveManager { | |||
| // Assign the local rank id for this process. | |||
| bool AssignLocalRank(); | |||
| // Initialize communication group on the device side. | |||
| bool InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info); | |||
| // Initialize communication group on the device side in thread with timeout limit. | |||
| std::unique_ptr<std::thread> init_group_thread_; | |||
| std::mutex init_group_mutex_; | |||
| std::atomic_bool inited_; | |||
| std::atomic_bool finalized_; | |||
| // Whether need reinitialize collective communication, this value should be set to true once a training process | |||
| // exits unexpectedly is detected. | |||
| std::atomic_bool need_reinit_; | |||
| // The device type read from MindSpore context. | |||
| std::string device_type_; | |||
| @@ -119,8 +121,8 @@ class BACKEND_EXPORT CollectiveManager { | |||
| // Global group ranks. | |||
| std::vector<uint32_t> global_group_ranks_; | |||
| // The global group name on the host side. This is used for Creating global group on host side for AllGather operation | |||
| // of host name while assigning local rank. | |||
| // The global group name on the host side. This is used for Creating global group on host side for AllGather | |||
| // operation of host name while assigning local rank. | |||
| std::string host_global_group_name_; | |||
| }; | |||
| } // namespace collective | |||
| @@ -17,11 +17,11 @@ | |||
| #include "distributed/init.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| namespace mindspore { | |||
| namespace distributed { | |||
| using runtime::recovery::RecoveryContext; | |||
| using distributed::recovery::RecoveryContext; | |||
| bool Initialize() { | |||
| if (!InitializeCluster()) { | |||
| @@ -43,7 +43,8 @@ bool Initialize() { | |||
| collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num()); | |||
| if (RecoveryContext::GetInstance()->enable_recovery()) { | |||
| cluster::ClusterContext::instance()->WaitForClusterReady(); | |||
| RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id()); | |||
| RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num()); | |||
| } | |||
| if (!InitializeCollective()) { | |||
| @@ -52,8 +53,6 @@ bool Initialize() { | |||
| } | |||
| if (RecoveryContext::GetInstance()->enable_recovery()) { | |||
| RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id()); | |||
| RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num()); | |||
| RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo(); | |||
| } | |||
| } | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| #include <dirent.h> | |||
| #include <algorithm> | |||
| @@ -26,13 +26,12 @@ | |||
| #include "utils/file_utils.h" | |||
| #include "distributed/constants.h" | |||
| #include "distributed/cluster/topology/common.h" | |||
| #include "distributed/init.h" | |||
| #include "runtime/hardware/device_context.h" | |||
| #include "runtime/hardware/device_context_manager.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| namespace distributed { | |||
| namespace recovery { | |||
| constexpr char kEnvEnableRecovery[] = "MS_ENABLE_RECOVERY"; | |||
| constexpr char kEnvRecoveryPath[] = "MS_RECOVERY_PATH"; | |||
| @@ -142,30 +141,6 @@ void RecoveryContext::Initialize() { | |||
| initialized_ = true; | |||
| } | |||
| bool RecoveryContext::ReInitializeCollective() { | |||
| auto ret = distributed::Initialize(); | |||
| if (ret) { | |||
| recovery_status_ = RecoveryErrCode::kUnKnownError; | |||
| set_need_reset(true); | |||
| set_need_sync_weight_to_device(true); | |||
| return true; | |||
| } | |||
| if (recovery_status_ == RecoveryErrCode::kBroadcastUniqueIDFailed || | |||
| recovery_status_ == RecoveryErrCode::kAllGatherHostNameFailed) { | |||
| MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying."; | |||
| // Retry duration: 30s. | |||
| const int kRetryDuration = 30; | |||
| std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration)); | |||
| return ReInitializeCollective(); | |||
| } else if (recovery_status_ == RecoveryErrCode::kInitNcclFailed) { | |||
| MS_LOG(EXCEPTION) << "Initialize NCCL failed."; | |||
| } | |||
| MS_LOG(EXCEPTION) << "ReInitialize collective failed."; | |||
| return false; | |||
| } | |||
| void RecoveryContext::ObtainGlobalLatestCkptInfo() { | |||
| // 1. Obtain the step corresponding to the local latest checkpoint. | |||
| ObtainLocalLatestCkptInfo(); | |||
| @@ -326,6 +301,7 @@ void RecoveryContext::ParseLatestCkptInfo(const int *recv_buffer, const uint32_t | |||
| } | |||
| void RecoveryContext::CreatePersistentFile() { | |||
| std::unique_lock<std::mutex> lock(create_persist_json_mtx_); | |||
| if (node_role_ == distributed::kEnvRoleOfScheduler) { | |||
| return; | |||
| } | |||
| @@ -344,7 +320,7 @@ void RecoveryContext::CreatePersistentFile() { | |||
| // The directory used to save ckpt is persisted to json file. | |||
| std::string persistent_file_path = | |||
| recovery_path_ + "/" + node_role_ + "_" + std::to_string(global_rank_id_) + "_persistent.json"; | |||
| persistent_json_ = std::make_unique<JsonUtils>(persistent_file_path); | |||
| persistent_json_ = std::make_shared<JsonUtils>(persistent_file_path); | |||
| if (!persistent_json_->Initialize()) { | |||
| MS_LOG(EXCEPTION) << "Initialize json failed, file path: " << persistent_file_path; | |||
| } | |||
| @@ -388,6 +364,24 @@ std::string RecoveryContext::GetCkptPath() { | |||
| return persistent_json_->Get<std::string>(kCkptPath); | |||
| } | |||
| const std::shared_ptr<JsonUtils> &RecoveryContext::persistent_json() { | |||
| if (persistent_json_ == nullptr) { | |||
| CreatePersistentFile(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(persistent_json_); | |||
| return persistent_json_; | |||
| } | |||
| std::string RecoveryContext::latest_ckpt_file() { | |||
| // For standalone training. | |||
| if (enable_recovery_ && global_rank_size_ == 0 && latest_ckpt_file_.empty()) { | |||
| ObtainLocalLatestCkptInfo(); | |||
| } | |||
| return latest_ckpt_file_; | |||
| } | |||
| } // namespace recovery | |||
| } // namespace runtime | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_ | |||
| #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_ | |||
| #define MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| @@ -27,13 +27,11 @@ | |||
| #include "include/backend/visible.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| namespace distributed { | |||
| namespace recovery { | |||
| using distributed::storage::FileIOUtils; | |||
| using distributed::storage::JsonUtils; | |||
| enum class RecoveryErrCode { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed }; | |||
| // Used to save disaster recovery-related state quantities and provide disaster recovery-related | |||
| // functions, such as reinitializing collective communication, etc. | |||
| class BACKEND_EXPORT RecoveryContext { | |||
| @@ -47,14 +45,6 @@ class BACKEND_EXPORT RecoveryContext { | |||
| } | |||
| ~RecoveryContext() = default; | |||
| // Reinitializing collective communication. | |||
| bool ReInitializeCollective(); | |||
| // Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be | |||
| // some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the | |||
| // successful checkpoint in each training process, and then take the minimum value as the final reset position. | |||
| void ObtainGlobalLatestCkptInfo(); | |||
| // Get whether enable recovery or not. | |||
| bool enable_recovery() const { return enable_recovery_; } | |||
| @@ -64,18 +54,14 @@ class BACKEND_EXPORT RecoveryContext { | |||
| // Get interval to persist model. | |||
| int recovery_interval() const { return recovery_interval_; } | |||
| // Get the error status of recovery. | |||
| RecoveryErrCode recovery_status() const { return recovery_status_; } | |||
| // Set the error status of recovery. | |||
| void set_recovery_status(RecoveryErrCode recovery_status) { recovery_status_ = recovery_status; } | |||
| // Set the path used to save checkpoint. | |||
| void SetCkptPath(const std::string &path); | |||
| // Get the path used to save checkpoint. | |||
| std::string GetCkptPath(); | |||
| // Get the latest checkpoint in this node. | |||
| std::string latest_ckpt_file() const { return latest_ckpt_file_; } | |||
| std::string latest_ckpt_file(); | |||
| // Get the epoch of latest checkpoint in this node. | |||
| int latest_ckpt_epoch() const { return latest_ckpt_epoch_; } | |||
| // Get the step of latest checkpoint in this node. | |||
| @@ -99,10 +85,13 @@ class BACKEND_EXPORT RecoveryContext { | |||
| // Set global rank size. | |||
| void set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; } | |||
| // Set whether need reinitialize collective communication. | |||
| void set_need_reinit_collective(bool need_reinit_collective) { need_reinit_collective_ = need_reinit_collective; } | |||
| // Get whether need reinitialize collective communication. | |||
| bool need_reinit_collective() const { return need_reinit_collective_.load(); } | |||
| // Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be | |||
| // some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the | |||
| // successful checkpoint in each training process, and then take the minimum value as the final reset position. | |||
| void ObtainGlobalLatestCkptInfo(); | |||
| // Get the persistent json file pointer. | |||
| const std::shared_ptr<JsonUtils> &persistent_json(); | |||
| private: | |||
| inline static std::shared_ptr<RecoveryContext> instance_{}; | |||
| @@ -155,20 +144,14 @@ class BACKEND_EXPORT RecoveryContext { | |||
| // performs load checkpoint. | |||
| bool need_sync_weight_to_device_{false}; | |||
| // Whether need reinitialize collective communication, this value should be set to true once a training process | |||
| // exits unexpectedly is detected. | |||
| std::atomic_bool need_reinit_collective_{false}; | |||
| // Whether the recovery context is already initialized. | |||
| bool initialized_{false}; | |||
| // The error status of recovery. | |||
| RecoveryErrCode recovery_status_{RecoveryErrCode::kUnKnownError}; | |||
| std::mutex create_persist_json_mtx_; | |||
| // The persitent json file util, used to persist recovery config. | |||
| std::unique_ptr<JsonUtils> persistent_json_; | |||
| std::shared_ptr<JsonUtils> persistent_json_; | |||
| }; | |||
| } // namespace recovery | |||
| } // namespace runtime | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_ | |||
| #endif // MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_ | |||
| @@ -40,7 +40,7 @@ | |||
| #include "ps/util.h" | |||
| #endif | |||
| #include "ps/ps_context.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| #include "pybind_api/gil_scoped_long_running.h" | |||
| @@ -58,7 +58,7 @@ using ParallelContext = mindspore::parallel::ParallelContext; | |||
| using CostModelContext = mindspore::parallel::CostModelContext; | |||
| using mindspore::MsCtxParam; | |||
| using PSContext = mindspore::ps::PSContext; | |||
| using RecoveryContext = mindspore::runtime::recovery::RecoveryContext; | |||
| using RecoveryContext = mindspore::distributed::recovery::RecoveryContext; | |||
| // Interface with python | |||
| PYBIND11_MODULE(_c_expression, m) { | |||
| @@ -61,6 +61,7 @@ | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "runtime/pynative/op_executor.h" | |||
| #include "runtime/device/stream_synchronizer.h" | |||
| #include "distributed/collective/collective_manager.h" | |||
| #ifndef ENABLE_SECURITY | |||
| #ifdef ENABLE_D | |||
| @@ -1608,6 +1609,8 @@ void ClearResAtexit() { | |||
| device::StreamSynchronizer::GetInstance()->Finalize(); | |||
| MS_LOG(INFO) << "End Finalize StreamSynchronizer..."; | |||
| (void)distributed::collective::CollectiveManager::instance()->Finalize(); | |||
| PrimitivePy::ClearHookRes(); | |||
| ad::g_k_prims.clear(); | |||
| ad::ClearKPynativeCellStaticRes(); | |||
| @@ -35,7 +35,7 @@ using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo; | |||
| using ps::core::NodeCommand; | |||
| // The time interval for send info or query info between worker and scheduler. | |||
| constexpr uint32_t kWaitDuration = 3; | |||
| constexpr uint32_t kWaitDuration = 5; | |||
| // The collective communication library for MindSpore self developed communication framework. | |||
| class MsCollectiveCommLib : public CollectiveCommunicationLib { | |||
| @@ -1301,6 +1301,7 @@ void AbstractNode::InitCommandHandler() { | |||
| handlers_[NodeCommand::SEND_EVENT] = nullptr; | |||
| RegisterActorRouteTableRspHandler(); | |||
| RegisterInitCollectCommResphandler(); | |||
| RegisterRecoveryRespHandler(); | |||
| } | |||
| void AbstractNode::RegisterActorRouteTableRspHandler() { | |||
| @@ -228,6 +228,9 @@ class BACKEND_EXPORT AbstractNode : public Node { | |||
| // Register collective communication initialization response methods. | |||
| virtual void RegisterInitCollectCommResphandler() {} | |||
| // Register recovery response methods. | |||
| virtual void RegisterRecoveryRespHandler() {} | |||
| // when initializing the node, should initializing the node info. | |||
| void InitNodeInfo(const NodeRole &role); | |||
| // Initialize worker num and server num by cluster config. | |||
| @@ -139,6 +139,9 @@ bool AbstractPSNode::HandleHeartbeatTimeout() { | |||
| if (!stop_heartbeat_.load()) { | |||
| stop_heartbeat_ = true; | |||
| while (!heartbeat_stopped_.load()) { | |||
| if (is_finish_.load()) { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Waiting for heartbeat to stop..."; | |||
| // Time interval for waiting the heartbeat to stop. | |||
| @@ -152,6 +155,9 @@ bool AbstractPSNode::HandleHeartbeatTimeout() { | |||
| bool success = false; | |||
| while (!success) { | |||
| if (is_finish_.load()) { | |||
| return; | |||
| } | |||
| MS_LOG(WARNING) << "Trying to reconnect to the scheduler..."; | |||
| success = InitClientToScheduler(); | |||
| if (success) { | |||
| @@ -176,6 +182,11 @@ void AbstractPSNode::RegisterInitCollectCommResphandler() { | |||
| handlers_[NodeCommand::SEND_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp; | |||
| handlers_[NodeCommand::QUERY_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp; | |||
| } | |||
| void AbstractPSNode::RegisterRecoveryRespHandler() { | |||
| handlers_[NodeCommand::SEND_FINISH_TRANSFORM] = &AbstractPSNode::ProcessReceiveSchedulerResp; | |||
| handlers_[NodeCommand::QUERY_FINISH_TRANSFORM] = &AbstractPSNode::ProcessReceiveSchedulerResp; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -38,6 +38,9 @@ class AbstractPSNode : public AbstractNode { | |||
| // Register collective communication initialization response methods. | |||
| void RegisterInitCollectCommResphandler() override; | |||
| // Register recovery response methods. | |||
| void RegisterRecoveryRespHandler() override; | |||
| // Indicate whether the heartbeat thread should be stopped. | |||
| std::atomic<bool> stop_heartbeat_{false}; | |||
| @@ -67,7 +67,7 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message | |||
| return rank_id; | |||
| } | |||
| } else { | |||
| ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port()); | |||
| (void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id); | |||
| } | |||
| return rank_id; | |||
| } | |||
| @@ -206,11 +206,16 @@ std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() { | |||
| return servers_meta_list; | |||
| } | |||
| const std::unordered_map<std::string, NodeInfo> &NodeManager::QueryTimeOutNodesInfo() const { | |||
| return timeout_nodes_info_; | |||
| } | |||
| void NodeManager::UpdateCluster() { | |||
| // 1. update cluster timeout state | |||
| struct timeval current_time {}; | |||
| (void)gettimeofday(¤t_time, nullptr); | |||
| timeout_nodes_info_.clear(); | |||
| std::lock_guard<std::mutex> lock(heartbeat_mutex_); | |||
| for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) { | |||
| if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) { | |||
| if (registered_nodes_info_.count(it->first)) { | |||
| @@ -139,6 +139,8 @@ class NodeManager { | |||
| bool IsAllNodesAlive() const; | |||
| const std::unordered_map<std::string, NodeInfo> &QueryTimeOutNodesInfo() const; | |||
| private: | |||
| std::mutex node_mutex_; | |||
| std::mutex cluster_mutex_; | |||
| @@ -57,6 +57,12 @@ enum NodeCommand { | |||
| SEND_UNIQUE_ID = 20; | |||
| // Query unique id used to initialize collective communication. | |||
| QUERY_UNIQUE_ID = 21; | |||
| // Send the ready status to finish transform graph of computed node, | |||
| // used in disaster recovery mode to prevent timeout of waiting for graph transformation. | |||
| SEND_FINISH_TRANSFORM = 22; | |||
| // Query the ready status to finish transform graph of computed node, | |||
| // used in disaster recovery mode to prevent timeout of waiting for graph transformation. | |||
| QUERY_FINISH_TRANSFORM = 23; | |||
| } | |||
| enum NodeRole { | |||
| @@ -298,3 +304,19 @@ message QueryUniqueIDRespMessage { | |||
| // The unique id used to initialize collective communication. | |||
| bytes unique_id = 2; | |||
| } | |||
| message SendFinishTransformMessage { | |||
| // the current Node unique id:0,1,2... | |||
| string node_id = 1; | |||
| // The rank id of the node in the cluster. | |||
| uint32 rank_id = 2; | |||
| // Whether finish transform graph. | |||
| bool is_ready = 3; | |||
| } | |||
| message QueryFinishTransformRespMessage { | |||
| // Whether all computed nodes are ready to run dag. | |||
| bool is_ready = 1; | |||
| // Whether there is any worker timeout. | |||
| bool is_worker_timeout = 2; | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| #include <memory> | |||
| #include "ps/core/ps_scheduler_node.h" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -58,6 +59,13 @@ void PSSchedulerNode::RegisterInitCollectCommServiceHandler() { | |||
| handlers_[NodeCommand::QUERY_UNIQUE_ID] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryUniqueID); | |||
| } | |||
| void PSSchedulerNode::RegisterRecoveryServiceHandler() { | |||
| handlers_[NodeCommand::SEND_FINISH_TRANSFORM] = | |||
| static_cast<ResponseHandler>(&PSSchedulerNode::ProcessSendFinishTransform); | |||
| handlers_[NodeCommand::QUERY_FINISH_TRANSFORM] = | |||
| static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryFinishTransform); | |||
| } | |||
| void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &server, | |||
| const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) { | |||
| @@ -71,7 +79,7 @@ void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &serv | |||
| std::string node_id = send_host_name_msg.node_id(); | |||
| uint32_t rank_id = send_host_name_msg.rank_id(); | |||
| size_t host_hash_name = send_host_name_msg.host_hash_name(); | |||
| MS_LOG(INFO) << "Received send host name request, node id: " << node_id << ", rank id: " << rank_id; | |||
| MS_LOG(INFO) << "Receive send host name request, node id: " << node_id << ", rank id: " << rank_id; | |||
| bool ret = false; | |||
| std::string error = ""; | |||
| @@ -100,7 +108,7 @@ void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &se | |||
| query_msg.ParseFromArray(data, SizeToInt(size)); | |||
| std::string node_id = query_msg.node_id(); | |||
| uint32_t rank_id = query_msg.rank_id(); | |||
| MS_LOG(INFO) << "Received query host name request, node id: " << node_id << ", rank id: " << rank_id; | |||
| MS_LOG(INFO) << "Receive query host name request, node id: " << node_id << ", rank id: " << rank_id; | |||
| bool is_success = recv_rank_id_send_host_name_.size() == host_hash_names_.size(); | |||
| QueryHostHashNameRespMessage resp_msg; | |||
| @@ -121,6 +129,7 @@ void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &se | |||
| if (recv_rank_id_query_host_name_.size() == recv_rank_id_send_host_name_.size()) { | |||
| recv_rank_id_send_host_name_.clear(); | |||
| recv_rank_id_query_host_name_.clear(); | |||
| node_timeout_ = false; | |||
| } | |||
| } | |||
| } | |||
| @@ -138,7 +147,7 @@ void PSSchedulerNode::ProcessSendUniqueID(const std::shared_ptr<TcpServer> &serv | |||
| std::string node_id = send_unique_id_msg.node_id(); | |||
| uint32_t rank_id = send_unique_id_msg.rank_id(); | |||
| std::string group_name = send_unique_id_msg.group_name(); | |||
| MS_LOG(INFO) << "Received send unique id request, group name: " << group_name << ", node id: " << node_id | |||
| MS_LOG(INFO) << "Receive send unique id request, group name: " << group_name << ", node id: " << node_id | |||
| << ", rank id: " << rank_id; | |||
| bool ret = false; | |||
| @@ -169,7 +178,7 @@ void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &ser | |||
| std::string node_id = query_msg.node_id(); | |||
| uint32_t rank_id = query_msg.rank_id(); | |||
| std::string group_name = query_msg.group_name(); | |||
| MS_LOG(INFO) << "Received query unique id request, group name: " << group_name << ", node id: " << node_id | |||
| MS_LOG(INFO) << "Receive query unique id request, group name: " << group_name << ", node id: " << node_id | |||
| << ", rank id: " << rank_id; | |||
| auto iter = unique_id_group_.find(group_name); | |||
| @@ -190,6 +199,89 @@ void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &ser | |||
| MS_LOG(INFO) << "Respond query unique id request, group name: " << group_name << ", node id: " << node_id | |||
| << ", rank id: " << rank_id; | |||
| } | |||
| void PSSchedulerNode::ProcessSendFinishTransform(const std::shared_ptr<TcpServer> &server, | |||
| const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, | |||
| size_t size) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(conn); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(meta); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data); | |||
| SendFinishTransformMessage send_ready_to_run_msg; | |||
| send_ready_to_run_msg.ParseFromArray(data, SizeToInt(size)); | |||
| std::string node_id = send_ready_to_run_msg.node_id(); | |||
| uint32_t rank_id = send_ready_to_run_msg.rank_id(); | |||
| MS_LOG(INFO) << "Receive send finish transform request, node id: " << node_id << ", rank id: " << rank_id; | |||
| bool is_ready = send_ready_to_run_msg.is_ready(); | |||
| if (is_ready) { | |||
| std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_); | |||
| (void)nodes_finish_trans_.insert(rank_id); | |||
| } | |||
| GeneralResponse(server, conn, meta, true, ""); | |||
| MS_LOG(INFO) << "Respond send finish transform request, node id: " << node_id << ", rank id: " << rank_id; | |||
| } | |||
| void PSSchedulerNode::ProcessQueryFinishTransform(const std::shared_ptr<TcpServer> &server, | |||
| const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, | |||
| size_t size) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(conn); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(meta); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data); | |||
| GeneralQueryMessage query_msg; | |||
| query_msg.ParseFromArray(data, SizeToInt(size)); | |||
| std::string node_id = query_msg.node_id(); | |||
| uint32_t rank_id = query_msg.rank_id(); | |||
| MS_LOG(INFO) << "Receive query finish transform request, node id: " << node_id << ", rank id: " << rank_id; | |||
| std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_); | |||
| bool is_ready = nodes_finish_trans_.size() == worker_num_; | |||
| QueryFinishTransformRespMessage resp_msg; | |||
| resp_msg.set_is_ready(is_ready); | |||
| if (node_timeout_) { | |||
| (void)resp_msg.set_is_worker_timeout(true); | |||
| } else { | |||
| resp_msg.set_is_worker_timeout(false); | |||
| } | |||
| if (!server->SendMessage(conn, meta, Protos::PROTOBUF, resp_msg.SerializeAsString().data(), | |||
| resp_msg.ByteSizeLong())) { | |||
| MS_LOG(ERROR) << "Scheduler failed to respond message."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Respond query finish transform request, node id: " << node_id << ", rank id: " << rank_id; | |||
| } | |||
| void PSSchedulerNode::HandleNodeTimeoutForRecovery( | |||
| const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) { | |||
| return; | |||
| } | |||
| if (timeout_nodes_infos.empty()) { | |||
| return; | |||
| } | |||
| std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_); | |||
| node_timeout_ = true; | |||
| for (const auto &item : timeout_nodes_infos) { | |||
| (void)nodes_finish_trans_.erase(item.second.rank_id_); | |||
| } | |||
| } | |||
| void PSSchedulerNode::HandleNodeRecoverByHeartBeat(uint32_t rank_id) { | |||
| std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_); | |||
| (void)nodes_finish_trans_.insert(rank_id); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_ | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -47,9 +48,16 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode { | |||
| // alive node should be rejected. | |||
| bool NeedRejectRegister(const NodeInfo &node_info) override { return node_info.is_alive; } | |||
| bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos) override { | |||
| return true; | |||
| }; | |||
| // Register collective communication initialization service. | |||
| void RegisterInitCollectCommServiceHandler() override; | |||
| // Register recovery service. | |||
| void RegisterRecoveryServiceHandler() override; | |||
| // Process message for sending node's host name. | |||
| void ProcessSendHostName(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| @@ -66,6 +74,20 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode { | |||
| void ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Process message for sending the ready status to finish transform graph of computed node, | |||
| void ProcessSendFinishTransform(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Process message for querying the ready status to finish transform graph of computed node, | |||
| void ProcessQueryFinishTransform(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Handle node timeout info and update nodes which finish transform graph. | |||
| void HandleNodeTimeoutForRecovery(const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) override; | |||
| // Recover finish transform nodes info when nodes recover heartbeat. | |||
| void HandleNodeRecoverByHeartBeat(uint32_t rank_id) override; | |||
| // Record received host hash name from workers. | |||
| std::vector<size_t> host_hash_names_; | |||
| // Record rank id of the nodes which sended host name. | |||
| @@ -77,6 +99,11 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode { | |||
| std::map<std::string, std::string> unique_id_group_; | |||
| uint32_t worker_num_; | |||
| std::mutex nodes_finish_trans_mutex_; | |||
| // Record the rank ids of nodes who finish transform graph. | |||
| std::set<uint32_t> nodes_finish_trans_; | |||
| std::atomic_bool node_timeout_{false}; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -39,16 +39,13 @@ bool PSServerNode::Start(const uint32_t &timeout) { | |||
| } | |||
| void PSServerNode::Initialize() { | |||
| InitNodeNum(); | |||
| config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path()); | |||
| MS_EXCEPTION_IF_NULL(config_); | |||
| if (!config_->Initialize()) { | |||
| MS_LOG(INFO) << "The config file is empty, then init node by context."; | |||
| InitNodeNum(); | |||
| } else { | |||
| if (!Recover()) { | |||
| MS_LOG(WARNING) << "Recover the server node is failed."; | |||
| } | |||
| if (config_->Initialize() && !Recover()) { | |||
| MS_LOG(INFO) << "Recover the server node is failed."; | |||
| } | |||
| InitServerHandler(); | |||
| CreateTcpServer(); | |||
| InitNodeInfo(NodeRole::SERVER); | |||
| @@ -119,8 +116,9 @@ void PSServerNode::Register(const std::shared_ptr<TcpClient> &client) { | |||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!"; | |||
| const int kCommTimeoutInSeconds = 20; | |||
| if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(), | |||
| register_message.ByteSizeLong())) { | |||
| register_message.ByteSizeLong(), kCommTimeoutInSeconds)) { | |||
| MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " register timeout!"; | |||
| } else { | |||
| @@ -79,16 +79,13 @@ bool PSWorkerNode::Finish(const uint32_t &timeout) { | |||
| } | |||
| void PSWorkerNode::Initialize() { | |||
| InitNodeNum(); | |||
| config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path()); | |||
| MS_EXCEPTION_IF_NULL(config_); | |||
| if (!config_->Initialize()) { | |||
| MS_LOG(INFO) << "The config file is empty, then init node by context."; | |||
| InitNodeNum(); | |||
| } else { | |||
| if (!Recover()) { | |||
| MS_LOG(WARNING) << "Recover the worker node is failed."; | |||
| } | |||
| if (config_->Initialize() && !Recover()) { | |||
| MS_LOG(INFO) << "Recover the worker node is failed."; | |||
| } | |||
| InitServerHandler(); | |||
| CreateTcpServer(); | |||
| InitNodeInfo(NodeRole::WORKER); | |||
| @@ -120,8 +117,9 @@ void PSWorkerNode::Register(const std::shared_ptr<TcpClient> &client) { | |||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!"; | |||
| const int kCommTimeoutInSeconds = 20; | |||
| if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(), | |||
| register_message.ByteSizeLong())) { | |||
| register_message.ByteSizeLong(), kCommTimeoutInSeconds)) { | |||
| MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " register timeout!"; | |||
| } else { | |||
| @@ -176,10 +176,12 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, | |||
| return; | |||
| } | |||
| uint32_t rank_id = UINT32_MAX; | |||
| // Re-Add the missing node into node manager. | |||
| if (heartbeat_message.has_address() && | |||
| node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port())) { | |||
| node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port(), &rank_id)) { | |||
| SetRegisterConnectionFd(conn, node_id); | |||
| HandleNodeRecoverByHeartBeat(rank_id); | |||
| if (node_manager_.IsAllNodesRegistered()) { | |||
| is_ready_ = true; | |||
| @@ -234,6 +236,7 @@ void SchedulerNode::InitCommandHandler() { | |||
| handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent; | |||
| RegisterActorRouteTableServiceHandler(); | |||
| RegisterInitCollectCommServiceHandler(); | |||
| RegisterRecoveryServiceHandler(); | |||
| } | |||
| void SchedulerNode::RegisterActorRouteTableServiceHandler() { | |||
| @@ -699,6 +702,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() { | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval)); | |||
| node_manager_.UpdateCluster(); | |||
| HandleNodeTimeoutForRecovery(node_manager_.QueryTimeOutNodesInfo()); | |||
| if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) { | |||
| std::this_thread::sleep_for( | |||
| @@ -92,6 +92,15 @@ class BACKEND_EXPORT SchedulerNode : public Node { | |||
| // Register collective communication initialization service. | |||
| virtual void RegisterInitCollectCommServiceHandler() {} | |||
| // Register recovery service. | |||
| virtual void RegisterRecoveryServiceHandler() {} | |||
| // Handle node timeout info and update nodes which finish transform graph. | |||
| virtual void HandleNodeTimeoutForRecovery(const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) {} | |||
| // Recover finish transform nodes info when nodes recover heartbeat. | |||
| virtual void HandleNodeRecoverByHeartBeat(uint32_t rank_id) {} | |||
| const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info); | |||
| void ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| @@ -197,7 +206,7 @@ class BACKEND_EXPORT SchedulerNode : public Node { | |||
| void SetRegisterConnectionFd(const std::shared_ptr<TcpConnection> &conn, const std::string &node_id); | |||
| bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos); | |||
| virtual bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos); | |||
| // Responding peer with the general response message. | |||
| void GeneralResponse(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| @@ -17,16 +17,19 @@ | |||
| #include "runtime/device/stream_synchronizer.h" | |||
| #include "utils/ms_context.h" | |||
| #include "distributed/collective/collective_manager.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| using distributed::collective::CollectiveManager; | |||
| using distributed::recovery::RecoveryContext; | |||
| std::mutex StreamSynchronizer::instance_lock_; | |||
| std::shared_ptr<StreamSynchronizer> StreamSynchronizer::instance_ = nullptr; | |||
| void StreamSynchronizer::Initialize() { | |||
| // Non disaster recovery mode does not need to start thread and timeout mechanisms. | |||
| if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) { | |||
| if (!RecoveryContext::GetInstance()->enable_recovery()) { | |||
| return; | |||
| } | |||
| @@ -56,7 +59,7 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| // If disable recovery or timeout==0, sync stream directly to improve performance. | |||
| if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) { | |||
| if (!RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) { | |||
| device_context->Initialize(); | |||
| return device_context->SyncStream(); | |||
| } | |||
| @@ -68,26 +71,19 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim | |||
| device_context_ = device_context; | |||
| do_sync_stream_cv_.notify_one(); | |||
| if (sync_stream_time_out_) { | |||
| // If sync stream timeout has happened, increase the timeout by 4 times. | |||
| const uint32_t kTimeOutScaleFactor = 4; | |||
| timeout *= kTimeOutScaleFactor; | |||
| } | |||
| if (time_out_cv_.wait_for(lock, std::chrono::seconds(timeout)) == std::cv_status::no_timeout) { | |||
| if (!sync_stream_ret_) { | |||
| MS_LOG(ERROR) << "Synchronize stream failed."; | |||
| } | |||
| return sync_stream_ret_; | |||
| } else { | |||
| sync_stream_time_out_ = true; | |||
| runtime::recovery::RecoveryContext::GetInstance()->set_need_reinit_collective(true); | |||
| if (!distributed::collective::CollectiveManager::instance()->Finalize()) { | |||
| CollectiveManager::instance()->set_need_reinit(true); | |||
| if (!CollectiveManager::instance()->Finalize()) { | |||
| MS_LOG(ERROR) << "Finalize collective manager failed."; | |||
| return false; | |||
| } | |||
| time_out_cv_.wait(lock, [this]() { return device_context_ == nullptr; }); | |||
| MS_LOG(WARNING) << "Synchronize stream time out."; | |||
| MS_LOG(WARNING) << "Synchronize stream timeout."; | |||
| return true; | |||
| } | |||
| } | |||
| @@ -71,9 +71,6 @@ class BACKEND_EXPORT StreamSynchronizer { | |||
| // The singleton pointer. | |||
| static std::shared_ptr<StreamSynchronizer> instance_; | |||
| // Record whether the synchronization stream task has timed out. | |||
| bool sync_stream_time_out_{false}; | |||
| // Return value of synchronization stream. | |||
| bool sync_stream_ret_{false}; | |||
| @@ -25,11 +25,11 @@ | |||
| #include "mindrt/include/async/async.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "include/common/utils/convert_utils.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| using recovery::RecoveryContext; | |||
| using distributed::recovery::RecoveryContext; | |||
| namespace { | |||
| void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node, | |||
| const DeviceContext *device_context, OpContext<DeviceTensor> *const context, | |||
| @@ -21,11 +21,13 @@ | |||
| #include "runtime/graph_scheduler/actor/debug_actor.h" | |||
| #include "mindrt/include/async/async.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| #include "distributed/collective/collective_manager.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| using recovery::RecoveryContext; | |||
| using distributed::collective::CollectiveManager; | |||
| using distributed::recovery::RecoveryContext; | |||
| void KernelActor::Init() { | |||
| // Check device contexts number. | |||
| @@ -243,7 +245,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) { | |||
| PreLaunchKernel(context); | |||
| try { | |||
| if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) { | |||
| if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) { | |||
| // In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch, | |||
| // especially the collective communication operators. | |||
| MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: " | |||
| @@ -25,11 +25,13 @@ | |||
| #include "mindrt/include/async/async.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "runtime/device/stream_synchronizer.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| #include "distributed/collective/collective_manager.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| using recovery::RecoveryContext; | |||
| using distributed::collective::CollectiveManager; | |||
| using distributed::recovery::RecoveryContext; | |||
| void LoopCountActor::Run(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| @@ -70,8 +72,7 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) { | |||
| (void)sync_stream_device_contexts.insert(device_context); | |||
| // Trigger disaster recovery and exit loop early. | |||
| if (RecoveryContext::GetInstance()->enable_recovery() && | |||
| RecoveryContext::GetInstance()->need_reinit_collective()) { | |||
| if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) { | |||
| current_count_ = loop_count_; | |||
| } | |||
| } | |||
| @@ -17,11 +17,13 @@ | |||
| #include "runtime/graph_scheduler/actor/output_actor.h" | |||
| #include "runtime/hardware/device_context_manager.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| #include "distributed/collective/collective_manager.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| using recovery::RecoveryContext; | |||
| using distributed::collective::CollectiveManager; | |||
| using distributed::recovery::RecoveryContext; | |||
| bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) { | |||
| MS_EXCEPTION_IF_NULL(output_node); | |||
| @@ -105,7 +107,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex | |||
| ++current_count_; | |||
| // Trigger disaster recovery and return empty output. | |||
| if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) { | |||
| if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) { | |||
| FreeOutputNodeMem(); | |||
| ClearOutputCache(); | |||
| SET_OPCONTEXT_SUCCESS_RET((*context)); | |||
| @@ -45,11 +45,19 @@ | |||
| #endif | |||
| #include "profiler/device/profiling.h" | |||
| #include "include/common/debug/common.h" | |||
| #include "runtime/recovery/recovery_context.h" | |||
| #include "distributed/recovery/recovery_context.h" | |||
| #include "distributed/collective/collective_manager.h" | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64)) | |||
| #include "distributed/cluster/cluster_context.h" | |||
| #else | |||
| #include "distributed/cluster/dummy_cluster_context.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| using recovery::RecoveryContext; | |||
| using distributed::cluster::ClusterContext; | |||
| using distributed::collective::CollectiveManager; | |||
| using distributed::recovery::RecoveryContext; | |||
| namespace { | |||
| bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) { | |||
| MS_EXCEPTION_IF_NULL(from_device_context); | |||
| @@ -196,6 +204,108 @@ void IntHandler(int, siginfo_t *, void *) { | |||
| (void)kill(this_pid, SIGTERM); | |||
| } | |||
| #endif | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64)) | |||
| bool SendFinishTransform() { | |||
| auto node = ClusterContext::instance()->node(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->role() != ps::core::NodeRole::WORKER) { | |||
| return true; | |||
| } | |||
| auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(ClusterContext::instance()->node()); | |||
| MS_EXCEPTION_IF_NULL(abstract_node); | |||
| ps::core::SendFinishTransformMessage send_ready_to_run_msg; | |||
| send_ready_to_run_msg.set_node_id(abstract_node->node_id()); | |||
| send_ready_to_run_msg.set_rank_id(abstract_node->rank_id()); | |||
| send_ready_to_run_msg.set_is_ready(true); | |||
| std::shared_ptr<std::vector<unsigned char>> output = nullptr; | |||
| if (!abstract_node->SendToScheduler(send_ready_to_run_msg.SerializeAsString().data(), | |||
| send_ready_to_run_msg.SerializeAsString().size(), | |||
| ps::core::NodeCommand::SEND_FINISH_TRANSFORM, &output)) { | |||
| MS_LOG(WARNING) << "Failed to send finish transform request to scheduler."; | |||
| return false; | |||
| } | |||
| ps::core::GeneralResponseMsg resp_msg; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| (void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size())); | |||
| if (!resp_msg.is_success()) { | |||
| MS_LOG(ERROR) << "Send finish transform to scheduler failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool QueryFinishTransform() { | |||
| auto node = ClusterContext::instance()->node(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->role() != ps::core::NodeRole::WORKER) { | |||
| return true; | |||
| } | |||
| auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(ClusterContext::instance()->node()); | |||
| MS_EXCEPTION_IF_NULL(abstract_node); | |||
| ps::core::GeneralQueryMessage general_query_msg; | |||
| general_query_msg.set_node_id(abstract_node->node_id()); | |||
| general_query_msg.set_rank_id(abstract_node->rank_id()); | |||
| std::shared_ptr<std::vector<unsigned char>> output = nullptr; | |||
| bool ret = false; | |||
| while (!ret) { | |||
| if (!abstract_node->SendToScheduler(general_query_msg.SerializeAsString().data(), | |||
| general_query_msg.SerializeAsString().size(), | |||
| ps::core::NodeCommand::QUERY_FINISH_TRANSFORM, &output)) { | |||
| MS_LOG(WARNING) << "Failed to send query finish transform request to scheduler."; | |||
| ret = false; | |||
| continue; | |||
| } | |||
| ps::core::QueryFinishTransformRespMessage resp_msg; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| (void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size())); | |||
| ret = resp_msg.is_ready(); | |||
| if (!ret) { | |||
| MS_LOG(INFO) << "There is worker which has not finished transform graph"; | |||
| } | |||
| if (resp_msg.is_worker_timeout()) { | |||
| MS_LOG(WARNING) << "There is worker timeout"; | |||
| return false; | |||
| } | |||
| // The time interval for querying the all worker finish transform graphs status to scheduler: 10 seconds. | |||
| const uint32_t kWaitDuration = 10; | |||
| std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration)); | |||
| } | |||
| return ret; | |||
| } | |||
| void DoDisasterRecovery() { | |||
| if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) { | |||
| MS_LOG(INFO) << "Begin reinitialize collective communication for recovery."; | |||
| bool ret = false; | |||
| while (!ret) { | |||
| while (!CollectiveManager::instance()->Initialize()) { | |||
| MS_LOG(WARNING) << "ReInitialize collective communication failed, retrying..."; | |||
| } | |||
| MS_LOG(INFO) << "Finish reinitialize collective communication for recovery."; | |||
| RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo(); | |||
| ret = QueryFinishTransform(); | |||
| if (!ret) { | |||
| CollectiveManager::instance()->set_need_reinit(true); | |||
| (void)CollectiveManager::instance()->Finalize(); | |||
| } | |||
| } | |||
| RecoveryContext::GetInstance()->set_need_reset(true); | |||
| RecoveryContext::GetInstance()->set_need_sync_weight_to_device(true); | |||
| } | |||
| } | |||
| #endif | |||
| } // namespace | |||
| GraphScheduler &GraphScheduler::GetInstance() noexcept { | |||
| @@ -407,6 +517,17 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info | |||
| } | |||
| MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end."; | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64)) | |||
| if (ClusterContext::instance()->initialized() && RecoveryContext::GetInstance()->enable_recovery()) { | |||
| while (!SendFinishTransform()) { | |||
| MS_LOG(WARNING) << "Send finish transform graph failed."; | |||
| // The time interval for sending finish transform graph to scheduler. | |||
| constexpr uint32_t kWaitDuration = 10; | |||
| std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration)); | |||
| } | |||
| } | |||
| #endif | |||
| return actor_set.get(); | |||
| } | |||
| @@ -489,15 +610,9 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont | |||
| const size_t kSecondsToMilliseconds = 1000; | |||
| SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds); | |||
| if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) { | |||
| MS_LOG(INFO) << "Begin reinitialize collective communication for recovery."; | |||
| if (!RecoveryContext::GetInstance()->ReInitializeCollective()) { | |||
| MS_LOG(EXCEPTION) << "Reinitialize collective communication failed."; | |||
| } | |||
| MS_LOG(INFO) << "Finish reinitialize collective communication for recovery."; | |||
| RecoveryContext::GetInstance()->set_need_reinit_collective(false); | |||
| } | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64)) | |||
| DoDisasterRecovery(); | |||
| #endif | |||
| } | |||
| void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy, | |||
| @@ -1,9 +0,0 @@ | |||
| file(GLOB_RECURSE RECOVERY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "recovery_context.cc") | |||
| if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-abstract-non-virtual-dtor") | |||
| endif() | |||
| set_property(SOURCE ${RECOVERY_SRC_LIST} PROPERTY SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||
| add_library(_mindspore_runtime_recovery_obj OBJECT ${RECOVERY_SRC_LIST}) | |||
| @@ -306,8 +306,7 @@ add_library(backend_static STATIC | |||
| $<TARGET_OBJECTS:_mindspore_runtime_device_obj> | |||
| $<TARGET_OBJECTS:_mindspore_runtime_graph_scheduler_obj> | |||
| $<TARGET_OBJECTS:_mindspore_runtime_hardware_obj> | |||
| $<TARGET_OBJECTS:_mindspore_runtime_pynative_obj> | |||
| $<TARGET_OBJECTS:_mindspore_runtime_recovery_obj>) | |||
| $<TARGET_OBJECTS:_mindspore_runtime_pynative_obj>) | |||
| target_link_libraries(ut_tests PRIVATE mindspore securec -Wl,--start-group proto_input mindspore::protobuf | |||
| backend_static -Wl,--end-group) | |||
| target_link_libraries(ut_tests PRIVATE mindspore::grpc++) | |||