You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 9.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/server/executor.h"
  17. #include <set>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. namespace mindspore {
  22. namespace fl {
  23. namespace server {
  24. void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
  25. MS_EXCEPTION_IF_NULL(func_graph);
  26. MS_LOG(INFO) << "Start Initialize Executor.";
  27. if (aggregation_count == 0) {
  28. MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
  29. return;
  30. }
  31. aggregation_count_ = aggregation_count;
  32. // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and
  33. // optimizers.
  34. bool ret = InitParamAggregator(func_graph);
  35. if (!ret) {
  36. MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
  37. return;
  38. }
  39. initialized_ = true;
  40. return;
  41. }
  42. bool Executor::ReInitForScaling() {
  43. auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(),
  44. [](auto param_aggr) { return !param_aggr.second->ReInitForScaling(); });
  45. if (result != param_aggrs_.end()) {
  46. MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
  47. return false;
  48. }
  49. return true;
  50. }
  51. bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
  52. aggregation_count_ = aggr_threshold;
  53. auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(), [this](auto param_aggr) {
  54. return !param_aggr.second->ReInitForUpdatingHyperParams(aggregation_count_);
  55. });
  56. if (result != param_aggrs_.end()) {
  57. MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
  58. return false;
  59. }
  60. return true;
  61. }
  62. bool Executor::initialized() const { return initialized_; }
  63. bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData &upload_data) {
  64. MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
  65. if (param_aggrs_.count(param_name) == 0) {
  66. // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
  67. // just print a warning log and return true.
  68. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  69. return true;
  70. }
  71. std::mutex &mtx = parameter_mutex_[param_name];
  72. std::unique_lock<std::mutex> lock(mtx);
  73. auto &param_aggr = param_aggrs_[param_name];
  74. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  75. if (!param_aggr->UpdateData(upload_data)) {
  76. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  77. return false;
  78. }
  79. // Different from Push, UpdateModel doesn't need to checkout the aggregation status.
  80. if (!param_aggr->LaunchAggregators()) {
  81. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  82. return false;
  83. }
  84. return true;
  85. }
  86. bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
  87. for (const auto &trainable_param : feature_map) {
  88. const std::string &param_name = trainable_param.first;
  89. if (param_aggrs_.count(param_name) == 0) {
  90. MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
  91. continue;
  92. }
  93. std::mutex &mtx = parameter_mutex_[param_name];
  94. std::unique_lock<std::mutex> lock(mtx);
  95. auto &param_aggr = param_aggrs_[param_name];
  96. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  97. AddressPtr old_weight = param_aggr->GetWeight();
  98. const Address &new_weight = trainable_param.second;
  99. MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false);
  100. MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false);
  101. MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false);
  102. int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
  103. if (ret != 0) {
  104. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  105. return false;
  106. }
  107. }
  108. return true;
  109. }
  110. std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> &param_names) {
  111. std::map<std::string, AddressPtr> weights;
  112. for (const auto &param_name : param_names) {
  113. if (param_aggrs_.count(param_name) == 0) {
  114. MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
  115. return weights;
  116. }
  117. std::mutex &mtx = parameter_mutex_[param_name];
  118. std::unique_lock<std::mutex> lock(mtx);
  119. const auto &param_aggr = param_aggrs_[param_name];
  120. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights);
  121. AddressPtr addr = param_aggr->GetWeight();
  122. if (addr == nullptr) {
  123. MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
  124. continue;
  125. }
  126. weights[param_name] = addr;
  127. }
  128. return weights;
  129. }
  130. bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
  131. bool Executor::RunAllWeightAggregation() {
  132. for (const auto &name : param_names_) {
  133. if (param_aggrs_.count(name) == 0) {
  134. MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
  135. return false;
  136. }
  137. std::mutex &mtx = parameter_mutex_[name];
  138. std::unique_lock<std::mutex> lock(mtx);
  139. auto &param_aggr = param_aggrs_[name];
  140. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  141. if (!param_aggr->requires_aggr()) {
  142. continue;
  143. }
  144. if (!param_aggr->RunAggregation()) {
  145. MS_LOG(WARNING) << "Failed to run aggregation for " << name;
  146. return false;
  147. }
  148. }
  149. return true;
  150. }
  151. bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {
  152. for (const auto &name : param_names) {
  153. if (param_aggrs_.count(name) == 0) {
  154. MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
  155. return false;
  156. }
  157. std::mutex &mtx = parameter_mutex_[name];
  158. std::unique_lock<std::mutex> lock(mtx);
  159. auto &param_aggr = param_aggrs_[name];
  160. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  161. if (!param_aggr->requires_aggr()) {
  162. continue;
  163. }
  164. if (!param_aggr->IsAggregationDone()) {
  165. MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
  166. return false;
  167. }
  168. }
  169. return true;
  170. }
  171. void Executor::ResetAggregationStatus() {
  172. for (const auto &param_name : param_names_) {
  173. std::mutex &mtx = parameter_mutex_[param_name];
  174. std::unique_lock<std::mutex> lock(mtx);
  175. auto &param_aggr = param_aggrs_[param_name];
  176. MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr);
  177. param_aggr->ResetAggregationStatus();
  178. }
  179. return;
  180. }
  181. std::map<std::string, AddressPtr> Executor::GetModel() {
  182. std::map<std::string, AddressPtr> model = {};
  183. for (const auto &name : param_names_) {
  184. std::mutex &mtx = parameter_mutex_[name];
  185. std::unique_lock<std::mutex> lock(mtx);
  186. AddressPtr addr = param_aggrs_[name]->GetWeight();
  187. if (addr == nullptr) {
  188. MS_LOG(WARNING) << "Get weight of " << name << " failed.";
  189. continue;
  190. }
  191. model[name] = addr;
  192. }
  193. return model;
  194. }
  195. const std::vector<std::string> &Executor::param_names() const { return param_names_; }
  196. bool Executor::Unmask() {
  197. #ifdef ENABLE_ARMOUR
  198. auto model = GetModel();
  199. return cipher_unmask_.UnMask(model);
  200. #else
  201. return false;
  202. #endif
  203. }
  204. void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; }
  205. bool Executor::unmasked() const {
  206. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  207. if (encrypt_type == ps::kPWEncryptType) {
  208. return unmasked_.load();
  209. } else {
  210. // If the algorithm of mind armour is not enabled, consider unmasked_ flag as true.
  211. return true;
  212. }
  213. }
  214. std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
  215. MS_EXCEPTION_IF_NULL(cnode);
  216. std::string cnode_name = common::AnfAlgo::GetCNodeName(cnode);
  217. if (kNameToIdxMap.count(cnode_name) == 0) {
  218. return "";
  219. }
  220. const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
  221. size_t weight_idx = index_info.at("inputs").at(kWeight);
  222. AnfNodePtr weight_node =
  223. common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
  224. MS_EXCEPTION_IF_NULL(weight_node);
  225. if (!weight_node->isa<Parameter>()) {
  226. MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
  227. return "";
  228. }
  229. return weight_node->fullname_with_scope();
  230. }
  231. bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
  232. MS_EXCEPTION_IF_NULL(func_graph);
  233. const auto &cnodes = func_graph->GetOrderedCnodes();
  234. for (const auto &cnode : cnodes) {
  235. MS_EXCEPTION_IF_NULL(cnode);
  236. const std::string &param_name = GetTrainableParamName(cnode);
  237. if (param_name.empty()) {
  238. continue;
  239. }
  240. if (param_aggrs_.count(param_name) != 0) {
  241. MS_LOG(WARNING) << param_name << " already has parameter aggregator registered.";
  242. continue;
  243. }
  244. std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
  245. MS_EXCEPTION_IF_NULL(param_aggr);
  246. param_names_.push_back(param_name);
  247. param_aggrs_[param_name] = param_aggr;
  248. parameter_mutex_[param_name];
  249. if (!param_aggr->Init(cnode, aggregation_count_)) {
  250. MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed.";
  251. return false;
  252. }
  253. MS_LOG(INFO) << "Initializing parameter aggregator for param_name " << param_name << " success.";
  254. }
  255. return true;
  256. }
  257. } // namespace server
  258. } // namespace fl
  259. } // namespace mindspore