diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 44db02baa9..60ca74824d 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -624,7 +624,6 @@ bool StartServerAction(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); const std::string &server_mode_ = ps::PSContext::instance()->server_mode(); size_t worker_num = ps::PSContext::instance()->initial_worker_num(); - size_t server_num = ps::PSContext::instance()->initial_server_num(); uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port(); // Update model threshold is a certain ratio of start_fl_job threshold. @@ -633,17 +632,9 @@ bool StartServerAction(const ResourcePtr &res) { float percent_for_update_model = 1; size_t update_model_threshold = static_cast(std::ceil(start_fl_job_threshold * percent_for_update_model)); - std::vector rounds_config = { - {"startFLJob", false, 3000, false, start_fl_job_threshold}, - {"updateModel", false, 3000, false, update_model_threshold}, - {"getModel", false, 3000}, - {"asyncUpdateModel"}, - {"asyncGetModel"}, - {"push", false, 3000, true, worker_num}, - {"pull", false, 3000, true, worker_num}, - {"getWeightsByKey", false, 3000, true, 1}, - {"overwriteWeightsByKey", false, 3000, true, server_num}, - }; + std::vector rounds_config = {{"startFLJob", false, 3000, true, start_fl_job_threshold}, + {"updateModel", false, 3000, true, update_model_threshold}, + {"getModel", false, 3000}}; size_t executor_threshold = 0; if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 8b77ff8b6c..937bf64baf 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -235,7 +235,7 @@ void PSContext::GenerateResetterRound() { } binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | - (is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4); + (is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_upload_weights_ << 4); if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { resetter_round_ = ResetterRound::kNoNeedToReset; } else { @@ -277,11 +277,11 @@ void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch uint64_t PSContext::client_batch_size() const { return client_batch_size_; } -void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) { - worker_overwrite_weights_ = worker_overwrite_weights; +void PSContext::set_worker_upload_weights(uint64_t worker_upload_weights) { + worker_upload_weights_ = worker_upload_weights; } -uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; } +uint64_t PSContext::worker_upload_weights() const { return worker_upload_weights_; } void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 68e45db5db..f1199c1d30 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -132,8 +132,8 @@ class PSContext { uint64_t client_batch_size() const; // Set true if worker will overwrite weights on server. Used in hybrid training. - void set_worker_overwrite_weights(uint64_t worker_overwrite_weights); - uint64_t worker_overwrite_weights() const; + void set_worker_upload_weights(uint64_t worker_upload_weights); + uint64_t worker_upload_weights() const; // Set true if using secure aggregation for federated learning. void set_secure_aggregation(bool secure_aggregation); @@ -149,7 +149,19 @@ class PSContext { worker_num_(0), server_num_(0), scheduler_host_(""), - scheduler_port_(0) {} + scheduler_port_(0), + role_(kEnvRoleOfNotPS), + server_mode_(""), + resetter_round_(ResetterRound::kNoNeedToReset), + fl_server_port_(0), + fl_client_enable_(false), + fl_name_(""), + start_fl_job_threshold_(0), + fl_iteration_num_(0), + client_epoch_num_(0), + client_batch_size_(0), + secure_aggregation_(false), + worker_upload_weights_(false) {} bool ps_enabled_; bool is_worker_; bool is_pserver_; @@ -160,22 +172,42 @@ class PSContext { std::string scheduler_host_; uint16_t scheduler_port_; + // The server process's role. std::string role_; - // Members for federated learning. + // Server mode which could be Parameter Server, Federated Learning and Hybrid Training mode. std::string server_mode_; + + // The round which will reset the iteration. Used in federated learning for now. ResetterRound resetter_round_; + + // Http port of federated learning server. uint16_t fl_server_port_; + + // Whether this process is the federated client. Used in cross-silo scenario of federated learning. bool fl_client_enable_; + + // Federated learning job name. std::string fl_name_; + + // The threshold count of startFLJob round. Used in federated learning for now. size_t start_fl_job_threshold_; + + // Iteration number of federeated learning, which is the number of interactions between client and server. uint64_t fl_iteration_num_; + + // Client training epoch number. Used in federated learning for now. uint64_t client_epoch_num_; + + // Client training data batch size. Used in federated learning for now. uint64_t client_batch_size_; - bool worker_overwrite_weights_; - // Federated learning security. + // Whether to use secure aggregation algorithm. Used in federated learning for now. bool secure_aggregation_; + + // Whether there's a federated learning worker uploading weights to federated learning server. Used in hybrid training + // mode for now. + bool worker_upload_weights_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h b/mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h index 5c44afd200..2a76a87ba5 100644 --- a/mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h +++ b/mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h @@ -58,6 +58,12 @@ class AggregationKernel : public CPUKernel { virtual bool IsAggregationDone() = 0; + // Some kernels should know the inputs/workspace/outputs addresses at initializing phase. For example, FedAvgKernel. + virtual void SetParameterAddress(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + return; + } + // Setter and getter of kernels parameters information. void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; } const std::vector &input_names() { return params_info_.inputs_names(); } diff --git a/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h b/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h index 2a313affc6..c8a53dd80a 100644 --- a/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h +++ b/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h @@ -44,6 +44,7 @@ class FedAvgKernel : public AggregationKernel { public: FedAvgKernel() : participated_(false) {} ~FedAvgKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override { MS_EXCEPTION_IF_NULL(kernel_node); std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); @@ -97,6 +98,7 @@ class FedAvgKernel : public AggregationKernel { DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler}); return; } + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override { std::unique_lock lock(weight_mutex_); @@ -125,14 +127,25 @@ class FedAvgKernel : public AggregationKernel { name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); return true; } - void Reset() { + + void Reset() override { accum_count_ = 0; done_ = false; participated_ = false; DistributedCountService::GetInstance().ResetCounter(name_); return; } - bool IsAggregationDone() { return done_; } + + bool IsAggregationDone() override { return done_; } + + void SetParameterAddress(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + weight_addr_ = inputs[0]; + data_size_addr_ = inputs[1]; + new_weight_addr_ = inputs[2]; + new_data_size_addr_ = inputs[3]; + return; + } private: void GenerateReuseKernelNodeInfo() override { diff --git a/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc index 766d8afe5f..ab745701e5 100644 --- a/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc @@ -96,7 +96,7 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr *array) { float_ar void MemoryRegister::StoreInt32Array(std::unique_ptr *array) { int32_arrays_.push_back(std::move(*array)); } void MemoryRegister::StoreUint64Array(std::unique_ptr *array) { uint64_arrays_.push_back(std::move(*array)); } + +void MemoryRegister::StoreCharArray(std::unique_ptr *array) { char_arrays_.push_back(std::move(*array)); } } // namespace server } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/memory_register.h b/mindspore/ccsrc/ps/server/memory_register.h index 161de5c7ba..1b48c341d7 100644 --- a/mindspore/ccsrc/ps/server/memory_register.h +++ b/mindspore/ccsrc/ps/server/memory_register.h @@ -58,6 +58,9 @@ class MemoryRegister { } else if (typeid(T) == typeid(size_t)) { auto uint64_arr = CastUniquePtr(array); StoreUint64Array(&uint64_arr); + } else if (typeid(T) == typeid(char)) { + auto char_arr = CastUniquePtr(array); + StoreCharArray(&char_arr); } else { MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); return; @@ -72,10 +75,12 @@ class MemoryRegister { std::vector> float_arrays_; std::vector> int32_arrays_; std::vector> uint64_arrays_; + std::vector> char_arrays_; void StoreInt32Array(std::unique_ptr *array); void StoreFloatArray(std::unique_ptr *array); void StoreUint64Array(std::unique_ptr *array); + void StoreCharArray(std::unique_ptr *array); template std::unique_ptr CastUniquePtr(std::unique_ptr *array) { diff --git a/mindspore/ccsrc/ps/server/model_store.cc b/mindspore/ccsrc/ps/server/model_store.cc index eb78495263..e8eac5a4d3 100644 --- a/mindspore/ccsrc/ps/server/model_store.cc +++ b/mindspore/ccsrc/ps/server/model_store.cc @@ -68,7 +68,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::mapaddresses(); for (const auto &weight : new_model) { const std::string &weight_name = weight.first; - if (stored_model.count(weight_name) != 0) { + if (stored_model.count(weight_name) == 0) { MS_LOG(ERROR) << "The stored model has no weight " << weight_name; continue; } diff --git a/mindspore/ccsrc/ps/server/parameter_aggregator.cc b/mindspore/ccsrc/ps/server/parameter_aggregator.cc index a683d68f8a..e595624b58 100644 --- a/mindspore/ccsrc/ps/server/parameter_aggregator.cc +++ b/mindspore/ccsrc/ps/server/parameter_aggregator.cc @@ -183,10 +183,10 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) { } bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { - // if (PSContext::instance()->server_mode() == kServerModeFL) { - // MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; - // return false; - // } + if (PSContext::instance()->server_mode() == kServerModeFL) { + MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; + return false; + } MS_EXCEPTION_IF_NULL(cnode); const std::string &name = AnfAlgo::GetCNodeName(cnode); auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode); @@ -275,6 +275,7 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr< std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), [&](const std::string &name) { return memory_register->addresses()[name]; }); + aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs); aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); return true; } @@ -302,7 +303,16 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { - std::vector aggregation_algorithm = {kFedAvg}; + std::vector aggregation_algorithm = {}; + if (PSContext::instance()->server_mode() == kServerModeFL || + PSContext::instance()->server_mode() == kServerModeHybrid) { + aggregation_algorithm.push_back("FedAvg"); + } else if (PSContext::instance()->server_mode() == kServerModePS) { + aggregation_algorithm.push_back("DenseGradAccum"); + } else { + MS_LOG(ERROR) << "Server doesn't support mode " << PSContext::instance()->server_mode(); + } + MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; return aggregation_algorithm; }