From: @dayschan Reviewed-by: @ckey_dou,@gaoxiong1 Signed-off-by: @gaoxiong1pull/15093/MERGE
| @@ -1,197 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "base/core_ops.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "vm/segment_runner.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { | |||
| if (cur_node == node) { | |||
| return FOLLOW; | |||
| } | |||
| if (IsFusibleOp(node)) { | |||
| return FOLLOW; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); | |||
| if (AnfAlgo::IsGraphKernel(prev_node)) { | |||
| return FOLLOW; | |||
| } | |||
| } | |||
| return EXCLUDE; | |||
| } | |||
| // The GetItem node should be fused with its real input and users. | |||
| // If its real input is not in the fuse_list, the GetItem should be excluded. | |||
| AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { | |||
| if (fused_op.empty()) return AnfNodePtrList(); | |||
| std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); | |||
| auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; | |||
| auto mng = fused_op[0]->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| bool changed = true; | |||
| while (changed) { | |||
| changed = false; | |||
| AnfNodePtrList remove_list; | |||
| for (auto getitem : fused_op_set) { | |||
| if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; | |||
| // GetItem should be fused with its real input. | |||
| auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem); | |||
| if (check_include(prev_node) == EXCLUDE) { | |||
| remove_list.push_back(getitem); | |||
| break; | |||
| } | |||
| // GetItem should be fused with its all users. | |||
| const auto &users = mng->node_users()[getitem]; | |||
| if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) { | |||
| return check_include(user.first) == EXCLUDE; | |||
| })) { | |||
| remove_list = DeepLinkedGraphSearch(getitem, check_include); | |||
| break; | |||
| } | |||
| } | |||
| if (!remove_list.empty()) { | |||
| for (auto node : remove_list) { | |||
| fused_op_set.erase(node); | |||
| } | |||
| changed = true; | |||
| } | |||
| } | |||
| // keep the original order of fused_op. | |||
| AnfNodePtrList result; | |||
| for (auto node : fused_op) { | |||
| if (fused_op_set.count(node)) { | |||
| result.push_back(node); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) { | |||
| // Search fusable nodes according input direction. | |||
| auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); | |||
| auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); | |||
| if (used_nodes.size() > 1) { | |||
| used_nodes = RemoveCircle(used_nodes, dep_pri); | |||
| } | |||
| used_nodes = RemoveWildGetitem(used_nodes); | |||
| TopoSortForNodeList(&used_nodes); | |||
| return used_nodes; | |||
| } | |||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos, | |||
| std::unordered_set<AnfNodePtr> *fused_ops) { | |||
| bool changed = false; | |||
| auto mng = kernel_graph->manager(); | |||
| // depend_prior[depend] = pair(prior, behind) | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior; | |||
| // InitDependPrior(todos, &depend_prior); | |||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | |||
| auto node = (*iter)->cast<CNodePtr>(); | |||
| if (node == nullptr || IsKeepBasicNode(node) || fused_ops->count(node)) { | |||
| continue; | |||
| } | |||
| bool is_fusible_op = IsFusibleOp(node); | |||
| if (!is_fusible_op || !kernel_graph->nodes().contains(node)) { | |||
| continue; | |||
| } | |||
| auto fuse_nodes = FindFuseCNodes(node, depend_prior); | |||
| if (fuse_nodes.empty()) { | |||
| continue; | |||
| } | |||
| if (fuse_nodes.size() == 1) { | |||
| // Do not fuse a single GraphKernel again. | |||
| // Do not fuse a single Assign. | |||
| if (AnfAlgo::IsGraphKernel(fuse_nodes[0]) || IsPrimitiveCNode(fuse_nodes[0], prim::kPrimAssign)) { | |||
| continue; | |||
| } | |||
| } | |||
| changed = true; | |||
| fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); | |||
| AnfNodePtr fused_new_node; | |||
| AnfNodePtrList old_outputs; | |||
| std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "fusion"); | |||
| ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs); | |||
| } | |||
| std::dynamic_pointer_cast<session::KernelGraph>(kernel_graph)->SetExecOrderByDefault(); | |||
| return changed; | |||
| } | |||
| } // namespace | |||
| bool FuseBasicOps(const FuncGraphPtr &func_graph) { | |||
| std::unordered_set<AnfNodePtr> fused_ops; | |||
| auto todos = TopoSort(func_graph->get_return()); | |||
| std::reverse(todos.begin(), todos.end()); | |||
| return FuseBasicOps(func_graph, todos, &fused_ops); | |||
| } | |||
| void EliminateGetitem(const FuncGraphPtr &func_graph) { | |||
| std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>(); | |||
| auto todos = TopoSort(func_graph->get_return()); | |||
| for (auto node : todos) { | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node)); | |||
| } | |||
| } | |||
| } | |||
| bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| func_graph->set_manager(mng); | |||
| } | |||
| bool changed = FuseBasicOps(func_graph); | |||
| if (changed) { | |||
| EliminateGetitem(func_graph); | |||
| mng->RemoveRoots(); | |||
| mng->KeepRoots({func_graph}); | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,36 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph); | |||
| class BasicOpsFusion : public Pass { | |||
| public: | |||
| BasicOpsFusion() : Pass("basic_ops_fusion") {} | |||
| ~BasicOpsFusion() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| }; | |||
| using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_ | |||
| @@ -18,6 +18,7 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_cluster.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -99,7 +100,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) { | |||
| continue; | |||
| } | |||
| // Cast cannot fuse with its input | |||
| if (IsFusibleOp((cast_node->cast<CNodePtr>())->input(1))) { | |||
| if (IsClusterableOp((cast_node->cast<CNodePtr>())->input(1))) { | |||
| continue; | |||
| } | |||
| @@ -1,217 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <queue> | |||
| #include <string> | |||
| #include <set> | |||
| #include <unordered_set> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "frontend/operator/ops.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/ordered_set.h" | |||
| #include "utils/ordered_map.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "vm/segment_runner.h" | |||
| #include "debug/draw.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) { | |||
| std::vector<AnfNodePtr> inputs; | |||
| for (auto &root : roots) { | |||
| auto tmp = DeepLinkedGraphSearch(root, include); | |||
| inputs.insert(inputs.end(), tmp.begin(), tmp.end()); | |||
| } | |||
| return inputs; | |||
| } | |||
| } // namespace | |||
| bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, | |||
| std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) { | |||
| if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) { | |||
| return false; | |||
| } | |||
| circle_nodes->clear(); | |||
| auto InputEdges = [&depend_prior](const CNodePtr &cnode) { | |||
| std::set<AnfNodePtr> edges; | |||
| auto range = depend_prior.equal_range(cnode); | |||
| for (auto iter = range.first; iter != range.second; ++iter) { | |||
| edges.insert(iter->second.first); | |||
| } | |||
| auto inputs = cnode->inputs(); | |||
| for (auto input : inputs) { | |||
| edges.insert(input); | |||
| } | |||
| return edges; | |||
| }; | |||
| // consider prior depend both in fused_op_set | |||
| auto range = depend_prior.equal_range(check_node); | |||
| for (auto iter = range.first; iter != range.second; ++iter) { | |||
| if (fused_op_set.count(iter->second.first)) { | |||
| circle_nodes->push_back(iter->second.first); | |||
| } | |||
| } | |||
| std::set<AnfNodePtr> cached_done_set; | |||
| auto cnode = check_node->cast<CNodePtr>(); | |||
| const auto &inputs = InputEdges(cnode); | |||
| // there is a input not in fused_op_set, but the input depends on the fused_op_set | |||
| for (auto input : inputs) { | |||
| if (input->isa<CNode>() && !fused_op_set.count(input)) { | |||
| bool has_circle = false; | |||
| std::set<AnfNodePtr> done; | |||
| std::vector<AnfNodePtr> todos = {input}; | |||
| while (!todos.empty()) { | |||
| auto node = todos.back(); | |||
| todos.pop_back(); | |||
| if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) { | |||
| continue; | |||
| } | |||
| done.insert(node); | |||
| if (fused_op_set.count(node)) { | |||
| has_circle = true; | |||
| circle_nodes->push_back(node); | |||
| continue; | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| auto cnode_ptr = node->cast<CNodePtr>(); | |||
| for (auto it : InputEdges(cnode_ptr)) { | |||
| if (it->isa<CNode>()) { | |||
| todos.push_back(it); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (has_circle) { | |||
| cached_done_set.insert(done.begin(), done.end()); | |||
| } else { | |||
| cached_unconnected_set->insert(done.begin(), done.end()); | |||
| } | |||
| done.clear(); | |||
| } | |||
| } | |||
| return !circle_nodes->empty(); | |||
| } | |||
| AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) { | |||
| std::set<AnfNodePtr> cached_unconnected_set; | |||
| std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); | |||
| auto include = [&fused_op_set](const AnfNodePtr &node) { | |||
| if (fused_op_set.count(node)) { | |||
| return FOLLOW; | |||
| } | |||
| return EXCLUDE; | |||
| }; | |||
| std::vector<AnfNodePtr> circle_nodes; | |||
| for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { | |||
| circle_nodes.clear(); | |||
| bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior); | |||
| // delete the circle node and the node which depend on the circle node in fused op | |||
| if (has_circle) { | |||
| std::vector<AnfNodePtr> erase_nodes; | |||
| erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); | |||
| for (auto erase_node : erase_nodes) { | |||
| fused_op_set.erase(erase_node); | |||
| } | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> res; | |||
| for (auto node : fused_op) { | |||
| if (fused_op_set.count(node)) { | |||
| res.push_back(node); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) { | |||
| if (lst->size() < 2) { | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> res; | |||
| std::set<AnfNodePtr> node_sets(lst->begin(), lst->end()); | |||
| OrderedMap<AnfNodePtr, std::set<AnfNodePtr>> ins; | |||
| OrderedMap<AnfNodePtr, OrderedSet<AnfNodePtr>> outs; | |||
| std::queue<AnfNodePtr> q; | |||
| for (auto node : *lst) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (auto input : cnode->inputs()) { | |||
| if (!node_sets.count(input)) { | |||
| continue; | |||
| } | |||
| // out_degree | |||
| outs[input].insert(node); | |||
| // in_degree | |||
| ins[node].insert(input); | |||
| } | |||
| if (!ins.count(node)) { | |||
| ins[node] = {}; | |||
| } | |||
| } | |||
| for (auto p : ins) { | |||
| if (p.second.size() == 0) { | |||
| q.push(p.first); | |||
| } | |||
| } | |||
| while (!q.empty()) { | |||
| auto node = q.front(); | |||
| q.pop(); | |||
| res.push_back(node); | |||
| if (!outs.count(node)) { | |||
| continue; | |||
| } | |||
| for (auto out : outs[node]) { | |||
| if (!ins.count(out)) { | |||
| continue; | |||
| } | |||
| ins[out].erase(node); | |||
| if (ins[out].size() == 0) { | |||
| q.push(out); | |||
| } | |||
| } | |||
| } | |||
| lst->assign(res.begin(), res.end()); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,37 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ | |||
| #include <limits> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op, | |||
| const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior); | |||
| void TopoSortForNodeList(std::vector<AnfNodePtr> *lst); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ | |||
| @@ -0,0 +1,477 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_cluster.h" | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <fstream> | |||
| #include "base/core_ops.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "debug/common.h" | |||
| #include "utils/context/graph_kernel_flags.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| std::vector<PrimitivePtr> GetClusterableOpList() { | |||
| std::vector<PrimitivePtr> clusterable_ops = { | |||
| prim::kPrimAbs, | |||
| prim::kPrimRound, | |||
| prim::kPrimNeg, | |||
| prim::kPrimExp, | |||
| prim::kPrimAdd, | |||
| prim::kPrimCast, | |||
| prim::kPrimMul, | |||
| prim::kPrimMinimum, | |||
| prim::kPrimMaximum, | |||
| prim::kPrimLog, | |||
| prim::kPrimPow, | |||
| prim::kPrimSub, | |||
| prim::kPrimRsqrt, | |||
| prim::kPrimSqrt, | |||
| prim::kPrimAddN, | |||
| prim::kPrimReciprocal, | |||
| prim::kPrimTanh, | |||
| prim::kPrimReshape, | |||
| prim::kPrimTranspose, | |||
| prim::kPrimRealDiv, | |||
| prim::kPrimReduceSum, | |||
| prim::kPrimEqual, | |||
| prim::kPrimAssign, | |||
| prim::kPrimInplaceAssign, | |||
| #if ENABLE_D | |||
| prim::kPrimMatMul, | |||
| prim::KPrimTransData, | |||
| #elif ENABLE_GPU | |||
| prim::kPrimReduceMax, | |||
| prim::kPrimReduceMin, | |||
| prim::kPrimGreater, | |||
| prim::kPrimLess, | |||
| prim::kPrimGreaterEqual, | |||
| prim::kPrimLessEqual, | |||
| prim::kPrimSelect, | |||
| #endif | |||
| }; | |||
| const auto &flags = context::GraphKernelFlags::GetInstance(); | |||
| OpListFilter(&clusterable_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops); | |||
| return clusterable_ops; | |||
| } | |||
| size_t CountGraphKernelInnerNodes(const AnfNodePtr &node) { | |||
| AnfNodePtrList node_list; | |||
| kernel::GetValidKernelNodes(AnfAlgo::GetCNodeFuncGraphPtr(node), &node_list); | |||
| return node_list.size(); | |||
| } | |||
| } // namespace | |||
| bool IsClusterableOp(const AnfNodePtr &node) { | |||
| if (IsKeepBasicNode(node)) { | |||
| return false; | |||
| } | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| return true; | |||
| } | |||
| auto op_list = GetClusterableOpList(); | |||
| bool node_in_oplist = std::any_of(op_list.begin(), op_list.end(), | |||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||
| if (!node_in_oplist) { | |||
| return false; | |||
| } | |||
| #if ENABLE_D | |||
| // For AICPU operators, only the Reshape can be clustered. | |||
| if (AnfAlgo::GetProcessor(node) != kernel::Processor::AICORE && !IsPrimitiveCNode(node, prim::kPrimReshape)) { | |||
| return false; | |||
| } | |||
| #endif | |||
| return true; | |||
| } | |||
| class Graph { | |||
| struct Cluster { | |||
| size_t cluster_id_; // node_id of the representative. | |||
| size_t cluster_size_{1}; // size of cluster, composite node is considered as one node. | |||
| size_t basic_op_cnt_{1}; // basic node count, the inner nodes of composite node are counted. | |||
| std::set<size_t> inputs_; // inputs' cluster_id. | |||
| size_t seed_{0}; // visited flag of dfs. | |||
| Cluster(size_t node_id, const AnfNodePtr &node, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map) | |||
| : cluster_id_(node_id) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||
| basic_op_cnt_ = 0; | |||
| } else if (AnfAlgo::IsGraphKernel(node)) { | |||
| // the basic_op_cnt_ is used to limit the composite op size | |||
| basic_op_cnt_ = CountGraphKernelInnerNodes(node); | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (const auto &inp : cnode->inputs()) { | |||
| auto iter = node_idx_map.find(inp); | |||
| if (iter != node_idx_map.end()) { | |||
| // At the beginning, cluster_id is equal to node_id | |||
| inputs_.insert(iter->second); | |||
| } | |||
| } | |||
| } | |||
| ~Cluster() = default; | |||
| void Merge(Cluster *other_cluster) { | |||
| other_cluster->cluster_id_ = cluster_id_; | |||
| cluster_size_ += other_cluster->cluster_size_; | |||
| basic_op_cnt_ += other_cluster->basic_op_cnt_; | |||
| std::for_each(other_cluster->inputs_.begin(), other_cluster->inputs_.end(), | |||
| [this](size_t inp) { this->inputs_.insert(inp); }); | |||
| other_cluster->Clean(); | |||
| } | |||
| // clean the info to free memory. | |||
| void Clean() { | |||
| inputs_.clear(); | |||
| cluster_size_ = 0; | |||
| basic_op_cnt_ = 0; | |||
| } | |||
| }; // struct Cluster | |||
| public: | |||
| // Init and build graph | |||
| Graph(const AnfNodePtrList &nodes, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map) { | |||
| clusters_.reserve(nodes.size()); | |||
| for (size_t i = 0; i < nodes.size(); i++) { | |||
| clusters_.emplace_back(i, nodes[i], node_idx_map); | |||
| } | |||
| } | |||
| ~Graph() = default; | |||
| // find the representative of the cluster | |||
| int Find(size_t node_id) { | |||
| size_t &pre_id = clusters_[node_id].cluster_id_; | |||
| return (pre_id == clusters_[pre_id].cluster_id_) ? pre_id : (pre_id = Find(pre_id)); | |||
| } | |||
| // merge clusters, the smallest cluster id will be the new cluster id. | |||
| void Merge(const std::set<size_t> &candidates) { | |||
| for (auto iter = ++candidates.begin(); iter != candidates.end(); ++iter) { | |||
| clusters_[*candidates.begin()].Merge(&clusters_[*iter]); | |||
| } | |||
| } | |||
| // Collect nodes together that are in the same cluster. | |||
| std::vector<std::vector<size_t>> CollectClusters() { | |||
| std::vector<std::vector<size_t>> cluster_map(clusters_.size()); | |||
| for (size_t i = 0; i < clusters_.size(); i++) { | |||
| cluster_map[Find(i)].push_back(i); | |||
| } | |||
| return cluster_map; | |||
| } | |||
| using VisitFunc = std::function<IncludeType(size_t)>; | |||
| void Dfs(size_t node_id, VisitFunc visitor) { | |||
| ++seen_; | |||
| return DepthFirstSearch(Find(node_id), visitor); | |||
| } | |||
| // Get cluster size | |||
| size_t GetSize(size_t cluster_id) { return clusters_[Find(cluster_id)].cluster_size_; } | |||
| // Get cluster's basic op count | |||
| size_t GetBasicNodeCount(size_t cluster_id) { return clusters_[Find(cluster_id)].basic_op_cnt_; } | |||
| // Get cluster's inputs | |||
| const std::set<size_t> &GetInputs(size_t cluster_id) { | |||
| cluster_id = Find(cluster_id); | |||
| RefreshInputs(cluster_id); | |||
| return clusters_[cluster_id].inputs_; | |||
| } | |||
| private: | |||
| void RefreshInputs(size_t i) { | |||
| auto &inputs = clusters_[i].inputs_; | |||
| for (auto iter = inputs.begin(); iter != inputs.end();) { | |||
| size_t new_id = Find(*iter); | |||
| if (new_id != *iter) { | |||
| iter = inputs.erase(iter); | |||
| inputs.insert(new_id); | |||
| } else { | |||
| ++iter; | |||
| } | |||
| } | |||
| inputs.erase(i); | |||
| } | |||
| void DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor) { | |||
| if (clusters_[cluster_id].seed_ >= seen_) return; | |||
| clusters_[cluster_id].seed_ = seen_; | |||
| if (visitor(cluster_id) != FOLLOW) { | |||
| return; | |||
| } | |||
| // traverse inputs in descending order. | |||
| const auto &inputs = GetInputs(cluster_id); | |||
| for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { | |||
| DepthFirstSearch(*iter, visitor); | |||
| } | |||
| } | |||
| std::vector<Cluster> clusters_; | |||
| size_t seen_{0}; | |||
| }; // class Graph | |||
| class CircleChecker { | |||
| public: | |||
| explicit CircleChecker(GraphPtr graph) : graph_(graph) {} | |||
| ~CircleChecker() = default; | |||
| void RemoveCircle(std::set<size_t> *candidates) { | |||
| if (candidates->size() <= 1) { | |||
| return; | |||
| } | |||
| candidates_ = candidates; | |||
| std::vector<size_t> tmp_list(candidates->begin(), candidates->end()); | |||
| for (auto c : tmp_list) { | |||
| if (!candidates->count(c)) continue; | |||
| circle_nodes_.clear(); | |||
| if (CheckCircle(c)) { | |||
| RemoveCircleNodesFromCandidates(); | |||
| } | |||
| } | |||
| } | |||
| private: | |||
| /** | |||
| * Check circle. the candidate is collected into circle_nodes_ if it will form a circle. | |||
| * | |||
| * algorithm: | |||
| * Search from the basenode's input that is NOT in candidates (the basenode is a candidate), | |||
| * If it depends on a node that belongs to candidates, it will form a circle. | |||
| * e.g. A -> x -> ... -> B | |||
| * -> y -> ... -> C | |||
| * In this case, A, B and C are candidates while x and y are not. | |||
| * Both x and y are inputs of A. assumes A is the basenode. | |||
| * When searching from x, the B will be found and added into circle_nodes list, | |||
| * and then when searching from y, the C will be found and added into circle_nodes list. | |||
| */ | |||
| bool CheckCircle(size_t basenode) { | |||
| const auto &inputs = graph_->GetInputs(basenode); | |||
| std::set<size_t> visited_circle_nodes; | |||
| for (auto x : inputs) { | |||
| if (candidates_->count(x)) continue; | |||
| bool has_circle = false; | |||
| std::set<size_t> done; | |||
| auto vis_func = [this, &has_circle, &done, &visited_circle_nodes](size_t node_id) { | |||
| if (done.count(node_id) || acyclic_nodes_.count(node_id) || visited_circle_nodes.count(node_id)) { | |||
| return EXCLUDE; | |||
| } | |||
| done.insert(node_id); | |||
| if (candidates_->count(node_id)) { | |||
| has_circle = true; | |||
| circle_nodes_.push_back(node_id); | |||
| return EXCLUDE; | |||
| } | |||
| // all nodes are indexed by topo order, | |||
| // so if the node_id is less than the minimal candidate, a cycle cannot be formed from this node. | |||
| if (candidates_->empty() || node_id < *candidates_->begin()) { | |||
| return EXCLUDE; | |||
| } | |||
| return FOLLOW; | |||
| }; | |||
| graph_->Dfs(x, vis_func); | |||
| if (has_circle) { | |||
| visited_circle_nodes.insert(done.begin(), done.end()); | |||
| } else { | |||
| acyclic_nodes_.insert(done.begin(), done.end()); | |||
| } | |||
| } | |||
| return !circle_nodes_.empty(); | |||
| } | |||
| // remove all circle nodes from candidates | |||
| void RemoveCircleNodesFromCandidates() { | |||
| auto remove_from_candidates = [this](size_t node_id) { | |||
| if (candidates_->count(node_id)) { | |||
| candidates_->erase(node_id); | |||
| return FOLLOW; | |||
| } | |||
| return EXCLUDE; | |||
| }; | |||
| for (auto node : circle_nodes_) { | |||
| graph_->Dfs(node, remove_from_candidates); | |||
| } | |||
| } | |||
| private: | |||
| GraphPtr graph_; // bind the global graph | |||
| std::set<size_t> *candidates_{nullptr}; // bind the input candidates | |||
| std::vector<size_t> circle_nodes_; | |||
| std::set<size_t> acyclic_nodes_; | |||
| }; // CircleChecker | |||
| std::set<size_t> GraphKernelCluster::FindCandidates(size_t basenode_id) { | |||
| std::set<size_t> candidates; | |||
| auto include = [this, &candidates, func_graph = nodes_[basenode_id]->func_graph()](size_t cluster_id) { | |||
| const AnfNodePtr &node = this->nodes_[cluster_id]; | |||
| if (node->func_graph() != func_graph) { | |||
| return EXCLUDE; | |||
| } | |||
| if (!IsClusterableOp(node) && !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||
| return EXCLUDE; | |||
| } | |||
| candidates.insert(cluster_id); | |||
| // Do not search from clustered node again. | |||
| if (this->graph_->GetSize(cluster_id) > 1) { | |||
| return NOFOLLOW; | |||
| } | |||
| return FOLLOW; | |||
| }; | |||
| graph_->Dfs(basenode_id, include); | |||
| return candidates; | |||
| } | |||
| bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) { | |||
| bool changed = false; | |||
| for (int i = nodes_.size() - 1; i >= 0; i--) { | |||
| // if the node has been clustered, it has tried to find its previous nodes, so it's unnecessary to try again. | |||
| if (graph_->GetSize(i) > 1) { | |||
| continue; | |||
| } | |||
| auto candidates = FindCandidates(i); | |||
| CircleChecker(graph_).RemoveCircle(&candidates); | |||
| RemoveWildGetitem(&candidates); | |||
| if (candidates.empty()) continue; | |||
| // merge candidates into one cluster | |||
| graph_->Merge(candidates); | |||
| } | |||
| // Rebuild func_graphs | |||
| auto clusters = graph_->CollectClusters(); | |||
| for (size_t i = 0; i < clusters.size(); i++) { | |||
| auto node_without_getitem = std::count_if(clusters[i].begin(), clusters[i].end(), [this](size_t node_id) { | |||
| return !IsPrimitiveCNode(this->nodes_[node_id], prim::kPrimTupleGetItem); | |||
| }); | |||
| if (node_without_getitem == 0) continue; | |||
| if (node_without_getitem == 1) { | |||
| // Do not cluster a single GraphKernel again. | |||
| // Do not cluster a single Assign. | |||
| const auto &node = nodes_[clusters[i][0]]; | |||
| if (AnfAlgo::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) { | |||
| continue; | |||
| } | |||
| } | |||
| CreateFuncGraph(func_graph, clusters[i]); | |||
| changed = true; | |||
| } | |||
| return changed; | |||
| } | |||
| void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) { | |||
| AnfNodePtrList old_nodes; | |||
| AnfNodePtr new_node; | |||
| std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes), | |||
| [this](size_t id) { return this->nodes_[id]; }); | |||
| std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion"); | |||
| std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>(); | |||
| eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node)); | |||
| if (context::GraphKernelFlags::GetInstance().dump_as_text) { | |||
| DumpClusterInfo(old_nodes, new_node); | |||
| } | |||
| } | |||
| void GraphKernelCluster::DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node) { | |||
| #ifdef ENABLE_DUMP_IR | |||
| dump_buf_ << "Source nodes of " << new_node->fullname_with_scope() << " = " << new_node->DebugString() << std::endl; | |||
| for (const auto &node : old_nodes) { | |||
| dump_buf_ << " " << node->fullname_with_scope() << " = " << node->DebugString() << std::endl; | |||
| } | |||
| dump_buf_ << "=======================" << std::endl; | |||
| #endif | |||
| } | |||
| void GraphKernelCluster::DumpToFile() { | |||
| #ifdef ENABLE_DUMP_IR | |||
| auto pathname = std::string("./") + kGraphKernelDumpPath + "/graph_kernel_cluster.txt"; | |||
| auto realpath = Common::GetRealPath(pathname); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Get real path failed. path=" << pathname; | |||
| return; | |||
| } | |||
| std::ofstream fout(realpath.value(), std::ios::app); | |||
| if (!fout.is_open()) { | |||
| MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!"; | |||
| return; | |||
| } | |||
| fout << dump_buf_.str() << std::endl; | |||
| fout.close(); | |||
| #endif | |||
| } | |||
| // The GetItem node should be clustered with its real input. | |||
| // If its real input is not in the candidates, the GetItem should be excluded. | |||
| void GraphKernelCluster::RemoveWildGetitem(std::set<size_t> *candidates) { | |||
| for (auto iter = candidates->begin(); iter != candidates->end();) { | |||
| size_t cluster_id = *iter; | |||
| /*The implied condition is graph->GetSize(cluster_id) == 1*/ | |||
| if (IsPrimitiveCNode(nodes_[cluster_id], prim::kPrimTupleGetItem)) { | |||
| const auto &inputs = graph_->GetInputs(cluster_id); | |||
| if (inputs.size() != 1) { | |||
| MS_LOG(ERROR) << "Input size of GetItem(" << cluster_id << ") should be 1, but got " << inputs.size(); | |||
| candidates->clear(); | |||
| return; | |||
| } | |||
| auto prev_id = *(inputs.begin()); | |||
| if (!candidates->count(prev_id)) { | |||
| iter = candidates->erase(iter); | |||
| continue; | |||
| } | |||
| } | |||
| ++iter; | |||
| } | |||
| } | |||
| void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) { | |||
| // process cnode only | |||
| nodes_ = TopoSort(func_graph->get_return(), SuccIncoming, | |||
| [](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; }); | |||
| for (size_t i = 0; i < nodes_.size(); i++) { | |||
| node_idx_map_[nodes_[i]] = i; | |||
| } | |||
| graph_ = std::make_shared<Graph>(nodes_, node_idx_map_); | |||
| MS_EXCEPTION_IF_NULL(graph_); | |||
| } | |||
| bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) { | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| Init(func_graph); | |||
| bool changed = Process(func_graph); | |||
| if (changed) { | |||
| if (context::GraphKernelFlags::GetInstance().dump_as_text) { | |||
| DumpToFile(); | |||
| } | |||
| mng->RemoveRoots(); | |||
| mng->KeepRoots({func_graph}); | |||
| } | |||
| Clean(); | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <set> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class Graph; | |||
| using GraphPtr = std::shared_ptr<Graph>; | |||
| class GraphKernelCluster : public Pass { | |||
| public: | |||
| GraphKernelCluster() : Pass("graph_kernel_cluster") {} | |||
| ~GraphKernelCluster() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| void Init(const FuncGraphPtr &func_graph); | |||
| bool Process(const FuncGraphPtr &func_graph); | |||
| std::set<size_t> FindCandidates(size_t basenode_id); | |||
| void RemoveWildGetitem(std::set<size_t> *candidates); | |||
| void CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id); | |||
| void DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node); | |||
| void DumpToFile(); | |||
| void Clean() { | |||
| std::vector<AnfNodePtr>().swap(nodes_); | |||
| node_idx_map_.clear(); | |||
| graph_ = nullptr; | |||
| } | |||
| GraphPtr graph_{nullptr}; | |||
| std::vector<AnfNodePtr> nodes_; | |||
| std::unordered_map<AnfNodePtr, size_t> node_idx_map_; | |||
| std::stringstream dump_buf_; | |||
| }; | |||
| bool IsClusterableOp(const AnfNodePtr &node); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_ | |||
| @@ -593,77 +593,6 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p | |||
| return name.str(); | |||
| } | |||
| std::vector<PrimitivePtr> GetFusibleOpList() { | |||
| #if ENABLE_D | |||
| std::vector<PrimitivePtr> fusible_basic_ops = { | |||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | |||
| prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | |||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | |||
| prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimInplaceAssign, | |||
| prim::KPrimTransData}; | |||
| #elif ENABLE_GPU | |||
| std::vector<PrimitivePtr> fusible_basic_ops = { | |||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | |||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | |||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, | |||
| prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | |||
| prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin, | |||
| prim::kPrimLess, prim::kPrimInplaceAssign}; | |||
| #else | |||
| std::vector<PrimitivePtr> fusible_basic_ops; | |||
| #endif | |||
| const auto &flags = context::GraphKernelFlags::GetInstance(); | |||
| OpListFilter(&fusible_basic_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops); | |||
| return fusible_basic_ops; | |||
| } | |||
| bool CheckProcessor(const AnfNodePtr &node, kernel::Processor processor = kernel::Processor::AICORE) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto node_kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); | |||
| if (node_kernel_info == nullptr) { | |||
| return false; | |||
| } | |||
| auto node_build_info = node_kernel_info->GetMutableSelectKernelBuildInfo(); | |||
| if (node_build_info == nullptr) { | |||
| return false; | |||
| } | |||
| return node_build_info->processor() == processor; | |||
| } | |||
| bool IsBasicFuseOp(const AnfNodePtr &node) { | |||
| std::vector<PrimitivePtr> basic_ops = GetFusibleOpList(); | |||
| #if ENABLE_D | |||
| if (!CheckProcessor(node)) { | |||
| std::vector<PrimitivePtr> fused_aicpu_op = {prim::kPrimExpandDims, prim::kPrimReshape}; | |||
| if (!std::any_of(fused_aicpu_op.begin(), fused_aicpu_op.end(), | |||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) { | |||
| return false; | |||
| } | |||
| } | |||
| #endif | |||
| return std::any_of(basic_ops.begin(), basic_ops.end(), | |||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | |||
| } | |||
| bool IsFusibleOp(const AnfNodePtr &node) { | |||
| #if ENABLE_D | |||
| const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", | |||
| "LambNextMV", "LambUpdateWithLR"}; | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| if (fg_attr != nullptr) { | |||
| return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0; | |||
| } | |||
| } | |||
| #endif | |||
| return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node); | |||
| } | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -674,37 +603,6 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | |||
| #endif | |||
| } | |||
| void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | |||
| const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri; | |||
| for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) { | |||
| if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) { | |||
| MS_LOG(ERROR) << "Need real outputs of makeTuple"; | |||
| } | |||
| if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) { | |||
| continue; | |||
| } | |||
| for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) { | |||
| if (iter->first == outputs[out_idx]) { | |||
| new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second}); | |||
| iter = depend_prior->erase(iter); | |||
| continue; | |||
| } | |||
| if (iter->second.first == outputs[out_idx]) { | |||
| new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)}); | |||
| iter = depend_prior->erase(iter); | |||
| continue; | |||
| } | |||
| ++iter; | |||
| } | |||
| } | |||
| for (auto item : new_fuse_cnode_dep_pri) { | |||
| depend_prior->insert(item); | |||
| } | |||
| } | |||
| std::string GetFormat(const AnfNodePtr &node) { | |||
| auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| @@ -47,6 +47,8 @@ constexpr auto kJsonKeyMultiGraph = "multi_graph"; | |||
| constexpr auto kJsonKeyGraphDesc = "graph_desc"; | |||
| constexpr auto kJsonKeyGraphMode = "graph_mode"; | |||
| constexpr auto kGraphKernelDumpPath = "graph_kernel_dump"; | |||
| struct DataInfo { | |||
| std::string format{kOpFormat_DEFAULT}; | |||
| ShapeVector shape{1}; | |||
| @@ -75,12 +77,7 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n | |||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | |||
| FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | |||
| std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = ""); | |||
| std::vector<PrimitivePtr> GetFusibleOpList(); | |||
| bool IsBasicFuseOp(const AnfNodePtr &node); | |||
| bool IsFusibleOp(const AnfNodePtr &node); | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||
| void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | |||
| const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); | |||
| std::string GetFormat(const AnfNodePtr &node); | |||
| TypePtr GetType(const AnfNodePtr &node); | |||
| @@ -25,7 +25,7 @@ | |||
| #include "backend/optimizer/graph_kernel/add_atomic_clean.h" | |||
| #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" | |||
| #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_cluster.h" | |||
| #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h" | |||
| #include "backend/optimizer/graph_kernel/tensor_promotion.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" | |||
| @@ -65,8 +65,8 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const { | |||
| // Expand complex basic kernels to composite kernels | |||
| pm->AddPass(std::make_shared<GraphKernelExpander>()); | |||
| // Fuse basic kernels and composite kernels | |||
| pm->AddPass(std::make_shared<BasicOpsFusion>()); | |||
| // Cluster basic kernels and composite kernels | |||
| pm->AddPass(std::make_shared<GraphKernelCluster>()); | |||
| // Eliminate the outputs without external user | |||
| pm->AddPass(std::make_shared<EliminateRedundantOutput>()); | |||