You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_aggregator.cc 14 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "fl/server/parameter_aggregator.h"
  17. #include <map>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include <utility>
  22. #include <algorithm>
  23. namespace mindspore {
  24. namespace fl {
  25. namespace server {
  26. bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
  27. MS_EXCEPTION_IF_NULL(cnode);
  28. memory_register_ = std::make_shared<MemoryRegister>();
  29. MS_EXCEPTION_IF_NULL(memory_register_);
  30. required_push_count_ = threshold_count;
  31. // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
  32. // required_pull_count_ normally used in parameter server training mode.
  33. required_pull_count_ = threshold_count;
  34. MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
  35. if (!InitAggregationKernels(cnode)) {
  36. MS_LOG(EXCEPTION) << "Initializing aggregation kernels failed.";
  37. return false;
  38. }
  39. if (!InitOptimizerKernels(cnode)) {
  40. MS_LOG(EXCEPTION) << "Initializing optimizer kernels failed.";
  41. return false;
  42. }
  43. return true;
  44. }
  45. bool ParameterAggregator::ReInitForScaling() {
  46. auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
  47. [](auto aggregation_kernel) {
  48. MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
  49. return !aggregation_kernel.first->ReInitForScaling();
  50. });
  51. if (result != aggregation_kernel_parameters_.end()) {
  52. MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
  53. return false;
  54. }
  55. return true;
  56. }
  57. bool ParameterAggregator::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
  58. required_push_count_ = aggr_threshold;
  59. required_pull_count_ = aggr_threshold;
  60. auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
  61. [aggr_threshold](auto aggregation_kernel) {
  62. MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
  63. return !aggregation_kernel.first->ReInitForUpdatingHyperParams(aggr_threshold);
  64. });
  65. if (result != aggregation_kernel_parameters_.end()) {
  66. MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
  67. return false;
  68. }
  69. return true;
  70. }
  71. bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_data) {
  72. std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
  73. for (const auto &data : new_data) {
  74. const std::string &name = data.first;
  75. if (name_to_addr.count(name) == 0) {
  76. continue;
  77. }
  78. MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name], false);
  79. MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name]->addr, false);
  80. MS_ERROR_IF_NULL_W_RET_VAL(data.second.addr, false);
  81. MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size
  82. << ". Source size: " << data.second.size;
  83. int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size);
  84. if (ret != 0) {
  85. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  86. return false;
  87. }
  88. }
  89. return true;
  90. }
  91. bool ParameterAggregator::LaunchAggregators() {
  92. for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
  93. KernelParams &params = aggregator_with_params.second;
  94. std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
  95. MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
  96. bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs);
  97. if (!ret) {
  98. MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed.";
  99. return false;
  100. }
  101. }
  102. return true;
  103. }
  104. AddressPtr ParameterAggregator::GetWeight() {
  105. if (memory_register_ == nullptr) {
  106. MS_LOG(ERROR)
  107. << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
  108. return nullptr;
  109. }
  110. std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
  111. return name_to_addr["weight"];
  112. }
  113. void ParameterAggregator::ResetAggregationStatus() {
  114. for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
  115. std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
  116. if (aggr_kernel == nullptr) {
  117. MS_LOG(ERROR) << "The aggregation kernel is nullptr.";
  118. continue;
  119. }
  120. aggr_kernel->Reset();
  121. }
  122. return;
  123. }
  124. void ParameterAggregator::ResetOptimizingStatus() { optimizing_done_ = false; }
  125. void ParameterAggregator::ResetPullingStatus() {
  126. pulling_done_ = false;
  127. current_pull_count_ = 0;
  128. }
  129. bool ParameterAggregator::IsAggregationDone() const {
  130. // Only consider aggregation done after each aggregation kernel is done.
  131. for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
  132. std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
  133. MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
  134. if (!aggr_kernel->IsAggregationDone()) {
  135. return false;
  136. }
  137. }
  138. return true;
  139. }
  140. bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; }
  141. bool ParameterAggregator::IsPullingDone() const { return pulling_done_; }
  142. bool ParameterAggregator::requires_aggr() const { return requires_aggr_; }
  143. bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
  144. MS_EXCEPTION_IF_NULL(cnode);
  145. if (!JudgeRequiredAggr(cnode)) {
  146. MS_LOG(WARNING) << "Aggregation for weight of kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required.";
  147. }
  148. std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
  149. for (const std::string &name : aggr_kernel_names) {
  150. auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode);
  151. if (aggr_kernel == nullptr) {
  152. MS_LOG(EXCEPTION) << "Fail to create aggregation kernel " << name << " for " << AnfAlgo::GetCNodeName(cnode);
  153. return false;
  154. }
  155. // set_done_count must be called before InitKernel because InitKernel may use this count.
  156. aggr_kernel->set_done_count(required_push_count_);
  157. aggr_kernel->InitKernel(cnode);
  158. const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = aggr_kernel->reuse_kernel_node_inputs_info();
  159. if (!AssignMemory(aggr_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
  160. MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
  161. return false;
  162. }
  163. if (!GenerateAggregationKernelParams(aggr_kernel, memory_register_)) {
  164. MS_LOG(EXCEPTION) << "Generating aggregation kernel parameters for " << name << " failed.";
  165. return false;
  166. }
  167. }
  168. return true;
  169. }
  170. bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &) {
  171. if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
  172. ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
  173. MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
  174. return true;
  175. }
  176. return true;
  177. }
  178. template <typename K>
  179. bool ParameterAggregator::AssignMemory(const K server_kernel, const CNodePtr &cnode,
  180. const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
  181. const std::shared_ptr<MemoryRegister> &memory_register) {
  182. MS_EXCEPTION_IF_NULL(server_kernel);
  183. MS_EXCEPTION_IF_NULL(cnode);
  184. MS_EXCEPTION_IF_NULL(memory_register);
  185. const std::vector<std::string> &input_names = server_kernel->input_names();
  186. const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList();
  187. if (input_names.size() != input_size_list.size()) {
  188. MS_LOG(EXCEPTION) << "Server kernel " << typeid(server_kernel.get()).name()
  189. << " input number is not matched: input_names size is " << input_names.size()
  190. << ", input_size_list size is " << input_size_list.size();
  191. return false;
  192. }
  193. if (reuse_kernel_node_inputs_info.size() > input_names.size()) {
  194. MS_LOG(EXCEPTION) << "The reuse kernel node information number is invalid: got "
  195. << reuse_kernel_node_inputs_info.size() << ", but input_names size is " << input_names.size();
  196. return false;
  197. }
  198. for (size_t i = 0; i < input_names.size(); i++) {
  199. const std::string &name = input_names[i];
  200. if (memory_register->addresses().count(name) != 0) {
  201. MS_LOG(DEBUG) << "The memory for " << name << " is already assigned.";
  202. continue;
  203. }
  204. if (reuse_kernel_node_inputs_info.count(name) != 0) {
  205. // Reusing memory of the kernel node means the memory of the input is already assigned by the front end, which
  206. // is to say, the input node is a parameter node.
  207. size_t index = reuse_kernel_node_inputs_info.at(name);
  208. MS_LOG(INFO) << "Try to reuse memory of kernel node " << AnfAlgo::GetCNodeName(cnode) << " for parameter " << name
  209. << ", kernel node index " << index;
  210. AddressPtr input_addr = GenerateParameterNodeAddrPtr(cnode, index);
  211. MS_EXCEPTION_IF_NULL(input_addr);
  212. memory_register->RegisterAddressPtr(name, input_addr);
  213. } else {
  214. MS_LOG(INFO) << "Assign new memory for " << name;
  215. auto input_addr = std::make_unique<char[]>(input_size_list[i]);
  216. MS_EXCEPTION_IF_NULL(input_addr);
  217. memory_register->RegisterArray(name, &input_addr, input_size_list[i]);
  218. }
  219. }
  220. return true;
  221. }
  222. bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
  223. const std::shared_ptr<MemoryRegister> &memory_register) {
  224. MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
  225. MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
  226. KernelParams aggr_params = {};
  227. const std::vector<std::string> &input_names = aggr_kernel->input_names();
  228. (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
  229. [&](const std::string &name) { return memory_register->addresses()[name]; });
  230. const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names();
  231. (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
  232. [&](const std::string &name) { return memory_register->addresses()[name]; });
  233. const std::vector<std::string> &output_names = aggr_kernel->output_names();
  234. (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
  235. [&](const std::string &name) { return memory_register->addresses()[name]; });
  236. aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs);
  237. aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params));
  238. return true;
  239. }
  240. std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
  241. std::vector<std::string> aggregation_algorithm = {};
  242. if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
  243. ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
  244. (void)aggregation_algorithm.emplace_back("FedAvg");
  245. } else if (ps::PSContext::instance()->server_mode() == ps::kServerModePS) {
  246. (void)aggregation_algorithm.emplace_back("DenseGradAccum");
  247. } else {
  248. MS_LOG(EXCEPTION) << "Server doesn't support mode " << ps::PSContext::instance()->server_mode();
  249. return aggregation_algorithm;
  250. }
  251. MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
  252. return aggregation_algorithm;
  253. }
  254. bool ParameterAggregator::JudgeRequiredAggr(const CNodePtr &cnode) {
  255. MS_EXCEPTION_IF_NULL(cnode);
  256. std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
  257. if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
  258. kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
  259. MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
  260. return false;
  261. }
  262. size_t cnode_weight_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
  263. auto weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, cnode_weight_idx), 0).first;
  264. MS_EXCEPTION_IF_NULL(weight_node);
  265. if (!weight_node->isa<Parameter>()) {
  266. MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter.";
  267. return false;
  268. }
  269. auto param_info = weight_node->cast<ParameterPtr>()->param_info();
  270. MS_EXCEPTION_IF_NULL(param_info);
  271. requires_aggr_ = param_info->requires_aggr();
  272. return requires_aggr_;
  273. }
  274. template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel,
  275. const CNodePtr &cnode,
  276. const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
  277. const std::shared_ptr<MemoryRegister> &memory_register);
  278. template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernel> server_kernel,
  279. const CNodePtr &cnode,
  280. const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
  281. const std::shared_ptr<MemoryRegister> &memory_register);
  282. } // namespace server
  283. } // namespace fl
  284. } // namespace mindspore