You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_parallel.h 7.2 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_H_
  18. #include <vector>
  19. #include <map>
  20. #include <memory>
  21. #include <set>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <utility>
  25. #include "frontend/optimizer/opt.h"
  26. #include "frontend/parallel/strategy.h"
  27. #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
  28. #include "pipeline/jit/pipeline.h"
  29. #include "frontend/parallel/ops_info/ops_utils.h"
  30. #include "frontend/parallel/auto_parallel/operator_costmodel.h"
  31. #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
  32. using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
  33. namespace mindspore {
  34. namespace parallel {
  35. const uint64_t kUSecondInSecond = 1000000;
  36. const int32_t RECURSION_LIMIT = 3;
  37. struct LossNodeInfo {
  38. bool has_tuple_getitem = false;
  39. int64_t dout_index = 0; // now don't support the sens is a tuple
  40. CNodePtr loss_node = nullptr;
  41. };
  42. struct CommInfo {
  43. int64_t device_num = 1;
  44. int64_t global_rank = 0;
  45. std::string world_group;
  46. std::string communication_backend;
  47. };
  48. std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
  49. std::string CreateInstanceName(const CNodePtr &node, size_t index);
  50. void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node);
  51. void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node,
  52. const FuncGraphPtr &func_graph, int64_t pos, const CNodePtr &pre_node);
  53. TensorLayout GetTensorInLayout(const CNodePtr &pre_node, const PrimitivePtr &pre_prim,
  54. const OperatorInfoPtr &distribute_operator_pre);
  55. OperatorInfoPtr GetDistributeOperator(const CNodePtr &node);
  56. void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const OperatorInfoPtr &distribute_operator,
  57. const CNodePtr &middle_node, int64_t index, TensorRedistribution tensor_redistribution,
  58. const CNodePtr &pre_node);
  59. bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs);
  60. void MarkForwardCNode(const FuncGraphPtr &root);
  61. bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
  62. void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node,
  63. const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node);
  64. std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
  65. const CNodePtr &node);
  66. void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node);
  67. void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node);
  68. std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
  69. std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph);
  70. // Generate and init parallel operator
  71. OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
  72. const std::vector<Shapes> &shape_list);
  73. // Generate without initing parallel operator
  74. OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
  75. std::vector<Shapes> shape_list);
  76. // Extract strategy from attr
  77. StrategyPtr ExtractStrategy(const ValuePtr &strategy);
  78. // Extract shape from anfnode
  79. std::vector<Shapes> ExtractShape(const CNodePtr &node);
  80. // Find finally sub graph
  81. std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &parameter);
  82. // Set distribute shape for parameters abstract
  83. std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int64_t> &res);
  84. // change parameters'shape in resource
  85. void CoverSliceShape(const FuncGraphPtr &root);
  86. void SetVirtualDatasetStrategy(const CNodePtr &node);
  87. bool IsInsertVirtualOutput(const FuncGraphPtr &root);
  88. // Create parallel operator for primitive node(has strategy)
  89. void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training = true);
  90. TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair);
  91. std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &node);
  92. std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index);
  93. std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index);
  94. std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node);
  95. void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
  96. StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
  97. // Add node for whole graph
  98. void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
  99. const FuncGraphManagerPtr &manager);
  100. ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
  101. void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
  102. // main step of Parallel
  103. bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer);
  104. int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
  105. Status ParallelInit();
  106. std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);
  107. std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
  108. std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node);
  109. bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter);
  110. void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
  111. const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index);
  112. void SetLastNodeStrategy(const StrategyPtr strategyPtr);
  113. bool CreateGroupsByCkptFile(const std::string &file);
  114. void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
  115. std::vector<size_t> *indexes);
  116. void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
  117. std::string MirrorOpName();
  118. CommInfo GetCommInfo();
  119. std::string GetPrimName(const CNodePtr &node);
  120. void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages);
  121. } // namespace parallel
  122. } // namespace mindspore
  123. #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_H_