| @@ -29,6 +29,8 @@ void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> & | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false); | |||
| int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T)); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| @@ -135,6 +137,12 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false); | |||
| if (sendbuff == nullptr || recvbuff == nullptr) { | |||
| MS_LOG(ERROR) << "Input sendbuff or recvbuff for ReduceBroadcastAllReduce is nullptr."; | |||
| return false; | |||
| } | |||
| uint32_t rank_size = server_num_; | |||
| MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_ | |||
| << ", count:" << count; | |||
| @@ -201,12 +201,6 @@ constexpr auto kCtxCipherPrimer = "cipher_primer"; | |||
| #define CURRENT_TIME_MILLI \ | |||
| std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()) | |||
| #define RETURN_IF_NULL(expr, ret) \ | |||
| if (expr == nullptr) { \ | |||
| MS_LOG(ERROR) << #expr << " is nullptr."; \ | |||
| return ret; \ | |||
| } | |||
| // This method returns the size in bytes of the given TypeId. | |||
| inline size_t GetTypeIdByte(const TypeId &type) { | |||
| switch (type) { | |||
| @@ -224,12 +218,12 @@ inline size_t GetTypeIdByte(const TypeId &type) { | |||
| } | |||
| inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size_t param_idx) { | |||
| RETURN_IF_NULL(kernel_node, nullptr); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(kernel_node, nullptr); | |||
| auto param_node = | |||
| AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, param_idx), 0).first->cast<ParameterPtr>(); | |||
| RETURN_IF_NULL(param_node, nullptr); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_node, nullptr); | |||
| auto param_tensor = param_node->default_param()->cast<tensor::TensorPtr>(); | |||
| RETURN_IF_NULL(param_tensor, nullptr); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_tensor, nullptr); | |||
| AddressPtr addr = std::make_shared<kernel::Address>(); | |||
| addr->addr = param_tensor->data_c(); | |||
| addr->size = param_tensor->data().nbytes(); | |||
| @@ -24,8 +24,8 @@ namespace fl { | |||
| namespace server { | |||
| void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node, | |||
| uint32_t counting_server_rank) { | |||
| MS_EXCEPTION_IF_NULL(server_node); | |||
| server_node_ = server_node; | |||
| MS_EXCEPTION_IF_NULL(server_node_); | |||
| local_rank_ = server_node_->rank_id(); | |||
| server_num_ = ps::PSContext::instance()->initial_server_num(); | |||
| counting_server_rank_ = counting_server_rank; | |||
| @@ -33,8 +33,8 @@ void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerN | |||
| } | |||
| void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) { | |||
| MS_EXCEPTION_IF_NULL(communicator); | |||
| communicator_ = communicator; | |||
| MS_EXCEPTION_IF_NULL(communicator_); | |||
| communicator_->RegisterMsgCallBack( | |||
| "count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1)); | |||
| communicator_->RegisterMsgCallBack( | |||
| @@ -107,6 +107,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string & | |||
| (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); | |||
| if (!count_rsp.result()) { | |||
| MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); | |||
| // If the error is caused by the network issue, return the reason. | |||
| if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) { | |||
| *reason = kNetworkError; | |||
| } | |||
| @@ -23,8 +23,8 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) { | |||
| server_node_ = server_node; | |||
| MS_EXCEPTION_IF_NULL(server_node); | |||
| server_node_ = server_node; | |||
| local_rank_ = server_node_->rank_id(); | |||
| server_num_ = ps::PSContext::instance()->initial_server_num(); | |||
| InitHashRing(); | |||
| @@ -32,8 +32,8 @@ void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::Server | |||
| } | |||
| void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) { | |||
| MS_EXCEPTION_IF_NULL(communicator); | |||
| communicator_ = communicator; | |||
| MS_EXCEPTION_IF_NULL(communicator_); | |||
| communicator_->RegisterMsgCallBack( | |||
| "updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1)); | |||
| communicator_->RegisterMsgCallBack( | |||
| @@ -109,6 +109,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM | |||
| return false; | |||
| } | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_meta_rsp_msg, false); | |||
| std::string update_meta_rsp = | |||
| std::string(reinterpret_cast<char *>(update_meta_rsp_msg->data()), update_meta_rsp_msg->size()); | |||
| if (update_meta_rsp != kSuccess) { | |||
| @@ -141,6 +142,8 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { | |||
| MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed."; | |||
| return get_metadata_rsp; | |||
| } | |||
| MS_ERROR_IF_NULL_W_RET_VAL(get_meta_rsp_msg, get_metadata_rsp); | |||
| (void)get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size())); | |||
| return get_metadata_rsp; | |||
| } | |||
| @@ -63,7 +63,7 @@ bool Executor::HandlePush(const std::string ¶m_name, const UploadData &uploa | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); | |||
| // Push operation needs to wait until the pulling process is done. | |||
| while (!param_aggr->IsPullingDone()) { | |||
| lock.unlock(); | |||
| @@ -106,7 +106,7 @@ bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); | |||
| if (!param_aggr->UpdateData(upload_data)) { | |||
| MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; | |||
| return false; | |||
| @@ -131,7 +131,7 @@ bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &f | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); | |||
| const UploadData &upload_data = trainable_param.second; | |||
| if (!param_aggr->UpdateData(upload_data)) { | |||
| MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; | |||
| @@ -156,7 +156,7 @@ bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_ma | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); | |||
| AddressPtr old_weight = param_aggr->GetWeight(); | |||
| if (old_weight == nullptr) { | |||
| MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; | |||
| @@ -188,7 +188,7 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) { | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, nullptr); | |||
| // Pulling must wait until the optimizing process is done. | |||
| while (!param_aggr->IsOptimizingDone()) { | |||
| lock.unlock(); | |||
| @@ -214,7 +214,7 @@ std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<s | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| const auto ¶m_aggr = param_aggrs_[param_name]; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights); | |||
| AddressPtr addr = param_aggr->GetWeight(); | |||
| if (addr == nullptr) { | |||
| MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; | |||
| @@ -236,7 +236,9 @@ bool Executor::IsWeightAggrDone(const std::vector<std::string> ¶m_names) { | |||
| std::mutex &mtx = parameter_mutex_[name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| if (!param_aggrs_[name]->IsAggregationDone()) { | |||
| auto ¶m_aggr = param_aggrs_[name]; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); | |||
| if (!param_aggr->IsAggregationDone()) { | |||
| MS_LOG(DEBUG) << "Update model for " << name << " is not done yet."; | |||
| return false; | |||
| } | |||
| @@ -248,7 +250,9 @@ void Executor::ResetAggregationStatus() { | |||
| for (const auto ¶m_name : param_names_) { | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| param_aggrs_[param_name]->ResetAggregationStatus(); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr); | |||
| param_aggr->ResetAggregationStatus(); | |||
| } | |||
| return; | |||
| } | |||
| @@ -90,6 +90,7 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string & | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); | |||
| if (server_node_->rank_id() == kLeaderServerRank) { | |||
| if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) { | |||
| MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed."; | |||
| @@ -114,10 +115,7 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string & | |||
| void Iteration::SetIterationRunning() { | |||
| MS_LOG(INFO) << "Iteration " << iteration_num_ << " start running."; | |||
| if (server_node_ == nullptr) { | |||
| MS_LOG(ERROR) << "Server node is empty."; | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); | |||
| if (server_node_->rank_id() == kLeaderServerRank) { | |||
| // This event helps worker/server to be consistent in iteration state. | |||
| server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning)); | |||
| @@ -127,10 +125,7 @@ void Iteration::SetIterationRunning() { | |||
| void Iteration::SetIterationCompleted() { | |||
| MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes."; | |||
| if (server_node_ == nullptr) { | |||
| MS_LOG(ERROR) << "Server node is empty."; | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); | |||
| if (server_node_->rank_id() == kLeaderServerRank) { | |||
| // This event helps worker/server to be consistent in iteration state. | |||
| server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted)); | |||
| @@ -167,6 +162,7 @@ const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return ro | |||
| bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; } | |||
| bool Iteration::SyncIteration(uint32_t rank) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false); | |||
| SyncIterationRequest sync_iter_req; | |||
| sync_iter_req.set_rank(rank); | |||
| @@ -190,10 +186,8 @@ bool Iteration::SyncIteration(uint32_t rank) { | |||
| } | |||
| void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| MS_LOG(ERROR) << "Message is nullptr."; | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(message); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(communicator_); | |||
| SyncIterationRequest sync_iter_req; | |||
| (void)sync_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| @@ -220,6 +214,7 @@ bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) { | |||
| } | |||
| bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false); | |||
| MS_LOG(INFO) << "Notify leader server to control the cluster to proceed to next iteration."; | |||
| NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req; | |||
| notify_leader_to_next_iter_req.set_rank(server_node_->rank_id()); | |||
| @@ -235,10 +230,8 @@ bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const s | |||
| } | |||
| void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(message); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(communicator_); | |||
| NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp; | |||
| notify_leader_to_next_iter_rsp.set_result("success"); | |||
| if (!communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(), | |||
| @@ -275,8 +268,8 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps | |||
| } | |||
| bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false); | |||
| PrepareForNextIter(); | |||
| MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration."; | |||
| PrepareForNextIterRequest prepare_next_iter_req; | |||
| prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid); | |||
| @@ -307,10 +300,8 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons | |||
| } | |||
| void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(message); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(communicator_); | |||
| PrepareForNextIterRequest prepare_next_iter_req; | |||
| (void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const auto &reason = prepare_next_iter_req.reason(); | |||
| @@ -332,6 +323,7 @@ void Iteration::PrepareForNextIter() { | |||
| } | |||
| bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false); | |||
| MS_LOG(INFO) << "Notify all follower servers to proceed to next iteration. Set last iteration number " | |||
| << iteration_num_; | |||
| MoveToNextIterRequest proceed_to_next_iter_req; | |||
| @@ -350,10 +342,8 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st | |||
| } | |||
| void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(message); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(communicator_); | |||
| MoveToNextIterResponse proceed_to_next_iter_rsp; | |||
| proceed_to_next_iter_rsp.set_result("success"); | |||
| if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(), | |||
| @@ -397,6 +387,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { | |||
| } | |||
| bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false); | |||
| MS_LOG(INFO) << "Notify all follower servers to end last iteration."; | |||
| EndLastIterRequest end_last_iter_req; | |||
| end_last_iter_req.set_last_iter_num(last_iter_num); | |||
| @@ -412,10 +403,8 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) { | |||
| } | |||
| void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(message); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(communicator_); | |||
| EndLastIterRequest end_last_iter_req; | |||
| (void)end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const auto &last_iter_num = end_last_iter_req.last_iter_num(); | |||
| @@ -456,7 +445,9 @@ void Iteration::EndLastIter() { | |||
| ModelStore::GetInstance().Reset(); | |||
| } | |||
| std::unique_lock<std::mutex> lock(pinned_mtx_); | |||
| pinned_iter_num_ = 0; | |||
| lock.unlock(); | |||
| LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); | |||
| Server::GetInstance().CancelSafeMode(); | |||
| SetIterationCompleted(); | |||
| @@ -22,6 +22,7 @@ namespace fl { | |||
| namespace server { | |||
| namespace kernel { | |||
| bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kNameToIdxMap.count(cnode_name) == 0) { | |||
| MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name; | |||
| @@ -36,6 +36,7 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne | |||
| ~ApplyMomentumKernel() override = default; | |||
| void InitKernel(const CNodePtr &cnode) override { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| ApplyMomentumCPUKernel::InitKernel(cnode); | |||
| InitServerKernelInputOutputSize(cnode); | |||
| GenerateReuseKernelNodeInfo(); | |||
| @@ -29,6 +29,7 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| namespace kernel { | |||
| constexpr size_t kDenseGradAccumKernelInputsNum = 2; | |||
| template <typename T> | |||
| class DenseGradAccumKernel : public AggregationKernel { | |||
| public: | |||
| @@ -53,8 +54,15 @@ class DenseGradAccumKernel : public AggregationKernel { | |||
| return; | |||
| } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &) override { | |||
| if (inputs.size() != kDenseGradAccumKernelInputsNum) { | |||
| MS_LOG(ERROR) << "The inputs number of DenseGradAccumKernel should be 2, but got " << inputs.size(); | |||
| return false; | |||
| } | |||
| MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false); | |||
| if (accum_count_ == 0) { | |||
| int ret = memset_s(inputs[0]->addr, inputs[0]->size, 0x00, inputs[0]->size); | |||
| if (ret != 0) { | |||
| @@ -34,6 +34,7 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| namespace kernel { | |||
| constexpr size_t kFedAvgInputsNum = 4; | |||
| // The implementation for the federated average. We do weighted average for the weights. The uploaded weights from | |||
| // FL-clients is already multiplied by its data size so only sum and division are done in this kernel. | |||
| @@ -106,6 +107,15 @@ class FedAvgKernel : public AggregationKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override { | |||
| if (inputs.size() != kFedAvgInputsNum) { | |||
| MS_LOG(ERROR) << "The inputs number of FedAvgKernel should be 4, but got " << inputs.size(); | |||
| return false; | |||
| } | |||
| MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(inputs[2]->addr, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(inputs[3]->addr, false); | |||
| std::unique_lock<std::mutex> lock(weight_mutex_); | |||
| // The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication | |||
| // again. | |||
| @@ -160,12 +170,16 @@ class FedAvgKernel : public AggregationKernel { | |||
| void GenerateReuseKernelNodeInfo() override { | |||
| MS_LOG(INFO) << "FedAvg reuse 'weight' of the kernel node."; | |||
| // Only the trainable parameter is reused for federated average. | |||
| reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_)); | |||
| (void)reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_)); | |||
| return; | |||
| } | |||
| // In some cases, the Launch method is not called and the weights involved in AllReduce should be set to 0. | |||
| void ClearWeightAndDataSize() { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_->addr); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_->addr); | |||
| int ret = memset_s(weight_addr_->addr, weight_addr_->size, 0x00, weight_addr_->size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")"; | |||
| @@ -22,6 +22,7 @@ namespace fl { | |||
| namespace server { | |||
| namespace kernel { | |||
| bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kNameToIdxMap.count(cnode_name) == 0) { | |||
| MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name; | |||
| @@ -61,7 +62,6 @@ bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodeP | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -41,11 +41,30 @@ void GetModelKernel::InitKernel(size_t) { | |||
| bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||
| std::string reason = "inputs or outputs size is invalid."; | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| void *req_data = inputs[0]->addr; | |||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||
| if (fbb == nullptr || req_data == nullptr) { | |||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | |||
| return false; | |||
| std::string reason = "FBBuilder builder or req_data is nullptr."; | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size); | |||
| if (!verifier.VerifyBuffer<schema::RequestGetModel>()) { | |||
| std::string reason = "The schema of RequestGetModel is invalid."; | |||
| BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(), {}, | |||
| ""); | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| ++retry_count_; | |||
| @@ -55,8 +74,10 @@ bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve | |||
| const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data); | |||
| if (get_model_req == nullptr) { | |||
| MS_LOG(ERROR) << "RequestGetModel is nullptr."; | |||
| return false; | |||
| std::string reason = "Building flatbuffers schema failed for RequestGetModel."; | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| GetModel(get_model_req, fbb); | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| @@ -110,6 +131,10 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con | |||
| const std::string &reason, const size_t iter, | |||
| const std::map<std::string, AddressPtr> &feature_maps, | |||
| const std::string ×tamp) { | |||
| if (fbb == nullptr) { | |||
| MS_LOG(ERROR) << "Input fbb is nullptr."; | |||
| return; | |||
| } | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| auto fbs_timestamp = fbb->CreateString(timestamp); | |||
| std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; | |||
| @@ -65,6 +65,10 @@ bool PullWeightKernel::Reset() { | |||
| void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestPullWeight *pull_weight_req) { | |||
| if (fbb == nullptr || pull_weight_req == nullptr) { | |||
| MS_LOG(ERROR) << "fbb or pull_weight_req is nullptr."; | |||
| return; | |||
| } | |||
| std::map<std::string, AddressPtr> feature_maps = {}; | |||
| size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| size_t pull_weight_iter = IntToSize(pull_weight_req->iteration()); | |||
| @@ -114,6 +118,10 @@ void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, size_t iteration, | |||
| const std::map<std::string, AddressPtr> &feature_maps) { | |||
| if (fbb == nullptr) { | |||
| MS_LOG(ERROR) << "fbb is nullptr."; | |||
| return; | |||
| } | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; | |||
| for (auto feature_map : feature_maps) { | |||
| @@ -36,8 +36,19 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||
| void *req_data = inputs[0]->addr; | |||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||
| if (fbb == nullptr || req_data == nullptr) { | |||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | |||
| return false; | |||
| std::string reason = "FBBuilder builder or req_data is nullptr."; | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size); | |||
| if (!verifier.VerifyBuffer<schema::RequestPushWeight>()) { | |||
| std::string reason = "The schema of RequestPushWeight is invalid."; | |||
| BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num()); | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| const schema::RequestPushWeight *push_weight_req = flatbuffers::GetRoot<schema::RequestPushWeight>(req_data); | |||
| @@ -69,9 +80,8 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH | |||
| ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestPushWeight *push_weight_req) { | |||
| if (fbb == nullptr || push_weight_req == nullptr) { | |||
| return ResultCode::kSuccessAndReturn; | |||
| } | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, ResultCode::kSuccessAndReturn); | |||
| size_t iteration = IntToSize(push_weight_req->iteration()); | |||
| size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| if (iteration != current_iter) { | |||
| @@ -110,10 +120,10 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb, | |||
| } | |||
| std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) { | |||
| RETURN_IF_NULL(push_weight_req, {}); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, {}); | |||
| std::map<std::string, Address> upload_feature_map; | |||
| auto fbs_feature_map = push_weight_req->feature_map(); | |||
| RETURN_IF_NULL(push_weight_req, upload_feature_map); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, upload_feature_map); | |||
| for (size_t i = 0; i < fbs_feature_map->size(); i++) { | |||
| std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str(); | |||
| float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data()); | |||
| @@ -125,6 +135,10 @@ std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::R | |||
| void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, size_t iteration) { | |||
| if (fbb == nullptr) { | |||
| MS_LOG(ERROR) << "Input fbb is nullptr."; | |||
| return; | |||
| } | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get())); | |||
| rsp_push_weight_builder.add_retcode(retcode); | |||
| @@ -53,19 +53,23 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||
| const std::vector<AddressPtr> &outputs) { | |||
| MS_LOG(INFO) << "Launching StartFLJobKernel kernel."; | |||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||
| MS_LOG(ERROR) << "inputs or outputs size is invalid."; | |||
| return false; | |||
| std::string reason = "inputs or outputs size is invalid."; | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| void *req_data = inputs[0]->addr; | |||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||
| if (fbb == nullptr || req_data == nullptr) { | |||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | |||
| return false; | |||
| std::string reason = "FBBuilder builder or req_data is nullptr."; | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size); | |||
| if (!verifier.VerifyBuffer<schema::RequestFLJob>()) { | |||
| std::string reason = "The schema of startFLJob is invalid."; | |||
| std::string reason = "The schema of RequestFLJob is invalid."; | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, ""); | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| @@ -80,9 +84,15 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||
| const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data); | |||
| if (start_fl_job_req == nullptr) { | |||
| MS_LOG(ERROR) << "RequestFLJob is nullptr."; | |||
| return false; | |||
| std::string reason = "Building flatbuffers schema failed for RequestFLJob."; | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_RequestError, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req); | |||
| result_code = ReadyForStartFLJob(fbb, device_meta); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| @@ -94,9 +104,9 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||
| std::string update_reason = ""; | |||
| if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) { | |||
| std::string reason = "Updating device metadata failed. " + update_reason; | |||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), | |||
| {}); | |||
| BuildStartFLJobRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, false, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return update_reason == kNetworkError ? false : true; | |||
| } | |||
| @@ -141,7 +151,7 @@ ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<F | |||
| } | |||
| DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) { | |||
| RETURN_IF_NULL(start_fl_job_req, {}); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, {}); | |||
| std::string fl_name = start_fl_job_req->fl_name()->str(); | |||
| std::string fl_id = start_fl_job_req->fl_id()->str(); | |||
| int data_size = start_fl_job_req->data_size(); | |||
| @@ -172,7 +182,7 @@ ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> | |||
| ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestFLJob *start_fl_job_req) { | |||
| RETURN_IF_NULL(start_fl_job_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kSuccessAndReturn); | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) { | |||
| std::string reason = "Counting start fl job request failed. Please retry later."; | |||
| @@ -201,6 +211,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, | |||
| const std::string &reason, const bool is_selected, | |||
| const std::string &next_req_time, | |||
| std::map<std::string, AddressPtr> feature_maps) { | |||
| if (fbb == nullptr) { | |||
| MS_LOG(ERROR) << "Input fbb is nullptr."; | |||
| return; | |||
| } | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| auto fbs_next_req_time = fbb->CreateString(next_req_time); | |||
| auto fbs_server_mode = fbb->CreateString(ps::PSContext::instance()->server_mode()); | |||
| @@ -44,18 +44,32 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) { | |||
| bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| MS_LOG(INFO) << "Launching UpdateModelKernel kernel."; | |||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||
| MS_LOG(ERROR) << "inputs or outputs size is invalid."; | |||
| return false; | |||
| std::string reason = "inputs or outputs size is invalid."; | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| void *req_data = inputs[0]->addr; | |||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||
| if (fbb == nullptr || req_data == nullptr) { | |||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | |||
| return false; | |||
| std::string reason = "FBBuilder builder or req_data is nullptr."; | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, reason.c_str(), reason.size()); | |||
| return true; | |||
| } | |||
| flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size); | |||
| if (!verifier.VerifyBuffer<schema::RequestUpdateModel>()) { | |||
| std::string reason = "The schema of RequestUpdateModel is invalid."; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| MS_LOG(INFO) << "Launching UpdateModelKernel kernel."; | |||
| ResultCode result_code = ReachThresholdForUpdateModel(fbb); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| @@ -63,6 +77,14 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std: | |||
| } | |||
| const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data); | |||
| if (update_model_req == nullptr) { | |||
| std::string reason = "Building flatbuffers schema failed for RequestUpdateModel."; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| MS_LOG(ERROR) << reason; | |||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return true; | |||
| } | |||
| result_code = UpdateModel(update_model_req, fbb); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| MS_LOG(ERROR) << "Updating model failed."; | |||
| @@ -119,7 +141,7 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr | |||
| ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, | |||
| const std::shared_ptr<FBBuilder> &fbb) { | |||
| RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| size_t iteration = IntToSize(update_model_req->iteration()); | |||
| if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { | |||
| std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + | |||
| @@ -149,9 +171,7 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| auto feature_map = ParseFeatureMap(update_model_req); | |||
| if (feature_map.empty()) { | |||
| std::string reason = "Feature map is empty."; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_RequestError, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| MS_LOG(ERROR) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| } | |||
| @@ -190,10 +210,10 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap( | |||
| const schema::RequestUpdateModel *update_model_req) { | |||
| RETURN_IF_NULL(update_model_req, {}); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {}); | |||
| std::map<std::string, UploadData> feature_map; | |||
| auto fbs_feature_map = update_model_req->feature_map(); | |||
| RETURN_IF_NULL(fbs_feature_map, feature_map); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map); | |||
| for (size_t i = 0; i < fbs_feature_map->size(); i++) { | |||
| std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str(); | |||
| float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data()); | |||
| @@ -208,7 +228,8 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap( | |||
| ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestUpdateModel *update_model_req) { | |||
| RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) { | |||
| std::string reason = "Counting for update model request failed. Please retry later. " + count_reason; | |||
| @@ -223,6 +244,10 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilde | |||
| void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||
| const std::string &reason, const std::string &next_req_time) { | |||
| if (fbb == nullptr) { | |||
| MS_LOG(ERROR) << "Input fbb is nullptr."; | |||
| return; | |||
| } | |||
| auto fbs_reason = fbb->CreateString(reason); | |||
| auto fbs_next_req_time = fbb->CreateString(next_req_time); | |||
| @@ -21,16 +21,29 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(address); | |||
| (void)addresses_.try_emplace(name, address); | |||
| } | |||
| void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) { float_arrays_.push_back(std::move(*array)); } | |||
| void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(array); | |||
| float_arrays_.push_back(std::move(*array)); | |||
| } | |||
| void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); } | |||
| void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(array); | |||
| int32_arrays_.push_back(std::move(*array)); | |||
| } | |||
| void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); } | |||
| void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(array); | |||
| uint64_arrays_.push_back(std::move(*array)); | |||
| } | |||
| void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); } | |||
| void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(array); | |||
| char_arrays_.push_back(std::move(*array)); | |||
| } | |||
| } // namespace server | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| @@ -46,7 +46,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin | |||
| return false; | |||
| } | |||
| std::shared_ptr<MemoryRegister> memory_register; | |||
| std::shared_ptr<MemoryRegister> memory_register = nullptr; | |||
| if (iteration_to_model_.size() < max_model_count_) { | |||
| // If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model. | |||
| memory_register = AssignNewModelMemory(); | |||
| @@ -123,10 +123,14 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() { | |||
| // Assign new memory for the model. | |||
| std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>(); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(memory_register, nullptr); | |||
| for (const auto &weight : model) { | |||
| const std::string weight_name = weight.first; | |||
| size_t weight_size = weight.second->size; | |||
| auto weight_data = std::make_unique<char[]>(weight_size); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(weight_data, nullptr); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(weight.second, nullptr); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(weight.second->addr, nullptr); | |||
| if (weight_data == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Assign memory for weight failed."; | |||
| return nullptr; | |||
| @@ -139,7 +143,6 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return nullptr; | |||
| } | |||
| memory_register->RegisterArray(weight_name, &weight_data, weight_size); | |||
| } | |||
| return memory_register; | |||
| @@ -80,8 +80,7 @@ bool ParameterAggregator::LaunchAggregators() { | |||
| for (auto &aggregator_with_params : aggregation_kernel_parameters_) { | |||
| KernelParams ¶ms = aggregator_with_params.second; | |||
| std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first; | |||
| RETURN_IF_NULL(aggr_kernel, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false); | |||
| bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed."; | |||
| @@ -95,8 +94,7 @@ bool ParameterAggregator::LaunchOptimizers() { | |||
| for (auto &optimizer_with_params : optimizer_kernel_parameters_) { | |||
| KernelParams ¶ms = optimizer_with_params.second; | |||
| std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel = optimizer_with_params.first; | |||
| RETURN_IF_NULL(optimizer_kernel, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false); | |||
| bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed."; | |||
| @@ -158,7 +156,7 @@ bool ParameterAggregator::IsAggregationDone() const { | |||
| // Only consider aggregation done after each aggregation kernel is done. | |||
| for (auto &aggregator_with_params : aggregation_kernel_parameters_) { | |||
| std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first; | |||
| RETURN_IF_NULL(aggr_kernel, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false); | |||
| if (!aggr_kernel->IsAggregationDone()) { | |||
| return false; | |||
| } | |||
| @@ -276,8 +274,8 @@ bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode, | |||
| bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel, | |||
| const std::shared_ptr<MemoryRegister> memory_register) { | |||
| RETURN_IF_NULL(aggr_kernel, false); | |||
| RETURN_IF_NULL(memory_register, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false); | |||
| KernelParams aggr_params = {}; | |||
| const std::vector<std::string> &input_names = aggr_kernel->input_names(); | |||
| @@ -299,8 +297,8 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr< | |||
| bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel, | |||
| const std::shared_ptr<MemoryRegister> memory_register) { | |||
| RETURN_IF_NULL(optimizer_kernel, false); | |||
| RETURN_IF_NULL(memory_register, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false); | |||
| KernelParams optimizer_params = {}; | |||
| const std::vector<std::string> &input_names = optimizer_kernel->input_names(); | |||
| @@ -107,11 +107,7 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) | |||
| } | |||
| void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) { | |||
| if (message == nullptr) { | |||
| MS_LOG(ERROR) << "Message is nullptr."; | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(message); | |||
| // If the server is still in the process of scaling, refuse the request. | |||
| if (Server::GetInstance().IsSafeMode()) { | |||
| MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later."; | |||
| @@ -125,6 +121,8 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m | |||
| AddressPtr input = std::make_shared<Address>(); | |||
| AddressPtr output = std::make_shared<Address>(); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(input); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(output); | |||
| input->addr = message->data(); | |||
| input->size = message->len(); | |||
| bool ret = kernel_->Launch({input}, {}, {output}); | |||
| @@ -137,7 +137,6 @@ void Server::InitCluster() { | |||
| bool Server::InitCommunicatorWithServer() { | |||
| MS_EXCEPTION_IF_NULL(task_executor_); | |||
| MS_EXCEPTION_IF_NULL(server_node_); | |||
| communicator_with_server_ = | |||
| server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_); | |||
| MS_EXCEPTION_IF_NULL(communicator_with_server_); | |||
| @@ -395,10 +394,7 @@ void Server::ProcessBeforeScalingIn() { | |||
| void Server::ProcessAfterScalingOut() { | |||
| std::unique_lock<std::mutex> lock(scaling_mtx_); | |||
| if (server_node_ == nullptr) { | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); | |||
| if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) { | |||
| MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed."; | |||
| } | |||
| @@ -420,10 +416,7 @@ void Server::ProcessAfterScalingOut() { | |||
| void Server::ProcessAfterScalingIn() { | |||
| std::unique_lock<std::mutex> lock(scaling_mtx_); | |||
| if (server_node_ == nullptr) { | |||
| return; | |||
| } | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); | |||
| if (server_node_->rank_id() == UINT32_MAX) { | |||
| MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting."; | |||
| (void)std::for_each( | |||
| @@ -233,6 +233,22 @@ class LogWriter { | |||
| } \ | |||
| } while (0) | |||
| #define MS_ERROR_IF_NULL_W_RET_VAL(ptr, val) \ | |||
| do { \ | |||
| if ((ptr) == nullptr) { \ | |||
| MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \ | |||
| return val; \ | |||
| } \ | |||
| } while (0) | |||
| #define MS_ERROR_IF_NULL_WO_RET_VAL(ptr) \ | |||
| do { \ | |||
| if ((ptr) == nullptr) { \ | |||
| MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \ | |||
| return; \ | |||
| } \ | |||
| } while (0) | |||
| #ifdef DEBUG | |||
| #include <cassert> | |||
| #define MS_ASSERT(f) assert(f) | |||