Browse Source

!31879 Training support recovery

Merge pull request !31879 from zyli2020/worker_failover_bp
r1.7
i-robot Gitee 4 years ago
parent
commit
c6da65ecfa
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
32 changed files with 538 additions and 207 deletions
  1. +0
    -1
      mindspore/ccsrc/CMakeLists.txt
  2. +3
    -3
      mindspore/ccsrc/backend/graph_compiler/backend.cc
  3. +123
    -50
      mindspore/ccsrc/distributed/collective/collective_manager.cc
  4. +11
    -9
      mindspore/ccsrc/distributed/collective/collective_manager.h
  5. +4
    -5
      mindspore/ccsrc/distributed/init.cc
  6. +24
    -30
      mindspore/ccsrc/distributed/recovery/recovery_context.cc
  7. +16
    -33
      mindspore/ccsrc/distributed/recovery/recovery_context.h
  8. +2
    -2
      mindspore/ccsrc/pipeline/jit/init.cc
  9. +3
    -0
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  10. +1
    -1
      mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_comm_lib.h
  11. +1
    -0
      mindspore/ccsrc/ps/core/abstract_node.cc
  12. +3
    -0
      mindspore/ccsrc/ps/core/abstract_node.h
  13. +11
    -0
      mindspore/ccsrc/ps/core/abstract_ps_node.cc
  14. +3
    -0
      mindspore/ccsrc/ps/core/abstract_ps_node.h
  15. +6
    -1
      mindspore/ccsrc/ps/core/node_manager.cc
  16. +2
    -0
      mindspore/ccsrc/ps/core/node_manager.h
  17. +22
    -0
      mindspore/ccsrc/ps/core/protos/comm.proto
  18. +96
    -4
      mindspore/ccsrc/ps/core/ps_scheduler_node.cc
  19. +27
    -0
      mindspore/ccsrc/ps/core/ps_scheduler_node.h
  20. +6
    -8
      mindspore/ccsrc/ps/core/ps_server_node.cc
  21. +6
    -8
      mindspore/ccsrc/ps/core/ps_worker_node.cc
  22. +5
    -1
      mindspore/ccsrc/ps/core/scheduler_node.cc
  23. +10
    -1
      mindspore/ccsrc/ps/core/scheduler_node.h
  24. +9
    -13
      mindspore/ccsrc/runtime/device/stream_synchronizer.cc
  25. +0
    -3
      mindspore/ccsrc/runtime/device/stream_synchronizer.h
  26. +2
    -2
      mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc
  27. +5
    -3
      mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc
  28. +5
    -4
      mindspore/ccsrc/runtime/graph_scheduler/actor/loop_count_actor.cc
  29. +5
    -3
      mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc
  30. +126
    -11
      mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc
  31. +0
    -9
      mindspore/ccsrc/runtime/recovery/CMakeLists.txt
  32. +1
    -2
      tests/ut/cpp/CMakeLists.txt

+ 0
- 1
mindspore/ccsrc/CMakeLists.txt View File

@@ -310,7 +310,6 @@ set(BACKEND_SUB_COMP
runtime/graph_scheduler
runtime/hardware
runtime/pynative
runtime/recovery
plugin/device/ascend/hal/device
plugin/device/ascend/hal/hardware
plugin/device/ascend/hal/hccl_adapter


+ 3
- 3
mindspore/ccsrc/backend/graph_compiler/backend.cc View File

@@ -37,7 +37,7 @@
#include "runtime/hardware/device_context_manager.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/pynative/run_op_helper.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "include/common/utils/scoped_long_running.h"
#ifdef ENABLE_D
#include "include/common/utils/callbacks_ge.h"
@@ -953,8 +953,8 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);

bool need_contruct_output = !(runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
runtime::recovery::RecoveryContext::GetInstance()->need_reset());
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
if (need_contruct_output) {
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.


+ 123
- 50
mindspore/ccsrc/distributed/collective/collective_manager.cc View File

@@ -15,18 +15,24 @@
*/

#include "distributed/collective/collective_manager.h"
#include <algorithm>
#include <string>
#include <vector>
#include <functional>
#include <csignal>
#include <memory>
#include "utils/ms_context.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"

namespace mindspore {
namespace distributed {
namespace collective {
using recovery::RecoveryContext;

CollectiveManager::CollectiveManager()
: inited_(false),
finalized_(true),
need_reinit_(false),
host_ctx_(nullptr),
device_ctx_(nullptr),
host_comm_lib_instance_(nullptr),
@@ -60,8 +66,75 @@ std::shared_ptr<CollectiveManager> CollectiveManager::instance() {
return instance;
}

namespace {
// The wrapper to provide a timeout mechanism for executing functions.
bool ExecuteFuncInThread(const std::function<bool()> &func, const int64_t timeout) {
bool execute_success = false;
bool execute_fail = false;
std::mutex exec_ret_mutex;
std::condition_variable thread_blocker;

std::unique_ptr<std::thread> executive_thread = std::make_unique<std::thread>([&] {
if (!func()) {
MS_LOG(ERROR) << "Failed to execute function asynchronously";
std::unique_lock<std::mutex> lock(exec_ret_mutex);
execute_fail = true;
thread_blocker.notify_one();
return;
}

{
std::unique_lock<std::mutex> lock(exec_ret_mutex);
execute_success = true;
thread_blocker.notify_one();
}
});
executive_thread->detach();

std::unique_lock<std::mutex> locker(exec_ret_mutex);
(void)thread_blocker.wait_for(locker, std::chrono::seconds(timeout), [&] { return execute_success || execute_fail; });

if (!execute_success && !execute_fail) {
std::string node_id = common::GetEnv("MS_NODE_ID");
#if !defined(_WIN32) && !defined(_WIN64)
MS_LOG(ERROR) << "Execute function asynchronously timeout, node id: " << node_id << " exit process";
(void)kill(getpid(), SIGTERM);
#endif
}
return execute_success;
}

// In a disaster recovery scenario, the comparison between the current unique id and the last generated unique id
// ensures that the acquired unique id is newly generated, and the latest unique id will be persisted.
bool CheckUniqueIDLatest(const std::string &group_name, size_t root_info_size, const void *root_info) {
MS_EXCEPTION_IF_NULL(root_info);
auto persistent_json = RecoveryContext::GetInstance()->persistent_json();
MS_EXCEPTION_IF_NULL(persistent_json);

std::string new_unique_id(static_cast<const char *>(root_info), root_info_size);
std::vector<int> new_unique_id_integer_seq;
(void)std::transform(new_unique_id.begin(), new_unique_id.end(), std::back_inserter(new_unique_id_integer_seq),
[](char c) { return static_cast<int>(c); });

const char unique_id_str[] = "_unique_id";
std::string unique_id_key = group_name + unique_id_str;
if (!persistent_json->Exists(unique_id_key)) {
persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
return true;
}

std::vector<int> old_unique_id_integer_seq = persistent_json->Get<std::vector<int>>(unique_id_key);
if (new_unique_id_integer_seq == old_unique_id_integer_seq) {
return false;
}

persistent_json->Insert(unique_id_key, new_unique_id_integer_seq);
return true;
}
} // namespace

bool CollectiveManager::Initialize() {
if (inited_ && !runtime::recovery::RecoveryContext::GetInstance()->need_reinit_collective()) {
if (inited_ && !need_reinit_) {
return true;
}

@@ -98,6 +171,8 @@ bool CollectiveManager::Initialize() {
MS_LOG(INFO) << "End initializing collective communication for backend: " << device_type_;
inited_ = true;
finalized_ = false;
need_reinit_ = false;

return true;
}

@@ -125,50 +200,36 @@ bool CollectiveManager::CreateCommunicationGroup(const std::string &group_name,
void *root_info = group->GenerateRootInfo(&root_info_size);
MS_EXCEPTION_IF_NULL(root_info);

bool ret = false;
// Step 4: Broadcast the device root information to all nodes on host side.
if (!host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info)) {
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
return false;
while (!ret) {
ret = host_comm_lib_instance_->BroadcastUniqueID(group_name, is_root_node, root_info_size, root_info);
if (!ret) {
MS_LOG(ERROR) << "Broadcast for device root info failed on the host side.";
return false;
}

// In disaster recovery scenarios, it is necessary to ensure that the unique id obtained from the Scheduler is a
// newly generated one.
if (RecoveryContext::GetInstance()->enable_recovery()) {
ret = CheckUniqueIDLatest(group_name, root_info_size, root_info);
}
}

// Step 5: Initialize communication group on the device side.
return InitDeviceCommGroup(group, root_info);
}

bool CollectiveManager::InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info) {
bool init_group_success = false;
bool init_group_fail = false;
std::condition_variable thread_blocker;
init_group_thread_ = std::make_unique<std::thread>([&, this] {
std::function<bool()> init_device_comm_group_func = [&, this]() {
device_ctx_->Initialize();
if (!group->Initialize(root_info)) {
MS_LOG(ERROR) << "Initialize group on the device side failed.";
std::unique_lock<std::mutex> lock(init_group_mutex_);
init_group_fail = true;
thread_blocker.notify_one();
return;
}

{
std::unique_lock<std::mutex> lock(init_group_mutex_);
init_group_success = true;
thread_blocker.notify_one();
}
});
init_group_thread_->detach();
return group->Initialize(root_info);
};
MS_LOG(INFO) << "Begin initialize communication group on the device side.";

// Timeout limit 180 seconds to wait finishing init device communication group.
// Timeout limit 180 seconds to wait finish initializing device communication group.
const int64_t kTimeToWait = 180;
std::unique_lock<std::mutex> locker(init_group_mutex_);
(void)thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait),
[&] { return init_group_success || init_group_fail; });

if (!init_group_success && runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
runtime::recovery::RecoveryContext::GetInstance()->set_recovery_status(
runtime::recovery::RecoveryErrCode::kInitNcclFailed);
MS_LOG(ERROR) << "Initialize group on the device side failed.";
}
return init_group_success;
// Initialize communication group on the device side in thread with timeout limit.
ret = ExecuteFuncInThread(init_device_comm_group_func, kTimeToWait);

MS_LOG(INFO) << "End initialize communication group on the device side.";
return ret;
}

bool CollectiveManager::DestroyCommunicationGroup(const std::string &group_name) {
@@ -197,22 +258,34 @@ uint32_t CollectiveManager::GetGroupSize(const std::string &group_name) {
}

bool CollectiveManager::Finalize() {
if (finalized_) {
if (!inited_.load() || finalized_.load()) {
return true;
}

MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
if (!host_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize host communication library.";
}
std::function<bool()> finalize_func = [&, this]() {
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
if (!host_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize host communication library.";
}

MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!device_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize device communication library.";
}
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
if (!device_comm_lib_instance_->Finalize()) {
MS_LOG(WARNING) << "Failed to finalize device communication library.";
}

finalized_ = true;
return true;
finalized_ = true;
return true;
};

MS_LOG(INFO) << "Begin finalize collective manager.";

// Timeout limit 5 seconds to wait to finish finalizing device communication group.
const int64_t kTimeToWait = 5;
// Finalize collective manager in thread with timeout limit.
bool ret = ExecuteFuncInThread(finalize_func, kTimeToWait);

MS_LOG(INFO) << "End finalize collective manager.";
return ret;
}

void CollectiveManager::set_global_rank_id(uint32_t global_rank_id) { global_rank_id_ = global_rank_id; }


+ 11
- 9
mindspore/ccsrc/distributed/collective/collective_manager.h View File

@@ -69,6 +69,11 @@ class BACKEND_EXPORT CollectiveManager {

uint32_t local_rank_id() const;

// Set whether need reinitialize collective communication.
void set_need_reinit(bool need_reinit) { need_reinit_ = need_reinit; }
// Get whether need reinitialize collective communication.
bool need_reinit() const { return need_reinit_.load(); }

private:
CollectiveManager();

@@ -81,16 +86,13 @@ class BACKEND_EXPORT CollectiveManager {
// Assign the local rank id for this process.
bool AssignLocalRank();

// Initialize communication group on the device side.
bool InitDeviceCommGroup(const CommunicationGroupPtr &group, void *root_info);

// Initialize communication group on the device side in thread with timeout limit.
std::unique_ptr<std::thread> init_group_thread_;
std::mutex init_group_mutex_;

std::atomic_bool inited_;
std::atomic_bool finalized_;

// Whether need reinitialize collective communication, this value should be set to true once a training process
// exits unexpectedly is detected.
std::atomic_bool need_reinit_;

// The device type read from MindSpore context.
std::string device_type_;

@@ -119,8 +121,8 @@ class BACKEND_EXPORT CollectiveManager {
// Global group ranks.
std::vector<uint32_t> global_group_ranks_;

// The global group name on the host side. This is used for Creating global group on host side for AllGather operation
// of host name while assigning local rank.
// The global group name on the host side. This is used for Creating global group on host side for AllGather
// operation of host name while assigning local rank.
std::string host_global_group_name_;
};
} // namespace collective


+ 4
- 5
mindspore/ccsrc/distributed/init.cc View File

@@ -17,11 +17,11 @@
#include "distributed/init.h"
#include <vector>
#include <string>
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"

namespace mindspore {
namespace distributed {
using runtime::recovery::RecoveryContext;
using distributed::recovery::RecoveryContext;

bool Initialize() {
if (!InitializeCluster()) {
@@ -43,7 +43,8 @@ bool Initialize() {
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());

if (RecoveryContext::GetInstance()->enable_recovery()) {
cluster::ClusterContext::instance()->WaitForClusterReady();
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
}

if (!InitializeCollective()) {
@@ -52,8 +53,6 @@ bool Initialize() {
}

if (RecoveryContext::GetInstance()->enable_recovery()) {
RecoveryContext::GetInstance()->set_global_rank_id(abstract_node->rank_id());
RecoveryContext::GetInstance()->set_global_rank_size(abstract_node->worker_num());
RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();
}
}


mindspore/ccsrc/runtime/recovery/recovery_context.cc → mindspore/ccsrc/distributed/recovery/recovery_context.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"

#include <dirent.h>
#include <algorithm>
@@ -26,13 +26,12 @@
#include "utils/file_utils.h"
#include "distributed/constants.h"
#include "distributed/cluster/topology/common.h"
#include "distributed/init.h"
#include "runtime/hardware/device_context.h"
#include "runtime/hardware/device_context_manager.h"
#include "utils/convert_utils_base.h"
#include "utils/ms_context.h"

namespace mindspore {
namespace runtime {
namespace distributed {
namespace recovery {
constexpr char kEnvEnableRecovery[] = "MS_ENABLE_RECOVERY";
constexpr char kEnvRecoveryPath[] = "MS_RECOVERY_PATH";
@@ -142,30 +141,6 @@ void RecoveryContext::Initialize() {
initialized_ = true;
}

bool RecoveryContext::ReInitializeCollective() {
auto ret = distributed::Initialize();
if (ret) {
recovery_status_ = RecoveryErrCode::kUnKnownError;
set_need_reset(true);
set_need_sync_weight_to_device(true);
return true;
}

if (recovery_status_ == RecoveryErrCode::kBroadcastUniqueIDFailed ||
recovery_status_ == RecoveryErrCode::kAllGatherHostNameFailed) {
MS_LOG(WARNING) << "Prepare to initialize NCCL failed, retrying.";
// Retry duration: 30s.
const int kRetryDuration = 30;
std::this_thread::sleep_for(std::chrono::seconds(kRetryDuration));
return ReInitializeCollective();
} else if (recovery_status_ == RecoveryErrCode::kInitNcclFailed) {
MS_LOG(EXCEPTION) << "Initialize NCCL failed.";
}

MS_LOG(EXCEPTION) << "ReInitialize collective failed.";
return false;
}

void RecoveryContext::ObtainGlobalLatestCkptInfo() {
// 1. Obtain the step corresponding to the local latest checkpoint.
ObtainLocalLatestCkptInfo();
@@ -326,6 +301,7 @@ void RecoveryContext::ParseLatestCkptInfo(const int *recv_buffer, const uint32_t
}

void RecoveryContext::CreatePersistentFile() {
std::unique_lock<std::mutex> lock(create_persist_json_mtx_);
if (node_role_ == distributed::kEnvRoleOfScheduler) {
return;
}
@@ -344,7 +320,7 @@ void RecoveryContext::CreatePersistentFile() {
// The directory used to save ckpt is persisted to json file.
std::string persistent_file_path =
recovery_path_ + "/" + node_role_ + "_" + std::to_string(global_rank_id_) + "_persistent.json";
persistent_json_ = std::make_unique<JsonUtils>(persistent_file_path);
persistent_json_ = std::make_shared<JsonUtils>(persistent_file_path);
if (!persistent_json_->Initialize()) {
MS_LOG(EXCEPTION) << "Initialize json failed, file path: " << persistent_file_path;
}
@@ -388,6 +364,24 @@ std::string RecoveryContext::GetCkptPath() {

return persistent_json_->Get<std::string>(kCkptPath);
}

const std::shared_ptr<JsonUtils> &RecoveryContext::persistent_json() {
if (persistent_json_ == nullptr) {
CreatePersistentFile();
}

MS_EXCEPTION_IF_NULL(persistent_json_);
return persistent_json_;
}

std::string RecoveryContext::latest_ckpt_file() {
// For standalone training.
if (enable_recovery_ && global_rank_size_ == 0 && latest_ckpt_file_.empty()) {
ObtainLocalLatestCkptInfo();
}

return latest_ckpt_file_;
}
} // namespace recovery
} // namespace runtime
} // namespace distributed
} // namespace mindspore

mindspore/ccsrc/runtime/recovery/recovery_context.h → mindspore/ccsrc/distributed/recovery/recovery_context.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
#define MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_

#include <vector>
#include <string>
@@ -27,13 +27,11 @@
#include "include/backend/visible.h"

namespace mindspore {
namespace runtime {
namespace distributed {
namespace recovery {
using distributed::storage::FileIOUtils;
using distributed::storage::JsonUtils;

enum class RecoveryErrCode { kUnKnownError, kAllGatherHostNameFailed, kBroadcastUniqueIDFailed, kInitNcclFailed };

// Used to save disaster recovery-related state quantities and provide disaster recovery-related
// functions, such as reinitializing collective communication, etc.
class BACKEND_EXPORT RecoveryContext {
@@ -47,14 +45,6 @@ class BACKEND_EXPORT RecoveryContext {
}
~RecoveryContext() = default;

// Reinitializing collective communication.
bool ReInitializeCollective();

// Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be
// some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the
// successful checkpoint in each training process, and then take the minimum value as the final reset position.
void ObtainGlobalLatestCkptInfo();

// Get whether enable recovery or not.
bool enable_recovery() const { return enable_recovery_; }

@@ -64,18 +54,14 @@ class BACKEND_EXPORT RecoveryContext {
// Get interval to persist model.
int recovery_interval() const { return recovery_interval_; }

// Get the error status of recovery.
RecoveryErrCode recovery_status() const { return recovery_status_; }
// Set the error status of recovery.
void set_recovery_status(RecoveryErrCode recovery_status) { recovery_status_ = recovery_status; }

// Set the path used to save checkpoint.
void SetCkptPath(const std::string &path);
// Get the path used to save checkpoint.
std::string GetCkptPath();

// Get the latest checkpoint in this node.
std::string latest_ckpt_file() const { return latest_ckpt_file_; }
std::string latest_ckpt_file();

// Get the epoch of latest checkpoint in this node.
int latest_ckpt_epoch() const { return latest_ckpt_epoch_; }
// Get the step of latest checkpoint in this node.
@@ -99,10 +85,13 @@ class BACKEND_EXPORT RecoveryContext {
// Set global rank size.
void set_global_rank_size(uint32_t global_rank_size) { global_rank_size_ = global_rank_size; }

// Set whether need reinitialize collective communication.
void set_need_reinit_collective(bool need_reinit_collective) { need_reinit_collective_ = need_reinit_collective; }
// Get whether need reinitialize collective communication.
bool need_reinit_collective() const { return need_reinit_collective_.load(); }
// Obtain the global step corresponding to the global latest checkpoint in each training process. Since there may be
// some processes that fails to save the checkpoint, it is necessary for AllGather to save the latest step of the
// successful checkpoint in each training process, and then take the minimum value as the final reset position.
void ObtainGlobalLatestCkptInfo();

// Get the persistent json file pointer.
const std::shared_ptr<JsonUtils> &persistent_json();

private:
inline static std::shared_ptr<RecoveryContext> instance_{};
@@ -155,20 +144,14 @@ class BACKEND_EXPORT RecoveryContext {
// performs load checkpoint.
bool need_sync_weight_to_device_{false};

// Whether need reinitialize collective communication, this value should be set to true once a training process
// exits unexpectedly is detected.
std::atomic_bool need_reinit_collective_{false};

// Whether the recovery context is already initialized.
bool initialized_{false};

// The error status of recovery.
RecoveryErrCode recovery_status_{RecoveryErrCode::kUnKnownError};

std::mutex create_persist_json_mtx_;
// The persitent json file util, used to persist recovery config.
std::unique_ptr<JsonUtils> persistent_json_;
std::shared_ptr<JsonUtils> persistent_json_;
};
} // namespace recovery
} // namespace runtime
} // namespace distributed
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_RECOVERY_RECOVERY_H_
#endif // MINDSPORE_CCSRC_DISTRIBUTED_RECOVERY_RECOVERY_H_

+ 2
- 2
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -40,7 +40,7 @@
#include "ps/util.h"
#endif
#include "ps/ps_context.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"

#include "pybind_api/gil_scoped_long_running.h"

@@ -58,7 +58,7 @@ using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam;
using PSContext = mindspore::ps::PSContext;
using RecoveryContext = mindspore::runtime::recovery::RecoveryContext;
using RecoveryContext = mindspore::distributed::recovery::RecoveryContext;

// Interface with python
PYBIND11_MODULE(_c_expression, m) {


+ 3
- 0
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -61,6 +61,7 @@
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/pynative/op_executor.h"
#include "runtime/device/stream_synchronizer.h"
#include "distributed/collective/collective_manager.h"

#ifndef ENABLE_SECURITY
#ifdef ENABLE_D
@@ -1608,6 +1609,8 @@ void ClearResAtexit() {
device::StreamSynchronizer::GetInstance()->Finalize();
MS_LOG(INFO) << "End Finalize StreamSynchronizer...";

(void)distributed::collective::CollectiveManager::instance()->Finalize();

PrimitivePy::ClearHookRes();
ad::g_k_prims.clear();
ad::ClearKPynativeCellStaticRes();


+ 1
- 1
mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_comm_lib.h View File

@@ -35,7 +35,7 @@ using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;
using ps::core::NodeCommand;

// The time interval for send info or query info between worker and scheduler.
constexpr uint32_t kWaitDuration = 3;
constexpr uint32_t kWaitDuration = 5;

// The collective communication library for MindSpore self developed communication framework.
class MsCollectiveCommLib : public CollectiveCommunicationLib {


+ 1
- 0
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -1301,6 +1301,7 @@ void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::SEND_EVENT] = nullptr;
RegisterActorRouteTableRspHandler();
RegisterInitCollectCommResphandler();
RegisterRecoveryRespHandler();
}

void AbstractNode::RegisterActorRouteTableRspHandler() {


+ 3
- 0
mindspore/ccsrc/ps/core/abstract_node.h View File

@@ -228,6 +228,9 @@ class BACKEND_EXPORT AbstractNode : public Node {
// Register collective communication initialization response methods.
virtual void RegisterInitCollectCommResphandler() {}

// Register recovery response methods.
virtual void RegisterRecoveryRespHandler() {}

// when initializing the node, should initializing the node info.
void InitNodeInfo(const NodeRole &role);
// Initialize worker num and server num by cluster config.


+ 11
- 0
mindspore/ccsrc/ps/core/abstract_ps_node.cc View File

@@ -139,6 +139,9 @@ bool AbstractPSNode::HandleHeartbeatTimeout() {
if (!stop_heartbeat_.load()) {
stop_heartbeat_ = true;
while (!heartbeat_stopped_.load()) {
if (is_finish_.load()) {
return;
}
MS_LOG(INFO) << "Waiting for heartbeat to stop...";

// Time interval for waiting the heartbeat to stop.
@@ -152,6 +155,9 @@ bool AbstractPSNode::HandleHeartbeatTimeout() {

bool success = false;
while (!success) {
if (is_finish_.load()) {
return;
}
MS_LOG(WARNING) << "Trying to reconnect to the scheduler...";
success = InitClientToScheduler();
if (success) {
@@ -176,6 +182,11 @@ void AbstractPSNode::RegisterInitCollectCommResphandler() {
handlers_[NodeCommand::SEND_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::QUERY_UNIQUE_ID] = &AbstractPSNode::ProcessReceiveSchedulerResp;
}

void AbstractPSNode::RegisterRecoveryRespHandler() {
handlers_[NodeCommand::SEND_FINISH_TRANSFORM] = &AbstractPSNode::ProcessReceiveSchedulerResp;
handlers_[NodeCommand::QUERY_FINISH_TRANSFORM] = &AbstractPSNode::ProcessReceiveSchedulerResp;
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 3
- 0
mindspore/ccsrc/ps/core/abstract_ps_node.h View File

@@ -38,6 +38,9 @@ class AbstractPSNode : public AbstractNode {
// Register collective communication initialization response methods.
void RegisterInitCollectCommResphandler() override;

// Register recovery response methods.
void RegisterRecoveryRespHandler() override;

// Indicate whether the heartbeat thread should be stopped.
std::atomic<bool> stop_heartbeat_{false};



+ 6
- 1
mindspore/ccsrc/ps/core/node_manager.cc View File

@@ -67,7 +67,7 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
return rank_id;
}
} else {
ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port());
(void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id);
}
return rank_id;
}
@@ -206,11 +206,16 @@ std::vector<ServersMeta> NodeManager::FetchAllNodesMeta() {
return servers_meta_list;
}

const std::unordered_map<std::string, NodeInfo> &NodeManager::QueryTimeOutNodesInfo() const {
return timeout_nodes_info_;
}

void NodeManager::UpdateCluster() {
// 1. update cluster timeout state
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
timeout_nodes_info_.clear();
std::lock_guard<std::mutex> lock(heartbeat_mutex_);
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
if (registered_nodes_info_.count(it->first)) {


+ 2
- 0
mindspore/ccsrc/ps/core/node_manager.h View File

@@ -139,6 +139,8 @@ class NodeManager {

bool IsAllNodesAlive() const;

const std::unordered_map<std::string, NodeInfo> &QueryTimeOutNodesInfo() const;

private:
std::mutex node_mutex_;
std::mutex cluster_mutex_;


+ 22
- 0
mindspore/ccsrc/ps/core/protos/comm.proto View File

@@ -57,6 +57,12 @@ enum NodeCommand {
SEND_UNIQUE_ID = 20;
// Query unique id used to initialize collective communication.
QUERY_UNIQUE_ID = 21;
// Send the ready status to finish transform graph of computed node,
// used in disaster recovery mode to prevent timeout of waiting for graph transformation.
SEND_FINISH_TRANSFORM = 22;
// Query the ready status to finish transform graph of computed node,
// used in disaster recovery mode to prevent timeout of waiting for graph transformation.
QUERY_FINISH_TRANSFORM = 23;
}

enum NodeRole {
@@ -298,3 +304,19 @@ message QueryUniqueIDRespMessage {
// The unique id used to initialize collective communication.
bytes unique_id = 2;
}

message SendFinishTransformMessage {
// the current Node unique id:0,1,2...
string node_id = 1;
// The rank id of the node in the cluster.
uint32 rank_id = 2;
// Whether finish transform graph.
bool is_ready = 3;
}

message QueryFinishTransformRespMessage {
// Whether all computed nodes are ready to run dag.
bool is_ready = 1;
// Whether there is any worker timeout.
bool is_worker_timeout = 2;
}

+ 96
- 4
mindspore/ccsrc/ps/core/ps_scheduler_node.cc View File

@@ -16,6 +16,7 @@

#include <memory>
#include "ps/core/ps_scheduler_node.h"
#include "utils/ms_context.h"

namespace mindspore {
namespace ps {
@@ -58,6 +59,13 @@ void PSSchedulerNode::RegisterInitCollectCommServiceHandler() {
handlers_[NodeCommand::QUERY_UNIQUE_ID] = static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryUniqueID);
}

void PSSchedulerNode::RegisterRecoveryServiceHandler() {
handlers_[NodeCommand::SEND_FINISH_TRANSFORM] =
static_cast<ResponseHandler>(&PSSchedulerNode::ProcessSendFinishTransform);
handlers_[NodeCommand::QUERY_FINISH_TRANSFORM] =
static_cast<ResponseHandler>(&PSSchedulerNode::ProcessQueryFinishTransform);
}

void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
@@ -71,7 +79,7 @@ void PSSchedulerNode::ProcessSendHostName(const std::shared_ptr<TcpServer> &serv
std::string node_id = send_host_name_msg.node_id();
uint32_t rank_id = send_host_name_msg.rank_id();
size_t host_hash_name = send_host_name_msg.host_hash_name();
MS_LOG(INFO) << "Received send host name request, node id: " << node_id << ", rank id: " << rank_id;
MS_LOG(INFO) << "Receive send host name request, node id: " << node_id << ", rank id: " << rank_id;

bool ret = false;
std::string error = "";
@@ -100,7 +108,7 @@ void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &se
query_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = query_msg.node_id();
uint32_t rank_id = query_msg.rank_id();
MS_LOG(INFO) << "Received query host name request, node id: " << node_id << ", rank id: " << rank_id;
MS_LOG(INFO) << "Receive query host name request, node id: " << node_id << ", rank id: " << rank_id;

bool is_success = recv_rank_id_send_host_name_.size() == host_hash_names_.size();
QueryHostHashNameRespMessage resp_msg;
@@ -121,6 +129,7 @@ void PSSchedulerNode::ProcessQueryHostNames(const std::shared_ptr<TcpServer> &se
if (recv_rank_id_query_host_name_.size() == recv_rank_id_send_host_name_.size()) {
recv_rank_id_send_host_name_.clear();
recv_rank_id_query_host_name_.clear();
node_timeout_ = false;
}
}
}
@@ -138,7 +147,7 @@ void PSSchedulerNode::ProcessSendUniqueID(const std::shared_ptr<TcpServer> &serv
std::string node_id = send_unique_id_msg.node_id();
uint32_t rank_id = send_unique_id_msg.rank_id();
std::string group_name = send_unique_id_msg.group_name();
MS_LOG(INFO) << "Received send unique id request, group name: " << group_name << ", node id: " << node_id
MS_LOG(INFO) << "Receive send unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;

bool ret = false;
@@ -169,7 +178,7 @@ void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &ser
std::string node_id = query_msg.node_id();
uint32_t rank_id = query_msg.rank_id();
std::string group_name = query_msg.group_name();
MS_LOG(INFO) << "Received query unique id request, group name: " << group_name << ", node id: " << node_id
MS_LOG(INFO) << "Receive query unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;

auto iter = unique_id_group_.find(group_name);
@@ -190,6 +199,89 @@ void PSSchedulerNode::ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &ser
MS_LOG(INFO) << "Respond query unique id request, group name: " << group_name << ", node id: " << node_id
<< ", rank id: " << rank_id;
}

void PSSchedulerNode::ProcessSendFinishTransform(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);

SendFinishTransformMessage send_ready_to_run_msg;
send_ready_to_run_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = send_ready_to_run_msg.node_id();
uint32_t rank_id = send_ready_to_run_msg.rank_id();
MS_LOG(INFO) << "Receive send finish transform request, node id: " << node_id << ", rank id: " << rank_id;
bool is_ready = send_ready_to_run_msg.is_ready();
if (is_ready) {
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
(void)nodes_finish_trans_.insert(rank_id);
}

GeneralResponse(server, conn, meta, true, "");
MS_LOG(INFO) << "Respond send finish transform request, node id: " << node_id << ", rank id: " << rank_id;
}

void PSSchedulerNode::ProcessQueryFinishTransform(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);

GeneralQueryMessage query_msg;
query_msg.ParseFromArray(data, SizeToInt(size));
std::string node_id = query_msg.node_id();
uint32_t rank_id = query_msg.rank_id();
MS_LOG(INFO) << "Receive query finish transform request, node id: " << node_id << ", rank id: " << rank_id;

std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
bool is_ready = nodes_finish_trans_.size() == worker_num_;

QueryFinishTransformRespMessage resp_msg;
resp_msg.set_is_ready(is_ready);

if (node_timeout_) {
(void)resp_msg.set_is_worker_timeout(true);
} else {
resp_msg.set_is_worker_timeout(false);
}

if (!server->SendMessage(conn, meta, Protos::PROTOBUF, resp_msg.SerializeAsString().data(),
resp_msg.ByteSizeLong())) {
MS_LOG(ERROR) << "Scheduler failed to respond message.";
return;
}
MS_LOG(INFO) << "Respond query finish transform request, node id: " << node_id << ", rank id: " << rank_id;
}

void PSSchedulerNode::HandleNodeTimeoutForRecovery(
const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) {
return;
}

if (timeout_nodes_infos.empty()) {
return;
}

std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
node_timeout_ = true;
for (const auto &item : timeout_nodes_infos) {
(void)nodes_finish_trans_.erase(item.second.rank_id_);
}
}

void PSSchedulerNode::HandleNodeRecoverByHeartBeat(uint32_t rank_id) {
std::unique_lock<std::mutex> lock(nodes_finish_trans_mutex_);
(void)nodes_finish_trans_.insert(rank_id);
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 27
- 0
mindspore/ccsrc/ps/core/ps_scheduler_node.h View File

@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_

#include <map>
#include <unordered_map>
#include <memory>
#include <vector>
#include <set>
@@ -47,9 +48,16 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
// alive node should be rejected.
bool NeedRejectRegister(const NodeInfo &node_info) override { return node_info.is_alive; }

bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos) override {
return true;
};

// Register collective communication initialization service.
void RegisterInitCollectCommServiceHandler() override;

// Register recovery service.
void RegisterRecoveryServiceHandler() override;

// Process message for sending node's host name.
void ProcessSendHostName(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
@@ -66,6 +74,20 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
void ProcessQueryUniqueID(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Process message for sending the ready status to finish transform graph of computed node,
void ProcessSendFinishTransform(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Process message for querying the ready status to finish transform graph of computed node,
void ProcessQueryFinishTransform(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Handle node timeout info and update nodes which finish transform graph.
void HandleNodeTimeoutForRecovery(const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) override;

// Recover finish transform nodes info when nodes recover heartbeat.
void HandleNodeRecoverByHeartBeat(uint32_t rank_id) override;

// Record received host hash name from workers.
std::vector<size_t> host_hash_names_;
// Record rank id of the nodes which sended host name.
@@ -77,6 +99,11 @@ class BACKEND_EXPORT PSSchedulerNode : public SchedulerNode {
std::map<std::string, std::string> unique_id_group_;

uint32_t worker_num_;

std::mutex nodes_finish_trans_mutex_;
// Record the rank ids of nodes who finish transform graph.
std::set<uint32_t> nodes_finish_trans_;
std::atomic_bool node_timeout_{false};
};
} // namespace core
} // namespace ps


+ 6
- 8
mindspore/ccsrc/ps/core/ps_server_node.cc View File

@@ -39,16 +39,13 @@ bool PSServerNode::Start(const uint32_t &timeout) {
}

void PSServerNode::Initialize() {
InitNodeNum();
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeNum();
} else {
if (!Recover()) {
MS_LOG(WARNING) << "Recover the server node is failed.";
}
if (config_->Initialize() && !Recover()) {
MS_LOG(INFO) << "Recover the server node is failed.";
}

InitServerHandler();
CreateTcpServer();
InitNodeInfo(NodeRole::SERVER);
@@ -119,8 +116,9 @@ void PSServerNode::Register(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";

const int kCommTimeoutInSeconds = 20;
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
register_message.ByteSizeLong())) {
register_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
} else {


+ 6
- 8
mindspore/ccsrc/ps/core/ps_worker_node.cc View File

@@ -79,16 +79,13 @@ bool PSWorkerNode::Finish(const uint32_t &timeout) {
}

void PSWorkerNode::Initialize() {
InitNodeNum();
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeNum();
} else {
if (!Recover()) {
MS_LOG(WARNING) << "Recover the worker node is failed.";
}
if (config_->Initialize() && !Recover()) {
MS_LOG(INFO) << "Recover the worker node is failed.";
}

InitServerHandler();
CreateTcpServer();
InitNodeInfo(NodeRole::WORKER);
@@ -120,8 +117,9 @@ void PSWorkerNode::Register(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";

const int kCommTimeoutInSeconds = 20;
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
register_message.ByteSizeLong())) {
register_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
} else {


+ 5
- 1
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -176,10 +176,12 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
return;
}

uint32_t rank_id = UINT32_MAX;
// Re-Add the missing node into node manager.
if (heartbeat_message.has_address() &&
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port())) {
node_manager_.ReAddNodeIfNotExists(node_id, heartbeat_message.ip(), heartbeat_message.port(), &rank_id)) {
SetRegisterConnectionFd(conn, node_id);
HandleNodeRecoverByHeartBeat(rank_id);

if (node_manager_.IsAllNodesRegistered()) {
is_ready_ = true;
@@ -234,6 +236,7 @@ void SchedulerNode::InitCommandHandler() {
handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent;
RegisterActorRouteTableServiceHandler();
RegisterInitCollectCommServiceHandler();
RegisterRecoveryServiceHandler();
}

void SchedulerNode::RegisterActorRouteTableServiceHandler() {
@@ -699,6 +702,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
}
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
node_manager_.UpdateCluster();
HandleNodeTimeoutForRecovery(node_manager_.QueryTimeOutNodesInfo());

if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
std::this_thread::sleep_for(


+ 10
- 1
mindspore/ccsrc/ps/core/scheduler_node.h View File

@@ -92,6 +92,15 @@ class BACKEND_EXPORT SchedulerNode : public Node {
// Register collective communication initialization service.
virtual void RegisterInitCollectCommServiceHandler() {}

// Register recovery service.
virtual void RegisterRecoveryServiceHandler() {}

// Handle node timeout info and update nodes which finish transform graph.
virtual void HandleNodeTimeoutForRecovery(const std::unordered_map<std::string, NodeInfo> &timeout_nodes_infos) {}

// Recover finish transform nodes info when nodes recover heartbeat.
virtual void HandleNodeRecoverByHeartBeat(uint32_t rank_id) {}

const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info);

void ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
@@ -197,7 +206,7 @@ class BACKEND_EXPORT SchedulerNode : public Node {

void SetRegisterConnectionFd(const std::shared_ptr<TcpConnection> &conn, const std::string &node_id);

bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos);
virtual bool SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos);

// Responding peer with the general response message.
void GeneralResponse(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,


+ 9
- 13
mindspore/ccsrc/runtime/device/stream_synchronizer.cc View File

@@ -17,16 +17,19 @@
#include "runtime/device/stream_synchronizer.h"
#include "utils/ms_context.h"
#include "distributed/collective/collective_manager.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"

namespace mindspore {
namespace device {
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;

std::mutex StreamSynchronizer::instance_lock_;
std::shared_ptr<StreamSynchronizer> StreamSynchronizer::instance_ = nullptr;

void StreamSynchronizer::Initialize() {
// Non disaster recovery mode does not need to start thread and timeout mechanisms.
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) {
if (!RecoveryContext::GetInstance()->enable_recovery()) {
return;
}

@@ -56,7 +59,7 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim
MS_EXCEPTION_IF_NULL(device_context);

// If disable recovery or timeout==0, sync stream directly to improve performance.
if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
if (!RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
device_context->Initialize();
return device_context->SyncStream();
}
@@ -68,26 +71,19 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim
device_context_ = device_context;
do_sync_stream_cv_.notify_one();

if (sync_stream_time_out_) {
// If sync stream timeout has happened, increase the timeout by 4 times.
const uint32_t kTimeOutScaleFactor = 4;
timeout *= kTimeOutScaleFactor;
}

if (time_out_cv_.wait_for(lock, std::chrono::seconds(timeout)) == std::cv_status::no_timeout) {
if (!sync_stream_ret_) {
MS_LOG(ERROR) << "Synchronize stream failed.";
}
return sync_stream_ret_;
} else {
sync_stream_time_out_ = true;
runtime::recovery::RecoveryContext::GetInstance()->set_need_reinit_collective(true);
if (!distributed::collective::CollectiveManager::instance()->Finalize()) {
CollectiveManager::instance()->set_need_reinit(true);
if (!CollectiveManager::instance()->Finalize()) {
MS_LOG(ERROR) << "Finalize collective manager failed.";
return false;
}
time_out_cv_.wait(lock, [this]() { return device_context_ == nullptr; });
MS_LOG(WARNING) << "Synchronize stream time out.";
MS_LOG(WARNING) << "Synchronize stream timeout.";
return true;
}
}


+ 0
- 3
mindspore/ccsrc/runtime/device/stream_synchronizer.h View File

@@ -71,9 +71,6 @@ class BACKEND_EXPORT StreamSynchronizer {
// The singleton pointer.
static std::shared_ptr<StreamSynchronizer> instance_;

// Record whether the synchronization stream task has timed out.
bool sync_stream_time_out_{false};

// Return value of synchronization stream.
bool sync_stream_ret_{false};



+ 2
- 2
mindspore/ccsrc/runtime/graph_scheduler/actor/data_prepare_actor.cc View File

@@ -25,11 +25,11 @@
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "include/common/utils/convert_utils.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"

namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::recovery::RecoveryContext;
namespace {
void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
const DeviceContext *device_context, OpContext<DeviceTensor> *const context,


+ 5
- 3
mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc View File

@@ -21,11 +21,13 @@
#include "runtime/graph_scheduler/actor/debug_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "distributed/collective/collective_manager.h"

namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;

void KernelActor::Init() {
// Check device contexts number.
@@ -243,7 +245,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
PreLaunchKernel(context);

try {
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
// In disaster recovery scenarios, run dag in this step failed, the rest operators of graph do not need launch,
// especially the collective communication operators.
MS_LOG(WARNING) << "Collective communication need reinitialize, skip launch kernel: "


+ 5
- 4
mindspore/ccsrc/runtime/graph_scheduler/actor/loop_count_actor.cc View File

@@ -25,11 +25,13 @@
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
#include "runtime/device/stream_synchronizer.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "distributed/collective/collective_manager.h"

namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;

void LoopCountActor::Run(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
@@ -70,8 +72,7 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
(void)sync_stream_device_contexts.insert(device_context);

// Trigger disaster recovery and exit loop early.
if (RecoveryContext::GetInstance()->enable_recovery() &&
RecoveryContext::GetInstance()->need_reinit_collective()) {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
current_count_ = loop_count_;
}
}


+ 5
- 3
mindspore/ccsrc/runtime/graph_scheduler/actor/output_actor.cc View File

@@ -17,11 +17,13 @@
#include "runtime/graph_scheduler/actor/output_actor.h"
#include "runtime/hardware/device_context_manager.h"
#include "utils/log_adapter.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "distributed/collective/collective_manager.h"

namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;

bool IsOutputAddressPersisted(const DeviceTensor *output_device_tensor, const AnfNodePtr &output_node) {
MS_EXCEPTION_IF_NULL(output_node);
@@ -105,7 +107,7 @@ void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const contex
++current_count_;

// Trigger disaster recovery and return empty output.
if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
FreeOutputNodeMem();
ClearOutputCache();
SET_OPCONTEXT_SUCCESS_RET((*context));


+ 126
- 11
mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc View File

@@ -45,11 +45,19 @@
#endif
#include "profiler/device/profiling.h"
#include "include/common/debug/common.h"
#include "runtime/recovery/recovery_context.h"
#include "distributed/recovery/recovery_context.h"
#include "distributed/collective/collective_manager.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
#include "distributed/cluster/cluster_context.h"
#else
#include "distributed/cluster/dummy_cluster_context.h"
#endif

namespace mindspore {
namespace runtime {
using recovery::RecoveryContext;
using distributed::cluster::ClusterContext;
using distributed::collective::CollectiveManager;
using distributed::recovery::RecoveryContext;
namespace {
bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
MS_EXCEPTION_IF_NULL(from_device_context);
@@ -196,6 +204,108 @@ void IntHandler(int, siginfo_t *, void *) {
(void)kill(this_pid, SIGTERM);
}
#endif

#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
bool SendFinishTransform() {
auto node = ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
if (node->role() != ps::core::NodeRole::WORKER) {
return true;
}

auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);

ps::core::SendFinishTransformMessage send_ready_to_run_msg;
send_ready_to_run_msg.set_node_id(abstract_node->node_id());
send_ready_to_run_msg.set_rank_id(abstract_node->rank_id());
send_ready_to_run_msg.set_is_ready(true);
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
if (!abstract_node->SendToScheduler(send_ready_to_run_msg.SerializeAsString().data(),
send_ready_to_run_msg.SerializeAsString().size(),
ps::core::NodeCommand::SEND_FINISH_TRANSFORM, &output)) {
MS_LOG(WARNING) << "Failed to send finish transform request to scheduler.";
return false;
}

ps::core::GeneralResponseMsg resp_msg;
MS_EXCEPTION_IF_NULL(output);
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
if (!resp_msg.is_success()) {
MS_LOG(ERROR) << "Send finish transform to scheduler failed.";
return false;
}
return true;
}

bool QueryFinishTransform() {
auto node = ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
if (node->role() != ps::core::NodeRole::WORKER) {
return true;
}

auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);

ps::core::GeneralQueryMessage general_query_msg;
general_query_msg.set_node_id(abstract_node->node_id());
general_query_msg.set_rank_id(abstract_node->rank_id());
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
bool ret = false;
while (!ret) {
if (!abstract_node->SendToScheduler(general_query_msg.SerializeAsString().data(),
general_query_msg.SerializeAsString().size(),
ps::core::NodeCommand::QUERY_FINISH_TRANSFORM, &output)) {
MS_LOG(WARNING) << "Failed to send query finish transform request to scheduler.";
ret = false;
continue;
}

ps::core::QueryFinishTransformRespMessage resp_msg;
MS_EXCEPTION_IF_NULL(output);
(void)resp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
ret = resp_msg.is_ready();
if (!ret) {
MS_LOG(INFO) << "There is worker which has not finished transform graph";
}

if (resp_msg.is_worker_timeout()) {
MS_LOG(WARNING) << "There is worker timeout";
return false;
}
// The time interval for querying the all worker finish transform graphs status to scheduler: 10 seconds.
const uint32_t kWaitDuration = 10;
std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
}

return ret;
}

void DoDisasterRecovery() {
if (RecoveryContext::GetInstance()->enable_recovery() && CollectiveManager::instance()->need_reinit()) {
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
bool ret = false;
while (!ret) {
while (!CollectiveManager::instance()->Initialize()) {
MS_LOG(WARNING) << "ReInitialize collective communication failed, retrying...";
}
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";

RecoveryContext::GetInstance()->ObtainGlobalLatestCkptInfo();

ret = QueryFinishTransform();
if (!ret) {
CollectiveManager::instance()->set_need_reinit(true);
(void)CollectiveManager::instance()->Finalize();
}
}

RecoveryContext::GetInstance()->set_need_reset(true);
RecoveryContext::GetInstance()->set_need_sync_weight_to_device(true);
}
}
#endif
} // namespace

GraphScheduler &GraphScheduler::GetInstance() noexcept {
@@ -407,6 +517,17 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
}
MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";

#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
if (ClusterContext::instance()->initialized() && RecoveryContext::GetInstance()->enable_recovery()) {
while (!SendFinishTransform()) {
MS_LOG(WARNING) << "Send finish transform graph failed.";
// The time interval for sending finish transform graph to scheduler.
constexpr uint32_t kWaitDuration = 10;
std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
}
}
#endif

return actor_set.get();
}

@@ -489,15 +610,9 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
const size_t kSecondsToMilliseconds = 1000;
SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);

if (RecoveryContext::GetInstance()->enable_recovery() && RecoveryContext::GetInstance()->need_reinit_collective()) {
MS_LOG(INFO) << "Begin reinitialize collective communication for recovery.";
if (!RecoveryContext::GetInstance()->ReInitializeCollective()) {
MS_LOG(EXCEPTION) << "Reinitialize collective communication failed.";
}
MS_LOG(INFO) << "Finish reinitialize collective communication for recovery.";

RecoveryContext::GetInstance()->set_need_reinit_collective(false);
}
#if ((defined ENABLE_CPU) && (!defined _WIN32) && (!defined _WIN64))
DoDisasterRecovery();
#endif
}

void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,


+ 0
- 9
mindspore/ccsrc/runtime/recovery/CMakeLists.txt View File

@@ -1,9 +0,0 @@
file(GLOB_RECURSE RECOVERY_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"recovery_context.cc")

if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-abstract-non-virtual-dtor")
endif()

set_property(SOURCE ${RECOVERY_SRC_LIST} PROPERTY SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(_mindspore_runtime_recovery_obj OBJECT ${RECOVERY_SRC_LIST})

+ 1
- 2
tests/ut/cpp/CMakeLists.txt View File

@@ -306,8 +306,7 @@ add_library(backend_static STATIC
$<TARGET_OBJECTS:_mindspore_runtime_device_obj>
$<TARGET_OBJECTS:_mindspore_runtime_graph_scheduler_obj>
$<TARGET_OBJECTS:_mindspore_runtime_hardware_obj>
$<TARGET_OBJECTS:_mindspore_runtime_pynative_obj>
$<TARGET_OBJECTS:_mindspore_runtime_recovery_obj>)
$<TARGET_OBJECTS:_mindspore_runtime_pynative_obj>)
target_link_libraries(ut_tests PRIVATE mindspore securec -Wl,--start-group proto_input mindspore::protobuf
backend_static -Wl,--end-group)
target_link_libraries(ut_tests PRIVATE mindspore::grpc++)

Loading…
Cancel
Save