You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

costmodel.cc 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/auto_parallel/costmodel.h"
  17. #include <cmath>
  18. #include <numeric>
  19. #include <utility>
  20. #include "parallel/auto_parallel/graph_costmodel.h"
  21. namespace mindspore {
  22. namespace parallel {
  23. void Simplify(CostPtrList *clist_ptrs) {
  24. // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method
  25. // excludes the cost with greater computation_cost_ and greater communication_cost.
  26. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
  27. if (!COST_MODEL_SIMPLIFY_CALCULATION) {
  28. return;
  29. }
  30. MS_EXCEPTION_IF_NULL(clist_ptrs);
  31. std::vector<size_t> id(clist_ptrs->size());
  32. std::iota(id.begin(), id.end(), size_t(0));
  33. std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
  34. return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
  35. });
  36. CostPtrList ret;
  37. for (size_t i = 0; i < clist_ptrs->size(); ++i) {
  38. if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) {
  39. ret.emplace_back(std::move(clist_ptrs->at(id[i])));
  40. }
  41. }
  42. *clist_ptrs = std::move(ret);
  43. }
  44. void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
  45. // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
  46. // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
  47. if (!COST_MODEL_SIMPLIFY_CALCULATION) {
  48. return;
  49. }
  50. MS_EXCEPTION_IF_NULL(clist_ptrs);
  51. std::vector<size_t> id(clist_ptrs->size());
  52. std::iota(id.begin(), id.end(), size_t(0));
  53. std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
  54. return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
  55. });
  56. CostPtrList ret;
  57. for (size_t i = 0; i < clist_ptrs->size(); ++i) {
  58. if ((ret.size() == size_t(0)) ||
  59. (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) {
  60. ret.emplace_back(std::move(clist_ptrs->at(id[i])));
  61. }
  62. }
  63. *clist_ptrs = std::move(ret);
  64. }
  65. void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) {
  66. MS_EXCEPTION_IF_NULL(origin_cost);
  67. if (is_redistribution) {
  68. // Redistribution cost
  69. if ((origin_cost->communication_redis_forward_ > EPS) &&
  70. (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
  71. origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST;
  72. } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) {
  73. origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS;
  74. }
  75. if ((origin_cost->communication_redis_backward_ > EPS) &&
  76. (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
  77. origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST;
  78. } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) {
  79. origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS;
  80. }
  81. origin_cost->communication_cost_ =
  82. origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_;
  83. origin_cost->communication_without_parameter_ = origin_cost->communication_cost_;
  84. origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_;
  85. } else {
  86. // Operator cost
  87. double backward = 0.0;
  88. if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) {
  89. backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_;
  90. }
  91. // forward cost
  92. if ((origin_cost->communication_without_parameter_ > EPS) &&
  93. (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
  94. origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST;
  95. } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) {
  96. origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS;
  97. }
  98. // total
  99. if (origin_cost->communication_cost_ > EPS) {
  100. origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward;
  101. }
  102. if (origin_cost->communication_with_partial_para_ > EPS) {
  103. origin_cost->communication_with_partial_para_ =
  104. origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward;
  105. }
  106. }
  107. }
  108. } // namespace parallel
  109. } // namespace mindspore