Browse Source

!9699 【GraphKernel】Add a float16 restriction in graph_split

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
56e54a3737
1 changed files with 7 additions and 4 deletions
  1. +7
    -4
      mindspore/_extends/graph_kernel/model/graph_split.py

+ 7
- 4
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -283,8 +283,9 @@ class GraphSplitByPattern:
if _check_reduce_exclude(dom):
return None
a, r = list(dom.in_relations.items())[0]
if a.is_output and len(a.ops) >= 10 and _is_atomic_add_available(dom):
# to evade the precision problem in akg
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
_is_atomic_add_available(dom):
# to evade the precision problem.
return None
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
return None
@@ -295,10 +296,12 @@ class GraphSplitByPattern:
return None
if _check_reduce_exclude(dom):
return None
if len(dom.ops) == 1:
return None
fused = []
for a, r in dom.in_relations.items():
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
_is_atomic_add_available(dom):
# to evade the precision problem.
continue
if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom):
fused.append(a)
return fused, True


Loading…
Cancel
Save