|
|
|
@@ -685,6 +685,19 @@ class GraphSplitAscend(GraphSplitByPattern): |
|
|
|
fused.append(a) |
|
|
|
return fused, True |
|
|
|
|
|
|
|
def _reduce_output(dom): |
|
|
|
if dom.pattern != PrimLib.REDUCE: |
|
|
|
return None |
|
|
|
op_attrs = dom.dom_op().attrs |
|
|
|
if not op_attrs.get('reduce_output_fuse'): |
|
|
|
return None |
|
|
|
fused = [] |
|
|
|
for a, r in dom.out_relations.items(): |
|
|
|
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ |
|
|
|
dom.check_acyclic(a): |
|
|
|
fused.append(a) |
|
|
|
return fused, False |
|
|
|
|
|
|
|
changed = True |
|
|
|
while changed: |
|
|
|
changed = self.fuse(_reshape) |
|
|
|
@@ -695,6 +708,7 @@ class GraphSplitAscend(GraphSplitByPattern): |
|
|
|
changed = self.fuse(_broadcast_depth) or changed |
|
|
|
changed = self.fuse(_broadcast_width) or changed |
|
|
|
changed = self.fuse(_matmul_depth) or changed |
|
|
|
changed = self.fuse(_reduce_output) or changed |
|
|
|
self.fuse(_transdata) |
|
|
|
|
|
|
|
|
|
|
|
|