You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

edge_costmodel.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/auto_parallel/edge_costmodel.h"
  17. #include <algorithm>
  18. #include <functional>
  19. #include <iterator>
  20. #include <utility>
  21. #include "parallel/auto_parallel/costmodel.h"
  22. #include "parallel/auto_parallel/graph_costmodel.h"
  23. #include "parallel/tensor_layout/tensor_redistribution.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. Status Edge::InitEdgeCost() {
  27. bool has_available_cost = false;
  28. for (auto &swc : prev_op_->GetStrategyCost()) {
  29. MS_EXCEPTION_IF_NULL(swc);
  30. pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr));
  31. }
  32. for (auto &swc : next_op_->GetStrategyCost()) {
  33. MS_EXCEPTION_IF_NULL(swc);
  34. next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr));
  35. }
  36. if (is_identity_edge) {
  37. for (auto &target_output : pre_op_output_) {
  38. auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
  39. auto target_output_str = target_output.first;
  40. for (auto &target_input : next_op_input_) {
  41. auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
  42. auto target_input_str = target_input.first;
  43. if (target_output_lyt == target_input_lyt) {
  44. CostPtrKey ck = {target_output_str, target_input_str};
  45. CostPtr cost = std::make_shared<Cost>(0.0, 0.0);
  46. MS_EXCEPTION_IF_NULL(cost);
  47. cost->communication_without_parameter_ = 0.0;
  48. cost->communication_with_partial_para_ = 0.0;
  49. CostPtrList cl;
  50. cl.push_back(cost);
  51. (void)cost_map_.emplace(std::make_pair(ck, cl));
  52. has_available_cost = true;
  53. }
  54. }
  55. }
  56. } else {
  57. for (auto &target_output : pre_op_output_) {
  58. auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
  59. auto target_output_str = target_output.first;
  60. auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_];
  61. auto type = prev_op_->outputs_type()[prev_op_output_index_];
  62. for (auto &target_input : next_op_input_) {
  63. auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
  64. auto target_input_str = target_input.first;
  65. CostPtr cost;
  66. if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) {
  67. MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed";
  68. }
  69. MS_EXCEPTION_IF_NULL(cost);
  70. MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_
  71. << ", communication_cost: " << cost->communication_cost_
  72. << ", communication_without_parameter_: " << cost->communication_without_parameter_
  73. << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
  74. // refine communication cost calculation for practice
  75. RefineForPracticalCost(cost, true);
  76. CostPtrKey ck = {target_output_str, target_input_str};
  77. CostPtrList cl;
  78. cl.push_back(cost);
  79. (void)cost_map_.emplace(std::make_pair(ck, cl));
  80. has_available_cost = true;
  81. }
  82. }
  83. }
  84. if (!has_available_cost) {
  85. if (FULLY_USE_DEVICES) {
  86. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
  87. << " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
  88. "'fully_use_devices' false.";
  89. } else if (ELEMENTWISE_OP_STRA_FOLLOW) {
  90. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
  91. << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
  92. "Try to set 'elementwise_op_strategy_follow' false.";
  93. }
  94. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed.";
  95. }
  96. return Status::SUCCESS;
  97. }
  98. Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout,
  99. size_t type_length, TypePtr type, CostPtr *cost) {
  100. MS_EXCEPTION_IF_NULL(prev_op_);
  101. MS_EXCEPTION_IF_NULL(cost);
  102. RankList dev_list = prev_op_->global_device_list();
  103. TensorRedistribution tensor_redistribution(false);
  104. // Init TensorRedistribution
  105. if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) {
  106. MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
  107. }
  108. if (tensor_redistribution.ComputeCost() == FAILED) {
  109. MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
  110. }
  111. double comm_cost = tensor_redistribution.comm_cost();
  112. double forward_comm_cost = tensor_redistribution.forward_comm_cost();
  113. double backward_comm_cost = tensor_redistribution.backward_comm_cost();
  114. double computation_cost = tensor_redistribution.computation_cost();
  115. double mem_cost = tensor_redistribution.memory_cost();
  116. // Now AllGather, ReduceScatter, AlltoAll don't support bool type
  117. MS_EXCEPTION_IF_NULL(type);
  118. if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) {
  119. computation_cost = INF;
  120. comm_cost = INF;
  121. MS_LOG(WARNING) << "Communication Operators don't support bool dtype!";
  122. }
  123. *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost);
  124. (*cost)->communication_without_parameter_ = type_length * comm_cost;
  125. (*cost)->communication_with_partial_para_ =
  126. (*cost)->communication_without_parameter_ +
  127. COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
  128. (*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
  129. (*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
  130. (*cost)->memory_with_reuse_ = mem_cost;
  131. return Status::SUCCESS;
  132. }
  133. CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) {
  134. CostPtrKey ck = {output_str, input_str};
  135. CostPtrList result;
  136. if (cost_map_.find(ck) != cost_map_.end()) {
  137. return cost_map_.at(ck);
  138. }
  139. return result;
  140. }
  141. CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector<EdgePtr> &edges,
  142. const StrategyPtr &input_st_ptr) {
  143. std::function<CostPtrList(EdgePtr)> LocalGetCostList = [&](const EdgePtr &edge) {
  144. MS_EXCEPTION_IF_NULL(edge);
  145. return edge->GetCostList(output_st_ptr, input_st_ptr);
  146. };
  147. CostPtrList result;
  148. std::vector<CostPtrList> all_cost_list;
  149. all_cost_list.resize(edges.size());
  150. (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
  151. CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
  152. std::function<void(size_t, double, double, double, double)> recursive =
  153. [&](size_t k, double computation, double memory, double communication, double communication_without_para) {
  154. if (k == edges.size()) {
  155. auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
  156. CostPtr new_cost = std::make_shared<Cost>(computation, communication);
  157. MS_EXCEPTION_IF_NULL(new_cost);
  158. new_cost->communication_without_parameter_ = communication_without_para;
  159. new_cost->communication_with_partial_para_ =
  160. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  161. new_cost->memory_with_reuse_ = memory;
  162. new_cost->decision_ptr_ = decision;
  163. result.push_back(new_cost);
  164. return;
  165. }
  166. for (auto &c : all_cost_list[k]) {
  167. MS_EXCEPTION_IF_NULL(c);
  168. selected_cost_list[k] = c;
  169. recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
  170. communication + c->communication_cost_,
  171. communication_without_para + c->communication_without_parameter_);
  172. }
  173. };
  174. recursive(0, 0.0, 0.0, 0.0, 0.0);
  175. SimplifyForDreasingCommunicationWithPartialPara(&result);
  176. return result;
  177. }
  178. void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector<EdgePtr> &edges, OperatorInfoPtr) {
  179. bool valid = false;
  180. for (const auto &output_pair : pre_op_output_) {
  181. StrategyPtr output_st_ptr = output_pair.first;
  182. for (const auto &input_pair : next_op_input_) {
  183. StrategyPtr input_st_ptr = input_pair.first;
  184. CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr);
  185. CostPtrKey key = {output_st_ptr, input_st_ptr};
  186. cost_map_[key] = clist;
  187. if ((!valid) && (!clist.empty())) {
  188. valid = true;
  189. }
  190. }
  191. }
  192. if (!valid) {
  193. MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
  194. }
  195. }
  196. void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list,
  197. const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list,
  198. CostPtrList *ret_cost_list) {
  199. for (auto &left_cost : left_cost_list) {
  200. MS_EXCEPTION_IF_NULL(left_cost);
  201. for (auto &middle_cost : middle_cost_list) {
  202. MS_EXCEPTION_IF_NULL(middle_cost);
  203. for (auto &right_cost : right_cost_list) {
  204. MS_EXCEPTION_IF_NULL(right_cost);
  205. double computation =
  206. left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
  207. double communication =
  208. left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
  209. double communication_without_para = left_cost->communication_without_parameter_ +
  210. middle_cost->communication_without_parameter_ +
  211. right_cost->communication_without_parameter_;
  212. double memory_cost =
  213. left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
  214. auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
  215. auto cost = std::make_shared<Cost>(computation, communication, decision);
  216. MS_EXCEPTION_IF_NULL(cost);
  217. cost->communication_without_parameter_ = communication_without_para;
  218. cost->communication_with_partial_para_ =
  219. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  220. cost->memory_with_reuse_ = memory_cost;
  221. ret_cost_list->emplace_back(std::move(cost));
  222. }
  223. }
  224. }
  225. }
  226. CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr,
  227. const OperatorInfoPtr &op, const EdgePtr &e2,
  228. const StrategyPtr &input_st_ptr) {
  229. MS_EXCEPTION_IF_NULL(op);
  230. MS_EXCEPTION_IF_NULL(e1);
  231. MS_EXCEPTION_IF_NULL(e2);
  232. CostPtrList result;
  233. for (const auto &op_strategy : op->GetStrategyCost()) {
  234. MS_EXCEPTION_IF_NULL(op_strategy);
  235. auto middle_strategy = op_strategy->strategy_ptr;
  236. CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
  237. op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
  238. }
  239. SimplifyForDreasingCommunicationWithPartialPara(&result);
  240. return result;
  241. }
  242. void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) {
  243. bool valid = false;
  244. for (const auto &output_pair : pre_op_output_) {
  245. StrategyPtr output_st_ptr = output_pair.first;
  246. for (const auto &input_pair : next_op_input_) {
  247. StrategyPtr input_st_ptr = input_pair.first;
  248. CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr);
  249. CostPtrKey key = {output_st_ptr, input_st_ptr};
  250. cost_map_[key] = clist;
  251. if ((!valid) && (!clist.empty())) {
  252. valid = true;
  253. }
  254. }
  255. }
  256. if (!valid) {
  257. MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
  258. }
  259. }
  260. Status Edge::CalculateMemoryCost() {
  261. if (is_output_parameter_involve_ == -1) {
  262. MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
  263. return FAILED;
  264. }
  265. if (is_output_parameter_involve_ == 0) {
  266. // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
  267. // unnecessary to keep them in memory.
  268. for (auto &cost_kv : cost_map_) {
  269. auto &cost_v = cost_kv.second;
  270. if (!cost_v.empty()) {
  271. cost_v[0]->memory_with_reuse_ = 0;
  272. }
  273. }
  274. }
  275. return SUCCESS;
  276. }
  277. } // namespace parallel
  278. } // namespace mindspore