From: @wenfangpei Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doupull/15985/MERGE
| @@ -44,7 +44,7 @@ class GraphKernelExpander : public Pass { | |||
| public: | |||
| GraphKernelExpander() : Pass("graph_kernel_expander") {} | |||
| ~GraphKernelExpander() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph); | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| ExpanderPtr GetExpander(const AnfNodePtr &node); | |||
| @@ -45,7 +45,7 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| PassManagerPtr GraphKernelOptimizer::PreProcess() { | |||
| PassManagerPtr GraphKernelOptimizer::PreProcess() const { | |||
| auto pm = std::make_shared<PassManager>("graphkernel_stage1_preprocess"); | |||
| // Change Assign(p, a, U) to Assign(Depend(p, U), a) | |||
| pm->AddPass(std::make_shared<SplitAssign>()); | |||
| @@ -60,7 +60,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() { | |||
| return pm; | |||
| } | |||
| PassManagerPtr GraphKernelOptimizer::Cluster() { | |||
| PassManagerPtr GraphKernelOptimizer::Cluster() const { | |||
| auto pm = std::make_shared<PassManager>("graphkernel_stage2_cluster"); | |||
| // Expand complex basic kernels to composite kernels | |||
| pm->AddPass(std::make_shared<GraphKernelExpander>()); | |||
| @@ -73,7 +73,7 @@ PassManagerPtr GraphKernelOptimizer::Cluster() { | |||
| return pm; | |||
| } | |||
| PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() { | |||
| PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const { | |||
| auto pm = std::make_shared<PassManager>("graphkernel_stage3_highlevelopt1"); | |||
| // Reorder Cast and Type-insensitive node | |||
| pm->AddPass(std::make_shared<ReorderOps>()); | |||
| @@ -98,7 +98,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() { | |||
| return pm; | |||
| } | |||
| PassManagerPtr GraphKernelOptimizer::Split() { | |||
| PassManagerPtr GraphKernelOptimizer::Split() const { | |||
| auto pm = std::make_shared<PassManager>("graphkernel_stage4_split"); | |||
| // Move the non-scalar tensor (in composite node) to parameter list | |||
| @@ -126,7 +126,7 @@ PassManagerPtr GraphKernelOptimizer::Split() { | |||
| return pm; | |||
| } | |||
| PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() { | |||
| PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const { | |||
| auto pm = std::make_shared<PassManager>("graphkernel_stage5_highlevelopt2"); | |||
| // Enable atomic add | |||
| pm->AddPass(std::make_shared<AtomicCleanInsertter>()); | |||
| @@ -136,7 +136,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() { | |||
| return pm; | |||
| } | |||
| PassManagerPtr GraphKernelOptimizer::Combine() { | |||
| PassManagerPtr GraphKernelOptimizer::Combine() const { | |||
| auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine"); | |||
| // Enable parallel fusion | |||
| if (is_gpu && context::GraphKernelFlags::GetInstance().enable_parallel_fusion) { | |||
| @@ -146,7 +146,7 @@ PassManagerPtr GraphKernelOptimizer::Combine() { | |||
| return pm; | |||
| } | |||
| PassManagerPtr GraphKernelOptimizer::PostProcess() { | |||
| PassManagerPtr GraphKernelOptimizer::PostProcess() const { | |||
| auto pm = std::make_shared<PassManager>("graphkernel_stage7_postprocess"); | |||
| // Add the new tensors to the kernel_graph | |||
| pm->AddPass(std::make_shared<BindValueToGraph>()); | |||
| @@ -30,19 +30,19 @@ class GraphKernelOptimizer { | |||
| private: | |||
| // Pre-process | |||
| PassManagerPtr PreProcess(); | |||
| PassManagerPtr PreProcess() const; | |||
| // Cluster kernels | |||
| PassManagerPtr Cluster(); | |||
| PassManagerPtr Cluster() const; | |||
| // High level optimize 1 | |||
| PassManagerPtr HighLevelOpt1(); | |||
| PassManagerPtr HighLevelOpt1() const; | |||
| // Split kernels | |||
| PassManagerPtr Split(); | |||
| PassManagerPtr Split() const; | |||
| // High level optimize 2 | |||
| PassManagerPtr HighLevelOpt2(); | |||
| PassManagerPtr HighLevelOpt2() const; | |||
| // Combine kernels | |||
| PassManagerPtr Combine(); | |||
| PassManagerPtr Combine() const; | |||
| // Post-process | |||
| PassManagerPtr PostProcess(); | |||
| PassManagerPtr PostProcess() const; | |||
| bool is_gpu{false}; | |||
| bool is_ascend{false}; | |||
| @@ -17,6 +17,7 @@ | |||
| #include "backend/optimizer/graph_kernel/parallel_fusion.h" | |||
| #include <algorithm> | |||
| #include <cstddef> | |||
| #include <list> | |||
| #include <map> | |||
| #include <memory> | |||
| @@ -454,10 +455,10 @@ std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups( | |||
| } | |||
| std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset( | |||
| int start, const std::vector<int> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes, | |||
| int start, const std::vector<size_t> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes, | |||
| const std::set<int> &excludes) { | |||
| // Get unused nodes by offset index, the result will contain the node with start index. | |||
| int node_limit = nodes.size(); | |||
| int node_limit = static_cast<int>(nodes.size()); | |||
| if (start >= node_limit) { | |||
| MS_LOG(EXCEPTION) << "Index offset is exceed the limit of given nodes."; | |||
| } | |||
| @@ -469,7 +470,7 @@ std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodes | |||
| unused.push_back(i); | |||
| } | |||
| } | |||
| int limit = unused.size(); | |||
| size_t limit = unused.size(); | |||
| for (auto offset : offsets) { | |||
| if (offset >= limit) { | |||
| MS_LOG(EXCEPTION) << "Index offset is exceed the limit of unused nodes."; | |||
| @@ -520,7 +521,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||
| size_t begin = 1, end = unused_num; | |||
| while (begin <= end) { | |||
| size_t mid = (begin + end) / 2; | |||
| std::vector<int> tc(mid); | |||
| std::vector<size_t> tc(mid); | |||
| std::iota(tc.begin(), tc.end(), 1); | |||
| AnfNodePtrList other_candidates; | |||
| std::tie(other_candidates, std::ignore) = | |||
| @@ -535,7 +536,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||
| } | |||
| if (begin > 1) { | |||
| std::vector<int> tc(begin - 1); | |||
| std::vector<size_t> tc(begin - 1); | |||
| std::iota(tc.begin(), tc.end(), 1); | |||
| AnfNodePtrList other_candidates; | |||
| std::tie(other_candidates, std::ignore) = | |||
| @@ -560,7 +561,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||
| // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility. | |||
| if (parallel_infos.size() == 0) { | |||
| origin_candidates_used[get_index(origin_indices, candidates[0])] = true; | |||
| origin_candidates_used[get_index(origin_indices, candidates[parallel_infos.size()])] = true; | |||
| } | |||
| return std::make_tuple(origin_candidates_used, parallel_infos); | |||
| @@ -68,7 +68,7 @@ class ParallelConfig { | |||
| explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {} | |||
| explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; } | |||
| ~ParallelConfig() = default; | |||
| size_t max_num_for_fuse() { return max_num_for_fuse_; } | |||
| size_t max_num_for_fuse() const { return max_num_for_fuse_; } | |||
| private: | |||
| size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result. | |||
| @@ -90,7 +90,7 @@ class ParallelOpFusion : public Pass { | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<int> &offsets, | |||
| std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<size_t> &offsets, | |||
| const std::vector<bool> &used, | |||
| const AnfNodePtrList &nodes, | |||
| const std::set<int> &excludes); | |||
| @@ -32,11 +32,12 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool RaiseReductionPrecision::IsFp16ReduceSum(const AnfNodePtr &node) { | |||
| bool RaiseReductionPrecision::IsFp16ReduceSum(const AnfNodePtr &node) const { | |||
| return IsPrimitiveCNode(node, prim::kPrimReduceSum) && AnfAlgo::GetInputDeviceDataType(node, 0) == kNumberTypeFloat16; | |||
| } | |||
| AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, std::string format) { | |||
| AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, | |||
| const std::string &format) const { | |||
| auto func_graph = input->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), input}; | |||
| @@ -45,7 +46,7 @@ AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const Ty | |||
| return cnode; | |||
| } | |||
| AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) { | |||
| AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) const { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| cnode->set_input(1, input); | |||
| @@ -62,17 +63,19 @@ AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, cons | |||
| return node; | |||
| } | |||
| void RaiseReductionPrecision::ReplaceNode(const AnfNodePtr &reduce_node, const AnfNodePtr &cast_node) { | |||
| void RaiseReductionPrecision::ReplaceNode(const AnfNodePtr &reduce_node, const AnfNodePtr &cast_node) const { | |||
| auto mng = reduce_node->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| // use a copy of user, since the following `mng->Replace` will change the original users of reduce_node. | |||
| auto users = mng->node_users()[reduce_node]; | |||
| for (const auto &user : users) { | |||
| auto user_node = user.first; | |||
| auto user_index = user.second; | |||
| size_t user_index = static_cast<size_t>(user.second); | |||
| if (IsPrimitiveCNode(user_node, prim::kPrimCast) && | |||
| AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat32) { | |||
| mng->Replace(user_node, reduce_node); | |||
| if (!(mng->Replace(user_node, reduce_node))) { | |||
| MS_LOG(ERROR) << "Something happened error, when replacing nodes."; | |||
| } | |||
| } else { | |||
| if (user_node->isa<CNode>()) { | |||
| user_node->cast<CNodePtr>()->set_input(user_index, cast_node); | |||
| @@ -25,14 +25,14 @@ class RaiseReductionPrecision : public Pass { | |||
| public: | |||
| RaiseReductionPrecision() : Pass("raise_reduction_precision") {} | |||
| ~RaiseReductionPrecision() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph); | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| bool IsFp16ReduceSum(const AnfNodePtr &node); | |||
| bool IsFp16ReduceSum(const AnfNodePtr &node) const; | |||
| bool Process(const FuncGraphPtr &func_graph); | |||
| AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, std::string format); | |||
| AnfNodePtr CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input); | |||
| void ReplaceNode(const AnfNodePtr &src_node, const AnfNodePtr &dst_node); | |||
| AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, const std::string &format) const; | |||
| AnfNodePtr CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) const; | |||
| void ReplaceNode(const AnfNodePtr &src_node, const AnfNodePtr &dst_node) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||