You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/server/executor.h"
  17. #include <set>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. namespace mindspore {
  22. namespace ps {
  23. namespace server {
  24. void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
  25. MS_EXCEPTION_IF_NULL(func_graph);
  26. if (aggregation_count == 0) {
  27. MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
  28. return;
  29. }
  30. aggregation_count_ = aggregation_count;
  31. // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
  32. bool ret = InitParamAggregator(func_graph);
  33. if (!ret) {
  34. MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
  35. return;
  36. }
  37. initialized_ = true;
  38. return;
  39. }
  40. bool Executor::initialized() const { return initialized_; }
  41. bool Executor::HandlePush(const std::string &param_name, const UploadData &upload_data) {
  42. MS_LOG(DEBUG) << "Do Push for parameter " << param_name;
  43. if (param_aggrs_.count(param_name) == 0) {
  44. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  45. return false;
  46. }
  47. std::mutex &mtx = parameter_mutex_[param_name];
  48. std::unique_lock<std::mutex> lock(mtx);
  49. auto &param_aggr = param_aggrs_[param_name];
  50. // Push operation needs to wait until the pulling process is done.
  51. while (!param_aggr->IsPullingDone()) {
  52. lock.unlock();
  53. std::this_thread::sleep_for(std::chrono::milliseconds(5));
  54. lock.lock();
  55. }
  56. // 1.Update data with the uploaded data of the worker.
  57. if (!param_aggr->UpdateData(upload_data)) {
  58. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  59. return false;
  60. }
  61. // 2.Launch aggregation for this trainable parameter.
  62. if (!param_aggr->LaunchAggregators()) {
  63. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  64. return false;
  65. }
  66. if (param_aggr->IsAggregationDone()) {
  67. // 3.After the aggregation is done, optimize the trainable parameter.
  68. if (!param_aggr->LaunchOptimizers()) {
  69. MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed.";
  70. return false;
  71. }
  72. // 4.Reset pulling and aggregation status after optimizing is done.
  73. param_aggr->ResetPullingStatus();
  74. param_aggr->ResetAggregationStatus();
  75. }
  76. return true;
  77. }
  78. bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData &upload_data) {
  79. MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
  80. if (param_aggrs_.count(param_name) == 0) {
  81. // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
  82. // just print a warning log and return true.
  83. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  84. return true;
  85. }
  86. std::mutex &mtx = parameter_mutex_[param_name];
  87. std::unique_lock<std::mutex> lock(mtx);
  88. auto &param_aggr = param_aggrs_[param_name];
  89. if (!param_aggr->UpdateData(upload_data)) {
  90. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  91. return false;
  92. }
  93. // Different from Push, UpdateModel doesn't need to checkout the aggregation status.
  94. if (!param_aggr->LaunchAggregators()) {
  95. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  96. return false;
  97. }
  98. return true;
  99. }
  100. bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) {
  101. std::unique_lock<std::mutex> model_lock(model_mutex_);
  102. for (const auto &trainable_param : feature_map) {
  103. const std::string &param_name = trainable_param.first;
  104. if (param_aggrs_.count(param_name) == 0) {
  105. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  106. continue;
  107. }
  108. std::mutex &mtx = parameter_mutex_[param_name];
  109. std::unique_lock<std::mutex> lock(mtx);
  110. auto &param_aggr = param_aggrs_[param_name];
  111. const UploadData &upload_data = trainable_param.second;
  112. if (!param_aggr->UpdateData(upload_data)) {
  113. MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
  114. return false;
  115. }
  116. if (!param_aggr->LaunchAggregators()) {
  117. MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
  118. return false;
  119. }
  120. }
  121. return true;
  122. }
  123. bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map) {
  124. for (const auto &trainable_param : feature_map) {
  125. const std::string &param_name = trainable_param.first;
  126. if (param_aggrs_.count(param_name) == 0) {
  127. MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
  128. continue;
  129. }
  130. std::mutex &mtx = parameter_mutex_[param_name];
  131. std::unique_lock<std::mutex> lock(mtx);
  132. auto &param_aggr = param_aggrs_[param_name];
  133. AddressPtr old_weight = param_aggr->GetWeight();
  134. if (old_weight == nullptr) {
  135. MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
  136. return false;
  137. }
  138. const Address &new_weight = trainable_param.second;
  139. if (new_weight.addr == nullptr) {
  140. MS_LOG(ERROR) << "The new weight is nullptr.";
  141. return false;
  142. }
  143. int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
  144. if (ret != 0) {
  145. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  146. return false;
  147. }
  148. }
  149. return true;
  150. }
  151. AddressPtr Executor::HandlePull(const std::string &param_name) {
  152. MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
  153. if (param_aggrs_.count(param_name) == 0) {
  154. MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
  155. return nullptr;
  156. }
  157. std::mutex &mtx = parameter_mutex_[param_name];
  158. std::unique_lock<std::mutex> lock(mtx);
  159. auto &param_aggr = param_aggrs_[param_name];
  160. // Pulling must wait until the optimizing process is done.
  161. while (!param_aggr->IsOptimizingDone()) {
  162. lock.unlock();
  163. std::this_thread::sleep_for(std::chrono::milliseconds(5));
  164. lock.lock();
  165. }
  166. AddressPtr addr = param_aggr->Pull();
  167. // If this Pull is the last one, reset pulling and optimizing status.
  168. if (param_aggr->IsPullingDone()) {
  169. param_aggr->ResetOptimizingStatus();
  170. }
  171. return addr;
  172. }
  173. std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> &param_names) {
  174. std::map<std::string, AddressPtr> weights;
  175. for (const auto &param_name : param_names) {
  176. if (param_aggrs_.count(param_name) == 0) {
  177. MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
  178. return weights;
  179. }
  180. std::mutex &mtx = parameter_mutex_[param_name];
  181. std::unique_lock<std::mutex> lock(mtx);
  182. const auto &param_aggr = param_aggrs_[param_name];
  183. AddressPtr addr = param_aggr->GetWeight();
  184. if (addr == nullptr) {
  185. MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
  186. continue;
  187. }
  188. weights[param_name] = addr;
  189. }
  190. return weights;
  191. }
  192. bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
  193. bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {
  194. for (const auto &name : param_names) {
  195. if (param_aggrs_.count(name) == 0) {
  196. MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
  197. return false;
  198. }
  199. std::mutex &mtx = parameter_mutex_[name];
  200. std::unique_lock<std::mutex> lock(mtx);
  201. if (!param_aggrs_[name]->IsAggregationDone()) {
  202. MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
  203. return false;
  204. }
  205. }
  206. return true;
  207. }
  208. void Executor::ResetAggregationStatus() {
  209. for (const auto &param_name : param_names_) {
  210. std::mutex &mtx = parameter_mutex_[param_name];
  211. std::unique_lock<std::mutex> lock(mtx);
  212. param_aggrs_[param_name]->ResetAggregationStatus();
  213. }
  214. return;
  215. }
  216. std::map<std::string, AddressPtr> Executor::GetModel() {
  217. std::map<std::string, AddressPtr> model = {};
  218. for (const auto &name : param_names_) {
  219. std::mutex &mtx = parameter_mutex_[name];
  220. std::unique_lock<std::mutex> lock(mtx);
  221. AddressPtr addr = param_aggrs_[name]->GetWeight();
  222. if (addr == nullptr) {
  223. MS_LOG(WARNING) << "Get weight of " << name << " failed.";
  224. continue;
  225. }
  226. model[name] = addr;
  227. }
  228. return model;
  229. }
  230. // bool Executor::Unmask() {
  231. // auto model = GetModel();
  232. // return mindarmour::CipherMgr::GetInstance().UnMask(model);
  233. // }
  234. const std::vector<std::string> &Executor::param_names() const { return param_names_; }
  235. std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
  236. MS_EXCEPTION_IF_NULL(cnode);
  237. std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
  238. if (kNameToIdxMap.count(cnode_name) == 0) {
  239. return "";
  240. }
  241. const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
  242. size_t weight_idx = index_info.at("inputs").at(kWeight);
  243. AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
  244. MS_EXCEPTION_IF_NULL(weight_node);
  245. if (!weight_node->isa<Parameter>()) {
  246. MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
  247. }
  248. return weight_node->fullname_with_scope();
  249. }
  250. bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
  251. MS_EXCEPTION_IF_NULL(func_graph);
  252. const auto &cnodes = func_graph->GetOrderedCnodes();
  253. for (const auto &cnode : cnodes) {
  254. MS_EXCEPTION_IF_NULL(cnode);
  255. const std::string &param_name = GetTrainableParamName(cnode);
  256. if (param_name.empty()) {
  257. continue;
  258. }
  259. if (param_aggrs_.count(param_name) != 0) {
  260. MS_LOG(WARNING) << param_name << " already has its control flow.";
  261. continue;
  262. }
  263. std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
  264. MS_EXCEPTION_IF_NULL(param_aggr);
  265. param_names_.push_back(param_name);
  266. param_aggrs_[param_name] = param_aggr;
  267. parameter_mutex_[param_name];
  268. param_aggr->Init(cnode, aggregation_count_);
  269. MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success.";
  270. }
  271. return true;
  272. }
  273. } // namespace server
  274. } // namespace ps
  275. } // namespace mindspore