You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 13 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/server/executor.h"
  17. #include <set>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. namespace mindspore {
  22. namespace fl {
  23. namespace server {
  24. void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
  25. MS_EXCEPTION_IF_NULL(func_graph);
  26. if (aggregation_count == 0) {
  27. MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
  28. return;
  29. }
  30. aggregation_count_ = aggregation_count;
  31. // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
  32. bool ret = InitParamAggregator(func_graph);
  33. if (!ret) {
  34. MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
  35. return;
  36. }
  37. initialized_ = true;
  38. return;
  39. }
  40. bool Executor::ReInitForScaling() {
  41. auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(),
  42. [](auto param_aggr) { return !param_aggr.second->ReInitForScaling(); });
  43. if (result != param_aggrs_.end()) {
  44. MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
  45. return false;
  46. }
  47. return true;
  48. }
  49. bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
  50. aggregation_count_ = aggr_threshold;
  51. auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(), [this](auto param_aggr) {
  52. return !param_aggr.second->ReInitForUpdatingHyperParams(aggregation_count_);
  53. });
  54. if (result != param_aggrs_.end()) {
  55. MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
  56. return false;
  57. }
  58. return true;
  59. }
  60. bool Executor::initialized() const { return initialized_; }
  61. bool Executor::HandlePush(const std::string &param_name, const UploadData &upload_data) {
  62. MS_LOG(DEBUG) << "Do Push for parameter " << param_name;
  63. if (param_aggrs_.count(param_name) == 0) {
  64. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  65. return false;
  66. }
  67. std::mutex &mtx = parameter_mutex_[param_name];
  68. std::unique_lock<std::mutex> lock(mtx);
  69. auto &param_aggr = param_aggrs_[param_name];
  70. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  71. // Push operation needs to wait until the pulling process is done.
  72. while (!param_aggr->IsPullingDone()) {
  73. lock.unlock();
  74. std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
  75. lock.lock();
  76. }
  77. // 1.Update data with the uploaded data of the worker.
  78. if (!param_aggr->UpdateData(upload_data)) {
  79. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  80. return false;
  81. }
  82. // 2.Launch aggregation for this trainable parameter.
  83. if (!param_aggr->LaunchAggregators()) {
  84. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  85. return false;
  86. }
  87. if (param_aggr->IsAggregationDone()) {
  88. // 3.After the aggregation is done, optimize the trainable parameter.
  89. if (!param_aggr->LaunchOptimizers()) {
  90. MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed.";
  91. return false;
  92. }
  93. // 4.Reset pulling and aggregation status after optimizing is done.
  94. param_aggr->ResetPullingStatus();
  95. param_aggr->ResetAggregationStatus();
  96. }
  97. return true;
  98. }
  99. bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData &upload_data) {
  100. MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
  101. if (param_aggrs_.count(param_name) == 0) {
  102. // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
  103. // just print a warning log and return true.
  104. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  105. return true;
  106. }
  107. std::mutex &mtx = parameter_mutex_[param_name];
  108. std::unique_lock<std::mutex> lock(mtx);
  109. auto &param_aggr = param_aggrs_[param_name];
  110. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  111. if (!param_aggr->UpdateData(upload_data)) {
  112. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  113. return false;
  114. }
  115. // Different from Push, UpdateModel doesn't need to checkout the aggregation status.
  116. if (!param_aggr->LaunchAggregators()) {
  117. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  118. return false;
  119. }
  120. return true;
  121. }
  122. bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) {
  123. std::unique_lock<std::mutex> model_lock(model_mutex_);
  124. for (const auto &trainable_param : feature_map) {
  125. const std::string &param_name = trainable_param.first;
  126. if (param_aggrs_.count(param_name) == 0) {
  127. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  128. continue;
  129. }
  130. std::mutex &mtx = parameter_mutex_[param_name];
  131. std::unique_lock<std::mutex> lock(mtx);
  132. auto &param_aggr = param_aggrs_[param_name];
  133. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  134. const UploadData &upload_data = trainable_param.second;
  135. if (!param_aggr->UpdateData(upload_data)) {
  136. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  137. return false;
  138. }
  139. if (!param_aggr->LaunchAggregators()) {
  140. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  141. return false;
  142. }
  143. }
  144. return true;
  145. }
  146. bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
  147. for (const auto &trainable_param : feature_map) {
  148. const std::string &param_name = trainable_param.first;
  149. if (param_aggrs_.count(param_name) == 0) {
  150. MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
  151. continue;
  152. }
  153. std::mutex &mtx = parameter_mutex_[param_name];
  154. std::unique_lock<std::mutex> lock(mtx);
  155. auto &param_aggr = param_aggrs_[param_name];
  156. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  157. AddressPtr old_weight = param_aggr->GetWeight();
  158. const Address &new_weight = trainable_param.second;
  159. MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false);
  160. MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false);
  161. MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false);
  162. int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
  163. if (ret != 0) {
  164. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  165. return false;
  166. }
  167. }
  168. return true;
  169. }
  170. AddressPtr Executor::HandlePull(const std::string &param_name) {
  171. MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
  172. if (param_aggrs_.count(param_name) == 0) {
  173. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  174. return nullptr;
  175. }
  176. std::mutex &mtx = parameter_mutex_[param_name];
  177. std::unique_lock<std::mutex> lock(mtx);
  178. auto &param_aggr = param_aggrs_[param_name];
  179. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, nullptr);
  180. // Pulling must wait until the optimizing process is done.
  181. while (!param_aggr->IsOptimizingDone()) {
  182. lock.unlock();
  183. std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
  184. lock.lock();
  185. }
  186. AddressPtr addr = param_aggr->Pull();
  187. // If this Pull is the last one, reset pulling and optimizing status.
  188. if (param_aggr->IsPullingDone()) {
  189. param_aggr->ResetOptimizingStatus();
  190. }
  191. return addr;
  192. }
  193. std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> &param_names) {
  194. std::map<std::string, AddressPtr> weights;
  195. for (const auto &param_name : param_names) {
  196. if (param_aggrs_.count(param_name) == 0) {
  197. MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
  198. return weights;
  199. }
  200. std::mutex &mtx = parameter_mutex_[param_name];
  201. std::unique_lock<std::mutex> lock(mtx);
  202. const auto &param_aggr = param_aggrs_[param_name];
  203. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, weights);
  204. AddressPtr addr = param_aggr->GetWeight();
  205. if (addr == nullptr) {
  206. MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
  207. continue;
  208. }
  209. weights[param_name] = addr;
  210. }
  211. return weights;
  212. }
  213. bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
  214. bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {
  215. for (const auto &name : param_names) {
  216. if (param_aggrs_.count(name) == 0) {
  217. MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
  218. return false;
  219. }
  220. std::mutex &mtx = parameter_mutex_[name];
  221. std::unique_lock<std::mutex> lock(mtx);
  222. auto &param_aggr = param_aggrs_[name];
  223. MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
  224. if (!param_aggr->requires_aggr()) {
  225. continue;
  226. }
  227. if (!param_aggr->IsAggregationDone()) {
  228. MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
  229. return false;
  230. }
  231. }
  232. return true;
  233. }
  234. void Executor::ResetAggregationStatus() {
  235. for (const auto &param_name : param_names_) {
  236. std::mutex &mtx = parameter_mutex_[param_name];
  237. std::unique_lock<std::mutex> lock(mtx);
  238. auto &param_aggr = param_aggrs_[param_name];
  239. MS_ERROR_IF_NULL_WO_RET_VAL(param_aggr);
  240. param_aggr->ResetAggregationStatus();
  241. }
  242. return;
  243. }
  244. std::map<std::string, AddressPtr> Executor::GetModel() {
  245. std::map<std::string, AddressPtr> model = {};
  246. for (const auto &name : param_names_) {
  247. std::mutex &mtx = parameter_mutex_[name];
  248. std::unique_lock<std::mutex> lock(mtx);
  249. AddressPtr addr = param_aggrs_[name]->GetWeight();
  250. if (addr == nullptr) {
  251. MS_LOG(WARNING) << "Get weight of " << name << " failed.";
  252. continue;
  253. }
  254. model[name] = addr;
  255. }
  256. return model;
  257. }
  258. const std::vector<std::string> &Executor::param_names() const { return param_names_; }
  259. bool Executor::Unmask() {
  260. #ifdef ENABLE_ARMOUR
  261. auto model = GetModel();
  262. return cipher_unmask_.UnMask(model);
  263. #else
  264. return false;
  265. #endif
  266. }
  267. void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; }
  268. bool Executor::unmasked() const {
  269. std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
  270. if (encrypt_type == ps::kPWEncryptType) {
  271. return unmasked_.load();
  272. } else {
  273. // If the algorithm of pairwise encrypt is not enabled, consider_ unmasked flag as true.
  274. return true;
  275. }
  276. }
  277. std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
  278. MS_EXCEPTION_IF_NULL(cnode);
  279. std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
  280. if (kNameToIdxMap.count(cnode_name) == 0) {
  281. return "";
  282. }
  283. const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
  284. size_t weight_idx = index_info.at("inputs").at(kWeight);
  285. AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
  286. MS_EXCEPTION_IF_NULL(weight_node);
  287. if (!weight_node->isa<Parameter>()) {
  288. MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
  289. return "";
  290. }
  291. return weight_node->fullname_with_scope();
  292. }
  293. bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
  294. MS_EXCEPTION_IF_NULL(func_graph);
  295. const auto &cnodes = func_graph->GetOrderedCnodes();
  296. for (const auto &cnode : cnodes) {
  297. MS_EXCEPTION_IF_NULL(cnode);
  298. const std::string &param_name = GetTrainableParamName(cnode);
  299. if (param_name.empty()) {
  300. continue;
  301. }
  302. if (param_aggrs_.count(param_name) != 0) {
  303. MS_LOG(WARNING) << param_name << " already has parameter aggregator registered.";
  304. continue;
  305. }
  306. std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
  307. MS_EXCEPTION_IF_NULL(param_aggr);
  308. param_names_.push_back(param_name);
  309. param_aggrs_[param_name] = param_aggr;
  310. parameter_mutex_[param_name];
  311. if (!param_aggr->Init(cnode, aggregation_count_)) {
  312. MS_LOG(EXCEPTION) << "Initializing parameter aggregator for " << param_name << " failed.";
  313. return false;
  314. }
  315. MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
  316. }
  317. return true;
  318. }
  319. } // namespace server
  320. } // namespace fl
  321. } // namespace mindspore