|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310 |
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "ps/server/executor.h"
- #include <set>
- #include <memory>
- #include <string>
- #include <vector>
-
- namespace mindspore {
- namespace ps {
- namespace server {
- void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
- MS_EXCEPTION_IF_NULL(func_graph);
- if (aggregation_count == 0) {
- MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
- return;
- }
- aggregation_count_ = aggregation_count;
-
- // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
- bool ret = InitParamAggregator(func_graph);
- if (!ret) {
- MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
- return;
- }
- initialized_ = true;
- return;
- }
-
- bool Executor::initialized() const { return initialized_; }
-
- bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) {
- MS_LOG(DEBUG) << "Do Push for parameter " << param_name;
- if (param_aggrs_.count(param_name) == 0) {
- MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
- return false;
- }
-
- std::mutex &mtx = parameter_mutex_[param_name];
- std::unique_lock<std::mutex> lock(mtx);
- auto ¶m_aggr = param_aggrs_[param_name];
-
- // Push operation needs to wait until the pulling process is done.
- while (!param_aggr->IsPullingDone()) {
- lock.unlock();
- std::this_thread::sleep_for(std::chrono::milliseconds(5));
- lock.lock();
- }
-
- // 1.Update data with the uploaded data of the worker.
- if (!param_aggr->UpdateData(upload_data)) {
- MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
- return false;
- }
- // 2.Launch aggregation for this trainable parameter.
- if (!param_aggr->LaunchAggregators()) {
- MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
- return false;
- }
- if (param_aggr->IsAggregationDone()) {
- // 3.After the aggregation is done, optimize the trainable parameter.
- if (!param_aggr->LaunchOptimizers()) {
- MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed.";
- return false;
- }
- // 4.Reset pulling and aggregation status after optimizing is done.
- param_aggr->ResetPullingStatus();
- param_aggr->ResetAggregationStatus();
- }
- return true;
- }
-
- bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) {
- MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
- if (param_aggrs_.count(param_name) == 0) {
- // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
- // just print a warning log and return true.
- MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
- return true;
- }
-
- std::mutex &mtx = parameter_mutex_[param_name];
- std::unique_lock<std::mutex> lock(mtx);
- auto ¶m_aggr = param_aggrs_[param_name];
-
- if (!param_aggr->UpdateData(upload_data)) {
- MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
- return false;
- }
- // Different from Push, UpdateModel doesn't need to checkout the aggregation status.
- if (!param_aggr->LaunchAggregators()) {
- MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
- return false;
- }
- return true;
- }
-
- bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) {
- std::unique_lock<std::mutex> model_lock(model_mutex_);
- for (const auto &trainable_param : feature_map) {
- const std::string ¶m_name = trainable_param.first;
- if (param_aggrs_.count(param_name) == 0) {
- MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
- continue;
- }
-
- std::mutex &mtx = parameter_mutex_[param_name];
- std::unique_lock<std::mutex> lock(mtx);
- auto ¶m_aggr = param_aggrs_[param_name];
-
- const UploadData &upload_data = trainable_param.second;
- if (!param_aggr->UpdateData(upload_data)) {
- MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
- return false;
- }
- if (!param_aggr->LaunchAggregators()) {
- MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
- return false;
- }
- }
- return true;
- }
-
- bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map) {
- for (const auto &trainable_param : feature_map) {
- const std::string ¶m_name = trainable_param.first;
- if (param_aggrs_.count(param_name) == 0) {
- MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
- continue;
- }
-
- std::mutex &mtx = parameter_mutex_[param_name];
- std::unique_lock<std::mutex> lock(mtx);
- auto ¶m_aggr = param_aggrs_[param_name];
-
- AddressPtr old_weight = param_aggr->GetWeight();
- if (old_weight == nullptr) {
- MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
- return false;
- }
-
- const Address &new_weight = trainable_param.second;
- if (new_weight.addr == nullptr) {
- MS_LOG(ERROR) << "The new weight is nullptr.";
- return false;
- }
-
- int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
- if (ret != 0) {
- MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
- return false;
- }
- }
- return true;
- }
-
- AddressPtr Executor::HandlePull(const std::string ¶m_name) {
- MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
- if (param_aggrs_.count(param_name) == 0) {
- MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
- return nullptr;
- }
-
- std::mutex &mtx = parameter_mutex_[param_name];
- std::unique_lock<std::mutex> lock(mtx);
- auto ¶m_aggr = param_aggrs_[param_name];
-
- // Pulling must wait until the optimizing process is done.
- while (!param_aggr->IsOptimizingDone()) {
- lock.unlock();
- std::this_thread::sleep_for(std::chrono::milliseconds(5));
- lock.lock();
- }
- AddressPtr addr = param_aggr->Pull();
- // If this Pull is the last one, reset pulling and optimizing status.
- if (param_aggr->IsPullingDone()) {
- param_aggr->ResetOptimizingStatus();
- }
- return addr;
- }
-
- std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) {
- std::map<std::string, AddressPtr> weights;
- for (const auto ¶m_name : param_names) {
- if (param_aggrs_.count(param_name) == 0) {
- MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
- return weights;
- }
-
- std::mutex &mtx = parameter_mutex_[param_name];
- std::unique_lock<std::mutex> lock(mtx);
- const auto ¶m_aggr = param_aggrs_[param_name];
-
- AddressPtr addr = param_aggr->GetWeight();
- if (addr == nullptr) {
- MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
- continue;
- }
- weights[param_name] = addr;
- }
- return weights;
- }
-
- bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
-
- bool Executor::IsWeightAggrDone(const std::vector<std::string> ¶m_names) {
- for (const auto &name : param_names) {
- if (param_aggrs_.count(name) == 0) {
- MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
- return false;
- }
-
- std::mutex &mtx = parameter_mutex_[name];
- std::unique_lock<std::mutex> lock(mtx);
- if (!param_aggrs_[name]->IsAggregationDone()) {
- MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
- return false;
- }
- }
- return true;
- }
-
- void Executor::ResetAggregationStatus() {
- for (const auto ¶m_name : param_names_) {
- std::mutex &mtx = parameter_mutex_[param_name];
- std::unique_lock<std::mutex> lock(mtx);
- param_aggrs_[param_name]->ResetAggregationStatus();
- }
- return;
- }
-
- std::map<std::string, AddressPtr> Executor::GetModel() {
- std::map<std::string, AddressPtr> model = {};
- for (const auto &name : param_names_) {
- std::mutex &mtx = parameter_mutex_[name];
- std::unique_lock<std::mutex> lock(mtx);
- AddressPtr addr = param_aggrs_[name]->GetWeight();
- if (addr == nullptr) {
- MS_LOG(WARNING) << "Get weight of " << name << " failed.";
- continue;
- }
- model[name] = addr;
- }
- return model;
- }
-
- // bool Executor::Unmask() {
- // auto model = GetModel();
- // return mindarmour::CipherMgr::GetInstance().UnMask(model);
- // }
-
- const std::vector<std::string> &Executor::param_names() const { return param_names_; }
-
- std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
- MS_EXCEPTION_IF_NULL(cnode);
- std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
- if (kNameToIdxMap.count(cnode_name) == 0) {
- return "";
- }
- const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
- size_t weight_idx = index_info.at("inputs").at(kWeight);
- AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
- MS_EXCEPTION_IF_NULL(weight_node);
- if (!weight_node->isa<Parameter>()) {
- MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
- }
- return weight_node->fullname_with_scope();
- }
-
- bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
- MS_EXCEPTION_IF_NULL(func_graph);
- const auto &cnodes = func_graph->GetOrderedCnodes();
- for (const auto &cnode : cnodes) {
- MS_EXCEPTION_IF_NULL(cnode);
- const std::string ¶m_name = GetTrainableParamName(cnode);
- if (param_name.empty()) {
- continue;
- }
- if (param_aggrs_.count(param_name) != 0) {
- MS_LOG(WARNING) << param_name << " already has its control flow.";
- continue;
- }
-
- std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
- MS_EXCEPTION_IF_NULL(param_aggr);
- param_names_.push_back(param_name);
- param_aggrs_[param_name] = param_aggr;
- parameter_mutex_[param_name];
- param_aggr->Init(cnode, aggregation_count_);
- MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success.";
- }
- return true;
- }
- } // namespace server
- } // namespace ps
- } // namespace mindspore
|