You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

edge_costmodel.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/auto_parallel/edge_costmodel.h"
  17. #include <algorithm>
  18. #include <functional>
  19. #include <iterator>
  20. #include <utility>
  21. #include "parallel/auto_parallel/costmodel.h"
  22. #include "parallel/auto_parallel/graph_costmodel.h"
  23. #include "parallel/tensor_layout/tensor_redistribution.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. Status Edge::InitEdgeCost() {
  27. bool has_available_cost = false;
  28. for (auto &swc : prev_op_->GetStrategyCost()) {
  29. MS_EXCEPTION_IF_NULL(swc);
  30. pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr));
  31. }
  32. for (auto &swc : next_op_->GetStrategyCost()) {
  33. MS_EXCEPTION_IF_NULL(swc);
  34. next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr));
  35. }
  36. if (is_identity_edge) {
  37. for (auto &target_output : pre_op_output_) {
  38. auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
  39. auto target_output_str = target_output.first;
  40. for (auto &target_input : next_op_input_) {
  41. auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
  42. auto target_input_str = target_input.first;
  43. if (target_output_lyt == target_input_lyt) {
  44. CostPtrKey ck = {target_output_str, target_input_str};
  45. CostPtr cost = std::make_shared<Cost>(0.0, 0.0);
  46. MS_EXCEPTION_IF_NULL(cost);
  47. cost->communication_without_parameter_ = 0.0;
  48. cost->communication_with_partial_para_ = 0.0;
  49. CostPtrList cl;
  50. cl.push_back(cost);
  51. (void)cost_map_.emplace(std::make_pair(ck, cl));
  52. has_available_cost = true;
  53. }
  54. }
  55. }
  56. } else {
  57. for (auto &target_output : pre_op_output_) {
  58. auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
  59. auto target_output_str = target_output.first;
  60. auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_];
  61. auto type = prev_op_->outputs_type()[prev_op_output_index_];
  62. for (auto &target_input : next_op_input_) {
  63. auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
  64. auto target_input_str = target_input.first;
  65. CostPtr cost;
  66. if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) {
  67. MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed";
  68. }
  69. MS_EXCEPTION_IF_NULL(cost);
  70. MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_
  71. << ", communication_cost: " << cost->communication_cost_
  72. << ", communication_without_parameter_: " << cost->communication_without_parameter_
  73. << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
  74. // refine communication cost calculation for practice
  75. RefineForPracticalCost(cost, true);
  76. cost->communication_forward_ = cost->communication_redis_forward_;
  77. CostPtrKey ck = {target_output_str, target_input_str};
  78. CostPtrList cl;
  79. cl.push_back(cost);
  80. (void)cost_map_.emplace(std::make_pair(ck, cl));
  81. has_available_cost = true;
  82. }
  83. }
  84. }
  85. if (!has_available_cost) {
  86. if (FULLY_USE_DEVICES) {
  87. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
  88. << " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
  89. "'fully_use_devices' false.";
  90. } else if (ELEMENTWISE_OP_STRA_FOLLOW) {
  91. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
  92. << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
  93. "Try to set 'elementwise_op_strategy_follow' false.";
  94. }
  95. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed.";
  96. }
  97. return Status::SUCCESS;
  98. }
  99. Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout,
  100. size_t type_length, TypePtr type, CostPtr *cost) {
  101. MS_EXCEPTION_IF_NULL(prev_op_);
  102. MS_EXCEPTION_IF_NULL(cost);
  103. RankList dev_list = prev_op_->global_device_list();
  104. TensorRedistribution tensor_redistribution(false);
  105. // Init TensorRedistribution
  106. if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) {
  107. MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
  108. }
  109. if (tensor_redistribution.ComputeCost() == FAILED) {
  110. MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
  111. }
  112. double comm_cost = tensor_redistribution.comm_cost();
  113. double forward_comm_cost = tensor_redistribution.forward_comm_cost();
  114. double backward_comm_cost = tensor_redistribution.backward_comm_cost();
  115. double computation_cost = tensor_redistribution.computation_cost();
  116. double mem_cost = tensor_redistribution.memory_cost();
  117. // Now AllGather, ReduceScatter, AlltoAll don't support bool type
  118. MS_EXCEPTION_IF_NULL(type);
  119. if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) {
  120. computation_cost = INF;
  121. comm_cost = INF;
  122. MS_LOG(WARNING) << "Communication Operators don't support bool dtype!";
  123. }
  124. *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost);
  125. (*cost)->communication_without_parameter_ = type_length * comm_cost;
  126. (*cost)->communication_with_partial_para_ =
  127. (*cost)->communication_without_parameter_ +
  128. COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
  129. (*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
  130. (*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
  131. (*cost)->memory_with_reuse_ = mem_cost;
  132. return Status::SUCCESS;
  133. }
  134. CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) {
  135. CostPtrKey ck = {output_str, input_str};
  136. CostPtrList result;
  137. if (cost_map_.find(ck) != cost_map_.end()) {
  138. return cost_map_.at(ck);
  139. }
  140. return result;
  141. }
  142. CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector<EdgePtr> &edges,
  143. const StrategyPtr &input_st_ptr) {
  144. std::function<CostPtrList(EdgePtr)> LocalGetCostList = [&](const EdgePtr &edge) {
  145. MS_EXCEPTION_IF_NULL(edge);
  146. return edge->GetCostList(output_st_ptr, input_st_ptr);
  147. };
  148. CostPtrList result;
  149. std::vector<CostPtrList> all_cost_list;
  150. all_cost_list.resize(edges.size());
  151. (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
  152. CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
  153. std::function<void(size_t, double, double, double, double, double)> recursive =
  154. [&](size_t k, double computation, double memory, double communication, double communication_without_para,
  155. double communication_forward) {
  156. if (k == edges.size()) {
  157. auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
  158. CostPtr new_cost = std::make_shared<Cost>(computation, communication);
  159. MS_EXCEPTION_IF_NULL(new_cost);
  160. new_cost->communication_without_parameter_ = communication_without_para;
  161. new_cost->communication_with_partial_para_ =
  162. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  163. new_cost->memory_with_reuse_ = memory;
  164. new_cost->communication_forward_ = communication_forward;
  165. new_cost->decision_ptr_ = decision;
  166. result.push_back(new_cost);
  167. return;
  168. }
  169. for (auto &c : all_cost_list[k]) {
  170. MS_EXCEPTION_IF_NULL(c);
  171. selected_cost_list[k] = c;
  172. recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
  173. communication + c->communication_cost_,
  174. communication_without_para + c->communication_without_parameter_,
  175. communication_forward + c->communication_forward_);
  176. }
  177. };
  178. recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
  179. Simplify(&result);
  180. return result;
  181. }
  182. void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector<EdgePtr> &edges, OperatorInfoPtr) {
  183. bool valid = false;
  184. for (const auto &output_pair : pre_op_output_) {
  185. StrategyPtr output_st_ptr = output_pair.first;
  186. for (const auto &input_pair : next_op_input_) {
  187. StrategyPtr input_st_ptr = input_pair.first;
  188. CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr);
  189. CostPtrKey key = {output_st_ptr, input_st_ptr};
  190. cost_map_[key] = clist;
  191. if ((!valid) && (!clist.empty())) {
  192. valid = true;
  193. }
  194. }
  195. }
  196. if (!valid) {
  197. MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
  198. }
  199. }
  200. void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list,
  201. const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list,
  202. CostPtrList *ret_cost_list) {
  203. for (auto &left_cost : left_cost_list) {
  204. MS_EXCEPTION_IF_NULL(left_cost);
  205. for (auto &middle_cost : middle_cost_list) {
  206. MS_EXCEPTION_IF_NULL(middle_cost);
  207. for (auto &right_cost : right_cost_list) {
  208. MS_EXCEPTION_IF_NULL(right_cost);
  209. double computation =
  210. left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
  211. double communication =
  212. left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
  213. double communication_forward =
  214. left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
  215. double communication_without_para = left_cost->communication_without_parameter_ +
  216. middle_cost->communication_without_parameter_ +
  217. right_cost->communication_without_parameter_;
  218. double memory_cost =
  219. left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
  220. auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
  221. auto cost = std::make_shared<Cost>(computation, communication, decision);
  222. MS_EXCEPTION_IF_NULL(cost);
  223. cost->communication_without_parameter_ = communication_without_para;
  224. cost->communication_with_partial_para_ =
  225. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  226. cost->memory_with_reuse_ = memory_cost;
  227. cost->communication_forward_ = communication_forward;
  228. ret_cost_list->emplace_back(std::move(cost));
  229. }
  230. }
  231. }
  232. }
  233. CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr,
  234. const OperatorInfoPtr &op, const EdgePtr &e2,
  235. const StrategyPtr &input_st_ptr) {
  236. MS_EXCEPTION_IF_NULL(op);
  237. MS_EXCEPTION_IF_NULL(e1);
  238. MS_EXCEPTION_IF_NULL(e2);
  239. CostPtrList result;
  240. for (const auto &op_strategy : op->GetStrategyCost()) {
  241. MS_EXCEPTION_IF_NULL(op_strategy);
  242. auto middle_strategy = op_strategy->strategy_ptr;
  243. CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
  244. op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
  245. }
  246. Simplify(&result);
  247. return result;
  248. }
  249. void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) {
  250. bool valid = false;
  251. for (const auto &output_pair : pre_op_output_) {
  252. StrategyPtr output_st_ptr = output_pair.first;
  253. for (const auto &input_pair : next_op_input_) {
  254. StrategyPtr input_st_ptr = input_pair.first;
  255. CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr);
  256. CostPtrKey key = {output_st_ptr, input_st_ptr};
  257. cost_map_[key] = clist;
  258. if ((!valid) && (!clist.empty())) {
  259. valid = true;
  260. }
  261. }
  262. }
  263. if (!valid) {
  264. MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
  265. }
  266. }
  267. Status Edge::CalculateMemoryCost() {
  268. if (is_output_parameter_involve_ == -1) {
  269. MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
  270. return FAILED;
  271. }
  272. if (is_output_parameter_involve_ == 0) {
  273. // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
  274. // unnecessary to keep them in memory.
  275. for (auto &cost_kv : cost_map_) {
  276. auto &cost_v = cost_kv.second;
  277. if (!cost_v.empty()) {
  278. cost_v[0]->memory_with_reuse_ = 0;
  279. }
  280. }
  281. }
  282. return SUCCESS;
  283. }
  284. Status Edge::CalculateMemoryCostForInference() {
  285. // Currently, memory cost is NOT calculated for redistribution
  286. if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) {
  287. MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_;
  288. return FAILED;
  289. }
  290. for (auto &cost_kv : cost_map_) {
  291. auto &cost_v = cost_kv.second;
  292. if (!cost_v.empty()) {
  293. cost_v[0]->memory_with_reuse_ = 0;
  294. }
  295. }
  296. return SUCCESS;
  297. }
  298. } // namespace parallel
  299. } // namespace mindspore