|
|
|
@@ -283,8 +283,9 @@ class GraphSplitByPattern: |
|
|
|
if _check_reduce_exclude(dom): |
|
|
|
return None |
|
|
|
a, r = list(dom.in_relations.items())[0] |
|
|
|
if a.is_output and len(a.ops) >= 10 and _is_atomic_add_available(dom): |
|
|
|
# to evade the precision problem in akg |
|
|
|
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ |
|
|
|
_is_atomic_add_available(dom): |
|
|
|
# to evade the precision problem. |
|
|
|
return None |
|
|
|
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: |
|
|
|
return None |
|
|
|
@@ -295,10 +296,12 @@ class GraphSplitByPattern: |
|
|
|
return None |
|
|
|
if _check_reduce_exclude(dom): |
|
|
|
return None |
|
|
|
if len(dom.ops) == 1: |
|
|
|
return None |
|
|
|
fused = [] |
|
|
|
for a, r in dom.in_relations.items(): |
|
|
|
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ |
|
|
|
_is_atomic_add_available(dom): |
|
|
|
# to evade the precision problem. |
|
|
|
continue |
|
|
|
if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): |
|
|
|
fused.append(a) |
|
|
|
return fused, True |
|
|
|
|