| @@ -15,10 +15,19 @@ | |||
| */ | |||
| #include "fl/server/collective_ops_impl.h" | |||
| #include "fl/server/local_meta_store.h" | |||
| #include "fl/server/iteration.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| namespace { | |||
| const char *kCollectivePhaseRing = "ring"; | |||
| const char *kCollectivePhaseGather = "gather"; | |||
| const char *kCollectivePhaseReduce = "reduce"; | |||
| const char *kCollectivePhaseBroadcast = "broadcast"; | |||
| } // namespace | |||
| void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) { | |||
| MS_EXCEPTION_IF_NULL(server_node); | |||
| server_node_ = server_node; | |||
| @@ -28,16 +37,15 @@ void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> & | |||
| } | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(server_node_, false); | |||
| bool CollectiveOpsImpl::RingAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, | |||
| size_t count) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false); | |||
| int ret; | |||
| if (recvbuff != sendbuff) { | |||
| size_t src_size = count * sizeof(T); | |||
| size_t dst_size = count * sizeof(T); | |||
| ret = memcpy_s(recvbuff, dst_size, sendbuff, src_size); | |||
| auto ret = memcpy_s(recvbuff, dst_size, sendbuff, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return false; | |||
| @@ -68,32 +76,61 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size | |||
| << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank | |||
| << ", recv_from_rank:" << recv_from_rank; | |||
| return RunRingAllReduce<T>(data_name, send_to_rank, recv_from_rank, chunk_sizes, chunk_offset, output_buff); | |||
| } | |||
| // Implementation of RingAllReduce. | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::RunRingAllReduce(const std::string &data_name, uint32_t send_to_rank, uint32_t recv_from_rank, | |||
| const std::vector<size_t> &chunk_sizes, | |||
| const std::vector<size_t> &chunk_offset, T *output_buff) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(server_node_, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(output_buff, false); | |||
| auto curr_iteration_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| ps::core::CollectiveMessageMeta send_meta; | |||
| send_meta.set_enable_flag(true); | |||
| send_meta.set_send_rank_id(rank_id_); | |||
| send_meta.set_recv_rank_id(send_to_rank); | |||
| send_meta.set_iteration(curr_iteration_num); | |||
| send_meta.set_weight_name(data_name); | |||
| ps::core::CollectiveMessageMeta recv_meta; | |||
| recv_meta.set_enable_flag(true); | |||
| recv_meta.set_send_rank_id(recv_from_rank); | |||
| recv_meta.set_recv_rank_id(rank_id_); | |||
| recv_meta.set_iteration(curr_iteration_num); | |||
| recv_meta.set_weight_name(data_name); | |||
| // Ring ReduceScatter. | |||
| MS_LOG(DEBUG) << "Start Ring ReduceScatter."; | |||
| send_meta.set_phase(kCollectivePhaseRing); | |||
| recv_meta.set_phase(kCollectivePhaseRing); | |||
| uint32_t rank_size = server_num_; | |||
| for (size_t i = 0; i < rank_size - 1; i++) { | |||
| // Step 1: Async send data to next rank. | |||
| size_t send_chunk_index = (rank_id_ - i + rank_size) % rank_size; | |||
| T *send_chunk = output_buff + chunk_offset[send_chunk_index]; | |||
| send_meta.set_chunk_index(send_chunk_index); | |||
| send_meta.set_for_index(i); | |||
| auto send_chunk_count = chunk_sizes[send_chunk_index]; | |||
| auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk, | |||
| send_chunk_count * sizeof(T)); | |||
| auto send_req_id = server_node_->FlCollectiveSendAsync(send_meta, send_chunk, send_chunk_count * sizeof(T)); | |||
| // Step 2: Async receive data to next rank and wait until it's done. | |||
| size_t recv_chunk_index = (rank_id_ - i - 1 + rank_size) % rank_size; | |||
| recv_meta.set_chunk_index(recv_chunk_index); | |||
| recv_meta.set_for_index(i); | |||
| T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; | |||
| auto recv_chunk_count = chunk_sizes[recv_chunk_index]; | |||
| MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank | |||
| << ", send count:" << send_chunk_count << ", recv count:" << recv_chunk_count << ", iteration:" << i; | |||
| << ", send chunk index:" << send_chunk_index << ", send count:" << send_chunk_count | |||
| << ", recv chunk index:" << recv_chunk_index << ", recv count:" << recv_chunk_count | |||
| << ", for index:" << i; | |||
| std::shared_ptr<std::vector<uint8_t>> recv_str; | |||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||
| return false; | |||
| } | |||
| if (recv_chunk_count * sizeof(T) != recv_str->size()) { | |||
| MS_LOG(ERROR) << "Expect receive chunk size " << recv_chunk_count * sizeof(T) << " from rank " << recv_from_rank | |||
| << " != real receive chunk size " << recv_str->size() << ", current rank: " << rank_id_ | |||
| << ", total data size " << count * sizeof(T); | |||
| auto expect_size = recv_chunk_count * sizeof(T); | |||
| if (!server_node_->FlCollectiveWait(recv_meta, expect_size, &recv_str, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "FlCollectiveWait failed, send rank id: " << recv_meta.send_rank_id(); | |||
| return false; | |||
| } | |||
| auto tmp_recv_chunk = reinterpret_cast<T *>(recv_str->data()); | |||
| @@ -103,7 +140,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size | |||
| } | |||
| // Step 4: Wait until send is done. | |||
| if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||
| MS_LOG(ERROR) << "Wait response of rank " << send_req_id << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -111,30 +148,40 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size | |||
| // Ring AllGather. | |||
| MS_LOG(DEBUG) << "Start Ring AllGather."; | |||
| send_meta.set_phase(kCollectivePhaseGather); | |||
| recv_meta.set_phase(kCollectivePhaseGather); | |||
| for (size_t i = 0; i < rank_size - 1; i++) { | |||
| size_t send_chunk_index = (rank_id_ - i + 1 + rank_size) % rank_size; | |||
| T *send_chunk = output_buff + chunk_offset[send_chunk_index]; | |||
| auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk, | |||
| chunk_sizes[send_chunk_index] * sizeof(T)); | |||
| send_meta.set_chunk_index(send_chunk_index); | |||
| send_meta.set_for_index(i); | |||
| auto send_chunk_count = chunk_sizes[send_chunk_index]; | |||
| auto send_req_id = server_node_->FlCollectiveSendAsync(send_meta, send_chunk, send_chunk_count * sizeof(T)); | |||
| size_t recv_chunk_index = (rank_id_ - i + rank_size) % rank_size; | |||
| T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; | |||
| recv_meta.set_chunk_index(recv_chunk_index); | |||
| recv_meta.set_for_index(i); | |||
| auto recv_chunk_count = chunk_sizes[recv_chunk_index]; | |||
| MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank | |||
| << ", send count:" << chunk_sizes[send_chunk_index] | |||
| << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; | |||
| << ", send chunk index:" << send_chunk_index << ", send count:" << send_chunk_count | |||
| << ", recv chunk index:" << recv_chunk_index << ", recv count:" << recv_chunk_count | |||
| << ", for index:" << i; | |||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||
| auto expect_size = recv_chunk_count * sizeof(T); | |||
| if (!server_node_->FlCollectiveWait(recv_meta, expect_size, &recv_str, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "FlCollectiveWait failed, send rank id: " << recv_meta.send_rank_id(); | |||
| return false; | |||
| } | |||
| ret = memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size()); | |||
| auto ret = memcpy_s(recv_chunk, expect_size, recv_str->data(), recv_str->size()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")" | |||
| << ", dest size is " << recv_chunk_count * sizeof(T) << ", src size is " << recv_str->size(); | |||
| return false; | |||
| } | |||
| if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||
| MS_LOG(ERROR) << "Wait response of rank " << send_req_id << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -143,7 +190,8 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size | |||
| } | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||
| bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, | |||
| size_t count) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(server_node_, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false); | |||
| @@ -155,37 +203,54 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec | |||
| size_t dst_size = count * sizeof(T); | |||
| int ret = memcpy_s(recvbuff, dst_size, sendbuff, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")" | |||
| << ", dest size is " << dst_size << ", src size is " << src_size; | |||
| return false; | |||
| } | |||
| T *output_buff = reinterpret_cast<T *>(recvbuff); | |||
| // Reduce data to rank 0 process. | |||
| auto curr_iteration_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| ps::core::CollectiveMessageMeta send_meta; | |||
| send_meta.set_enable_flag(true); | |||
| send_meta.set_send_rank_id(rank_id_); | |||
| send_meta.set_iteration(curr_iteration_num); | |||
| send_meta.set_weight_name(data_name); | |||
| send_meta.set_chunk_index(0); | |||
| send_meta.set_for_index(0); | |||
| ps::core::CollectiveMessageMeta recv_meta; | |||
| recv_meta.set_enable_flag(true); | |||
| recv_meta.set_recv_rank_id(rank_id_); | |||
| recv_meta.set_iteration(curr_iteration_num); | |||
| recv_meta.set_weight_name(data_name); | |||
| recv_meta.set_chunk_index(0); | |||
| recv_meta.set_for_index(0); | |||
| send_meta.set_phase(kCollectivePhaseReduce); | |||
| recv_meta.set_phase(kCollectivePhaseReduce); | |||
| MS_LOG(DEBUG) << "Start Reduce to rank 0 process."; | |||
| if (rank_id_ == 0) { | |||
| for (uint32_t i = 1; i < rank_size; i++) { | |||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||
| MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i; | |||
| auto recv_req_id1 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id1, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id1 << " failed."; | |||
| return false; | |||
| } | |||
| if (count * sizeof(T) != recv_str->size()) { | |||
| MS_LOG(ERROR) << "Expect receive chunk size " << count * sizeof(T) << " from rank " << i | |||
| << " != real receive chunk size " << recv_str->size() << ", current rank: " << rank_id_ | |||
| << ", total data size " << count * sizeof(T); | |||
| recv_meta.set_send_rank_id(i); | |||
| auto expect_size = count * sizeof(T); | |||
| if (!server_node_->FlCollectiveWait(recv_meta, expect_size, &recv_str, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "FlCollectiveWait failed, send rank id: " << recv_meta.send_rank_id(); | |||
| return false; | |||
| } | |||
| auto tmp_recv_chunk = reinterpret_cast<T *>(recv_str->data()); | |||
| auto tmp_recv_chunk = reinterpret_cast<T *>(recv_str->data()); // recv_str size has checked in FlCollectiveWait | |||
| for (size_t j = 0; j < count; j++) { | |||
| output_buff[j] += tmp_recv_chunk[j]; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "Reduce send data to rank 0 process."; | |||
| auto send_req_id1 = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); | |||
| send_meta.set_recv_rank_id(0); | |||
| auto send_req_id1 = server_node_->FlCollectiveSendAsync(send_meta, sendbuff, count * sizeof(T)); | |||
| if (!server_node_->Wait(send_req_id1, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id1 << " failed."; | |||
| MS_LOG(ERROR) << "Wait response of rank " << send_req_id1 << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -193,27 +258,31 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec | |||
| // Broadcast data to not 0 rank process. | |||
| MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes."; | |||
| send_meta.set_phase(kCollectivePhaseBroadcast); | |||
| recv_meta.set_phase(kCollectivePhaseBroadcast); | |||
| if (rank_id_ == 0) { | |||
| for (uint32_t i = 1; i < rank_size; i++) { | |||
| MS_LOG(DEBUG) << "Broadcast data to process " << i; | |||
| auto send_req_id2 = | |||
| server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T)); | |||
| send_meta.set_recv_rank_id(i); | |||
| auto send_req_id2 = server_node_->FlCollectiveSendAsync(send_meta, output_buff, count * sizeof(T)); | |||
| if (!server_node_->Wait(send_req_id2, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id2 << " failed."; | |||
| MS_LOG(ERROR) << "Wait response of rank " << send_req_id2 << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "Broadcast receive from rank 0."; | |||
| recv_meta.set_send_rank_id(0); | |||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||
| auto recv_req_id2 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str); | |||
| if (!server_node_->CollectiveWait(recv_req_id2, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id2 << " failed."; | |||
| auto expect_size = count * sizeof(T); | |||
| if (!server_node_->FlCollectiveWait(recv_meta, expect_size, &recv_str, kCollectiveCommTimeout)) { | |||
| MS_LOG(ERROR) << "FlCollectiveWait failed, send rank id: " << recv_meta.send_rank_id(); | |||
| return false; | |||
| } | |||
| ret = memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size()); | |||
| ret = memcpy_s(output_buff, expect_size, recv_str->data(), recv_str->size()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")" | |||
| << ", dest size is " << expect_size << ", src size is " << recv_str->size(); | |||
| return false; | |||
| } | |||
| } | |||
| @@ -249,7 +318,8 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff | |||
| size_t dst_size = send_count * sizeof(T); | |||
| int ret = memcpy_s(output_buff + chunk_offset[rank_id_], dst_size, sendbuff, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")" | |||
| << ", dest size is " << dst_size << ", src size is " << src_size; | |||
| return false; | |||
| } | |||
| @@ -273,7 +343,9 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff | |||
| } | |||
| ret = memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")" | |||
| << ", dest size is " << (chunk_sizes[recv_chunk_index] * sizeof(T)) << ", src size is " | |||
| << recv_str->size(); | |||
| return false; | |||
| } | |||
| if (!node_->Wait(send_req_id, kCollectiveCommTimeout)) { | |||
| @@ -322,7 +394,8 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c | |||
| } | |||
| int ret = memcpy_s(recvbuff, count * sizeof(T), recv_str->data(), recv_str->size()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")" | |||
| << ", dest size is " << (count * sizeof(T)) << ", src size is " << recv_str->size(); | |||
| return false; | |||
| } | |||
| } | |||
| @@ -331,11 +404,12 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c | |||
| } | |||
| template <typename T> | |||
| bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||
| bool CollectiveOpsImpl::AllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count) { | |||
| // The collective communication API does not support calling Send and Recv concurrently with multiple threads; | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(server_node_, false); | |||
| uint32_t rank_size = server_num_; | |||
| if (rank_size == 0) { | |||
| @@ -346,11 +420,16 @@ bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t c | |||
| MS_LOG(INFO) << "Rank size is 1. Do nothing."; | |||
| return true; | |||
| } | |||
| auto cur_iteration_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| if (server_node_->HasIterationFailed(cur_iteration_num)) { | |||
| MS_LOG(WARNING) << "Detect iteration " << cur_iteration_num << " has failed"; | |||
| return false; | |||
| } | |||
| if (count >= rank_size) { | |||
| return RingAllReduce<T>(sendbuff, recvbuff, count); | |||
| return RingAllReduce<T>(data_name, sendbuff, recvbuff, count); | |||
| } else { | |||
| return ReduceBroadcastAllReduce<T>(sendbuff, recvbuff, count); | |||
| return ReduceBroadcastAllReduce<T>(data_name, sendbuff, recvbuff, count); | |||
| } | |||
| } | |||
| @@ -429,17 +508,12 @@ bool CollectiveOpsImpl::ReInitForScaling() { | |||
| return true; | |||
| } | |||
| template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<float>(const std::string &data_name, const void *sendbuff, void *recvbuff, | |||
| size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<size_t>(const std::string &data_name, const void *sendbuff, void *recvbuff, | |||
| size_t count); | |||
| template bool CollectiveOpsImpl::AllReduce<int>(const std::string &data_name, const void *sendbuff, void *recvbuff, | |||
| size_t count); | |||
| template bool CollectiveOpsImpl::AllGather<float>(const void *sendbuff, void *recvbuff, size_t send_count, | |||
| const std::shared_ptr<ps::core::AbstractNode> &node); | |||
| @@ -62,7 +62,7 @@ class CollectiveOpsImpl { | |||
| void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node); | |||
| template <typename T> | |||
| bool AllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||
| bool AllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count); | |||
| template <typename T> | |||
| bool AllGather(const void *sendbuff, void *recvbuff, size_t send_count, | |||
| @@ -91,11 +91,17 @@ class CollectiveOpsImpl { | |||
| // Implementation of RingAllReduce. | |||
| template <typename T> | |||
| bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||
| bool RunRingAllReduce(const std::string &data_name, uint32_t send_to_rank, uint32_t recv_from_rank, | |||
| const std::vector<size_t> &chunk_sizes, const std::vector<size_t> &chunk_offset, | |||
| T *output_buff); | |||
| // Implementation of RingAllReduce. | |||
| template <typename T> | |||
| bool RingAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count); | |||
| // Implementation of BroadcastAllReduce. | |||
| template <typename T> | |||
| bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||
| bool ReduceBroadcastAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count); | |||
| // Implementation of RingAllGather. | |||
| template <typename T> | |||
| @@ -18,6 +18,8 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "fl/server/iteration.h" | |||
| #include "fl/server/server.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| @@ -252,20 +254,17 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core: | |||
| MS_LOG(INFO) << "Global current count for " << name << " is: " << global_current_count_[name].size() << "/" | |||
| << global_threshold_count_[name]; | |||
| } | |||
| std::string reason = "success"; | |||
| if (!TriggerCounterEvent(name, &reason)) { | |||
| count_rsp.set_result(false); | |||
| count_rsp.set_reason(reason); | |||
| } else { | |||
| count_rsp.set_result(true); | |||
| count_rsp.set_reason(reason); | |||
| } | |||
| count_rsp.set_result(true); | |||
| count_rsp.set_reason("success"); | |||
| if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), | |||
| message)) { | |||
| MS_LOG(WARNING) << "Sending response failed."; | |||
| return; | |||
| } | |||
| return; | |||
| std::string reason = "success"; | |||
| if (!TriggerCounterEvent(name, &reason)) { | |||
| Iteration::GetInstance().NotifyNext(false, reason); | |||
| } | |||
| } | |||
| void DistributedCountService::HandleCountReachThresholdRequest( | |||
| @@ -360,15 +359,16 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st | |||
| return false; | |||
| } | |||
| } | |||
| if (counter_handlers_.count(name) == 0) { | |||
| auto counter_it = counter_handlers_.find(name); | |||
| if (counter_it == counter_handlers_.end() || !counter_it->second.first_count_handler) { | |||
| MS_LOG(WARNING) << "The counter handler of " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| // Leader server directly calls the callback. | |||
| MS_LOG(DEBUG) << "Leader server call first count handler for " << name; | |||
| counter_handlers_[name].first_count_handler(nullptr); | |||
| MS_LOG(DEBUG) << "First count handler for " << name << " is successfully called."; | |||
| auto count_handler = counter_it->second.first_count_handler; | |||
| Server::GetInstance().SubmitTask([count_handler]() { count_handler(nullptr); }); | |||
| MS_LOG(DEBUG) << "First count handler for " << name << " is successfully submitted."; | |||
| return true; | |||
| } | |||
| @@ -389,15 +389,16 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std | |||
| return false; | |||
| } | |||
| } | |||
| if (counter_handlers_.count(name) == 0) { | |||
| auto counter_it = counter_handlers_.find(name); | |||
| if (counter_it == counter_handlers_.end() || !counter_it->second.last_count_handler) { | |||
| MS_LOG(WARNING) << "The counter handler of " << name << " is not registered."; | |||
| return false; | |||
| } | |||
| // Leader server directly calls the callback. | |||
| MS_LOG(DEBUG) << "Leader server call last count handler for " << name; | |||
| counter_handlers_[name].last_count_handler(nullptr); | |||
| MS_LOG(INFO) << "Last count handler for " << name << " is successfully called."; | |||
| auto count_handler = counter_it->second.last_count_handler; | |||
| Server::GetInstance().SubmitTask([count_handler]() { count_handler(nullptr); }); | |||
| MS_LOG(INFO) << "Last count handler for " << name << " is successfully submitted."; | |||
| return true; | |||
| } | |||
| } // namespace server | |||
| @@ -140,6 +140,27 @@ std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<s | |||
| bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); } | |||
| bool Executor::RunAllWeightAggregation() { | |||
| for (const auto &name : param_names_) { | |||
| if (param_aggrs_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Weight " << name << " is invalid in server."; | |||
| return false; | |||
| } | |||
| std::mutex &mtx = parameter_mutex_[name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[name]; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); | |||
| if (!param_aggr->requires_aggr()) { | |||
| continue; | |||
| } | |||
| if (!param_aggr->RunAggregation()) { | |||
| MS_LOG(WARNING) << "Failed to run aggregation for " << name; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool Executor::IsWeightAggrDone(const std::vector<std::string> ¶m_names) { | |||
| for (const auto &name : param_names) { | |||
| if (param_aggrs_.count(name) == 0) { | |||
| @@ -69,6 +69,8 @@ class Executor { | |||
| // Judge whether aggregation processes for all weights/gradients are completed. | |||
| bool IsAllWeightAggregationDone(); | |||
| bool RunAllWeightAggregation(); | |||
| // Judge whether the aggregation processes for the given param_names are completed. | |||
| bool IsWeightAggrDone(const std::vector<std::string> ¶m_names); | |||
| @@ -121,7 +121,7 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string & | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); | |||
| if (server_node_->rank_id() == kLeaderServerRank) { | |||
| if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) { | |||
| if (!BroadcastPrepareForNextIterRequest(iteration_num_, is_last_iter_valid, reason)) { | |||
| MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed."; | |||
| return; | |||
| } | |||
| @@ -452,7 +452,7 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps | |||
| return; | |||
| } | |||
| if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) { | |||
| if (!BroadcastPrepareForNextIterRequest(iter_num, is_last_iter_valid, reason)) { | |||
| MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed."; | |||
| return; | |||
| } | |||
| @@ -466,13 +466,15 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps | |||
| } | |||
| } | |||
| bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) { | |||
| bool Iteration::BroadcastPrepareForNextIterRequest(size_t last_iteration, bool is_last_iter_valid, | |||
| const std::string &reason) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false); | |||
| PrepareForNextIter(); | |||
| PrepareForNextIter(last_iteration, is_last_iter_valid); | |||
| MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration."; | |||
| PrepareForNextIterRequest prepare_next_iter_req; | |||
| prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid); | |||
| prepare_next_iter_req.set_reason(reason); | |||
| prepare_next_iter_req.set_last_iteration(last_iteration); | |||
| std::vector<uint32_t> offline_servers = {}; | |||
| for (uint32_t i = 1; i < server_node_->server_num(); i++) { | |||
| @@ -504,8 +506,12 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core:: | |||
| PrepareForNextIterRequest prepare_next_iter_req; | |||
| (void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len())); | |||
| const auto &reason = prepare_next_iter_req.reason(); | |||
| MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason; | |||
| PrepareForNextIter(); | |||
| auto is_last_iter_valid = prepare_next_iter_req.is_last_iter_valid(); | |||
| auto last_iteration = prepare_next_iter_req.last_iteration(); | |||
| MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() | |||
| << ", last iteration: " << last_iteration << ", last iteration valid: " << is_last_iter_valid | |||
| << ", reason: " << reason; | |||
| PrepareForNextIter(last_iteration, is_last_iter_valid); | |||
| PrepareForNextIterResponse prepare_next_iter_rsp; | |||
| prepare_next_iter_rsp.set_result("success"); | |||
| @@ -516,9 +522,12 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core:: | |||
| } | |||
| } | |||
| void Iteration::PrepareForNextIter() { | |||
| void Iteration::PrepareForNextIter(size_t last_iteration, bool is_last_iter_valid) { | |||
| MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode."; | |||
| Server::GetInstance().SwitchToSafeMode(); | |||
| if (server_node_) { | |||
| server_node_->SetIterationResult(last_iteration, is_last_iter_valid); | |||
| } | |||
| MS_LOG(INFO) << "Start waiting for rounds to finish."; | |||
| WaitAllRoundsFinish(); | |||
| MS_LOG(INFO) << "End waiting for rounds to finish."; | |||
| @@ -189,10 +189,10 @@ class Iteration { | |||
| void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); | |||
| // Step 2: leader server broadcasts to all follower servers to prepare for next iteration and switch to safemode.. | |||
| bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason); | |||
| bool BroadcastPrepareForNextIterRequest(size_t last_iteration, bool is_last_iter_valid, const std::string &reason); | |||
| void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); | |||
| // The server prepare for the next iteration. This method will switch the server to safemode. | |||
| void PrepareForNextIter(); | |||
| void PrepareForNextIter(size_t last_iteration, bool is_last_iter_valid); | |||
| // Step 3: leader server broadcasts to all follower servers to move to next iteration. | |||
| bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason); | |||
| @@ -45,6 +45,7 @@ class AggregationKernelMod : public NativeCpuKernelMod { | |||
| const std::vector<AddressPtr> &outputs) { | |||
| return true; | |||
| } | |||
| virtual bool AllReduce() = 0; | |||
| // Server kernel's memory allocation method, which is different from the workflow in | |||
| // Session(GPUSession/CPUSession/AscendSession). | |||
| @@ -43,6 +43,7 @@ class DenseGradAccumKernel : public AggregationKernelMod { | |||
| return true; | |||
| } | |||
| bool AllReduce() override { return true; } | |||
| void Reset() { accum_count_ = 0; } | |||
| bool IsAggregationDone() { return accum_count_ >= done_count_; } | |||
| @@ -48,8 +48,7 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| weight_addr_(nullptr), | |||
| data_size_addr_(nullptr), | |||
| new_weight_addr_(nullptr), | |||
| new_data_size_addr_(nullptr), | |||
| participated_(false) {} | |||
| new_data_size_addr_(nullptr) {} | |||
| ~FedAvgKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override { | |||
| @@ -76,45 +75,42 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| .first; | |||
| MS_EXCEPTION_IF_NULL(weight_node); | |||
| name_ = cnode_name + "." + weight_node->fullname_with_scope(); | |||
| first_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) { | |||
| std::unique_lock<std::mutex> lock(weight_mutex_); | |||
| if (!participated_) { | |||
| ClearWeightAndDataSize(); | |||
| } | |||
| }; | |||
| last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) { | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_->addr); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_->addr); | |||
| std::unique_lock<std::mutex> lock(weight_mutex_); | |||
| T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr); | |||
| size_t weight_size = weight_addr_->size; | |||
| S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr); | |||
| if (!CollectiveOpsImpl::GetInstance().AllReduce<T>(weight_addr, weight_addr, weight_size / sizeof(T))) { | |||
| MS_LOG(ERROR) << "Federated average allreduce failed."; | |||
| return; | |||
| } | |||
| if (!CollectiveOpsImpl::GetInstance().AllReduce<S>(data_size_addr, data_size_addr, 1)) { | |||
| MS_LOG(ERROR) << "Federated average allreduce failed."; | |||
| return; | |||
| } | |||
| if (data_size_addr[0] == 0) { | |||
| MS_LOG(ERROR) << "After AllReduce, the data size is 0."; | |||
| return; | |||
| } | |||
| LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]); | |||
| for (size_t i = 0; i < weight_size / sizeof(T); i++) { | |||
| weight_addr[i] /= data_size_addr[0]; | |||
| } | |||
| done_ = true; | |||
| return; | |||
| }; | |||
| DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_}); | |||
| MS_LOG(INFO) << "Aggregate Weight full name is " << weight_node->fullname_with_scope() << ", weight byte size is " | |||
| << weight_size; | |||
| GenerateReuseKernelNodeInfo(); | |||
| return; | |||
| } | |||
| bool AllReduce() override { | |||
| std::unique_lock<std::mutex> lock(weight_mutex_); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(weight_addr_, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(data_size_addr_, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(weight_addr_->addr, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(data_size_addr_->addr, false); | |||
| T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr); | |||
| size_t weight_size = weight_addr_->size; | |||
| S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr); | |||
| if (!CollectiveOpsImpl::GetInstance().AllReduce<T>(name_, weight_addr, weight_addr, weight_size / sizeof(T))) { | |||
| MS_LOG(ERROR) << "Federated average allreduce failed."; | |||
| return false; | |||
| } | |||
| if (!CollectiveOpsImpl::GetInstance().AllReduce<S>(name_ + "_data_size", data_size_addr, data_size_addr, 1)) { | |||
| MS_LOG(ERROR) << "Federated average allreduce failed."; | |||
| return false; | |||
| } | |||
| if (data_size_addr[0] == 0) { | |||
| MS_LOG(ERROR) << "After AllReduce, the data size is 0."; | |||
| return false; | |||
| } | |||
| LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]); | |||
| for (size_t i = 0; i < weight_size / sizeof(T); i++) { | |||
| weight_addr[i] /= data_size_addr[0]; | |||
| } | |||
| done_ = true; | |||
| return true; | |||
| } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override { | |||
| if (inputs.size() != kFedAvgInputsNum) { | |||
| @@ -126,15 +122,16 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| } | |||
| std::unique_lock<std::mutex> lock(weight_mutex_); | |||
| if (done_) { | |||
| MS_LOG(INFO) << "AllReduce for " << name_ << " has finished"; | |||
| return true; | |||
| } | |||
| // The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication | |||
| // again. | |||
| T *weight_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| S *data_size_addr = reinterpret_cast<S *>(inputs[1]->addr); | |||
| T *new_weight_addr = reinterpret_cast<T *>(inputs[2]->addr); | |||
| S *new_data_size_addr = reinterpret_cast<S *>(inputs[3]->addr); | |||
| if (accum_count_ == 0) { | |||
| ClearWeightAndDataSize(); | |||
| } | |||
| MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for " | |||
| << name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is " | |||
| @@ -146,17 +143,13 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| lock.unlock(); | |||
| accum_count_++; | |||
| participated_ = true; | |||
| return DistributedCountService::GetInstance().Count( | |||
| name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); | |||
| return true; | |||
| } | |||
| void Reset() override { | |||
| accum_count_ = 0; | |||
| done_ = false; | |||
| participated_ = false; | |||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||
| return; | |||
| ClearWeightAndDataSize(); | |||
| } | |||
| bool IsAggregationDone() override { return done_; } | |||
| @@ -169,18 +162,8 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| new_data_size_addr_ = inputs[3]; | |||
| return; | |||
| } | |||
| bool ReInitForScaling() override { | |||
| DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_}); | |||
| return true; | |||
| } | |||
| bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override { | |||
| done_count_ = aggr_threshold; | |||
| if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) { | |||
| MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -211,9 +194,6 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| return; | |||
| } | |||
| MessageCallback first_cnt_handler_; | |||
| MessageCallback last_cnt_handler_; | |||
| // The trainable parameter index of the kernel node which is parsed from the frontend func_graph. | |||
| size_t cnode_weight_idx_; | |||
| @@ -222,10 +202,6 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| AddressPtr data_size_addr_; | |||
| AddressPtr new_weight_addr_; | |||
| AddressPtr new_data_size_addr_; | |||
| // Whether the kernel's Launch method is called. | |||
| bool participated_; | |||
| // The kernel could be called concurrently so we need lock to ensure threadsafe. | |||
| std::mutex weight_mutex_; | |||
| }; | |||
| @@ -67,7 +67,7 @@ bool PushMetricsKernel::Reset() { | |||
| void PushMetricsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { | |||
| if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushMetrics) { | |||
| FinishIteration(); | |||
| FinishIteration(true); | |||
| } | |||
| return; | |||
| } | |||
| @@ -72,7 +72,7 @@ bool PushWeightKernel::Reset() { | |||
| void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { | |||
| if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushWeight) { | |||
| FinishIteration(); | |||
| FinishIteration(true); | |||
| } | |||
| return; | |||
| } | |||
| @@ -32,7 +32,7 @@ void ReconstructSecretsKernel::InitKernel(size_t) { | |||
| auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) { | |||
| if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) { | |||
| MS_LOG(INFO) << "start FinishIteration"; | |||
| FinishIteration(); | |||
| FinishIteration(true); | |||
| MS_LOG(INFO) << "end FinishIteration"; | |||
| } | |||
| return; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <utility> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "fl/server/iteration.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| @@ -42,21 +43,18 @@ void RoundKernel::StopTimer() const { | |||
| return; | |||
| } | |||
| void RoundKernel::FinishIteration() const { | |||
| if (finish_iteration_cb_) { | |||
| finish_iteration_cb_(true, ""); | |||
| void RoundKernel::FinishIteration(bool is_last_iter_valid, const std::string &in_reason) const { | |||
| std::string reason = in_reason; | |||
| if (is_last_iter_valid) { | |||
| reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration."; | |||
| } | |||
| return; | |||
| Iteration::GetInstance().NotifyNext(is_last_iter_valid, reason); | |||
| } | |||
| void RoundKernel::set_name(const std::string &name) { name_ = name; } | |||
| void RoundKernel::set_stop_timer_cb(const StopTimerCb &timer_stopper) { stop_timer_cb_ = timer_stopper; } | |||
| void RoundKernel::set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb) { | |||
| finish_iteration_cb_ = finish_iteration_cb; | |||
| } | |||
| void RoundKernel::GenerateOutput(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, | |||
| size_t len) { | |||
| if (message == nullptr) { | |||
| @@ -72,14 +72,13 @@ class RoundKernel { | |||
| // Called after this iteration(including all rounds) is finished. All rounds' Reset method will | |||
| // be called. | |||
| void FinishIteration() const; | |||
| void FinishIteration(bool is_last_iter_valid, const std::string &reason = "") const; | |||
| // Set round kernel name, which could be used in round kernel's methods. | |||
| void set_name(const std::string &name); | |||
| // Set callbacks to be called under certain triggered conditions. | |||
| void set_stop_timer_cb(const StopTimerCb &timer_stopper); | |||
| void set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb); | |||
| void Summarize(); | |||
| @@ -106,7 +105,6 @@ class RoundKernel { | |||
| size_t current_count_; | |||
| StopTimerCb stop_timer_cb_; | |||
| FinishIterCb finish_iteration_cb_; | |||
| // Members below are used for allocating and releasing response data on the heap. | |||
| @@ -25,7 +25,8 @@ namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| namespace kernel { | |||
| constexpr uint32_t kRetryCountOfWaitWeightAggregation = 30; | |||
| const char *kCountForAggregation = "count_for_aggregation"; | |||
| void UpdateModelKernel::InitKernel(size_t threshold_count) { | |||
| if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { | |||
| iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); | |||
| @@ -42,6 +43,11 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) { | |||
| DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list); | |||
| LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count); | |||
| LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum); | |||
| auto first_cnt_handler = [](std::shared_ptr<ps::core::MessageHandler>) {}; | |||
| auto last_cnt_handler = [this](std::shared_ptr<ps::core::MessageHandler>) { RunAggregation(); }; | |||
| DistributedCountService::GetInstance().RegisterCounter(kCountForAggregation, threshold_count, | |||
| {first_cnt_handler, last_cnt_handler}); | |||
| } | |||
| bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| @@ -72,7 +78,7 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| } | |||
| const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data); | |||
| if (update_model_req == nullptr) { | |||
| if (update_model_req == nullptr || update_model_req->fl_id() == nullptr) { | |||
| std::string reason = "Building flatbuffers schema failed for RequestUpdateModel."; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| MS_LOG(WARNING) << reason; | |||
| @@ -118,11 +124,16 @@ bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len, | |||
| if (result_code != ResultCode::kSuccess) { | |||
| MS_LOG(WARNING) << "Updating model failed."; | |||
| GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| return ConvertResultCode(result_code); | |||
| return false; | |||
| } | |||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||
| IncreaseAcceptClientNum(); | |||
| GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize()); | |||
| result_code = CountForAggregation(update_model_fl_id); | |||
| if (result_code != ResultCode::kSuccess) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -130,6 +141,7 @@ bool UpdateModelKernel::Reset() { | |||
| MS_LOG(INFO) << "Update model kernel reset!"; | |||
| StopTimer(); | |||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||
| DistributedCountService::GetInstance().ResetCounter(kCountForAggregation); | |||
| executor_->ResetAggregationStatus(); | |||
| DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList); | |||
| size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value<size_t>(kCtxFedAvgTotalDataSize); | |||
| @@ -137,23 +149,22 @@ bool UpdateModelKernel::Reset() { | |||
| return true; | |||
| } | |||
| void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { | |||
| if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) { | |||
| last_count_thread_ = std::make_unique<std::thread>([this]() { | |||
| uint32_t retryCount = 0; | |||
| while (!executor_->IsAllWeightAggregationDone() && retryCount <= kRetryCountOfWaitWeightAggregation) { | |||
| std::this_thread::sleep_for(std::chrono::seconds(1)); | |||
| retryCount += 1; | |||
| } | |||
| size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize); | |||
| MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " | |||
| << total_data_size; | |||
| if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) { | |||
| FinishIteration(); | |||
| } | |||
| }); | |||
| last_count_thread_->detach(); | |||
| void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {} | |||
| void UpdateModelKernel::RunAggregation() { | |||
| auto is_last_iter_valid = Executor::GetInstance().RunAllWeightAggregation(); | |||
| auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num(); | |||
| if (is_last_iter_valid) { | |||
| size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize); | |||
| MS_LOG(INFO) << "Total data size for iteration " << curr_iter_num << " is " << total_data_size; | |||
| if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel && | |||
| ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) { | |||
| FinishIteration(is_last_iter_valid); | |||
| } | |||
| } else { | |||
| std::string reason = "Weight aggregation failed, current iteration: " + std::to_string(curr_iter_num); | |||
| MS_LOG(WARNING) << reason; | |||
| FinishIteration(is_last_iter_valid, reason); | |||
| } | |||
| } | |||
| @@ -279,6 +290,15 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap( | |||
| return feature_map; | |||
| } | |||
| ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) { | |||
| std::string count_reason = ""; | |||
| if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) { | |||
| MS_LOG(ERROR) << "Counting for aggregation failed. reason: " + count_reason; | |||
| return ResultCode::kFail; | |||
| } | |||
| return ResultCode::kSuccess; | |||
| } | |||
| ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestUpdateModel *update_model_req) { | |||
| MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn); | |||
| @@ -55,6 +55,9 @@ class UpdateModelKernel : public RoundKernel { | |||
| ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb, | |||
| const DeviceMeta &device_meta); | |||
| std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); | |||
| void RunAggregation(); | |||
| ResultCode CountForAggregation(const std::string &req_fl_id); | |||
| ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||
| const schema::RequestUpdateModel *update_model_req); | |||
| sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req); | |||
| @@ -67,8 +70,6 @@ class UpdateModelKernel : public RoundKernel { | |||
| // The time window of one iteration. | |||
| size_t iteration_time_window_{0}; | |||
| std::unique_ptr<std::thread> last_count_thread_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| @@ -152,6 +152,17 @@ bool ParameterAggregator::IsAggregationDone() const { | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::RunAggregation() { | |||
| for (auto &aggregator_with_params : aggregation_kernel_parameters_) { | |||
| std::shared_ptr<kernel::AggregationKernelMod> aggr_kernel = aggregator_with_params.first; | |||
| MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false); | |||
| if (!aggr_kernel->AllReduce()) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; } | |||
| bool ParameterAggregator::IsPullingDone() const { return pulling_done_; } | |||
| @@ -90,6 +90,7 @@ class ParameterAggregator { | |||
| // Returns the aggregation/optimizing/pulling status to the caller. | |||
| bool IsAggregationDone() const; | |||
| bool RunAggregation(); | |||
| bool IsOptimizingDone() const; | |||
| bool IsPullingDone() const; | |||
| @@ -45,12 +45,6 @@ void Round::RegisterMsgCallBack(const std::shared_ptr<ps::core::CommunicatorBase | |||
| void Round::Initialize(const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) { | |||
| MS_LOG(INFO) << "Round " << name_ << " start initialize."; | |||
| // Callback when the iteration is finished. | |||
| finish_iteration_cb_ = [this, finish_iteration_cb](bool, const std::string &) -> void { | |||
| std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration."; | |||
| finish_iteration_cb(true, reason); | |||
| }; | |||
| if (check_timeout_) { | |||
| iter_timer_ = std::make_shared<IterationTimer>(); | |||
| MS_EXCEPTION_IF_NULL(iter_timer_); | |||
| @@ -115,7 +109,6 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| kernel_ = kernel; | |||
| kernel_->set_stop_timer_cb(stop_timer_cb_); | |||
| kernel_->set_finish_iteration_cb(finish_iteration_cb_); | |||
| return; | |||
| } | |||
| @@ -110,7 +110,6 @@ class Round { | |||
| // The callbacks which will be set to the round kernel. | |||
| StopTimerCb stop_timer_cb_; | |||
| FinishIterCb finish_iteration_cb_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace fl | |||
| @@ -159,6 +159,13 @@ void Server::InitCluster() { | |||
| return; | |||
| } | |||
| bool Server::SubmitTask(std::function<void()> &&task) { | |||
| if (task_executor_ == nullptr) { | |||
| return false; | |||
| } | |||
| return task_executor_->Submit(task); | |||
| } | |||
| bool Server::InitCommunicatorWithServer() { | |||
| MS_EXCEPTION_IF_NULL(task_executor_); | |||
| MS_EXCEPTION_IF_NULL(server_node_); | |||
| @@ -396,6 +403,8 @@ void Server::InitExecutor() { | |||
| MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_; | |||
| Executor::GetInstance().Initialize(func_graph, executor_threshold_); | |||
| ModelStore::GetInstance().Initialize(); | |||
| // init weight memory to 0 after get model | |||
| Executor::GetInstance().ResetAggregationStatus(); | |||
| return; | |||
| } | |||
| @@ -76,6 +76,8 @@ class Server { | |||
| // Whether the training job of the server is enabled. | |||
| InstanceState instance_state() const; | |||
| bool SubmitTask(std::function<void()> &&task); | |||
| private: | |||
| Server() | |||
| : server_node_(nullptr), | |||
| @@ -341,6 +341,136 @@ uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint | |||
| return SendMessageAsync(client, message_meta, Protos::RAW, data, size); | |||
| } | |||
| static std::string CollectiveMetaToString(const CollectiveMessageMeta &meta) { | |||
| std::ostringstream os; | |||
| os << "{iteration:" << meta.iteration() << ", data:" << meta.weight_name() << ", send rank:" << meta.send_rank_id() | |||
| << ", recv rank:" << meta.recv_rank_id() << ", phase:" << meta.phase() << ", chunk index:" << meta.chunk_index() | |||
| << ", for index:" << meta.for_index() << "}"; | |||
| return os.str(); | |||
| } | |||
| uint64_t AbstractNode::FlCollectiveSendAsync(const CollectiveMessageMeta &collective_meta, const void *data, | |||
| size_t size) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto recv_rank_id = collective_meta.recv_rank_id(); | |||
| if (!CommUtil::ValidateRankId(SERVER, recv_rank_id, worker_num_, server_num_)) { | |||
| MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_ | |||
| << ", the server num:" << server_num_ << ", the rank id:" << recv_rank_id; | |||
| return 0; | |||
| } | |||
| std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>(); | |||
| MS_EXCEPTION_IF_NULL(message_meta); | |||
| message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| message_meta->set_role(node_info_.node_role_); | |||
| *(message_meta->mutable_collective_meta()) = collective_meta; | |||
| message_meta->mutable_collective_meta()->set_enable_flag(true); | |||
| message_meta->mutable_collective_meta()->set_send_rank_id(node_info_.rank_id_); | |||
| MS_LOG(DEBUG) << "Send data to rank id:" << recv_rank_id | |||
| << ", send meta:" << CollectiveMetaToString(message_meta->collective_meta()); | |||
| auto client = GetOrCreateTcpClient(recv_rank_id, SERVER); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| return SendMessageAsync(client, message_meta, Protos::RAW, data, size); | |||
| } | |||
| bool AbstractNode::FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output, | |||
| const uint32_t &timeout) { | |||
| if (output == nullptr) { | |||
| return false; | |||
| } | |||
| auto send_rank_id = expect_meta.send_rank_id(); | |||
| if (!CommUtil::ValidateRankId(SERVER, send_rank_id, worker_num_, server_num_)) { | |||
| MS_LOG(ERROR) << "The node role or rank_id is illegal, the worker num:" << worker_num_ | |||
| << ", the server num:" << server_num_ << ", the rank id:" << send_rank_id; | |||
| return false; | |||
| } | |||
| auto check_meta = [](const CollectiveMessageMeta &left, const CollectiveMessageMeta &right) { | |||
| return left.iteration() == right.iteration() && left.weight_name() == right.weight_name() && | |||
| left.recv_rank_id() == right.recv_rank_id() && left.send_rank_id() == right.send_rank_id() && | |||
| left.phase() == right.phase() && left.chunk_index() == right.chunk_index() && | |||
| left.for_index() == right.for_index(); | |||
| }; | |||
| auto iteration_num = expect_meta.iteration(); | |||
| std::unique_lock<std::mutex> lock(fl_receive_mutex_); | |||
| auto &recv_data_list = fl_received_data_[send_rank_id]; | |||
| for (uint32_t i = 0; i < timeout; i++) { | |||
| if (recv_data_list.empty()) { | |||
| fl_receive_cond_.wait_for(lock, std::chrono::seconds(1), [&recv_data_list]() { return !recv_data_list.empty(); }); | |||
| if (recv_data_list.empty()) { // timeout | |||
| if (HasIterationFailed(iteration_num)) { // if result of iteration reported by other server is failed | |||
| MS_LOG(WARNING) << "Detect iteration " << iteration_num << " has failed"; | |||
| return false; | |||
| } | |||
| continue; | |||
| } | |||
| } | |||
| while (!recv_data_list.empty()) { | |||
| auto first = recv_data_list.begin(); | |||
| auto recv_meta = std::move(first->first); | |||
| auto recv_data = std::move(first->second); | |||
| recv_data_list.erase(first); | |||
| MS_LOG(DEBUG) << "Handle receive data from rank id:" << send_rank_id | |||
| << ", recv meta:" << CollectiveMetaToString(recv_meta); | |||
| if (recv_meta.iteration() != expect_meta.iteration()) { | |||
| MS_LOG(WARNING) << "Skip recv data, iteration of recv meta " << recv_meta.iteration() | |||
| << " != iteration of expected meta " << expect_meta.iteration(); | |||
| continue; | |||
| } | |||
| // error data in the same iteration | |||
| if (!check_meta(recv_meta, expect_meta)) { | |||
| MS_LOG(WARNING) << "Recv meta not match expected meta, recv mata: " << CollectiveMetaToString(recv_meta) | |||
| << ", expected meta: " << CollectiveMetaToString(expect_meta); | |||
| return false; | |||
| } | |||
| *output = recv_data; | |||
| return true; // success to recv data | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool AbstractNode::FlCollectiveWait(const CollectiveMessageMeta &expect_meta, size_t expect_size, VectorPtr *output, | |||
| const uint32_t &timeout) { | |||
| if (output == nullptr) { | |||
| MS_LOG(ERROR) << "FlCollectiveWait failed, parameter output invalid"; | |||
| return false; | |||
| } | |||
| auto data_recved = FlCollectiveWaitInner(expect_meta, output, timeout); | |||
| if (!data_recved) { | |||
| MS_LOG(ERROR) << "FlCollectiveWait failed, expect meta: " << CollectiveMetaToString(expect_meta); | |||
| return false; | |||
| } | |||
| if (*output == nullptr) { | |||
| MS_LOG(ERROR) << "FlCollectiveWait failed, recv buffer invalid"; | |||
| return false; | |||
| } | |||
| if (expect_size != (*output)->size()) { | |||
| MS_LOG(ERROR) << "Expected data size " << expect_size << " != recv data size " << (*output)->size() | |||
| << CollectiveMetaToString(expect_meta); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AbstractNode::OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data) { | |||
| std::unique_lock<std::mutex> lock(fl_receive_mutex_); | |||
| auto &recv_meta = message_meta.collective_meta(); | |||
| auto send_rank_id = recv_meta.send_rank_id(); | |||
| MS_LOG(DEBUG) << "Receive data from rank id:" << send_rank_id << ", recv meta:" << CollectiveMetaToString(recv_meta); | |||
| fl_received_data_[send_rank_id].emplace_back(std::make_pair(recv_meta, data)); | |||
| fl_receive_cond_.notify_all(); | |||
| } | |||
| void AbstractNode::SetIterationResult(size_t last_iteration, bool is_iteration_valid) { | |||
| iteration_failed_ = !is_iteration_valid; | |||
| failed_iteration_num_ = last_iteration; | |||
| } | |||
| bool AbstractNode::HasIterationFailed(uint32_t iteration_num) const { | |||
| return iteration_num == failed_iteration_num_ && iteration_failed_; | |||
| } | |||
| std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id, | |||
| VectorPtr *output) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| @@ -1081,19 +1211,22 @@ void AbstractNode::RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, | |||
| size_t size) { | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| receive_callbacks_mutex_.lock(); | |||
| uint32_t rank_id = meta->rank_id(); | |||
| // When receiving a collective message, Then generate rank request id,compare with the desired rank request id, | |||
| // If they are equal, then call the callback function | |||
| uint64_t rank_request_id = NextActualRankRequestId(rank_id); | |||
| std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0); | |||
| size_t dest_size = size; | |||
| size_t src_size = size; | |||
| int ret = memcpy_s(received_data->data(), dest_size, data, src_size); | |||
| if (ret != 0) { | |||
| receive_callbacks_mutex_.unlock(); | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| if (meta->collective_meta().enable_flag()) { | |||
| OnRecvCollectiveData(*meta, received_data); | |||
| return; | |||
| } | |||
| receive_callbacks_mutex_.lock(); | |||
| uint32_t rank_id = meta->rank_id(); | |||
| // When receiving a collective message, Then generate rank request id,compare with the desired rank request id, | |||
| // If they are equal, then call the callback function | |||
| uint64_t rank_request_id = NextActualRankRequestId(rank_id); | |||
| received_data_[std::make_pair(rank_id, rank_request_id)] = received_data; | |||
| MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id | |||
| << ", the send request id is:" << meta->request_id() << " the size is:" << size; | |||
| @@ -23,6 +23,7 @@ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include <functional> | |||
| #include "ps/core/node.h" | |||
| #include "ps/core/communicator/message.h" | |||
| @@ -103,6 +104,12 @@ class AbstractNode : public Node { | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size); | |||
| using CheckFailReturnFun = std::function<bool()>; | |||
| uint64_t FlCollectiveSendAsync(const CollectiveMessageMeta &collective_meta, const void *data, size_t size); | |||
| bool FlCollectiveWait(const CollectiveMessageMeta &expect_meta, size_t expect_size, VectorPtr *output, | |||
| const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id, | |||
| VectorPtr *output); | |||
| bool CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | |||
| @@ -145,6 +152,9 @@ class AbstractNode : public Node { | |||
| uint32_t worker_num, uint32_t server_num, | |||
| const std::shared_ptr<TaskExecutor> &task_executor); | |||
| void SetIterationResult(size_t last_iteration, bool is_iteration_valid); | |||
| bool HasIterationFailed(uint32_t iteration_num) const; | |||
| protected: | |||
| virtual void Register(const std::shared_ptr<TcpClient> &client); | |||
| bool Heartbeat(const std::shared_ptr<TcpClient> &client); | |||
| @@ -235,6 +245,9 @@ class AbstractNode : public Node { | |||
| const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, | |||
| size_t size); | |||
| bool FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output, const uint32_t &timeout); | |||
| void OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data); | |||
| std::unique_ptr<std::thread> heart_beat_thread_; | |||
| std::unique_ptr<std::thread> client_to_scheduler_thread_; | |||
| std::shared_ptr<TcpClient> client_to_scheduler_; | |||
| @@ -260,6 +273,12 @@ class AbstractNode : public Node { | |||
| std::unordered_map<NodeCommand, ResponseHandler> handlers_; | |||
| std::unordered_map<NodeCommand, ServerHandler> server_handler_; | |||
| // send_rank_id, recv CollectiveMessageMeta and data | |||
| std::unordered_map<uint32_t, std::vector<std::pair<CollectiveMessageMeta, std::shared_ptr<std::vector<uint8_t>>>>> | |||
| fl_received_data_; | |||
| std::mutex fl_receive_mutex_; | |||
| std::condition_variable fl_receive_cond_; | |||
| // Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> server_thread_; | |||
| @@ -302,6 +321,9 @@ class AbstractNode : public Node { | |||
| std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_; | |||
| std::mutex communicator_mutex_; | |||
| std::mutex cluster_state_mutex_; | |||
| size_t failed_iteration_num_ = 0; | |||
| bool iteration_failed_ = false; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -68,7 +68,18 @@ enum PersistentState { | |||
| PREPARING_PERSIST = 1; | |||
| READY_PERSIST = 2; | |||
| PERSISTING = 3; | |||
| FINISH_PERSIST = 4; | |||
| FINISH_PERSIST = 4; | |||
| } | |||
| message CollectiveMessageMeta { | |||
| bool enable_flag = 1; | |||
| uint32 send_rank_id = 2; | |||
| uint32 recv_rank_id = 3; | |||
| uint32 iteration = 4; | |||
| bytes weight_name = 5; | |||
| bytes phase = 6; // ring, gather, reduce, broadcast | |||
| uint32 chunk_index = 7; | |||
| uint32 for_index = 8; | |||
| } | |||
| message MessageMeta { | |||
| @@ -82,6 +93,8 @@ message MessageMeta { | |||
| uint32 rank_id = 4; | |||
| // User-defined commands | |||
| int32 user_cmd = 5; | |||
| CollectiveMessageMeta collective_meta = 6; | |||
| } | |||
| message RegisterMessage { | |||
| @@ -147,7 +160,7 @@ message ServersMeta { | |||
| bool is_alive = 4; | |||
| NodeRole role = 5; | |||
| string node_id = 6; | |||
| PersistentState persistent_state = 7; | |||
| PersistentState persistent_state = 7; | |||
| } | |||
| message SendMetadataMessage { | |||
| @@ -213,6 +213,7 @@ message SyncIterationResponse { | |||
| message PrepareForNextIterRequest { | |||
| bool is_last_iter_valid = 2; | |||
| string reason = 3; | |||
| uint64 last_iteration = 4; | |||
| } | |||
| message PrepareForNextIterResponse { | |||