/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fl/server/executor.h" #include #include #include #include namespace mindspore { namespace fl { namespace server { void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) { MS_EXCEPTION_IF_NULL(func_graph); if (aggregation_count == 0) { MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0"; return; } aggregation_count_ = aggregation_count; // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers. bool ret = InitParamAggregator(func_graph); if (!ret) { MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed."; return; } initialized_ = true; return; } bool Executor::ReInitForScaling() { auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(), [](auto param_aggr) { return !param_aggr.second->ReInitForScaling(); }); if (result != param_aggrs_.end()) { MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed."; return false; } return true; } bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) { aggregation_count_ = aggr_threshold; auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(), [this](auto param_aggr) { return !param_aggr.second->ReInitForUpdatingHyperParams(aggregation_count_); }); if (result != param_aggrs_.end()) { MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed."; return false; } return true; } bool Executor::initialized() const { return initialized_; } bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) { MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name; if (param_aggrs_.count(param_name) == 0) { // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we // just print a warning log and return true. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; return true; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[param_name]; MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); if (!param_aggr->UpdateData(upload_data)) { MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; return false; } // Different from Push, UpdateModel doesn't need to checkout the aggregation status. if (!param_aggr->LaunchAggregators()) { MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; return false; } return true; } bool Executor::HandlePushWeight(const std::map &feature_map) { for (const auto &trainable_param : feature_map) { const std::string ¶m_name = trainable_param.first; if (param_aggrs_.count(param_name) == 0) { MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server."; continue; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[param_name]; MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); AddressPtr old_weight = param_aggr->GetWeight(); const Address &new_weight = trainable_param.second; MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false); MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false); MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false); int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; return false; } } return true; } std::map Executor::HandlePullWeight(const std::vector ¶m_names) { std::map weights; for (const auto ¶m_name : param_names) { if (param_aggrs_.count(param_name) == 0) { MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server."; return weights; } std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); const auto ¶m_aggr = param_aggrs_[param_name]; MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights); AddressPtr addr = param_aggr->GetWeight(); if (addr == nullptr) { MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; continue; } weights[param_name] = addr; } return weights; } bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); } bool Executor::RunAllWeightAggregation() { for (const auto &name : param_names_) { if (param_aggrs_.count(name) == 0) { MS_LOG(ERROR) << "Weight " << name << " is invalid in server."; return false; } std::mutex &mtx = parameter_mutex_[name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[name]; MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); if (!param_aggr->requires_aggr()) { continue; } if (!param_aggr->RunAggregation()) { MS_LOG(WARNING) << "Failed to run aggregation for " << name; return false; } } return true; } bool Executor::IsWeightAggrDone(const std::vector ¶m_names) { for (const auto &name : param_names) { if (param_aggrs_.count(name) == 0) { MS_LOG(ERROR) << "Weight " << name << " is invalid in server."; return false; } std::mutex &mtx = parameter_mutex_[name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[name]; MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); if (!param_aggr->requires_aggr()) { continue; } if (!param_aggr->IsAggregationDone()) { MS_LOG(DEBUG) << "Update model for " << name << " is not done yet."; return false; } } return true; } void Executor::ResetAggregationStatus() { for (const auto ¶m_name : param_names_) { std::mutex &mtx = parameter_mutex_[param_name]; std::unique_lock lock(mtx); auto ¶m_aggr = param_aggrs_[param_name]; MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr); param_aggr->ResetAggregationStatus(); } return; } std::map Executor::GetModel() { std::map model = {}; for (const auto &name : param_names_) { std::mutex &mtx = parameter_mutex_[name]; std::unique_lock lock(mtx); AddressPtr addr = param_aggrs_[name]->GetWeight(); if (addr == nullptr) { MS_LOG(WARNING) << "Get weight of " << name << " failed."; continue; } model[name] = addr; } return model; } const std::vector &Executor::param_names() const { return param_names_; } bool Executor::Unmask() { #ifdef ENABLE_ARMOUR auto model = GetModel(); return cipher_unmask_.UnMask(model); #else return false; #endif } void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; } bool Executor::unmasked() const { std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); if (encrypt_type == ps::kPWEncryptType) { return unmasked_.load(); } else { // If the algorithm of mind armour is not enabled, consider unmasked_ flag as true. return true; } } std::string Executor::GetTrainableParamName(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); std::string cnode_name = common::AnfAlgo::GetCNodeName(cnode); if (kNameToIdxMap.count(cnode_name) == 0) { return ""; } const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name); size_t weight_idx = index_info.at("inputs").at(kWeight); AnfNodePtr weight_node = common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(cnode, weight_idx), 0).first; MS_EXCEPTION_IF_NULL(weight_node); if (!weight_node->isa()) { MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter."; return ""; } return weight_node->fullname_with_scope(); } bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); const auto &cnodes = func_graph->GetOrderedCnodes(); for (const auto &cnode : cnodes) { MS_EXCEPTION_IF_NULL(cnode); const std::string ¶m_name = GetTrainableParamName(cnode); if (param_name.empty()) { continue; } if (param_aggrs_.count(param_name) != 0) { MS_LOG(WARNING) << param_name << " already has parameter aggregator registered."; continue; } std::shared_ptr param_aggr = std::make_shared(); MS_EXCEPTION_IF_NULL(param_aggr); param_names_.push_back(param_name); param_aggrs_[param_name] = param_aggr; parameter_mutex_[param_name]; if (!param_aggr->Init(cnode, aggregation_count_)) { MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed."; return false; } MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success."; } return true; } } // namespace server } // namespace fl } // namespace mindspore