You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

costmodel.cc 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/parallel/auto_parallel/costmodel.h"
  17. #include <cmath>
  18. #include <numeric>
  19. #include <utility>
  20. #include "frontend/parallel/auto_parallel/graph_costmodel.h"
  21. namespace mindspore {
  22. namespace parallel {
  23. void Simplify(CostPtrList *clist_ptrs) {
  24. if (RUN_PHASE == TRAINING_PHASE) {
  25. // training phase
  26. SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
  27. } else {
  28. // inference phase
  29. SimplifyForDecreasingCommunicationForward(clist_ptrs);
  30. }
  31. }
  32. void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
  33. // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
  34. // excludes the cost with greater computation_cost_ and greater communication_forward.
  35. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
  36. if (!COST_MODEL_SIMPLIFY_CALCULATION) {
  37. return;
  38. }
  39. MS_EXCEPTION_IF_NULL(clist_ptrs);
  40. std::vector<size_t> id(clist_ptrs->size());
  41. std::iota(id.begin(), id.end(), size_t(0));
  42. std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
  43. return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
  44. });
  45. CostPtrList ret;
  46. for (size_t i = 0; i < clist_ptrs->size(); ++i) {
  47. if ((ret.size() == size_t(0)) ||
  48. (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) {
  49. ret.emplace_back(std::move(clist_ptrs->at(id[i])));
  50. }
  51. }
  52. *clist_ptrs = std::move(ret);
  53. }
  54. void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
  55. // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
  56. // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
  57. if (!COST_MODEL_SIMPLIFY_CALCULATION) {
  58. return;
  59. }
  60. MS_EXCEPTION_IF_NULL(clist_ptrs);
  61. std::vector<size_t> id(clist_ptrs->size());
  62. std::iota(id.begin(), id.end(), size_t(0));
  63. std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
  64. return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
  65. });
  66. CostPtrList ret;
  67. for (size_t i = 0; i < clist_ptrs->size(); ++i) {
  68. if ((ret.size() == size_t(0)) ||
  69. (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) {
  70. ret.emplace_back(std::move(clist_ptrs->at(id[i])));
  71. }
  72. }
  73. *clist_ptrs = std::move(ret);
  74. }
  75. void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) {
  76. MS_EXCEPTION_IF_NULL(origin_cost);
  77. if (is_redistribution) {
  78. // Redistribution cost
  79. if ((origin_cost->communication_redis_forward_ > EPS) &&
  80. (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
  81. origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST;
  82. } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) {
  83. origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS;
  84. }
  85. if ((origin_cost->communication_redis_backward_ > EPS) &&
  86. (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
  87. origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST;
  88. } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) {
  89. origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS;
  90. }
  91. origin_cost->communication_cost_ =
  92. origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_;
  93. origin_cost->communication_without_parameter_ = origin_cost->communication_cost_;
  94. origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_;
  95. } else {
  96. // Operator cost
  97. double backward = 0.0;
  98. if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) {
  99. backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_;
  100. }
  101. // forward cost
  102. if ((origin_cost->communication_without_parameter_ > EPS) &&
  103. (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) {
  104. origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST;
  105. } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) {
  106. origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS;
  107. }
  108. // total
  109. if (origin_cost->communication_cost_ > EPS) {
  110. origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward;
  111. }
  112. if (origin_cost->communication_with_partial_para_ > EPS) {
  113. origin_cost->communication_with_partial_para_ =
  114. origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward;
  115. }
  116. }
  117. }
  118. } // namespace parallel
  119. } // namespace mindspore