You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

costmodel.h 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_
  18. #include <algorithm>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "frontend/parallel/strategy.h"
  24. #include "frontend/parallel/tensor_layout/tensor_info.h"
  25. namespace mindspore {
  26. namespace parallel {
  27. struct Decision;
  28. using OperatorName = std::string;
  29. using Attr = std::pair<std::string, ValuePtr>;
  30. using Param = std::pair<std::pair<std::string, ValuePtr>, int32_t>;
  31. using OperatorParams = std::vector<Param>;
  32. using OperatorAttrs = std::vector<Attr>;
  33. // OutPutInfo.fist: true if the operator's output is a tuple
  34. // OutPutInfo.second: elements number of the tuple output. Only meaningful if OutPutInfo.fist is true.
  35. using OutPutInfo = std::pair<bool, uint32_t>;
  36. using OutPutInfoVector = std::vector<OutPutInfo>;
  37. using OperatorArgs = std::pair<OperatorAttrs, OperatorParams>;
  38. using Operator = std::pair<OperatorName, OperatorArgs>;
  39. using OperatorVector = std::vector<Operator>;
  40. using RedistributionOpListPtr = std::shared_ptr<std::pair<OperatorVector, OutPutInfoVector>>;
  41. struct Cost {
  42. Cost();
  43. Cost(double computation, double commuication, const std::shared_ptr<Decision> &decision_ = nullptr)
  44. : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) {
  45. memory_with_reuse_ = 0.0;
  46. communication_without_parameter_ = 0.0;
  47. communication_with_partial_para_ = 0.0;
  48. communication_redis_forward_ = 0.0;
  49. communication_redis_backward_ = 0.0;
  50. communication_forward_ = 0.0;
  51. }
  52. // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase
  53. double memory_with_reuse_;
  54. // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
  55. // by ONLY forward phase
  56. double computation_cost_;
  57. // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution)
  58. double communication_cost_;
  59. // communication_without_parameter_ = communication_cost_ - (backward communication from operators)
  60. double communication_without_parameter_;
  61. // communication_with_partial_para_ =
  62. // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ )
  63. double communication_with_partial_para_;
  64. // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution.
  65. double communication_forward_;
  66. double communication_redis_forward_;
  67. double communication_redis_backward_;
  68. std::shared_ptr<Decision> decision_ptr_;
  69. };
  70. using CostPtr = std::shared_ptr<Cost>;
  71. using CostPtrList = std::vector<std::shared_ptr<Cost>>;
  72. class StrategyWithCost {
  73. public:
  74. StrategyWithCost(StrategyPtr strategy, std::vector<TensorInfo> inputs_, std::vector<TensorInfo> outputs_)
  75. : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {}
  76. StrategyWithCost(StrategyPtr strategy, CostPtrList c_list)
  77. : strategy_ptr(std::move(strategy)), cost_list(std::move(c_list)) {}
  78. StrategyWithCost(const StrategyWithCost &swc) = delete;
  79. StrategyWithCost(StrategyWithCost &&swc)
  80. : strategy_ptr(swc.strategy_ptr),
  81. inputs_ptr(swc.inputs_ptr),
  82. outputs_ptr(swc.outputs_ptr),
  83. cost_list(swc.cost_list) {}
  84. ~StrategyWithCost() = default;
  85. StrategyPtr strategy_ptr;
  86. std::vector<TensorInfo> inputs_ptr;
  87. std::vector<TensorInfo> outputs_ptr;
  88. CostPtrList cost_list;
  89. };
  90. enum DecisionType {
  91. OP_ELIMINATION,
  92. EDGE_ELIMINATION,
  93. MERGE_ELIMINATION,
  94. CONTRACT_ELIMINATION,
  95. SOURCE_ELIMINATION,
  96. TRIANGLE_ELIMINATION,
  97. STAR_ELIMINATION,
  98. FINAL_TYPE,
  99. FINAL_SINGLE
  100. };
  101. struct Decision : public Base {
  102. ~Decision() override = default;
  103. DecisionType type_;
  104. };
  105. // 'OpEliminationDecision' is for the Operator Elimination in DP algorithm: u --> v --> w ==> u --> w.
  106. // This data structure records the strategy 'op_strategy_' for v, the edge cost 'left_cost_' for 'u --> v', the
  107. // operator cost 'middle_cost_' for v, and the edge cost 'right_cost_' for 'v --> w'
  108. struct OpEliminationDecision : public Decision {
  109. OpEliminationDecision(StrategyPtr op_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost)
  110. : op_strategy_(std::move(op_stra)),
  111. left_cost_(std::move(l_cost)),
  112. middle_cost_(std::move(m_cost)),
  113. right_cost_(std::move(r_cost)) {
  114. type_ = DecisionType::OP_ELIMINATION;
  115. }
  116. StrategyPtr op_strategy_;
  117. CostPtr left_cost_;
  118. CostPtr middle_cost_;
  119. CostPtr right_cost_;
  120. MS_DECLARE_PARENT(OpEliminationDecision, Decision);
  121. };
  122. /* 'EdgeEliminationDecision' is for the Edge Elimination in DP algorithm:
  123. ____
  124. / \
  125. u v ==> u --> v, which replace the multi-edges by a single edge.
  126. \____/
  127. This data structure records the cost list for all edges 'edges_cost_list_'
  128. */
  129. struct EdgeEliminationDecision : public Decision {
  130. explicit EdgeEliminationDecision(CostPtrList cost_list) : edges_cost_list_(std::move(cost_list)) {
  131. type_ = DecisionType::EDGE_ELIMINATION;
  132. }
  133. CostPtrList edges_cost_list_;
  134. MS_DECLARE_PARENT(EdgeEliminationDecision, Decision);
  135. };
  136. // 'MergeEliminationDecision' is for the Merge Elimination in DP algorithm:
  137. // w
  138. // |
  139. // | ==> u --> v
  140. // u --> v In the original graph, v has two alive incoming edges, w has one alive outgoing edge,
  141. // and w has zero alive incoming edges. After the Merge Elimination, the result graph contains only 'u -- >v'.
  142. // This data structure records the strategy 'merged_op_strategy_' for operator 'w',
  143. // the cost 'merged_op_cost_' for operator 'w', and the edge cost 'edge_cost_' for 'w --> v'.
  144. struct MergeEliminationDecision : public Decision {
  145. MergeEliminationDecision(StrategyPtr op_stra, CostPtr op_cost, CostPtr edge_c, StrategyPtr tar_op_stra,
  146. CostPtr target_op_c)
  147. : merged_op_strategy_(std::move(op_stra)),
  148. merged_op_cost_(std::move(op_cost)),
  149. edge_cost_(std::move(edge_c)),
  150. target_op_strategy_(std::move(tar_op_stra)),
  151. target_op_cost_(std::move(target_op_c)) {
  152. type_ = DecisionType::MERGE_ELIMINATION;
  153. }
  154. StrategyPtr merged_op_strategy_;
  155. CostPtr merged_op_cost_;
  156. CostPtr edge_cost_;
  157. StrategyPtr target_op_strategy_;
  158. CostPtr target_op_cost_;
  159. MS_DECLARE_PARENT(MergeEliminationDecision, Decision);
  160. };
  161. // 'ContractEliminationDecision' is for the Contract Elimination in DP algorithm:
  162. // u --> v
  163. // |
  164. // | ==> u --> w
  165. // w In the original graph, u has two alive outgoing edges, v has one alive incoming edge,
  166. // and v has zero outgoing edge. After the Contract Elimination, the result graph contains only 'u --> w'.
  167. // This data structure records the strategy 'contracted_op_strategy_' for operator 'v', the cost for
  168. // operator 'contracted_op_cost_', and the edge cost for 'edge_cost_'.
  169. struct ContractEliminationDecision : public Decision {
  170. ContractEliminationDecision(StrategyPtr contra_stra, CostPtr contra_op_cost, CostPtr edge_cost,
  171. StrategyPtr target_stra, CostPtr tar_cost)
  172. : contracted_op_strategy_(std::move(contra_stra)),
  173. contracted_op_cost_(std::move(contra_op_cost)),
  174. edge_cost_(std::move(edge_cost)),
  175. target_op_strategy_(std::move(target_stra)),
  176. target_cost_(std::move(tar_cost)) {
  177. type_ = DecisionType::CONTRACT_ELIMINATION;
  178. }
  179. StrategyPtr contracted_op_strategy_;
  180. CostPtr contracted_op_cost_;
  181. CostPtr edge_cost_;
  182. StrategyPtr target_op_strategy_;
  183. CostPtr target_cost_;
  184. MS_DECLARE_PARENT(ContractEliminationDecision, Decision);
  185. };
  186. /* 'SourceEliminationDecision' is for the source Elimination in DP algorithm:
  187. * 1 1,5
  188. * / \ // \\
  189. * / \ // \\
  190. * / \ // \\
  191. * / \ // \\
  192. * 2 <- 5 -> 3 ==> 2 3
  193. * \ / \ /
  194. * \ / \ /
  195. * \ / \ /
  196. * 4 4
  197. *
  198. * In the original graph, '1' has two alive outgoing edges and no incoming edges. '5' has two alive outgoing edges and
  199. * no incoming edges. '4' has two alive incoming edges and no outgoing edges. Source Elimination will merge '5' into
  200. * '1' new edges are generated to replace the old ones incident to '1' and '5'.
  201. *
  202. */
  203. struct SourceEliminationDecision : public Decision {
  204. SourceEliminationDecision(StrategyPtr op1_stra, CostPtr op1_c, StrategyPtr op2_stra, CostPtr op2_c)
  205. : op1_strategy_(std::move(op1_stra)),
  206. op1_cost_(std::move(op1_c)),
  207. op2_strategy_(std::move(op2_stra)),
  208. op2_cost_(std::move(op2_c)) {
  209. type_ = DecisionType::SOURCE_ELIMINATION;
  210. }
  211. StrategyPtr op1_strategy_;
  212. CostPtr op1_cost_;
  213. StrategyPtr op2_strategy_;
  214. CostPtr op2_cost_;
  215. MS_DECLARE_PARENT(SourceEliminationDecision, Decision);
  216. };
  217. /* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm:
  218. *
  219. * u
  220. * / \
  221. * / \
  222. * v --- w ==> v --- w In the original graph, u has 2 outgoing edges, v has 1 outgoing edge,
  223. * and w has 2 incoming edges, u can be eliminated into v.
  224. * 'eliminated_op_strategy_' is for u, 'eliminated_op_cost_' is for u, 'eliminated_left_edge_' is for edge u --> v,
  225. * 'eliminated_right_edge_' is for edge u --> w.
  226. */
  227. struct TriangleEliminationDecision : public Decision {
  228. TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
  229. StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra)
  230. : eliminated_op_strategy_(std::move(elimi_stra)),
  231. eliminated_op_cost_(std::move(elimi_op_cost)),
  232. left_edge_cost_(std::move(l_edge_cost)),
  233. right_edge_cost_(std::move(r_edge_cost)),
  234. left_node_strategy_(std::move(left_stra)),
  235. left_node_cost_(std::move(l_node_cost)),
  236. right_node_strategy_(std::move(right_stra)) {
  237. type_ = DecisionType::TRIANGLE_ELIMINATION;
  238. }
  239. StrategyPtr eliminated_op_strategy_;
  240. CostPtr eliminated_op_cost_;
  241. CostPtr left_edge_cost_;
  242. CostPtr right_edge_cost_;
  243. StrategyPtr left_node_strategy_;
  244. CostPtr left_node_cost_;
  245. StrategyPtr right_node_strategy_;
  246. MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
  247. };
  248. /* 'StarEliminationDecision' is for the Star Elimination in DP algorithm:
  249. *
  250. * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges.
  251. * In addition, v and w have other complicated connections, resulting in v and w can not be performed other
  252. * eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple
  253. * connected components.
  254. * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
  255. */
  256. struct StarEliminationDecision : public Decision {
  257. StarEliminationDecision(StrategyPtr elimi_op_stra, CostPtr elimi_op_cost, CostPtrList succ_edges_clist,
  258. std::vector<StrategyPtr> succ_ops_stra_list, CostPtrList succ_ops_clist)
  259. : eliminated_op_strategy_(std::move(elimi_op_stra)),
  260. eliminated_op_cost_(std::move(elimi_op_cost)),
  261. succ_edges_cost_list_(std::move(succ_edges_clist)),
  262. succ_ops_stra_list_(std::move(succ_ops_stra_list)),
  263. succ_ops_cost_list_(std::move(succ_ops_clist)) {
  264. type_ = DecisionType::STAR_ELIMINATION;
  265. }
  266. StrategyPtr eliminated_op_strategy_;
  267. CostPtr eliminated_op_cost_;
  268. CostPtrList succ_edges_cost_list_;
  269. std::vector<StrategyPtr> succ_ops_stra_list_;
  270. CostPtrList succ_ops_cost_list_;
  271. MS_DECLARE_PARENT(StarEliminationDecision, Decision);
  272. };
  273. // This data structure records the decision for the graph which contains two nodes: u --> v. This includes
  274. // the strategy 'u_strategy_' for 'u', the strategy 'v_strategy_' for 'v', the cost 'left_cost_' for 'u'.
  275. struct FinalDecision : public Decision {
  276. FinalDecision(StrategyPtr u_stra, StrategyPtr v_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost)
  277. : u_strategy_(std::move(u_stra)),
  278. v_strategy_(std::move(v_stra)),
  279. left_cost_(std::move(l_cost)),
  280. middle_cost_(std::move(m_cost)),
  281. right_cost_(std::move(r_cost)) {
  282. type_ = DecisionType::FINAL_TYPE;
  283. }
  284. StrategyPtr u_strategy_;
  285. StrategyPtr v_strategy_;
  286. CostPtr left_cost_;
  287. CostPtr middle_cost_;
  288. CostPtr right_cost_;
  289. MS_DECLARE_PARENT(FinalDecision, Decision);
  290. };
  291. // This data structure records the final decision for the graph containing a single node: u. This includes
  292. // the strategy 'u_strategy_' for 'u', the cost 'u_cost_' for 'u'.
  293. struct FinalSingleDecision : public Decision {
  294. FinalSingleDecision(StrategyPtr u_stra, CostPtr u_cost) : u_strategy_(std::move(u_stra)), u_cost_(std::move(u_cost)) {
  295. type_ = DecisionType::FINAL_SINGLE;
  296. }
  297. StrategyPtr u_strategy_;
  298. CostPtr u_cost_;
  299. MS_DECLARE_PARENT(FinalSingleDecision, Decision);
  300. };
  301. using DecisionPtr = std::shared_ptr<Decision>;
  302. using OpEliminationDecisionPtr = std::shared_ptr<OpEliminationDecision>;
  303. using EdgeEliminationDecisionPtr = std::shared_ptr<EdgeEliminationDecision>;
  304. using MergeEliminationDecisionPtr = std::shared_ptr<MergeEliminationDecision>;
  305. using ContractEliminationDecisionPtr = std::shared_ptr<ContractEliminationDecision>;
  306. using SourceEliminationDecisionPtr = std::shared_ptr<SourceEliminationDecision>;
  307. using TriangleEliminationDecisionPtr = std::shared_ptr<TriangleEliminationDecision>;
  308. using StarEliminationDecisionPtr = std::shared_ptr<StarEliminationDecision>;
  309. using FinalDecisionPtr = std::shared_ptr<FinalDecision>;
  310. using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>;
  311. void Simplify(CostPtrList *clist);
  312. void SimplifyForDecreasingCommunicationForward(CostPtrList *clist);
  313. void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist);
  314. void RefineForPracticalCost(const CostPtr &, bool is_redistribution);
  315. } // namespace parallel
  316. } // namespace mindspore
  317. #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_