Browse Source

!15951 [GraphKernel]add the attribute reduce_output_fuse to enable fuse for the reduce_output on Ascend

From: @hanhuifeng2020
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
pull/15951/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
0887d35b1c
2 changed files with 16 additions and 1 deletions
  1. +14
    -0
      mindspore/_extends/graph_kernel/model/graph_split.py
  2. +2
    -1
      mindspore/_extends/graph_kernel/model/model_builder.py

+ 14
- 0
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -685,6 +685,19 @@ class GraphSplitAscend(GraphSplitByPattern):
fused.append(a)
return fused, True

def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
op_attrs = dom.dom_op().attrs
if not op_attrs.get('reduce_output_fuse'):
return None
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
dom.check_acyclic(a):
fused.append(a)
return fused, False

changed = True
while changed:
changed = self.fuse(_reshape)
@@ -695,6 +708,7 @@ class GraphSplitAscend(GraphSplitByPattern):
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_matmul_depth) or changed
changed = self.fuse(_reduce_output) or changed
self.fuse(_transdata)




+ 2
- 1
mindspore/_extends/graph_kernel/model/model_builder.py View File

@@ -147,7 +147,8 @@ class CompositeGraph:
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
break
if a['name'] == "reduce_output_fuse":
attr['reduce_output_fuse'] = a['value']
return attr

builder = GraphBuilder()


Loading…
Cancel
Save