| @@ -624,7 +624,6 @@ bool StartServerAction(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| const std::string &server_mode_ = ps::PSContext::instance()->server_mode(); | |||
| size_t worker_num = ps::PSContext::instance()->initial_worker_num(); | |||
| size_t server_num = ps::PSContext::instance()->initial_server_num(); | |||
| uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port(); | |||
| // Update model threshold is a certain ratio of start_fl_job threshold. | |||
| @@ -633,17 +632,9 @@ bool StartServerAction(const ResourcePtr &res) { | |||
| float percent_for_update_model = 1; | |||
| size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model)); | |||
| std::vector<ps::server::RoundConfig> rounds_config = { | |||
| {"startFLJob", false, 3000, false, start_fl_job_threshold}, | |||
| {"updateModel", false, 3000, false, update_model_threshold}, | |||
| {"getModel", false, 3000}, | |||
| {"asyncUpdateModel"}, | |||
| {"asyncGetModel"}, | |||
| {"push", false, 3000, true, worker_num}, | |||
| {"pull", false, 3000, true, worker_num}, | |||
| {"getWeightsByKey", false, 3000, true, 1}, | |||
| {"overwriteWeightsByKey", false, 3000, true, server_num}, | |||
| }; | |||
| std::vector<ps::server::RoundConfig> rounds_config = {{"startFLJob", false, 3000, true, start_fl_job_threshold}, | |||
| {"updateModel", false, 3000, true, update_model_threshold}, | |||
| {"getModel", false, 3000}}; | |||
| size_t executor_threshold = 0; | |||
| if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { | |||
| @@ -235,7 +235,7 @@ void PSContext::GenerateResetterRound() { | |||
| } | |||
| binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | | |||
| (is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4); | |||
| (is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_upload_weights_ << 4); | |||
| if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { | |||
| resetter_round_ = ResetterRound::kNoNeedToReset; | |||
| } else { | |||
| @@ -277,11 +277,11 @@ void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch | |||
| uint64_t PSContext::client_batch_size() const { return client_batch_size_; } | |||
| void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) { | |||
| worker_overwrite_weights_ = worker_overwrite_weights; | |||
| void PSContext::set_worker_upload_weights(uint64_t worker_upload_weights) { | |||
| worker_upload_weights_ = worker_upload_weights; | |||
| } | |||
| uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; } | |||
| uint64_t PSContext::worker_upload_weights() const { return worker_upload_weights_; } | |||
| void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } | |||
| @@ -132,8 +132,8 @@ class PSContext { | |||
| uint64_t client_batch_size() const; | |||
| // Set true if worker will overwrite weights on server. Used in hybrid training. | |||
| void set_worker_overwrite_weights(uint64_t worker_overwrite_weights); | |||
| uint64_t worker_overwrite_weights() const; | |||
| void set_worker_upload_weights(uint64_t worker_upload_weights); | |||
| uint64_t worker_upload_weights() const; | |||
| // Set true if using secure aggregation for federated learning. | |||
| void set_secure_aggregation(bool secure_aggregation); | |||
| @@ -149,7 +149,19 @@ class PSContext { | |||
| worker_num_(0), | |||
| server_num_(0), | |||
| scheduler_host_(""), | |||
| scheduler_port_(0) {} | |||
| scheduler_port_(0), | |||
| role_(kEnvRoleOfNotPS), | |||
| server_mode_(""), | |||
| resetter_round_(ResetterRound::kNoNeedToReset), | |||
| fl_server_port_(0), | |||
| fl_client_enable_(false), | |||
| fl_name_(""), | |||
| start_fl_job_threshold_(0), | |||
| fl_iteration_num_(0), | |||
| client_epoch_num_(0), | |||
| client_batch_size_(0), | |||
| secure_aggregation_(false), | |||
| worker_upload_weights_(false) {} | |||
| bool ps_enabled_; | |||
| bool is_worker_; | |||
| bool is_pserver_; | |||
| @@ -160,22 +172,42 @@ class PSContext { | |||
| std::string scheduler_host_; | |||
| uint16_t scheduler_port_; | |||
| // The server process's role. | |||
| std::string role_; | |||
| // Members for federated learning. | |||
| // Server mode which could be Parameter Server, Federated Learning and Hybrid Training mode. | |||
| std::string server_mode_; | |||
| // The round which will reset the iteration. Used in federated learning for now. | |||
| ResetterRound resetter_round_; | |||
| // Http port of federated learning server. | |||
| uint16_t fl_server_port_; | |||
| // Whether this process is the federated client. Used in cross-silo scenario of federated learning. | |||
| bool fl_client_enable_; | |||
| // Federated learning job name. | |||
| std::string fl_name_; | |||
| // The threshold count of startFLJob round. Used in federated learning for now. | |||
| size_t start_fl_job_threshold_; | |||
| // Iteration number of federeated learning, which is the number of interactions between client and server. | |||
| uint64_t fl_iteration_num_; | |||
| // Client training epoch number. Used in federated learning for now. | |||
| uint64_t client_epoch_num_; | |||
| // Client training data batch size. Used in federated learning for now. | |||
| uint64_t client_batch_size_; | |||
| bool worker_overwrite_weights_; | |||
| // Federated learning security. | |||
| // Whether to use secure aggregation algorithm. Used in federated learning for now. | |||
| bool secure_aggregation_; | |||
| // Whether there's a federated learning worker uploading weights to federated learning server. Used in hybrid training | |||
| // mode for now. | |||
| bool worker_upload_weights_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -58,6 +58,12 @@ class AggregationKernel : public CPUKernel { | |||
| virtual bool IsAggregationDone() = 0; | |||
| // Some kernels should know the inputs/workspace/outputs addresses at initializing phase. For example, FedAvgKernel. | |||
| virtual void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| return; | |||
| } | |||
| // Setter and getter of kernels parameters information. | |||
| void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; } | |||
| const std::vector<std::string> &input_names() { return params_info_.inputs_names(); } | |||
| @@ -44,6 +44,7 @@ class FedAvgKernel : public AggregationKernel { | |||
| public: | |||
| FedAvgKernel() : participated_(false) {} | |||
| ~FedAvgKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -97,6 +98,7 @@ class FedAvgKernel : public AggregationKernel { | |||
| DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler}); | |||
| return; | |||
| } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override { | |||
| std::unique_lock<std::mutex> lock(weight_mutex_); | |||
| @@ -125,14 +127,25 @@ class FedAvgKernel : public AggregationKernel { | |||
| name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); | |||
| return true; | |||
| } | |||
| void Reset() { | |||
| void Reset() override { | |||
| accum_count_ = 0; | |||
| done_ = false; | |||
| participated_ = false; | |||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||
| return; | |||
| } | |||
| bool IsAggregationDone() { return done_; } | |||
| bool IsAggregationDone() override { return done_; } | |||
| void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| weight_addr_ = inputs[0]; | |||
| data_size_addr_ = inputs[1]; | |||
| new_weight_addr_ = inputs[2]; | |||
| new_data_size_addr_ = inputs[3]; | |||
| return; | |||
| } | |||
| private: | |||
| void GenerateReuseKernelNodeInfo() override { | |||
| @@ -96,7 +96,7 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand | |||
| MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " | |||
| << total_data_size; | |||
| FinishIterCb(); | |||
| FinishIteration(); | |||
| } | |||
| } | |||
| @@ -29,6 +29,8 @@ void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) { float_ar | |||
| void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); } | |||
| void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); } | |||
| void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -58,6 +58,9 @@ class MemoryRegister { | |||
| } else if (typeid(T) == typeid(size_t)) { | |||
| auto uint64_arr = CastUniquePtr<size_t, T>(array); | |||
| StoreUint64Array(&uint64_arr); | |||
| } else if (typeid(T) == typeid(char)) { | |||
| auto char_arr = CastUniquePtr<char, T>(array); | |||
| StoreCharArray(&char_arr); | |||
| } else { | |||
| MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); | |||
| return; | |||
| @@ -72,10 +75,12 @@ class MemoryRegister { | |||
| std::vector<std::unique_ptr<float[]>> float_arrays_; | |||
| std::vector<std::unique_ptr<int[]>> int32_arrays_; | |||
| std::vector<std::unique_ptr<size_t[]>> uint64_arrays_; | |||
| std::vector<std::unique_ptr<char[]>> char_arrays_; | |||
| void StoreInt32Array(std::unique_ptr<int[]> *array); | |||
| void StoreFloatArray(std::unique_ptr<float[]> *array); | |||
| void StoreUint64Array(std::unique_ptr<size_t[]> *array); | |||
| void StoreCharArray(std::unique_ptr<char[]> *array); | |||
| template <typename T, typename S> | |||
| std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) { | |||
| @@ -68,7 +68,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin | |||
| auto &stored_model = memory_register->addresses(); | |||
| for (const auto &weight : new_model) { | |||
| const std::string &weight_name = weight.first; | |||
| if (stored_model.count(weight_name) != 0) { | |||
| if (stored_model.count(weight_name) == 0) { | |||
| MS_LOG(ERROR) << "The stored model has no weight " << weight_name; | |||
| continue; | |||
| } | |||
| @@ -183,10 +183,10 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) { | |||
| } | |||
| bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { | |||
| // if (PSContext::instance()->server_mode() == kServerModeFL) { | |||
| // MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; | |||
| // return false; | |||
| // } | |||
| if (PSContext::instance()->server_mode() == kServerModeFL) { | |||
| MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const std::string &name = AnfAlgo::GetCNodeName(cnode); | |||
| auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode); | |||
| @@ -275,6 +275,7 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr< | |||
| std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs); | |||
| aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); | |||
| return true; | |||
| } | |||
| @@ -302,7 +303,16 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke | |||
| } | |||
| std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { | |||
| std::vector<std::string> aggregation_algorithm = {kFedAvg}; | |||
| std::vector<std::string> aggregation_algorithm = {}; | |||
| if (PSContext::instance()->server_mode() == kServerModeFL || | |||
| PSContext::instance()->server_mode() == kServerModeHybrid) { | |||
| aggregation_algorithm.push_back("FedAvg"); | |||
| } else if (PSContext::instance()->server_mode() == kServerModePS) { | |||
| aggregation_algorithm.push_back("DenseGradAccum"); | |||
| } else { | |||
| MS_LOG(ERROR) << "Server doesn't support mode " << PSContext::instance()->server_mode(); | |||
| } | |||
| MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; | |||
| return aggregation_algorithm; | |||
| } | |||