| @@ -20,8 +20,8 @@ namespace mindspore { | |||
| namespace kernel { | |||
| namespace ps { | |||
| void PServerKernel::Shard(std::vector<size_t> *shape, int axis) { | |||
| (*shape)[axis] = | |||
| LongToSize(Util::LocalShard(SizeToLong((*shape)[axis]), SizeToLong(rank_id_), SizeToLong(pserver_num_))); | |||
| (*shape)[IntToSize(axis)] = | |||
| LongToSize(Util::LocalShard(SizeToLong((*shape)[IntToSize(axis)]), SizeToLong(rank_id_), SizeToLong(pserver_num_))); | |||
| } | |||
| } // namespace ps | |||
| } // namespace kernel | |||
| @@ -350,7 +350,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra | |||
| size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | |||
| } | |||
| if (AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| const std::string ¶m_name = input_node->fullname_with_scope(); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| continue; | |||
| @@ -34,7 +34,7 @@ | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "debug/dump_proto.h" | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| @@ -77,7 +77,7 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos | |||
| void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { | |||
| @@ -193,7 +193,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| MS_LOG(INFO) << "Bind input output address"; | |||
| runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| InitPSParamAndOptim(kernel_graph, inputs); | |||
| #endif | |||
| } | |||
| @@ -22,7 +22,7 @@ | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/scoped_long_running.h" | |||
| #include "pybind_api/ir/tensor_py.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| @@ -44,7 +44,7 @@ | |||
| #include "debug/common.h" | |||
| #include "utils/trace_base.h" | |||
| #include "frontend/parallel/context.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "ps/constants.h" | |||
| #include "ps/util.h" | |||
| @@ -2483,7 +2483,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| void SessionBasic::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIROptimization(graph); } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { | |||
| if (!ps::PSContext::instance()->is_worker()) { | |||
| return; | |||
| @@ -157,10 +157,10 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl | |||
| } | |||
| // reconstruct secrets | |||
| bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time, | |||
| const schema::SendReconstructSecret *reconstruct_secret_req, | |||
| std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder, | |||
| const std::vector<std::string> &client_list) { | |||
| bool CipherReconStruct::ReconstructSecrets( | |||
| const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req, | |||
| const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder, | |||
| const std::vector<std::string> &client_list) { | |||
| MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START"; | |||
| clock_t start_time = clock(); | |||
| if (reconstruct_secret_req == nullptr || reconstruct_secret_resp_builder == nullptr) { | |||
| @@ -285,7 +285,7 @@ void CipherReconStruct::ClearReconstructSecrets() { | |||
| MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success"; | |||
| } | |||
| void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb, | |||
| void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, | |||
| const schema::ResponseCode retcode, const std::string &reason, | |||
| const int iteration, const std::string &next_req_time) { | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| @@ -44,11 +44,11 @@ class CipherReconStruct { | |||
| // reconstruct secret mask | |||
| bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time, | |||
| const schema::SendReconstructSecret *reconstruct_secret_req, | |||
| std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder, | |||
| const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder, | |||
| const std::vector<std::string> &client_list); | |||
| // build response code of reconstruct secret. | |||
| void BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, | |||
| void BuildReconstructSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const int iteration, const std::string &next_req_time); | |||
| // clear the shared memory. | |||
| @@ -21,7 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace armour { | |||
| bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req, | |||
| std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, | |||
| const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder, | |||
| const string next_req_time) { | |||
| MS_LOG(INFO) << "CipherShares::ShareSecrets START"; | |||
| if (share_secrets_req == nullptr) { | |||
| @@ -95,7 +95,7 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha | |||
| } | |||
| bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, | |||
| std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, | |||
| const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, | |||
| const std::string &next_req_time) { | |||
| MS_LOG(INFO) << "CipherShares::GetSecrets START"; | |||
| clock_t start_time = clock(); | |||
| @@ -180,7 +180,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, | |||
| } | |||
| void CipherShares::BuildGetSecretsRsp( | |||
| std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration, | |||
| const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, schema::ResponseCode retcode, int iteration, | |||
| std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) { | |||
| int rsp_retcode = retcode; | |||
| int rsp_iteration = iteration; | |||
| @@ -199,7 +199,7 @@ void CipherShares::BuildGetSecretsRsp( | |||
| return; | |||
| } | |||
| void CipherShares::BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, | |||
| void CipherShares::BuildShareSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder, | |||
| const schema::ResponseCode retcode, const string &reason, | |||
| const string &next_req_time, const int iteration) { | |||
| auto rsp_reason = share_secrets_resp_builder->CreateString(reason); | |||
| @@ -43,17 +43,19 @@ class CipherShares { | |||
| // handle the client's request of share secrets. | |||
| bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req, | |||
| std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, const string next_req_time); | |||
| const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder, | |||
| const string next_req_time); | |||
| // handle the client's request of get secrets. | |||
| bool GetSecrets(const schema::GetShareSecrets *get_secrets_req, | |||
| std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time); | |||
| const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, | |||
| const std::string &next_req_time); | |||
| // build response code of share secrets. | |||
| void BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, | |||
| void BuildShareSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder, | |||
| const schema::ResponseCode retcode, const string &reason, const string &next_req_time, | |||
| const int iteration); | |||
| // build response code of get secrets. | |||
| void BuildGetSecretsRsp(std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, | |||
| void BuildGetSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, | |||
| const schema::ResponseCode retcode, const int iteration, std::string next_req_time, | |||
| std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares); | |||
| // clear the shared memory. | |||
| @@ -26,7 +26,7 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) { | |||
| clock_t start_time = clock(); | |||
| std::vector<float> noise; | |||
| cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise); | |||
| (void)cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise); | |||
| if (noise.size() != cipher_init_->featuremap_) { | |||
| MS_LOG(ERROR) << " CipherMgr UnMask ERROR"; | |||
| return false; | |||
| @@ -114,7 +114,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size | |||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||
| return false; | |||
| @@ -24,6 +24,7 @@ | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| constexpr uint32_t kDefaultVirtualNodeNum = 32; | |||
| // To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in | |||
| // server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node. | |||
| @@ -104,7 +104,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| } | |||
| CountResponse count_rsp; | |||
| count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); | |||
| (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); | |||
| if (!count_rsp.result()) { | |||
| MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); | |||
| if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) { | |||
| @@ -138,8 +138,8 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) { | |||
| } | |||
| CountReachThresholdResponse count_reach_threshold_rsp; | |||
| count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), | |||
| SizeToInt(query_cnt_enough_rsp_msg->size())); | |||
| (void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), | |||
| SizeToInt(query_cnt_enough_rsp_msg->size())); | |||
| return count_reach_threshold_rsp.is_enough(); | |||
| } | |||
| } | |||
| @@ -178,7 +178,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core: | |||
| } | |||
| CountRequest report_count_req; | |||
| report_count_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| (void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const std::string &name = report_count_req.name(); | |||
| const std::string &id = report_count_req.id(); | |||
| @@ -228,7 +228,7 @@ void DistributedCountService::HandleCountReachThresholdRequest( | |||
| } | |||
| CountReachThresholdRequest count_reach_threshold_req; | |||
| count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| (void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const std::string &name = count_reach_threshold_req.name(); | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| @@ -256,7 +256,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core: | |||
| communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message); | |||
| CounterEvent counter_event; | |||
| counter_event.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| (void)counter_event.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const auto &type = counter_event.type(); | |||
| const auto &name = counter_event.name(); | |||
| @@ -141,7 +141,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { | |||
| MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed."; | |||
| return get_metadata_rsp; | |||
| } | |||
| get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size())); | |||
| (void)get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size())); | |||
| return get_metadata_rsp; | |||
| } | |||
| } | |||
| @@ -165,7 +165,7 @@ bool DistributedMetadataStore::ReInitForScaling() { | |||
| } | |||
| void DistributedMetadataStore::InitHashRing() { | |||
| router_ = std::make_shared<ConsistentHashRing>(32); | |||
| router_ = std::make_shared<ConsistentHashRing>(kDefaultVirtualNodeNum); | |||
| MS_EXCEPTION_IF_NULL(router_); | |||
| for (uint32_t i = 0; i < server_num_; i++) { | |||
| bool ret = router_->Insert(i); | |||
| @@ -184,7 +184,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr | |||
| } | |||
| PBMetadataWithName meta_with_name; | |||
| meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| (void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const std::string &name = meta_with_name.name(); | |||
| MS_LOG(INFO) << "Update metadata for " << name; | |||
| @@ -195,7 +195,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr | |||
| } else { | |||
| update_meta_rsp_msg = "Success"; | |||
| } | |||
| communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message); | |||
| (void)communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message); | |||
| return; | |||
| } | |||
| @@ -206,14 +206,14 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps | |||
| } | |||
| GetMetadataRequest get_metadata_req; | |||
| get_metadata_req.ParseFromArray(message->data(), message->len()); | |||
| (void)get_metadata_req.ParseFromArray(message->data(), message->len()); | |||
| const std::string &name = get_metadata_req.name(); | |||
| MS_LOG(INFO) << "Getting metadata for " << name; | |||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||
| PBMetadata stored_meta = metadata_[name]; | |||
| std::string getting_meta_rsp_msg = stored_meta.SerializeAsString(); | |||
| communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message); | |||
| (void)communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message); | |||
| return; | |||
| } | |||
| @@ -67,7 +67,7 @@ bool Executor::HandlePush(const std::string ¶m_name, const UploadData &uploa | |||
| // Push operation needs to wait until the pulling process is done. | |||
| while (!param_aggr->IsPullingDone()) { | |||
| lock.unlock(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5)); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime)); | |||
| lock.lock(); | |||
| } | |||
| @@ -192,7 +192,7 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) { | |||
| // Pulling must wait until the optimizing process is done. | |||
| while (!param_aggr->IsOptimizingDone()) { | |||
| lock.unlock(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5)); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime)); | |||
| lock.lock(); | |||
| } | |||
| AddressPtr addr = param_aggr->Pull(); | |||
| @@ -314,7 +314,10 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { | |||
| param_names_.push_back(param_name); | |||
| param_aggrs_[param_name] = param_aggr; | |||
| parameter_mutex_[param_name]; | |||
| param_aggr->Init(cnode, aggregation_count_); | |||
| if (!param_aggr->Init(cnode, aggregation_count_)) { | |||
| MS_LOG(EXCEPTION) << "Initializing parameter aggregator failed for " << param_name; | |||
| return false; | |||
| } | |||
| MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success."; | |||
| } | |||
| return true; | |||
| @@ -33,6 +33,8 @@ | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| constexpr int kThreadSleepTime = 5; | |||
| // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles | |||
| // logics relevant to kernel launching. | |||
| class Executor { | |||
| @@ -74,10 +74,10 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica | |||
| }); | |||
| // The time window for one iteration, which will be used in some round kernels. | |||
| size_t iteration_time_window = | |||
| std::accumulate(rounds_.begin(), rounds_.end(), 0, [](size_t total, const std::shared_ptr<Round> &round) { | |||
| return round->check_timeout() ? total + round->time_window() : total; | |||
| }); | |||
| size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0), | |||
| [](size_t total, const std::shared_ptr<Round> &round) { | |||
| return round->check_timeout() ? total + round->time_window() : total; | |||
| }); | |||
| LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window); | |||
| MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window; | |||
| return; | |||
| @@ -162,7 +162,7 @@ bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) { | |||
| return true; | |||
| } | |||
| const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; } | |||
| const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return rounds_; } | |||
| bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; } | |||
| @@ -182,7 +182,7 @@ bool Iteration::SyncIteration(uint32_t rank) { | |||
| } | |||
| SyncIterationResponse sync_iter_rsp; | |||
| sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size()); | |||
| (void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size())); | |||
| iteration_num_ = sync_iter_rsp.iteration(); | |||
| MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is " | |||
| << sync_iter_rsp.iteration(); | |||
| @@ -196,14 +196,14 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::Messa | |||
| } | |||
| SyncIterationRequest sync_iter_req; | |||
| sync_iter_req.ParseFromArray(message->data(), message->len()); | |||
| (void)sync_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| uint32_t rank = sync_iter_req.rank(); | |||
| MS_LOG(INFO) << "Synchronizing iteration request from rank " << rank; | |||
| SyncIterationResponse sync_iter_rsp; | |||
| sync_iter_rsp.set_iteration(iteration_num_); | |||
| std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString(); | |||
| communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message); | |||
| (void)communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message); | |||
| } | |||
| bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) { | |||
| @@ -238,11 +238,11 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps | |||
| NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp; | |||
| notify_leader_to_next_iter_rsp.set_result("success"); | |||
| communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(), | |||
| notify_leader_to_next_iter_rsp.SerializeAsString().size(), message); | |||
| (void)communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(), | |||
| notify_leader_to_next_iter_rsp.SerializeAsString().size(), message); | |||
| NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req; | |||
| notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| (void)notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const auto &rank = notify_leader_to_next_iter_req.rank(); | |||
| const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid(); | |||
| const auto &iter_num = notify_leader_to_next_iter_req.iter_num(); | |||
| @@ -296,7 +296,7 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons | |||
| } | |||
| MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success."; | |||
| }); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(1000)); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); | |||
| return true; | |||
| } | |||
| @@ -306,15 +306,15 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core:: | |||
| } | |||
| PrepareForNextIterRequest prepare_next_iter_req; | |||
| prepare_next_iter_req.ParseFromArray(message->data(), message->len()); | |||
| (void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const auto &reason = prepare_next_iter_req.reason(); | |||
| MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason; | |||
| PrepareForNextIter(); | |||
| PrepareForNextIterResponse prepare_next_iter_rsp; | |||
| prepare_next_iter_rsp.set_result("success"); | |||
| communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(), | |||
| prepare_next_iter_rsp.SerializeAsString().size(), message); | |||
| (void)communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(), | |||
| prepare_next_iter_rsp.SerializeAsString().size(), message); | |||
| } | |||
| void Iteration::PrepareForNextIter() { | |||
| @@ -347,11 +347,11 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::Mess | |||
| MoveToNextIterResponse proceed_to_next_iter_rsp; | |||
| proceed_to_next_iter_rsp.set_result("success"); | |||
| communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(), | |||
| proceed_to_next_iter_rsp.SerializeAsString().size(), message); | |||
| (void)communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(), | |||
| proceed_to_next_iter_rsp.SerializeAsString().size(), message); | |||
| MoveToNextIterRequest proceed_to_next_iter_req; | |||
| proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| (void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid(); | |||
| const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num(); | |||
| const auto &reason = proceed_to_next_iter_req.reason(); | |||
| @@ -370,12 +370,12 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { | |||
| if (is_iteration_valid) { | |||
| // Store the model which is successfully aggregated for this iteration. | |||
| const auto &model = Executor::GetInstance().GetModel(); | |||
| ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||
| (void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||
| MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; | |||
| } else { | |||
| // Store last iteration's model because this iteration is considered as invalid. | |||
| const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1); | |||
| ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||
| (void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||
| MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; | |||
| } | |||
| @@ -405,7 +405,7 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message | |||
| } | |||
| EndLastIterRequest end_last_iter_req; | |||
| end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| (void)end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const auto &last_iter_num = end_last_iter_req.last_iter_num(); | |||
| // If the iteration number is not matched, return error. | |||
| if (last_iter_num != iteration_num_) { | |||
| @@ -413,8 +413,8 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message | |||
| std::to_string(iteration_num_) + ", iteration to be ended is " + std::to_string(last_iter_num); | |||
| EndLastIterResponse end_last_iter_rsp; | |||
| end_last_iter_rsp.set_result(reason); | |||
| communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), | |||
| end_last_iter_rsp.SerializeAsString().size(), message); | |||
| (void)communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), | |||
| end_last_iter_rsp.SerializeAsString().size(), message); | |||
| return; | |||
| } | |||
| @@ -422,8 +422,8 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message | |||
| EndLastIterResponse end_last_iter_rsp; | |||
| end_last_iter_rsp.set_result("success"); | |||
| communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), | |||
| end_last_iter_rsp.SerializeAsString().size(), message); | |||
| (void)communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(), | |||
| end_last_iter_rsp.SerializeAsString().size(), message); | |||
| } | |||
| void Iteration::EndLastIter() { | |||
| @@ -79,7 +79,7 @@ class Iteration { | |||
| // The server number after scaling is required in some rounds. | |||
| bool ReInitForScaling(uint32_t server_num, uint32_t server_rank); | |||
| const std::vector<std::shared_ptr<Round>> &rounds(); | |||
| const std::vector<std::shared_ptr<Round>> &rounds() const; | |||
| bool is_last_iteration_valid() const; | |||
| @@ -70,10 +70,9 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons | |||
| auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp); | |||
| std::map<std::string, AddressPtr> feature_maps; | |||
| size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| size_t get_model_iter = static_cast<size_t>(get_model_req->iteration()); | |||
| size_t get_model_iter = IntToSize(get_model_req->iteration()); | |||
| const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); | |||
| size_t latest_iter_num = iter_to_model.rbegin()->first; | |||
| // If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later. | |||
| if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) { | |||
| std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) + | |||
| @@ -63,10 +63,11 @@ bool PullWeightKernel::Reset() { | |||
| return true; | |||
| } | |||
| void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req) { | |||
| void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestPullWeight *pull_weight_req) { | |||
| std::map<std::string, AddressPtr> feature_maps = {}; | |||
| size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| size_t pull_weight_iter = static_cast<size_t>(pull_weight_req->iteration()); | |||
| size_t pull_weight_iter = IntToSize(pull_weight_req->iteration()); | |||
| // The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry. | |||
| if (pull_weight_iter != current_iter) { | |||
| std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) + | |||
| @@ -110,7 +111,7 @@ void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema:: | |||
| return; | |||
| } | |||
| void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, | |||
| void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, size_t iteration, | |||
| const std::map<std::string, AddressPtr> &feature_maps) { | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| @@ -127,7 +128,7 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const | |||
| schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get())); | |||
| rsp_pull_weight_builder.add_retcode(retcode); | |||
| rsp_pull_weight_builder.add_reason(fbs_reason); | |||
| rsp_pull_weight_builder.add_iteration(iteration); | |||
| rsp_pull_weight_builder.add_iteration(SizeToInt(iteration)); | |||
| rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector); | |||
| auto rsp_pull_weight = rsp_pull_weight_builder.Finish(); | |||
| fbb->Finish(rsp_pull_weight); | |||
| @@ -42,9 +42,10 @@ class PullWeightKernel : public RoundKernel { | |||
| bool Reset() override; | |||
| private: | |||
| void PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req); | |||
| void BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason, | |||
| size_t iteration, const std::map<std::string, AddressPtr> &feature_maps); | |||
| void PullWeight(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestPullWeight *pull_weight_req); | |||
| void BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, size_t iteration, | |||
| const std::map<std::string, AddressPtr> &feature_maps); | |||
| Executor *executor_; | |||
| @@ -67,12 +67,12 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH | |||
| return; | |||
| } | |||
| ResultCode PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, | |||
| ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestPushWeight *push_weight_req) { | |||
| if (fbb == nullptr || push_weight_req == nullptr) { | |||
| return ResultCode::kSuccessAndReturn; | |||
| } | |||
| size_t iteration = static_cast<size_t>(push_weight_req->iteration()); | |||
| size_t iteration = IntToSize(push_weight_req->iteration()); | |||
| size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| if (iteration != current_iter) { | |||
| std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) + | |||
| @@ -123,13 +123,13 @@ std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::R | |||
| return upload_feature_map; | |||
| } | |||
| void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, | |||
| void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, size_t iteration) { | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get())); | |||
| rsp_push_weight_builder.add_retcode(retcode); | |||
| rsp_push_weight_builder.add_reason(fbs_reason); | |||
| rsp_push_weight_builder.add_iteration(iteration); | |||
| rsp_push_weight_builder.add_iteration(SizeToInt(iteration)); | |||
| auto rsp_push_weight = rsp_push_weight_builder.Finish(); | |||
| fbb->Finish(rsp_push_weight); | |||
| return; | |||
| @@ -42,10 +42,10 @@ class PushWeightKernel : public RoundKernel { | |||
| void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override; | |||
| private: | |||
| ResultCode PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req); | |||
| ResultCode PushWeight(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestPushWeight *push_weight_req); | |||
| std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req); | |||
| void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason, | |||
| size_t iteration); | |||
| void BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, size_t iteration); | |||
| Executor *executor_; | |||
| uint32_t local_rank_; | |||
| @@ -34,7 +34,7 @@ RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), e | |||
| // Detect whether there's any data needs to be released every 100 milliseconds. | |||
| if (heap_data_to_release_.empty()) { | |||
| release_lock.unlock(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kReleaseDuration)); | |||
| continue; | |||
| } | |||
| @@ -61,9 +61,9 @@ RoundKernel::~RoundKernel() { | |||
| } | |||
| } | |||
| void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; } | |||
| void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { return; } | |||
| void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; } | |||
| void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { return; } | |||
| void RoundKernel::StopTimer() const { | |||
| if (stop_timer_cb_) { | |||
| @@ -38,6 +38,7 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| namespace kernel { | |||
| constexpr uint64_t kReleaseDuration = 100; | |||
| // RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round | |||
| // kernels to represent the process. They receive and parse messages from the server communication module. After | |||
| // handling these messages, round kernels allocate response data and send it back. | |||
| @@ -118,7 +118,7 @@ bool StartFLJobKernel::Reset() { | |||
| } | |||
| void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { | |||
| iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_; | |||
| iter_next_req_timestamp_ = LongToSize(CURRENT_TIME_MILLI.count()) + iteration_time_window_; | |||
| LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_); | |||
| // The first startFLJob request means a new iteration starts running. | |||
| Iteration::GetInstance().SetIterationRunning(); | |||
| @@ -220,9 +220,9 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, | |||
| schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); | |||
| fl_plan_builder.add_fl_name(fbs_fl_name); | |||
| fl_plan_builder.add_server_mode(fbs_server_mode); | |||
| fl_plan_builder.add_iterations(ps::PSContext::instance()->fl_iteration_num()); | |||
| fl_plan_builder.add_epochs(ps::PSContext::instance()->client_epoch_num()); | |||
| fl_plan_builder.add_mini_batch(ps::PSContext::instance()->client_batch_size()); | |||
| fl_plan_builder.add_iterations(SizeToInt(ps::PSContext::instance()->fl_iteration_num())); | |||
| fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num())); | |||
| fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size())); | |||
| fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate()); | |||
| #ifdef ENABLE_ARMOUR | |||
| fl_plan_builder.add_cipher(cipher_public_params); | |||
| @@ -90,7 +90,7 @@ bool UpdateModelKernel::Reset() { | |||
| return true; | |||
| } | |||
| void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { | |||
| if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) { | |||
| while (!executor_->IsAllWeightAggregationDone()) { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5)); | |||
| @@ -120,7 +120,7 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr | |||
| ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb) { | |||
| RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| size_t iteration = static_cast<size_t>(update_model_req->iteration()); | |||
| size_t iteration = IntToSize(update_model_req->iteration()); | |||
| if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { | |||
| std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + | |||
| ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) + | |||
| @@ -281,16 +281,16 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr< | |||
| KernelParams aggr_params = {}; | |||
| const std::vector<std::string> &input_names = aggr_kernel->input_names(); | |||
| std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names(); | |||
| std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| const std::vector<std::string> &output_names = aggr_kernel->output_names(); | |||
| std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs); | |||
| aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); | |||
| @@ -304,16 +304,16 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke | |||
| KernelParams optimizer_params = {}; | |||
| const std::vector<std::string> &input_names = optimizer_kernel->input_names(); | |||
| std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names(); | |||
| std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| const std::vector<std::string> &output_names = optimizer_kernel->output_names(); | |||
| std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params)); | |||
| return true; | |||
| @@ -34,8 +34,8 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo | |||
| threshold_count_(threshold_count), | |||
| server_num_as_threshold_(server_num_as_threshold) {} | |||
| void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb, | |||
| FinishIterCb finish_iteration_cb) { | |||
| void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb, | |||
| const FinishIterCb &finish_iteration_cb) { | |||
| MS_EXCEPTION_IF_NULL(communicator); | |||
| communicator_ = communicator; | |||
| @@ -50,7 +50,7 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun | |||
| }; | |||
| // Callback for finalizing the server. This can only be called once. | |||
| finalize_cb_ = [&](void) -> void { communicator_->Stop(); }; | |||
| finalize_cb_ = [&](void) -> void { (void)communicator_->Stop(); }; | |||
| if (check_timeout_) { | |||
| iter_timer_ = std::make_shared<IterationTimer>(); | |||
| @@ -116,7 +116,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m | |||
| if (Server::GetInstance().IsSafeMode()) { | |||
| MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later."; | |||
| std::string reason = "The cluster is in safemode."; | |||
| communicator_->SendResponse(reason.c_str(), reason.size(), message); | |||
| (void)communicator_->SendResponse(reason.c_str(), reason.size(), message); | |||
| return; | |||
| } | |||
| @@ -128,10 +128,10 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m | |||
| if (output->size == 0) { | |||
| std::string reason = "The output of the round " + name_ + " is empty."; | |||
| MS_LOG(WARNING) << reason; | |||
| communicator_->SendResponse(reason.c_str(), reason.size(), message); | |||
| (void)communicator_->SendResponse(reason.c_str(), reason.size(), message); | |||
| return; | |||
| } | |||
| communicator_->SendResponse(output->addr, output->size, message); | |||
| (void)communicator_->SendResponse(output->addr, output->size, message); | |||
| kernel_->Release(output); | |||
| // Must send response back no matter what value Launch method returns. | |||
| @@ -142,7 +142,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m | |||
| return; | |||
| } | |||
| void Round::Reset() { kernel_->Reset(); } | |||
| void Round::Reset() { (void)kernel_->Reset(); } | |||
| const std::string &Round::name() const { return name_; } | |||
| @@ -37,8 +37,8 @@ class Round { | |||
| bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false); | |||
| ~Round() = default; | |||
| void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb, | |||
| FinishIterCb finish_iteration_cb); | |||
| void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb, | |||
| const FinishIterCb &finish_iteration_cb); | |||
| // Reinitialize count service and round kernel of this round after scaling operations are done. | |||
| bool ReInitForScaling(uint32_t server_num); | |||
| @@ -102,7 +102,7 @@ void Server::CancelSafeMode() { | |||
| safemode_ = false; | |||
| } | |||
| bool Server::IsSafeMode() { return safemode_.load(); } | |||
| bool Server::IsSafeMode() const { return safemode_.load(); } | |||
| void Server::InitServerContext() { | |||
| ps::PSContext::instance()->GenerateResetterRound(); | |||
| @@ -121,7 +121,7 @@ void Server::InitServerContext() { | |||
| void Server::InitCluster() { | |||
| server_node_ = std::make_shared<ps::core::ServerNode>(); | |||
| MS_EXCEPTION_IF_NULL(server_node_); | |||
| task_executor_ = std::make_shared<ps::core::TaskExecutor>(32); | |||
| task_executor_ = std::make_shared<ps::core::TaskExecutor>(kExecutorThreadPoolSize); | |||
| MS_EXCEPTION_IF_NULL(task_executor_); | |||
| if (!InitCommunicatorWithServer()) { | |||
| MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed."; | |||
| @@ -235,9 +235,9 @@ void Server::InitCipher() { | |||
| #ifdef ENABLE_ARMOUR | |||
| cipher_init_ = &armour::CipherInit::GetInstance(); | |||
| int cipher_t = cipher_reconstruct_secrets_down_cnt_; | |||
| int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_); | |||
| unsigned char cipher_p[SECRET_MAX_LEN] = {0}; | |||
| int cipher_g = 1; | |||
| const int cipher_g = 1; | |||
| unsigned char cipher_prime[PRIME_MAX_LEN] = {0}; | |||
| float dp_eps = ps::PSContext::instance()->dp_eps(); | |||
| float dp_delta = ps::PSContext::instance()->dp_delta(); | |||
| @@ -304,8 +304,8 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC | |||
| MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed."; | |||
| safemode_ = true; | |||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||
| [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); }); | |||
| communicator_with_server_->Stop(); | |||
| [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); }); | |||
| (void)communicator_with_server_->Stop(); | |||
| }); | |||
| communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() { | |||
| @@ -314,8 +314,8 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC | |||
| "network building phase."; | |||
| safemode_ = true; | |||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||
| [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); }); | |||
| communicator_with_server_->Stop(); | |||
| [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); }); | |||
| (void)communicator_with_server_->Stop(); | |||
| }); | |||
| } | |||
| @@ -363,7 +363,10 @@ void Server::StartCommunicator() { | |||
| } | |||
| MS_LOG(INFO) << "Start communicator with server."; | |||
| communicator_with_server_->Start(); | |||
| if (!communicator_with_server_->Start()) { | |||
| MS_LOG(EXCEPTION) << "Starting communicator with server failed."; | |||
| return; | |||
| } | |||
| DistributedMetadataStore::GetInstance().Initialize(server_node_); | |||
| CollectiveOpsImpl::GetInstance().Initialize(server_node_); | |||
| DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank); | |||
| @@ -371,7 +374,11 @@ void Server::StartCommunicator() { | |||
| MS_LOG(INFO) << "Start communicator with worker."; | |||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||
| [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Start(); }); | |||
| [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { | |||
| if (!communicator->Start()) { | |||
| MS_LOG(EXCEPTION) << "Starting communicator with worker failed."; | |||
| } | |||
| }); | |||
| } | |||
| void Server::ProcessBeforeScalingOut() { | |||
| @@ -405,7 +412,7 @@ void Server::ProcessAfterScalingOut() { | |||
| if (!Executor::GetInstance().ReInitForScaling()) { | |||
| MS_LOG(WARNING) << "Executor reinitializing failed."; | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(1000)); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); | |||
| safemode_ = false; | |||
| } | |||
| @@ -418,7 +425,7 @@ void Server::ProcessAfterScalingIn() { | |||
| if (server_node_->rank_id() == UINT32_MAX) { | |||
| MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting."; | |||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||
| [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); }); | |||
| [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); }); | |||
| communicator_with_server_->Stop(); | |||
| return; | |||
| } | |||
| @@ -439,7 +446,7 @@ void Server::ProcessAfterScalingIn() { | |||
| if (!Executor::GetInstance().ReInitForScaling()) { | |||
| MS_LOG(WARNING) << "Executor reinitializing failed."; | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(1000)); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); | |||
| safemode_ = false; | |||
| } | |||
| } // namespace server | |||
| @@ -33,6 +33,9 @@ | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| // The sleeping time of the server thread before the networking is completed. | |||
| constexpr uint32_t kServerSleepTimeForNetworking = 1000; | |||
| // Class Server is the entrance of MindSpore's parameter server training mode and federated learning. | |||
| class Server { | |||
| public: | |||
| @@ -51,7 +54,7 @@ class Server { | |||
| void SwitchToSafeMode(); | |||
| void CancelSafeMode(); | |||
| bool IsSafeMode(); | |||
| bool IsSafeMode() const; | |||
| private: | |||
| Server() | |||
| @@ -162,8 +165,6 @@ class Server { | |||
| uint32_t server_num_; | |||
| uint32_t worker_num_; | |||
| uint16_t fl_server_port_; | |||
| size_t start_fl_job_cnt_; | |||
| size_t update_model_cnt_; | |||
| size_t cipher_initial_client_cnt_; | |||
| size_t cipher_exchange_secrets_cnt_; | |||
| size_t cipher_share_secrets_cnt_; | |||
| @@ -171,9 +172,6 @@ class Server { | |||
| size_t cipher_reconstruct_secrets_up_cnt_; | |||
| size_t cipher_reconstruct_secrets_down_cnt_; | |||
| uint64_t cipher_time_window_; | |||
| float percent_for_update_model_; | |||
| float percent_for_get_model_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace fl | |||
| @@ -70,8 +70,14 @@ void FLWorker::Run() { | |||
| void FLWorker::Finalize() { | |||
| MS_EXCEPTION_IF_NULL(worker_node_); | |||
| worker_node_->Finish(); | |||
| worker_node_->Stop(); | |||
| if (!worker_node_->Finish()) { | |||
| MS_LOG(ERROR) << "Worker node finishing failed."; | |||
| return; | |||
| } | |||
| if (!worker_node_->Stop()) { | |||
| MS_LOG(ERROR) << "Worker node stopping failed."; | |||
| return; | |||
| } | |||
| } | |||
| bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command, | |||
| @@ -201,8 +207,8 @@ void FLWorker::ProcessAfterScalingOut() { | |||
| } | |||
| MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker."; | |||
| server_num_ = worker_node_->server_num(); | |||
| worker_num_ = worker_node_->worker_num(); | |||
| server_num_ = IntToUint(worker_node_->server_num()); | |||
| worker_num_ = IntToUint(worker_node_->worker_num()); | |||
| MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is " | |||
| << server_num_ << ". Exit safemode."; | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking)); | |||
| @@ -215,8 +221,8 @@ void FLWorker::ProcessAfterScalingIn() { | |||
| } | |||
| MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker."; | |||
| server_num_ = worker_node_->server_num(); | |||
| worker_num_ = worker_node_->worker_num(); | |||
| server_num_ = IntToUint(worker_node_->server_num()); | |||
| worker_num_ = IntToUint(worker_node_->worker_num()); | |||
| MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_ | |||
| << ". Exit safemode."; | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking)); | |||
| @@ -25,7 +25,7 @@ | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||
| #include "frontend/parallel/context.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "utils/ms_context.h" | |||
| #endif | |||
| @@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() { | |||
| if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { | |||
| dynamic_shape_indices_ = true; | |||
| } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE); | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) { | |||
| @@ -637,7 +637,7 @@ Status GatherPInfo::InferBias() { | |||
| rank = rank % (params_strategy[0] * params_strategy[1]); | |||
| } | |||
| } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); | |||
| return SUCCESS; | |||
| @@ -28,7 +28,7 @@ | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| @@ -119,7 +119,7 @@ std::vector<StrategyPtr> UniqueInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| GenerateGraph gen_g = GenerateGraph(attrs_); | |||
| if (gen_g.Init(cnode) != SUCCESS) { | |||
| @@ -156,7 +156,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| #endif | |||
| ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| @@ -47,7 +47,7 @@ | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| #include "mindspore/core/utils/parallel_node_check.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| @@ -44,7 +44,7 @@ | |||
| #include "vm/transform.h" | |||
| #include "parse/python_adapter.h" | |||
| #include "frontend/optimizer/py_pass_manager.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/parameter_server.h" | |||
| #include "ps/scheduler.h" | |||
| #include "ps/worker.h" | |||
| @@ -478,7 +478,7 @@ bool OptInlineAction(const ResourcePtr &res) { | |||
| bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } | |||
| bool VmOptimizeAction(const ResourcePtr &res) { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| if (ps::PSContext::instance()->is_ps_mode()) { | |||
| kVmPasses.push_back({"server_communication_op_fusion", ps::Util::FuseServerCommOps}); | |||
| } | |||
| @@ -633,8 +633,8 @@ bool ExecuteAction(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| bool StartPSWorkerAction(const ResourcePtr &res) { | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| bool StartPSWorkerAction(const ResourcePtr &) { | |||
| ps::Worker::GetInstance().Run(); | |||
| return true; | |||
| } | |||
| @@ -695,7 +695,7 @@ bool StartServerAction(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| bool StartPSSchedulerAction(const ResourcePtr &res) { | |||
| bool StartPSSchedulerAction(const ResourcePtr &) { | |||
| ps::Scheduler::GetInstance().Run(); | |||
| return true; | |||
| } | |||
| @@ -861,7 +861,7 @@ std::vector<ActionItem> VmPipeline() { | |||
| actions.emplace_back(std::make_pair("remove_monad_from_random_op", RemoveRandomOpMonadAction)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| if (ps::PSContext::instance()->is_worker()) { | |||
| std::string server_mode = ps::PSContext::instance()->server_mode(); | |||
| if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) { | |||
| @@ -889,7 +889,7 @@ std::vector<ActionItem> BackendPipeline() { | |||
| return actions; | |||
| } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| std::vector<ActionItem> ServerPipeline() { | |||
| auto actions = CommonPipeline(); | |||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | |||
| @@ -34,7 +34,7 @@ | |||
| #else | |||
| #include "runtime/device/gpu/distribution/collective_fake_init.h" | |||
| #endif | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/util.h" | |||
| #endif | |||
| #include "ps/ps_context.h" | |||
| @@ -47,7 +47,7 @@ | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #include "frontend/optimizer/irpass/parameter_eliminate.h" | |||
| #include "frontend/optimizer/irpass/updatestate_eliminate.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| @@ -211,7 +211,7 @@ namespace { | |||
| bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); } | |||
| bool parallel_mode() { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { | |||
| return false; | |||
| } | |||
| @@ -556,7 +556,7 @@ bool AddRecomputationPass(const ResourcePtr &res) { | |||
| } | |||
| bool AddCacheEmbeddingPass(const ResourcePtr &res) { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| if (ps::PSContext::instance()->is_ps_mode()) { | |||
| return true; | |||
| } | |||
| @@ -148,7 +148,7 @@ std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::str | |||
| } | |||
| std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateTcpComm(const std::string &scheduler_ip, | |||
| std::int16_t scheduler_port, uint32_t worker_num, | |||
| uint16_t scheduler_port, uint32_t worker_num, | |||
| uint32_t server_num, | |||
| const std::shared_ptr<TaskExecutor> &task_executor) { | |||
| std::lock_guard<std::mutex> lock(communicator_mutex_); | |||
| @@ -61,7 +61,7 @@ class ServerNode : public AbstractNode { | |||
| std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port, | |||
| const std::shared_ptr<TaskExecutor> &task_executor); | |||
| std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port, | |||
| std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, uint16_t scheduler_port, | |||
| uint32_t worker_num, uint32_t server_num, | |||
| const std::shared_ptr<TaskExecutor> &task_executor); | |||
| @@ -18,7 +18,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #endif | |||
| @@ -63,7 +63,7 @@ void PSContext::SetPSEnable(bool enabled) { | |||
| } | |||
| bool PSContext::is_ps_mode() const { | |||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||
| if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { | |||
| return true; | |||
| } | |||
| return ps_enabled_; | |||
| @@ -74,7 +74,7 @@ void PSContext::Reset() { | |||
| is_worker_ = false; | |||
| is_pserver_ = false; | |||
| is_sched_ = false; | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps_cache_instance.Finalize(); | |||
| set_cache_enable(false); | |||
| @@ -83,7 +83,7 @@ void PSContext::Reset() { | |||
| } | |||
| std::string PSContext::ms_role() const { | |||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||
| if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { | |||
| return role_; | |||
| } | |||
| if (is_worker_) { | |||
| @@ -98,21 +98,21 @@ std::string PSContext::ms_role() const { | |||
| } | |||
| bool PSContext::is_worker() const { | |||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||
| if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { | |||
| return role_ == kEnvRoleOfWorker; | |||
| } | |||
| return is_worker_; | |||
| } | |||
| bool PSContext::is_server() const { | |||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||
| if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { | |||
| return role_ == kEnvRoleOfServer; | |||
| } | |||
| return is_pserver_; | |||
| } | |||
| bool PSContext::is_scheduler() const { | |||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||
| if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) { | |||
| return role_ == kEnvRoleOfScheduler; | |||
| } | |||
| return is_sched_; | |||
| @@ -130,44 +130,44 @@ uint32_t PSContext::ps_rank_id() const { return rank_id_; } | |||
| void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | |||
| size_t vocab_size) const { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); | |||
| #endif | |||
| } | |||
| void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | |||
| size_t cache_vocab_size, size_t embedding_size) const { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); | |||
| #endif | |||
| } | |||
| void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); | |||
| #endif | |||
| } | |||
| void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); | |||
| #endif | |||
| } | |||
| void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); | |||
| #endif | |||
| } | |||
| void PSContext::set_cache_enable(bool cache_enable) const { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); | |||
| #endif | |||
| } | |||
| void PSContext::set_rank_id(uint32_t rank_id) const { | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| ps_cache_instance.set_rank_id(rank_id); | |||
| #endif | |||
| } | |||
| @@ -358,7 +358,7 @@ void PSContext::set_cipher_time_window(uint64_t cipher_time_window) { | |||
| uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; } | |||
| void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) { | |||
| if (reconstruct_secrets_threshold <= 0) { | |||
| if (reconstruct_secrets_threshold == 0) { | |||
| MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive."; | |||
| return; | |||
| } | |||
| @@ -136,7 +136,8 @@ bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| void Util::DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name) { | |||
| void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name, | |||
| const std::string &fused_cnode_name) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||
| @@ -56,7 +56,8 @@ class Util { | |||
| static bool FuseServerCommOps(const pipeline::ResourcePtr &res); | |||
| private: | |||
| static void DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name); | |||
| static void DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name, | |||
| const std::string &fused_cnode_name); | |||
| static kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list); | |||
| static std::unordered_map<std::string, int64_t> optimizer_to_ids; | |||
| @@ -32,7 +32,7 @@ | |||
| #include "utils/utils.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "debug/env_config_parser.h" | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| @@ -333,7 +333,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| } | |||
| add_need_alloc_nodes(input_node); | |||
| } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| bool ps_cache_check = false; | |||
| #endif | |||
| for (auto &item : need_alloc_nodes) { | |||
| @@ -346,7 +346,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| continue; | |||
| } | |||
| DeviceAddressPtr device_address = nullptr; | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| const std::string ¶m_name = item->fullname_with_scope(); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| MS_LOG(INFO) << "Parameter(" << param_name << ")" | |||
| @@ -1087,7 +1087,7 @@ void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs, | |||
| } | |||
| } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, | |||
| AnfNodePtr *const first_cache_input_index, | |||
| size_t *const first_cache_size) { | |||
| @@ -817,6 +817,7 @@ def reset_ps_context(): | |||
| """ | |||
| _reset_ps_context() | |||
| def set_fl_context(**kwargs): | |||
| """ | |||
| Set federated learning training mode context. | |||
| @@ -726,6 +726,7 @@ class Pull(PrimitiveWithInfer): | |||
| def infer_dtype(self, key_dtype, weight_dtype): | |||
| return mstype.float32 | |||
| class PullWeight(PrimitiveWithInfer): | |||
| """ | |||
| Pull weight by its names from server. | |||
| @@ -751,6 +752,7 @@ class PullWeight(PrimitiveWithInfer): | |||
| def infer_dtype(self, weight, name, index): | |||
| return mstype.float32 | |||
| class PushWeight(PrimitiveWithInfer): | |||
| """ | |||
| Upload weight by its names to server. | |||
| @@ -776,6 +778,7 @@ class PushWeight(PrimitiveWithInfer): | |||
| def infer_dtype(self, weight, ps_key, index): | |||
| return mstype.float32 | |||
| class identity(Primitive): | |||
| """ | |||
| Makes a identify primitive, used for pynative mode. | |||
| @@ -31,6 +31,7 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"] | |||
| _check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"] | |||
| def ps_context(): | |||
| """ | |||
| Get the global _ps_context, if it is not created, create a new one. | |||
| @@ -226,6 +227,7 @@ def _set_cache_enable(cache_enable): | |||
| def _set_rank_id(rank_id): | |||
| ps_context().set_rank_id(rank_id) | |||
| def _check_value(key, value): | |||
| """ | |||
| Validate the value for parameter server context keys. | |||