You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

edge_costmodel.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/auto_parallel/edge_costmodel.h"
  17. #include <algorithm>
  18. #include <functional>
  19. #include <iterator>
  20. #include <utility>
  21. #include "parallel/auto_parallel/costmodel.h"
  22. #include "parallel/auto_parallel/graph_costmodel.h"
  23. #include "parallel/tensor_layout/tensor_redistribution.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. Status Edge::InitEdgeCost() {
  27. bool has_available_cost = false;
  28. for (auto &swc : prev_op_->GetStrategyCost()) {
  29. MS_EXCEPTION_IF_NULL(swc);
  30. pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr));
  31. }
  32. for (auto &swc : next_op_->GetStrategyCost()) {
  33. MS_EXCEPTION_IF_NULL(swc);
  34. next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr));
  35. }
  36. if (is_identity_edge) {
  37. for (auto &target_output : pre_op_output_) {
  38. auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
  39. auto target_output_str = target_output.first;
  40. for (auto &target_input : next_op_input_) {
  41. auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
  42. auto target_input_str = target_input.first;
  43. if (target_output_lyt == target_input_lyt) {
  44. CostPtrKey ck = {target_output_str, target_input_str};
  45. CostPtr cost = std::make_shared<Cost>(0.0, 0.0);
  46. MS_EXCEPTION_IF_NULL(cost);
  47. cost->communication_without_parameter_ = 0.0;
  48. cost->communication_with_partial_para_ = 0.0;
  49. CostPtrList cl;
  50. cl.push_back(cost);
  51. (void)cost_map_.emplace(std::make_pair(ck, cl));
  52. has_available_cost = true;
  53. }
  54. }
  55. }
  56. } else {
  57. for (auto &target_output : pre_op_output_) {
  58. auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
  59. auto target_output_str = target_output.first;
  60. auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_];
  61. auto type = prev_op_->outputs_type()[prev_op_output_index_];
  62. for (auto &target_input : next_op_input_) {
  63. auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
  64. auto target_input_str = target_input.first;
  65. CostPtr cost;
  66. if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) {
  67. MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed";
  68. }
  69. MS_EXCEPTION_IF_NULL(cost);
  70. MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_
  71. << ", communication_cost: " << cost->communication_cost_
  72. << ", communication_without_parameter_: " << cost->communication_without_parameter_
  73. << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
  74. // refine communication cost calculation for practice
  75. RefineForPracticalCost(cost, true);
  76. cost->communication_forward_ = cost->communication_redis_forward_;
  77. CostPtrKey ck = {target_output_str, target_input_str};
  78. CostPtrList cl;
  79. cl.push_back(cost);
  80. (void)cost_map_.emplace(std::make_pair(ck, cl));
  81. has_available_cost = true;
  82. }
  83. }
  84. }
  85. if (!has_available_cost) {
  86. if (FULLY_USE_DEVICES) {
  87. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
  88. << " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
  89. "'fully_use_devices' false.";
  90. } else if (ELEMENTWISE_OP_STRA_FOLLOW) {
  91. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
  92. << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
  93. "Try to set 'elementwise_op_strategy_follow' false.";
  94. }
  95. if (edge_name_.find(RESHAPE) != std::string::npos) {
  96. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
  97. << " failed, it may be caused by setting different strategies for operators following Reshape. "
  98. "Try to fix that.";
  99. }
  100. MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed.";
  101. }
  102. return Status::SUCCESS;
  103. }
  104. Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout,
  105. size_t type_length, TypePtr type, CostPtr *cost) {
  106. MS_EXCEPTION_IF_NULL(prev_op_);
  107. MS_EXCEPTION_IF_NULL(cost);
  108. RankList dev_list = prev_op_->global_device_list();
  109. TensorRedistribution tensor_redistribution(false);
  110. // Init TensorRedistribution
  111. if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) {
  112. MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
  113. }
  114. if (tensor_redistribution.ComputeCost() == FAILED) {
  115. MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
  116. }
  117. double comm_cost = tensor_redistribution.comm_cost();
  118. double forward_comm_cost = tensor_redistribution.forward_comm_cost();
  119. double backward_comm_cost = tensor_redistribution.backward_comm_cost();
  120. double computation_cost = tensor_redistribution.computation_cost();
  121. double mem_cost = tensor_redistribution.memory_cost();
  122. // Now AllGather, ReduceScatter, AlltoAll don't support bool type
  123. MS_EXCEPTION_IF_NULL(type);
  124. if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) {
  125. computation_cost = INF;
  126. comm_cost = INF;
  127. MS_LOG(WARNING) << "Communication Operators don't support bool dtype!";
  128. }
  129. *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost);
  130. (*cost)->communication_without_parameter_ = type_length * comm_cost;
  131. (*cost)->communication_with_partial_para_ =
  132. (*cost)->communication_without_parameter_ +
  133. COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
  134. (*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
  135. (*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
  136. (*cost)->memory_with_reuse_ = mem_cost;
  137. return Status::SUCCESS;
  138. }
  139. CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) {
  140. CostPtrKey ck = {output_str, input_str};
  141. CostPtrList result;
  142. if (cost_map_.find(ck) != cost_map_.end()) {
  143. return cost_map_.at(ck);
  144. }
  145. return result;
  146. }
  147. CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector<EdgePtr> &edges,
  148. const StrategyPtr &input_st_ptr) {
  149. std::function<CostPtrList(EdgePtr)> LocalGetCostList = [&](const EdgePtr &edge) {
  150. MS_EXCEPTION_IF_NULL(edge);
  151. return edge->GetCostList(output_st_ptr, input_st_ptr);
  152. };
  153. CostPtrList result;
  154. std::vector<CostPtrList> all_cost_list;
  155. all_cost_list.resize(edges.size());
  156. (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
  157. CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
  158. std::function<void(size_t, double, double, double, double, double)> recursive =
  159. [&](size_t k, double computation, double memory, double communication, double communication_without_para,
  160. double communication_forward) {
  161. if (k == edges.size()) {
  162. auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
  163. CostPtr new_cost = std::make_shared<Cost>(computation, communication);
  164. MS_EXCEPTION_IF_NULL(new_cost);
  165. new_cost->communication_without_parameter_ = communication_without_para;
  166. new_cost->communication_with_partial_para_ =
  167. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  168. new_cost->memory_with_reuse_ = memory;
  169. new_cost->communication_forward_ = communication_forward;
  170. new_cost->decision_ptr_ = decision;
  171. result.push_back(new_cost);
  172. return;
  173. }
  174. for (auto &c : all_cost_list[k]) {
  175. MS_EXCEPTION_IF_NULL(c);
  176. selected_cost_list[k] = c;
  177. recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
  178. communication + c->communication_cost_,
  179. communication_without_para + c->communication_without_parameter_,
  180. communication_forward + c->communication_forward_);
  181. }
  182. };
  183. recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
  184. Simplify(&result);
  185. return result;
  186. }
  187. void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector<EdgePtr> &edges, OperatorInfoPtr) {
  188. bool valid = false;
  189. for (const auto &output_pair : pre_op_output_) {
  190. StrategyPtr output_st_ptr = output_pair.first;
  191. for (const auto &input_pair : next_op_input_) {
  192. StrategyPtr input_st_ptr = input_pair.first;
  193. CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr);
  194. CostPtrKey key = {output_st_ptr, input_st_ptr};
  195. cost_map_[key] = clist;
  196. if ((!valid) && (!clist.empty())) {
  197. valid = true;
  198. }
  199. }
  200. }
  201. if (!valid) {
  202. MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
  203. }
  204. }
  205. void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list,
  206. const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list,
  207. CostPtrList *ret_cost_list) {
  208. for (auto &left_cost : left_cost_list) {
  209. MS_EXCEPTION_IF_NULL(left_cost);
  210. for (auto &middle_cost : middle_cost_list) {
  211. MS_EXCEPTION_IF_NULL(middle_cost);
  212. for (auto &right_cost : right_cost_list) {
  213. MS_EXCEPTION_IF_NULL(right_cost);
  214. double computation =
  215. left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
  216. double communication =
  217. left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
  218. double communication_forward =
  219. left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
  220. double communication_without_para = left_cost->communication_without_parameter_ +
  221. middle_cost->communication_without_parameter_ +
  222. right_cost->communication_without_parameter_;
  223. double memory_cost =
  224. left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
  225. auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
  226. auto cost = std::make_shared<Cost>(computation, communication, decision);
  227. MS_EXCEPTION_IF_NULL(cost);
  228. cost->communication_without_parameter_ = communication_without_para;
  229. cost->communication_with_partial_para_ =
  230. communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
  231. cost->memory_with_reuse_ = memory_cost;
  232. cost->communication_forward_ = communication_forward;
  233. ret_cost_list->emplace_back(std::move(cost));
  234. }
  235. }
  236. }
  237. }
  238. CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr,
  239. const OperatorInfoPtr &op, const EdgePtr &e2,
  240. const StrategyPtr &input_st_ptr) {
  241. MS_EXCEPTION_IF_NULL(op);
  242. MS_EXCEPTION_IF_NULL(e1);
  243. MS_EXCEPTION_IF_NULL(e2);
  244. CostPtrList result;
  245. for (const auto &op_strategy : op->GetStrategyCost()) {
  246. MS_EXCEPTION_IF_NULL(op_strategy);
  247. auto middle_strategy = op_strategy->strategy_ptr;
  248. CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
  249. op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
  250. }
  251. Simplify(&result);
  252. return result;
  253. }
  254. void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) {
  255. bool valid = false;
  256. for (const auto &output_pair : pre_op_output_) {
  257. StrategyPtr output_st_ptr = output_pair.first;
  258. for (const auto &input_pair : next_op_input_) {
  259. StrategyPtr input_st_ptr = input_pair.first;
  260. CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr);
  261. CostPtrKey key = {output_st_ptr, input_st_ptr};
  262. cost_map_[key] = clist;
  263. if ((!valid) && (!clist.empty())) {
  264. valid = true;
  265. }
  266. }
  267. }
  268. if (!valid) {
  269. MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
  270. }
  271. }
  272. Status Edge::CalculateMemoryCost() {
  273. if (is_output_parameter_involve_ == -1) {
  274. MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
  275. return FAILED;
  276. }
  277. if (is_output_parameter_involve_ == 0) {
  278. // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
  279. // unnecessary to keep them in memory.
  280. for (auto &cost_kv : cost_map_) {
  281. auto &cost_v = cost_kv.second;
  282. if (!cost_v.empty()) {
  283. cost_v[0]->memory_with_reuse_ = 0;
  284. }
  285. }
  286. }
  287. return SUCCESS;
  288. }
  289. Status Edge::CalculateMemoryCostForInference() {
  290. // Currently, memory cost is NOT calculated for redistribution
  291. if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) {
  292. MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_;
  293. return FAILED;
  294. }
  295. for (auto &cost_kv : cost_map_) {
  296. auto &cost_v = cost_kv.second;
  297. if (!cost_v.empty()) {
  298. cost_v[0]->memory_with_reuse_ = 0;
  299. }
  300. }
  301. return SUCCESS;
  302. }
  303. } // namespace parallel
  304. } // namespace mindspore