You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_fusion.h 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
  17. #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
  18. #include <map>
  19. #include <memory>
  20. #include <set>
  21. #include <sstream>
  22. #include <string>
  23. #include <tuple>
  24. #include <vector>
  25. #include "base/base.h"
  26. #include "backend/common/session/anf_runtime_algorithm.h"
  27. #include "include/common/utils/anfalgo.h"
  28. #include "backend/common/optimizer/optimizer.h"
  29. #include "common/graph_kernel/parallel_cost_model.h"
  30. #include "backend/common/session/kernel_graph.h"
  31. #include "utils/ms_context.h"
  32. namespace mindspore::graphkernel {
  33. class ParallelInfo {
  34. public:
  35. ParallelInfo() = default;
  36. ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims, const FusionInfoPtr &fusion_info)
  37. : nodes_(nodes), dims_(dims), fusion_info_(fusion_info) {}
  38. ParallelInfo(const ParallelInfo &obj) {
  39. nodes_ = obj.nodes_;
  40. dims_ = obj.dims_;
  41. fusion_info_ = obj.fusion_info_;
  42. }
  43. ~ParallelInfo() = default;
  44. size_t GetSize() const {
  45. if (nodes_.size() != dims_.size()) {
  46. MS_LOG(EXCEPTION) << "Internal error in parallel info! nodes' size is different from dims' size: "
  47. << nodes_.size() << " vs " << dims_.size();
  48. }
  49. return nodes_.size();
  50. }
  51. const AnfNodePtrList &nodes() const { return nodes_; }
  52. const std::vector<DimInfoPtr> &dims() const { return dims_; }
  53. const FusionInfoPtr &fusion_info() const { return fusion_info_; }
  54. private:
  55. AnfNodePtrList nodes_;
  56. std::vector<DimInfoPtr> dims_;
  57. FusionInfoPtr fusion_info_;
  58. };
  59. class ParallelConfig {
  60. public:
  61. ParallelConfig() = default;
  62. explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {}
  63. explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; }
  64. ~ParallelConfig() = default;
  65. size_t max_num_for_fuse() const { return max_num_for_fuse_; }
  66. private:
  67. size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result.
  68. };
  69. struct NodeRelation {
  70. public:
  71. NodeRelation() {}
  72. ~NodeRelation() = default;
  73. OrderedSet<AnfNodePtr> pres;
  74. OrderedSet<AnfNodePtr> nexts;
  75. };
  76. class ParallelOpFusion : public opt::Pass {
  77. public:
  78. ParallelOpFusion(const std::string &target, const ParallelConfig &config)
  79. : Pass("parallel_fusion"), target_(target), config_(config) {}
  80. ~ParallelOpFusion() override = default;
  81. bool Run(const FuncGraphPtr &graph) override;
  82. private:
  83. std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<size_t> &offsets,
  84. const std::vector<bool> &used,
  85. const AnfNodePtrList &nodes,
  86. const std::set<int> &excludes);
  87. std::tuple<std::vector<bool>, std::vector<ParallelInfo>> DoSearchInSortedCandidates(
  88. size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
  89. std::map<AnfNodePtr, int> *sorted_indices);
  90. std::tuple<std::vector<bool>, std::vector<ParallelInfo>> SearchFuseNodesInCandidates(const AnfNodePtrList &cs);
  91. void SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
  92. std::vector<ParallelInfo> *parallel_infos);
  93. std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups);
  94. void SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info);
  95. void SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info);
  96. bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,
  97. const std::shared_ptr<session::KernelGraph> &kernel_graph);
  98. OrderedMap<AnfNodePtr, NodeRelation> GenAnalysisGraph(const AnfNodePtrList &nodes);
  99. std::vector<std::vector<AnfNodePtrList>> SearchParallelGroups(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels);
  100. std::string target_;
  101. ParallelConfig config_;
  102. ParallelCostModelPtr cost_model_ptr_;
  103. std::set<AnfNodePtr> virtual_noout_nodes_;
  104. std::set<AnfNodePtr> ignore_noin_nodes_;
  105. unsigned int parallel_level_{0};
  106. };
  107. using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>;
  108. } // namespace mindspore::graphkernel
  109. #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_