You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_costmodel.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_
  17. #define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "../../common.h"
  24. #include "common/utils.h"
  25. #include "parallel/auto_parallel/edge_costmodel.h"
  26. #include "parallel/costmodel_context.h"
  27. #include "parallel/ops_info/operator_info.h"
  28. #include "parallel/ops_info/tmp_identity_info.h"
  29. namespace mindspore {
  30. namespace parallel {
  31. #define OPERATOR_TO_OPERATOR_CONNECTOR "-"
  32. #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
  33. #define DEFAULT_COST_MODEL_ALPHA 1.0
  34. #define DEFAULT_COST_MODEL_BETA 400.0
  35. #define DEFAULT_COST_MODEL_GAMMA 0.001
  36. #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
  37. #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
  38. #define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0
  39. #define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0
  40. #define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false
  41. #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16
  42. #define DEFAULT_FULLY_USE_DEVICES true
  43. #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
  44. #define DEFAULT_IS_MULTI_SUBGRAPHS false
  45. #define DEFAULT_RUN_PHASE 0
  46. #define TRAINING_PHASE 0
  47. #define INFERENCE_PHASE 1
  48. class CostGraph;
  49. using CostGraphPtr = std::shared_ptr<CostGraph>;
  50. extern CostGraphPtr entire_costgraph;
  51. extern size_t TOTAL_OPS;
  52. extern double COST_MODEL_GAMMA;
  53. extern bool COST_MODEL_SIMPLIFY_CALCULATION;
  54. extern double DEVICE_MEMORY_CAPACITY;
  55. extern double COST_MODEL_COMMUNI_THRESHOLD;
  56. extern double COST_MODEL_COMMUNI_CONST;
  57. extern double COST_MODEL_COMMUNI_BIAS;
  58. extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
  59. extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
  60. extern bool FULLY_USE_DEVICES;
  61. extern bool ELEMENTWISE_OP_STRA_FOLLOW;
  62. extern bool MULTI_SUBGRAPHS;
  63. extern int32_t RUN_PHASE;
  64. class CostGraph {
  65. // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
  66. // output-input dependency relationship.
  67. public:
  68. CostGraph() {
  69. dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY;
  70. costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA;
  71. costmodel_beta_ = DEFAULT_COST_MODEL_BETA;
  72. }
  73. ~CostGraph() = default;
  74. void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); }
  75. OperatorInfoPtr FindOperatorByIndex(size_t index) {
  76. if (index >= ops_.size()) {
  77. MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << ".";
  78. return nullptr;
  79. }
  80. return ops_[index];
  81. }
  82. void RemoveOperator(const OperatorInfoPtr &op);
  83. bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
  84. // the edge is in the form: u --> v
  85. void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);
  86. std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; }
  87. std::vector<std::shared_ptr<Edge>> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; }
  88. // An edge is uniquely identified by its name, and its output index and input index.
  89. bool IsEdgeInCostGraph(const std::string &, size_t, size_t);
  90. void SetDeviceMemoryAndCostParameter();
  91. std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>);
  92. void DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
  93. const std::shared_ptr<CostGraph> &component);
  94. CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v);
  95. CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u);
  96. CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory);
  97. CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory);
  98. CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory);
  99. Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
  100. std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) {
  101. return edges_[{u_node, v_node}];
  102. }
  103. double GetDeviceMemory() const { return dev_memory_; }
  104. // Search the cost_list in the final graph, and determine the optimal one
  105. Status SearchStrategy();
  106. // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
  107. OperatorInfoPtr CheckOpElimination() const;
  108. // Given a graph which contains the following subgraph where there are multiple edges between u and v, these edges
  109. // can be eliminated into one
  110. std::vector<EdgePtr> CheckEdgeElimination() const;
  111. // Given a graph which contains the following subgraph:
  112. // u
  113. // |
  114. // w --- v --- x
  115. // where u has 0 incoming edge, u has 1 outgoing edge, and v has > 1 incoming edges, u can be merged into v.
  116. // u is returned.
  117. OperatorInfoPtr CheckMergeElimination() const;
  118. // Given a graph which contains the following subgraph:
  119. // u
  120. // |
  121. // v --- x
  122. // where v has 2 outgoing edges, and u has 1 incoming edges and no outgoing edges. In this case, u can be contracted
  123. // into v. u is returned.
  124. OperatorInfoPtr CheckContractElimination() const;
  125. /* Given a graph which contains the following subgraph:
  126. * u
  127. * / \
  128. * / \
  129. * v --- w
  130. * where u has 2 outgoing edges, v has 1 outgoing edge, and w has 2 incoming edges, u can be eliminated into v.
  131. * The returned value includes u and the edge <u, <v, w>>.
  132. */
  133. std::pair<OperatorInfoPtr, EdgePtr> CheckTriangleElimination() const;
  134. /* Given a graph which contains the following subgraph:
  135. * v <--- u ---> w
  136. * where u has 0 incoming edges, and multiple outgoing edges. In addition, v and w have other complicated connections,
  137. * resulting in v and w can not be performed ContractElimination. u is returned.
  138. * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
  139. */
  140. OperatorInfoPtr CheckStarElimination() const;
  141. // Applying Operator Elimination in DP algorithm
  142. EdgePtr EliminationOp(const OperatorInfoPtr &op);
  143. // Applying Edge Elimination in DP algorithm
  144. EdgePtr EliminationEdges(const std::vector<EdgePtr> &edges);
  145. // Applying Merge Elimination in DP algorithm
  146. OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op);
  147. void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
  148. const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
  149. const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new);
  150. // Applying Contract Elimination in DP algorithm
  151. OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op);
  152. void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr,
  153. const CostPtrList &, CostPtrList *);
  154. // Applying Triangle Elimination in DP algorithm. return the left_node
  155. OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right);
  156. void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &,
  157. const StrategyPtr &, const StrategyPtr &, const StrategyPtr &,
  158. const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *);
  159. // Given the relevant costlist, create the TriangleElimination cost
  160. void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &,
  161. const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *);
  162. // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op
  163. // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
  164. std::vector<EdgePtr> EliminationStar(const OperatorInfoPtr &op);
  165. void CreateStarEliminationCostList(std::vector<EdgePtr> &, const StrategyPtr &, const CostPtrList &,
  166. const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *);
  167. void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &,
  168. const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>,
  169. CostPtrList &, CostPtrList &, CostPtrList *);
  170. // Calculate memory cost for training phase or inference phase.
  171. Status CalculateMemoryCost();
  172. // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
  173. // the memory cost can be resused. This is used to calculate memory in the training phase.
  174. Status CalculateOpsMemoryCost();
  175. // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
  176. // the memory cost can be reused. This is used to calculate memory in the training phase.
  177. Status CalculateEdgesMemoryCost();
  178. // Calculate memory cost of operators in the inference phase.
  179. Status CalculateOpsMemoryCostForInference();
  180. // Calculate memory cost of edges in the inference phase.
  181. Status CalculateEdgesMemoryCostForInference();
  182. Status ComputeOpsAndEdgesParameterInvolved();
  183. // Compute for each operator whether the output is critical.
  184. Status ComputeOpsAndEdgesOutputCritical();
  185. std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
  186. size_t GetNumEdges() const;
  187. Status InitSelectedStrategy();
  188. OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
  189. // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
  190. // once (instead of multiple times), this method is used to correct this.
  191. Status CorrectOpsMemoryCost();
  192. // Needed by rec_parser
  193. void add_inputs_tensor_name(const std::vector<std::string> &inputs_tensor_name) {
  194. inputs_tensor_name_list_.push_back(inputs_tensor_name);
  195. }
  196. const std::vector<std::vector<std::string>> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; }
  197. void add_tuple_getitem(const std::pair<std::string, std::string> &tuple_getitem) {
  198. auto ret = tuple_getitem_list_.insert(tuple_getitem);
  199. if (ret.second == false) {
  200. MS_LOG(EXCEPTION) << "The insert item is already exist.";
  201. }
  202. }
  203. const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; }
  204. private:
  205. void TopologyOrder(std::vector<OperatorInfoPtr> *);
  206. void DFSForTopoOrder(const OperatorInfoPtr &, std::map<OperatorInfoPtr, bool> *, std::vector<OperatorInfoPtr> *);
  207. Status DetermineCriticalOps(const std::vector<OperatorInfoPtr> &);
  208. void MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &);
  209. // Needed by rec_parser
  210. std::vector<std::vector<std::string>> inputs_tensor_name_list_;
  211. std::map<std::string, std::string> tuple_getitem_list_;
  212. double dev_memory_;
  213. double costmodel_alpha_;
  214. double costmodel_beta_;
  215. std::vector<OperatorInfoPtr> ops_;
  216. std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_;
  217. std::vector<std::shared_ptr<CostGraph>> connected_compoents_;
  218. std::map<OperatorInfoPtr, std::vector<EdgePtr>> out_edges_;
  219. std::map<OperatorInfoPtr, std::vector<EdgePtr>> in_edges_;
  220. };
  221. } // namespace parallel
  222. } // namespace mindspore
  223. #endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_