| @@ -283,8 +283,9 @@ class GraphSplitByPattern: | |||||
| if _check_reduce_exclude(dom): | if _check_reduce_exclude(dom): | ||||
| return None | return None | ||||
| a, r = list(dom.in_relations.items())[0] | a, r = list(dom.in_relations.items())[0] | ||||
| if a.is_output and len(a.ops) >= 10 and _is_atomic_add_available(dom): | |||||
| # to evade the precision problem in akg | |||||
| if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ | |||||
| _is_atomic_add_available(dom): | |||||
| # to evade the precision problem. | |||||
| return None | return None | ||||
| if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: | if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: | ||||
| return None | return None | ||||
| @@ -295,10 +296,12 @@ class GraphSplitByPattern: | |||||
| return None | return None | ||||
| if _check_reduce_exclude(dom): | if _check_reduce_exclude(dom): | ||||
| return None | return None | ||||
| if len(dom.ops) == 1: | |||||
| return None | |||||
| fused = [] | fused = [] | ||||
| for a, r in dom.in_relations.items(): | for a, r in dom.in_relations.items(): | ||||
| if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ | |||||
| _is_atomic_add_available(dom): | |||||
| # to evade the precision problem. | |||||
| continue | |||||
| if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): | if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): | ||||
| fused.append(a) | fused.append(a) | ||||
| return fused, True | return fused, True | ||||