Browse Source

!20888 Fix code review

Merge pull request !20888 from ZPaC/static
tags/v1.4.0
i-robot Gitee 4 years ago
parent
commit
621c5e5041
22 changed files with 250 additions and 118 deletions
  1. +8
    -0
      mindspore/ccsrc/fl/server/collective_ops_impl.cc
  2. +3
    -9
      mindspore/ccsrc/fl/server/common.h
  3. +3
    -2
      mindspore/ccsrc/fl/server/distributed_count_service.cc
  4. +5
    -2
      mindspore/ccsrc/fl/server/distributed_metadata_store.cc
  5. +12
    -8
      mindspore/ccsrc/fl/server/executor.cc
  6. +20
    -29
      mindspore/ccsrc/fl/server/iteration.cc
  7. +1
    -0
      mindspore/ccsrc/fl/server/kernel/aggregation_kernel_factory.cc
  8. +1
    -0
      mindspore/ccsrc/fl/server/kernel/apply_momentum_kernel.h
  9. +10
    -2
      mindspore/ccsrc/fl/server/kernel/dense_grad_accum_kernel.h
  10. +15
    -1
      mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h
  11. +1
    -1
      mindspore/ccsrc/fl/server/kernel/optimizer_kernel_factory.cc
  12. +29
    -4
      mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc
  13. +8
    -0
      mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc
  14. +21
    -7
      mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc
  15. +26
    -12
      mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc
  16. +37
    -12
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc
  17. +17
    -4
      mindspore/ccsrc/fl/server/memory_register.cc
  18. +5
    -2
      mindspore/ccsrc/fl/server/model_store.cc
  19. +7
    -9
      mindspore/ccsrc/fl/server/parameter_aggregator.cc
  20. +3
    -5
      mindspore/ccsrc/fl/server/round.cc
  21. +2
    -9
      mindspore/ccsrc/fl/server/server.cc
  22. +16
    -0
      mindspore/core/utils/log_adapter.h

+ 8
- 0
mindspore/ccsrc/fl/server/collective_ops_impl.cc View File

@@ -29,6 +29,8 @@ void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &

template <typename T>
bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@@ -135,6 +137,12 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size

template <typename T>
bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
if (sendbuff == nullptr || recvbuff == nullptr) {
MS_LOG(ERROR) << "Input sendbuff or recvbuff for ReduceBroadcastAllReduce is nullptr.";
return false;
}
uint32_t rank_size = server_num_;
MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_
<< ", count:" << count;


+ 3
- 9
mindspore/ccsrc/fl/server/common.h View File

@@ -201,12 +201,6 @@ constexpr auto kCtxCipherPrimer = "cipher_primer";
#define CURRENT_TIME_MILLI \
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())

#define RETURN_IF_NULL(expr, ret) \
if (expr == nullptr) { \
MS_LOG(ERROR) << #expr << " is nullptr."; \
return ret; \
}

// This method returns the size in bytes of the given TypeId.
inline size_t GetTypeIdByte(const TypeId &type) {
switch (type) {
@@ -224,12 +218,12 @@ inline size_t GetTypeIdByte(const TypeId &type) {
}

inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size_t param_idx) {
RETURN_IF_NULL(kernel_node, nullptr);
MS_ERROR_IF_NULL_W_RET_VAL(kernel_node, nullptr);
auto param_node =
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, param_idx), 0).first->cast<ParameterPtr>();
RETURN_IF_NULL(param_node, nullptr);
MS_ERROR_IF_NULL_W_RET_VAL(param_node, nullptr);
auto param_tensor = param_node->default_param()->cast<tensor::TensorPtr>();
RETURN_IF_NULL(param_tensor, nullptr);
MS_ERROR_IF_NULL_W_RET_VAL(param_tensor, nullptr);
AddressPtr addr = std::make_shared<kernel::Address>();
addr->addr = param_tensor->data_c();
addr->size = param_tensor->data().nbytes();


+ 3
- 2
mindspore/ccsrc/fl/server/distributed_count_service.cc View File

@@ -24,8 +24,8 @@ namespace fl {
namespace server {
void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node,
uint32_t counting_server_rank) {
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
MS_EXCEPTION_IF_NULL(server_node_);
local_rank_ = server_node_->rank_id();
server_num_ = ps::PSContext::instance()->initial_server_num();
counting_server_rank_ = counting_server_rank;
@@ -33,8 +33,8 @@ void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerN
}

void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
MS_EXCEPTION_IF_NULL(communicator_);
communicator_->RegisterMsgCallBack(
"count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
@@ -107,6 +107,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
(void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
if (!count_rsp.result()) {
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
// If the error is caused by the network issue, return the reason.
if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
*reason = kNetworkError;
}


+ 5
- 2
mindspore/ccsrc/fl/server/distributed_metadata_store.cc View File

@@ -23,8 +23,8 @@ namespace mindspore {
namespace fl {
namespace server {
void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
server_node_ = server_node;
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
local_rank_ = server_node_->rank_id();
server_num_ = ps::PSContext::instance()->initial_server_num();
InitHashRing();
@@ -32,8 +32,8 @@ void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::Server
}

void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
MS_EXCEPTION_IF_NULL(communicator_);
communicator_->RegisterMsgCallBack(
"updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
@@ -109,6 +109,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
return false;
}

MS_ERROR_IF_NULL_W_RET_VAL(update_meta_rsp_msg, false);
std::string update_meta_rsp =
std::string(reinterpret_cast<char *>(update_meta_rsp_msg->data()), update_meta_rsp_msg->size());
if (update_meta_rsp != kSuccess) {
@@ -141,6 +142,8 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
return get_metadata_rsp;
}

MS_ERROR_IF_NULL_W_RET_VAL(get_meta_rsp_msg, get_metadata_rsp);
(void)get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
return get_metadata_rsp;
}


+ 12
- 8
mindspore/ccsrc/fl/server/executor.cc View File

@@ -63,7 +63,7 @@ bool Executor::HandlePush(const std::string &param_name, const UploadData &uploa
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
// Push operation needs to wait until the pulling process is done.
while (!param_aggr->IsPullingDone()) {
lock.unlock();
@@ -106,7 +106,7 @@ bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
if (!param_aggr->UpdateData(upload_data)) {
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
return false;
@@ -131,7 +131,7 @@ bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &f
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
const UploadData &upload_data = trainable_param.second;
if (!param_aggr->UpdateData(upload_data)) {
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
@@ -156,7 +156,7 @@ bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_ma
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
AddressPtr old_weight = param_aggr->GetWeight();
if (old_weight == nullptr) {
MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
@@ -188,7 +188,7 @@ AddressPtr Executor::HandlePull(const std::string &param_name) {
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, nullptr);
// Pulling must wait until the optimizing process is done.
while (!param_aggr->IsOptimizingDone()) {
lock.unlock();
@@ -214,7 +214,7 @@ std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<s
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
const auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights);
AddressPtr addr = param_aggr->GetWeight();
if (addr == nullptr) {
MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
@@ -236,7 +236,9 @@ bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {

std::mutex &mtx = parameter_mutex_[name];
std::unique_lock<std::mutex> lock(mtx);
if (!param_aggrs_[name]->IsAggregationDone()) {
auto &param_aggr = param_aggrs_[name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
if (!param_aggr->IsAggregationDone()) {
MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
return false;
}
@@ -248,7 +250,9 @@ void Executor::ResetAggregationStatus() {
for (const auto &param_name : param_names_) {
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
param_aggrs_[param_name]->ResetAggregationStatus();
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr);
param_aggr->ResetAggregationStatus();
}
return;
}


+ 20
- 29
mindspore/ccsrc/fl/server/iteration.cc View File

@@ -90,6 +90,7 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &
return;
}

MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == kLeaderServerRank) {
if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
@@ -114,10 +115,7 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &

void Iteration::SetIterationRunning() {
MS_LOG(INFO) << "Iteration " << iteration_num_ << " start running.";
if (server_node_ == nullptr) {
MS_LOG(ERROR) << "Server node is empty.";
return;
}
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning));
@@ -127,10 +125,7 @@ void Iteration::SetIterationRunning() {

void Iteration::SetIterationCompleted() {
MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes.";
if (server_node_ == nullptr) {
MS_LOG(ERROR) << "Server node is empty.";
return;
}
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted));
@@ -167,6 +162,7 @@ const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return ro
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }

bool Iteration::SyncIteration(uint32_t rank) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
SyncIterationRequest sync_iter_req;
sync_iter_req.set_rank(rank);

@@ -190,10 +186,8 @@ bool Iteration::SyncIteration(uint32_t rank) {
}

void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);

SyncIterationRequest sync_iter_req;
(void)sync_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
@@ -220,6 +214,7 @@ bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
}

bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
MS_LOG(INFO) << "Notify leader server to control the cluster to proceed to next iteration.";
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
notify_leader_to_next_iter_req.set_rank(server_node_->rank_id());
@@ -235,10 +230,8 @@ bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const s
}

void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
return;
}

MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp;
notify_leader_to_next_iter_rsp.set_result("success");
if (!communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
@@ -275,8 +268,8 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps
}

bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
PrepareForNextIter();

MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration.";
PrepareForNextIterRequest prepare_next_iter_req;
prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
@@ -307,10 +300,8 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
}

void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
return;
}

MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
PrepareForNextIterRequest prepare_next_iter_req;
(void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &reason = prepare_next_iter_req.reason();
@@ -332,6 +323,7 @@ void Iteration::PrepareForNextIter() {
}

bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
MS_LOG(INFO) << "Notify all follower servers to proceed to next iteration. Set last iteration number "
<< iteration_num_;
MoveToNextIterRequest proceed_to_next_iter_req;
@@ -350,10 +342,8 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
}

void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
return;
}

MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
MoveToNextIterResponse proceed_to_next_iter_rsp;
proceed_to_next_iter_rsp.set_result("success");
if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
@@ -397,6 +387,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
}

bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
MS_LOG(INFO) << "Notify all follower servers to end last iteration.";
EndLastIterRequest end_last_iter_req;
end_last_iter_req.set_last_iter_num(last_iter_num);
@@ -412,10 +403,8 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
}

void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
return;
}

MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
EndLastIterRequest end_last_iter_req;
(void)end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &last_iter_num = end_last_iter_req.last_iter_num();
@@ -456,7 +445,9 @@ void Iteration::EndLastIter() {
ModelStore::GetInstance().Reset();
}

std::unique_lock<std::mutex> lock(pinned_mtx_);
pinned_iter_num_ = 0;
lock.unlock();
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
Server::GetInstance().CancelSafeMode();
SetIterationCompleted();


+ 1
- 0
mindspore/ccsrc/fl/server/kernel/aggregation_kernel_factory.cc View File

@@ -22,6 +22,7 @@ namespace fl {
namespace server {
namespace kernel {
bool AggregationKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
if (kNameToIdxMap.count(cnode_name) == 0) {
MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name;


+ 1
- 0
mindspore/ccsrc/fl/server/kernel/apply_momentum_kernel.h View File

@@ -36,6 +36,7 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
~ApplyMomentumKernel() override = default;

void InitKernel(const CNodePtr &cnode) override {
MS_EXCEPTION_IF_NULL(cnode);
ApplyMomentumCPUKernel::InitKernel(cnode);
InitServerKernelInputOutputSize(cnode);
GenerateReuseKernelNodeInfo();


+ 10
- 2
mindspore/ccsrc/fl/server/kernel/dense_grad_accum_kernel.h View File

@@ -29,6 +29,7 @@ namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
constexpr size_t kDenseGradAccumKernelInputsNum = 2;
template <typename T>
class DenseGradAccumKernel : public AggregationKernel {
public:
@@ -53,8 +54,15 @@ class DenseGradAccumKernel : public AggregationKernel {
return;
}

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) override {
if (inputs.size() != kDenseGradAccumKernelInputsNum) {
MS_LOG(ERROR) << "The inputs number of DenseGradAccumKernel should be 2, but got " << inputs.size();
return false;
}
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false);

if (accum_count_ == 0) {
int ret = memset_s(inputs[0]->addr, inputs[0]->size, 0x00, inputs[0]->size);
if (ret != 0) {


+ 15
- 1
mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h View File

@@ -34,6 +34,7 @@ namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
constexpr size_t kFedAvgInputsNum = 4;
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
// FL-clients is already multiplied by its data size so only sum and division are done in this kernel.

@@ -106,6 +107,15 @@ class FedAvgKernel : public AggregationKernel {

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
if (inputs.size() != kFedAvgInputsNum) {
MS_LOG(ERROR) << "The inputs number of FedAvgKernel should be 4, but got " << inputs.size();
return false;
}
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[2]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[3]->addr, false);

std::unique_lock<std::mutex> lock(weight_mutex_);
// The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication
// again.
@@ -160,12 +170,16 @@ class FedAvgKernel : public AggregationKernel {
void GenerateReuseKernelNodeInfo() override {
MS_LOG(INFO) << "FedAvg reuse 'weight' of the kernel node.";
// Only the trainable parameter is reused for federated average.
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
(void)reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
return;
}

// In some cases, the Launch method is not called and the weights involved in AllReduce should be set to 0.
void ClearWeightAndDataSize() {
MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_);
MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_);
MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_->addr);
MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_->addr);
int ret = memset_s(weight_addr_->addr, weight_addr_->size, 0x00, weight_addr_->size);
if (ret != 0) {
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/optimizer_kernel_factory.cc View File

@@ -22,6 +22,7 @@ namespace fl {
namespace server {
namespace kernel {
bool OptimizerKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
if (kNameToIdxMap.count(cnode_name) == 0) {
MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name;
@@ -61,7 +62,6 @@ bool OptimizerKernelFactory::Matched(const ParamsInfo &params_info, const CNodeP
return false;
}
}

return true;
}
} // namespace kernel


+ 29
- 4
mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc View File

@@ -41,11 +41,30 @@ void GetModelKernel::InitKernel(size_t) {

bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}

void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}

flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestGetModel>()) {
std::string reason = "The schema of RequestGetModel is invalid.";
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(), {},
"");
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

++retry_count_;
@@ -55,8 +74,10 @@ bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve

const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
if (get_model_req == nullptr) {
MS_LOG(ERROR) << "RequestGetModel is nullptr.";
return false;
std::string reason = "Building flatbuffers schema failed for RequestGetModel.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
GetModel(get_model_req, fbb);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
@@ -110,6 +131,10 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con
const std::string &reason, const size_t iter,
const std::map<std::string, AddressPtr> &feature_maps,
const std::string &timestamp) {
if (fbb == nullptr) {
MS_LOG(ERROR) << "Input fbb is nullptr.";
return;
}
auto fbs_reason = fbb->CreateString(reason);
auto fbs_timestamp = fbb->CreateString(timestamp);
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;


+ 8
- 0
mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc View File

@@ -65,6 +65,10 @@ bool PullWeightKernel::Reset() {

void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestPullWeight *pull_weight_req) {
if (fbb == nullptr || pull_weight_req == nullptr) {
MS_LOG(ERROR) << "fbb or pull_weight_req is nullptr.";
return;
}
std::map<std::string, AddressPtr> feature_maps = {};
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t pull_weight_iter = IntToSize(pull_weight_req->iteration());
@@ -114,6 +118,10 @@ void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, size_t iteration,
const std::map<std::string, AddressPtr> &feature_maps) {
if (fbb == nullptr) {
MS_LOG(ERROR) << "fbb is nullptr.";
return;
}
auto fbs_reason = fbb->CreateString(reason);
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
for (auto feature_map : feature_maps) {


+ 21
- 7
mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc View File

@@ -36,8 +36,19 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}

flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestPushWeight>()) {
std::string reason = "The schema of RequestPushWeight is invalid.";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

const schema::RequestPushWeight *push_weight_req = flatbuffers::GetRoot<schema::RequestPushWeight>(req_data);
@@ -69,9 +80,8 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH

ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestPushWeight *push_weight_req) {
if (fbb == nullptr || push_weight_req == nullptr) {
return ResultCode::kSuccessAndReturn;
}
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, ResultCode::kSuccessAndReturn);
size_t iteration = IntToSize(push_weight_req->iteration());
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
if (iteration != current_iter) {
@@ -110,10 +120,10 @@ ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
}

std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {
RETURN_IF_NULL(push_weight_req, {});
MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, {});
std::map<std::string, Address> upload_feature_map;
auto fbs_feature_map = push_weight_req->feature_map();
RETURN_IF_NULL(push_weight_req, upload_feature_map);
MS_ERROR_IF_NULL_W_RET_VAL(push_weight_req, upload_feature_map);
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
@@ -125,6 +135,10 @@ std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::R

void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, size_t iteration) {
if (fbb == nullptr) {
MS_LOG(ERROR) << "Input fbb is nullptr.";
return;
}
auto fbs_reason = fbb->CreateString(reason);
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
rsp_push_weight_builder.add_retcode(retcode);


+ 26
- 12
mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc View File

@@ -53,19 +53,23 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching StartFLJobKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
return false;
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}

flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestFLJob>()) {
std::string reason = "The schema of startFLJob is invalid.";
std::string reason = "The schema of RequestFLJob is invalid.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, "");
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
@@ -80,9 +84,15 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::

const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
if (start_fl_job_req == nullptr) {
MS_LOG(ERROR) << "RequestFLJob is nullptr.";
return false;
std::string reason = "Building flatbuffers schema failed for RequestFLJob.";
BuildStartFLJobRsp(
fbb, schema::ResponseCode_RequestError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}

DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
result_code = ReadyForStartFLJob(fbb, device_meta);
if (result_code != ResultCode::kSuccess) {
@@ -94,9 +104,9 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
std::string update_reason = "";
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) {
std::string reason = "Updating device metadata failed. " + update_reason;
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
{});
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return update_reason == kNetworkError ? false : true;
}
@@ -141,7 +151,7 @@ ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<F
}

DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) {
RETURN_IF_NULL(start_fl_job_req, {});
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, {});
std::string fl_name = start_fl_job_req->fl_name()->str();
std::string fl_id = start_fl_job_req->fl_id()->str();
int data_size = start_fl_job_req->data_size();
@@ -172,7 +182,7 @@ ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder>

ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
RETURN_IF_NULL(start_fl_job_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(start_fl_job_req, ResultCode::kSuccessAndReturn);
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) {
std::string reason = "Counting start fl job request failed. Please retry later.";
@@ -201,6 +211,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
const std::string &reason, const bool is_selected,
const std::string &next_req_time,
std::map<std::string, AddressPtr> feature_maps) {
if (fbb == nullptr) {
MS_LOG(ERROR) << "Input fbb is nullptr.";
return;
}
auto fbs_reason = fbb->CreateString(reason);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
auto fbs_server_mode = fbb->CreateString(ps::PSContext::instance()->server_mode());


+ 37
- 12
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc View File

@@ -44,18 +44,32 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {

bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
return false;
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}

void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}

flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestUpdateModel>()) {
std::string reason = "The schema of RequestUpdateModel is invalid.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
ResultCode result_code = ReachThresholdForUpdateModel(fbb);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
@@ -63,6 +77,14 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
}

const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
if (update_model_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestUpdateModel.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

result_code = UpdateModel(update_model_req, fbb);
if (result_code != ResultCode::kSuccess) {
MS_LOG(ERROR) << "Updating model failed.";
@@ -119,7 +141,7 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr

ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb) {
RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
size_t iteration = IntToSize(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
@@ -149,9 +171,7 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
auto feature_map = ParseFeatureMap(update_model_req);
if (feature_map.empty()) {
std::string reason = "Feature map is empty.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_RequestError, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
}
@@ -190,10 +210,10 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda

std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
const schema::RequestUpdateModel *update_model_req) {
RETURN_IF_NULL(update_model_req, {});
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {});
std::map<std::string, UploadData> feature_map;
auto fbs_feature_map = update_model_req->feature_map();
RETURN_IF_NULL(fbs_feature_map, feature_map);
MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map);
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
@@ -208,7 +228,8 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(

ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req) {
RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) {
std::string reason = "Counting for update model request failed. Please retry later. " + count_reason;
@@ -223,6 +244,10 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilde

void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time) {
if (fbb == nullptr) {
MS_LOG(ERROR) << "Input fbb is nullptr.";
return;
}
auto fbs_reason = fbb->CreateString(reason);
auto fbs_next_req_time = fbb->CreateString(next_req_time);



+ 17
- 4
mindspore/ccsrc/fl/server/memory_register.cc View File

@@ -21,16 +21,29 @@ namespace mindspore {
namespace fl {
namespace server {
void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) {
MS_ERROR_IF_NULL_WO_RET_VAL(address);
(void)addresses_.try_emplace(name, address);
}

void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) { float_arrays_.push_back(std::move(*array)); }
void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) {
MS_ERROR_IF_NULL_WO_RET_VAL(array);
float_arrays_.push_back(std::move(*array));
}

void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); }
void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) {
MS_ERROR_IF_NULL_WO_RET_VAL(array);
int32_arrays_.push_back(std::move(*array));
}

void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); }
void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) {
MS_ERROR_IF_NULL_WO_RET_VAL(array);
uint64_arrays_.push_back(std::move(*array));
}

void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); }
void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) {
MS_ERROR_IF_NULL_WO_RET_VAL(array);
char_arrays_.push_back(std::move(*array));
}
} // namespace server
} // namespace fl
} // namespace mindspore

+ 5
- 2
mindspore/ccsrc/fl/server/model_store.cc View File

@@ -46,7 +46,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
return false;
}

std::shared_ptr<MemoryRegister> memory_register;
std::shared_ptr<MemoryRegister> memory_register = nullptr;
if (iteration_to_model_.size() < max_model_count_) {
// If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model.
memory_register = AssignNewModelMemory();
@@ -123,10 +123,14 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {

// Assign new memory for the model.
std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>();
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, nullptr);
for (const auto &weight : model) {
const std::string weight_name = weight.first;
size_t weight_size = weight.second->size;
auto weight_data = std::make_unique<char[]>(weight_size);
MS_ERROR_IF_NULL_W_RET_VAL(weight_data, nullptr);
MS_ERROR_IF_NULL_W_RET_VAL(weight.second, nullptr);
MS_ERROR_IF_NULL_W_RET_VAL(weight.second->addr, nullptr);
if (weight_data == nullptr) {
MS_LOG(EXCEPTION) << "Assign memory for weight failed.";
return nullptr;
@@ -139,7 +143,6 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return nullptr;
}

memory_register->RegisterArray(weight_name, &weight_data, weight_size);
}
return memory_register;


+ 7
- 9
mindspore/ccsrc/fl/server/parameter_aggregator.cc View File

@@ -80,8 +80,7 @@ bool ParameterAggregator::LaunchAggregators() {
for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
KernelParams &params = aggregator_with_params.second;
std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
RETURN_IF_NULL(aggr_kernel, false);

MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs);
if (!ret) {
MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed.";
@@ -95,8 +94,7 @@ bool ParameterAggregator::LaunchOptimizers() {
for (auto &optimizer_with_params : optimizer_kernel_parameters_) {
KernelParams &params = optimizer_with_params.second;
std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel = optimizer_with_params.first;
RETURN_IF_NULL(optimizer_kernel, false);

MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs);
if (!ret) {
MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed.";
@@ -158,7 +156,7 @@ bool ParameterAggregator::IsAggregationDone() const {
// Only consider aggregation done after each aggregation kernel is done.
for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
RETURN_IF_NULL(aggr_kernel, false);
MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
if (!aggr_kernel->IsAggregationDone()) {
return false;
}
@@ -276,8 +274,8 @@ bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode,

bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel,
const std::shared_ptr<MemoryRegister> memory_register) {
RETURN_IF_NULL(aggr_kernel, false);
RETURN_IF_NULL(memory_register, false);
MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
KernelParams aggr_params = {};

const std::vector<std::string> &input_names = aggr_kernel->input_names();
@@ -299,8 +297,8 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<

bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel,
const std::shared_ptr<MemoryRegister> memory_register) {
RETURN_IF_NULL(optimizer_kernel, false);
RETURN_IF_NULL(memory_register, false);
MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
KernelParams optimizer_params = {};

const std::vector<std::string> &input_names = optimizer_kernel->input_names();


+ 3
- 5
mindspore/ccsrc/fl/server/round.cc View File

@@ -107,11 +107,7 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
}

void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}

MS_ERROR_IF_NULL_WO_RET_VAL(message);
// If the server is still in the process of scaling, refuse the request.
if (Server::GetInstance().IsSafeMode()) {
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
@@ -125,6 +121,8 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m

AddressPtr input = std::make_shared<Address>();
AddressPtr output = std::make_shared<Address>();
MS_ERROR_IF_NULL_WO_RET_VAL(input);
MS_ERROR_IF_NULL_WO_RET_VAL(output);
input->addr = message->data();
input->size = message->len();
bool ret = kernel_->Launch({input}, {}, {output});


+ 2
- 9
mindspore/ccsrc/fl/server/server.cc View File

@@ -137,7 +137,6 @@ void Server::InitCluster() {
bool Server::InitCommunicatorWithServer() {
MS_EXCEPTION_IF_NULL(task_executor_);
MS_EXCEPTION_IF_NULL(server_node_);

communicator_with_server_ =
server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_);
MS_EXCEPTION_IF_NULL(communicator_with_server_);
@@ -395,10 +394,7 @@ void Server::ProcessBeforeScalingIn() {

void Server::ProcessAfterScalingOut() {
std::unique_lock<std::mutex> lock(scaling_mtx_);
if (server_node_ == nullptr) {
return;
}

MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
}
@@ -420,10 +416,7 @@ void Server::ProcessAfterScalingOut() {

void Server::ProcessAfterScalingIn() {
std::unique_lock<std::mutex> lock(scaling_mtx_);
if (server_node_ == nullptr) {
return;
}

MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == UINT32_MAX) {
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
(void)std::for_each(


+ 16
- 0
mindspore/core/utils/log_adapter.h View File

@@ -233,6 +233,22 @@ class LogWriter {
} \
} while (0)

#define MS_ERROR_IF_NULL_W_RET_VAL(ptr, val) \
do { \
if ((ptr) == nullptr) { \
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
return val; \
} \
} while (0)

#define MS_ERROR_IF_NULL_WO_RET_VAL(ptr) \
do { \
if ((ptr) == nullptr) { \
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
return; \
} \
} while (0)

#ifdef DEBUG
#include <cassert>
#define MS_ASSERT(f) assert(f)


Loading…
Cancel
Save