|
|
|
@@ -395,7 +395,7 @@ class GraphSplitByPattern: |
|
|
|
dom = areas[i] |
|
|
|
for a in areas[i + 1:]: |
|
|
|
if dom.check_acyclic(a) and a.check_acyclic(dom) and \ |
|
|
|
selector(dom, a) and self.limit_area_size(dom, [a], 64): |
|
|
|
selector(dom, a) and self.limit_area_size(dom, [a], 64): |
|
|
|
dom.fuse(a) |
|
|
|
self.set_area_map(a.ops, dom) |
|
|
|
self.areas.remove(a) |
|
|
|
@@ -909,28 +909,6 @@ class GraphSplitGpu(GraphSplitByPattern): |
|
|
|
return [a], True |
|
|
|
return None |
|
|
|
|
|
|
|
def _h_broadcast(dom, a): |
|
|
|
if dom.pattern > PrimLib.BROADCAST: |
|
|
|
return None |
|
|
|
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape |
|
|
|
|
|
|
|
def _h_reduce(dom, a): |
|
|
|
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops: |
|
|
|
return None |
|
|
|
dom_op = dom.ops[0] |
|
|
|
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom): |
|
|
|
return None |
|
|
|
op = a.ops[0] |
|
|
|
return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ |
|
|
|
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ |
|
|
|
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis") |
|
|
|
|
|
|
|
def _h_opaque(dom, a): |
|
|
|
if dom.ops[0].prim not in {"StridedSlice"}: |
|
|
|
return None |
|
|
|
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ |
|
|
|
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape |
|
|
|
|
|
|
|
def _fuse_loop(): |
|
|
|
changed = True |
|
|
|
while changed: |
|
|
|
@@ -948,9 +926,6 @@ class GraphSplitGpu(GraphSplitByPattern): |
|
|
|
if self.enable_stitch_fusion: |
|
|
|
changed = self.fuse(_reduce_stitch) or changed |
|
|
|
self.fuse(_transpose) |
|
|
|
self.hfuse(_h_broadcast) |
|
|
|
self.hfuse(_h_reduce) |
|
|
|
self.hfuse(_h_opaque) |
|
|
|
|
|
|
|
def _fuse_once(fuse_func): |
|
|
|
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \ |
|
|
|
|