diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index c2b67eafce..a5ad57fdd9 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -685,6 +685,19 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, True + def _reduce_output(dom): + if dom.pattern != PrimLib.REDUCE: + return None + op_attrs = dom.dom_op().attrs + if not op_attrs.get('reduce_output_fuse'): + return None + fused = [] + for a, r in dom.out_relations.items(): + if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ + dom.check_acyclic(a): + fused.append(a) + return fused, False + changed = True while changed: changed = self.fuse(_reshape) @@ -695,6 +708,7 @@ class GraphSplitAscend(GraphSplitByPattern): changed = self.fuse(_broadcast_depth) or changed changed = self.fuse(_broadcast_width) or changed changed = self.fuse(_matmul_depth) or changed + changed = self.fuse(_reduce_output) or changed self.fuse(_transdata) diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index 4125b75e1b..5ff88f4e5e 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -147,7 +147,8 @@ class CompositeGraph: for i in a['value']: red_axis.append(i if i >= 0 else dim_size + i) attr['reduce_axis'] = red_axis - break + if a['name'] == "reduce_output_fuse": + attr['reduce_output_fuse'] = a['value'] return attr builder = GraphBuilder()