Browse Source

[GraphKernel]add the attr reduce_output_fuse to enable fuse for reduce_output on Ascend

tags/v1.3.0
hanhuifeng2020 5 years ago
parent
commit
425d401e85
2 changed files with 16 additions and 1 deletions
  1. +14
    -0
      mindspore/_extends/graph_kernel/model/graph_split.py
  2. +2
    -1
      mindspore/_extends/graph_kernel/model/model_builder.py

+ 14
- 0
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -617,6 +617,19 @@ class GraphSplitAscend(GraphSplitByPattern):
fused.append(a)
return fused, True

def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
op_attrs = dom.dom_op().attrs
if not op_attrs.get('reduce_output_fuse'):
return None
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
dom.check_acyclic(a):
fused.append(a)
return fused, False

changed = True
while changed:
changed = self.fuse(_reshape)
@@ -627,6 +640,7 @@ class GraphSplitAscend(GraphSplitByPattern):
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_matmul_depth) or changed
changed = self.fuse(_reduce_output) or changed
self.fuse(_transdata)




+ 2
- 1
mindspore/_extends/graph_kernel/model/model_builder.py View File

@@ -147,7 +147,8 @@ class CompositeGraph:
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
break
if a['name'] == "reduce_output_fuse":
attr['reduce_output_fuse'] = a['value']
return attr

builder = GraphBuilder()


Loading…
Cancel
Save