|
|
|
@@ -33,7 +33,7 @@ class GraphSplitByPattern: |
|
|
|
self.out_relations = dict() # {area1: relation1, area2: relation2, ...} |
|
|
|
self.mode = self.MODE_BASIC |
|
|
|
if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \ |
|
|
|
(use_poly_reduce and self.pattern == PrimLib.REDUCE): |
|
|
|
(use_poly_reduce and self.pattern == PrimLib.REDUCE): |
|
|
|
self.mode = self.MODE_COMPOSITE |
|
|
|
if init_op.prim == "AddN": |
|
|
|
self.mode = self.MODE_COMPOSITE |
|
|
|
@@ -283,6 +283,9 @@ class GraphSplitByPattern: |
|
|
|
if _check_reduce_exclude(dom): |
|
|
|
return None |
|
|
|
a, r = list(dom.in_relations.items())[0] |
|
|
|
if a.is_output and len(a.ops) >= 10 and _is_atomic_add_available(dom): |
|
|
|
# to evade the precision problem in akg |
|
|
|
return None |
|
|
|
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: |
|
|
|
return None |
|
|
|
return [a], True |
|
|
|
@@ -292,6 +295,8 @@ class GraphSplitByPattern: |
|
|
|
return None |
|
|
|
if _check_reduce_exclude(dom): |
|
|
|
return None |
|
|
|
if len(dom.ops) == 1: |
|
|
|
return None |
|
|
|
fused = [] |
|
|
|
for a, r in dom.in_relations.items(): |
|
|
|
if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): |
|
|
|
@@ -304,16 +309,17 @@ class GraphSplitByPattern: |
|
|
|
size *= i |
|
|
|
return size |
|
|
|
|
|
|
|
def _is_atomic_add_available(dom): |
|
|
|
if any(["Reduce" in x.prim for x in dom.ops[1:]]): |
|
|
|
return False |
|
|
|
op = dom.ops[0] |
|
|
|
reduce_axis = op.attrs["reduce_axis"] |
|
|
|
if len(op.inputs[0].shape) - 1 in reduce_axis: |
|
|
|
reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) |
|
|
|
return reduce_size >= 1024 |
|
|
|
return True |
|
|
|
|
|
|
|
def _reduce_output(dom): |
|
|
|
def _is_atomic_add_available(dom): |
|
|
|
if any(["Reduce" in x.prim for x in dom.ops[1:]]): |
|
|
|
return False |
|
|
|
op = dom.ops[0] |
|
|
|
reduce_axis = op.attrs["reduce_axis"] |
|
|
|
if len(op.inputs[0].shape) - 1 in reduce_axis: |
|
|
|
reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) |
|
|
|
return reduce_size >= 1024 |
|
|
|
return True |
|
|
|
if dom.pattern != PrimLib.REDUCE: |
|
|
|
return None |
|
|
|
if _is_atomic_add_available(dom): |
|
|
|
|