From: @zpac Reviewed-by: @cristoval Signed-off-by:pull/16035/MERGE
| @@ -47,6 +47,7 @@ | |||||
| #include "ps/parameter_server.h" | #include "ps/parameter_server.h" | ||||
| #include "ps/scheduler.h" | #include "ps/scheduler.h" | ||||
| #include "ps/worker.h" | #include "ps/worker.h" | ||||
| #include "ps/server/server.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -619,6 +620,47 @@ bool StartPSServerAction(const ResourcePtr &res) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool StartServerAction(const ResourcePtr &res) { | |||||
| FuncGraphPtr func_graph = res->func_graph(); | |||||
| const std::string &server_mode_ = ps::PSContext::instance()->server_mode(); | |||||
| size_t worker_num = ps::PSContext::instance()->initial_worker_num(); | |||||
| size_t server_num = ps::PSContext::instance()->initial_server_num(); | |||||
| uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port(); | |||||
| // Update model threshold is a certain ratio of start_fl_job threshold. | |||||
| // update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_. | |||||
| size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold(); | |||||
| float percent_for_update_model = 1; | |||||
| size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model)); | |||||
| std::vector<ps::server::RoundConfig> rounds_config = { | |||||
| {"startFLJob", false, 3000, false, start_fl_job_threshold}, | |||||
| {"updateModel", false, 3000, false, update_model_threshold}, | |||||
| {"getModel", false, 3000}, | |||||
| {"asyncUpdateModel"}, | |||||
| {"asyncGetModel"}, | |||||
| {"push", false, 3000, true, worker_num}, | |||||
| {"pull", false, 3000, true, worker_num}, | |||||
| {"getWeightsByKey", false, 3000, true, 1}, | |||||
| {"overwriteWeightsByKey", false, 3000, true, server_num}, | |||||
| }; | |||||
| size_t executor_threshold = 0; | |||||
| if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { | |||||
| executor_threshold = update_model_threshold; | |||||
| ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, func_graph, | |||||
| executor_threshold); | |||||
| } else if (server_mode_ == ps::kServerModePS) { | |||||
| executor_threshold = worker_num; | |||||
| ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, func_graph, executor_threshold); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported."; | |||||
| return false; | |||||
| } | |||||
| ps::server::Server::GetInstance().Run(); | |||||
| return true; | |||||
| } | |||||
| bool StartPSSchedulerAction(const ResourcePtr &res) { | bool StartPSSchedulerAction(const ResourcePtr &res) { | ||||
| ps::Scheduler::GetInstance().Run(); | ps::Scheduler::GetInstance().Run(); | ||||
| return true; | return true; | ||||
| @@ -797,6 +839,14 @@ std::vector<ActionItem> VmPipeline() { | |||||
| } | } | ||||
| #if (ENABLE_CPU && !_WIN32) | #if (ENABLE_CPU && !_WIN32) | ||||
| std::vector<ActionItem> ServerPipeline() { | |||||
| auto actions = CommonPipeline(); | |||||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | |||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||||
| actions.emplace_back(std::make_pair("server", StartServerAction)); | |||||
| return actions; | |||||
| } | |||||
| std::vector<ActionItem> PServerPipeline() { | std::vector<ActionItem> PServerPipeline() { | ||||
| auto actions = CommonPipeline(); | auto actions = CommonPipeline(); | ||||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | ||||
| @@ -43,10 +43,14 @@ bool ExecuteAction(const ResourcePtr &res); | |||||
| bool StartPSWorkerAction(const ResourcePtr &res); | bool StartPSWorkerAction(const ResourcePtr &res); | ||||
| bool StartPSServerAction(const ResourcePtr &res); | bool StartPSServerAction(const ResourcePtr &res); | ||||
| bool StartPSSchedulerAction(const ResourcePtr &res); | bool StartPSSchedulerAction(const ResourcePtr &res); | ||||
| // This action is only for federated learning only. In later version, parameter server mode and federated learning will | |||||
| // use the same action. | |||||
| bool StartServerAction(const ResourcePtr &res); | |||||
| std::vector<ActionItem> GePipeline(); | std::vector<ActionItem> GePipeline(); | ||||
| std::vector<ActionItem> VmPipeline(); | std::vector<ActionItem> VmPipeline(); | ||||
| std::vector<ActionItem> PServerPipeline(); | std::vector<ActionItem> PServerPipeline(); | ||||
| std::vector<ActionItem> ServerPipeline(); | |||||
| std::vector<ActionItem> PSchedulerPipeline(); | std::vector<ActionItem> PSchedulerPipeline(); | ||||
| abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, | abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, | ||||
| const abstract::AbstractBasePtrList &args_spec, bool clear = false); | const abstract::AbstractBasePtrList &args_spec, bool clear = false); | ||||
| @@ -326,7 +326,24 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") | .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") | ||||
| .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") | .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") | ||||
| .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.") | .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.") | ||||
| .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode."); | |||||
| .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.") | |||||
| .def("set_server_mode", &PSContext::set_server_mode, "Set server mode.") | |||||
| .def("server_mode", &PSContext::server_mode, "Get server mode.") | |||||
| .def("set_ms_role", &PSContext::set_ms_role, "Set role for this process.") | |||||
| .def("ms_role", &PSContext::ms_role, "Get role for this process.") | |||||
| .def("set_worker_num", &PSContext::set_worker_num, "Set worker number.") | |||||
| .def("set_server_num", &PSContext::set_server_num, "Set server number.") | |||||
| .def("set_scheduler_ip", &PSContext::set_scheduler_ip, "Set scheduler ip.") | |||||
| .def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.") | |||||
| .def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.") | |||||
| .def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.") | |||||
| .def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.") | |||||
| .def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.") | |||||
| .def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.") | |||||
| .def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.") | |||||
| .def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.") | |||||
| .def("set_secure_aggregation", &PSContext::set_secure_aggregation, | |||||
| "Set federated learning client using secure aggregation."); | |||||
| (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| @@ -55,6 +55,7 @@ | |||||
| #include "ps/worker.h" | #include "ps/worker.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | ||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #include "ps/server/server.h" | |||||
| #endif | #endif | ||||
| #if (ENABLE_GE || ENABLE_D) | #if (ENABLE_GE || ENABLE_D) | ||||
| @@ -529,6 +530,11 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri | |||||
| std::string backend = MsContext::GetInstance()->backend_policy(); | std::string backend = MsContext::GetInstance()->backend_policy(); | ||||
| #if (ENABLE_CPU && !_WIN32) | #if (ENABLE_CPU && !_WIN32) | ||||
| const std::string &server_mode = ps::PSContext::instance()->server_mode(); | |||||
| if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) && | |||||
| ps::PSContext::instance()->is_server()) { | |||||
| return ServerPipeline(); | |||||
| } | |||||
| if (ps::PSContext::instance()->is_server()) { | if (ps::PSContext::instance()->is_server()) { | ||||
| resource->results()[kBackend] = compile::CreateBackend(); | resource->results()[kBackend] = compile::CreateBackend(); | ||||
| return PServerPipeline(); | return PServerPipeline(); | ||||
| @@ -50,10 +50,13 @@ if(NOT ENABLE_CPU OR WIN32) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/fed_avg_kernel.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/update_model_kernel.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/get_model_kernel.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") | ||||
| @@ -67,6 +70,7 @@ if(NOT ENABLE_CPU OR WIN32) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/server.cc") | |||||
| endif() | endif() | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | ||||
| @@ -61,7 +61,12 @@ void PSContext::SetPSEnable(bool enabled) { | |||||
| } | } | ||||
| } | } | ||||
| bool PSContext::is_ps_mode() const { return ps_enabled_; } | |||||
| bool PSContext::is_ps_mode() const { | |||||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||||
| return true; | |||||
| } | |||||
| return ps_enabled_; | |||||
| } | |||||
| void PSContext::Reset() { | void PSContext::Reset() { | ||||
| ps_enabled_ = false; | ps_enabled_ = false; | ||||
| @@ -77,6 +82,9 @@ void PSContext::Reset() { | |||||
| } | } | ||||
| std::string PSContext::ms_role() const { | std::string PSContext::ms_role() const { | ||||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||||
| return role_; | |||||
| } | |||||
| if (is_worker_) { | if (is_worker_) { | ||||
| return kEnvRoleOfWorker; | return kEnvRoleOfWorker; | ||||
| } else if (is_pserver_) { | } else if (is_pserver_) { | ||||
| @@ -88,11 +96,26 @@ std::string PSContext::ms_role() const { | |||||
| } | } | ||||
| } | } | ||||
| bool PSContext::is_worker() const { return is_worker_; } | |||||
| bool PSContext::is_worker() const { | |||||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||||
| return role_ == kRoleOfWorker; | |||||
| } | |||||
| return is_worker_; | |||||
| } | |||||
| bool PSContext::is_server() const { return is_pserver_; } | |||||
| bool PSContext::is_server() const { | |||||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||||
| return role_ == kEnvRoleOfServer; | |||||
| } | |||||
| return is_pserver_; | |||||
| } | |||||
| bool PSContext::is_scheduler() const { return is_sched_; } | |||||
| bool PSContext::is_scheduler() const { | |||||
| if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { | |||||
| return role_ == kEnvRoleOfScheduler; | |||||
| } | |||||
| return is_sched_; | |||||
| } | |||||
| uint32_t PSContext::initial_worker_num() { return worker_num_; } | uint32_t PSContext::initial_worker_num() { return worker_num_; } | ||||
| @@ -150,6 +173,94 @@ void PSContext::set_rank_id(int rank_id) const { | |||||
| #endif | #endif | ||||
| } | } | ||||
| void PSContext::set_server_mode(const std::string &server_mode) { | |||||
| if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) { | |||||
| MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL | |||||
| << " or " << kServerModeHybrid; | |||||
| return; | |||||
| } | |||||
| server_mode_ = server_mode; | |||||
| } | |||||
| const std::string &PSContext::server_mode() const { return server_mode_; } | |||||
| void PSContext::set_ms_role(const std::string &role) { | |||||
| if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) { | |||||
| MS_LOG(EXCEPTION) << "Only federated learning supports to set role by ps context."; | |||||
| return; | |||||
| } | |||||
| if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) { | |||||
| MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid."; | |||||
| return; | |||||
| } | |||||
| role_ = role; | |||||
| } | |||||
| void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; } | |||||
| uint32_t PSContext::worker_num() const { return worker_num_; } | |||||
| void PSContext::set_server_num(uint32_t server_num) { | |||||
| if (server_num == 0) { | |||||
| MS_LOG(EXCEPTION) << "Server number must be greater than 0."; | |||||
| return; | |||||
| } | |||||
| server_num_ = server_num; | |||||
| } | |||||
| uint32_t PSContext::server_num() const { return server_num_; } | |||||
| void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; } | |||||
| std::string PSContext::scheduler_ip() const { return scheduler_host_; } | |||||
| void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; } | |||||
| uint16_t PSContext::scheduler_port() const { return scheduler_port_; } | |||||
| void PSContext::GenerateResetterRound() { | |||||
| uint32_t binary_server_context = 0; | |||||
| bool is_parameter_server_mode = false; | |||||
| bool is_federated_learning_mode = false; | |||||
| bool is_mixed_training_mode = false; | |||||
| if (server_mode_ == kServerModePS) { | |||||
| is_parameter_server_mode = true; | |||||
| } else if (server_mode_ == kServerModeFL) { | |||||
| is_federated_learning_mode = true; | |||||
| } else if (server_mode_ == kServerModeHybrid) { | |||||
| is_mixed_training_mode = true; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL | |||||
| << " or " << kServerModeHybrid; | |||||
| return; | |||||
| } | |||||
| binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | | |||||
| (is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4); | |||||
| if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { | |||||
| resetter_round_ = ResetterRound::kNoNeedToReset; | |||||
| } else { | |||||
| resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context); | |||||
| } | |||||
| MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_; | |||||
| return; | |||||
| } | |||||
| ResetterRound PSContext::resetter_round() const { return resetter_round_; } | |||||
| void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; } | |||||
| uint16_t PSContext::fl_server_port() const { return fl_server_port_; } | |||||
| void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; } | |||||
| bool PSContext::fl_client_enable() { return fl_client_enable_; } | |||||
| void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) { | |||||
| start_fl_job_threshold_ = start_fl_job_threshold; | |||||
| } | |||||
| size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; } | |||||
| void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } | void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } | ||||
| const std::string &PSContext::fl_name() const { return fl_name_; } | const std::string &PSContext::fl_name() const { return fl_name_; } | ||||
| @@ -165,5 +276,15 @@ uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; } | |||||
| void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; } | void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; } | ||||
| uint64_t PSContext::client_batch_size() const { return client_batch_size_; } | uint64_t PSContext::client_batch_size() const { return client_batch_size_; } | ||||
| void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) { | |||||
| worker_overwrite_weights_ = worker_overwrite_weights; | |||||
| } | |||||
| uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; } | |||||
| void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } | |||||
| bool PSContext::secure_aggregation() const { return secure_aggregation_; } | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,7 @@ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_ | #ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_ | ||||
| #define MINDSPORE_CCSRC_PS_CONTEXT_H_ | #define MINDSPORE_CCSRC_PS_CONTEXT_H_ | ||||
| #include <map> | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "ps/constants.h" | #include "ps/constants.h" | ||||
| @@ -24,12 +25,32 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| constexpr char kServerModePS[] = "PARAMETER_SERVER"; | |||||
| constexpr char kServerModeFL[] = "FEDERATED_LEARNING"; | |||||
| constexpr char kServerModeHybrid[] = "HYBRID_TRAINING"; | |||||
| constexpr char kEnvRole[] = "MS_ROLE"; | constexpr char kEnvRole[] = "MS_ROLE"; | ||||
| constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; | constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; | ||||
| constexpr char kEnvRoleOfServer[] = "MS_SERVER"; | |||||
| constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; | constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; | ||||
| constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; | constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; | ||||
| constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; | constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; | ||||
| // Use binary data to represent federated learning server's context so that we can judge which round resets the | |||||
| // iteration. From right to left, each bit stands for: | |||||
| // 0: Server is in parameter server mode. | |||||
| // 1: Server is in federated learning mode. | |||||
| // 2: Server is in mixed training mode. | |||||
| // 3: Server enables sucure aggregation. | |||||
| // 4: Server needs worker to overwrite weights. | |||||
| // For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled. | |||||
| enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights }; | |||||
| const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = { | |||||
| {0b00010, ResetterRound::kUpdateModel}, | |||||
| {0b01010, ResetterRound::kReconstructSeccrets}, | |||||
| {0b11100, ResetterRound::kWorkerOverwriteWeights}, | |||||
| {0b10100, ResetterRound::kWorkerOverwriteWeights}, | |||||
| {0b00100, ResetterRound::kUpdateModel}}; | |||||
| class PSContext { | class PSContext { | ||||
| public: | public: | ||||
| ~PSContext() = default; | ~PSContext() = default; | ||||
| @@ -60,19 +81,64 @@ class PSContext { | |||||
| void set_cache_enable(bool cache_enable) const; | void set_cache_enable(bool cache_enable) const; | ||||
| void set_rank_id(int rank_id) const; | void set_rank_id(int rank_id) const; | ||||
| // Setter and getter for federated learning. | |||||
| // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set | |||||
| // by ps_context. | |||||
| void set_server_mode(const std::string &server_mode); | |||||
| const std::string &server_mode() const; | |||||
| void set_ms_role(const std::string &role); | |||||
| void set_worker_num(uint32_t worker_num); | |||||
| uint32_t worker_num() const; | |||||
| void set_server_num(uint32_t server_num); | |||||
| uint32_t server_num() const; | |||||
| void set_scheduler_ip(const std::string &sched_ip); | |||||
| std::string scheduler_ip() const; | |||||
| void set_scheduler_port(uint16_t sched_port); | |||||
| uint16_t scheduler_port() const; | |||||
| // Methods federated learning. | |||||
| // Generate which round should reset the iteration. | |||||
| void GenerateResetterRound(); | |||||
| ResetterRound resetter_round() const; | |||||
| void set_fl_server_port(uint16_t fl_server_port); | |||||
| uint16_t fl_server_port() const; | |||||
| // Set true if this process is a federated learning worker in cross-silo scenario. | |||||
| void set_fl_client_enable(bool enabled); | |||||
| bool fl_client_enable(); | |||||
| void set_start_fl_job_threshold(size_t start_fl_job_threshold); | |||||
| size_t start_fl_job_threshold() const; | |||||
| void set_fl_name(const std::string &fl_name); | void set_fl_name(const std::string &fl_name); | ||||
| const std::string &fl_name() const; | const std::string &fl_name() const; | ||||
| // Set the iteration number of the federated learning. | |||||
| void set_fl_iteration_num(uint64_t fl_iteration_num); | void set_fl_iteration_num(uint64_t fl_iteration_num); | ||||
| uint64_t fl_iteration_num() const; | uint64_t fl_iteration_num() const; | ||||
| // Set the training epoch number of the client. | |||||
| void set_client_epoch_num(uint64_t client_epoch_num); | void set_client_epoch_num(uint64_t client_epoch_num); | ||||
| uint64_t client_epoch_num() const; | uint64_t client_epoch_num() const; | ||||
| // Set the data batch size of the client. | |||||
| void set_client_batch_size(uint64_t client_batch_size); | void set_client_batch_size(uint64_t client_batch_size); | ||||
| uint64_t client_batch_size() const; | uint64_t client_batch_size() const; | ||||
| // Set true if worker will overwrite weights on server. Used in hybrid training. | |||||
| void set_worker_overwrite_weights(uint64_t worker_overwrite_weights); | |||||
| uint64_t worker_overwrite_weights() const; | |||||
| // Set true if using secure aggregation for federated learning. | |||||
| void set_secure_aggregation(bool secure_aggregation); | |||||
| bool secure_aggregation() const; | |||||
| private: | private: | ||||
| PSContext() | PSContext() | ||||
| : ps_enabled_(false), | : ps_enabled_(false), | ||||
| @@ -94,11 +160,22 @@ class PSContext { | |||||
| std::string scheduler_host_; | std::string scheduler_host_; | ||||
| uint16_t scheduler_port_; | uint16_t scheduler_port_; | ||||
| std::string role_; | |||||
| // Members for federated learning. | // Members for federated learning. | ||||
| std::string server_mode_; | |||||
| ResetterRound resetter_round_; | |||||
| uint16_t fl_server_port_; | |||||
| bool fl_client_enable_; | |||||
| std::string fl_name_; | std::string fl_name_; | ||||
| size_t start_fl_job_threshold_; | |||||
| uint64_t fl_iteration_num_; | uint64_t fl_iteration_num_; | ||||
| uint64_t client_epoch_num_; | uint64_t client_epoch_num_; | ||||
| uint64_t client_batch_size_; | uint64_t client_batch_size_; | ||||
| bool worker_overwrite_weights_; | |||||
| // Federated learning security. | |||||
| bool secure_aggregation_; | |||||
| }; | }; | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,6 +20,9 @@ namespace mindspore { | |||||
| namespace ps { | namespace ps { | ||||
| void Scheduler::Run() { | void Scheduler::Run() { | ||||
| MS_LOG(INFO) << "Start scheduler."; | MS_LOG(INFO) << "Start scheduler."; | ||||
| core::ClusterMetadata::instance()->Init( | |||||
| PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), | |||||
| PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); | |||||
| scheduler_node_.Start(); | scheduler_node_.Start(); | ||||
| scheduler_node_.Finish(); | scheduler_node_.Finish(); | ||||
| scheduler_node_.Stop(); | scheduler_node_.Stop(); | ||||
| @@ -44,6 +44,14 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER }; | |||||
| enum CommType { HTTP = 0, TCP }; | enum CommType { HTTP = 0, TCP }; | ||||
| enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; | enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; | ||||
| struct RoundConfig { | |||||
| std::string name; | |||||
| bool check_timeout = false; | |||||
| size_t time_window = 3000; | |||||
| bool check_count = false; | |||||
| size_t threshold_count = 0; | |||||
| }; | |||||
| using mindspore::kernel::Address; | using mindspore::kernel::Address; | ||||
| using mindspore::kernel::AddressPtr; | using mindspore::kernel::AddressPtr; | ||||
| using mindspore::kernel::CPUKernel; | using mindspore::kernel::CPUKernel; | ||||
| @@ -73,6 +81,7 @@ using ReuseKernelNodeInfo = std::map<std::string, size_t>; | |||||
| using UploadData = std::map<std::string, Address>; | using UploadData = std::map<std::string, Address>; | ||||
| constexpr auto kWeight = "weight"; | constexpr auto kWeight = "weight"; | ||||
| constexpr auto kNewWeight = "new_weight"; | |||||
| constexpr auto kAccumulation = "accum"; | constexpr auto kAccumulation = "accum"; | ||||
| constexpr auto kLearningRate = "lr"; | constexpr auto kLearningRate = "lr"; | ||||
| constexpr auto kGradient = "grad"; | constexpr auto kGradient = "grad"; | ||||
| @@ -87,6 +96,8 @@ constexpr auto kAdamBeta1 = "beta1"; | |||||
| constexpr auto kAdamBeta2 = "beta2"; | constexpr auto kAdamBeta2 = "beta2"; | ||||
| constexpr auto kAdamEps = "eps"; | constexpr auto kAdamEps = "eps"; | ||||
| constexpr auto kFtrlLinear = "linear"; | constexpr auto kFtrlLinear = "linear"; | ||||
| constexpr auto kDataSize = "data_size"; | |||||
| constexpr auto kNewDataSize = "new_data_size"; | |||||
| // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is | // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is | ||||
| // launched. | // launched. | ||||
| @@ -137,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32; | |||||
| constexpr int kHttpSuccess = 200; | constexpr int kHttpSuccess = 200; | ||||
| constexpr auto kPBProtocol = "PB"; | constexpr auto kPBProtocol = "PB"; | ||||
| constexpr auto kFBSProtocol = "FBS"; | constexpr auto kFBSProtocol = "FBS"; | ||||
| constexpr auto kFedAvg = "FedAvg"; | |||||
| constexpr auto kAggregationKernelType = "Aggregation"; | constexpr auto kAggregationKernelType = "Aggregation"; | ||||
| constexpr auto kOptimizerKernelType = "Optimizer"; | constexpr auto kOptimizerKernelType = "Optimizer"; | ||||
| constexpr auto kCtxFuncGraph = "FuncGraph"; | constexpr auto kCtxFuncGraph = "FuncGraph"; | ||||
| @@ -145,6 +157,8 @@ constexpr auto kCtxDeviceMetas = "device_metas"; | |||||
| constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration"; | constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration"; | ||||
| constexpr auto kCtxUpdateModelClientList = "update_model_client_list"; | constexpr auto kCtxUpdateModelClientList = "update_model_client_list"; | ||||
| constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; | constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; | ||||
| constexpr auto kCtxUpdateModelThld = "update_model_threshold"; | |||||
| constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size"; | |||||
| // This macro the current timestamp in milliseconds. | // This macro the current timestamp in milliseconds. | ||||
| #define CURRENT_TIME_MILLI \ | #define CURRENT_TIME_MILLI \ | ||||
| @@ -112,19 +112,19 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) { | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | std::unique_lock<std::mutex> lock(mutex_[name]); | ||||
| return global_current_count_[name].size() == global_threshold_count_[name]; | return global_current_count_[name].size() == global_threshold_count_[name]; | ||||
| } else { | } else { | ||||
| CountReachThresholdRequest count_reach_threashold_req; | |||||
| count_reach_threashold_req.set_name(name); | |||||
| CountReachThresholdRequest count_reach_threshold_req; | |||||
| count_reach_threshold_req.set_name(name); | |||||
| std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr; | std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr; | ||||
| if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_, | |||||
| if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_, | |||||
| core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { | core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { | ||||
| MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; | MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; | ||||
| return false; | return false; | ||||
| } | } | ||||
| CountReachThresholdResponse count_reach_threashold_rsp; | |||||
| count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size()); | |||||
| return count_reach_threashold_rsp.is_enough(); | |||||
| CountReachThresholdResponse count_reach_threshold_rsp; | |||||
| count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size()); | |||||
| return count_reach_threshold_rsp.is_enough(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -200,9 +200,9 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared | |||||
| return; | return; | ||||
| } | } | ||||
| CountReachThresholdRequest count_reach_threashold_req; | |||||
| count_reach_threashold_req.ParseFromArray(message->data(), message->len()); | |||||
| const std::string &name = count_reach_threashold_req.name(); | |||||
| CountReachThresholdRequest count_reach_threshold_req; | |||||
| count_reach_threshold_req.ParseFromArray(message->data(), message->len()); | |||||
| const std::string &name = count_reach_threshold_req.name(); | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | std::unique_lock<std::mutex> lock(mutex_[name]); | ||||
| if (global_threshold_count_.count(name) == 0) { | if (global_threshold_count_.count(name) == 0) { | ||||
| @@ -210,10 +210,10 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared | |||||
| return; | return; | ||||
| } | } | ||||
| CountReachThresholdResponse count_reach_threashold_rsp; | |||||
| count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); | |||||
| communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(), | |||||
| count_reach_threashold_rsp.SerializeAsString().size(), message); | |||||
| CountReachThresholdResponse count_reach_threshold_rsp; | |||||
| count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); | |||||
| communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(), | |||||
| count_reach_threshold_rsp.SerializeAsString().size(), message); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -193,7 +193,29 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<co | |||||
| bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) { | bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) { | ||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | std::unique_lock<std::mutex> lock(mutex_[name]); | ||||
| metadata_[name] = meta; | |||||
| if (meta.has_device_meta()) { | |||||
| auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta(); | |||||
| auto &fl_id = meta.device_meta().fl_id(); | |||||
| auto &device_meta = meta.device_meta(); | |||||
| fl_id_to_meta_map[fl_id] = device_meta; | |||||
| } else if (meta.has_fl_id()) { | |||||
| auto client_list = metadata_[name].mutable_client_list(); | |||||
| auto &fl_id = meta.fl_id().fl_id(); | |||||
| // Check whether the new item already exists. | |||||
| bool add_flag = true; | |||||
| for (int i = 0; i < client_list->fl_id_size(); i++) { | |||||
| if (fl_id == client_list->fl_id(i)) { | |||||
| add_flag = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (add_flag) { | |||||
| client_list->add_fl_id(fl_id); | |||||
| } | |||||
| } else if (meta.has_update_model_threshold()) { | |||||
| auto update_model_threshold = metadata_[name].mutable_update_model_threshold(); | |||||
| *update_model_threshold = meta.update_model_threshold(); | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace server | } // namespace server | ||||
| @@ -23,7 +23,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace server { | namespace server { | ||||
| void Executor::Init(const FuncGraphPtr &func_graph, size_t aggregation_count) { | |||||
| void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| if (aggregation_count == 0) { | if (aggregation_count == 0) { | ||||
| MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0"; | MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0"; | ||||
| @@ -43,7 +43,7 @@ class Executor { | |||||
| // be used for aggregators. | // be used for aggregators. | ||||
| // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the | // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the | ||||
| // optimizer cnode's input. So we need to initialize server executor using func_graph. | // optimizer cnode's input. So we need to initialize server executor using func_graph. | ||||
| void Init(const FuncGraphPtr &func_graph, size_t aggregation_count); | |||||
| void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count); | |||||
| // Called in parameter server training mode to do Push operation. | // Called in parameter server training mode to do Push operation. | ||||
| // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered | // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered | ||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/kernel/fed_avg_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| REG_AGGREGATION_KERNEL_TWO(FedAvg, | |||||
| ParamsInfo() | |||||
| .AddInputNameType(kWeight, kNumberTypeFloat32) | |||||
| .AddInputNameType(kDataSize, kNumberTypeUInt64) | |||||
| .AddInputNameType(kNewWeight, kNumberTypeFloat32) | |||||
| .AddInputNameType(kNewDataSize, kNumberTypeUInt64), | |||||
| FedAvgKernel, float, size_t) | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,179 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <functional> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/collective_ops_impl.h" | |||||
| #include "ps/server/distributed_count_service.h" | |||||
| #include "ps/server/local_meta_store.h" | |||||
| #include "ps/server/kernel/aggregation_kernel.h" | |||||
| #include "ps/server/kernel/aggregation_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| // The implementation for the federated average. We do weighted average for the weights. The uploaded weights from | |||||
| // FL-clients is already multiplied by its data size so only sum and division are done in this kernel. | |||||
| // Pay attention that this kernel is the distributed version of federated average, which means each server node in the | |||||
| // cluster in invalved in the aggragation process. So the DistributedCountService and CollectiveOpsImpl are called. | |||||
| template <typename T, typename S> | |||||
| class FedAvgKernel : public AggregationKernel { | |||||
| public: | |||||
| FedAvgKernel() : participated_(false) {} | |||||
| ~FedAvgKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 || | |||||
| kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) { | |||||
| MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name; | |||||
| return; | |||||
| } | |||||
| cnode_weight_idx_ = kNameToIdxMap.at(cnode_name).at("inputs").at("weight"); | |||||
| std::vector<size_t> weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_weight_idx_); | |||||
| size_t weight_size = | |||||
| std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(T), std::multiplies<size_t>()); | |||||
| size_t new_weight_size = weight_size; | |||||
| input_size_list_.push_back(weight_size); | |||||
| input_size_list_.push_back(sizeof(size_t)); | |||||
| input_size_list_.push_back(new_weight_size); | |||||
| input_size_list_.push_back(sizeof(size_t)); | |||||
| auto weight_node = | |||||
| AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first; | |||||
| MS_EXCEPTION_IF_NULL(weight_node); | |||||
| name_ = cnode_name + "." + weight_node->fullname_with_scope(); | |||||
| MS_LOG(INFO) << "Register counter for " << name_; | |||||
| auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { | |||||
| std::unique_lock<std::mutex> lock(weight_mutex_); | |||||
| if (!participated_) { | |||||
| ClearWeightAndDataSize(); | |||||
| } | |||||
| }; | |||||
| auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { | |||||
| T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr); | |||||
| size_t weight_size = weight_addr_->size; | |||||
| S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr); | |||||
| if (!CollectiveOpsImpl::GetInstance().AllReduce<T>(weight_addr, weight_addr, weight_size / sizeof(T))) { | |||||
| MS_LOG(ERROR) << "Federated average allreduce failed."; | |||||
| return; | |||||
| } | |||||
| if (!CollectiveOpsImpl::GetInstance().AllReduce<S>(data_size_addr, data_size_addr, 1)) { | |||||
| MS_LOG(ERROR) << "Federated average allreduce failed."; | |||||
| return; | |||||
| } | |||||
| LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]); | |||||
| for (size_t i = 0; i < weight_size / sizeof(T); i++) { | |||||
| weight_addr[i] /= data_size_addr[0]; | |||||
| } | |||||
| done_ = true; | |||||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||||
| return; | |||||
| }; | |||||
| DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler}); | |||||
| return; | |||||
| } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override { | |||||
| std::unique_lock<std::mutex> lock(weight_mutex_); | |||||
| // The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication | |||||
| // again. | |||||
| T *weight_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| S *data_size_addr = reinterpret_cast<S *>(inputs[1]->addr); | |||||
| T *new_weight_addr = reinterpret_cast<T *>(inputs[2]->addr); | |||||
| S *new_data_size_addr = reinterpret_cast<S *>(inputs[3]->addr); | |||||
| if (accum_count_ == 0) { | |||||
| ClearWeightAndDataSize(); | |||||
| } | |||||
| MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for " | |||||
| << name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is " | |||||
| << data_size_addr[0]; | |||||
| for (size_t i = 0; i < inputs[2]->size / sizeof(T); i++) { | |||||
| weight_addr[i] += new_weight_addr[i]; | |||||
| } | |||||
| data_size_addr[0] += new_data_size_addr[0]; | |||||
| lock.unlock(); | |||||
| accum_count_++; | |||||
| participated_ = true; | |||||
| DistributedCountService::GetInstance().Count( | |||||
| name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); | |||||
| return true; | |||||
| } | |||||
| void Reset() { | |||||
| accum_count_ = 0; | |||||
| done_ = false; | |||||
| participated_ = false; | |||||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||||
| return; | |||||
| } | |||||
| bool IsAggregationDone() { return done_; } | |||||
| private: | |||||
| void GenerateReuseKernelNodeInfo() override { | |||||
| // Only the trainable parameter is reused for federated average. | |||||
| reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_)); | |||||
| return; | |||||
| } | |||||
| // In some cases, the Launch method is not called and the weights involved in AllReduce should be set to 0. | |||||
| void ClearWeightAndDataSize() { | |||||
| int ret = memset_s(weight_addr_->addr, weight_addr_->size, 0x00, weight_addr_->size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")"; | |||||
| return; | |||||
| } | |||||
| ret = memset_s(data_size_addr_->addr, data_size_addr_->size, 0x00, data_size_addr_->size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")"; | |||||
| return; | |||||
| } | |||||
| return; | |||||
| } | |||||
| // The trainable parameter index of the kernel node which is parsed from the frontend func_graph. | |||||
| size_t cnode_weight_idx_; | |||||
| // The address pointer of the inputs. | |||||
| AddressPtr weight_addr_; | |||||
| AddressPtr data_size_addr_; | |||||
| AddressPtr new_weight_addr_; | |||||
| AddressPtr new_data_size_addr_; | |||||
| // Whether the kernel's Launch method is called. | |||||
| bool participated_; | |||||
| // The kernel could be called concurrently so we need lock to ensure threadsafe. | |||||
| std::mutex weight_mutex_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ | |||||
| @@ -0,0 +1,125 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/kernel/round/get_model_kernel.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ps/server/model_store.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| void GetModelKernel::InitKernel(size_t) { | |||||
| if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { | |||||
| iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); | |||||
| } | |||||
| executor_ = &Executor::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(executor_); | |||||
| if (!executor_->initialized()) { | |||||
| MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) { | |||||
| MS_LOG(INFO) << "Launching GetModelKernel kernel."; | |||||
| void *req_data = inputs[0]->addr; | |||||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||||
| if (fbb == nullptr || req_data == nullptr) { | |||||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | |||||
| return false; | |||||
| } | |||||
| const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data); | |||||
| GetModel(get_model_req, fbb); | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return true; | |||||
| } | |||||
| bool GetModelKernel::Reset() { | |||||
| MS_LOG(INFO) << "Get model kernel reset!"; | |||||
| StopTimer(); | |||||
| return true; | |||||
| } | |||||
| void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb) { | |||||
| std::map<std::string, AddressPtr> feature_maps; | |||||
| size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); | |||||
| size_t get_model_iter = static_cast<size_t>(get_model_req->iteration()); | |||||
| const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); | |||||
| size_t latest_iter_num = iter_to_model.rbegin()->first; | |||||
| if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) { | |||||
| std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter); | |||||
| BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(WARNING) << reason; | |||||
| return; | |||||
| } | |||||
| if (iter_to_model.count(get_model_iter) == 0) { | |||||
| std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) + | |||||
| " is invalid. Current iteration is " + std::to_string(current_iter); | |||||
| BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(ERROR) << reason; | |||||
| return; | |||||
| } | |||||
| feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter); | |||||
| BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, | |||||
| "Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter, | |||||
| feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| return; | |||||
| } | |||||
| void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||||
| const std::string &reason, const size_t iter, | |||||
| const std::map<std::string, AddressPtr> &feature_maps, | |||||
| const std::string ×tamp) { | |||||
| auto fbs_reason = fbb->CreateString(reason); | |||||
| auto fbs_timestamp = fbb->CreateString(timestamp); | |||||
| std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; | |||||
| for (const auto &feature_map : feature_maps) { | |||||
| auto fbs_weight_fullname = fbb->CreateString(feature_map.first); | |||||
| auto fbs_weight_data = | |||||
| fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float)); | |||||
| auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data); | |||||
| fbs_feature_maps.push_back(fbs_feature_map); | |||||
| } | |||||
| auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); | |||||
| schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get())); | |||||
| rsp_get_model_builder.add_retcode(retcode); | |||||
| rsp_get_model_builder.add_reason(fbs_reason); | |||||
| rsp_get_model_builder.add_iteration(static_cast<int>(iter)); | |||||
| rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector); | |||||
| rsp_get_model_builder.add_timestamp(fbs_timestamp); | |||||
| auto rsp_get_model = rsp_get_model_builder.Finish(); | |||||
| fbb->Finish(rsp_get_model); | |||||
| return; | |||||
| } | |||||
| REG_ROUND_KERNEL(getModel, GetModelKernel) | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/executor.h" | |||||
| #include "ps/server/kernel/round/round_kernel.h" | |||||
| #include "ps/server/kernel/round/round_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| class GetModelKernel : public RoundKernel { | |||||
| public: | |||||
| GetModelKernel() = default; | |||||
| ~GetModelKernel() override = default; | |||||
| void InitKernel(size_t) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs); | |||||
| bool Reset() override; | |||||
| private: | |||||
| void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb); | |||||
| void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||||
| const std::string &reason, const size_t iter, | |||||
| const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp); | |||||
| // The executor is for getting model for getModel request. | |||||
| Executor *executor_; | |||||
| // The time window of one iteration. | |||||
| size_t iteration_time_window_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ | |||||
| @@ -49,7 +49,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||||
| return false; | return false; | ||||
| } | } | ||||
| void *req_data = inputs[0]->addr; | void *req_data = inputs[0]->addr; | ||||
| const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>(); | |||||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||||
| if (fbb == nullptr || req_data == nullptr) { | if (fbb == nullptr || req_data == nullptr) { | ||||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | ||||
| return false; | return false; | ||||
| @@ -0,0 +1,203 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ps/server/kernel/round/update_model_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| void UpdateModelKernel::InitKernel(size_t threshold_count) { | |||||
| if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { | |||||
| iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); | |||||
| } | |||||
| executor_ = &Executor::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(executor_); | |||||
| if (!executor_->initialized()) { | |||||
| MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; | |||||
| return; | |||||
| } | |||||
| PBMetadata client_list; | |||||
| DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list); | |||||
| LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count); | |||||
| } | |||||
| bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) { | |||||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||||
| MS_LOG(ERROR) << "inputs or outputs size is invalid."; | |||||
| return false; | |||||
| } | |||||
| void *req_data = inputs[0]->addr; | |||||
| std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); | |||||
| if (fbb == nullptr || req_data == nullptr) { | |||||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | |||||
| return false; | |||||
| } | |||||
| MS_LOG(INFO) << "Launching UpdateModelKernel kernel."; | |||||
| if (!ReachThresholdForUpdateModel(fbb)) { | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return false; | |||||
| } | |||||
| const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data); | |||||
| if (!UpdateModel(update_model_req, fbb)) { | |||||
| MS_LOG(ERROR) << "Updating model failed."; | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return false; | |||||
| } | |||||
| if (!CountForUpdateModel(fbb, update_model_req)) { | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return false; | |||||
| } | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return true; | |||||
| } | |||||
| bool UpdateModelKernel::Reset() { | |||||
| MS_LOG(INFO) << "Update model kernel reset!"; | |||||
| StopTimer(); | |||||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||||
| executor_->ResetAggregationStatus(); | |||||
| DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList); | |||||
| size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value<size_t>(kCtxFedAvgTotalDataSize); | |||||
| total_data_size = 0; | |||||
| return true; | |||||
| } | |||||
| void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { | |||||
| if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) { | |||||
| while (!executor_->IsAllWeightAggregationDone()) { | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(5)); | |||||
| } | |||||
| size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize); | |||||
| MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " | |||||
| << total_data_size; | |||||
| FinishIterCb(); | |||||
| } | |||||
| } | |||||
| bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) { | |||||
| if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { | |||||
| std::string reason = "Current amount for updateModel is enough."; | |||||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(ERROR) << reason; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, | |||||
| const std::shared_ptr<FBBuilder> &fbb) { | |||||
| size_t iteration = static_cast<size_t>(update_model_req->iteration()); | |||||
| if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { | |||||
| std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + | |||||
| ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()); | |||||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(ERROR) << reason; | |||||
| return false; | |||||
| } | |||||
| PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas); | |||||
| FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas(); | |||||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||||
| if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) { | |||||
| std::string reason = "devices_meta for " + update_model_fl_id + " is not set."; | |||||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(ERROR) << reason; | |||||
| return false; | |||||
| } | |||||
| size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size(); | |||||
| auto feature_map = ParseFeatureMap(update_model_req); | |||||
| for (auto weight : feature_map) { | |||||
| weight.second[kNewDataSize].addr = &data_size; | |||||
| weight.second[kNewDataSize].size = sizeof(size_t); | |||||
| executor_->HandleModelUpdate(weight.first, weight.second); | |||||
| } | |||||
| FLId fl_id; | |||||
| fl_id.set_fl_id(update_model_fl_id); | |||||
| PBMetadata comm_value; | |||||
| *comm_value.mutable_fl_id() = fl_id; | |||||
| DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value); | |||||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready", | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| return true; | |||||
| } | |||||
| std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap( | |||||
| const schema::RequestUpdateModel *update_model_req) { | |||||
| RETURN_IF_NULL(update_model_req, {}); | |||||
| std::map<std::string, UploadData> feature_map; | |||||
| auto fbs_feature_map = update_model_req->feature_map(); | |||||
| for (size_t i = 0; i < fbs_feature_map->size(); i++) { | |||||
| std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str(); | |||||
| float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data()); | |||||
| size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float); | |||||
| UploadData upload_data; | |||||
| upload_data[kNewWeight].addr = weight_data; | |||||
| upload_data[kNewWeight].size = weight_size; | |||||
| feature_map[weight_full_name] = upload_data; | |||||
| } | |||||
| return feature_map; | |||||
| } | |||||
| bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, | |||||
| const schema::RequestUpdateModel *update_model_req) { | |||||
| if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) { | |||||
| std::string reason = "UpdateModel counting failed."; | |||||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(ERROR) << reason; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||||
| const std::string &reason, const std::string &next_req_time) { | |||||
| auto fbs_reason = fbb->CreateString(reason); | |||||
| auto fbs_next_req_time = fbb->CreateString(next_req_time); | |||||
| schema::ResponseUpdateModelBuilder rsp_update_model_builder(*(fbb.get())); | |||||
| rsp_update_model_builder.add_retcode(retcode); | |||||
| rsp_update_model_builder.add_reason(fbs_reason); | |||||
| rsp_update_model_builder.add_next_req_time(fbs_next_req_time); | |||||
| auto rsp_update_model = rsp_update_model_builder.Finish(); | |||||
| fbb->Finish(rsp_update_model); | |||||
| return; | |||||
| } | |||||
| REG_ROUND_KERNEL(updateModel, UpdateModelKernel) | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/kernel/round/round_kernel.h" | |||||
| #include "ps/server/kernel/round/round_kernel_factory.h" | |||||
| #include "ps/server/executor.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| class UpdateModelKernel : public RoundKernel { | |||||
| public: | |||||
| UpdateModelKernel() = default; | |||||
| ~UpdateModelKernel() override = default; | |||||
| void InitKernel(size_t threshold_count) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs); | |||||
| bool Reset() override; | |||||
| // In some cases, the last updateModel message means this server iteration is finished. | |||||
| void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override; | |||||
| private: | |||||
| bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb); | |||||
| bool UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb); | |||||
| std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); | |||||
| bool CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestUpdateModel *update_model_req); | |||||
| void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||||
| const std::string &reason, const std::string &next_req_time); | |||||
| // The executor is for updating the model for updateModel request. | |||||
| Executor *executor_; | |||||
| // The time window of one iteration. | |||||
| size_t iteration_time_window_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ | |||||
| @@ -23,7 +23,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace server { | namespace server { | ||||
| void ModelStore::Init(uint32_t max_count) { | |||||
| void ModelStore::Initialize(uint32_t max_count) { | |||||
| if (!Executor::GetInstance().initialized()) { | if (!Executor::GetInstance().initialized()) { | ||||
| MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage."; | MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage."; | ||||
| return; | return; | ||||
| @@ -40,7 +40,7 @@ class ModelStore { | |||||
| } | } | ||||
| // Initialize ModelStore with max count of models need to be stored. | // Initialize ModelStore with max count of models need to be stored. | ||||
| void Init(uint32_t max_count = 3); | |||||
| void Initialize(uint32_t max_count = 3); | |||||
| // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already | // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already | ||||
| // max_model_count_, the earliest model will be replaced. | // max_model_count_, the earliest model will be replaced. | ||||
| @@ -302,7 +302,7 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke | |||||
| } | } | ||||
| std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { | std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { | ||||
| std::vector<std::string> aggregation_algorithm = {}; | |||||
| std::vector<std::string> aggregation_algorithm = {kFedAvg}; | |||||
| MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; | MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; | ||||
| return aggregation_algorithm; | return aggregation_algorithm; | ||||
| } | } | ||||
| @@ -0,0 +1,251 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/server.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <csignal> | |||||
| #include "ps/server/round.h" | |||||
| #include "ps/server/model_store.h" | |||||
| #include "ps/server/iteration.h" | |||||
| #include "ps/server/collective_ops_impl.h" | |||||
| #include "ps/server/distributed_metadata_store.h" | |||||
| #include "ps/server/distributed_count_service.h" | |||||
| #include "ps/server/kernel/round/round_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| static std::vector<std::shared_ptr<core::CommunicatorBase>> global_worker_server_comms = {}; | |||||
| // This function is for the exit of server process when an interrupt signal is captured. | |||||
| void SignalHandler(int signal) { | |||||
| MS_LOG(INFO) << "Interrupt signal captured: " << signal; | |||||
| std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(), | |||||
| [](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); }); | |||||
| return; | |||||
| } | |||||
| void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config, | |||||
| const FuncGraphPtr &func_graph, size_t executor_threshold) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| func_graph_ = func_graph; | |||||
| if (rounds_config.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Rounds are empty."; | |||||
| return; | |||||
| } | |||||
| rounds_config_ = rounds_config; | |||||
| use_tcp_ = use_tcp; | |||||
| use_http_ = use_http; | |||||
| http_port_ = http_port; | |||||
| executor_threshold_ = executor_threshold; | |||||
| return; | |||||
| } | |||||
| // Each step of the server pipeline may have dependency on other steps, which includes: | |||||
| // InitServerContext must be the first step to set contexts for later steps. | |||||
| // Server Running relies on URL or Message Type Register: | |||||
| // StartCommunicator---->InitIteration | |||||
| // Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion: | |||||
| // RegisterRoundKernel---->StartCommunicator | |||||
| // Kernel Initialization relies on Executor Initialization: | |||||
| // RegisterRoundKernel---->InitExecutor | |||||
| // Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization: | |||||
| // InitCipher---->InitExecutor | |||||
| void Server::Run() { | |||||
| signal(SIGINT, SignalHandler); | |||||
| InitServerContext(); | |||||
| InitCluster(); | |||||
| InitIteration(); | |||||
| StartCommunicator(); | |||||
| InitExecutor(); | |||||
| RegisterRoundKernel(); | |||||
| MS_LOG(INFO) << "Server started successfully."; | |||||
| // Wait communicators to stop so the main thread is blocked. | |||||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||||
| [](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Join(); }); | |||||
| communicator_with_server_->Join(); | |||||
| MsException::Instance().CheckException(); | |||||
| return; | |||||
| } | |||||
| void Server::InitServerContext() { | |||||
| PSContext::instance()->GenerateResetterRound(); | |||||
| scheduler_ip_ = PSContext::instance()->scheduler_host(); | |||||
| scheduler_port_ = PSContext::instance()->scheduler_port(); | |||||
| worker_num_ = PSContext::instance()->initial_worker_num(); | |||||
| server_num_ = PSContext::instance()->initial_server_num(); | |||||
| return; | |||||
| } | |||||
| void Server::InitCluster() { | |||||
| server_node_ = std::make_shared<core::ServerNode>(); | |||||
| MS_EXCEPTION_IF_NULL(server_node_); | |||||
| task_executor_ = std::make_shared<core::TaskExecutor>(32); | |||||
| MS_EXCEPTION_IF_NULL(task_executor_); | |||||
| if (!InitCommunicatorWithServer()) { | |||||
| MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed."; | |||||
| return; | |||||
| } | |||||
| if (!InitCommunicatorWithWorker()) { | |||||
| MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed."; | |||||
| return; | |||||
| } | |||||
| global_worker_server_comms = communicators_with_worker_; | |||||
| return; | |||||
| } | |||||
| bool Server::InitCommunicatorWithServer() { | |||||
| MS_EXCEPTION_IF_NULL(task_executor_); | |||||
| MS_EXCEPTION_IF_NULL(server_node_); | |||||
| communicator_with_server_ = | |||||
| server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_); | |||||
| MS_EXCEPTION_IF_NULL(communicator_with_server_); | |||||
| // Set exception event callbacks for server. | |||||
| auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_); | |||||
| MS_EXCEPTION_IF_NULL(tcp_comm); | |||||
| tcp_comm->RegisterEventCallback(core::CLUSTER_TIMEOUT, [&]() { | |||||
| MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not " | |||||
| "started during network building phase."; | |||||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||||
| [](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); }); | |||||
| communicator_with_server_->Stop(); | |||||
| }); | |||||
| tcp_comm->RegisterEventCallback(core::SCHEDULER_TIMEOUT, [&]() { | |||||
| MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed."; | |||||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||||
| [](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); }); | |||||
| communicator_with_server_->Stop(); | |||||
| }); | |||||
| tcp_comm->RegisterEventCallback(core::NODE_TIMEOUT, [&]() { | |||||
| MS_LOG(ERROR) | |||||
| << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the " | |||||
| "network building phase."; | |||||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||||
| [](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); }); | |||||
| communicator_with_server_->Stop(); | |||||
| }); | |||||
| return true; | |||||
| } | |||||
| bool Server::InitCommunicatorWithWorker() { | |||||
| MS_EXCEPTION_IF_NULL(server_node_); | |||||
| MS_EXCEPTION_IF_NULL(task_executor_); | |||||
| if (!use_tcp_ && !use_http_) { | |||||
| MS_LOG(EXCEPTION) << "At least one type of protocol should be set."; | |||||
| return false; | |||||
| } | |||||
| if (use_tcp_) { | |||||
| auto tcp_comm = communicator_with_server_; | |||||
| MS_EXCEPTION_IF_NULL(tcp_comm); | |||||
| communicators_with_worker_.push_back(tcp_comm); | |||||
| } | |||||
| if (use_http_) { | |||||
| auto http_comm = server_node_->GetOrCreateHttpComm("0.0.0.0", http_port_, task_executor_); | |||||
| MS_EXCEPTION_IF_NULL(http_comm); | |||||
| communicators_with_worker_.push_back(http_comm); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void Server::InitIteration() { | |||||
| iteration_ = std::make_shared<Iteration>(); | |||||
| MS_EXCEPTION_IF_NULL(iteration_); | |||||
| // 1.Add rounds to the iteration according to the server mode. | |||||
| for (const RoundConfig &config : rounds_config_) { | |||||
| std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window, | |||||
| config.check_count, config.threshold_count); | |||||
| MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count | |||||
| << ", threshold:" << config.threshold_count; | |||||
| iteration_->AddRound(round); | |||||
| } | |||||
| // 2.Initialize all the rounds. | |||||
| TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_); | |||||
| FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_); | |||||
| iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb); | |||||
| return; | |||||
| } | |||||
| void Server::InitExecutor() { | |||||
| if (executor_threshold_ == 0) { | |||||
| MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0."; | |||||
| return; | |||||
| } | |||||
| // The train engine instance is used in both push-type and pull-type kernels, | |||||
| // so the required_cnt of these kernels must be the same as update_model_threshold_. | |||||
| MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_; | |||||
| Executor::GetInstance().Initialize(func_graph_, executor_threshold_); | |||||
| ModelStore::GetInstance().Initialize(); | |||||
| return; | |||||
| } | |||||
| void Server::RegisterRoundKernel() { | |||||
| MS_EXCEPTION_IF_NULL(iteration_); | |||||
| auto &rounds = iteration_->rounds(); | |||||
| if (rounds.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Server has no round registered."; | |||||
| return; | |||||
| } | |||||
| for (auto &round : rounds) { | |||||
| const std::string &name = round->name(); | |||||
| std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name); | |||||
| if (round_kernel == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered."; | |||||
| return; | |||||
| } | |||||
| // For some round kernels, the threshold count should be set. | |||||
| round_kernel->InitKernel(round->threshold_count()); | |||||
| round->BindRoundKernel(round_kernel); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void Server::StartCommunicator() { | |||||
| MS_EXCEPTION_IF_NULL(communicator_with_server_); | |||||
| if (communicators_with_worker_.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty."; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Start communicator with server."; | |||||
| communicator_with_server_->Start(); | |||||
| DistributedMetadataStore::GetInstance().Initialize(server_node_); | |||||
| CollectiveOpsImpl::GetInstance().Initialize(server_node_); | |||||
| DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank); | |||||
| MS_LOG(INFO) << "Start communicator with worker."; | |||||
| std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), | |||||
| [](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); }); | |||||
| } | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,131 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ps/core/communicator/communicator_base.h" | |||||
| #include "ps/core/communicator/tcp_communicator.h" | |||||
| #include "ps/core/communicator/task_executor.h" | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/executor.h" | |||||
| #include "ps/server/iteration.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| // Class Server is the entrance of MindSpore's parameter server training mode and federated learning. | |||||
| class Server { | |||||
| public: | |||||
| static Server &GetInstance() { | |||||
| static Server instance; | |||||
| return instance; | |||||
| } | |||||
| void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config, | |||||
| const FuncGraphPtr &func_graph, size_t executor_threshold); | |||||
| // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be | |||||
| // blocked until the server is finalized. | |||||
| // func_graph is the frontend graph which will be parse in server's exector and aggregator. | |||||
| void Run(); | |||||
| private: | |||||
| Server() | |||||
| : server_node_(nullptr), | |||||
| task_executor_(nullptr), | |||||
| use_tcp_(false), | |||||
| use_http_(false), | |||||
| http_port_(0), | |||||
| func_graph_(nullptr), | |||||
| executor_threshold_(0), | |||||
| communicator_with_server_(nullptr), | |||||
| communicators_with_worker_({}), | |||||
| iteration_(nullptr), | |||||
| scheduler_ip_(""), | |||||
| scheduler_port_(0), | |||||
| server_num_(0), | |||||
| worker_num_(0) {} | |||||
| ~Server() = default; | |||||
| Server(const Server &) = delete; | |||||
| Server &operator=(const Server &) = delete; | |||||
| // Load variables which is set by ps_context. | |||||
| void InitServerContext(); | |||||
| // Initialize the server cluster, server node and communicators. | |||||
| void InitCluster(); | |||||
| bool InitCommunicatorWithServer(); | |||||
| bool InitCommunicatorWithWorker(); | |||||
| // Initialize iteration with rounds. Which rounds to use could be set by ps_context as well. | |||||
| void InitIteration(); | |||||
| // Initialize executor according to the server mode. | |||||
| void InitExecutor(); | |||||
| // Create round kernels and bind these kernels with corresponding Round. | |||||
| void RegisterRoundKernel(); | |||||
| // The communicators should be started after all initializations are completed. | |||||
| void StartCommunicator(); | |||||
| // The server node is initialized in Server. | |||||
| std::shared_ptr<core::ServerNode> server_node_; | |||||
| // The task executor of the communicators. This helps server to handle network message concurrently. The tasks | |||||
| // submitted to this task executor is asynchronous. | |||||
| std::shared_ptr<core::TaskExecutor> task_executor_; | |||||
| // Which protocol should communicators use. | |||||
| bool use_tcp_; | |||||
| bool use_http_; | |||||
| uint64_t http_port_; | |||||
| // The configure of all rounds. | |||||
| std::vector<RoundConfig> rounds_config_; | |||||
| // The graph passed by the frontend without backend optimizing. | |||||
| FuncGraphPtr func_graph_; | |||||
| // The threshold count for executor to do aggregation or optimizing. | |||||
| size_t executor_threshold_; | |||||
| // Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective | |||||
| // operations, etc. | |||||
| std::shared_ptr<core::CommunicatorBase> communicator_with_server_; | |||||
| // The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP. | |||||
| // In some cases, both types should be supported in one distributed training job. So here we may have multiple | |||||
| // communicators. | |||||
| std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_; | |||||
| // Iteration consists of multiple kinds of rounds. | |||||
| std::shared_ptr<Iteration> iteration_; | |||||
| // Variables set by ps context. | |||||
| std::string scheduler_ip_; | |||||
| uint16_t scheduler_port_; | |||||
| uint32_t server_num_; | |||||
| uint32_t worker_num_; | |||||
| }; | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ | |||||
| @@ -33,7 +33,21 @@ def ps_context(): | |||||
| return _ps_context | return _ps_context | ||||
| _set_ps_context_func_map = { | _set_ps_context_func_map = { | ||||
| "enable_ps": ps_context().set_ps_enable | |||||
| "server_mode": ps_context().set_server_mode, | |||||
| "ms_role": ps_context().set_ms_role, | |||||
| "enable_ps": ps_context().set_ps_enable, | |||||
| "worker_num": ps_context().set_worker_num, | |||||
| "server_num": ps_context().set_server_num, | |||||
| "scheduler_ip": ps_context().set_scheduler_ip, | |||||
| "scheduler_port": ps_context().set_scheduler_port, | |||||
| "fl_server_port": ps_context().set_fl_server_port, | |||||
| "enable_fl_client": ps_context().set_fl_client_enable, | |||||
| "start_fl_job_threshold": ps_context().set_start_fl_job_threshold, | |||||
| "fl_name": ps_context().set_fl_name, | |||||
| "fl_iteration_num": ps_context().set_fl_iteration_num, | |||||
| "client_epoch_num": ps_context().set_client_epoch_num, | |||||
| "client_batch_size": ps_context().set_client_batch_size, | |||||
| "secure_aggregation": ps_context().set_secure_aggregation | |||||
| } | } | ||||
| _get_ps_context_func_map = { | _get_ps_context_func_map = { | ||||
| @@ -0,0 +1,30 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import argparse | |||||
| import subprocess | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description="Finish test_mobile_lenet.py case") | |||||
| parser.add_argument("--scheduler_port", type=int, default=8113) | |||||
| args, _ = parser.parse_known_args() | |||||
| scheduler_port = args.scheduler_port | |||||
| cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" " | |||||
| cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && " | |||||
| cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done" | |||||
| subprocess.call(['bash', '-c', cmd]) | |||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import argparse | |||||
| import subprocess | |||||
| parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case") | |||||
| parser.add_argument("--device_target", type=str, default="CPU") | |||||
| parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") | |||||
| parser.add_argument("--worker_num", type=int, default=0) | |||||
| parser.add_argument("--server_num", type=int, default=2) | |||||
| parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") | |||||
| parser.add_argument("--scheduler_port", type=int, default=8113) | |||||
| parser.add_argument("--fl_server_port", type=int, default=6666) | |||||
| if __name__ == "__main__": | |||||
| args, _ = parser.parse_known_args() | |||||
| device_target = args.device_target | |||||
| server_mode = args.server_mode | |||||
| worker_num = args.worker_num | |||||
| server_num = args.server_num | |||||
| scheduler_ip = args.scheduler_ip | |||||
| scheduler_port = args.scheduler_port | |||||
| fl_server_port = args.fl_server_port | |||||
| cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&" | |||||
| cmd_sched += "mkdir ${execute_path}/scheduler/ &&" | |||||
| cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&" | |||||
| cmd_sched += "python ${self_path}/../test_mobile_lenet.py" | |||||
| cmd_sched += " --device_target=" + device_target | |||||
| cmd_sched += " --server_mode=" + server_mode | |||||
| cmd_sched += " --ms_role=MS_SCHED" | |||||
| cmd_sched += " --worker_num=" + str(worker_num) | |||||
| cmd_sched += " --server_num=" + str(server_num) | |||||
| cmd_sched += " --scheduler_ip=" + scheduler_ip | |||||
| cmd_sched += " --scheduler_port=" + str(scheduler_port) | |||||
| cmd_sched += " --fl_server_port=" + str(fl_server_port) | |||||
| cmd_sched += " > scheduler.log 2>&1 &" | |||||
| subprocess.call(['bash', '-c', cmd_sched]) | |||||
| @@ -0,0 +1,82 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import ast | |||||
| import argparse | |||||
| import subprocess | |||||
| parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case") | |||||
| parser.add_argument("--device_target", type=str, default="CPU") | |||||
| parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") | |||||
| parser.add_argument("--worker_num", type=int, default=0) | |||||
| parser.add_argument("--server_num", type=int, default=2) | |||||
| parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") | |||||
| parser.add_argument("--scheduler_port", type=int, default=8113) | |||||
| parser.add_argument("--fl_server_port", type=int, default=6666) | |||||
| parser.add_argument("--start_fl_job_threshold", type=int, default=1) | |||||
| parser.add_argument("--fl_name", type=str, default="Lenet") | |||||
| parser.add_argument("--fl_iteration_num", type=int, default=25) | |||||
| parser.add_argument("--client_epoch_num", type=int, default=20) | |||||
| parser.add_argument("--client_batch_size", type=int, default=32) | |||||
| parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False) | |||||
| parser.add_argument("--local_server_num", type=int, default=-1) | |||||
| if __name__ == "__main__": | |||||
| args, _ = parser.parse_known_args() | |||||
| device_target = args.device_target | |||||
| server_mode = args.server_mode | |||||
| worker_num = args.worker_num | |||||
| server_num = args.server_num | |||||
| scheduler_ip = args.scheduler_ip | |||||
| scheduler_port = args.scheduler_port | |||||
| fl_server_port = args.fl_server_port | |||||
| start_fl_job_threshold = args.start_fl_job_threshold | |||||
| fl_name = args.fl_name | |||||
| fl_iteration_num = args.fl_iteration_num | |||||
| client_epoch_num = args.client_epoch_num | |||||
| client_batch_size = args.client_batch_size | |||||
| secure_aggregation = args.secure_aggregation | |||||
| local_server_num = args.local_server_num | |||||
| if local_server_num == -1: | |||||
| local_server_num = server_num | |||||
| assert local_server_num <= server_num, "The local server number should not be bigger than total server number." | |||||
| for i in range(local_server_num): | |||||
| cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && " | |||||
| cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&" | |||||
| cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&" | |||||
| cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&" | |||||
| cmd_server += "python ${self_path}/../test_mobile_lenet.py" | |||||
| cmd_server += " --device_target=" + device_target | |||||
| cmd_server += " --server_mode=" + server_mode | |||||
| cmd_server += " --ms_role=MS_SERVER" | |||||
| cmd_server += " --worker_num=" + str(worker_num) | |||||
| cmd_server += " --server_num=" + str(server_num) | |||||
| cmd_server += " --scheduler_ip=" + scheduler_ip | |||||
| cmd_server += " --scheduler_port=" + str(scheduler_port) | |||||
| cmd_server += " --fl_server_port=" + str(fl_server_port + i) | |||||
| cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold) | |||||
| cmd_server += " --fl_name=" + fl_name | |||||
| cmd_server += " --fl_iteration_num=" + str(fl_iteration_num) | |||||
| cmd_server += " --client_epoch_num=" + str(client_epoch_num) | |||||
| cmd_server += " --client_batch_size=" + str(client_batch_size) | |||||
| cmd_server += " --secure_aggregation=" + str(secure_aggregation) | |||||
| cmd_server += " > server.log 2>&1 &" | |||||
| import time | |||||
| time.sleep(0.3) | |||||
| subprocess.call(['bash', '-c', cmd_server]) | |||||
| @@ -0,0 +1,29 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| export PYTHONPATH=../../../../:$PYTHONPATH | |||||
| server_num=$1 | |||||
| worker_num=$2 | |||||
| ip=$3 | |||||
| port=$4 | |||||
| for((i=0;i<worker_num;i++)); | |||||
| do | |||||
| ofs=`expr $i % $server_num` | |||||
| real_port=`expr $port + $ofs` | |||||
| echo $real_port | |||||
| python simulator.py --pid=$i --http_ip=$ip --http_port=$port --use_elb=True --server_num=$1 > simulator_$i.log 2>&1 & | |||||
| done | |||||
| @@ -0,0 +1,191 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import argparse | |||||
| import time | |||||
| import random | |||||
| import sys | |||||
| import requests | |||||
| import flatbuffers | |||||
| import numpy as np | |||||
| from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode, | |||||
| RequestUpdateModel, FeatureMap, RequestGetModel, ResponseGetModel) | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("--pid", type=int, default=0) | |||||
| parser.add_argument("--http_ip", type=str, default="10.113.216.106") | |||||
| parser.add_argument("--http_port", type=int, default=6666) | |||||
| parser.add_argument("--use_elb", type=bool, default=False) | |||||
| parser.add_argument("--server_num", type=int, default=1) | |||||
| args, _ = parser.parse_known_args() | |||||
| pid = args.pid | |||||
| http_ip = args.http_ip | |||||
| http_port = args.http_port | |||||
| use_elb = args.use_elb | |||||
| server_num = args.server_num | |||||
| str_fl_id = 'fl_lenet_' + str(pid) | |||||
| def generate_port(): | |||||
| if not use_elb: | |||||
| return http_port | |||||
| port = random.randint(0, 100000) % server_num + http_port | |||||
| return port | |||||
| def build_start_fl_job(iteration): | |||||
| start_fl_job_builder = flatbuffers.Builder(1024) | |||||
| fl_name = start_fl_job_builder.CreateString('fl_test_job') | |||||
| fl_id = start_fl_job_builder.CreateString(str_fl_id) | |||||
| data_size = 32 | |||||
| timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18') | |||||
| RequestFLJob.RequestFLJobStart(start_fl_job_builder) | |||||
| RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name) | |||||
| RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id) | |||||
| RequestFLJob.RequestFLJobAddIteration(start_fl_job_builder, iteration) | |||||
| RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size) | |||||
| RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp) | |||||
| fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder) | |||||
| start_fl_job_builder.Finish(fl_job_req) | |||||
| buf = start_fl_job_builder.Output() | |||||
| return buf | |||||
| def build_feature_map(builder, names, lengths): | |||||
| if len(names) != len(lengths): | |||||
| return None | |||||
| feature_maps = [] | |||||
| np_data = [] | |||||
| for j, _ in enumerate(names): | |||||
| name = names[j] | |||||
| length = lengths[j] | |||||
| weight_full_name = builder.CreateString(name) | |||||
| FeatureMap.FeatureMapStartDataVector(builder, length) | |||||
| weight = np.random.rand(length) * 32 | |||||
| np_data.append(weight) | |||||
| for idx in range(length - 1, -1, -1): | |||||
| builder.PrependFloat32(weight[idx]) | |||||
| data = builder.EndVector(length) | |||||
| FeatureMap.FeatureMapStart(builder) | |||||
| FeatureMap.FeatureMapAddData(builder, data) | |||||
| FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name) | |||||
| feature_map = FeatureMap.FeatureMapEnd(builder) | |||||
| feature_maps.append(feature_map) | |||||
| return feature_maps, np_data | |||||
| def build_update_model(iteration): | |||||
| builder_update_model = flatbuffers.Builder(1) | |||||
| fl_name = builder_update_model.CreateString('fl_test_job') | |||||
| fl_id = builder_update_model.CreateString(str_fl_id) | |||||
| timestamp = builder_update_model.CreateString('2020/11/16/19/18') | |||||
| feature_maps, np_data = build_feature_map(builder_update_model, | |||||
| ["conv1.weight", "conv2.weight", "fc1.weight", | |||||
| "fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"], | |||||
| [450, 2400, 48000, 10080, 5208, 120, 84, 62]) | |||||
| RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1) | |||||
| for single_feature_map in feature_maps: | |||||
| builder_update_model.PrependUOffsetTRelative(single_feature_map) | |||||
| feature_map = builder_update_model.EndVector(len(feature_maps)) | |||||
| RequestUpdateModel.RequestUpdateModelStart(builder_update_model) | |||||
| RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name) | |||||
| RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id) | |||||
| RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration) | |||||
| RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map) | |||||
| RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp) | |||||
| req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model) | |||||
| builder_update_model.Finish(req_update_model) | |||||
| buf = builder_update_model.Output() | |||||
| return buf, np_data | |||||
| def build_get_model(iteration): | |||||
| builder_get_model = flatbuffers.Builder(1) | |||||
| fl_name = builder_get_model.CreateString('fl_test_job') | |||||
| timestamp = builder_get_model.CreateString('2020/12/16/19/18') | |||||
| RequestGetModel.RequestGetModelStart(builder_get_model) | |||||
| RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name) | |||||
| RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration) | |||||
| RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp) | |||||
| req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model) | |||||
| builder_get_model.Finish(req_get_model) | |||||
| buf = builder_get_model.Output() | |||||
| return buf | |||||
| weight_name_to_idx = { | |||||
| "conv1.weight": 0, | |||||
| "conv2.weight": 1, | |||||
| "fc1.weight": 2, | |||||
| "fc2.weight": 3, | |||||
| "fc3.weight": 4, | |||||
| "fc1.bias": 5, | |||||
| "fc2.bias": 6, | |||||
| "fc3.bias": 7 | |||||
| } | |||||
| session = requests.Session() | |||||
| current_iteration = 1 | |||||
| url = "http://" + http_ip + ":" + str(generate_port()) | |||||
| np.random.seed(0) | |||||
| while True: | |||||
| url1 = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob' | |||||
| print("start url is ", url1) | |||||
| x = requests.post(url1, data=build_start_fl_job(current_iteration)) | |||||
| rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) | |||||
| print("start fl job iteration:", current_iteration, ", id:", args.pid) | |||||
| while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED: | |||||
| x = requests.post(url1, data=build_start_fl_job(current_iteration)) | |||||
| rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) | |||||
| print("epoch is", rsp_fl_job.FlPlanConfig().Epochs()) | |||||
| sys.stdout.flush() | |||||
| url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel' | |||||
| print("req update model iteration:", current_iteration, ", id:", args.pid) | |||||
| update_model_buf, update_model_np_data = build_update_model(current_iteration) | |||||
| x = session.post(url2, data=update_model_buf) | |||||
| print("rsp update model iteration:", current_iteration, ", id:", args.pid) | |||||
| sys.stdout.flush() | |||||
| url3 = "http://" + http_ip + ":" + str(generate_port()) + '/getModel' | |||||
| print("req get model iteration:", current_iteration, ", id:", args.pid) | |||||
| x = session.post(url3, data=build_get_model(current_iteration)) | |||||
| rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0) | |||||
| print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode()) | |||||
| sys.stdout.flush() | |||||
| repeat_time = 0 | |||||
| while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady: | |||||
| time.sleep(0.1) | |||||
| x = session.post(url3, data=build_get_model(current_iteration)) | |||||
| rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0) | |||||
| repeat_time += 1 | |||||
| if repeat_time > 1000: | |||||
| print("GetModel try timeout ", args.pid) | |||||
| sys.exit(0) | |||||
| for i in range(0, 1): | |||||
| print(rsp_get_model.FeatureMap(i).WeightFullname()) | |||||
| origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]] | |||||
| after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32 | |||||
| print("Before update model", args.pid, origin[0:10]) | |||||
| print("After get model", args.pid, after[0:10]) | |||||
| sys.stdout.flush() | |||||
| assert np.allclose(origin, after, rtol=1e-05, atol=1e-05) | |||||
| current_iteration += 1 | |||||
| @@ -0,0 +1,423 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag.""" | |||||
| import numpy as np | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| from mindspore.nn.optim.optimizer import Optimizer | |||||
| _adam_opt = C.MultitypeFuncGraph("adam_opt") | |||||
| _scaler_one = Tensor(1, mstype.int32) | |||||
| _scaler_ten = Tensor(10, mstype.float32) | |||||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Bool", "Bool") | |||||
| def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): | |||||
| """ | |||||
| Update parameters by AdamWeightDecay op. | |||||
| """ | |||||
| if optim_filter: | |||||
| adam = P.AdamWeightDecay() | |||||
| if decay_flags: | |||||
| next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient) | |||||
| else: | |||||
| next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient) | |||||
| return next_param | |||||
| return gradient | |||||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Bool", "Bool") | |||||
| def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter): | |||||
| """ | |||||
| Update parameters. | |||||
| Args: | |||||
| beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). | |||||
| beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). | |||||
| eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. | |||||
| lr (Tensor): Learning rate. | |||||
| overflow (Tensor): Whether overflow occurs. | |||||
| weight_decay (Number): Weight decay. Should be equal to or greater than 0. | |||||
| param (Tensor): Parameters. | |||||
| m (Tensor): m value of parameters. | |||||
| v (Tensor): v value of parameters. | |||||
| gradient (Tensor): Gradient of parameters. | |||||
| decay_flag (bool): Applies weight decay or not. | |||||
| optim_filter (bool): Applies parameter update or not. | |||||
| Returns: | |||||
| Tensor, the new value of v after updating. | |||||
| """ | |||||
| if optim_filter: | |||||
| op_mul = P.Mul() | |||||
| op_square = P.Square() | |||||
| op_sqrt = P.Sqrt() | |||||
| op_cast = P.Cast() | |||||
| op_reshape = P.Reshape() | |||||
| op_shape = P.Shape() | |||||
| op_select = P.Select() | |||||
| param_fp32 = op_cast(param, mstype.float32) | |||||
| m_fp32 = op_cast(m, mstype.float32) | |||||
| v_fp32 = op_cast(v, mstype.float32) | |||||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||||
| cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_) | |||||
| next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\ | |||||
| op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)) | |||||
| next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\ | |||||
| op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32))) | |||||
| update = next_m / (eps + op_sqrt(next_v)) | |||||
| if decay_flag: | |||||
| update = op_mul(weight_decay, param_fp32) + update | |||||
| update_with_lr = op_mul(lr, update) | |||||
| zeros = F.fill(mstype.float32, op_shape(param_fp32), 0) | |||||
| next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32))) | |||||
| next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) | |||||
| next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) | |||||
| next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) | |||||
| return op_cast(next_param, F.dtype(param)) | |||||
| return gradient | |||||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | |||||
| beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): | |||||
| """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | |||||
| success = True | |||||
| indices = gradient.indices | |||||
| values = gradient.values | |||||
| if ps_parameter and not cache_enable: | |||||
| op_shape = P.Shape() | |||||
| shapes = (op_shape(param), op_shape(m), op_shape(v), | |||||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | |||||
| op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) | |||||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, | |||||
| eps, values, indices), shapes), param)) | |||||
| return success | |||||
| if not target: | |||||
| success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2, | |||||
| eps, values, indices)) | |||||
| else: | |||||
| op_mul = P.Mul() | |||||
| op_square = P.Square() | |||||
| op_sqrt = P.Sqrt() | |||||
| scatter_add = P.ScatterAdd(use_locking) | |||||
| assign_m = F.assign(m, op_mul(beta1, m)) | |||||
| assign_v = F.assign(v, op_mul(beta2, v)) | |||||
| grad_indices = gradient.indices | |||||
| grad_value = gradient.values | |||||
| next_m = scatter_add(m, | |||||
| grad_indices, | |||||
| op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) | |||||
| next_v = scatter_add(v, | |||||
| grad_indices, | |||||
| op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value))) | |||||
| if use_nesterov: | |||||
| m_temp = next_m * _scaler_ten | |||||
| assign_m_nesterov = F.assign(m, op_mul(beta1, next_m)) | |||||
| div_value = scatter_add(m, | |||||
| op_mul(grad_indices, _scaler_one), | |||||
| op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) | |||||
| param_update = div_value / (op_sqrt(next_v) + eps) | |||||
| m_recover = F.assign(m, m_temp / _scaler_ten) | |||||
| F.control_depend(m_temp, assign_m_nesterov) | |||||
| F.control_depend(assign_m_nesterov, div_value) | |||||
| F.control_depend(param_update, m_recover) | |||||
| else: | |||||
| param_update = next_m / (op_sqrt(next_v) + eps) | |||||
| lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) | |||||
| next_param = param - lr_t * param_update | |||||
| F.control_depend(assign_m, next_m) | |||||
| F.control_depend(assign_v, next_v) | |||||
| success = F.depend(success, F.assign(param, next_param)) | |||||
| success = F.depend(success, F.assign(m, next_m)) | |||||
| success = F.depend(success, F.assign(v, next_v)) | |||||
| return success | |||||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, | |||||
| beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, | |||||
| moment1, moment2, ps_parameter, cache_enable): | |||||
| """Apply adam optimizer to the weight parameter using Tensor.""" | |||||
| success = True | |||||
| if ps_parameter and not cache_enable: | |||||
| op_shape = P.Shape() | |||||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), | |||||
| (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) | |||||
| else: | |||||
| success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | |||||
| eps, gradient)) | |||||
| return success | |||||
| @_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor") | |||||
| def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2): | |||||
| """Apply AdamOffload optimizer to the weight parameter using Tensor.""" | |||||
| success = True | |||||
| delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient) | |||||
| success = F.depend(success, F.assign_add(param, delat_param)) | |||||
| return success | |||||
| def _check_param_value(beta1, beta2, eps, prim_name): | |||||
| """Check the type of inputs.""" | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||||
| validator.check_value_type("eps", eps, [float], prim_name) | |||||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||||
| validator.check_positive_float(eps, "eps", prim_name) | |||||
| class AdamWeightDecayForBert(Optimizer): | |||||
| """ | |||||
| Implements the Adam algorithm to fix the weight decay. | |||||
| Note: | |||||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||||
| weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied | |||||
| on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. | |||||
| To improve parameter groups performance, the customized order of parameters can be supported. | |||||
| Args: | |||||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||||
| the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||||
| "lr", "weight_decay" and "order_params" are the keys can be parsed. | |||||
| - params: Required. The value must be a list of `Parameter`. | |||||
| - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. | |||||
| If not, the `learning_rate` in the API will be used. | |||||
| - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay | |||||
| will be used. If not, the `weight_decay` in the API will be used. | |||||
| - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and | |||||
| the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters | |||||
| which in the 'order_params' must be in one of group parameters. | |||||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then | |||||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||||
| use dynamic learning rate, the i-th learning rate will be calculated during the process of training | |||||
| according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero | |||||
| dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be | |||||
| equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. | |||||
| Default: 1e-3. | |||||
| beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. | |||||
| Should be in range (0.0, 1.0). | |||||
| beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. | |||||
| Should be in range (0.0, 1.0). | |||||
| eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. | |||||
| Should be greater than 0. | |||||
| weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. | |||||
| Inputs: | |||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||||
| - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale. | |||||
| Outputs: | |||||
| tuple[bool], all elements are True. | |||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> #1) All parameters use the same learning rate and weight decay | |||||
| >>> optim = AdamWeightDecay(params=net.trainable_params()) | |||||
| >>> | |||||
| >>> #2) Use parameter groups and set different values | |||||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||||
| ... {'order_params': net.trainable_params()}] | |||||
| >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) | |||||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | |||||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||||
| >>> | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim) | |||||
| """ | |||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): | |||||
| super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay) | |||||
| _check_param_value(beta1, beta2, eps, self.cls_name) | |||||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||||
| self.eps = Tensor(np.array([eps]).astype(np.float32)) | |||||
| self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') | |||||
| self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.op_select = P.Select() | |||||
| self.op_cast = P.Cast() | |||||
| self.op_reshape = P.Reshape() | |||||
| self.op_shape = P.Shape() | |||||
| def construct(self, gradients, overflow): | |||||
| """AdamWeightDecayForBert""" | |||||
| lr = self.get_lr() | |||||
| cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\ | |||||
| self.op_reshape(overflow, (())), mstype.bool_) | |||||
| beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1) | |||||
| beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2) | |||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), | |||||
| lr, self.weight_decay, self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| else: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow), | |||||
| self.weight_decay, self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| else: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), | |||||
| self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| if self.use_parallel: | |||||
| self.broadcast_params(optim_result) | |||||
| return optim_result | |||||
| class AdamWeightDecayOp(Optimizer): | |||||
| """ | |||||
| Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. | |||||
| Note: | |||||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||||
| weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied | |||||
| on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. | |||||
| To improve parameter groups performance, the customized order of parameters can be supported. | |||||
| Args: | |||||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||||
| the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||||
| "lr", "weight_decay" and "order_params" are the keys can be parsed. | |||||
| - params: Required. The value must be a list of `Parameter`. | |||||
| - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. | |||||
| If not, the `learning_rate` in the API will be used. | |||||
| - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay | |||||
| will be used. If not, the `weight_decay` in the API will be used. | |||||
| - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and | |||||
| the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters | |||||
| which in the 'order_params' must be in one of group parameters. | |||||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then | |||||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||||
| use dynamic learning rate, the i-th learning rate will be calculated during the process of training | |||||
| according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero | |||||
| dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be | |||||
| equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. | |||||
| Default: 1e-3. | |||||
| beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. | |||||
| Should be in range (0.0, 1.0). | |||||
| beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. | |||||
| Should be in range (0.0, 1.0). | |||||
| eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. | |||||
| Should be greater than 0. | |||||
| weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. | |||||
| Inputs: | |||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||||
| Outputs: | |||||
| tuple[bool], all elements are True. | |||||
| Supported Platforms: | |||||
| ``GPU`` | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> #1) All parameters use the same learning rate and weight decay | |||||
| >>> optim = AdamWeightDecayOp(params=net.trainable_params()) | |||||
| >>> | |||||
| >>> #2) Use parameter groups and set different values | |||||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||||
| ... {'order_params': net.trainable_params()}] | |||||
| >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) | |||||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | |||||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||||
| >>> | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim) | |||||
| """ | |||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): | |||||
| super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) | |||||
| _check_param_value(beta1, beta2, eps, self.cls_name) | |||||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||||
| self.eps = Tensor(np.array([eps]).astype(np.float32)) | |||||
| self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') | |||||
| self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, gradients): | |||||
| """AdamWeightDecayOp""" | |||||
| lr = self.get_lr() | |||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), | |||||
| lr, self.weight_decay, self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| else: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), | |||||
| self.weight_decay, self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| else: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), | |||||
| self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| if self.use_parallel: | |||||
| self.broadcast_params(optim_result) | |||||
| return optim_result | |||||
| @@ -0,0 +1,72 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
| """weight initial for conv layer""" | |||||
| weight = weight_variable() | |||||
| return nn.Conv2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| weight_init=weight, | |||||
| has_bias=False, | |||||
| pad_mode="valid", | |||||
| ) | |||||
| def fc_with_initialize(input_channels, out_channels): | |||||
| """weight initial for fc layer""" | |||||
| weight = weight_variable() | |||||
| bias = weight_variable() | |||||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||||
| def weight_variable(): | |||||
| """weight initial""" | |||||
| return TruncatedNormal(0.02) | |||||
| class LeNet5(nn.Cell): | |||||
| def __init__(self, num_class=10, channel=3): | |||||
| super(LeNet5, self).__init__() | |||||
| self.num_class = num_class | |||||
| self.conv1 = conv(channel, 6, 5) | |||||
| self.conv2 = conv(6, 16, 5) | |||||
| self.fc1 = fc_with_initialize(16 * 5 * 5, 120) | |||||
| self.fc2 = fc_with_initialize(120, 84) | |||||
| self.fc3 = fc_with_initialize(84, self.num_class) | |||||
| self.relu = nn.ReLU() | |||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
| self.flatten = nn.Flatten() | |||||
| def construct(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.conv2(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.flatten(x) | |||||
| x = self.fc1(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc2(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc3(x) | |||||
| return x | |||||
| @@ -0,0 +1,96 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import ast | |||||
| import argparse | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from src.model import LeNet5 | |||||
| from src.adam import AdamWeightDecayOp | |||||
| parser = argparse.ArgumentParser(description="test_fl_lenet") | |||||
| parser.add_argument("--device_target", type=str, default="CPU") | |||||
| parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") | |||||
| parser.add_argument("--ms_role", type=str, default="MS_WORKER") | |||||
| parser.add_argument("--worker_num", type=int, default=0) | |||||
| parser.add_argument("--server_num", type=int, default=1) | |||||
| parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") | |||||
| parser.add_argument("--scheduler_port", type=int, default=8113) | |||||
| parser.add_argument("--fl_server_port", type=int, default=6666) | |||||
| parser.add_argument("--start_fl_job_threshold", type=int, default=1) | |||||
| parser.add_argument("--fl_name", type=str, default="Lenet") | |||||
| parser.add_argument("--fl_iteration_num", type=int, default=25) | |||||
| parser.add_argument("--client_epoch_num", type=int, default=20) | |||||
| parser.add_argument("--client_batch_size", type=int, default=32) | |||||
| parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False) | |||||
| args, _ = parser.parse_known_args() | |||||
| device_target = args.device_target | |||||
| server_mode = args.server_mode | |||||
| ms_role = args.ms_role | |||||
| worker_num = args.worker_num | |||||
| server_num = args.server_num | |||||
| scheduler_ip = args.scheduler_ip | |||||
| scheduler_port = args.scheduler_port | |||||
| fl_server_port = args.fl_server_port | |||||
| start_fl_job_threshold = args.start_fl_job_threshold | |||||
| fl_name = args.fl_name | |||||
| fl_iteration_num = args.fl_iteration_num | |||||
| client_epoch_num = args.client_epoch_num | |||||
| client_batch_size = args.client_batch_size | |||||
| secure_aggregation = args.secure_aggregation | |||||
| ctx = { | |||||
| "enable_ps": False, | |||||
| "server_mode": server_mode, | |||||
| "ms_role": ms_role, | |||||
| "worker_num": worker_num, | |||||
| "server_num": server_num, | |||||
| "scheduler_ip": scheduler_ip, | |||||
| "scheduler_port": scheduler_port, | |||||
| "fl_server_port": fl_server_port, | |||||
| "start_fl_job_threshold": start_fl_job_threshold, | |||||
| "fl_name": fl_name, | |||||
| "fl_iteration_num": fl_iteration_num, | |||||
| "client_epoch_num": client_epoch_num, | |||||
| "client_batch_size": client_batch_size, | |||||
| "secure_aggregation": secure_aggregation | |||||
| } | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False) | |||||
| context.set_ps_context(**ctx) | |||||
| if __name__ == "__main__": | |||||
| epoch = 5 | |||||
| np.random.seed(0) | |||||
| network = LeNet5(62) | |||||
| criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
| net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) | |||||
| net_adam_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1) | |||||
| net_with_criterion = WithLossCell(network, criterion) | |||||
| train_network = TrainOneStepCell(net_with_criterion, net_opt) | |||||
| train_network.set_train() | |||||
| losses = [] | |||||
| for _ in range(epoch): | |||||
| data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) | |||||
| label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32)) | |||||
| loss = train_network(data, label).asnumpy() | |||||
| losses.append(loss) | |||||
| print(losses) | |||||