From 87bf1ec80fb9fa06ae2714e6704d56964f5ae7f9 Mon Sep 17 00:00:00 2001 From: tronzhang Date: Fri, 26 Mar 2021 17:09:30 +0800 Subject: [PATCH] delete mark_interface_fusion and tensor reuse frontend pass for graph kernel --- akg | 2 +- .../expanders/clip_by_norm_no_div_sum.py | 8 +- .../graph_kernel/model/graph_split.py | 34 ++-- .../akg/akg_kernel_json_generator.cc | 4 - .../akg/akg_kernel_json_generator.h | 1 - .../graph_kernel/graph_kernel_helper.cc | 15 +- .../frontend/optimizer/graph_kernel_reuse.cc | 152 ------------------ .../frontend/optimizer/graph_kernel_reuse.h | 50 ------ mindspore/ccsrc/frontend/optimizer/irpass.cc | 5 - mindspore/ccsrc/frontend/optimizer/irpass.h | 3 - .../optimizer/irpass/mark_interface_fusion.h | 84 ---------- mindspore/ccsrc/pipeline/jit/pass.cc | 6 - 12 files changed, 34 insertions(+), 330 deletions(-) delete mode 100644 mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc delete mode 100644 mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h delete mode 100644 mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h diff --git a/akg b/akg index 78799c123d..1b0aacec7c 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 78799c123d966ce78ae7c9f6860264628262a16e +Subproject commit 1b0aacec7c125083dbe9fd54174495f92f6bb191 diff --git a/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py b/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py index b8138a3b34..e6c345f4c2 100644 --- a/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +++ b/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py @@ -24,12 +24,10 @@ class ClipByNormNoDivSum(Expander): input_x0, input_x1, input_x2, input_x3 = self.inputs # cal result - greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'}) - select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2], - attrs={'fusion': 'SelectGT_000_end'}) + greater_res = graph_builder.emit('Greater', [input_x0, input_x1]) + select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2]) sqrt_res = graph_builder.emit('Sqrt', [select_res0]) - select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0], - attrs={'fusion': 'SelectGT_000_end'}) + select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0]) result = graph_builder.emit('Maximum', [select_res1, input_x3]) return result diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index f72a06511f..3051de7253 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -123,7 +123,7 @@ class GraphSplitByPattern: self.output_excluded.update(area.output_excluded) self.update_stitch_info(area.stitch_info) - def check_circle(self, to): + def check_acyclic(self, to): """Check circle. It returns false if circle exists""" def _reached(area, to): for out, _ in area.out_relations.items(): @@ -215,6 +215,10 @@ class GraphSplitByPattern: with open(filename, 'w') as f: f.write(subgraphs_str) + def do_split(self): + """Split graph by pattern""" + raise Exception("do_split() is not implemented in {}".format(self.__class__.__name__)) + def split(self): """Split graph by pattern""" self.do_split() @@ -270,11 +274,11 @@ class GraphSplitGpu(GraphSplitByPattern): return None min_area, forward_fuse = None, False for a, _ in dom.out_relations.items(): - if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ + if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ (min_area is None or a.pattern < min_area.pattern): min_area = a for a, _ in dom.in_relations.items(): - if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ + if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ (min_area is None or a.pattern < min_area.pattern): min_area, forward_fuse = a, True @@ -294,7 +298,7 @@ class GraphSplitGpu(GraphSplitByPattern): return None fused = [] for a, r in dom.in_relations.items(): - if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ + if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \ a.dom_op().output.shape == dom.dom_op().output.shape: fused.append(a) return fused, True @@ -319,7 +323,7 @@ class GraphSplitGpu(GraphSplitByPattern): return None fused = [] for a, r in dom.out_relations.items(): - if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ + if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \ (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): return None fused.append(a) @@ -341,7 +345,7 @@ class GraphSplitGpu(GraphSplitByPattern): return True return False - def _reduce_pat_exclude(dom, a, r): + def _reduce_pat_exclude(_, a, r): if len(a.ops) > self.REDUCE_FUSE_DEPTH: return True if use_poly_reduce: @@ -373,7 +377,7 @@ class GraphSplitGpu(GraphSplitByPattern): _is_atomic_add_available(dom): # to evade the precision problem. continue - if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): + if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom): fused.append(a) return fused, True @@ -415,7 +419,7 @@ class GraphSplitGpu(GraphSplitByPattern): fused = [] for a, r in dom.out_relations.items(): if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ - dom.check_circle(a) and not dom.reduce_out_exclude(a): + dom.check_acyclic(a) and not dom.reduce_out_exclude(a): fused.append(a) return fused, False @@ -429,7 +433,7 @@ class GraphSplitGpu(GraphSplitByPattern): fused = [] for a, r in dom.out_relations.items(): - if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_circle(a): + if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): if _reduce_nums(a.ops) < 2: # softmax if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: @@ -442,7 +446,7 @@ class GraphSplitGpu(GraphSplitByPattern): return None fused = [] for a, _ in dom.in_relations.items(): - if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom): + if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom): fused.append(a) return fused, True @@ -490,11 +494,11 @@ class GraphSplitAscend(GraphSplitByPattern): return None min_area, forward_fuse = None, False for a, _ in dom.out_relations.items(): - if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ + if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ (min_area is None or a.pattern < min_area.pattern): min_area = a for a, _ in dom.in_relations.items(): - if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ + if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ (min_area is None or a.pattern < min_area.pattern): min_area, forward_fuse = a, True @@ -514,7 +518,7 @@ class GraphSplitAscend(GraphSplitByPattern): return None fused = [] for a, r in dom.in_relations.items(): - if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ + if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \ a.dom_op().output.shape == dom.dom_op().output.shape: fused.append(a) return fused, True @@ -537,7 +541,7 @@ class GraphSplitAscend(GraphSplitByPattern): return None fused = [] for a, r in dom.out_relations.items(): - if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ + if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \ (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): return None fused.append(a) @@ -564,7 +568,7 @@ class GraphSplitAscend(GraphSplitByPattern): return None fused = [] for a, r in dom.in_relations.items(): - if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): + if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom): fused.append(a) return fused, True diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc index 7508476cb5..46b67c77e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc @@ -665,10 +665,6 @@ bool AkgKernelJsonGenerator::GenSingleJsons(const std::vector &anf_n auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("fusion") != nullptr) { - node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); - } - (*node_json_map)[anf_node] = node_json; } return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h index 99bca58db8..fa41f97d10 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h @@ -49,7 +49,6 @@ constexpr auto kJsonKeyPtrAddress = "ptr_address"; constexpr auto kJsonKeyCompositeGraph = "composite_graph"; constexpr auto kJsonKeyPlatform = "platform"; constexpr auto kJsonKeyOpFullName = "op_full_name"; -constexpr auto kJsonKeyFusion = "fusion"; constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; constexpr auto kJsonKeyFusionType = "fusion_type"; constexpr auto kJsonKeySubGraph = "sub_graph"; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 3de37c8798..99b3650943 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -206,7 +206,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i MS_EXCEPTION_IF_NULL(inputs_ptr); auto nodes = TopoSort(fg->get_return()); - OrderedMap vmap; + std::vector> v_replace; std::vector scalar_tensors; for (const auto &node : nodes) { if (!node->isa()) { @@ -222,14 +222,21 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i if (TensorElementAllTheSame(tensor)) { scalar_tensors.emplace_back(tnode); } else { - vmap[GetValueNode(tnode)].push_back(tnode); + auto tensor_iter = std::find_if( + v_replace.begin(), v_replace.end(), + [&tensor](const std::pair &vl) { return vl.first->ValueEqual(*tensor); }); + if (tensor_iter == v_replace.end()) { + v_replace.emplace_back(tensor, AnfNodePtrList{tnode}); + } else { + tensor_iter->second.push_back(tnode); + } } } } ReplaceTensorWithScalar(fg, scalar_tensors); - if (vmap.empty()) { + if (v_replace.empty()) { return false; } @@ -240,7 +247,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i } auto &inputs = *inputs_ptr; - for (auto iter : vmap) { + for (auto iter : v_replace) { auto value_nodes = iter.second; if (value_nodes.empty()) { MS_LOG(EXCEPTION) << "Invalid value in map!"; diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc deleted file mode 100644 index ce3a5d434e..0000000000 --- a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "frontend/optimizer/graph_kernel_reuse.h" -#include -#include -#include -#include "ir/graph_utils.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { -bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { - if (a->abstract() && b->abstract()) { - auto a_type = a->abstract()->GetTypeTrack(); - auto b_type = b->abstract()->GetTypeTrack(); - if (a_type != b_type) { - return false; - } - - auto a_shape = a->abstract()->GetShapeTrack(); - auto b_shape = b->abstract()->GetShapeTrack(); - if (a_shape != nullptr && a_shape == b_shape) { - return true; - } - - if (a_shape != nullptr && b_shape != nullptr && a_shape->isa() && - b_shape->isa()) { - return a_shape->cast()->shape() == b_shape->cast()->shape(); - } - } - return false; -} - -bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { - bool changed = false; - auto fgs = manager->func_graphs(); - for (FuncGraphPtr &fg : fgs) { - if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - continue; - } - std::string key = GetValue(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { - if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { - FuncGraphPtr new_fg = nullptr; - for (auto &cfg : graph_kernel_ops[key]) { - // If two graphs have different size then continue - auto fg_topos = TopoSort(fg->get_return()); - auto cfg_topos = TopoSort(cfg->get_return()); - if (fg_topos.size() != cfg_topos.size()) { - continue; - } - - // Compare const tensor - bool has_same = true; - for (size_t i = 0; i < fg_topos.size(); ++i) { - if (IsValueNode(fg_topos[i])) { - if (!IsValueNode(cfg_topos[i])) { - has_same = false; - break; - } - - auto tensor1 = GetValueNode(fg_topos[i]); - auto tensor2 = GetValueNode(cfg_topos[i]); - if (!tensor1->ValueEqual(*tensor2)) { - has_same = false; - break; - } - } - } - - if (!has_same) { - continue; - } - - auto fg_input = fg->parameters(); - auto cfg_input = cfg->parameters(); - if (fg_input.size() != cfg_input.size()) { - continue; - } - // Compare input - for (size_t i = 0; i < fg_input.size(); ++i) { - if (!CompareNode(fg_input[i], cfg_input[i])) { - has_same = false; - break; - } - } - if (!has_same) { - continue; - } - - // Compare output - if (!CompareNode(fg->output(), cfg->output())) { - continue; - } - - // Find reusable fg - new_fg = cfg; - break; - } - - if (new_fg != nullptr) { - // Replace current fg with existing fg - auto users = fg->func_graph_cnodes_index(); - for (auto &iter : users) { - auto cnode = iter.first->first->cast(); - auto new_input = cnode->inputs(); - auto main_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(main_graph); - if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { - new_input[1] = NewValueNode(new_fg); - } else { - new_input[0] = NewValueNode(new_fg); - } - auto new_cnode = main_graph->NewCNode(new_input); - manager->Replace(iter.first->first, new_cnode); - changed = true; - } - } else { - // Add current fg to map - graph_kernel_ops[key].push_back(fg); - } - } - } else { - graph_kernel_ops[key] = {fg}; - } - } - - return changed; -} - -bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - return DoReplace(manager); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h deleted file mode 100644 index 276927a5dc..0000000000 --- a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H -#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H - -#include -#include -#include -#include "mindspore/ccsrc/backend/session/anf_runtime_algorithm.h" -#include "frontend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -// Common subexpression elimination. -class GraphKernelReuse { - public: - GraphKernelReuse() : count(0) {} - virtual ~GraphKernelReuse() = default; - - bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { - bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); - return chg; - } - - bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); - bool DoReplace(const FuncGraphManagerPtr manager); - - bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); - - private: - std::unordered_map> graph_kernel_ops; - int64_t count; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index c866d83f25..db5b26de5a 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -29,7 +29,6 @@ #include "frontend/optimizer/irpass/incorporate_call.h" #include "frontend/optimizer/irpass/incorporate_getitem.h" #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" -#include "frontend/optimizer/irpass/mark_interface_fusion.h" #include "frontend/optimizer/irpass/merge_addn.h" #include "frontend/optimizer/irpass/accumulaten_eliminate.h" #include "frontend/optimizer/irpass/minmax_grad.h" @@ -198,10 +197,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // AddN eliminate addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); - // Mark interface fusion - mark_interface_fusion_ = - MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); - // RowTensor Eliminate row_tensor_eliminate_ = MakeSubstitution( std::make_shared(), "row_tensor_eliminate", diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 9ad4ab9ec5..7be50cb9eb 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -124,9 +124,6 @@ class OptimizeIRPassLib { // AddN eliminate SubstitutionPtr addn_eliminate_; - // Fusion - SubstitutionPtr mark_interface_fusion_; - // RowTensor Eliminate SubstitutionPtr row_tensor_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h b/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h deleted file mode 100644 index 1f85f53c64..0000000000 --- a/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H -#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H - -#include -#include -#include - -#include "backend/session/anf_runtime_algorithm.h" -#include "frontend/optimizer/optimizer.h" -#include "frontend/optimizer/irpass.h" -#include "frontend/optimizer/anf_visitor.h" -#include "frontend/operator/ops.h" -#include "ir/graph_utils.h" -#include "frontend/operator/composite/composite.h" - -namespace mindspore { -namespace opt { -namespace irpass { -static int64_t count = 0; - -std::string GetFusionNumber() { - std::stringstream ss; - ss << std::setw(4) << std::setfill('0') << count; - std::string num = ss.str(); - ++count; - - return "_" + num; -} - -// Mark CNodes which can be merged in kernel build -class MarkInterfaceFusion : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { - auto cnode = node->cast(); - auto condition = cnode->input(1); - std::string cmp; - std::unordered_map cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, - {"LessEqual", "LE"}, {"Less", "LT"}, - {"Equal", "EQ"}, {"NotEqual", "NE"}}; - if (IsPrimitiveCNode(condition)) { - auto prim_name = GetCNodeFuncName(condition->cast()); - if (cmp_list.count(prim_name) != 0) { - // Mark Select and compare node - cmp = cmp_list[prim_name]; - auto cnt = GetFusionNumber(); - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); - } - } - } - } - } - return nullptr; - } - - void Visit(const AnfNodePtr &) override {} - - private: - AnfNodePtr y_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 8e1f489758..7184cbd471 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -29,7 +29,6 @@ #include "pipeline/jit/remove_value_node_dup.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/cse_pass.h" -#include "frontend/optimizer/graph_kernel_reuse.h" #include "frontend/optimizer/clean.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/graph_transform.h" @@ -261,12 +260,7 @@ OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &i } OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig interface_fusion = opt::OptPassConfig({ - irpass.mark_interface_fusion_, - }); OptPassGroupMap map({ - {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, - {"interface_fusion", interface_fusion}, {"renormalize", opt::OptPassConfig::Renormalize()}, {"cse", opt::OptPassConfig(opt::CSEPass(false))}, });