Browse Source

FL, opt allreduce

feature/build-system-rewrite
xuyongfei 4 years ago
parent
commit
ec4b97ff01
27 changed files with 503 additions and 211 deletions
  1. +140
    -66
      mindspore/ccsrc/fl/server/collective_ops_impl.cc
  2. +9
    -3
      mindspore/ccsrc/fl/server/collective_ops_impl.h
  3. +18
    -17
      mindspore/ccsrc/fl/server/distributed_count_service.cc
  4. +21
    -0
      mindspore/ccsrc/fl/server/executor.cc
  5. +2
    -0
      mindspore/ccsrc/fl/server/executor.h
  6. +16
    -7
      mindspore/ccsrc/fl/server/iteration.cc
  7. +2
    -2
      mindspore/ccsrc/fl/server/iteration.h
  8. +1
    -0
      mindspore/ccsrc/fl/server/kernel/aggregation_kernel.h
  9. +1
    -0
      mindspore/ccsrc/fl/server/kernel/dense_grad_accum_kernel.h
  10. +39
    -63
      mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h
  11. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc
  12. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc
  13. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc
  14. +6
    -8
      mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc
  15. +1
    -3
      mindspore/ccsrc/fl/server/kernel/round/round_kernel.h
  16. +41
    -21
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc
  17. +3
    -2
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h
  18. +11
    -0
      mindspore/ccsrc/fl/server/parameter_aggregator.cc
  19. +1
    -0
      mindspore/ccsrc/fl/server/parameter_aggregator.h
  20. +0
    -7
      mindspore/ccsrc/fl/server/round.cc
  21. +0
    -1
      mindspore/ccsrc/fl/server/round.h
  22. +9
    -0
      mindspore/ccsrc/fl/server/server.cc
  23. +2
    -0
      mindspore/ccsrc/fl/server/server.h
  24. +139
    -6
      mindspore/ccsrc/ps/core/abstract_node.cc
  25. +22
    -0
      mindspore/ccsrc/ps/core/abstract_node.h
  26. +15
    -2
      mindspore/ccsrc/ps/core/protos/comm.proto
  27. +1
    -0
      mindspore/ccsrc/ps/core/protos/fl.proto

+ 140
- 66
mindspore/ccsrc/fl/server/collective_ops_impl.cc View File

@@ -15,10 +15,19 @@
*/

#include "fl/server/collective_ops_impl.h"
#include "fl/server/local_meta_store.h"
#include "fl/server/iteration.h"

namespace mindspore {
namespace fl {
namespace server {
namespace {
const char *kCollectivePhaseRing = "ring";
const char *kCollectivePhaseGather = "gather";
const char *kCollectivePhaseReduce = "reduce";
const char *kCollectivePhaseBroadcast = "broadcast";
} // namespace

void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
@@ -28,16 +37,15 @@ void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &
}

template <typename T>
bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
MS_ERROR_IF_NULL_W_RET_VAL(server_node_, false);
bool CollectiveOpsImpl::RingAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff,
size_t count) {
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);

int ret;
if (recvbuff != sendbuff) {
size_t src_size = count * sizeof(T);
size_t dst_size = count * sizeof(T);
ret = memcpy_s(recvbuff, dst_size, sendbuff, src_size);
auto ret = memcpy_s(recvbuff, dst_size, sendbuff, src_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return false;
@@ -68,32 +76,61 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
<< ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
<< ", recv_from_rank:" << recv_from_rank;

return RunRingAllReduce<T>(data_name, send_to_rank, recv_from_rank, chunk_sizes, chunk_offset, output_buff);
}

// Implementation of RingAllReduce.
template <typename T>
bool CollectiveOpsImpl::RunRingAllReduce(const std::string &data_name, uint32_t send_to_rank, uint32_t recv_from_rank,
const std::vector<size_t> &chunk_sizes,
const std::vector<size_t> &chunk_offset, T *output_buff) {
MS_ERROR_IF_NULL_W_RET_VAL(server_node_, false);
MS_ERROR_IF_NULL_W_RET_VAL(output_buff, false);
auto curr_iteration_num = LocalMetaStore::GetInstance().curr_iter_num();
ps::core::CollectiveMessageMeta send_meta;
send_meta.set_enable_flag(true);
send_meta.set_send_rank_id(rank_id_);
send_meta.set_recv_rank_id(send_to_rank);
send_meta.set_iteration(curr_iteration_num);
send_meta.set_weight_name(data_name);

ps::core::CollectiveMessageMeta recv_meta;
recv_meta.set_enable_flag(true);
recv_meta.set_send_rank_id(recv_from_rank);
recv_meta.set_recv_rank_id(rank_id_);
recv_meta.set_iteration(curr_iteration_num);
recv_meta.set_weight_name(data_name);

// Ring ReduceScatter.
MS_LOG(DEBUG) << "Start Ring ReduceScatter.";
send_meta.set_phase(kCollectivePhaseRing);
recv_meta.set_phase(kCollectivePhaseRing);

uint32_t rank_size = server_num_;
for (size_t i = 0; i < rank_size - 1; i++) {
// Step 1: Async send data to next rank.
size_t send_chunk_index = (rank_id_ - i + rank_size) % rank_size;
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
send_meta.set_chunk_index(send_chunk_index);
send_meta.set_for_index(i);
auto send_chunk_count = chunk_sizes[send_chunk_index];
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
send_chunk_count * sizeof(T));
auto send_req_id = server_node_->FlCollectiveSendAsync(send_meta, send_chunk, send_chunk_count * sizeof(T));
// Step 2: Async receive data to next rank and wait until it's done.
size_t recv_chunk_index = (rank_id_ - i - 1 + rank_size) % rank_size;
recv_meta.set_chunk_index(recv_chunk_index);
recv_meta.set_for_index(i);
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
auto recv_chunk_count = chunk_sizes[recv_chunk_index];
MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
<< ", send count:" << send_chunk_count << ", recv count:" << recv_chunk_count << ", iteration:" << i;
<< ", send chunk index:" << send_chunk_index << ", send count:" << send_chunk_count
<< ", recv chunk index:" << recv_chunk_index << ", recv count:" << recv_chunk_count
<< ", for index:" << i;

std::shared_ptr<std::vector<uint8_t>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
if (recv_chunk_count * sizeof(T) != recv_str->size()) {
MS_LOG(ERROR) << "Expect receive chunk size " << recv_chunk_count * sizeof(T) << " from rank " << recv_from_rank
<< " != real receive chunk size " << recv_str->size() << ", current rank: " << rank_id_
<< ", total data size " << count * sizeof(T);
auto expect_size = recv_chunk_count * sizeof(T);
if (!server_node_->FlCollectiveWait(recv_meta, expect_size, &recv_str, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "FlCollectiveWait failed, send rank id: " << recv_meta.send_rank_id();
return false;
}
auto tmp_recv_chunk = reinterpret_cast<T *>(recv_str->data());
@@ -103,7 +140,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
}
// Step 4: Wait until send is done.
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
MS_LOG(ERROR) << "Wait response of rank " << send_req_id << " failed.";
return false;
}
}
@@ -111,30 +148,40 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size

// Ring AllGather.
MS_LOG(DEBUG) << "Start Ring AllGather.";
send_meta.set_phase(kCollectivePhaseGather);
recv_meta.set_phase(kCollectivePhaseGather);
for (size_t i = 0; i < rank_size - 1; i++) {
size_t send_chunk_index = (rank_id_ - i + 1 + rank_size) % rank_size;
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
chunk_sizes[send_chunk_index] * sizeof(T));
send_meta.set_chunk_index(send_chunk_index);
send_meta.set_for_index(i);
auto send_chunk_count = chunk_sizes[send_chunk_index];
auto send_req_id = server_node_->FlCollectiveSendAsync(send_meta, send_chunk, send_chunk_count * sizeof(T));

size_t recv_chunk_index = (rank_id_ - i + rank_size) % rank_size;
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
recv_meta.set_chunk_index(recv_chunk_index);
recv_meta.set_for_index(i);
auto recv_chunk_count = chunk_sizes[recv_chunk_index];
MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
<< ", send count:" << chunk_sizes[send_chunk_index]
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
<< ", send chunk index:" << send_chunk_index << ", send count:" << send_chunk_count
<< ", recv chunk index:" << recv_chunk_index << ", recv count:" << recv_chunk_count
<< ", for index:" << i;

std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
auto expect_size = recv_chunk_count * sizeof(T);
if (!server_node_->FlCollectiveWait(recv_meta, expect_size, &recv_str, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "FlCollectiveWait failed, send rank id: " << recv_meta.send_rank_id();
return false;
}
ret = memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
auto ret = memcpy_s(recv_chunk, expect_size, recv_str->data(), recv_str->size());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << recv_chunk_count * sizeof(T) << ", src size is " << recv_str->size();
return false;
}
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
MS_LOG(ERROR) << "Wait response of rank " << send_req_id << " failed.";
return false;
}
}
@@ -143,7 +190,8 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
}

template <typename T>
bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff,
size_t count) {
MS_ERROR_IF_NULL_W_RET_VAL(server_node_, false);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
@@ -155,37 +203,54 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
size_t dst_size = count * sizeof(T);
int ret = memcpy_s(recvbuff, dst_size, sendbuff, src_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << dst_size << ", src size is " << src_size;
return false;
}
T *output_buff = reinterpret_cast<T *>(recvbuff);
// Reduce data to rank 0 process.
auto curr_iteration_num = LocalMetaStore::GetInstance().curr_iter_num();
ps::core::CollectiveMessageMeta send_meta;
send_meta.set_enable_flag(true);
send_meta.set_send_rank_id(rank_id_);
send_meta.set_iteration(curr_iteration_num);
send_meta.set_weight_name(data_name);
send_meta.set_chunk_index(0);
send_meta.set_for_index(0);

ps::core::CollectiveMessageMeta recv_meta;
recv_meta.set_enable_flag(true);
recv_meta.set_recv_rank_id(rank_id_);
recv_meta.set_iteration(curr_iteration_num);
recv_meta.set_weight_name(data_name);
recv_meta.set_chunk_index(0);
recv_meta.set_for_index(0);

send_meta.set_phase(kCollectivePhaseReduce);
recv_meta.set_phase(kCollectivePhaseReduce);

MS_LOG(DEBUG) << "Start Reduce to rank 0 process.";
if (rank_id_ == 0) {
for (uint32_t i = 1; i < rank_size; i++) {
std::shared_ptr<std::vector<unsigned char>> recv_str;
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
auto recv_req_id1 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id1, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id1 << " failed.";
return false;
}
if (count * sizeof(T) != recv_str->size()) {
MS_LOG(ERROR) << "Expect receive chunk size " << count * sizeof(T) << " from rank " << i
<< " != real receive chunk size " << recv_str->size() << ", current rank: " << rank_id_
<< ", total data size " << count * sizeof(T);
recv_meta.set_send_rank_id(i);
auto expect_size = count * sizeof(T);
if (!server_node_->FlCollectiveWait(recv_meta, expect_size, &recv_str, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "FlCollectiveWait failed, send rank id: " << recv_meta.send_rank_id();
return false;
}
auto tmp_recv_chunk = reinterpret_cast<T *>(recv_str->data());
auto tmp_recv_chunk = reinterpret_cast<T *>(recv_str->data()); // recv_str size has checked in FlCollectiveWait
for (size_t j = 0; j < count; j++) {
output_buff[j] += tmp_recv_chunk[j];
}
}
} else {
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
auto send_req_id1 = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
send_meta.set_recv_rank_id(0);
auto send_req_id1 = server_node_->FlCollectiveSendAsync(send_meta, sendbuff, count * sizeof(T));
if (!server_node_->Wait(send_req_id1, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id1 << " failed.";
MS_LOG(ERROR) << "Wait response of rank " << send_req_id1 << " failed.";
return false;
}
}
@@ -193,27 +258,31 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec

// Broadcast data to not 0 rank process.
MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes.";
send_meta.set_phase(kCollectivePhaseBroadcast);
recv_meta.set_phase(kCollectivePhaseBroadcast);
if (rank_id_ == 0) {
for (uint32_t i = 1; i < rank_size; i++) {
MS_LOG(DEBUG) << "Broadcast data to process " << i;
auto send_req_id2 =
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
send_meta.set_recv_rank_id(i);
auto send_req_id2 = server_node_->FlCollectiveSendAsync(send_meta, output_buff, count * sizeof(T));
if (!server_node_->Wait(send_req_id2, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id2 << " failed.";
MS_LOG(ERROR) << "Wait response of rank " << send_req_id2 << " failed.";
return false;
}
}
} else {
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
recv_meta.set_send_rank_id(0);
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id2 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id2, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id2 << " failed.";
auto expect_size = count * sizeof(T);
if (!server_node_->FlCollectiveWait(recv_meta, expect_size, &recv_str, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "FlCollectiveWait failed, send rank id: " << recv_meta.send_rank_id();
return false;
}
ret = memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
ret = memcpy_s(output_buff, expect_size, recv_str->data(), recv_str->size());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << expect_size << ", src size is " << recv_str->size();
return false;
}
}
@@ -249,7 +318,8 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
size_t dst_size = send_count * sizeof(T);
int ret = memcpy_s(output_buff + chunk_offset[rank_id_], dst_size, sendbuff, src_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << dst_size << ", src size is " << src_size;
return false;
}

@@ -273,7 +343,9 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff
}
ret = memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << (chunk_sizes[recv_chunk_index] * sizeof(T)) << ", src size is "
<< recv_str->size();
return false;
}
if (!node_->Wait(send_req_id, kCollectiveCommTimeout)) {
@@ -322,7 +394,8 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c
}
int ret = memcpy_s(recvbuff, count * sizeof(T), recv_str->data(), recv_str->size());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << (count * sizeof(T)) << ", src size is " << recv_str->size();
return false;
}
}
@@ -331,11 +404,12 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c
}

template <typename T>
bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) {
bool CollectiveOpsImpl::AllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count) {
// The collective communication API does not support calling Send and Recv concurrently with multiple threads;
std::unique_lock<std::mutex> lock(mtx_);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(server_node_, false);

uint32_t rank_size = server_num_;
if (rank_size == 0) {
@@ -346,11 +420,16 @@ bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t c
MS_LOG(INFO) << "Rank size is 1. Do nothing.";
return true;
}
auto cur_iteration_num = LocalMetaStore::GetInstance().curr_iter_num();
if (server_node_->HasIterationFailed(cur_iteration_num)) {
MS_LOG(WARNING) << "Detect iteration " << cur_iteration_num << " has failed";
return false;
}

if (count >= rank_size) {
return RingAllReduce<T>(sendbuff, recvbuff, count);
return RingAllReduce<T>(data_name, sendbuff, recvbuff, count);
} else {
return ReduceBroadcastAllReduce<T>(sendbuff, recvbuff, count);
return ReduceBroadcastAllReduce<T>(data_name, sendbuff, recvbuff, count);
}
}

@@ -429,17 +508,12 @@ bool CollectiveOpsImpl::ReInitForScaling() {
return true;
}

template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);

template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);

template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::AllReduce<float>(const std::string &data_name, const void *sendbuff, void *recvbuff,
size_t count);
template bool CollectiveOpsImpl::AllReduce<size_t>(const std::string &data_name, const void *sendbuff, void *recvbuff,
size_t count);
template bool CollectiveOpsImpl::AllReduce<int>(const std::string &data_name, const void *sendbuff, void *recvbuff,
size_t count);

template bool CollectiveOpsImpl::AllGather<float>(const void *sendbuff, void *recvbuff, size_t send_count,
const std::shared_ptr<ps::core::AbstractNode> &node);


+ 9
- 3
mindspore/ccsrc/fl/server/collective_ops_impl.h View File

@@ -62,7 +62,7 @@ class CollectiveOpsImpl {
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);

template <typename T>
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
bool AllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count);

template <typename T>
bool AllGather(const void *sendbuff, void *recvbuff, size_t send_count,
@@ -91,11 +91,17 @@ class CollectiveOpsImpl {

// Implementation of RingAllReduce.
template <typename T>
bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count);
bool RunRingAllReduce(const std::string &data_name, uint32_t send_to_rank, uint32_t recv_from_rank,
const std::vector<size_t> &chunk_sizes, const std::vector<size_t> &chunk_offset,
T *output_buff);

// Implementation of RingAllReduce.
template <typename T>
bool RingAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count);

// Implementation of BroadcastAllReduce.
template <typename T>
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
bool ReduceBroadcastAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count);

// Implementation of RingAllGather.
template <typename T>


+ 18
- 17
mindspore/ccsrc/fl/server/distributed_count_service.cc View File

@@ -18,6 +18,8 @@
#include <string>
#include <memory>
#include <vector>
#include "fl/server/iteration.h"
#include "fl/server/server.h"

namespace mindspore {
namespace fl {
@@ -252,20 +254,17 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
MS_LOG(INFO) << "Global current count for " << name << " is: " << global_current_count_[name].size() << "/"
<< global_threshold_count_[name];
}
std::string reason = "success";
if (!TriggerCounterEvent(name, &reason)) {
count_rsp.set_result(false);
count_rsp.set_reason(reason);
} else {
count_rsp.set_result(true);
count_rsp.set_reason(reason);
}
count_rsp.set_result(true);
count_rsp.set_reason("success");
if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
message)) {
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
std::string reason = "success";
if (!TriggerCounterEvent(name, &reason)) {
Iteration::GetInstance().NotifyNext(false, reason);
}
}

void DistributedCountService::HandleCountReachThresholdRequest(
@@ -360,15 +359,16 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
return false;
}
}
if (counter_handlers_.count(name) == 0) {
auto counter_it = counter_handlers_.find(name);
if (counter_it == counter_handlers_.end() || !counter_it->second.first_count_handler) {
MS_LOG(WARNING) << "The counter handler of " << name << " is not registered.";
return false;
}
// Leader server directly calls the callback.
MS_LOG(DEBUG) << "Leader server call first count handler for " << name;
counter_handlers_[name].first_count_handler(nullptr);
MS_LOG(DEBUG) << "First count handler for " << name << " is successfully called.";
auto count_handler = counter_it->second.first_count_handler;
Server::GetInstance().SubmitTask([count_handler]() { count_handler(nullptr); });
MS_LOG(DEBUG) << "First count handler for " << name << " is successfully submitted.";
return true;
}

@@ -389,15 +389,16 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
return false;
}
}
if (counter_handlers_.count(name) == 0) {
auto counter_it = counter_handlers_.find(name);
if (counter_it == counter_handlers_.end() || !counter_it->second.last_count_handler) {
MS_LOG(WARNING) << "The counter handler of " << name << " is not registered.";
return false;
}
// Leader server directly calls the callback.
MS_LOG(DEBUG) << "Leader server call last count handler for " << name;
counter_handlers_[name].last_count_handler(nullptr);
MS_LOG(INFO) << "Last count handler for " << name << " is successfully called.";
auto count_handler = counter_it->second.last_count_handler;
Server::GetInstance().SubmitTask([count_handler]() { count_handler(nullptr); });
MS_LOG(INFO) << "Last count handler for " << name << " is successfully submitted.";
return true;
}
} // namespace server


+ 21
- 0
mindspore/ccsrc/fl/server/executor.cc View File

@@ -140,6 +140,27 @@ std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<s

bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }

bool Executor::RunAllWeightAggregation() {
for (const auto &name : param_names_) {
if (param_aggrs_.count(name) == 0) {
MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
return false;
}
std::mutex &mtx = parameter_mutex_[name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
if (!param_aggr->requires_aggr()) {
continue;
}
if (!param_aggr->RunAggregation()) {
MS_LOG(WARNING) << "Failed to run aggregation for " << name;
return false;
}
}
return true;
}

bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {
for (const auto &name : param_names) {
if (param_aggrs_.count(name) == 0) {


+ 2
- 0
mindspore/ccsrc/fl/server/executor.h View File

@@ -69,6 +69,8 @@ class Executor {
// Judge whether aggregation processes for all weights/gradients are completed.
bool IsAllWeightAggregationDone();

bool RunAllWeightAggregation();

// Judge whether the aggregation processes for the given param_names are completed.
bool IsWeightAggrDone(const std::vector<std::string> &param_names);



+ 16
- 7
mindspore/ccsrc/fl/server/iteration.cc View File

@@ -121,7 +121,7 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &

MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == kLeaderServerRank) {
if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
if (!BroadcastPrepareForNextIterRequest(iteration_num_, is_last_iter_valid, reason)) {
MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
return;
}
@@ -452,7 +452,7 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps
return;
}

if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
if (!BroadcastPrepareForNextIterRequest(iter_num, is_last_iter_valid, reason)) {
MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
return;
}
@@ -466,13 +466,15 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps
}
}

bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
bool Iteration::BroadcastPrepareForNextIterRequest(size_t last_iteration, bool is_last_iter_valid,
const std::string &reason) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
PrepareForNextIter();
PrepareForNextIter(last_iteration, is_last_iter_valid);
MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration.";
PrepareForNextIterRequest prepare_next_iter_req;
prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
prepare_next_iter_req.set_reason(reason);
prepare_next_iter_req.set_last_iteration(last_iteration);

std::vector<uint32_t> offline_servers = {};
for (uint32_t i = 1; i < server_node_->server_num(); i++) {
@@ -504,8 +506,12 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::
PrepareForNextIterRequest prepare_next_iter_req;
(void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &reason = prepare_next_iter_req.reason();
MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason;
PrepareForNextIter();
auto is_last_iter_valid = prepare_next_iter_req.is_last_iter_valid();
auto last_iteration = prepare_next_iter_req.last_iteration();
MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id()
<< ", last iteration: " << last_iteration << ", last iteration valid: " << is_last_iter_valid
<< ", reason: " << reason;
PrepareForNextIter(last_iteration, is_last_iter_valid);

PrepareForNextIterResponse prepare_next_iter_rsp;
prepare_next_iter_rsp.set_result("success");
@@ -516,9 +522,12 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::
}
}

void Iteration::PrepareForNextIter() {
void Iteration::PrepareForNextIter(size_t last_iteration, bool is_last_iter_valid) {
MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode.";
Server::GetInstance().SwitchToSafeMode();
if (server_node_) {
server_node_->SetIterationResult(last_iteration, is_last_iter_valid);
}
MS_LOG(INFO) << "Start waiting for rounds to finish.";
WaitAllRoundsFinish();
MS_LOG(INFO) << "End waiting for rounds to finish.";


+ 2
- 2
mindspore/ccsrc/fl/server/iteration.h View File

@@ -189,10 +189,10 @@ class Iteration {
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);

// Step 2: leader server broadcasts to all follower servers to prepare for next iteration and switch to safemode..
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
bool BroadcastPrepareForNextIterRequest(size_t last_iteration, bool is_last_iter_valid, const std::string &reason);
void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The server prepare for the next iteration. This method will switch the server to safemode.
void PrepareForNextIter();
void PrepareForNextIter(size_t last_iteration, bool is_last_iter_valid);

// Step 3: leader server broadcasts to all follower servers to move to next iteration.
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);


+ 1
- 0
mindspore/ccsrc/fl/server/kernel/aggregation_kernel.h View File

@@ -45,6 +45,7 @@ class AggregationKernelMod : public NativeCpuKernelMod {
const std::vector<AddressPtr> &outputs) {
return true;
}
virtual bool AllReduce() = 0;

// Server kernel's memory allocation method, which is different from the workflow in
// Session(GPUSession/CPUSession/AscendSession).


+ 1
- 0
mindspore/ccsrc/fl/server/kernel/dense_grad_accum_kernel.h View File

@@ -43,6 +43,7 @@ class DenseGradAccumKernel : public AggregationKernelMod {
return true;
}

bool AllReduce() override { return true; }
void Reset() { accum_count_ = 0; }

bool IsAggregationDone() { return accum_count_ >= done_count_; }


+ 39
- 63
mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h View File

@@ -48,8 +48,7 @@ class FedAvgKernel : public AggregationKernelMod {
weight_addr_(nullptr),
data_size_addr_(nullptr),
new_weight_addr_(nullptr),
new_data_size_addr_(nullptr),
participated_(false) {}
new_data_size_addr_(nullptr) {}
~FedAvgKernel() override = default;

void InitKernel(const CNodePtr &kernel_node) override {
@@ -76,45 +75,42 @@ class FedAvgKernel : public AggregationKernelMod {
.first;
MS_EXCEPTION_IF_NULL(weight_node);
name_ = cnode_name + "." + weight_node->fullname_with_scope();
first_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
std::unique_lock<std::mutex> lock(weight_mutex_);
if (!participated_) {
ClearWeightAndDataSize();
}
};
last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_);
MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_);
MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_->addr);
MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_->addr);
std::unique_lock<std::mutex> lock(weight_mutex_);
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
size_t weight_size = weight_addr_->size;
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
if (!CollectiveOpsImpl::GetInstance().AllReduce<T>(weight_addr, weight_addr, weight_size / sizeof(T))) {
MS_LOG(ERROR) << "Federated average allreduce failed.";
return;
}
if (!CollectiveOpsImpl::GetInstance().AllReduce<S>(data_size_addr, data_size_addr, 1)) {
MS_LOG(ERROR) << "Federated average allreduce failed.";
return;
}
if (data_size_addr[0] == 0) {
MS_LOG(ERROR) << "After AllReduce, the data size is 0.";
return;
}
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]);
for (size_t i = 0; i < weight_size / sizeof(T); i++) {
weight_addr[i] /= data_size_addr[0];
}
done_ = true;
return;
};
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});

MS_LOG(INFO) << "Aggregate Weight full name is " << weight_node->fullname_with_scope() << ", weight byte size is "
<< weight_size;
GenerateReuseKernelNodeInfo();
return;
}

bool AllReduce() override {
std::unique_lock<std::mutex> lock(weight_mutex_);
MS_ERROR_IF_NULL_W_RET_VAL(weight_addr_, false);
MS_ERROR_IF_NULL_W_RET_VAL(data_size_addr_, false);
MS_ERROR_IF_NULL_W_RET_VAL(weight_addr_->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(data_size_addr_->addr, false);
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
size_t weight_size = weight_addr_->size;
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
if (!CollectiveOpsImpl::GetInstance().AllReduce<T>(name_, weight_addr, weight_addr, weight_size / sizeof(T))) {
MS_LOG(ERROR) << "Federated average allreduce failed.";
return false;
}
if (!CollectiveOpsImpl::GetInstance().AllReduce<S>(name_ + "_data_size", data_size_addr, data_size_addr, 1)) {
MS_LOG(ERROR) << "Federated average allreduce failed.";
return false;
}
if (data_size_addr[0] == 0) {
MS_LOG(ERROR) << "After AllReduce, the data size is 0.";
return false;
}
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]);
for (size_t i = 0; i < weight_size / sizeof(T); i++) {
weight_addr[i] /= data_size_addr[0];
}
done_ = true;
return true;
}

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
if (inputs.size() != kFedAvgInputsNum) {
@@ -126,15 +122,16 @@ class FedAvgKernel : public AggregationKernelMod {
}

std::unique_lock<std::mutex> lock(weight_mutex_);
if (done_) {
MS_LOG(INFO) << "AllReduce for " << name_ << " has finished";
return true;
}
// The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication
// again.
T *weight_addr = reinterpret_cast<T *>(inputs[0]->addr);
S *data_size_addr = reinterpret_cast<S *>(inputs[1]->addr);
T *new_weight_addr = reinterpret_cast<T *>(inputs[2]->addr);
S *new_data_size_addr = reinterpret_cast<S *>(inputs[3]->addr);
if (accum_count_ == 0) {
ClearWeightAndDataSize();
}

MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for "
<< name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is "
@@ -146,17 +143,13 @@ class FedAvgKernel : public AggregationKernelMod {
lock.unlock();

accum_count_++;
participated_ = true;
return DistributedCountService::GetInstance().Count(
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
return true;
}

void Reset() override {
accum_count_ = 0;
done_ = false;
participated_ = false;
DistributedCountService::GetInstance().ResetCounter(name_);
return;
ClearWeightAndDataSize();
}

bool IsAggregationDone() override { return done_; }
@@ -169,18 +162,8 @@ class FedAvgKernel : public AggregationKernelMod {
new_data_size_addr_ = inputs[3];
return;
}

bool ReInitForScaling() override {
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});
return true;
}

bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override {
done_count_ = aggr_threshold;
if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) {
MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
return false;
}
return true;
}

@@ -211,9 +194,6 @@ class FedAvgKernel : public AggregationKernelMod {
return;
}

MessageCallback first_cnt_handler_;
MessageCallback last_cnt_handler_;

// The trainable parameter index of the kernel node which is parsed from the frontend func_graph.
size_t cnode_weight_idx_;

@@ -222,10 +202,6 @@ class FedAvgKernel : public AggregationKernelMod {
AddressPtr data_size_addr_;
AddressPtr new_weight_addr_;
AddressPtr new_data_size_addr_;

// Whether the kernel's Launch method is called.
bool participated_;

// The kernel could be called concurrently so we need lock to ensure threadsafe.
std::mutex weight_mutex_;
};


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc View File

@@ -67,7 +67,7 @@ bool PushMetricsKernel::Reset() {

void PushMetricsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushMetrics) {
FinishIteration();
FinishIteration(true);
}
return;
}


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc View File

@@ -72,7 +72,7 @@ bool PushWeightKernel::Reset() {

void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushWeight) {
FinishIteration();
FinishIteration(true);
}
return;
}


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc View File

@@ -32,7 +32,7 @@ void ReconstructSecretsKernel::InitKernel(size_t) {
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) {
MS_LOG(INFO) << "start FinishIteration";
FinishIteration();
FinishIteration(true);
MS_LOG(INFO) << "end FinishIteration";
}
return;


+ 6
- 8
mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc View File

@@ -22,6 +22,7 @@
#include <utility>
#include <string>
#include <vector>
#include "fl/server/iteration.h"

namespace mindspore {
namespace fl {
@@ -42,21 +43,18 @@ void RoundKernel::StopTimer() const {
return;
}

void RoundKernel::FinishIteration() const {
if (finish_iteration_cb_) {
finish_iteration_cb_(true, "");
void RoundKernel::FinishIteration(bool is_last_iter_valid, const std::string &in_reason) const {
std::string reason = in_reason;
if (is_last_iter_valid) {
reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
}
return;
Iteration::GetInstance().NotifyNext(is_last_iter_valid, reason);
}

void RoundKernel::set_name(const std::string &name) { name_ = name; }

void RoundKernel::set_stop_timer_cb(const StopTimerCb &timer_stopper) { stop_timer_cb_ = timer_stopper; }

void RoundKernel::set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb) {
finish_iteration_cb_ = finish_iteration_cb;
}

void RoundKernel::GenerateOutput(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
size_t len) {
if (message == nullptr) {


+ 1
- 3
mindspore/ccsrc/fl/server/kernel/round/round_kernel.h View File

@@ -72,14 +72,13 @@ class RoundKernel {

// Called after this iteration(including all rounds) is finished. All rounds' Reset method will
// be called.
void FinishIteration() const;
void FinishIteration(bool is_last_iter_valid, const std::string &reason = "") const;

// Set round kernel name, which could be used in round kernel's methods.
void set_name(const std::string &name);

// Set callbacks to be called under certain triggered conditions.
void set_stop_timer_cb(const StopTimerCb &timer_stopper);
void set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb);

void Summarize();

@@ -106,7 +105,6 @@ class RoundKernel {
size_t current_count_;

StopTimerCb stop_timer_cb_;
FinishIterCb finish_iteration_cb_;

// Members below are used for allocating and releasing response data on the heap.



+ 41
- 21
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc View File

@@ -25,7 +25,8 @@ namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
constexpr uint32_t kRetryCountOfWaitWeightAggregation = 30;
const char *kCountForAggregation = "count_for_aggregation";

void UpdateModelKernel::InitKernel(size_t threshold_count) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
@@ -42,6 +43,11 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);

auto first_cnt_handler = [](std::shared_ptr<ps::core::MessageHandler>) {};
auto last_cnt_handler = [this](std::shared_ptr<ps::core::MessageHandler>) { RunAggregation(); };
DistributedCountService::GetInstance().RegisterCounter(kCountForAggregation, threshold_count,
{first_cnt_handler, last_cnt_handler});
}

bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
@@ -72,7 +78,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
}

const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
if (update_model_req == nullptr) {
if (update_model_req == nullptr || update_model_req->fl_id() == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestUpdateModel.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(WARNING) << reason;
@@ -118,11 +124,16 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
if (result_code != ResultCode::kSuccess) {
MS_LOG(WARNING) << "Updating model failed.";
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
return false;
}
std::string update_model_fl_id = update_model_req->fl_id()->str();
IncreaseAcceptClientNum();
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());

result_code = CountForAggregation(update_model_fl_id);
if (result_code != ResultCode::kSuccess) {
return false;
}
return true;
}

@@ -130,6 +141,7 @@ bool UpdateModelKernel::Reset() {
MS_LOG(INFO) << "Update model kernel reset!";
StopTimer();
DistributedCountService::GetInstance().ResetCounter(name_);
DistributedCountService::GetInstance().ResetCounter(kCountForAggregation);
executor_->ResetAggregationStatus();
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList);
size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value<size_t>(kCtxFedAvgTotalDataSize);
@@ -137,23 +149,22 @@ bool UpdateModelKernel::Reset() {
return true;
}

void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
last_count_thread_ = std::make_unique<std::thread>([this]() {
uint32_t retryCount = 0;
while (!executor_->IsAllWeightAggregationDone() && retryCount <= kRetryCountOfWaitWeightAggregation) {
std::this_thread::sleep_for(std::chrono::seconds(1));
retryCount += 1;
}

size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
<< total_data_size;
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
FinishIteration();
}
});
last_count_thread_->detach();
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {}

void UpdateModelKernel::RunAggregation() {
auto is_last_iter_valid = Executor::GetInstance().RunAllWeightAggregation();
auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num();
if (is_last_iter_valid) {
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
MS_LOG(INFO) << "Total data size for iteration " << curr_iter_num << " is " << total_data_size;
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel &&
ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
FinishIteration(is_last_iter_valid);
}
} else {
std::string reason = "Weight aggregation failed, current iteration: " + std::to_string(curr_iter_num);
MS_LOG(WARNING) << reason;
FinishIteration(is_last_iter_valid, reason);
}
}

@@ -279,6 +290,15 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
return feature_map;
}

ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) {
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) {
MS_LOG(ERROR) << "Counting for aggregation failed. reason: " + count_reason;
return ResultCode::kFail;
}
return ResultCode::kSuccess;
}

ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req) {
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);


+ 3
- 2
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h View File

@@ -55,6 +55,9 @@ class UpdateModelKernel : public RoundKernel {
ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb,
const DeviceMeta &device_meta);
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);

void RunAggregation();
ResultCode CountForAggregation(const std::string &req_fl_id);
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req);
sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req);
@@ -67,8 +70,6 @@ class UpdateModelKernel : public RoundKernel {

// The time window of one iteration.
size_t iteration_time_window_{0};

std::unique_ptr<std::thread> last_count_thread_;
};
} // namespace kernel
} // namespace server


+ 11
- 0
mindspore/ccsrc/fl/server/parameter_aggregator.cc View File

@@ -152,6 +152,17 @@ bool ParameterAggregator::IsAggregationDone() const {
return true;
}

bool ParameterAggregator::RunAggregation() {
for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
std::shared_ptr<kernel::AggregationKernelMod> aggr_kernel = aggregator_with_params.first;
MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
if (!aggr_kernel->AllReduce()) {
return false;
}
}
return true;
}

bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; }

bool ParameterAggregator::IsPullingDone() const { return pulling_done_; }


+ 1
- 0
mindspore/ccsrc/fl/server/parameter_aggregator.h View File

@@ -90,6 +90,7 @@ class ParameterAggregator {

// Returns the aggregation/optimizing/pulling status to the caller.
bool IsAggregationDone() const;
bool RunAggregation();
bool IsOptimizingDone() const;
bool IsPullingDone() const;



+ 0
- 7
mindspore/ccsrc/fl/server/round.cc View File

@@ -45,12 +45,6 @@ void Round::RegisterMsgCallBack(const std::shared_ptr<ps::core::CommunicatorBase

void Round::Initialize(const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
MS_LOG(INFO) << "Round " << name_ << " start initialize.";
// Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](bool, const std::string &) -> void {
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
finish_iteration_cb(true, reason);
};

if (check_timeout_) {
iter_timer_ = std::make_shared<IterationTimer>();
MS_EXCEPTION_IF_NULL(iter_timer_);
@@ -115,7 +109,6 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
MS_EXCEPTION_IF_NULL(kernel);
kernel_ = kernel;
kernel_->set_stop_timer_cb(stop_timer_cb_);
kernel_->set_finish_iteration_cb(finish_iteration_cb_);
return;
}



+ 0
- 1
mindspore/ccsrc/fl/server/round.h View File

@@ -110,7 +110,6 @@ class Round {

// The callbacks which will be set to the round kernel.
StopTimerCb stop_timer_cb_;
FinishIterCb finish_iteration_cb_;
};
} // namespace server
} // namespace fl


+ 9
- 0
mindspore/ccsrc/fl/server/server.cc View File

@@ -159,6 +159,13 @@ void Server::InitCluster() {
return;
}

bool Server::SubmitTask(std::function<void()> &&task) {
if (task_executor_ == nullptr) {
return false;
}
return task_executor_->Submit(task);
}

bool Server::InitCommunicatorWithServer() {
MS_EXCEPTION_IF_NULL(task_executor_);
MS_EXCEPTION_IF_NULL(server_node_);
@@ -396,6 +403,8 @@ void Server::InitExecutor() {
MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
Executor::GetInstance().Initialize(func_graph, executor_threshold_);
ModelStore::GetInstance().Initialize();
// init weight memory to 0 after get model
Executor::GetInstance().ResetAggregationStatus();
return;
}



+ 2
- 0
mindspore/ccsrc/fl/server/server.h View File

@@ -76,6 +76,8 @@ class Server {
// Whether the training job of the server is enabled.
InstanceState instance_state() const;

bool SubmitTask(std::function<void()> &&task);

private:
Server()
: server_node_(nullptr),


+ 139
- 6
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -341,6 +341,136 @@ uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint
return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
}

static std::string CollectiveMetaToString(const CollectiveMessageMeta &meta) {
std::ostringstream os;
os << "{iteration:" << meta.iteration() << ", data:" << meta.weight_name() << ", send rank:" << meta.send_rank_id()
<< ", recv rank:" << meta.recv_rank_id() << ", phase:" << meta.phase() << ", chunk index:" << meta.chunk_index()
<< ", for index:" << meta.for_index() << "}";
return os.str();
}

uint64_t AbstractNode::FlCollectiveSendAsync(const CollectiveMessageMeta &collective_meta, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(data);
auto recv_rank_id = collective_meta.recv_rank_id();
if (!CommUtil::ValidateRankId(SERVER, recv_rank_id, worker_num_, server_num_)) {
MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
<< ", the server num:" << server_num_ << ", the rank id:" << recv_rank_id;
return 0;
}
std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
MS_EXCEPTION_IF_NULL(message_meta);
message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
*(message_meta->mutable_collective_meta()) = collective_meta;
message_meta->mutable_collective_meta()->set_enable_flag(true);
message_meta->mutable_collective_meta()->set_send_rank_id(node_info_.rank_id_);

MS_LOG(DEBUG) << "Send data to rank id:" << recv_rank_id
<< ", send meta:" << CollectiveMetaToString(message_meta->collective_meta());
auto client = GetOrCreateTcpClient(recv_rank_id, SERVER);
MS_EXCEPTION_IF_NULL(client);
return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
}

bool AbstractNode::FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output,
const uint32_t &timeout) {
if (output == nullptr) {
return false;
}
auto send_rank_id = expect_meta.send_rank_id();
if (!CommUtil::ValidateRankId(SERVER, send_rank_id, worker_num_, server_num_)) {
MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_
<< ", the server num:" << server_num_ << ", the rank id:" << send_rank_id;
return false;
}
auto check_meta = [](const CollectiveMessageMeta &left, const CollectiveMessageMeta &right) {
return left.iteration() == right.iteration() && left.weight_name() == right.weight_name() &&
left.recv_rank_id() == right.recv_rank_id() && left.send_rank_id() == right.send_rank_id() &&
left.phase() == right.phase() && left.chunk_index() == right.chunk_index() &&
left.for_index() == right.for_index();
};
auto iteration_num = expect_meta.iteration();
std::unique_lock<std::mutex> lock(fl_receive_mutex_);
auto &recv_data_list = fl_received_data_[send_rank_id];
for (uint32_t i = 0; i < timeout; i++) {
if (recv_data_list.empty()) {
fl_receive_cond_.wait_for(lock, std::chrono::seconds(1), [&recv_data_list]() { return !recv_data_list.empty(); });
if (recv_data_list.empty()) { // timeout
if (HasIterationFailed(iteration_num)) { // if result of iteration reported by other server is failed
MS_LOG(WARNING) << "Detect iteration " << iteration_num << " has failed";
return false;
}
continue;
}
}
while (!recv_data_list.empty()) {
auto first = recv_data_list.begin();
auto recv_meta = std::move(first->first);
auto recv_data = std::move(first->second);
recv_data_list.erase(first);
MS_LOG(DEBUG) << "Handle receive data from rank id:" << send_rank_id
<< ", recv meta:" << CollectiveMetaToString(recv_meta);
if (recv_meta.iteration() != expect_meta.iteration()) {
MS_LOG(WARNING) << "Skip recv data, iteration of recv meta " << recv_meta.iteration()
<< " != iteration of expected meta " << expect_meta.iteration();
continue;
}
// error data in the same iteration
if (!check_meta(recv_meta, expect_meta)) {
MS_LOG(WARNING) << "Recv meta not match expected meta, recv mata: " << CollectiveMetaToString(recv_meta)
<< ", expected meta: " << CollectiveMetaToString(expect_meta);
return false;
}
*output = recv_data;
return true; // success to recv data
}
}
return false;
}

bool AbstractNode::FlCollectiveWait(const CollectiveMessageMeta &expect_meta, size_t expect_size, VectorPtr *output,
const uint32_t &timeout) {
if (output == nullptr) {
MS_LOG(ERROR) << "FlCollectiveWait failed, parameter output invalid";
return false;
}
auto data_recved = FlCollectiveWaitInner(expect_meta, output, timeout);
if (!data_recved) {
MS_LOG(ERROR) << "FlCollectiveWait failed, expect meta: " << CollectiveMetaToString(expect_meta);
return false;
}
if (*output == nullptr) {
MS_LOG(ERROR) << "FlCollectiveWait failed, recv buffer invalid";
return false;
}
if (expect_size != (*output)->size()) {
MS_LOG(ERROR) << "Expected data size " << expect_size << " != recv data size " << (*output)->size()
<< CollectiveMetaToString(expect_meta);
return false;
}
return true;
}

void AbstractNode::OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data) {
std::unique_lock<std::mutex> lock(fl_receive_mutex_);
auto &recv_meta = message_meta.collective_meta();
auto send_rank_id = recv_meta.send_rank_id();
MS_LOG(DEBUG) << "Receive data from rank id:" << send_rank_id << ", recv meta:" << CollectiveMetaToString(recv_meta);
fl_received_data_[send_rank_id].emplace_back(std::make_pair(recv_meta, data));
fl_receive_cond_.notify_all();
}

void AbstractNode::SetIterationResult(size_t last_iteration, bool is_iteration_valid) {
iteration_failed_ = !is_iteration_valid;
failed_iteration_num_ = last_iteration;
}

bool AbstractNode::HasIterationFailed(uint32_t iteration_num) const {
return iteration_num == failed_iteration_num_ && iteration_failed_;
}

std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
VectorPtr *output) {
MS_EXCEPTION_IF_NULL(output);
@@ -1081,19 +1211,22 @@ void AbstractNode::RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
receive_callbacks_mutex_.lock();
uint32_t rank_id = meta->rank_id();
// When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
// If they are equal, then call the callback function
uint64_t rank_request_id = NextActualRankRequestId(rank_id);
std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
size_t dest_size = size;
size_t src_size = size;
int ret = memcpy_s(received_data->data(), dest_size, data, src_size);
if (ret != 0) {
receive_callbacks_mutex_.unlock();
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
if (meta->collective_meta().enable_flag()) {
OnRecvCollectiveData(*meta, received_data);
return;
}
receive_callbacks_mutex_.lock();
uint32_t rank_id = meta->rank_id();
// When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
// If they are equal, then call the callback function
uint64_t rank_request_id = NextActualRankRequestId(rank_id);
received_data_[std::make_pair(rank_id, rank_request_id)] = received_data;
MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id
<< ", the send request id is:" << meta->request_id() << " the size is:" << size;


+ 22
- 0
mindspore/ccsrc/ps/core/abstract_node.h View File

@@ -23,6 +23,7 @@
#include <map>
#include <vector>
#include <unordered_map>
#include <functional>

#include "ps/core/node.h"
#include "ps/core/communicator/message.h"
@@ -103,6 +104,12 @@ class AbstractNode : public Node {
const uint32_t &timeout = kCommTimeoutInSeconds);

uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);

using CheckFailReturnFun = std::function<bool()>;
uint64_t FlCollectiveSendAsync(const CollectiveMessageMeta &collective_meta, const void *data, size_t size);
bool FlCollectiveWait(const CollectiveMessageMeta &expect_meta, size_t expect_size, VectorPtr *output,
const uint32_t &timeout = kCommTimeoutInSeconds);

std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
VectorPtr *output);
bool CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
@@ -145,6 +152,9 @@ class AbstractNode : public Node {
uint32_t worker_num, uint32_t server_num,
const std::shared_ptr<TaskExecutor> &task_executor);

void SetIterationResult(size_t last_iteration, bool is_iteration_valid);
bool HasIterationFailed(uint32_t iteration_num) const;

protected:
virtual void Register(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
@@ -235,6 +245,9 @@ class AbstractNode : public Node {
const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size);

bool FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output, const uint32_t &timeout);
void OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data);

std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
std::shared_ptr<TcpClient> client_to_scheduler_;
@@ -260,6 +273,12 @@ class AbstractNode : public Node {
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
std::unordered_map<NodeCommand, ServerHandler> server_handler_;

// send_rank_id, recv CollectiveMessageMeta and data
std::unordered_map<uint32_t, std::vector<std::pair<CollectiveMessageMeta, std::shared_ptr<std::vector<uint8_t>>>>>
fl_received_data_;
std::mutex fl_receive_mutex_;
std::condition_variable fl_receive_cond_;

// Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;
@@ -302,6 +321,9 @@ class AbstractNode : public Node {
std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
std::mutex communicator_mutex_;
std::mutex cluster_state_mutex_;

size_t failed_iteration_num_ = 0;
bool iteration_failed_ = false;
};
} // namespace core
} // namespace ps


+ 15
- 2
mindspore/ccsrc/ps/core/protos/comm.proto View File

@@ -68,7 +68,18 @@ enum PersistentState {
PREPARING_PERSIST = 1;
READY_PERSIST = 2;
PERSISTING = 3;
FINISH_PERSIST = 4;
FINISH_PERSIST = 4;
}

message CollectiveMessageMeta {
bool enable_flag = 1;
uint32 send_rank_id = 2;
uint32 recv_rank_id = 3;
uint32 iteration = 4;
bytes weight_name = 5;
bytes phase = 6; // ring, gather, reduce, broadcast
uint32 chunk_index = 7;
uint32 for_index = 8;
}

message MessageMeta {
@@ -82,6 +93,8 @@ message MessageMeta {
uint32 rank_id = 4;
// User-defined commands
int32 user_cmd = 5;

CollectiveMessageMeta collective_meta = 6;
}

message RegisterMessage {
@@ -147,7 +160,7 @@ message ServersMeta {
bool is_alive = 4;
NodeRole role = 5;
string node_id = 6;
PersistentState persistent_state = 7;
PersistentState persistent_state = 7;
}

message SendMetadataMessage {


+ 1
- 0
mindspore/ccsrc/ps/core/protos/fl.proto View File

@@ -213,6 +213,7 @@ message SyncIterationResponse {
message PrepareForNextIterRequest {
bool is_last_iter_valid = 2;
string reason = 3;
uint64 last_iteration = 4;
}

message PrepareForNextIterResponse {


Loading…
Cancel
Save