You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

costmodel.cc 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/parallel/auto_parallel/costmodel.h"
  17. #include <cmath>
  18. #include <numeric>
  19. #include <utility>
  20. #include "frontend/parallel/auto_parallel/graph_costmodel.h"
  21. namespace mindspore {
  22. namespace parallel {
  23. void Simplify(CostPtrList *clist_ptrs) {
  24. const auto run_phase = CostModelContext::GetInstance()->run_phase();
  25. if (run_phase == TRAINING_PHASE) {
  26. // training phase
  27. SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
  28. } else {
  29. // inference phase
  30. SimplifyForDecreasingCommunicationForward(clist_ptrs);
  31. }
  32. }
  33. void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
  34. // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
  35. // excludes the cost with greater computation_cost_ and greater communication_forward.
  36. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
  37. const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
  38. if (!simplify_cal) {
  39. return;
  40. }
  41. MS_EXCEPTION_IF_NULL(clist_ptrs);
  42. std::vector<size_t> id(clist_ptrs->size());
  43. std::iota(id.begin(), id.end(), size_t(0));
  44. std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
  45. return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
  46. });
  47. CostPtrList ret;
  48. for (size_t i = 0; i < clist_ptrs->size(); ++i) {
  49. if ((ret.size() == size_t(0)) ||
  50. (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) {
  51. ret.emplace_back(std::move(clist_ptrs->at(id[i])));
  52. }
  53. }
  54. *clist_ptrs = std::move(ret);
  55. }
  56. void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
  57. // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
  58. // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
  59. const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
  60. if (!simplify_cal) {
  61. return;
  62. }
  63. MS_EXCEPTION_IF_NULL(clist_ptrs);
  64. std::vector<size_t> id(clist_ptrs->size());
  65. std::iota(id.begin(), id.end(), size_t(0));
  66. std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
  67. return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
  68. });
  69. CostPtrList ret;
  70. for (size_t i = 0; i < clist_ptrs->size(); ++i) {
  71. if ((ret.size() == size_t(0)) ||
  72. (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) {
  73. ret.emplace_back(std::move(clist_ptrs->at(id[i])));
  74. }
  75. }
  76. *clist_ptrs = std::move(ret);
  77. }
  78. void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) {
  79. MS_EXCEPTION_IF_NULL(origin_cost);
  80. const auto comm_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
  81. const auto comm_const = CostModelContext::GetInstance()->costmodel_communi_const();
  82. const auto comm_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
  83. const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
  84. if (is_redistribution) {
  85. // Redistribution cost
  86. if ((origin_cost->communication_redis_forward_ > EPS) &&
  87. (origin_cost->communication_redis_forward_ <= comm_threshold)) {
  88. origin_cost->communication_redis_forward_ = comm_const;
  89. } else if (origin_cost->communication_redis_forward_ > comm_threshold) {
  90. origin_cost->communication_redis_forward_ += comm_bias;
  91. }
  92. if ((origin_cost->communication_redis_backward_ > EPS) &&
  93. (origin_cost->communication_redis_backward_ <= comm_threshold)) {
  94. origin_cost->communication_redis_backward_ = comm_const;
  95. } else if (origin_cost->communication_redis_backward_ > comm_threshold) {
  96. origin_cost->communication_redis_backward_ += comm_bias;
  97. }
  98. origin_cost->communication_cost_ =
  99. origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_;
  100. origin_cost->communication_without_parameter_ = origin_cost->communication_cost_;
  101. origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_;
  102. } else {
  103. // Operator cost
  104. double backward = 0.0;
  105. if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) {
  106. backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_;
  107. }
  108. // forward cost
  109. if ((origin_cost->communication_without_parameter_ > EPS) &&
  110. (origin_cost->communication_without_parameter_ <= comm_threshold)) {
  111. origin_cost->communication_without_parameter_ = comm_const;
  112. } else if (origin_cost->communication_without_parameter_ > comm_threshold) {
  113. origin_cost->communication_without_parameter_ += comm_bias;
  114. }
  115. // total
  116. if (origin_cost->communication_cost_ > EPS) {
  117. origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward;
  118. }
  119. if (origin_cost->communication_with_partial_para_ > EPS) {
  120. origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + gamma * backward;
  121. }
  122. }
  123. }
  124. } // namespace parallel
  125. } // namespace mindspore