You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.cc 6.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/util.h"
  17. #include <unordered_map>
  18. #include <vector>
  19. #include "ps/common.h"
  20. #include "ps/ps_context.h"
  21. #include "utils/ms_utils.h"
  22. namespace mindspore {
  23. namespace ps {
  24. int64_t Util::rank_id_ = -1;
  25. std::unordered_map<std::string, int64_t> Util::optimizer_to_ids{
  26. {kApplyMomentum, 0},
  27. {kSparseAdam, 1},
  28. {kSparseLazyAdam, 2},
  29. {kSparseFtrl, 3},
  30. };
  31. std::unordered_map<int64_t, std::string> Util::id_to_optimizers{
  32. {0, kApplyMomentum},
  33. {1, kSparseAdam},
  34. {2, kSparseLazyAdam},
  35. {3, kSparseFtrl},
  36. };
  37. std::unordered_map<int64_t, std::string> Util::id_to_optimizer_nodes{
  38. {0, kApplyMomentumOp},
  39. {1, kSparseAdamOp},
  40. {2, kSparseLazyAdamOp},
  41. {3, kSparseFtrlOp},
  42. };
  43. bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_enabled(); }
  44. bool Util::IsRoleOfWorker() { return PSContext::instance()->is_role_worker(); }
  45. bool Util::IsRoleOfPServer() { return PSContext::instance()->is_role_pserver(); }
  46. bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_role_sched(); }
  47. void Util::SetInternalEnvVar() {
  48. if (IsParamServerMode()) {
  49. auto comm_type = common::GetEnv(kEnvCommType);
  50. if (!comm_type.empty()) {
  51. (void)common::SetEnv(kDmlcCommType, comm_type.c_str());
  52. }
  53. auto interface = common::GetEnv(kEnvInterface);
  54. if (!interface.empty()) {
  55. (void)common::SetEnv(kDmlcInterface, interface.c_str());
  56. }
  57. auto server_num = common::GetEnv(kEnvPServerNum);
  58. if (!server_num.empty()) {
  59. (void)common::SetEnv(kDmlcPServerNum, server_num.c_str());
  60. }
  61. auto worker_num = common::GetEnv(kEnvWorkerNum);
  62. if (!worker_num.empty()) {
  63. (void)common::SetEnv(kDmlcWorkerNum, worker_num.c_str());
  64. }
  65. if (IsRoleOfScheduler()) {
  66. (void)common::SetEnv(kDmlcRole, kRoleOfScheduler);
  67. } else if (IsRoleOfPServer()) {
  68. (void)common::SetEnv(kDmlcRole, kRoleOfPServer);
  69. } else if (IsRoleOfWorker()) {
  70. (void)common::SetEnv(kDmlcRole, kRoleOfWorker);
  71. }
  72. auto scheduler_host = common::GetEnv(kEnvSchedulerHost);
  73. if (!scheduler_host.empty()) {
  74. (void)common::SetEnv(kDmlcSchedulerHost, scheduler_host.c_str());
  75. }
  76. auto scheduler_port = common::GetEnv(kEnvSchedulerPort);
  77. if (!scheduler_port.empty()) {
  78. (void)common::SetEnv(kDmlcSchedulerPort, scheduler_port.c_str());
  79. }
  80. }
  81. }
  82. int64_t Util::optimizer_id(std::string name) {
  83. if (optimizer_to_ids.count(name) > 0) {
  84. return optimizer_to_ids[name];
  85. }
  86. return -1;
  87. }
  88. std::string Util::optimizer_name(int64_t id) {
  89. if (id_to_optimizers.count(id) > 0) {
  90. return id_to_optimizers[id];
  91. }
  92. return "";
  93. }
  94. std::string Util::optimizer_node_name(int64_t id) {
  95. if (id_to_optimizer_nodes.count(id) > 0) {
  96. return id_to_optimizer_nodes[id];
  97. }
  98. return "";
  99. }
  100. bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; }
  101. int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
  102. std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
  103. if (shard_dims.count(rank_id) == 0) {
  104. MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
  105. }
  106. return shard_dims[rank_id];
  107. }
  108. std::map<int64_t, int64_t> Util::AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
  109. if (first_dim <= 0 || server_num <= 0 || rank_id < 0) {
  110. MS_LOG(EXCEPTION) << "Input values are invalid.";
  111. }
  112. if (rank_id >= server_num) {
  113. MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
  114. }
  115. std::map<int64_t, int64_t> shard_dims;
  116. for (int64_t i = 0; i < server_num; i++) {
  117. shard_dims[i] = 0;
  118. }
  119. if (server_num != static_cast<int64_t>(shard_dims.size())) {
  120. MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
  121. }
  122. int64_t server_index = -1;
  123. for (int64_t i = 0; i < first_dim; i++) {
  124. server_index = (server_index + 1) % server_num;
  125. shard_dims[server_index] = shard_dims[server_index] + 1;
  126. }
  127. if (shard_dims.count(rank_id) == 0) {
  128. MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
  129. }
  130. return shard_dims;
  131. }
  132. void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
  133. const size_t first_dim_size, const size_t outer_dim_size,
  134. mindspore::kernel::SparseGradient<int> *unique_sparse_grad) {
  135. size_t slice_segment_size = indices_size * segment_size;
  136. std::vector<float> workspace_grad(slice_segment_size);
  137. std::vector<int> workspace_indices(indices_size);
  138. MS_EXCEPTION_IF_NULL(gradients);
  139. MS_EXCEPTION_IF_NULL(indices);
  140. mindspore::kernel::SparseGradient<int> workspace_sparse_grad(
  141. {workspace_grad.data(), workspace_indices.data(), indices_size});
  142. mindspore::kernel::SparseGradient<int> input_sparse_grad({gradients, indices, indices_size});
  143. mindspore::kernel::ReduceSparseGradientParam<int> param;
  144. param.input_grad_ = &input_sparse_grad;
  145. param.workspace_grad_ = &workspace_sparse_grad;
  146. param.output_grad_ = unique_sparse_grad;
  147. param.max_index_ = first_dim_size;
  148. param.value_stride_ = outer_dim_size;
  149. mindspore::kernel::SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
  150. }
  151. } // namespace ps
  152. } // namespace mindspore