| @@ -624,7 +624,6 @@ bool StartServerAction(const ResourcePtr &res) { | |||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| const std::string &server_mode_ = ps::PSContext::instance()->server_mode(); | const std::string &server_mode_ = ps::PSContext::instance()->server_mode(); | ||||
| size_t worker_num = ps::PSContext::instance()->initial_worker_num(); | size_t worker_num = ps::PSContext::instance()->initial_worker_num(); | ||||
| size_t server_num = ps::PSContext::instance()->initial_server_num(); | |||||
| uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port(); | uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port(); | ||||
| // Update model threshold is a certain ratio of start_fl_job threshold. | // Update model threshold is a certain ratio of start_fl_job threshold. | ||||
| @@ -633,17 +632,9 @@ bool StartServerAction(const ResourcePtr &res) { | |||||
| float percent_for_update_model = 1; | float percent_for_update_model = 1; | ||||
| size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model)); | size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model)); | ||||
| std::vector<ps::server::RoundConfig> rounds_config = { | |||||
| {"startFLJob", false, 3000, false, start_fl_job_threshold}, | |||||
| {"updateModel", false, 3000, false, update_model_threshold}, | |||||
| {"getModel", false, 3000}, | |||||
| {"asyncUpdateModel"}, | |||||
| {"asyncGetModel"}, | |||||
| {"push", false, 3000, true, worker_num}, | |||||
| {"pull", false, 3000, true, worker_num}, | |||||
| {"getWeightsByKey", false, 3000, true, 1}, | |||||
| {"overwriteWeightsByKey", false, 3000, true, server_num}, | |||||
| }; | |||||
| std::vector<ps::server::RoundConfig> rounds_config = {{"startFLJob", false, 3000, true, start_fl_job_threshold}, | |||||
| {"updateModel", false, 3000, true, update_model_threshold}, | |||||
| {"getModel", false, 3000}}; | |||||
| size_t executor_threshold = 0; | size_t executor_threshold = 0; | ||||
| if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { | if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { | ||||
| @@ -235,7 +235,7 @@ void PSContext::GenerateResetterRound() { | |||||
| } | } | ||||
| binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | | binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | | ||||
| (is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4); | |||||
| (is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_upload_weights_ << 4); | |||||
| if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { | if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { | ||||
| resetter_round_ = ResetterRound::kNoNeedToReset; | resetter_round_ = ResetterRound::kNoNeedToReset; | ||||
| } else { | } else { | ||||
| @@ -277,11 +277,11 @@ void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch | |||||
| uint64_t PSContext::client_batch_size() const { return client_batch_size_; } | uint64_t PSContext::client_batch_size() const { return client_batch_size_; } | ||||
| void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) { | |||||
| worker_overwrite_weights_ = worker_overwrite_weights; | |||||
| void PSContext::set_worker_upload_weights(uint64_t worker_upload_weights) { | |||||
| worker_upload_weights_ = worker_upload_weights; | |||||
| } | } | ||||
| uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; } | |||||
| uint64_t PSContext::worker_upload_weights() const { return worker_upload_weights_; } | |||||
| void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } | void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } | ||||
| @@ -132,8 +132,8 @@ class PSContext { | |||||
| uint64_t client_batch_size() const; | uint64_t client_batch_size() const; | ||||
| // Set true if worker will overwrite weights on server. Used in hybrid training. | // Set true if worker will overwrite weights on server. Used in hybrid training. | ||||
| void set_worker_overwrite_weights(uint64_t worker_overwrite_weights); | |||||
| uint64_t worker_overwrite_weights() const; | |||||
| void set_worker_upload_weights(uint64_t worker_upload_weights); | |||||
| uint64_t worker_upload_weights() const; | |||||
| // Set true if using secure aggregation for federated learning. | // Set true if using secure aggregation for federated learning. | ||||
| void set_secure_aggregation(bool secure_aggregation); | void set_secure_aggregation(bool secure_aggregation); | ||||
| @@ -149,7 +149,19 @@ class PSContext { | |||||
| worker_num_(0), | worker_num_(0), | ||||
| server_num_(0), | server_num_(0), | ||||
| scheduler_host_(""), | scheduler_host_(""), | ||||
| scheduler_port_(0) {} | |||||
| scheduler_port_(0), | |||||
| role_(kEnvRoleOfNotPS), | |||||
| server_mode_(""), | |||||
| resetter_round_(ResetterRound::kNoNeedToReset), | |||||
| fl_server_port_(0), | |||||
| fl_client_enable_(false), | |||||
| fl_name_(""), | |||||
| start_fl_job_threshold_(0), | |||||
| fl_iteration_num_(0), | |||||
| client_epoch_num_(0), | |||||
| client_batch_size_(0), | |||||
| secure_aggregation_(false), | |||||
| worker_upload_weights_(false) {} | |||||
| bool ps_enabled_; | bool ps_enabled_; | ||||
| bool is_worker_; | bool is_worker_; | ||||
| bool is_pserver_; | bool is_pserver_; | ||||
| @@ -160,22 +172,42 @@ class PSContext { | |||||
| std::string scheduler_host_; | std::string scheduler_host_; | ||||
| uint16_t scheduler_port_; | uint16_t scheduler_port_; | ||||
| // The server process's role. | |||||
| std::string role_; | std::string role_; | ||||
| // Members for federated learning. | |||||
| // Server mode which could be Parameter Server, Federated Learning and Hybrid Training mode. | |||||
| std::string server_mode_; | std::string server_mode_; | ||||
| // The round which will reset the iteration. Used in federated learning for now. | |||||
| ResetterRound resetter_round_; | ResetterRound resetter_round_; | ||||
| // Http port of federated learning server. | |||||
| uint16_t fl_server_port_; | uint16_t fl_server_port_; | ||||
| // Whether this process is the federated client. Used in cross-silo scenario of federated learning. | |||||
| bool fl_client_enable_; | bool fl_client_enable_; | ||||
| // Federated learning job name. | |||||
| std::string fl_name_; | std::string fl_name_; | ||||
| // The threshold count of startFLJob round. Used in federated learning for now. | |||||
| size_t start_fl_job_threshold_; | size_t start_fl_job_threshold_; | ||||
| // Iteration number of federeated learning, which is the number of interactions between client and server. | |||||
| uint64_t fl_iteration_num_; | uint64_t fl_iteration_num_; | ||||
| // Client training epoch number. Used in federated learning for now. | |||||
| uint64_t client_epoch_num_; | uint64_t client_epoch_num_; | ||||
| // Client training data batch size. Used in federated learning for now. | |||||
| uint64_t client_batch_size_; | uint64_t client_batch_size_; | ||||
| bool worker_overwrite_weights_; | |||||
| // Federated learning security. | |||||
| // Whether to use secure aggregation algorithm. Used in federated learning for now. | |||||
| bool secure_aggregation_; | bool secure_aggregation_; | ||||
| // Whether there's a federated learning worker uploading weights to federated learning server. Used in hybrid training | |||||
| // mode for now. | |||||
| bool worker_upload_weights_; | |||||
| }; | }; | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -58,6 +58,12 @@ class AggregationKernel : public CPUKernel { | |||||
| virtual bool IsAggregationDone() = 0; | virtual bool IsAggregationDone() = 0; | ||||
| // Some kernels should know the inputs/workspace/outputs addresses at initializing phase. For example, FedAvgKernel. | |||||
| virtual void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) { | |||||
| return; | |||||
| } | |||||
| // Setter and getter of kernels parameters information. | // Setter and getter of kernels parameters information. | ||||
| void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; } | void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; } | ||||
| const std::vector<std::string> &input_names() { return params_info_.inputs_names(); } | const std::vector<std::string> &input_names() { return params_info_.inputs_names(); } | ||||
| @@ -44,6 +44,7 @@ class FedAvgKernel : public AggregationKernel { | |||||
| public: | public: | ||||
| FedAvgKernel() : participated_(false) {} | FedAvgKernel() : participated_(false) {} | ||||
| ~FedAvgKernel() override = default; | ~FedAvgKernel() override = default; | ||||
| void InitKernel(const CNodePtr &kernel_node) override { | void InitKernel(const CNodePtr &kernel_node) override { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | ||||
| @@ -97,6 +98,7 @@ class FedAvgKernel : public AggregationKernel { | |||||
| DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler}); | DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler}); | ||||
| return; | return; | ||||
| } | } | ||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs) override { | const std::vector<AddressPtr> &outputs) override { | ||||
| std::unique_lock<std::mutex> lock(weight_mutex_); | std::unique_lock<std::mutex> lock(weight_mutex_); | ||||
| @@ -125,14 +127,25 @@ class FedAvgKernel : public AggregationKernel { | |||||
| name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); | name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| void Reset() { | |||||
| void Reset() override { | |||||
| accum_count_ = 0; | accum_count_ = 0; | ||||
| done_ = false; | done_ = false; | ||||
| participated_ = false; | participated_ = false; | ||||
| DistributedCountService::GetInstance().ResetCounter(name_); | DistributedCountService::GetInstance().ResetCounter(name_); | ||||
| return; | return; | ||||
| } | } | ||||
| bool IsAggregationDone() { return done_; } | |||||
| bool IsAggregationDone() override { return done_; } | |||||
| void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) { | |||||
| weight_addr_ = inputs[0]; | |||||
| data_size_addr_ = inputs[1]; | |||||
| new_weight_addr_ = inputs[2]; | |||||
| new_data_size_addr_ = inputs[3]; | |||||
| return; | |||||
| } | |||||
| private: | private: | ||||
| void GenerateReuseKernelNodeInfo() override { | void GenerateReuseKernelNodeInfo() override { | ||||
| @@ -96,7 +96,7 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand | |||||
| MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " | MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " | ||||
| << total_data_size; | << total_data_size; | ||||
| FinishIterCb(); | |||||
| FinishIteration(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -29,6 +29,8 @@ void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) { float_ar | |||||
| void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); } | void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); } | ||||
| void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); } | void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); } | ||||
| void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); } | |||||
| } // namespace server | } // namespace server | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -58,6 +58,9 @@ class MemoryRegister { | |||||
| } else if (typeid(T) == typeid(size_t)) { | } else if (typeid(T) == typeid(size_t)) { | ||||
| auto uint64_arr = CastUniquePtr<size_t, T>(array); | auto uint64_arr = CastUniquePtr<size_t, T>(array); | ||||
| StoreUint64Array(&uint64_arr); | StoreUint64Array(&uint64_arr); | ||||
| } else if (typeid(T) == typeid(char)) { | |||||
| auto char_arr = CastUniquePtr<char, T>(array); | |||||
| StoreCharArray(&char_arr); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); | MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); | ||||
| return; | return; | ||||
| @@ -72,10 +75,12 @@ class MemoryRegister { | |||||
| std::vector<std::unique_ptr<float[]>> float_arrays_; | std::vector<std::unique_ptr<float[]>> float_arrays_; | ||||
| std::vector<std::unique_ptr<int[]>> int32_arrays_; | std::vector<std::unique_ptr<int[]>> int32_arrays_; | ||||
| std::vector<std::unique_ptr<size_t[]>> uint64_arrays_; | std::vector<std::unique_ptr<size_t[]>> uint64_arrays_; | ||||
| std::vector<std::unique_ptr<char[]>> char_arrays_; | |||||
| void StoreInt32Array(std::unique_ptr<int[]> *array); | void StoreInt32Array(std::unique_ptr<int[]> *array); | ||||
| void StoreFloatArray(std::unique_ptr<float[]> *array); | void StoreFloatArray(std::unique_ptr<float[]> *array); | ||||
| void StoreUint64Array(std::unique_ptr<size_t[]> *array); | void StoreUint64Array(std::unique_ptr<size_t[]> *array); | ||||
| void StoreCharArray(std::unique_ptr<char[]> *array); | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) { | std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) { | ||||
| @@ -68,7 +68,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin | |||||
| auto &stored_model = memory_register->addresses(); | auto &stored_model = memory_register->addresses(); | ||||
| for (const auto &weight : new_model) { | for (const auto &weight : new_model) { | ||||
| const std::string &weight_name = weight.first; | const std::string &weight_name = weight.first; | ||||
| if (stored_model.count(weight_name) != 0) { | |||||
| if (stored_model.count(weight_name) == 0) { | |||||
| MS_LOG(ERROR) << "The stored model has no weight " << weight_name; | MS_LOG(ERROR) << "The stored model has no weight " << weight_name; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -183,10 +183,10 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) { | |||||
| } | } | ||||
| bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { | bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { | ||||
| // if (PSContext::instance()->server_mode() == kServerModeFL) { | |||||
| // MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; | |||||
| // return false; | |||||
| // } | |||||
| if (PSContext::instance()->server_mode() == kServerModeFL) { | |||||
| MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; | |||||
| return false; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| const std::string &name = AnfAlgo::GetCNodeName(cnode); | const std::string &name = AnfAlgo::GetCNodeName(cnode); | ||||
| auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode); | auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode); | ||||
| @@ -275,6 +275,7 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr< | |||||
| std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), | std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), | ||||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | [&](const std::string &name) { return memory_register->addresses()[name]; }); | ||||
| aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs); | |||||
| aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); | aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -302,7 +303,16 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke | |||||
| } | } | ||||
| std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { | std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { | ||||
| std::vector<std::string> aggregation_algorithm = {kFedAvg}; | |||||
| std::vector<std::string> aggregation_algorithm = {}; | |||||
| if (PSContext::instance()->server_mode() == kServerModeFL || | |||||
| PSContext::instance()->server_mode() == kServerModeHybrid) { | |||||
| aggregation_algorithm.push_back("FedAvg"); | |||||
| } else if (PSContext::instance()->server_mode() == kServerModePS) { | |||||
| aggregation_algorithm.push_back("DenseGradAccum"); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Server doesn't support mode " << PSContext::instance()->server_mode(); | |||||
| } | |||||
| MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; | MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; | ||||
| return aggregation_algorithm; | return aggregation_algorithm; | ||||
| } | } | ||||