Browse Source

fix issue I51DN2、I51DK0 and open hybrid train mode

r1.7
twc 3 years ago
parent
commit
485b5259ab
24 changed files with 165 additions and 168 deletions
  1. +1
    -0
      mindspore/ccsrc/fl/server/common.h
  2. +14
    -27
      mindspore/ccsrc/fl/server/distributed_count_service.cc
  3. +5
    -6
      mindspore/ccsrc/fl/server/distributed_count_service.h
  4. +2
    -5
      mindspore/ccsrc/fl/server/distributed_metadata_store.cc
  5. +2
    -2
      mindspore/ccsrc/fl/server/distributed_metadata_store.h
  6. +2
    -3
      mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc
  7. +2
    -3
      mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc
  8. +1
    -2
      mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc
  9. +1
    -2
      mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc
  10. +2
    -5
      mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc
  11. +5
    -8
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc
  12. +1
    -0
      mindspore/ccsrc/ps/constants.h
  13. +3
    -1
      mindspore/ccsrc/ps/core/cluster_config.h
  14. +14
    -0
      mindspore/ccsrc/ps/core/comm_util.h
  15. +1
    -1
      mindspore/ccsrc/ps/core/file_configuration.cc
  16. +2
    -0
      mindspore/ccsrc/ps/core/follower_scaler.cc
  17. +5
    -5
      mindspore/ccsrc/ps/core/instance_manager.cc
  18. +0
    -3
      mindspore/ccsrc/ps/core/node.h
  19. +13
    -32
      mindspore/ccsrc/ps/core/node_manager.cc
  20. +0
    -4
      mindspore/ccsrc/ps/core/node_manager.h
  21. +78
    -39
      mindspore/ccsrc/ps/core/scheduler_node.cc
  22. +2
    -0
      mindspore/ccsrc/ps/core/scheduler_node.h
  23. +6
    -18
      mindspore/ccsrc/ps/core/scheduler_recovery.cc
  24. +3
    -2
      mindspore/ccsrc/ps/ps_context.cc

+ 1
- 0
mindspore/ccsrc/fl/server/common.h View File

@@ -293,6 +293,7 @@ inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &
// Definitions for Federated Learning.

constexpr auto kNetworkError = "Cluster networking failed.";
constexpr auto KTriggerCounterEventError = "Cluster trigger counter event failed.";

// The result code used for round kernels.
enum class ResultCode {


+ 14
- 27
mindspore/ccsrc/fl/server/distributed_count_service.cc View File

@@ -86,7 +86,7 @@ bool DistributedCountService::ReInitCounter(const std::string &name, size_t glob
return true;
}

bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) {
bool DistributedCountService::Count(const std::string &name, const std::string &id) {
MS_LOG(DEBUG) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
if (local_rank_ == counting_server_rank_) {
if (global_threshold_count_.count(name) == 0) {
@@ -107,9 +107,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
MS_LOG(INFO) << "Global current count for " << name << " is: " << global_current_count_[name].size() << "/"
<< global_threshold_count_[name];
}
if (!TriggerCounterEvent(name, reason)) {
if (!TriggerCounterEvent(name)) {
MS_LOG(WARNING) << "Leader server trigger count event failed.";
Iteration::GetInstance().NotifyNext(false, *reason);
Iteration::GetInstance().NotifyNext(false, KTriggerCounterEventError);
return false;
}
} else {
@@ -121,10 +121,8 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
&report_cnt_rsp_msg)) {
MS_LOG(WARNING) << "Sending reporting count message to leader server failed for " << name;
if (reason != nullptr) {
*reason = kNetworkError;
}
MS_LOG(WARNING) << "Sending reporting count " + name + " message to leader server failed for fl id " << id;
Iteration::GetInstance().NotifyNext(false, kNetworkError);
return false;
}

@@ -133,10 +131,6 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
(void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
if (!count_rsp.result()) {
MS_LOG(WARNING) << "Reporting count failed:" << count_rsp.reason();
// If the error is caused by the network issue, return the reason.
if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
*reason = kNetworkError;
}
return false;
}
}
@@ -263,9 +257,8 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
MS_LOG(WARNING) << "Sending response failed.";
return;
}
std::string reason = "success";
if (!TriggerCounterEvent(name, &reason)) {
Iteration::GetInstance().NotifyNext(false, reason);
if (!TriggerCounterEvent(name)) {
Iteration::GetInstance().NotifyNext(false, KTriggerCounterEventError);
}
}

@@ -322,7 +315,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
return;
}

bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) {
bool DistributedCountService::TriggerCounterEvent(const std::string &name) {
if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) {
MS_LOG(WARNING) << "The counter of " << name << " is not registered.";
return false;
@@ -332,19 +325,19 @@ bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::
<< ", threshold count is " << global_threshold_count_[name];
// The threshold count may be 1 so the first and last count event should be both activated.
if (global_current_count_[name].size() == 1) {
if (!TriggerFirstCountEvent(name, reason)) {
if (!TriggerFirstCountEvent(name)) {
return false;
}
}
if (global_current_count_[name].size() == global_threshold_count_[name]) {
if (!TriggerLastCountEvent(name, reason)) {
if (!TriggerLastCountEvent(name)) {
return false;
}
}
return true;
}

bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, std::string *reason) {
bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
MS_LOG(DEBUG) << "Activating first count event for " << name;
CounterEvent first_count_event;
first_count_event.set_type(CounterEventType::FIRST_CNT);
@@ -354,10 +347,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
for (uint32_t i = 1; i < server_num_; i++) {
MS_LOG(DEBUG) << "Start sending first count event message to server " << i;
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(WARNING) << "Activating first count event to server " << i << " failed.";
if (reason != nullptr) {
*reason = kNetworkError;
}
MS_LOG(WARNING) << "Send activating first count event to server " << i << " failed.";
return false;
}
}
@@ -374,7 +364,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
return true;
}

bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) {
bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
MS_LOG(DEBUG) << "Activating last count event for " << name;
CounterEvent last_count_event;
last_count_event.set_type(CounterEventType::LAST_CNT);
@@ -384,10 +374,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
for (uint32_t i = 1; i < server_num_; i++) {
MS_LOG(DEBUG) << "Start sending last count event message to server " << i;
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(WARNING) << "Activating last count event to server " << i << " failed.";
if (reason != nullptr) {
*reason = kNetworkError;
}
MS_LOG(WARNING) << "Send activating last count event to server " << i << " failed.";
return false;
}
}


+ 5
- 6
mindspore/ccsrc/fl/server/distributed_count_service.h View File

@@ -67,9 +67,8 @@ class DistributedCountService {
// Reinitialize counter due to the change of threshold count.
bool ReInitCounter(const std::string &name, size_t global_threshold_count);

// Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the
// reason why counting failed.
bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr);
// Report a count to the counting server. Parameter 'id' is in case of repeated counting.
bool Count(const std::string &name, const std::string &id);

// Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count,
// this method returns true.
@@ -103,9 +102,9 @@ class DistributedCountService {
void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message);

// Call the callbacks when the first/last count event is triggered.
bool TriggerCounterEvent(const std::string &name, std::string *reason = nullptr);
bool TriggerFirstCountEvent(const std::string &name, std::string *reason = nullptr);
bool TriggerLastCountEvent(const std::string &name, std::string *reason = nullptr);
bool TriggerCounterEvent(const std::string &name);
bool TriggerFirstCountEvent(const std::string &name);
bool TriggerLastCountEvent(const std::string &name);

// Members for the communication between counting server and other servers.
std::shared_ptr<ps::core::ServerNode> server_node_;


+ 2
- 5
mindspore/ccsrc/fl/server/distributed_metadata_store.cc View File

@@ -86,7 +86,7 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) {
return;
}

bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) {
bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
if (router_ == nullptr) {
MS_LOG(WARNING) << "The consistent hash ring is not initialized yet.";
return false;
@@ -107,10 +107,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
&update_meta_rsp_msg)) {
MS_LOG(WARNING) << "Sending updating metadata message to server " << stored_rank << " failed.";
if (reason != nullptr) {
*reason = kNetworkError;
}
Iteration::GetInstance().NotifyNext(false, *reason);
Iteration::GetInstance().NotifyNext(false, kNetworkError);
return false;
}



+ 2
- 2
mindspore/ccsrc/fl/server/distributed_metadata_store.h View File

@@ -55,8 +55,8 @@ class DistributedMetadataStore {
// Reset the metadata value for the name.
void ResetMetadata(const std::string &name);

// Update the metadata for the name. Parameter 'reason' is the reason why updating meta data failed.
bool UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason = nullptr);
// Update the metadata for the name.
bool UpdateMetadata(const std::string &name, const PBMetadata &meta);

// Get the metadata for the name.
PBMetadata GetMetadata(const std::string &name);


+ 2
- 3
mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc View File

@@ -153,9 +153,8 @@ bool GetListSignKernel::Launch(const uint8_t *req_data, size_t len,
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) {
std::string reason = "Counting for get list sign request failed. Please retry later. " + count_reason;
if (!DistributedCountService::GetInstance().Count(name_, fl_id)) {
std::string reason = "Counting for get list sign request failed for fl id " + fl_id + ". Please retry later. ";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num, list_signs);
MS_LOG(ERROR) << reason;


+ 2
- 3
mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc View File

@@ -128,9 +128,8 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) {
std::string reason = "Counting for push list sign request failed. Please retry later. " + count_reason;
if (!DistributedCountService::GetInstance().Count(name_, fl_id)) {
std::string reason = "Counting for push list sign request failed for fl id " + fl_id + ". Please retry later.";
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num);
MS_LOG(ERROR) << reason;


+ 1
- 2
mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc View File

@@ -85,8 +85,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb,
Iteration::GetInstance().set_loss(loss);
Iteration::GetInstance().set_accuracy(accuracy);

std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) {
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) {
std::string reason = "Count for push metrics request failed.";
BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError);
MS_LOG(ERROR) << reason;


+ 1
- 2
mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc View File

@@ -110,8 +110,7 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
}
MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";

std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) {
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) {
std::string reason = "Count for push weight request failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
MS_LOG(ERROR) << reason;


+ 2
- 5
mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc View File

@@ -107,8 +107,7 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
}
PBMetadata metadata;
*metadata.mutable_device_meta() = device_meta;
std::string update_reason = "";
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) {
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) {
std::string reason = "Updating device metadata failed for fl id " + device_meta.fl_id();
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
@@ -116,7 +115,6 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}

// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
result_code = CountForStartFLJob(fbb, start_fl_job_req);
if (result_code != ResultCode::kSuccess) {
@@ -299,8 +297,7 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder>
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req->fl_id(), ResultCode::kFail);

std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) {
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
std::string reason =
"Counting start fl job request failed for fl id " + start_fl_job_req->fl_id()->str() + ", Please retry later.";
BuildStartFLJobRsp(


+ 5
- 8
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc View File

@@ -365,8 +365,7 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
fl_id.set_fl_id(update_model_fl_id);
PBMetadata comm_value;
*comm_value.mutable_fl_id() = fl_id;
std::string update_reason = "";
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) {
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) {
std::string reason = "Updating metadata of UpdateModelClientList failed for fl id " + update_model_fl_id;
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
@@ -526,9 +525,8 @@ std::map<std::string, UploadData> UpdateModelKernel::DecodeFeatureMap(
}

ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) {
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) {
MS_LOG(ERROR) << "Counting for aggregation failed. reason: " + count_reason;
if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id)) {
MS_LOG(ERROR) << "Counting for aggregation failed for fl id " << req_fl_id;
return ResultCode::kFail;
}
return ResultCode::kSuccess;
@@ -538,10 +536,9 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilde
const schema::RequestUpdateModel *update_model_req) {
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) {
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
std::string reason = "Counting for update model request failed for fl id " + update_model_req->fl_id()->str() +
", Please retry later. " + count_reason;
", Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));


+ 1
- 0
mindspore/ccsrc/ps/constants.h View File

@@ -142,6 +142,7 @@ constexpr char kRecoveryTotalNodeNum[] = "total_node_num";
constexpr char kRecoveryNextWorkerRankId[] = "next_worker_rank_id";
constexpr char kRecoveryNextServerRankId[] = "next_server_rank_id";
constexpr char kRecoveryRegisteredNodesInfos[] = "node_ids";
constexpr char kRecoveryClusterState[] = "cluster_state";

constexpr char kServerCertPath[] = "server_cert_path";
constexpr char kServerPassword[] = "server_password";


+ 3
- 1
mindspore/ccsrc/ps/core/cluster_config.h View File

@@ -46,7 +46,8 @@ struct ClusterConfig {
scheduler_timeout(30),
initial_total_node_num(0),
initial_next_worker_rank_id(0),
initial_next_server_rank_id(0) {}
initial_next_server_rank_id(0),
initial_cluster_state(ClusterState::CLUSTER_STARTING) {}
// Configure through environment variables:MS_WORKER_NUM
uint32_t initial_worker_num;
// Configure through environment variables:MS_SERVER_NUM
@@ -72,6 +73,7 @@ struct ClusterConfig {
uint32_t initial_total_node_num;
uint32_t initial_next_worker_rank_id;
uint32_t initial_next_server_rank_id;
ClusterState initial_cluster_state;
};
} // namespace core
} // namespace ps


+ 14
- 0
mindspore/ccsrc/ps/core/comm_util.h View File

@@ -58,6 +58,7 @@
#include <fstream>
#include <iostream>
#include <vector>
#include <map>
#include <algorithm>

#include "proto/comm.pb.h"
@@ -99,6 +100,19 @@ const std::vector<std::string> kClusterState = {
"CLUSTER_SCALE_OUT_ROLLBACK", // When the cluster is scale out rollback.
};

const std::map<std::string, ClusterState> kClusterStateMap = {
{"CLUSTER_STARTING", ClusterState::CLUSTER_STARTING},
{"CLUSTER_READY", ClusterState::CLUSTER_READY},
{"CLUSTER_EXIT", ClusterState::CLUSTER_EXIT},
{"NODE_TIMEOUT", ClusterState::NODE_TIMEOUT},
{"CLUSTER_SCALE_OUT", ClusterState::CLUSTER_SCALE_OUT},
{"CLUSTER_SCALE_IN", ClusterState::CLUSTER_SCALE_IN},
{"CLUSTER_NEW_INSTANCE", ClusterState::CLUSTER_NEW_INSTANCE},
{"CLUSTER_ENABLE_FLS", ClusterState::CLUSTER_ENABLE_FLS},
{"CLUSTER_DISABLE_FLS", ClusterState::CLUSTER_DISABLE_FLS},
{"CLUSTER_SCHEDULER_RECOVERY", ClusterState::CLUSTER_SCHEDULER_RECOVERY},
{"CLUSTER_SCALE_OUT_ROLLBACK", ClusterState::CLUSTER_SCALE_OUT_ROLLBACK}};

class CommUtil {
public:
static bool CheckIpWithRegex(const std::string &ip);


+ 1
- 1
mindspore/ccsrc/ps/core/file_configuration.cc View File

@@ -117,7 +117,6 @@ void FileConfiguration::PersistNodes(const core::ClusterConfig &clusterConfig) c
res["node_id"] = node_info.node_id_;
res["rank_id"] = std::to_string(node_info.rank_id_);
res["role"] = CommUtil::NodeRoleToString(node_info.node_role_);
res["alive"] = CommUtil::BoolToString(node_info.is_alive);
persist_js["node_ids"].push_back(res);
}

@@ -138,6 +137,7 @@ void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) co
persist_js[kRecoveryServerNum] = clusterConfig.initial_server_num;
persist_js[kRecoverySchedulerIp] = clusterConfig.scheduler_host;
persist_js[kRecoverySchedulerPort] = clusterConfig.scheduler_port;
persist_js[kRecoveryClusterState] = CommUtil::ClusterStateToString(clusterConfig.initial_cluster_state);

std::ofstream output_file(file_path_);
output_file << persist_js.dump();


+ 2
- 0
mindspore/ccsrc/ps/core/follower_scaler.cc View File

@@ -225,6 +225,8 @@ std::string FollowerScaler::GetNodeScaleStateStr() {
return "kWaiting";
case NodeScaleState::kScaling:
return "kScaling";
case NodeScaleState::kRollback:
return "kRollback";
default:
MS_LOG(EXCEPTION) << "scale_state is not supported.";
}


+ 5
- 5
mindspore/ccsrc/ps/core/instance_manager.cc View File

@@ -35,7 +35,7 @@ void InstanceManager::NewInstanceAsync(const std::shared_ptr<TcpClient> &client,
MS_LOG(WARNING) << "Send new instance timeout!";
}

MS_LOG(INFO) << "The scheduler is sending new instance to workers and servers!";
MS_LOG(INFO) << "The scheduler is sending new instance to " << node_info.node_id_;
}

void InstanceManager::QueryInstanceAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &,
@@ -55,7 +55,7 @@ void InstanceManager::QueryInstanceAsync(const std::shared_ptr<TcpClient> &clien
MS_LOG(WARNING) << "Send query instance timeout!";
}

MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!";
MS_LOG(INFO) << "The scheduler is sending query instance to " << node_info.node_id_;
}

void InstanceManager::EnableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &,
@@ -75,7 +75,7 @@ void InstanceManager::EnableFLSAsync(const std::shared_ptr<TcpClient> &client, c
MS_LOG(WARNING) << "Send query instance timeout!";
}

MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!";
MS_LOG(INFO) << "The scheduler is sending enable FLS to " << node_info.node_id_;
}

void InstanceManager::DisableFLSAsync(const std::shared_ptr<TcpClient> &client, const NodeManager &,
@@ -95,7 +95,7 @@ void InstanceManager::DisableFLSAsync(const std::shared_ptr<TcpClient> &client,
MS_LOG(WARNING) << "Send query instance timeout!";
}

MS_LOG(INFO) << "The scheduler is sending query instance to workers and servers!";
MS_LOG(INFO) << "The scheduler is sending disable FLS to " << node_info.node_id_;
}

void InstanceManager::QueryNodeScaleState(const std::shared_ptr<TcpClient> &client, const NodeManager &,
@@ -115,7 +115,7 @@ void InstanceManager::QueryNodeScaleState(const std::shared_ptr<TcpClient> &clie
MS_LOG(WARNING) << "Send query node scale state timeout!";
}

MS_LOG(INFO) << "The scheduler is sending query node scale state to workers and servers!";
MS_LOG(INFO) << "The scheduler is sending query node scale state to " << node_info.node_id_;
}
} // namespace core
} // namespace ps


+ 0
- 3
mindspore/ccsrc/ps/core/node.h View File

@@ -54,7 +54,6 @@ class BACKEND_EXPORT Node {
is_already_stopped_(true),
is_already_finished_(false),
next_request_id_(0),
current_node_state_(NodeState::NODE_STARTING),
current_cluster_state_(ClusterState::CLUSTER_STARTING) {}
virtual ~Node() = default;

@@ -113,8 +112,6 @@ class BACKEND_EXPORT Node {
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;

// Worker and server receive the node state and cluster state from the scheduler.
NodeState current_node_state_;
ClusterState current_cluster_state_;

// Configuration file,The format is as follows


+ 13
- 32
mindspore/ccsrc/ps/core/node_manager.cc View File

@@ -70,8 +70,7 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
<< ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_
<< ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive
<< ", fl iteration num: " << new_fl_iteration_num
<< ", rank id: " << rank_id << ", fl iteration num: " << new_fl_iteration_num
<< ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_);
return rank_id;
}
@@ -235,12 +234,15 @@ void NodeManager::UpdateCluster(bool is_cluster_ready) {
timeout_nodes_info_[it->first] = registered_nodes_info_[it->first];
registered_nodes_info_[it->first].is_alive = false;
}
} else {
if (registered_nodes_info_.count(it->first) && !registered_nodes_info_[it->first].is_alive) {
MS_LOG(WARNING) << registered_nodes_info_[it->first].node_id_ << " is alive.";
registered_nodes_info_[it->first].is_alive = true;
}
}
}

if (!timeout_nodes_info_.empty()) {
UpdateClusterState(ClusterState::NODE_TIMEOUT);

auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY)) {
@@ -249,25 +251,14 @@ void NodeManager::UpdateCluster(bool is_cluster_ready) {
finish_nodes_id_.insert(iter->first);
}
}
if (onPersist_) {
onPersist_();
if (cluster_state_ != ClusterState::CLUSTER_DISABLE_FLS) {
UpdateClusterState(ClusterState::NODE_TIMEOUT);
}
} else if (SizeToUint(heartbeats_.size()) == total_node_num_) {
if (cluster_state_ == ClusterState::NODE_TIMEOUT) {
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
if (registered_nodes_info_.count(it->first) && !it->second.is_alive) {
MS_LOG(WARNING) << it->second.node_id_ << " is alive.";
it->second.is_alive = true;
}
}
if (onPersist_) {
onPersist_();
}
if (is_cluster_ready) {
UpdateClusterState(ClusterState::CLUSTER_READY);
} else {
UpdateClusterState(ClusterState::CLUSTER_STARTING);
}
} else if (SizeToUint(heartbeats_.size()) == total_node_num_ && cluster_state_ == ClusterState::NODE_TIMEOUT) {
if (is_cluster_ready) {
UpdateClusterState(ClusterState::CLUSTER_READY);
} else {
UpdateClusterState(ClusterState::CLUSTER_STARTING);
}
}

@@ -324,11 +315,6 @@ void NodeManager::UpdateNodesInfo() {
nodes_info_ = registered_nodes_info_;
}

void NodeManager::UpdateNodeState(const NodeState &state) {
std::lock_guard<std::mutex> lk(node_mutex_);
node_state_ = state;
}

void NodeManager::UpdateClusterState(const ClusterState &state) {
std::lock_guard<std::mutex> lk(cluster_mutex_);
std::string state_str = CommUtil::ClusterStateToString(state);
@@ -340,11 +326,6 @@ void NodeManager::UpdateClusterState(const ClusterState &state) {
cluster_state_ = state;
}

NodeState NodeManager::GetNodeState() {
std::lock_guard<std::mutex> lk(node_mutex_);
return node_state_;
}

ClusterState NodeManager::GetClusterState() {
std::lock_guard<std::mutex> lk(cluster_mutex_);
return cluster_state_;


+ 0
- 4
mindspore/ccsrc/ps/core/node_manager.h View File

@@ -49,7 +49,6 @@ class NodeManager {
next_worker_rank_id_(0),
next_server_rank_id_(0),
meta_data_(nullptr),
node_state_(NodeState::NODE_STARTING),
cluster_state_(ClusterState::CLUSTER_STARTING) {}
virtual ~NodeManager() = default;
using OnPersist = std::function<void()>;
@@ -104,9 +103,7 @@ class NodeManager {
uint32_t next_worker_rank_id() const;
uint32_t next_server_rank_id() const;

void UpdateNodeState(const NodeState &state);
void UpdateClusterState(const ClusterState &state);
NodeState GetNodeState();
ClusterState GetClusterState();

// When the scheduler receives the scale out or scale in message, the metadata needs to be reset, because all nodes
@@ -179,7 +176,6 @@ class NodeManager {
// Cluster metadata information can be dynamically changed
std::unique_ptr<ClusterMetadata> meta_data_;

NodeState node_state_;
ClusterState cluster_state_;

std::deque<uint32_t> recovery_worker_rank_id_;


+ 78
- 39
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -57,7 +57,6 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
MS_LOG(ERROR) << "Start Scheduler node timeout!";
return false;
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);

StartUpdatePersistentCommandTimer();
MS_LOG(INFO) << "[Scheduler start]: 4. Successfully start scheduler, there are " << node_manager_.worker_num()
@@ -85,7 +84,26 @@ void SchedulerNode::RunRecovery() {
MS_LOG(WARNING) << "There is no registered nodes in scheduler!";
return;
}
MS_LOG(INFO) << "The scheduler start run recovery!";
MS_LOG(INFO) << "The scheduler start run recovery!"
<< " The worker num:" << clusterConfig.initial_worker_num
<< ", the server num:" << clusterConfig.initial_server_num
<< ", the scheduler ip:" << clusterConfig.scheduler_host
<< ", the scheduler port:" << clusterConfig.scheduler_port
<< ", the initial total node num:" << clusterConfig.initial_total_node_num
<< ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id
<< ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id
<< ", the initial cluster state:" << kClusterState.at(clusterConfig.initial_cluster_state);

if (!clusterConfig.initial_registered_nodes_infos.empty()) {
for (const auto kvs : clusterConfig.initial_registered_nodes_infos) {
MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_
<< ", the node_id:" << kvs.second.node_id_
<< ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_)
<< ", the rank_id_:" << kvs.second.rank_id_
<< ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive);
}
}

uint32_t worker_num = clusterConfig.initial_worker_num;
uint32_t server_num = clusterConfig.initial_server_num;

@@ -94,6 +112,11 @@ void SchedulerNode::RunRecovery() {
node_manager_.set_next_worker_rank_id(clusterConfig.initial_next_worker_rank_id);
node_manager_.set_next_server_rank_id(clusterConfig.initial_next_server_rank_id);
node_manager_.set_total_node_num(clusterConfig.initial_total_node_num);
if (clusterConfig.initial_cluster_state == ClusterState::CLUSTER_DISABLE_FLS) {
MS_LOG(WARNING) << "Scheduler recover and update cluster state from recovery file, cluster state is "
<< CommUtil::ClusterStateToString(clusterConfig.initial_cluster_state);
node_manager_.UpdateClusterState(clusterConfig.initial_cluster_state);
}

for (const auto &kvs : initial_node_infos) {
auto &node_id = kvs.first;
@@ -352,44 +375,56 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
"will exit later.";
return;
}
if (!BuildingNetwork()) {
MS_LOG(ERROR) << "Building network failed! Cluster will exit later.";
}
}
}

if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
auto nodes = node_manager_.nodes_info();
for (const auto &id : scale_in_node_ids_) {
MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id;
if (nodes.count(id)) {
auto scale_in_client = GetOrCreateClient(nodes[id]);
SendMetadata(scale_in_client, nodes[id].rank_id_);
node_manager_.UpdateHeartbeat(id);
}
if (connected_nodes_.count(id)) {
MS_LOG(INFO) << "remove scale in node id: " << id << " connection.";
connected_nodes_.erase(id);
}
bool SchedulerNode::BuildingNetwork() {
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
auto nodes = node_manager_.nodes_info();
for (const auto &id : scale_in_node_ids_) {
MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id;
if (nodes.count(id)) {
auto scale_in_client = GetOrCreateClient(nodes[id]);
SendMetadata(scale_in_client, nodes[id].rank_id_);
node_manager_.UpdateHeartbeat(id);
}
if (connected_nodes_.count(id)) {
MS_LOG(INFO) << "remove scale in node id: " << id << " connection.";
connected_nodes_.erase(id);
}
}
node_manager_.UpdateNodesInfo();
auto node_infos = node_manager_.nodes_info();
bool res = SendPrepareBuildingNetwork(node_infos);
if (!res) {
MS_LOG(ERROR) << "Prepare for building network failed! Cluster will exit later.";
return;
}
is_ready_ = true;
MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and "
<< node_manager_.server_num()
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
}
node_manager_.UpdateNodesInfo();
auto node_infos = node_manager_.nodes_info();
bool res = SendPrepareBuildingNetwork(node_infos);
if (!res) {
MS_LOG(ERROR) << "Prepare for building network failed!";
return false;
}
is_ready_ = true;
MS_LOG(INFO) << "Prepare for building network success. There are " << node_manager_.worker_num() << " workers and "
<< node_manager_.server_num()
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";

for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
SendMetadata(client, kvs.second.rank_id_);
node_manager_.UpdateHeartbeat(kvs.first);
}
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
SendMetadata(client, kvs.second.rank_id_);
node_manager_.UpdateHeartbeat(kvs.first);
}

if (node_manager_.GetClusterState() == ClusterState::CLUSTER_DISABLE_FLS) {
MS_LOG(WARNING)
<< "Cluster state is CLUSTER_DISABLE_FLS, do not need to change to CLUSTER_READY when building network.";
} else {
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
PersistMetaData();
wait_start_cond_.notify_all();
}
PersistMetaData();
wait_start_cond_.notify_all();
return true;
}

void SchedulerNode::ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
@@ -897,7 +932,7 @@ bool SchedulerNode::QueryNodeScaleState(const std::shared_ptr<HttpMessageHandler
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(instance_manager_);
instance_manager_->QueryNodeScaleState(client, node_manager_, request_id, node_info_);
instance_manager_->QueryNodeScaleState(client, node_manager_, request_id, kvs.second);
}
}

@@ -1179,7 +1214,7 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr<HttpMessageHandler>
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(instance_manager_);
instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, node_info_);
instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, kvs.second);
}
}
bool res = Wait(request_id);
@@ -1246,7 +1281,7 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptr<HttpMessageHandle
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(instance_manager_);
instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, node_info_);
instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, kvs.second);
}
}
bool res = Wait(request_id);
@@ -1308,7 +1343,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(instance_manager_);
instance_manager_->EnableFLSAsync(client, node_manager_, request_id, node_info_);
instance_manager_->EnableFLSAsync(client, node_manager_, request_id, kvs.second);
}
}
bool res = Wait(request_id);
@@ -1334,6 +1369,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &
js["code"] = kSuccessCode;
js["result"] = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
PersistMetaData();
} else {
js["message"] = "start enabling FL-Server failed.";
js["code"] = kErrorCode;
@@ -1379,7 +1415,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
auto client = GetOrCreateClient(kvs.second);
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(instance_manager_);
instance_manager_->DisableFLSAsync(client, node_manager_, request_id, node_info_);
instance_manager_->DisableFLSAsync(client, node_manager_, request_id, kvs.second);
}
}
bool res = Wait(request_id);
@@ -1404,6 +1440,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler>
js["code"] = kSuccessCode;
js["result"] = true;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
PersistMetaData();
} else {
js["message"] = "start disabling FL-Server failed.";
js["code"] = kErrorCode;
@@ -1565,6 +1602,7 @@ void SchedulerNode::PersistMetaData() {
return;
}
if (!is_ready_) {
MS_LOG(WARNING) << "Cluster is not building network successful, do not persist meta data";
return;
}
if (config_->Exists(kKeyRecovery)) {
@@ -1576,6 +1614,7 @@ void SchedulerNode::PersistMetaData() {
clusterConfig.initial_next_server_rank_id = node_manager_.next_server_rank_id();
clusterConfig.initial_registered_nodes_infos.clear();
clusterConfig.initial_registered_nodes_infos = node_manager_.registered_nodes_info();
clusterConfig.initial_cluster_state = node_manager_.GetClusterState();

scheduler_recovery_->Persist(clusterConfig);
scheduler_recovery_->PersistNodesInfo(clusterConfig);


+ 2
- 0
mindspore/ccsrc/ps/core/scheduler_node.h View File

@@ -217,6 +217,8 @@ class BACKEND_EXPORT SchedulerNode : public Node {
void GeneralResponse(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, bool is_success, const std::string &error);

bool BuildingNetwork();

std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_;
std::unique_ptr<std::thread> update_state_thread_;


+ 6
- 18
mindspore/ccsrc/ps/core/scheduler_recovery.cc View File

@@ -15,6 +15,7 @@
*/

#include "ps/core/scheduler_recovery.h"
#include "ps/core/comm_util.h"

namespace mindspore {
namespace ps {
@@ -63,11 +64,6 @@ bool SchedulerRecovery::Recover() {
MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path();
}

MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num
<< ", the server num:" << clusterConfig.initial_server_num
<< ", the scheduler ip:" << clusterConfig.scheduler_host
<< ", the scheduler port:" << clusterConfig.scheduler_port;

MS_ERROR_IF_NULL_W_RET_VAL(scheduler_recovery_storage_, false);
// 5. recover total node num
if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) {
@@ -110,7 +106,6 @@ bool SchedulerRecovery::Recover() {
node_info.port_ = static_cast<uint16_t>(std::strtol(port.c_str(), nullptr, kBase));
node_info.node_id_ = elem.at("node_id");
node_info.rank_id_ = UlongToUint(std::strtoul(rank_id.c_str(), nullptr, kBase));
node_info.is_alive = CommUtil::StringToBool(elem.at("alive"));
node_info.node_role_ = CommUtil::StringToNodeRole(elem.at("role"));

nodes_infos[node_info.node_id_] = node_info;
@@ -127,18 +122,11 @@ bool SchedulerRecovery::Recover() {
MS_LOG(EXCEPTION) << kRecoveryRegisteredNodesInfos << " is not contained in " << recovery_storage_->file_path();
}

MS_LOG(INFO) << ", the initial total node num:" << clusterConfig.initial_total_node_num
<< ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id
<< ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id;

if (!clusterConfig.initial_registered_nodes_infos.empty()) {
for (const auto kvs : clusterConfig.initial_registered_nodes_infos) {
MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_
<< ", the node_id:" << kvs.second.node_id_
<< ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_)
<< ", the rank_id_:" << kvs.second.rank_id_
<< ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive);
}
// 9. recover cluster state
if (recovery_storage_->Exists(kRecoveryClusterState)) {
clusterConfig.initial_cluster_state = kClusterStateMap.at(recovery_storage_->GetString(kRecoveryClusterState, ""));
} else {
MS_LOG(EXCEPTION) << kRecoveryClusterState << " is not contained in " << recovery_storage_->file_path();
}
return true;
}


+ 3
- 2
mindspore/ccsrc/ps/ps_context.cc View File

@@ -209,8 +209,9 @@ void PSContext::set_rank_id(uint32_t rank_id) const {
}

void PSContext::set_server_mode(const std::string &server_mode) {
if (server_mode != kServerModePS && server_mode != kServerModeFL) {
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL;
if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
<< " or " << kServerModeHybrid;
return;
}
MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";


Loading…
Cancel
Save