diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 8ca1ac8064..944e008825 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -283,8 +283,9 @@ class GraphSplitByPattern: if _check_reduce_exclude(dom): return None a, r = list(dom.in_relations.items())[0] - if a.is_output and len(a.ops) >= 10 and _is_atomic_add_available(dom): - # to evade the precision problem in akg + if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ + _is_atomic_add_available(dom): + # to evade the precision problem. return None if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: return None @@ -295,10 +296,12 @@ class GraphSplitByPattern: return None if _check_reduce_exclude(dom): return None - if len(dom.ops) == 1: - return None fused = [] for a, r in dom.in_relations.items(): + if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ + _is_atomic_add_available(dom): + # to evade the precision problem. + continue if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): fused.append(a) return fused, True