From: @tronzhang Reviewed-by: @gaoxiong1,@gaoxiong1,@dylangeng Signed-off-by: @dylangengpull/13834/MERGE
| @@ -1 +1 @@ | |||||
| Subproject commit 78799c123d966ce78ae7c9f6860264628262a16e | |||||
| Subproject commit 1b0aacec7c125083dbe9fd54174495f92f6bb191 | |||||
| @@ -24,12 +24,10 @@ class ClipByNormNoDivSum(Expander): | |||||
| input_x0, input_x1, input_x2, input_x3 = self.inputs | input_x0, input_x1, input_x2, input_x3 = self.inputs | ||||
| # cal result | # cal result | ||||
| greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'}) | |||||
| select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2], | |||||
| attrs={'fusion': 'SelectGT_000_end'}) | |||||
| greater_res = graph_builder.emit('Greater', [input_x0, input_x1]) | |||||
| select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2]) | |||||
| sqrt_res = graph_builder.emit('Sqrt', [select_res0]) | sqrt_res = graph_builder.emit('Sqrt', [select_res0]) | ||||
| select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0], | |||||
| attrs={'fusion': 'SelectGT_000_end'}) | |||||
| select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0]) | |||||
| result = graph_builder.emit('Maximum', [select_res1, input_x3]) | result = graph_builder.emit('Maximum', [select_res1, input_x3]) | ||||
| return result | return result | ||||
| @@ -123,7 +123,7 @@ class GraphSplitByPattern: | |||||
| self.output_excluded.update(area.output_excluded) | self.output_excluded.update(area.output_excluded) | ||||
| self.update_stitch_info(area.stitch_info) | self.update_stitch_info(area.stitch_info) | ||||
| def check_circle(self, to): | |||||
| def check_acyclic(self, to): | |||||
| """Check circle. It returns false if circle exists""" | """Check circle. It returns false if circle exists""" | ||||
| def _reached(area, to): | def _reached(area, to): | ||||
| for out, _ in area.out_relations.items(): | for out, _ in area.out_relations.items(): | ||||
| @@ -215,6 +215,10 @@ class GraphSplitByPattern: | |||||
| with open(filename, 'w') as f: | with open(filename, 'w') as f: | ||||
| f.write(subgraphs_str) | f.write(subgraphs_str) | ||||
| def do_split(self): | |||||
| """Split graph by pattern""" | |||||
| raise Exception("do_split() is not implemented in {}".format(self.__class__.__name__)) | |||||
| def split(self): | def split(self): | ||||
| """Split graph by pattern""" | """Split graph by pattern""" | ||||
| self.do_split() | self.do_split() | ||||
| @@ -270,11 +274,11 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| return None | return None | ||||
| min_area, forward_fuse = None, False | min_area, forward_fuse = None, False | ||||
| for a, _ in dom.out_relations.items(): | for a, _ in dom.out_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | |||||
| if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ | |||||
| (min_area is None or a.pattern < min_area.pattern): | (min_area is None or a.pattern < min_area.pattern): | ||||
| min_area = a | min_area = a | ||||
| for a, _ in dom.in_relations.items(): | for a, _ in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | |||||
| if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ | |||||
| len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ | len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ | ||||
| (min_area is None or a.pattern < min_area.pattern): | (min_area is None or a.pattern < min_area.pattern): | ||||
| min_area, forward_fuse = a, True | min_area, forward_fuse = a, True | ||||
| @@ -294,7 +298,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.in_relations.items(): | for a, r in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | |||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \ | |||||
| a.dom_op().output.shape == dom.dom_op().output.shape: | a.dom_op().output.shape == dom.dom_op().output.shape: | ||||
| fused.append(a) | fused.append(a) | ||||
| return fused, True | return fused, True | ||||
| @@ -319,7 +323,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.out_relations.items(): | for a, r in dom.out_relations.items(): | ||||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | |||||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \ | |||||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | ||||
| return None | return None | ||||
| fused.append(a) | fused.append(a) | ||||
| @@ -341,7 +345,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| return True | return True | ||||
| return False | return False | ||||
| def _reduce_pat_exclude(dom, a, r): | |||||
| def _reduce_pat_exclude(_, a, r): | |||||
| if len(a.ops) > self.REDUCE_FUSE_DEPTH: | if len(a.ops) > self.REDUCE_FUSE_DEPTH: | ||||
| return True | return True | ||||
| if use_poly_reduce: | if use_poly_reduce: | ||||
| @@ -373,7 +377,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| _is_atomic_add_available(dom): | _is_atomic_add_available(dom): | ||||
| # to evade the precision problem. | # to evade the precision problem. | ||||
| continue | continue | ||||
| if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): | |||||
| if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom): | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused, True | return fused, True | ||||
| @@ -415,7 +419,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| fused = [] | fused = [] | ||||
| for a, r in dom.out_relations.items(): | for a, r in dom.out_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | ||||
| dom.check_circle(a) and not dom.reduce_out_exclude(a): | |||||
| dom.check_acyclic(a) and not dom.reduce_out_exclude(a): | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused, False | return fused, False | ||||
| @@ -429,7 +433,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| fused = [] | fused = [] | ||||
| for a, r in dom.out_relations.items(): | for a, r in dom.out_relations.items(): | ||||
| if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_circle(a): | |||||
| if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): | |||||
| if _reduce_nums(a.ops) < 2: | if _reduce_nums(a.ops) < 2: | ||||
| # softmax | # softmax | ||||
| if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: | if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: | ||||
| @@ -442,7 +446,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, _ in dom.in_relations.items(): | for a, _ in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom): | |||||
| if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom): | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused, True | return fused, True | ||||
| @@ -490,11 +494,11 @@ class GraphSplitAscend(GraphSplitByPattern): | |||||
| return None | return None | ||||
| min_area, forward_fuse = None, False | min_area, forward_fuse = None, False | ||||
| for a, _ in dom.out_relations.items(): | for a, _ in dom.out_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | |||||
| if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ | |||||
| (min_area is None or a.pattern < min_area.pattern): | (min_area is None or a.pattern < min_area.pattern): | ||||
| min_area = a | min_area = a | ||||
| for a, _ in dom.in_relations.items(): | for a, _ in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | |||||
| if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ | |||||
| len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ | len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ | ||||
| (min_area is None or a.pattern < min_area.pattern): | (min_area is None or a.pattern < min_area.pattern): | ||||
| min_area, forward_fuse = a, True | min_area, forward_fuse = a, True | ||||
| @@ -514,7 +518,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.in_relations.items(): | for a, r in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | |||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \ | |||||
| a.dom_op().output.shape == dom.dom_op().output.shape: | a.dom_op().output.shape == dom.dom_op().output.shape: | ||||
| fused.append(a) | fused.append(a) | ||||
| return fused, True | return fused, True | ||||
| @@ -537,7 +541,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.out_relations.items(): | for a, r in dom.out_relations.items(): | ||||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | |||||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \ | |||||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | ||||
| return None | return None | ||||
| fused.append(a) | fused.append(a) | ||||
| @@ -564,7 +568,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.in_relations.items(): | for a, r in dom.in_relations.items(): | ||||
| if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): | |||||
| if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom): | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused, True | return fused, True | ||||
| @@ -665,10 +665,6 @@ bool AkgKernelJsonGenerator::GenSingleJsons(const std::vector<AnfNodePtr> &anf_n | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| if (primitive->GetAttr("fusion") != nullptr) { | |||||
| node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); | |||||
| } | |||||
| (*node_json_map)[anf_node] = node_json; | (*node_json_map)[anf_node] = node_json; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -49,7 +49,6 @@ constexpr auto kJsonKeyPtrAddress = "ptr_address"; | |||||
| constexpr auto kJsonKeyCompositeGraph = "composite_graph"; | constexpr auto kJsonKeyCompositeGraph = "composite_graph"; | ||||
| constexpr auto kJsonKeyPlatform = "platform"; | constexpr auto kJsonKeyPlatform = "platform"; | ||||
| constexpr auto kJsonKeyOpFullName = "op_full_name"; | constexpr auto kJsonKeyOpFullName = "op_full_name"; | ||||
| constexpr auto kJsonKeyFusion = "fusion"; | |||||
| constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; | constexpr auto kJsonKeyParallelFusion = "parallel_fusion"; | ||||
| constexpr auto kJsonKeyFusionType = "fusion_type"; | constexpr auto kJsonKeyFusionType = "fusion_type"; | ||||
| constexpr auto kJsonKeySubGraph = "sub_graph"; | constexpr auto kJsonKeySubGraph = "sub_graph"; | ||||
| @@ -207,7 +207,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i | |||||
| MS_EXCEPTION_IF_NULL(inputs_ptr); | MS_EXCEPTION_IF_NULL(inputs_ptr); | ||||
| auto nodes = TopoSort(fg->get_return()); | auto nodes = TopoSort(fg->get_return()); | ||||
| OrderedMap<ValuePtr, AnfNodePtrList> vmap; | |||||
| std::vector<std::pair<tensor::TensorPtr, AnfNodePtrList>> v_replace; | |||||
| std::vector<AnfNodePtr> scalar_tensors; | std::vector<AnfNodePtr> scalar_tensors; | ||||
| for (const auto &node : nodes) { | for (const auto &node : nodes) { | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| @@ -223,14 +223,21 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i | |||||
| if (TensorElementAllTheSame(tensor)) { | if (TensorElementAllTheSame(tensor)) { | ||||
| scalar_tensors.emplace_back(tnode); | scalar_tensors.emplace_back(tnode); | ||||
| } else { | } else { | ||||
| vmap[GetValueNode(tnode)].push_back(tnode); | |||||
| auto tensor_iter = std::find_if( | |||||
| v_replace.begin(), v_replace.end(), | |||||
| [&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); }); | |||||
| if (tensor_iter == v_replace.end()) { | |||||
| v_replace.emplace_back(tensor, AnfNodePtrList{tnode}); | |||||
| } else { | |||||
| tensor_iter->second.push_back(tnode); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| ReplaceTensorWithScalar(fg, scalar_tensors); | ReplaceTensorWithScalar(fg, scalar_tensors); | ||||
| if (vmap.empty()) { | |||||
| if (v_replace.empty()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -241,7 +248,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i | |||||
| } | } | ||||
| auto &inputs = *inputs_ptr; | auto &inputs = *inputs_ptr; | ||||
| for (auto iter : vmap) { | |||||
| for (auto iter : v_replace) { | |||||
| auto value_nodes = iter.second; | auto value_nodes = iter.second; | ||||
| if (value_nodes.empty()) { | if (value_nodes.empty()) { | ||||
| MS_LOG(EXCEPTION) << "Invalid value in map!"; | MS_LOG(EXCEPTION) << "Invalid value in map!"; | ||||
| @@ -1,152 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/optimizer/graph_kernel_reuse.h" | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <string> | |||||
| #include "ir/graph_utils.h" | |||||
| namespace mindspore { | |||||
| /* namespace to support opt */ | |||||
| namespace opt { | |||||
| bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { | |||||
| if (a->abstract() && b->abstract()) { | |||||
| auto a_type = a->abstract()->GetTypeTrack(); | |||||
| auto b_type = b->abstract()->GetTypeTrack(); | |||||
| if (a_type != b_type) { | |||||
| return false; | |||||
| } | |||||
| auto a_shape = a->abstract()->GetShapeTrack(); | |||||
| auto b_shape = b->abstract()->GetShapeTrack(); | |||||
| if (a_shape != nullptr && a_shape == b_shape) { | |||||
| return true; | |||||
| } | |||||
| if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() && | |||||
| b_shape->isa<abstract::Shape>()) { | |||||
| return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape(); | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { | |||||
| bool changed = false; | |||||
| auto fgs = manager->func_graphs(); | |||||
| for (FuncGraphPtr &fg : fgs) { | |||||
| if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| continue; | |||||
| } | |||||
| std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||||
| if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { | |||||
| if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { | |||||
| FuncGraphPtr new_fg = nullptr; | |||||
| for (auto &cfg : graph_kernel_ops[key]) { | |||||
| // If two graphs have different size then continue | |||||
| auto fg_topos = TopoSort(fg->get_return()); | |||||
| auto cfg_topos = TopoSort(cfg->get_return()); | |||||
| if (fg_topos.size() != cfg_topos.size()) { | |||||
| continue; | |||||
| } | |||||
| // Compare const tensor | |||||
| bool has_same = true; | |||||
| for (size_t i = 0; i < fg_topos.size(); ++i) { | |||||
| if (IsValueNode<tensor::Tensor>(fg_topos[i])) { | |||||
| if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) { | |||||
| has_same = false; | |||||
| break; | |||||
| } | |||||
| auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]); | |||||
| auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]); | |||||
| if (!tensor1->ValueEqual(*tensor2)) { | |||||
| has_same = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!has_same) { | |||||
| continue; | |||||
| } | |||||
| auto fg_input = fg->parameters(); | |||||
| auto cfg_input = cfg->parameters(); | |||||
| if (fg_input.size() != cfg_input.size()) { | |||||
| continue; | |||||
| } | |||||
| // Compare input | |||||
| for (size_t i = 0; i < fg_input.size(); ++i) { | |||||
| if (!CompareNode(fg_input[i], cfg_input[i])) { | |||||
| has_same = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!has_same) { | |||||
| continue; | |||||
| } | |||||
| // Compare output | |||||
| if (!CompareNode(fg->output(), cfg->output())) { | |||||
| continue; | |||||
| } | |||||
| // Find reusable fg | |||||
| new_fg = cfg; | |||||
| break; | |||||
| } | |||||
| if (new_fg != nullptr) { | |||||
| // Replace current fg with existing fg | |||||
| auto users = fg->func_graph_cnodes_index(); | |||||
| for (auto &iter : users) { | |||||
| auto cnode = iter.first->first->cast<CNodePtr>(); | |||||
| auto new_input = cnode->inputs(); | |||||
| auto main_graph = cnode->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(main_graph); | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { | |||||
| new_input[1] = NewValueNode(new_fg); | |||||
| } else { | |||||
| new_input[0] = NewValueNode(new_fg); | |||||
| } | |||||
| auto new_cnode = main_graph->NewCNode(new_input); | |||||
| manager->Replace(iter.first->first, new_cnode); | |||||
| changed = true; | |||||
| } | |||||
| } else { | |||||
| // Add current fg to map | |||||
| graph_kernel_ops[key].push_back(fg); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| graph_kernel_ops[key] = {fg}; | |||||
| } | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| manager->AddFuncGraph(root); | |||||
| return DoReplace(manager); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,50 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||||
| #include <unordered_map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "mindspore/ccsrc/backend/session/anf_runtime_algorithm.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| // Common subexpression elimination. | |||||
| class GraphKernelReuse { | |||||
| public: | |||||
| GraphKernelReuse() : count(0) {} | |||||
| virtual ~GraphKernelReuse() = default; | |||||
| bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { | |||||
| bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); | |||||
| return chg; | |||||
| } | |||||
| bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); | |||||
| bool DoReplace(const FuncGraphManagerPtr manager); | |||||
| bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); | |||||
| private: | |||||
| std::unordered_map<std::string, std::vector<FuncGraphPtr>> graph_kernel_ops; | |||||
| int64_t count; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||||
| @@ -29,7 +29,6 @@ | |||||
| #include "frontend/optimizer/irpass/incorporate_call.h" | #include "frontend/optimizer/irpass/incorporate_call.h" | ||||
| #include "frontend/optimizer/irpass/incorporate_getitem.h" | #include "frontend/optimizer/irpass/incorporate_getitem.h" | ||||
| #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" | #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/mark_interface_fusion.h" | |||||
| #include "frontend/optimizer/irpass/merge_addn.h" | #include "frontend/optimizer/irpass/merge_addn.h" | ||||
| #include "frontend/optimizer/irpass/accumulaten_eliminate.h" | #include "frontend/optimizer/irpass/accumulaten_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/minmax_grad.h" | #include "frontend/optimizer/irpass/minmax_grad.h" | ||||
| @@ -198,10 +197,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // AddN eliminate | // AddN eliminate | ||||
| addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel); | addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel); | ||||
| // Mark interface fusion | |||||
| mark_interface_fusion_ = | |||||
| MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | |||||
| // RowTensor Eliminate | // RowTensor Eliminate | ||||
| row_tensor_eliminate_ = MakeSubstitution( | row_tensor_eliminate_ = MakeSubstitution( | ||||
| std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate", | std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate", | ||||
| @@ -124,9 +124,6 @@ class OptimizeIRPassLib { | |||||
| // AddN eliminate | // AddN eliminate | ||||
| SubstitutionPtr addn_eliminate_; | SubstitutionPtr addn_eliminate_; | ||||
| // Fusion | |||||
| SubstitutionPtr mark_interface_fusion_; | |||||
| // RowTensor Eliminate | // RowTensor Eliminate | ||||
| SubstitutionPtr row_tensor_eliminate_; | SubstitutionPtr row_tensor_eliminate_; | ||||
| @@ -1,84 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||||
| #include <string> | |||||
| #include <sstream> | |||||
| #include <unordered_map> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "frontend/optimizer/irpass.h" | |||||
| #include "frontend/optimizer/anf_visitor.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "ir/graph_utils.h" | |||||
| #include "frontend/operator/composite/composite.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| static int64_t count = 0; | |||||
| std::string GetFusionNumber() { | |||||
| std::stringstream ss; | |||||
| ss << std::setw(4) << std::setfill('0') << count; | |||||
| std::string num = ss.str(); | |||||
| ++count; | |||||
| return "_" + num; | |||||
| } | |||||
| // Mark CNodes which can be merged in kernel build | |||||
| class MarkInterfaceFusion : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto condition = cnode->input(1); | |||||
| std::string cmp; | |||||
| std::unordered_map<std::string, std::string> cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, | |||||
| {"LessEqual", "LE"}, {"Less", "LT"}, | |||||
| {"Equal", "EQ"}, {"NotEqual", "NE"}}; | |||||
| if (IsPrimitiveCNode(condition)) { | |||||
| auto prim_name = GetCNodeFuncName(condition->cast<CNodePtr>()); | |||||
| if (cmp_list.count(prim_name) != 0) { | |||||
| // Mark Select and compare node | |||||
| cmp = cmp_list[prim_name]; | |||||
| auto cnt = GetFusionNumber(); | |||||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); | |||||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { | |||||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void Visit(const AnfNodePtr &) override {} | |||||
| private: | |||||
| AnfNodePtr y_{nullptr}; | |||||
| }; | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||||
| @@ -29,7 +29,6 @@ | |||||
| #include "pipeline/jit/remove_value_node_dup.h" | #include "pipeline/jit/remove_value_node_dup.h" | ||||
| #include "frontend/optimizer/optimizer.h" | #include "frontend/optimizer/optimizer.h" | ||||
| #include "frontend/optimizer/cse_pass.h" | #include "frontend/optimizer/cse_pass.h" | ||||
| #include "frontend/optimizer/graph_kernel_reuse.h" | |||||
| #include "frontend/optimizer/clean.h" | #include "frontend/optimizer/clean.h" | ||||
| #include "frontend/optimizer/irpass.h" | #include "frontend/optimizer/irpass.h" | ||||
| #include "frontend/optimizer/graph_transform.h" | #include "frontend/optimizer/graph_transform.h" | ||||
| @@ -261,12 +260,7 @@ OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &i | |||||
| } | } | ||||
| OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig interface_fusion = opt::OptPassConfig({ | |||||
| irpass.mark_interface_fusion_, | |||||
| }); | |||||
| OptPassGroupMap map({ | OptPassGroupMap map({ | ||||
| {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, | |||||
| {"interface_fusion", interface_fusion}, | |||||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | {"renormalize", opt::OptPassConfig::Renormalize()}, | ||||
| {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | ||||
| }); | }); | ||||