You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 9.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/server/executor.h"
  17. #include <set>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. namespace mindspore {
  22. namespace fl {
  23. namespace server {
  24. void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
  25. MS_EXCEPTION_IF_NULL(func_graph);
  26. if (aggregation_count == 0) {
  27. MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
  28. return;
  29. }
  30. aggregation_count_ = aggregation_count;
  31. // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
  32. bool ret = InitParamAggregator(func_graph);
  33. if (!ret) {
  34. MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
  35. return;
  36. }
  37. initialized_ = true;
  38. return;
  39. }
  40. bool Executor::ReInitForScaling() {
  41. auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(),
  42. [](auto param_aggr) { return !param_aggr.second->ReInitForScaling(); });
  43. if (result != param_aggrs_.end()) {
  44. MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
  45. return false;
  46. }
  47. return true;
  48. }
  49. bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
  50. aggregation_count_ = aggr_threshold;
  51. auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(), [this](auto param_aggr) {
  52. return !param_aggr.second->ReInitForUpdatingHyperParams(aggregation_count_);
  53. });
  54. if (result != param_aggrs_.end()) {
  55. MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
  56. return false;
  57. }
  58. return true;
  59. }
  60. bool Executor::initialized() const { return initialized_; }
  61. bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData &upload_data) {
  62. MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
  63. if (param_aggrs_.count(param_name) == 0) {
  64. // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
  65. // just print a warning log and return true.
  66. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  67. return true;
  68. }
  69. std::mutex &mtx = parameter_mutex_[param_name];
  70. std::unique_lock<std::mutex> lock(mtx);
  71. auto &param_aggr = param_aggrs_[param_name];
  72. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  73. if (!param_aggr->UpdateData(upload_data)) {
  74. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  75. return false;
  76. }
  77. // Different from Push, UpdateModel doesn't need to checkout the aggregation status.
  78. if (!param_aggr->LaunchAggregators()) {
  79. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  80. return false;
  81. }
  82. return true;
  83. }
  84. bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
  85. for (const auto &trainable_param : feature_map) {
  86. const std::string &param_name = trainable_param.first;
  87. if (param_aggrs_.count(param_name) == 0) {
  88. MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
  89. continue;
  90. }
  91. std::mutex &mtx = parameter_mutex_[param_name];
  92. std::unique_lock<std::mutex> lock(mtx);
  93. auto &param_aggr = param_aggrs_[param_name];
  94. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  95. AddressPtr old_weight = param_aggr->GetWeight();
  96. const Address &new_weight = trainable_param.second;
  97. MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false);
  98. MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false);
  99. MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false);
  100. int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
  101. if (ret != 0) {
  102. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  103. return false;
  104. }
  105. }
  106. return true;
  107. }
  108. std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> &param_names) {
  109. std::map<std::string, AddressPtr> weights;
  110. for (const auto &param_name : param_names) {
  111. if (param_aggrs_.count(param_name) == 0) {
  112. MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
  113. return weights;
  114. }
  115. std::mutex &mtx = parameter_mutex_[param_name];
  116. std::unique_lock<std::mutex> lock(mtx);
  117. const auto &param_aggr = param_aggrs_[param_name];
  118. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights);
  119. AddressPtr addr = param_aggr->GetWeight();
  120. if (addr == nullptr) {
  121. MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
  122. continue;
  123. }
  124. weights[param_name] = addr;
  125. }
  126. return weights;
  127. }
  128. bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
  129. bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {
  130. for (const auto &name : param_names) {
  131. if (param_aggrs_.count(name) == 0) {
  132. MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
  133. return false;
  134. }
  135. std::mutex &mtx = parameter_mutex_[name];
  136. std::unique_lock<std::mutex> lock(mtx);
  137. auto &param_aggr = param_aggrs_[name];
  138. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  139. if (!param_aggr->requires_aggr()) {
  140. continue;
  141. }
  142. if (!param_aggr->IsAggregationDone()) {
  143. MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
  144. return false;
  145. }
  146. }
  147. return true;
  148. }
  149. void Executor::ResetAggregationStatus() {
  150. for (const auto &param_name : param_names_) {
  151. std::mutex &mtx = parameter_mutex_[param_name];
  152. std::unique_lock<std::mutex> lock(mtx);
  153. auto &param_aggr = param_aggrs_[param_name];
  154. MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr);
  155. param_aggr->ResetAggregationStatus();
  156. }
  157. return;
  158. }
  159. std::map<std::string, AddressPtr> Executor::GetModel() {
  160. std::map<std::string, AddressPtr> model = {};
  161. for (const auto &name : param_names_) {
  162. std::mutex &mtx = parameter_mutex_[name];
  163. std::unique_lock<std::mutex> lock(mtx);
  164. AddressPtr addr = param_aggrs_[name]->GetWeight();
  165. if (addr == nullptr) {
  166. MS_LOG(WARNING) << "Get weight of " << name << " failed.";
  167. continue;
  168. }
  169. model[name] = addr;
  170. }
  171. return model;
  172. }
  173. const std::vector<std::string> &Executor::param_names() const { return param_names_; }
  174. bool Executor::Unmask() {
  175. #ifdef ENABLE_ARMOUR
  176. auto model = GetModel();
  177. return cipher_unmask_.UnMask(model);
  178. #else
  179. return false;
  180. #endif
  181. }
  182. void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; }
  183. bool Executor::unmasked() const {
  184. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  185. if (encrypt_type == ps::kPWEncryptType) {
  186. return unmasked_.load();
  187. } else {
  188. // If the algorithm of mind armour is not enabled, consider unmasked_ flag as true.
  189. return true;
  190. }
  191. }
  192. std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
  193. MS_EXCEPTION_IF_NULL(cnode);
  194. std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
  195. if (kNameToIdxMap.count(cnode_name) == 0) {
  196. return "";
  197. }
  198. const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
  199. size_t weight_idx = index_info.at("inputs").at(kWeight);
  200. AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
  201. MS_EXCEPTION_IF_NULL(weight_node);
  202. if (!weight_node->isa<Parameter>()) {
  203. MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
  204. return "";
  205. }
  206. return weight_node->fullname_with_scope();
  207. }
  208. bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
  209. MS_EXCEPTION_IF_NULL(func_graph);
  210. const auto &cnodes = func_graph->GetOrderedCnodes();
  211. for (const auto &cnode : cnodes) {
  212. MS_EXCEPTION_IF_NULL(cnode);
  213. const std::string &param_name = GetTrainableParamName(cnode);
  214. if (param_name.empty()) {
  215. continue;
  216. }
  217. if (param_aggrs_.count(param_name) != 0) {
  218. MS_LOG(WARNING) << param_name << " already has parameter aggregator registered.";
  219. continue;
  220. }
  221. std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
  222. MS_EXCEPTION_IF_NULL(param_aggr);
  223. param_names_.push_back(param_name);
  224. param_aggrs_[param_name] = param_aggr;
  225. parameter_mutex_[param_name];
  226. if (!param_aggr->Init(cnode, aggregation_count_)) {
  227. MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed.";
  228. return false;
  229. }
  230. MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
  231. }
  232. return true;
  233. }
  234. } // namespace server
  235. } // namespace fl
  236. } // namespace mindspore