Browse Source

update akg/switch off h-fuse

tags/v1.6.0
Yang Jiao 4 years ago
parent
commit
07a0c24126
2 changed files with 2 additions and 27 deletions
  1. +1
    -1
      akg
  2. +1
    -26
      mindspore/_extends/graph_kernel/model/graph_split.py

+ 1
- 1
akg

@@ -1 +1 @@
Subproject commit a26bad300e932615d175fc7d34b9e213b5e811aa
Subproject commit 1c72d655e24d6def68e3345f3e6cd1a86efde64b

+ 1
- 26
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -395,7 +395,7 @@ class GraphSplitByPattern:
dom = areas[i]
for a in areas[i + 1:]:
if dom.check_acyclic(a) and a.check_acyclic(dom) and \
selector(dom, a) and self.limit_area_size(dom, [a], 64):
selector(dom, a) and self.limit_area_size(dom, [a], 64):
dom.fuse(a)
self.set_area_map(a.ops, dom)
self.areas.remove(a)
@@ -909,28 +909,6 @@ class GraphSplitGpu(GraphSplitByPattern):
return [a], True
return None

def _h_broadcast(dom, a):
if dom.pattern > PrimLib.BROADCAST:
return None
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape

def _h_reduce(dom, a):
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
return None
dom_op = dom.ops[0]
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
return None
op = a.ops[0]
return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis")

def _h_opaque(dom, a):
if dom.ops[0].prim not in {"StridedSlice"}:
return None
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape

def _fuse_loop():
changed = True
while changed:
@@ -948,9 +926,6 @@ class GraphSplitGpu(GraphSplitByPattern):
if self.enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transpose)
self.hfuse(_h_broadcast)
self.hfuse(_h_reduce)
self.hfuse(_h_opaque)

def _fuse_once(fuse_func):
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \


Loading…
Cancel
Save