From: @tronzhang Reviewed-by: @gaoxiong1,@gaoxiong1,@dylangeng Signed-off-by: @dylangengpull/13834/MERGE
| @@ -1 +1 @@ | |||
| Subproject commit 78799c123d966ce78ae7c9f6860264628262a16e | |||
| Subproject commit 1b0aacec7c125083dbe9fd54174495f92f6bb191 | |||
| @@ -24,12 +24,10 @@ class ClipByNormNoDivSum(Expander): | |||
| input_x0, input_x1, input_x2, input_x3 = self.inputs | |||
| # cal result | |||
| greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'}) | |||
| select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2], | |||
| attrs={'fusion': 'SelectGT_000_end'}) | |||
| greater_res = graph_builder.emit('Greater', [input_x0, input_x1]) | |||
| select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2]) | |||
| sqrt_res = graph_builder.emit('Sqrt', [select_res0]) | |||
| select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0], | |||
| attrs={'fusion': 'SelectGT_000_end'}) | |||
| select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0]) | |||
| result = graph_builder.emit('Maximum', [select_res1, input_x3]) | |||
| return result | |||
| @@ -123,7 +123,7 @@ class GraphSplitByPattern: | |||
| self.output_excluded.update(area.output_excluded) | |||
| self.update_stitch_info(area.stitch_info) | |||
| def check_circle(self, to): | |||
| def check_acyclic(self, to): | |||
| """Check circle. It returns false if circle exists""" | |||
| def _reached(area, to): | |||
| for out, _ in area.out_relations.items(): | |||
| @@ -215,6 +215,10 @@ class GraphSplitByPattern: | |||
| with open(filename, 'w') as f: | |||
| f.write(subgraphs_str) | |||
| def do_split(self): | |||
| """Split graph by pattern""" | |||
| raise Exception("do_split() is not implemented in {}".format(self.__class__.__name__)) | |||
| def split(self): | |||
| """Split graph by pattern""" | |||
| self.do_split() | |||
| @@ -270,11 +274,11 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| return None | |||
| min_area, forward_fuse = None, False | |||
| for a, _ in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | |||
| if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ | |||
| (min_area is None or a.pattern < min_area.pattern): | |||
| min_area = a | |||
| for a, _ in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ | |||
| len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ | |||
| (min_area is None or a.pattern < min_area.pattern): | |||
| min_area, forward_fuse = a, True | |||
| @@ -294,7 +298,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \ | |||
| a.dom_op().output.shape == dom.dom_op().output.shape: | |||
| fused.append(a) | |||
| return fused, True | |||
| @@ -319,7 +323,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | |||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \ | |||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | |||
| return None | |||
| fused.append(a) | |||
| @@ -341,7 +345,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| return True | |||
| return False | |||
| def _reduce_pat_exclude(dom, a, r): | |||
| def _reduce_pat_exclude(_, a, r): | |||
| if len(a.ops) > self.REDUCE_FUSE_DEPTH: | |||
| return True | |||
| if use_poly_reduce: | |||
| @@ -373,7 +377,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| _is_atomic_add_available(dom): | |||
| # to evade the precision problem. | |||
| continue | |||
| if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): | |||
| if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom): | |||
| fused.append(a) | |||
| return fused, True | |||
| @@ -415,7 +419,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | |||
| dom.check_circle(a) and not dom.reduce_out_exclude(a): | |||
| dom.check_acyclic(a) and not dom.reduce_out_exclude(a): | |||
| fused.append(a) | |||
| return fused, False | |||
| @@ -429,7 +433,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_circle(a): | |||
| if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): | |||
| if _reduce_nums(a.ops) < 2: | |||
| # softmax | |||
| if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: | |||
| @@ -442,7 +446,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| return None | |||
| fused = [] | |||
| for a, _ in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom): | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom): | |||
| fused.append(a) | |||
| return fused, True | |||
| @@ -490,11 +494,11 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| return None | |||
| min_area, forward_fuse = None, False | |||
| for a, _ in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | |||
| if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ | |||
| (min_area is None or a.pattern < min_area.pattern): | |||
| min_area = a | |||
| for a, _ in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ | |||
| len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ | |||
| (min_area is None or a.pattern < min_area.pattern): | |||
| min_area, forward_fuse = a, True | |||
| @@ -514,7 +518,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \ | |||
| a.dom_op().output.shape == dom.dom_op().output.shape: | |||
| fused.append(a) | |||
| return fused, True | |||
| @@ -537,7 +541,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | |||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \ | |||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | |||
| return None | |||
| fused.append(a) | |||
| @@ -564,7 +568,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): | |||
| if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom): | |||
| fused.append(a) | |||
| return fused, True | |||
| @@ -665,10 +665,6 @@ bool AkgKernelJsonGenerator::GenSingleJsons(const std::vector<AnfNodePtr> &anf_n | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (primitive->GetAttr("fusion") != nullptr) { | |||
| node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); | |||
| } | |||
| (*node_json_map)[anf_node] = node_json; | |||
| } | |||
| return true; | |||
| @@ -49,7 +49,6 @@ constexpr auto kJsonKeyPtrAddress = "ptr_address"; | |||
| constexpr auto kJsonKeyCompositeGraph = "composite_graph"; | |||
| constexpr auto kJsonKeyPlatform = "platform"; | |||
| constexpr auto kJsonKeyOpFullName = "op_full_name"; | |||
| constexpr auto kJsonKeyFusion = "fusion"; | |||
| constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; | |||
| constexpr auto kJsonKeyFusionType = "fusion_type"; | |||
| constexpr auto kJsonKeySubGraph = "sub_graph"; | |||
| @@ -207,7 +207,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i | |||
| MS_EXCEPTION_IF_NULL(inputs_ptr); | |||
| auto nodes = TopoSort(fg->get_return()); | |||
| OrderedMap<ValuePtr, AnfNodePtrList> vmap; | |||
| std::vector<std::pair<tensor::TensorPtr, AnfNodePtrList>> v_replace; | |||
| std::vector<AnfNodePtr> scalar_tensors; | |||
| for (const auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| @@ -223,14 +223,21 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i | |||
| if (TensorElementAllTheSame(tensor)) { | |||
| scalar_tensors.emplace_back(tnode); | |||
| } else { | |||
| vmap[GetValueNode(tnode)].push_back(tnode); | |||
| auto tensor_iter = std::find_if( | |||
| v_replace.begin(), v_replace.end(), | |||
| [&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); }); | |||
| if (tensor_iter == v_replace.end()) { | |||
| v_replace.emplace_back(tensor, AnfNodePtrList{tnode}); | |||
| } else { | |||
| tensor_iter->second.push_back(tnode); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| ReplaceTensorWithScalar(fg, scalar_tensors); | |||
| if (vmap.empty()) { | |||
| if (v_replace.empty()) { | |||
| return false; | |||
| } | |||
| @@ -241,7 +248,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i | |||
| } | |||
| auto &inputs = *inputs_ptr; | |||
| for (auto iter : vmap) { | |||
| for (auto iter : v_replace) { | |||
| auto value_nodes = iter.second; | |||
| if (value_nodes.empty()) { | |||
| MS_LOG(EXCEPTION) << "Invalid value in map!"; | |||
| @@ -1,152 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/optimizer/graph_kernel_reuse.h" | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include "ir/graph_utils.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| namespace opt { | |||
| bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { | |||
| if (a->abstract() && b->abstract()) { | |||
| auto a_type = a->abstract()->GetTypeTrack(); | |||
| auto b_type = b->abstract()->GetTypeTrack(); | |||
| if (a_type != b_type) { | |||
| return false; | |||
| } | |||
| auto a_shape = a->abstract()->GetShapeTrack(); | |||
| auto b_shape = b->abstract()->GetShapeTrack(); | |||
| if (a_shape != nullptr && a_shape == b_shape) { | |||
| return true; | |||
| } | |||
| if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() && | |||
| b_shape->isa<abstract::Shape>()) { | |||
| return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape(); | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { | |||
| bool changed = false; | |||
| auto fgs = manager->func_graphs(); | |||
| for (FuncGraphPtr &fg : fgs) { | |||
| if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| continue; | |||
| } | |||
| std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { | |||
| if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { | |||
| FuncGraphPtr new_fg = nullptr; | |||
| for (auto &cfg : graph_kernel_ops[key]) { | |||
| // If two graphs have different size then continue | |||
| auto fg_topos = TopoSort(fg->get_return()); | |||
| auto cfg_topos = TopoSort(cfg->get_return()); | |||
| if (fg_topos.size() != cfg_topos.size()) { | |||
| continue; | |||
| } | |||
| // Compare const tensor | |||
| bool has_same = true; | |||
| for (size_t i = 0; i < fg_topos.size(); ++i) { | |||
| if (IsValueNode<tensor::Tensor>(fg_topos[i])) { | |||
| if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) { | |||
| has_same = false; | |||
| break; | |||
| } | |||
| auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]); | |||
| auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]); | |||
| if (!tensor1->ValueEqual(*tensor2)) { | |||
| has_same = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (!has_same) { | |||
| continue; | |||
| } | |||
| auto fg_input = fg->parameters(); | |||
| auto cfg_input = cfg->parameters(); | |||
| if (fg_input.size() != cfg_input.size()) { | |||
| continue; | |||
| } | |||
| // Compare input | |||
| for (size_t i = 0; i < fg_input.size(); ++i) { | |||
| if (!CompareNode(fg_input[i], cfg_input[i])) { | |||
| has_same = false; | |||
| break; | |||
| } | |||
| } | |||
| if (!has_same) { | |||
| continue; | |||
| } | |||
| // Compare output | |||
| if (!CompareNode(fg->output(), cfg->output())) { | |||
| continue; | |||
| } | |||
| // Find reusable fg | |||
| new_fg = cfg; | |||
| break; | |||
| } | |||
| if (new_fg != nullptr) { | |||
| // Replace current fg with existing fg | |||
| auto users = fg->func_graph_cnodes_index(); | |||
| for (auto &iter : users) { | |||
| auto cnode = iter.first->first->cast<CNodePtr>(); | |||
| auto new_input = cnode->inputs(); | |||
| auto main_graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(main_graph); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { | |||
| new_input[1] = NewValueNode(new_fg); | |||
| } else { | |||
| new_input[0] = NewValueNode(new_fg); | |||
| } | |||
| auto new_cnode = main_graph->NewCNode(new_input); | |||
| manager->Replace(iter.first->first, new_cnode); | |||
| changed = true; | |||
| } | |||
| } else { | |||
| // Add current fg to map | |||
| graph_kernel_ops[key].push_back(fg); | |||
| } | |||
| } | |||
| } else { | |||
| graph_kernel_ops[key] = {fg}; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(root); | |||
| return DoReplace(manager); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,50 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||
| #include <unordered_map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "mindspore/ccsrc/backend/session/anf_runtime_algorithm.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| // Common subexpression elimination. | |||
| class GraphKernelReuse { | |||
| public: | |||
| GraphKernelReuse() : count(0) {} | |||
| virtual ~GraphKernelReuse() = default; | |||
| bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { | |||
| bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); | |||
| return chg; | |||
| } | |||
| bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); | |||
| bool DoReplace(const FuncGraphManagerPtr manager); | |||
| bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); | |||
| private: | |||
| std::unordered_map<std::string, std::vector<FuncGraphPtr>> graph_kernel_ops; | |||
| int64_t count; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||
| @@ -29,7 +29,6 @@ | |||
| #include "frontend/optimizer/irpass/incorporate_call.h" | |||
| #include "frontend/optimizer/irpass/incorporate_getitem.h" | |||
| #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" | |||
| #include "frontend/optimizer/irpass/mark_interface_fusion.h" | |||
| #include "frontend/optimizer/irpass/merge_addn.h" | |||
| #include "frontend/optimizer/irpass/accumulaten_eliminate.h" | |||
| #include "frontend/optimizer/irpass/minmax_grad.h" | |||
| @@ -198,10 +197,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // AddN eliminate | |||
| addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel); | |||
| // Mark interface fusion | |||
| mark_interface_fusion_ = | |||
| MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | |||
| // RowTensor Eliminate | |||
| row_tensor_eliminate_ = MakeSubstitution( | |||
| std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate", | |||
| @@ -124,9 +124,6 @@ class OptimizeIRPassLib { | |||
| // AddN eliminate | |||
| SubstitutionPtr addn_eliminate_; | |||
| // Fusion | |||
| SubstitutionPtr mark_interface_fusion_; | |||
| // RowTensor Eliminate | |||
| SubstitutionPtr row_tensor_eliminate_; | |||
| @@ -1,84 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||
| #include <string> | |||
| #include <sstream> | |||
| #include <unordered_map> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "frontend/operator/composite/composite.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| static int64_t count = 0; | |||
| std::string GetFusionNumber() { | |||
| std::stringstream ss; | |||
| ss << std::setw(4) << std::setfill('0') << count; | |||
| std::string num = ss.str(); | |||
| ++count; | |||
| return "_" + num; | |||
| } | |||
| // Mark CNodes which can be merged in kernel build | |||
| class MarkInterfaceFusion : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto condition = cnode->input(1); | |||
| std::string cmp; | |||
| std::unordered_map<std::string, std::string> cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, | |||
| {"LessEqual", "LE"}, {"Less", "LT"}, | |||
| {"Equal", "EQ"}, {"NotEqual", "NE"}}; | |||
| if (IsPrimitiveCNode(condition)) { | |||
| auto prim_name = GetCNodeFuncName(condition->cast<CNodePtr>()); | |||
| if (cmp_list.count(prim_name) != 0) { | |||
| // Mark Select and compare node | |||
| cmp = cmp_list[prim_name]; | |||
| auto cnt = GetFusionNumber(); | |||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); | |||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { | |||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| void Visit(const AnfNodePtr &) override {} | |||
| private: | |||
| AnfNodePtr y_{nullptr}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||
| @@ -29,7 +29,6 @@ | |||
| #include "pipeline/jit/remove_value_node_dup.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/cse_pass.h" | |||
| #include "frontend/optimizer/graph_kernel_reuse.h" | |||
| #include "frontend/optimizer/clean.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/graph_transform.h" | |||
| @@ -261,12 +260,7 @@ OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &i | |||
| } | |||
| OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig interface_fusion = opt::OptPassConfig({ | |||
| irpass.mark_interface_fusion_, | |||
| }); | |||
| OptPassGroupMap map({ | |||
| {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, | |||
| {"interface_fusion", interface_fusion}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | |||
| }); | |||