You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 9.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/server/executor.h"
  17. #include <set>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. namespace mindspore {
  22. namespace fl {
  23. namespace server {
  24. void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
  25. MS_EXCEPTION_IF_NULL(func_graph);
  26. if (aggregation_count == 0) {
  27. MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
  28. return;
  29. }
  30. aggregation_count_ = aggregation_count;
  31. // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
  32. bool ret = InitParamAggregator(func_graph);
  33. if (!ret) {
  34. MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
  35. return;
  36. }
  37. initialized_ = true;
  38. return;
  39. }
  40. bool Executor::ReInitForScaling() {
  41. auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(),
  42. [](auto param_aggr) { return !param_aggr.second->ReInitForScaling(); });
  43. if (result != param_aggrs_.end()) {
  44. MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
  45. return false;
  46. }
  47. return true;
  48. }
  49. bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
  50. aggregation_count_ = aggr_threshold;
  51. auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(), [this](auto param_aggr) {
  52. return !param_aggr.second->ReInitForUpdatingHyperParams(aggregation_count_);
  53. });
  54. if (result != param_aggrs_.end()) {
  55. MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
  56. return false;
  57. }
  58. return true;
  59. }
  60. bool Executor::initialized() const { return initialized_; }
  61. bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData &upload_data) {
  62. MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
  63. if (param_aggrs_.count(param_name) == 0) {
  64. // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
  65. // just print a warning log and return true.
  66. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  67. return true;
  68. }
  69. std::mutex &mtx = parameter_mutex_[param_name];
  70. std::unique_lock<std::mutex> lock(mtx);
  71. auto &param_aggr = param_aggrs_[param_name];
  72. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  73. if (!param_aggr->UpdateData(upload_data)) {
  74. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  75. return false;
  76. }
  77. // Different from Push, UpdateModel doesn't need to checkout the aggregation status.
  78. if (!param_aggr->LaunchAggregators()) {
  79. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  80. return false;
  81. }
  82. return true;
  83. }
  84. bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
  85. for (const auto &trainable_param : feature_map) {
  86. const std::string &param_name = trainable_param.first;
  87. if (param_aggrs_.count(param_name) == 0) {
  88. MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
  89. continue;
  90. }
  91. std::mutex &mtx = parameter_mutex_[param_name];
  92. std::unique_lock<std::mutex> lock(mtx);
  93. auto &param_aggr = param_aggrs_[param_name];
  94. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  95. AddressPtr old_weight = param_aggr->GetWeight();
  96. const Address &new_weight = trainable_param.second;
  97. MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false);
  98. MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false);
  99. MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false);
  100. int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
  101. if (ret != 0) {
  102. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  103. return false;
  104. }
  105. }
  106. return true;
  107. }
  108. std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> &param_names) {
  109. std::map<std::string, AddressPtr> weights;
  110. for (const auto &param_name : param_names) {
  111. if (param_aggrs_.count(param_name) == 0) {
  112. MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
  113. return weights;
  114. }
  115. std::mutex &mtx = parameter_mutex_[param_name];
  116. std::unique_lock<std::mutex> lock(mtx);
  117. const auto &param_aggr = param_aggrs_[param_name];
  118. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights);
  119. AddressPtr addr = param_aggr->GetWeight();
  120. if (addr == nullptr) {
  121. MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
  122. continue;
  123. }
  124. weights[param_name] = addr;
  125. }
  126. return weights;
  127. }
  128. bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
  129. bool Executor::RunAllWeightAggregation() {
  130. for (const auto &name : param_names_) {
  131. if (param_aggrs_.count(name) == 0) {
  132. MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
  133. return false;
  134. }
  135. std::mutex &mtx = parameter_mutex_[name];
  136. std::unique_lock<std::mutex> lock(mtx);
  137. auto &param_aggr = param_aggrs_[name];
  138. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  139. if (!param_aggr->requires_aggr()) {
  140. continue;
  141. }
  142. if (!param_aggr->RunAggregation()) {
  143. MS_LOG(WARNING) << "Failed to run aggregation for " << name;
  144. return false;
  145. }
  146. }
  147. return true;
  148. }
  149. bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {
  150. for (const auto &name : param_names) {
  151. if (param_aggrs_.count(name) == 0) {
  152. MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
  153. return false;
  154. }
  155. std::mutex &mtx = parameter_mutex_[name];
  156. std::unique_lock<std::mutex> lock(mtx);
  157. auto &param_aggr = param_aggrs_[name];
  158. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  159. if (!param_aggr->requires_aggr()) {
  160. continue;
  161. }
  162. if (!param_aggr->IsAggregationDone()) {
  163. MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
  164. return false;
  165. }
  166. }
  167. return true;
  168. }
  169. void Executor::ResetAggregationStatus() {
  170. for (const auto &param_name : param_names_) {
  171. std::mutex &mtx = parameter_mutex_[param_name];
  172. std::unique_lock<std::mutex> lock(mtx);
  173. auto &param_aggr = param_aggrs_[param_name];
  174. MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr);
  175. param_aggr->ResetAggregationStatus();
  176. }
  177. return;
  178. }
  179. std::map<std::string, AddressPtr> Executor::GetModel() {
  180. std::map<std::string, AddressPtr> model = {};
  181. for (const auto &name : param_names_) {
  182. std::mutex &mtx = parameter_mutex_[name];
  183. std::unique_lock<std::mutex> lock(mtx);
  184. AddressPtr addr = param_aggrs_[name]->GetWeight();
  185. if (addr == nullptr) {
  186. MS_LOG(WARNING) << "Get weight of " << name << " failed.";
  187. continue;
  188. }
  189. model[name] = addr;
  190. }
  191. return model;
  192. }
  193. const std::vector<std::string> &Executor::param_names() const { return param_names_; }
  194. bool Executor::Unmask() {
  195. #ifdef ENABLE_ARMOUR
  196. auto model = GetModel();
  197. return cipher_unmask_.UnMask(model);
  198. #else
  199. return false;
  200. #endif
  201. }
  202. void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; }
  203. bool Executor::unmasked() const {
  204. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  205. if (encrypt_type == ps::kPWEncryptType) {
  206. return unmasked_.load();
  207. } else {
  208. // If the algorithm of mind armour is not enabled, consider unmasked_ flag as true.
  209. return true;
  210. }
  211. }
  212. std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
  213. MS_EXCEPTION_IF_NULL(cnode);
  214. std::string cnode_name = common::AnfAlgo::GetCNodeName(cnode);
  215. if (kNameToIdxMap.count(cnode_name) == 0) {
  216. return "";
  217. }
  218. const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
  219. size_t weight_idx = index_info.at("inputs").at(kWeight);
  220. AnfNodePtr weight_node =
  221. common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
  222. MS_EXCEPTION_IF_NULL(weight_node);
  223. if (!weight_node->isa<Parameter>()) {
  224. MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
  225. return "";
  226. }
  227. return weight_node->fullname_with_scope();
  228. }
  229. bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
  230. MS_EXCEPTION_IF_NULL(func_graph);
  231. const auto &cnodes = func_graph->GetOrderedCnodes();
  232. for (const auto &cnode : cnodes) {
  233. MS_EXCEPTION_IF_NULL(cnode);
  234. const std::string &param_name = GetTrainableParamName(cnode);
  235. if (param_name.empty()) {
  236. continue;
  237. }
  238. if (param_aggrs_.count(param_name) != 0) {
  239. MS_LOG(WARNING) << param_name << " already has parameter aggregator registered.";
  240. continue;
  241. }
  242. std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
  243. MS_EXCEPTION_IF_NULL(param_aggr);
  244. param_names_.push_back(param_name);
  245. param_aggrs_[param_name] = param_aggr;
  246. parameter_mutex_[param_name];
  247. if (!param_aggr->Init(cnode, aggregation_count_)) {
  248. MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed.";
  249. return false;
  250. }
  251. MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
  252. }
  253. return true;
  254. }
  255. } // namespace server
  256. } // namespace fl
  257. } // namespace mindspore