/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ps/server/executor.h" #include #include #include #include namespace mindspore { namespace ps { namespace server { void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) { MS_EXCEPTION_IF_NULL(func_graph); if (aggregation_count == 0) { MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0"; return; } aggregation_count_ = aggregation_count; // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers. bool ret = InitParamAggregator(func_graph); if (!ret) { MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed."; return; } initialized_ = true; return; } bool Executor::initialized() const { return initialized_; } bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) { MS_LOG(DEBUG) << "Do Push for parameter " << param_name; if (param_aggrs_.count(param_name) == 0) { MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; return false; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[param_name]; // Push operation needs to wait until the pulling process is done. while (!param_aggr->IsPullingDone()) { lock.unlock(); std::this_thread::sleep_for(std::chrono::milliseconds(5)); lock.lock(); } // 1.Update data with the uploaded data of the worker. if (!param_aggr->UpdateData(upload_data)) { MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; return false; } // 2.Launch aggregation for this trainable parameter. if (!param_aggr->LaunchAggregators()) { MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; return false; } if (param_aggr->IsAggregationDone()) { // 3.After the aggregation is done, optimize the trainable parameter. if (!param_aggr->LaunchOptimizers()) { MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed."; return false; } // 4.Reset pulling and aggregation status after optimizing is done. param_aggr->ResetPullingStatus(); param_aggr->ResetAggregationStatus(); } return true; } bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) { MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name; if (param_aggrs_.count(param_name) == 0) { // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we // just print a warning log and return true. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; return true; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[param_name]; if (!param_aggr->UpdateData(upload_data)) { MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; return false; } // Different from Push, UpdateModel doesn't need to checkout the aggregation status. if (!param_aggr->LaunchAggregators()) { MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; return false; } return true; } bool Executor::HandleModelUpdateAsync(const std::map &feature_map) { std::unique_lock model_lock(model_mutex_); for (const auto &trainable_param : feature_map) { const std::string ¶m_name = trainable_param.first; if (param_aggrs_.count(param_name) == 0) { MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; continue; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[param_name]; const UploadData &upload_data = trainable_param.second; if (!param_aggr->UpdateData(upload_data)) { MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; return false; } if (!param_aggr->LaunchAggregators()) { MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; return false; } } return true; } bool Executor::HandleOverwriteWeightsByKey(const std::map &feature_map) { for (const auto &trainable_param : feature_map) { const std::string ¶m_name = trainable_param.first; if (param_aggrs_.count(param_name) == 0) { MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server."; continue; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[param_name]; AddressPtr old_weight = param_aggr->GetWeight(); if (old_weight == nullptr) { MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; return false; } const Address &new_weight = trainable_param.second; if (new_weight.addr == nullptr) { MS_LOG(ERROR) << "The new weight is nullptr."; return false; } int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; return false; } } return true; } AddressPtr Executor::HandlePull(const std::string ¶m_name) { MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name; if (param_aggrs_.count(param_name) == 0) { MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; return nullptr; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[param_name]; // Pulling must wait until the optimizing process is done. while (!param_aggr->IsOptimizingDone()) { lock.unlock(); std::this_thread::sleep_for(std::chrono::milliseconds(5)); lock.lock(); } AddressPtr addr = param_aggr->Pull(); // If this Pull is the last one, reset pulling and optimizing status. if (param_aggr->IsPullingDone()) { param_aggr->ResetOptimizingStatus(); } return addr; } std::map Executor::HandleGetWeightsByKey(const std::vector ¶m_names) { std::map weights; for (const auto ¶m_name : param_names) { if (param_aggrs_.count(param_name) == 0) { MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server."; return weights; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); const auto ¶m_aggr = param_aggrs_[param_name]; AddressPtr addr = param_aggr->GetWeight(); if (addr == nullptr) { MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; continue; } weights[param_name] = addr; } return weights; } bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); } bool Executor::IsWeightAggrDone(const std::vector ¶m_names) { for (const auto &name : param_names) { if (param_aggrs_.count(name) == 0) { MS_LOG(ERROR) << "Weight " << name << " is invalid in server."; return false; } std::mutex &mtx = parameter_mutex_[name]; std::unique_lock lock(mtx); if (!param_aggrs_[name]->IsAggregationDone()) { MS_LOG(DEBUG) << "Update model for " << name << " is not done yet."; return false; } } return true; } void Executor::ResetAggregationStatus() { for (const auto ¶m_name : param_names_) { std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); param_aggrs_[param_name]->ResetAggregationStatus(); } return; } std::map Executor::GetModel() { std::map model = {}; for (const auto &name : param_names_) { std::mutex &mtx = parameter_mutex_[name]; std::unique_lock lock(mtx); AddressPtr addr = param_aggrs_[name]->GetWeight(); if (addr == nullptr) { MS_LOG(WARNING) << "Get weight of " << name << " failed."; continue; } model[name] = addr; } return model; } // bool Executor::Unmask() { // auto model = GetModel(); // return mindarmour::CipherMgr::GetInstance().UnMask(model); // } const std::vector &Executor::param_names() const { return param_names_; } std::string Executor::GetTrainableParamName(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); std::string cnode_name = AnfAlgo::GetCNodeName(cnode); if (kNameToIdxMap.count(cnode_name) == 0) { return ""; } const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name); size_t weight_idx = index_info.at("inputs").at(kWeight); AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first; MS_EXCEPTION_IF_NULL(weight_node); if (!weight_node->isa()) { MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter."; } return weight_node->fullname_with_scope(); } bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); const auto &cnodes = func_graph->GetOrderedCnodes(); for (const auto &cnode : cnodes) { MS_EXCEPTION_IF_NULL(cnode); const std::string ¶m_name = GetTrainableParamName(cnode); if (param_name.empty()) { continue; } if (param_aggrs_.count(param_name) != 0) { MS_LOG(WARNING) << param_name << " already has its control flow."; continue; } std::shared_ptr param_aggr = std::make_shared(); MS_EXCEPTION_IF_NULL(param_aggr); param_names_.push_back(param_name); param_aggrs_[param_name] = param_aggr; parameter_mutex_[param_name]; param_aggr->Init(cnode, aggregation_count_); MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success."; } return true; } } // namespace server } // namespace ps } // namespace mindspore