You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.cc 9.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/util.h"
  17. #include <vector>
  18. #include <memory>
  19. #include "utils/hash_map.h"
  20. #include "ps/constants.h"
  21. #include "ps/ps_context.h"
  22. #include "utils/ms_utils.h"
  23. namespace mindspore {
  24. namespace ps {
  25. mindspore::HashMap<std::string, int64_t> Util::optimizer_to_ids{
  26. {kApplyMomentum, 0},
  27. {kSparseAdam, 1},
  28. {kSparseLazyAdam, 2},
  29. {kSparseFtrl, 3},
  30. };
  31. mindspore::HashMap<int64_t, std::string> Util::id_to_optimizers{
  32. {0, kApplyMomentum},
  33. {1, kSparseAdam},
  34. {2, kSparseLazyAdam},
  35. {3, kSparseFtrl},
  36. };
  37. mindspore::HashMap<int64_t, std::string> Util::id_to_optimizer_nodes{
  38. {0, kApplyMomentumOp},
  39. {1, kSparseAdamOp},
  40. {2, kSparseLazyAdamOp},
  41. {3, kSparseFtrlOp},
  42. };
  43. bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
  44. bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
  45. int64_t Util::optimizer_id(const std::string &name) {
  46. if (optimizer_to_ids.count(name) > 0) {
  47. return optimizer_to_ids[name];
  48. }
  49. return -1;
  50. }
  51. std::string Util::optimizer_name(int64_t id) {
  52. if (id_to_optimizers.count(id) > 0) {
  53. return id_to_optimizers[id];
  54. }
  55. return "";
  56. }
  57. std::string Util::optimizer_node_name(int64_t id) {
  58. if (id_to_optimizer_nodes.count(id) > 0) {
  59. return id_to_optimizer_nodes[id];
  60. }
  61. return "";
  62. }
  63. bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; }
  64. int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
  65. std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
  66. if (shard_dims.count(rank_id) == 0) {
  67. MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
  68. }
  69. return shard_dims[rank_id];
  70. }
  71. std::map<int64_t, int64_t> Util::AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
  72. if (first_dim <= 0 || server_num <= 0 || rank_id < 0) {
  73. MS_LOG(EXCEPTION) << "Input values are invalid.";
  74. }
  75. if (rank_id >= server_num) {
  76. MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
  77. }
  78. std::map<int64_t, int64_t> shard_dims;
  79. for (int64_t i = 0; i < server_num; i++) {
  80. shard_dims[i] = 0;
  81. }
  82. if (server_num != static_cast<int64_t>(shard_dims.size())) {
  83. MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
  84. }
  85. int64_t server_index = -1;
  86. for (int64_t i = 0; i < first_dim; i++) {
  87. server_index = (server_index + 1) % server_num;
  88. shard_dims[server_index] = shard_dims[server_index] + 1;
  89. }
  90. if (shard_dims.count(rank_id) == 0) {
  91. MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
  92. }
  93. return shard_dims;
  94. }
  95. void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
  96. const size_t first_dim_size, const size_t outer_dim_size,
  97. mindspore::kernel::SparseGradient<int> *unique_sparse_grad) {
  98. size_t slice_segment_size = indices_size * segment_size;
  99. std::vector<float> workspace_grad(slice_segment_size);
  100. std::vector<int> workspace_indices(indices_size);
  101. MS_EXCEPTION_IF_NULL(gradients);
  102. MS_EXCEPTION_IF_NULL(indices);
  103. mindspore::kernel::SparseGradient<int> workspace_sparse_grad(
  104. {workspace_grad.data(), workspace_indices.data(), indices_size});
  105. mindspore::kernel::SparseGradient<int> input_sparse_grad({gradients, indices, indices_size});
  106. mindspore::kernel::ReduceSparseGradientParam<int> param;
  107. param.input_grad_ = &input_sparse_grad;
  108. param.workspace_grad_ = &workspace_sparse_grad;
  109. param.output_grad_ = unique_sparse_grad;
  110. param.max_index_ = first_dim_size;
  111. param.value_stride_ = outer_dim_size;
  112. mindspore::kernel::SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
  113. }
  114. bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) {
  115. FuncGraphPtr func_graph = res->func_graph();
  116. MS_EXCEPTION_IF_NULL(func_graph);
  117. DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName);
  118. DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName);
  119. return true;
  120. }
  121. void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
  122. const std::string &fused_cnode_name) {
  123. MS_EXCEPTION_IF_NULL(func_graph);
  124. std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
  125. std::vector<AnfNodePtr> single_nodes;
  126. std::vector<std::string> weight_names;
  127. std::vector<int64_t> indices;
  128. for (const AnfNodePtr &node : node_list) {
  129. if (node != nullptr && node->isa<CNode>()) {
  130. if (AnfAlgo::GetCNodeName(node) == cnode_name) {
  131. single_nodes.push_back(node);
  132. auto weight_name_value_node =
  133. AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightNameOffset)->cast<ValueNodePtr>();
  134. const std::string &weight_name = GetValue<std::string>(weight_name_value_node->value());
  135. weight_names.push_back(weight_name);
  136. auto weight_index_value_node =
  137. AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightIndexOffset)->cast<ValueNodePtr>();
  138. int64_t weight_index = GetValue<int64_t>(weight_index_value_node->value());
  139. indices.push_back(weight_index);
  140. }
  141. }
  142. }
  143. auto prim = std::make_shared<Primitive>(fused_cnode_name);
  144. MS_EXCEPTION_IF_NULL(prim);
  145. std::vector<AnfNodePtr> fused_node_inputs = {};
  146. fused_node_inputs.push_back(NewValueNode(prim));
  147. (void)std::for_each(single_nodes.begin(), single_nodes.end(), [&](const AnfNodePtr &node) {
  148. fused_node_inputs.push_back(AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0));
  149. });
  150. auto fused_cnode = func_graph->NewCNode(fused_node_inputs);
  151. MS_EXCEPTION_IF_NULL(fused_cnode);
  152. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode);
  153. AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode);
  154. AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode);
  155. auto kernel_info = std::make_shared<device::KernelInfo>();
  156. MS_EXCEPTION_IF_NULL(kernel_info);
  157. fused_cnode->set_kernel_info(kernel_info);
  158. auto kernel_build_info = GenerateKernelBuildInfo(single_nodes);
  159. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get());
  160. AbstractBasePtrList abstract_list;
  161. for (const auto &node : single_nodes) {
  162. auto cnode = node->cast<CNodePtr>();
  163. MS_EXCEPTION_IF_NULL(cnode);
  164. abstract_list.push_back(cnode->abstract());
  165. }
  166. auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
  167. MS_EXCEPTION_IF_NULL(abstract_tuple);
  168. fused_cnode->set_abstract(abstract_tuple);
  169. auto manager = func_graph->manager();
  170. MS_EXCEPTION_IF_NULL(manager);
  171. for (const auto &node : single_nodes) {
  172. if (!manager->Replace(node, fused_cnode)) {
  173. MS_LOG(EXCEPTION) << "manager replace node failed";
  174. }
  175. }
  176. return;
  177. }
  178. kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
  179. std::vector<std::string> inputs_device_format;
  180. std::vector<std::string> outputs_device_format;
  181. std::vector<TypeId> inputs_device_type;
  182. std::vector<TypeId> outputs_device_type;
  183. std::vector<std::vector<size_t>> outputs_shape;
  184. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  185. for (size_t idx = 0; idx < node_list.size(); ++idx) {
  186. auto cnode = utils::cast<CNodePtr>(node_list[idx]);
  187. MS_EXCEPTION_IF_NULL(cnode);
  188. size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
  189. for (size_t input_index = 0; input_index < input_num; ++input_index) {
  190. (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
  191. inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
  192. }
  193. size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
  194. for (size_t output_index = 0; output_index < output_num; ++output_index) {
  195. (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
  196. outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
  197. outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
  198. }
  199. }
  200. builder.SetInputsFormat(inputs_device_format);
  201. builder.SetOutputsFormat(outputs_device_format);
  202. builder.SetInputsDeviceType(inputs_device_type);
  203. builder.SetOutputsDeviceType(outputs_device_type);
  204. return builder.Build();
  205. }
  206. } // namespace ps
  207. } // namespace mindspore