You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_aggregator.cc 15 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/server/parameter_aggregator.h"
  17. #include <map>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include <utility>
  22. #include <algorithm>
  23. namespace mindspore {
  24. namespace fl {
  25. namespace server {
  26. bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
  27. MS_EXCEPTION_IF_NULL(cnode);
  28. memory_register_ = std::make_shared<MemoryRegister>();
  29. MS_EXCEPTION_IF_NULL(memory_register_);
  30. required_push_count_ = threshold_count;
  31. // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
  32. // required_pull_count_ normally used in parameter server training mode.
  33. required_pull_count_ = threshold_count;
  34. MS_LOG(INFO) << "Start initializing kernels for " << common::AnfAlgo::GetCNodeName(cnode);
  35. if (!InitAggregationKernels(cnode)) {
  36. MS_LOG(EXCEPTION) << "Initializing aggregation kernels failed.";
  37. return false;
  38. }
  39. if (!InitOptimizerKernels(cnode)) {
  40. MS_LOG(EXCEPTION) << "Initializing optimizer kernels failed.";
  41. return false;
  42. }
  43. return true;
  44. }
  45. bool ParameterAggregator::ReInitForScaling() {
  46. auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
  47. [](auto aggregation_kernel) {
  48. MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
  49. return !aggregation_kernel.first->ReInitForScaling();
  50. });
  51. if (result != aggregation_kernel_parameters_.end()) {
  52. MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
  53. return false;
  54. }
  55. return true;
  56. }
  57. bool ParameterAggregator::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
  58. required_push_count_ = aggr_threshold;
  59. required_pull_count_ = aggr_threshold;
  60. auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
  61. [aggr_threshold](auto aggregation_kernel) {
  62. MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
  63. return !aggregation_kernel.first->ReInitForUpdatingHyperParams(aggr_threshold);
  64. });
  65. if (result != aggregation_kernel_parameters_.end()) {
  66. MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
  67. return false;
  68. }
  69. return true;
  70. }
  71. bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_data) {
  72. std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
  73. for (const auto &data : new_data) {
  74. const std::string &name = data.first;
  75. if (name_to_addr.count(name) == 0) {
  76. continue;
  77. }
  78. MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name], false);
  79. MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name]->addr, false);
  80. MS_ERROR_IF_NULL_W_RET_VAL(data.second.addr, false);
  81. MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size
  82. << ". Source size: " << data.second.size;
  83. int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size);
  84. if (ret != 0) {
  85. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  86. return false;
  87. }
  88. }
  89. return true;
  90. }
  91. bool ParameterAggregator::LaunchAggregators() {
  92. for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
  93. KernelParams &params = aggregator_with_params.second;
  94. std::shared_ptr<kernel::AggregationKernelMod> aggr_kernel = aggregator_with_params.first;
  95. MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
  96. bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs);
  97. if (!ret) {
  98. MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed.";
  99. return false;
  100. }
  101. }
  102. return true;
  103. }
  104. AddressPtr ParameterAggregator::GetWeight() {
  105. if (memory_register_ == nullptr) {
  106. MS_LOG(ERROR)
  107. << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
  108. return nullptr;
  109. }
  110. std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
  111. return name_to_addr["weight"];
  112. }
  113. void ParameterAggregator::ResetAggregationStatus() {
  114. for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
  115. std::shared_ptr<kernel::AggregationKernelMod> aggr_kernel = aggregator_with_params.first;
  116. if (aggr_kernel == nullptr) {
  117. MS_LOG(ERROR) << "The aggregation kernel is nullptr.";
  118. continue;
  119. }
  120. aggr_kernel->Reset();
  121. }
  122. return;
  123. }
  124. void ParameterAggregator::ResetOptimizingStatus() { optimizing_done_ = false; }
  125. void ParameterAggregator::ResetPullingStatus() {
  126. pulling_done_ = false;
  127. current_pull_count_ = 0;
  128. }
  129. bool ParameterAggregator::IsAggregationDone() const {
  130. // Only consider aggregation done after each aggregation kernel is done.
  131. for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
  132. std::shared_ptr<kernel::AggregationKernelMod> aggr_kernel = aggregator_with_params.first;
  133. MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
  134. if (!aggr_kernel->IsAggregationDone()) {
  135. return false;
  136. }
  137. }
  138. return true;
  139. }
  140. bool ParameterAggregator::RunAggregation() {
  141. for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
  142. std::shared_ptr<kernel::AggregationKernelMod> aggr_kernel = aggregator_with_params.first;
  143. MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
  144. if (!aggr_kernel->AllReduce()) {
  145. return false;
  146. }
  147. }
  148. return true;
  149. }
  150. bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; }
  151. bool ParameterAggregator::IsPullingDone() const { return pulling_done_; }
  152. bool ParameterAggregator::requires_aggr() const { return requires_aggr_; }
  153. bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
  154. MS_EXCEPTION_IF_NULL(cnode);
  155. if (!JudgeRequiredAggr(cnode)) {
  156. MS_LOG(WARNING) << "Aggregation for weight of kernel " << common::AnfAlgo::GetCNodeName(cnode)
  157. << " is not required.";
  158. }
  159. std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
  160. for (const std::string &name : aggr_kernel_names) {
  161. auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode);
  162. if (aggr_kernel == nullptr) {
  163. MS_LOG(EXCEPTION) << "Fail to create aggregation kernel " << name << " for "
  164. << common::AnfAlgo::GetCNodeName(cnode);
  165. return false;
  166. }
  167. // set_done_count must be called before InitKernel because InitKernel may use this count.
  168. aggr_kernel->set_done_count(required_push_count_);
  169. aggr_kernel->InitKernel(cnode);
  170. const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = aggr_kernel->reuse_kernel_node_inputs_info();
  171. if (!AssignMemory(aggr_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
  172. MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
  173. return false;
  174. }
  175. if (!GenerateAggregationKernelParams(aggr_kernel, memory_register_)) {
  176. MS_LOG(EXCEPTION) << "Generating aggregation kernel parameters for " << name << " failed.";
  177. return false;
  178. }
  179. }
  180. return true;
  181. }
  182. bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &) {
  183. if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
  184. ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
  185. MS_LOG(INFO) << "Federated learning mode doesn't need optimizer kernel.";
  186. return true;
  187. }
  188. return false;
  189. }
  190. template <typename K>
  191. bool ParameterAggregator::AssignMemory(const K server_kernel, const CNodePtr &cnode,
  192. const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
  193. const std::shared_ptr<MemoryRegister> &memory_register) {
  194. MS_EXCEPTION_IF_NULL(server_kernel);
  195. MS_EXCEPTION_IF_NULL(cnode);
  196. MS_EXCEPTION_IF_NULL(memory_register);
  197. const std::vector<std::string> &input_names = server_kernel->input_names();
  198. const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList();
  199. if (input_names.size() != input_size_list.size()) {
  200. MS_LOG(EXCEPTION) << "Server kernel " << typeid(server_kernel.get()).name()
  201. << " input number is not matched: input_names size is " << input_names.size()
  202. << ", input_size_list size is " << input_size_list.size();
  203. return false;
  204. }
  205. if (reuse_kernel_node_inputs_info.size() > input_names.size()) {
  206. MS_LOG(EXCEPTION) << "The reuse kernel node information number is invalid: got "
  207. << reuse_kernel_node_inputs_info.size() << ", but input_names size is " << input_names.size();
  208. return false;
  209. }
  210. for (size_t i = 0; i < input_names.size(); i++) {
  211. const std::string &name = input_names[i];
  212. if (memory_register->addresses().count(name) != 0) {
  213. MS_LOG(DEBUG) << "The memory for " << name << " is already assigned.";
  214. continue;
  215. }
  216. if (reuse_kernel_node_inputs_info.count(name) != 0) {
  217. // Reusing memory of the kernel node means the memory of the input is already assigned by the front end, which
  218. // is to say, the input node is a parameter node.
  219. size_t index = reuse_kernel_node_inputs_info.at(name);
  220. MS_LOG(INFO) << "Try to reuse memory of kernel node " << common::AnfAlgo::GetCNodeName(cnode) << " for parameter "
  221. << name << ", kernel node index " << index;
  222. AddressPtr input_addr = GenerateParameterNodeAddrPtr(cnode, index);
  223. MS_EXCEPTION_IF_NULL(input_addr);
  224. memory_register->RegisterAddressPtr(name, input_addr);
  225. } else {
  226. MS_LOG(INFO) << "Assign new memory for " << name;
  227. auto input_addr = std::make_unique<char[]>(input_size_list[i]);
  228. MS_EXCEPTION_IF_NULL(input_addr);
  229. memory_register->RegisterArray(name, &input_addr, input_size_list[i]);
  230. }
  231. }
  232. return true;
  233. }
  234. bool ParameterAggregator::GenerateAggregationKernelParams(
  235. const std::shared_ptr<kernel::AggregationKernelMod> &aggr_kernel,
  236. const std::shared_ptr<MemoryRegister> &memory_register) {
  237. MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
  238. MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
  239. KernelParams aggr_params = {};
  240. const std::vector<std::string> &input_names = aggr_kernel->input_names();
  241. (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
  242. [&](const std::string &name) { return memory_register->addresses()[name]; });
  243. const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names();
  244. (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
  245. [&](const std::string &name) { return memory_register->addresses()[name]; });
  246. const std::vector<std::string> &output_names = aggr_kernel->output_names();
  247. (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
  248. [&](const std::string &name) { return memory_register->addresses()[name]; });
  249. aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs);
  250. aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params));
  251. return true;
  252. }
  253. std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
  254. std::vector<std::string> aggregation_algorithm = {};
  255. if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
  256. ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
  257. (void)aggregation_algorithm.emplace_back("FedAvg");
  258. } else if (ps::PSContext::instance()->server_mode() == ps::kServerModePS) {
  259. (void)aggregation_algorithm.emplace_back("DenseGradAccum");
  260. } else {
  261. MS_LOG(EXCEPTION) << "Server doesn't support mode " << ps::PSContext::instance()->server_mode();
  262. return aggregation_algorithm;
  263. }
  264. MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
  265. return aggregation_algorithm;
  266. }
  267. bool ParameterAggregator::JudgeRequiredAggr(const CNodePtr &cnode) {
  268. MS_EXCEPTION_IF_NULL(cnode);
  269. std::string cnode_name = common::AnfAlgo::GetCNodeName(cnode);
  270. if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
  271. kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
  272. MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
  273. return false;
  274. }
  275. size_t cnode_weight_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
  276. auto weight_node =
  277. common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(cnode, cnode_weight_idx), 0).first;
  278. MS_EXCEPTION_IF_NULL(weight_node);
  279. if (!weight_node->isa<Parameter>()) {
  280. MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter.";
  281. return false;
  282. }
  283. auto param_info = weight_node->cast<ParameterPtr>()->param_info();
  284. MS_EXCEPTION_IF_NULL(param_info);
  285. requires_aggr_ = param_info->requires_aggr();
  286. return requires_aggr_;
  287. }
  288. template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernelMod> server_kernel,
  289. const CNodePtr &cnode,
  290. const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
  291. const std::shared_ptr<MemoryRegister> &memory_register);
  292. template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernelMod> server_kernel,
  293. const CNodePtr &cnode,
  294. const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
  295. const std::shared_ptr<MemoryRegister> &memory_register);
  296. } // namespace server
  297. } // namespace fl
  298. } // namespace mindspore