From 485b5259ab17998745cce27a8aaeeeedeb895dfb Mon Sep 17 00:00:00 2001 From: twc Date: Thu, 7 Apr 2022 11:08:28 +0800 Subject: [PATCH] =?UTF-8?q?fix=20issue=20I51DN2=E3=80=81I51DK0=20and=20ope?= =?UTF-8?q?n=20hybrid=20train=20mode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspore/ccsrc/fl/server/common.h | 1 + .../fl/server/distributed_count_service.cc | 41 +++--- .../fl/server/distributed_count_service.h | 11 +- .../fl/server/distributed_metadata_store.cc | 7 +- .../fl/server/distributed_metadata_store.h | 4 +- .../kernel/round/get_list_sign_kernel.cc | 5 +- .../kernel/round/push_list_sign_kernel.cc | 5 +- .../kernel/round/push_metrics_kernel.cc | 3 +- .../server/kernel/round/push_weight_kernel.cc | 3 +- .../kernel/round/start_fl_job_kernel.cc | 7 +- .../kernel/round/update_model_kernel.cc | 13 +- mindspore/ccsrc/ps/constants.h | 1 + mindspore/ccsrc/ps/core/cluster_config.h | 4 +- mindspore/ccsrc/ps/core/comm_util.h | 14 +++ mindspore/ccsrc/ps/core/file_configuration.cc | 2 +- mindspore/ccsrc/ps/core/follower_scaler.cc | 2 + mindspore/ccsrc/ps/core/instance_manager.cc | 10 +- mindspore/ccsrc/ps/core/node.h | 3 - mindspore/ccsrc/ps/core/node_manager.cc | 45 ++----- mindspore/ccsrc/ps/core/node_manager.h | 4 - mindspore/ccsrc/ps/core/scheduler_node.cc | 117 ++++++++++++------ mindspore/ccsrc/ps/core/scheduler_node.h | 2 + mindspore/ccsrc/ps/core/scheduler_recovery.cc | 24 +--- mindspore/ccsrc/ps/ps_context.cc | 5 +- 24 files changed, 165 insertions(+), 168 deletions(-) diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index 1d9429078b..f2cb92da2a 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -293,6 +293,7 @@ inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string & // Definitions for Federated Learning. constexpr auto kNetworkError = "Cluster networking failed."; +constexpr auto KTriggerCounterEventError = "Cluster trigger counter event failed."; // The result code used for round kernels. enum class ResultCode { diff --git a/mindspore/ccsrc/fl/server/distributed_count_service.cc b/mindspore/ccsrc/fl/server/distributed_count_service.cc index bfd66567bb..f976e4c827 100644 --- a/mindspore/ccsrc/fl/server/distributed_count_service.cc +++ b/mindspore/ccsrc/fl/server/distributed_count_service.cc @@ -86,7 +86,7 @@ bool DistributedCountService::ReInitCounter(const std::string &name, size_t glob return true; } -bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) { +bool DistributedCountService::Count(const std::string &name, const std::string &id) { MS_LOG(DEBUG) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; if (local_rank_ == counting_server_rank_) { if (global_threshold_count_.count(name) == 0) { @@ -107,9 +107,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string & MS_LOG(INFO) << "Global current count for " << name << " is: " << global_current_count_[name].size() << "/" << global_threshold_count_[name]; } - if (!TriggerCounterEvent(name, reason)) { + if (!TriggerCounterEvent(name)) { MS_LOG(WARNING) << "Leader server trigger count event failed."; - Iteration::GetInstance().NotifyNext(false, *reason); + Iteration::GetInstance().NotifyNext(false, KTriggerCounterEventError); return false; } } else { @@ -121,10 +121,8 @@ bool DistributedCountService::Count(const std::string &name, const std::string & std::shared_ptr> report_cnt_rsp_msg = nullptr; if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount, &report_cnt_rsp_msg)) { - MS_LOG(WARNING) << "Sending reporting count message to leader server failed for " << name; - if (reason != nullptr) { - *reason = kNetworkError; - } + MS_LOG(WARNING) << "Sending reporting count " + name + " message to leader server failed for fl id " << id; + Iteration::GetInstance().NotifyNext(false, kNetworkError); return false; } @@ -133,10 +131,6 @@ bool DistributedCountService::Count(const std::string &name, const std::string & (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); if (!count_rsp.result()) { MS_LOG(WARNING) << "Reporting count failed:" << count_rsp.reason(); - // If the error is caused by the network issue, return the reason. - if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) { - *reason = kNetworkError; - } return false; } } @@ -263,9 +257,8 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptrSendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { - MS_LOG(WARNING) << "Activating first count event to server " << i << " failed."; - if (reason != nullptr) { - *reason = kNetworkError; - } + MS_LOG(WARNING) << "Send activating first count event to server " << i << " failed."; return false; } } @@ -374,7 +364,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st return true; } -bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) { +bool DistributedCountService::TriggerLastCountEvent(const std::string &name) { MS_LOG(DEBUG) << "Activating last count event for " << name; CounterEvent last_count_event; last_count_event.set_type(CounterEventType::LAST_CNT); @@ -384,10 +374,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std for (uint32_t i = 1; i < server_num_; i++) { MS_LOG(DEBUG) << "Start sending last count event message to server " << i; if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { - MS_LOG(WARNING) << "Activating last count event to server " << i << " failed."; - if (reason != nullptr) { - *reason = kNetworkError; - } + MS_LOG(WARNING) << "Send activating last count event to server " << i << " failed."; return false; } } diff --git a/mindspore/ccsrc/fl/server/distributed_count_service.h b/mindspore/ccsrc/fl/server/distributed_count_service.h index c713242130..396f772570 100644 --- a/mindspore/ccsrc/fl/server/distributed_count_service.h +++ b/mindspore/ccsrc/fl/server/distributed_count_service.h @@ -67,9 +67,8 @@ class DistributedCountService { // Reinitialize counter due to the change of threshold count. bool ReInitCounter(const std::string &name, size_t global_threshold_count); - // Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the - // reason why counting failed. - bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr); + // Report a count to the counting server. Parameter 'id' is in case of repeated counting. + bool Count(const std::string &name, const std::string &id); // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, // this method returns true. @@ -103,9 +102,9 @@ class DistributedCountService { void HandleCounterEvent(const std::shared_ptr &message); // Call the callbacks when the first/last count event is triggered. - bool TriggerCounterEvent(const std::string &name, std::string *reason = nullptr); - bool TriggerFirstCountEvent(const std::string &name, std::string *reason = nullptr); - bool TriggerLastCountEvent(const std::string &name, std::string *reason = nullptr); + bool TriggerCounterEvent(const std::string &name); + bool TriggerFirstCountEvent(const std::string &name); + bool TriggerLastCountEvent(const std::string &name); // Members for the communication between counting server and other servers. std::shared_ptr server_node_; diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc index 436aece43b..820d5880f4 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc @@ -86,7 +86,7 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) { return; } -bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) { +bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) { if (router_ == nullptr) { MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; return false; @@ -107,10 +107,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata, &update_meta_rsp_msg)) { MS_LOG(WARNING) << "Sending updating metadata message to server " << stored_rank << " failed."; - if (reason != nullptr) { - *reason = kNetworkError; - } - Iteration::GetInstance().NotifyNext(false, *reason); + Iteration::GetInstance().NotifyNext(false, kNetworkError); return false; } diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.h b/mindspore/ccsrc/fl/server/distributed_metadata_store.h index ddf6557f81..79dc5b3041 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.h +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.h @@ -55,8 +55,8 @@ class DistributedMetadataStore { // Reset the metadata value for the name. void ResetMetadata(const std::string &name); - // Update the metadata for the name. Parameter 'reason' is the reason why updating meta data failed. - bool UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason = nullptr); + // Update the metadata for the name. + bool UpdateMetadata(const std::string &name, const PBMetadata &meta); // Get the metadata for the name. PBMetadata GetMetadata(const std::string &name); diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc index d4bdf6ddbe..e737b1cf67 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc @@ -153,9 +153,8 @@ bool GetListSignKernel::Launch(const uint8_t *req_data, size_t len, SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); return true; } - std::string count_reason = ""; - if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) { - std::string reason = "Counting for get list sign request failed. Please retry later. " + count_reason; + if (!DistributedCountService::GetInstance().Count(name_, fl_id)) { + std::string reason = "Counting for get list sign request failed for fl id " + fl_id + ". Please retry later. "; BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs); MS_LOG(ERROR) << reason; diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc index 9e8e7d2b6a..1a958584aa 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc @@ -128,9 +128,8 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); return true; } - std::string count_reason = ""; - if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) { - std::string reason = "Counting for push list sign request failed. Please retry later. " + count_reason; + if (!DistributedCountService::GetInstance().Count(name_, fl_id)) { + std::string reason = "Counting for push list sign request failed for fl id " + fl_id + ". Please retry later."; BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), iter_num); MS_LOG(ERROR) << reason; diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc index b512d5ad5d..7702e82f9a 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc @@ -85,8 +85,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr &fbb, Iteration::GetInstance().set_loss(loss); Iteration::GetInstance().set_accuracy(accuracy); - std::string count_reason = ""; - if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) { + if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) { std::string reason = "Count for push metrics request failed."; BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError); MS_LOG(ERROR) << reason; diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc index 83ef56a77b..283af6ee78 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc @@ -110,8 +110,7 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr &fbb, } MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds."; - std::string count_reason = ""; - if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) { + if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) { std::string reason = "Count for push weight request failed."; BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); MS_LOG(ERROR) << reason; diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index bb5915f09e..47bd8050dd 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -107,8 +107,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, } PBMetadata metadata; *metadata.mutable_device_meta() = device_meta; - std::string update_reason = ""; - if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) { + if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) { std::string reason = "Updating device metadata failed for fl id " + device_meta.fl_id(); BuildStartFLJobRsp( fbb, schema::ResponseCode_OutOfTime, reason, false, @@ -116,7 +115,6 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); return false; } - // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. result_code = CountForStartFLJob(fbb, start_fl_job_req); if (result_code != ResultCode::kSuccess) { @@ -299,8 +297,7 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail); MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail); - std::string count_reason = ""; - if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) { + if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) { std::string reason = "Counting start fl job request failed for fl id " + start_fl_job_req->fl_id()->str() + ", Please retry later."; BuildStartFLJobRsp( diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc index 631c89cedf..1749f5be5c 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc @@ -365,8 +365,7 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda fl_id.set_fl_id(update_model_fl_id); PBMetadata comm_value; *comm_value.mutable_fl_id() = fl_id; - std::string update_reason = ""; - if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) { + if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) { std::string reason = "Updating metadata of UpdateModelClientList failed for fl id " + update_model_fl_id; BuildUpdateModelRsp( fbb, schema::ResponseCode_OutOfTime, reason, @@ -526,9 +525,8 @@ std::map UpdateModelKernel::DecodeFeatureMap( } ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) { - std::string count_reason = ""; - if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) { - MS_LOG(ERROR) << "Counting for aggregation failed. reason: " + count_reason; + if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id)) { + MS_LOG(ERROR) << "Counting for aggregation failed for fl id " << req_fl_id; return ResultCode::kFail; } return ResultCode::kSuccess; @@ -538,10 +536,9 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptrfl_id()->str(), &count_reason)) { + if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) { std::string reason = "Counting for update model request failed for fl id " + update_model_req->fl_id()->str() + - ", Please retry later. " + count_reason; + ", Please retry later."; BuildUpdateModelRsp( fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); diff --git a/mindspore/ccsrc/ps/constants.h b/mindspore/ccsrc/ps/constants.h index 06ec18c5e3..b57eb318be 100644 --- a/mindspore/ccsrc/ps/constants.h +++ b/mindspore/ccsrc/ps/constants.h @@ -142,6 +142,7 @@ constexpr char kRecoveryTotalNodeNum[] = "total_node_num"; constexpr char kRecoveryNextWorkerRankId[] = "next_worker_rank_id"; constexpr char kRecoveryNextServerRankId[] = "next_server_rank_id"; constexpr char kRecoveryRegisteredNodesInfos[] = "node_ids"; +constexpr char kRecoveryClusterState[] = "cluster_state"; constexpr char kServerCertPath[] = "server_cert_path"; constexpr char kServerPassword[] = "server_password"; diff --git a/mindspore/ccsrc/ps/core/cluster_config.h b/mindspore/ccsrc/ps/core/cluster_config.h index 7e325c994e..70ec6797a1 100644 --- a/mindspore/ccsrc/ps/core/cluster_config.h +++ b/mindspore/ccsrc/ps/core/cluster_config.h @@ -46,7 +46,8 @@ struct ClusterConfig { scheduler_timeout(30), initial_total_node_num(0), initial_next_worker_rank_id(0), - initial_next_server_rank_id(0) {} + initial_next_server_rank_id(0), + initial_cluster_state(ClusterState::CLUSTER_STARTING) {} // Configure through environment variables:MS_WORKER_NUM uint32_t initial_worker_num; // Configure through environment variables:MS_SERVER_NUM @@ -72,6 +73,7 @@ struct ClusterConfig { uint32_t initial_total_node_num; uint32_t initial_next_worker_rank_id; uint32_t initial_next_server_rank_id; + ClusterState initial_cluster_state; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index 7831ddad8c..b564725d7a 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -58,6 +58,7 @@ #include #include #include +#include #include #include "proto/comm.pb.h" @@ -99,6 +100,19 @@ const std::vector kClusterState = { "CLUSTER_SCALE_OUT_ROLLBACK", // When the cluster is scale out rollback. }; +const std::map kClusterStateMap = { + {"CLUSTER_STARTING", ClusterState::CLUSTER_STARTING}, + {"CLUSTER_READY", ClusterState::CLUSTER_READY}, + {"CLUSTER_EXIT", ClusterState::CLUSTER_EXIT}, + {"NODE_TIMEOUT", ClusterState::NODE_TIMEOUT}, + {"CLUSTER_SCALE_OUT", ClusterState::CLUSTER_SCALE_OUT}, + {"CLUSTER_SCALE_IN", ClusterState::CLUSTER_SCALE_IN}, + {"CLUSTER_NEW_INSTANCE", ClusterState::CLUSTER_NEW_INSTANCE}, + {"CLUSTER_ENABLE_FLS", ClusterState::CLUSTER_ENABLE_FLS}, + {"CLUSTER_DISABLE_FLS", ClusterState::CLUSTER_DISABLE_FLS}, + {"CLUSTER_SCHEDULER_RECOVERY", ClusterState::CLUSTER_SCHEDULER_RECOVERY}, + {"CLUSTER_SCALE_OUT_ROLLBACK", ClusterState::CLUSTER_SCALE_OUT_ROLLBACK}}; + class CommUtil { public: static bool CheckIpWithRegex(const std::string &ip); diff --git a/mindspore/ccsrc/ps/core/file_configuration.cc b/mindspore/ccsrc/ps/core/file_configuration.cc index d91e376e4d..7c81fd0abd 100644 --- a/mindspore/ccsrc/ps/core/file_configuration.cc +++ b/mindspore/ccsrc/ps/core/file_configuration.cc @@ -117,7 +117,6 @@ void FileConfiguration::PersistNodes(const core::ClusterConfig &clusterConfig) c res["node_id"] = node_info.node_id_; res["rank_id"] = std::to_string(node_info.rank_id_); res["role"] = CommUtil::NodeRoleToString(node_info.node_role_); - res["alive"] = CommUtil::BoolToString(node_info.is_alive); persist_js["node_ids"].push_back(res); } @@ -138,6 +137,7 @@ void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) co persist_js[kRecoveryServerNum] = clusterConfig.initial_server_num; persist_js[kRecoverySchedulerIp] = clusterConfig.scheduler_host; persist_js[kRecoverySchedulerPort] = clusterConfig.scheduler_port; + persist_js[kRecoveryClusterState] = CommUtil::ClusterStateToString(clusterConfig.initial_cluster_state); std::ofstream output_file(file_path_); output_file << persist_js.dump(); diff --git a/mindspore/ccsrc/ps/core/follower_scaler.cc b/mindspore/ccsrc/ps/core/follower_scaler.cc index d290813abb..46b03bb5af 100644 --- a/mindspore/ccsrc/ps/core/follower_scaler.cc +++ b/mindspore/ccsrc/ps/core/follower_scaler.cc @@ -225,6 +225,8 @@ std::string FollowerScaler::GetNodeScaleStateStr() { return "kWaiting"; case NodeScaleState::kScaling: return "kScaling"; + case NodeScaleState::kRollback: + return "kRollback"; default: MS_LOG(EXCEPTION) << "scale_state is not supported."; } diff --git a/mindspore/ccsrc/ps/core/instance_manager.cc b/mindspore/ccsrc/ps/core/instance_manager.cc index ce0a04b831..e1f4a575a1 100644 --- a/mindspore/ccsrc/ps/core/instance_manager.cc +++ b/mindspore/ccsrc/ps/core/instance_manager.cc @@ -35,7 +35,7 @@ void InstanceManager::NewInstanceAsync(const std::shared_ptr &client, MS_LOG(WARNING) << "Send new instance timeout!"; } - MS_LOG(INFO) << "The scheduler is sending new instance to workers and servers!"; + MS_LOG(INFO) << "The scheduler is sending new instance to " << node_info.node_id_; } void InstanceManager::QueryInstanceAsync(const std::shared_ptr &client, const NodeManager &, @@ -55,7 +55,7 @@ void InstanceManager::QueryInstanceAsync(const std::shared_ptr &clien MS_LOG(WARNING) << "Send query instance timeout!"; } - MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; + MS_LOG(INFO) << "The scheduler is sending query instance to " << node_info.node_id_; } void InstanceManager::EnableFLSAsync(const std::shared_ptr &client, const NodeManager &, @@ -75,7 +75,7 @@ void InstanceManager::EnableFLSAsync(const std::shared_ptr &client, c MS_LOG(WARNING) << "Send query instance timeout!"; } - MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; + MS_LOG(INFO) << "The scheduler is sending enable FLS to " << node_info.node_id_; } void InstanceManager::DisableFLSAsync(const std::shared_ptr &client, const NodeManager &, @@ -95,7 +95,7 @@ void InstanceManager::DisableFLSAsync(const std::shared_ptr &client, MS_LOG(WARNING) << "Send query instance timeout!"; } - MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!"; + MS_LOG(INFO) << "The scheduler is sending disable FLS to " << node_info.node_id_; } void InstanceManager::QueryNodeScaleState(const std::shared_ptr &client, const NodeManager &, @@ -115,7 +115,7 @@ void InstanceManager::QueryNodeScaleState(const std::shared_ptr &clie MS_LOG(WARNING) << "Send query node scale state timeout!"; } - MS_LOG(INFO) << "The scheduler is sending query node scale state to workers and servers!"; + MS_LOG(INFO) << "The scheduler is sending query node scale state to " << node_info.node_id_; } } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h index adaf011c5e..bc111906ee 100644 --- a/mindspore/ccsrc/ps/core/node.h +++ b/mindspore/ccsrc/ps/core/node.h @@ -54,7 +54,6 @@ class BACKEND_EXPORT Node { is_already_stopped_(true), is_already_finished_(false), next_request_id_(0), - current_node_state_(NodeState::NODE_STARTING), current_cluster_state_(ClusterState::CLUSTER_STARTING) {} virtual ~Node() = default; @@ -113,8 +112,6 @@ class BACKEND_EXPORT Node { std::mutex message_tracker_mutex_; std::condition_variable message_tracker_cond_; - // Worker and server receive the node state and cluster state from the scheduler. - NodeState current_node_state_; ClusterState current_cluster_state_; // Configuration file,The format is as follows diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index 7095e21dd5..6e9c668d19 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -70,8 +70,7 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message registered_nodes_info_[node_id] = recovery_node_infos[node_id]; MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!" << ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_ - << ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive - << ", fl iteration num: " << new_fl_iteration_num + << ", rank id: " << rank_id << ", fl iteration num: " << new_fl_iteration_num << ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_); return rank_id; } @@ -235,12 +234,15 @@ void NodeManager::UpdateCluster(bool is_cluster_ready) { timeout_nodes_info_[it->first] = registered_nodes_info_[it->first]; registered_nodes_info_[it->first].is_alive = false; } + } else { + if (registered_nodes_info_.count(it->first) && !registered_nodes_info_[it->first].is_alive) { + MS_LOG(WARNING) << registered_nodes_info_[it->first].node_id_ << " is alive."; + registered_nodes_info_[it->first].is_alive = true; + } } } if (!timeout_nodes_info_.empty()) { - UpdateClusterState(ClusterState::NODE_TIMEOUT); - auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); if (!context_ptr->get_param(MS_CTX_ENABLE_RECOVERY)) { @@ -249,25 +251,14 @@ void NodeManager::UpdateCluster(bool is_cluster_ready) { finish_nodes_id_.insert(iter->first); } } - if (onPersist_) { - onPersist_(); + if (cluster_state_ != ClusterState::CLUSTER_DISABLE_FLS) { + UpdateClusterState(ClusterState::NODE_TIMEOUT); } - } else if (SizeToUint(heartbeats_.size()) == total_node_num_) { - if (cluster_state_ == ClusterState::NODE_TIMEOUT) { - for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) { - if (registered_nodes_info_.count(it->first) && !it->second.is_alive) { - MS_LOG(WARNING) << it->second.node_id_ << " is alive."; - it->second.is_alive = true; - } - } - if (onPersist_) { - onPersist_(); - } - if (is_cluster_ready) { - UpdateClusterState(ClusterState::CLUSTER_READY); - } else { - UpdateClusterState(ClusterState::CLUSTER_STARTING); - } + } else if (SizeToUint(heartbeats_.size()) == total_node_num_ && cluster_state_ == ClusterState::NODE_TIMEOUT) { + if (is_cluster_ready) { + UpdateClusterState(ClusterState::CLUSTER_READY); + } else { + UpdateClusterState(ClusterState::CLUSTER_STARTING); } } @@ -324,11 +315,6 @@ void NodeManager::UpdateNodesInfo() { nodes_info_ = registered_nodes_info_; } -void NodeManager::UpdateNodeState(const NodeState &state) { - std::lock_guard lk(node_mutex_); - node_state_ = state; -} - void NodeManager::UpdateClusterState(const ClusterState &state) { std::lock_guard lk(cluster_mutex_); std::string state_str = CommUtil::ClusterStateToString(state); @@ -340,11 +326,6 @@ void NodeManager::UpdateClusterState(const ClusterState &state) { cluster_state_ = state; } -NodeState NodeManager::GetNodeState() { - std::lock_guard lk(node_mutex_); - return node_state_; -} - ClusterState NodeManager::GetClusterState() { std::lock_guard lk(cluster_mutex_); return cluster_state_; diff --git a/mindspore/ccsrc/ps/core/node_manager.h b/mindspore/ccsrc/ps/core/node_manager.h index 523f9761f5..afb2501181 100644 --- a/mindspore/ccsrc/ps/core/node_manager.h +++ b/mindspore/ccsrc/ps/core/node_manager.h @@ -49,7 +49,6 @@ class NodeManager { next_worker_rank_id_(0), next_server_rank_id_(0), meta_data_(nullptr), - node_state_(NodeState::NODE_STARTING), cluster_state_(ClusterState::CLUSTER_STARTING) {} virtual ~NodeManager() = default; using OnPersist = std::function; @@ -104,9 +103,7 @@ class NodeManager { uint32_t next_worker_rank_id() const; uint32_t next_server_rank_id() const; - void UpdateNodeState(const NodeState &state); void UpdateClusterState(const ClusterState &state); - NodeState GetNodeState(); ClusterState GetClusterState(); // When the scheduler receives the scale out or scale in message, the metadata needs to be reset, because all nodes @@ -179,7 +176,6 @@ class NodeManager { // Cluster metadata information can be dynamically changed std::unique_ptr meta_data_; - NodeState node_state_; ClusterState cluster_state_; std::deque recovery_worker_rank_id_; diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index 8b825a8bc0..2d88c254f8 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -57,7 +57,6 @@ bool SchedulerNode::Start(const uint32_t &timeout) { MS_LOG(ERROR) << "Start Scheduler node timeout!"; return false; } - node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); StartUpdatePersistentCommandTimer(); MS_LOG(INFO) << "[Scheduler start]: 4. Successfully start scheduler, there are " << node_manager_.worker_num() @@ -85,7 +84,26 @@ void SchedulerNode::RunRecovery() { MS_LOG(WARNING) << "There is no registered nodes in scheduler!"; return; } - MS_LOG(INFO) << "The scheduler start run recovery!"; + MS_LOG(INFO) << "The scheduler start run recovery!" + << " The worker num:" << clusterConfig.initial_worker_num + << ", the server num:" << clusterConfig.initial_server_num + << ", the scheduler ip:" << clusterConfig.scheduler_host + << ", the scheduler port:" << clusterConfig.scheduler_port + << ", the initial total node num:" << clusterConfig.initial_total_node_num + << ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id + << ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id + << ", the initial cluster state:" << kClusterState.at(clusterConfig.initial_cluster_state); + + if (!clusterConfig.initial_registered_nodes_infos.empty()) { + for (const auto kvs : clusterConfig.initial_registered_nodes_infos) { + MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_ + << ", the node_id:" << kvs.second.node_id_ + << ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_) + << ", the rank_id_:" << kvs.second.rank_id_ + << ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive); + } + } + uint32_t worker_num = clusterConfig.initial_worker_num; uint32_t server_num = clusterConfig.initial_server_num; @@ -94,6 +112,11 @@ void SchedulerNode::RunRecovery() { node_manager_.set_next_worker_rank_id(clusterConfig.initial_next_worker_rank_id); node_manager_.set_next_server_rank_id(clusterConfig.initial_next_server_rank_id); node_manager_.set_total_node_num(clusterConfig.initial_total_node_num); + if (clusterConfig.initial_cluster_state == ClusterState::CLUSTER_DISABLE_FLS) { + MS_LOG(WARNING) << "Scheduler recover and update cluster state from recovery file, cluster state is " + << CommUtil::ClusterStateToString(clusterConfig.initial_cluster_state); + node_manager_.UpdateClusterState(clusterConfig.initial_cluster_state); + } for (const auto &kvs : initial_node_infos) { auto &node_id = kvs.first; @@ -352,44 +375,56 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr &server, "will exit later."; return; } + if (!BuildingNetwork()) { + MS_LOG(ERROR) << "Building network failed! Cluster will exit later."; + } + } +} - if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) { - auto nodes = node_manager_.nodes_info(); - for (const auto &id : scale_in_node_ids_) { - MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id; - if (nodes.count(id)) { - auto scale_in_client = GetOrCreateClient(nodes[id]); - SendMetadata(scale_in_client, nodes[id].rank_id_); - node_manager_.UpdateHeartbeat(id); - } - if (connected_nodes_.count(id)) { - MS_LOG(INFO) << "remove scale in node id: " << id << " connection."; - connected_nodes_.erase(id); - } +bool SchedulerNode::BuildingNetwork() { + if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) { + auto nodes = node_manager_.nodes_info(); + for (const auto &id : scale_in_node_ids_) { + MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id; + if (nodes.count(id)) { + auto scale_in_client = GetOrCreateClient(nodes[id]); + SendMetadata(scale_in_client, nodes[id].rank_id_); + node_manager_.UpdateHeartbeat(id); + } + if (connected_nodes_.count(id)) { + MS_LOG(INFO) << "remove scale in node id: " << id << " connection."; + connected_nodes_.erase(id); } } - node_manager_.UpdateNodesInfo(); - auto node_infos = node_manager_.nodes_info(); - bool res = SendPrepareBuildingNetwork(node_infos); - if (!res) { - MS_LOG(ERROR) << "Prepare for building network failed! Cluster will exit later."; - return; - } - is_ready_ = true; - MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and " - << node_manager_.server_num() - << " servers registered to scheduer, so the scheduler send meta data to worker/server."; + } + node_manager_.UpdateNodesInfo(); + auto node_infos = node_manager_.nodes_info(); + bool res = SendPrepareBuildingNetwork(node_infos); + if (!res) { + MS_LOG(ERROR) << "Prepare for building network failed!"; + return false; + } + is_ready_ = true; + MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and " + << node_manager_.server_num() + << " servers registered to scheduer, so the scheduler send meta data to worker/server."; - for (const auto &kvs : node_infos) { - auto client = GetOrCreateClient(kvs.second); - MS_EXCEPTION_IF_NULL(client); - SendMetadata(client, kvs.second.rank_id_); - node_manager_.UpdateHeartbeat(kvs.first); - } + for (const auto &kvs : node_infos) { + auto client = GetOrCreateClient(kvs.second); + MS_EXCEPTION_IF_NULL(client); + SendMetadata(client, kvs.second.rank_id_); + node_manager_.UpdateHeartbeat(kvs.first); + } + + if (node_manager_.GetClusterState() == ClusterState::CLUSTER_DISABLE_FLS) { + MS_LOG(WARNING) + << "Cluster state is CLUSTER_DISABLE_FLS, do not need to change to CLUSTER_READY when building network."; + } else { node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); - PersistMetaData(); - wait_start_cond_.notify_all(); } + PersistMetaData(); + wait_start_cond_.notify_all(); + return true; } void SchedulerNode::ProcessFinish(const std::shared_ptr &server, const std::shared_ptr &conn, @@ -897,7 +932,7 @@ bool SchedulerNode::QueryNodeScaleState(const std::shared_ptrQueryNodeScaleState(client, node_manager_, request_id, node_info_); + instance_manager_->QueryNodeScaleState(client, node_manager_, request_id, kvs.second); } } @@ -1179,7 +1214,7 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr auto client = GetOrCreateClient(kvs.second); MS_EXCEPTION_IF_NULL(client); MS_EXCEPTION_IF_NULL(instance_manager_); - instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, node_info_); + instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, kvs.second); } } bool res = Wait(request_id); @@ -1246,7 +1281,7 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptrQueryInstanceAsync(client, node_manager_, request_id, node_info_); + instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, kvs.second); } } bool res = Wait(request_id); @@ -1308,7 +1343,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr & auto client = GetOrCreateClient(kvs.second); MS_EXCEPTION_IF_NULL(client); MS_EXCEPTION_IF_NULL(instance_manager_); - instance_manager_->EnableFLSAsync(client, node_manager_, request_id, node_info_); + instance_manager_->EnableFLSAsync(client, node_manager_, request_id, kvs.second); } } bool res = Wait(request_id); @@ -1334,6 +1369,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr & js["code"] = kSuccessCode; js["result"] = true; node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); + PersistMetaData(); } else { js["message"] = "start enabling FL-Server failed."; js["code"] = kErrorCode; @@ -1379,7 +1415,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr auto client = GetOrCreateClient(kvs.second); MS_EXCEPTION_IF_NULL(client); MS_EXCEPTION_IF_NULL(instance_manager_); - instance_manager_->DisableFLSAsync(client, node_manager_, request_id, node_info_); + instance_manager_->DisableFLSAsync(client, node_manager_, request_id, kvs.second); } } bool res = Wait(request_id); @@ -1404,6 +1440,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr js["code"] = kSuccessCode; js["result"] = true; node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS); + PersistMetaData(); } else { js["message"] = "start disabling FL-Server failed."; js["code"] = kErrorCode; @@ -1565,6 +1602,7 @@ void SchedulerNode::PersistMetaData() { return; } if (!is_ready_) { + MS_LOG(WARNING) << "Cluster is not building network successful, do not persist meta data"; return; } if (config_->Exists(kKeyRecovery)) { @@ -1576,6 +1614,7 @@ void SchedulerNode::PersistMetaData() { clusterConfig.initial_next_server_rank_id = node_manager_.next_server_rank_id(); clusterConfig.initial_registered_nodes_infos.clear(); clusterConfig.initial_registered_nodes_infos = node_manager_.registered_nodes_info(); + clusterConfig.initial_cluster_state = node_manager_.GetClusterState(); scheduler_recovery_->Persist(clusterConfig); scheduler_recovery_->PersistNodesInfo(clusterConfig); diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index 66f5963919..757c6dc910 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -217,6 +217,8 @@ class BACKEND_EXPORT SchedulerNode : public Node { void GeneralResponse(const std::shared_ptr &server, const std::shared_ptr &conn, const std::shared_ptr &meta, bool is_success, const std::string &error); + bool BuildingNetwork(); + std::shared_ptr server_; std::unique_ptr scheduler_thread_; std::unique_ptr update_state_thread_; diff --git a/mindspore/ccsrc/ps/core/scheduler_recovery.cc b/mindspore/ccsrc/ps/core/scheduler_recovery.cc index cc450d5f9b..a484a5a2c8 100644 --- a/mindspore/ccsrc/ps/core/scheduler_recovery.cc +++ b/mindspore/ccsrc/ps/core/scheduler_recovery.cc @@ -15,6 +15,7 @@ */ #include "ps/core/scheduler_recovery.h" +#include "ps/core/comm_util.h" namespace mindspore { namespace ps { @@ -63,11 +64,6 @@ bool SchedulerRecovery::Recover() { MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path(); } - MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num - << ", the server num:" << clusterConfig.initial_server_num - << ", the scheduler ip:" << clusterConfig.scheduler_host - << ", the scheduler port:" << clusterConfig.scheduler_port; - MS_ERROR_IF_NULL_W_RET_VAL(scheduler_recovery_storage_, false); // 5. recover total node num if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) { @@ -110,7 +106,6 @@ bool SchedulerRecovery::Recover() { node_info.port_ = static_cast(std::strtol(port.c_str(), nullptr, kBase)); node_info.node_id_ = elem.at("node_id"); node_info.rank_id_ = UlongToUint(std::strtoul(rank_id.c_str(), nullptr, kBase)); - node_info.is_alive = CommUtil::StringToBool(elem.at("alive")); node_info.node_role_ = CommUtil::StringToNodeRole(elem.at("role")); nodes_infos[node_info.node_id_] = node_info; @@ -127,18 +122,11 @@ bool SchedulerRecovery::Recover() { MS_LOG(EXCEPTION) << kRecoveryRegisteredNodesInfos << " is not contained in " << recovery_storage_->file_path(); } - MS_LOG(INFO) << ", the initial total node num:" << clusterConfig.initial_total_node_num - << ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id - << ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id; - - if (!clusterConfig.initial_registered_nodes_infos.empty()) { - for (const auto kvs : clusterConfig.initial_registered_nodes_infos) { - MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_ - << ", the node_id:" << kvs.second.node_id_ - << ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_) - << ", the rank_id_:" << kvs.second.rank_id_ - << ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive); - } + // 9. recover cluster state + if (recovery_storage_->Exists(kRecoveryClusterState)) { + clusterConfig.initial_cluster_state = kClusterStateMap.at(recovery_storage_->GetString(kRecoveryClusterState, "")); + } else { + MS_LOG(EXCEPTION) << kRecoveryClusterState << " is not contained in " << recovery_storage_->file_path(); } return true; } diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 05e5de6388..8f8fc1d66f 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -209,8 +209,9 @@ void PSContext::set_rank_id(uint32_t rank_id) const { } void PSContext::set_server_mode(const std::string &server_mode) { - if (server_mode != kServerModePS && server_mode != kServerModeFL) { - MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL; + if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) { + MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL + << " or " << kServerModeHybrid; return; } MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";