Browse Source

!14796 [GRAPH KERNEL]optimize stitch fusion strategy

From: @r1chardf1d0
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
pull/14796/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
7585362148
2 changed files with 42 additions and 5 deletions
  1. +1
    -1
      akg
  2. +41
    -4
      mindspore/_extends/graph_kernel/model/graph_split.py

+ 1
- 1
akg

@@ -1 +1 @@
Subproject commit e2a30d6b8ece4a69790ac9e37ae862fe8124ad7c
Subproject commit d91d772a3a913f20eaef6c47517b9ca140edaee2

+ 41
- 4
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -424,6 +424,41 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a)
return fused, False

def _stitch_axis(shape):
stitch_axis = []
size = 1
for i in shape:
size = size * i
stitch_axis.append(i)
if size >= 1024 * 8:
return stitch_axis
return None

def _same_stitch_axis(a, b):
x = []
x.extend(a)
x.extend(b)
stitch_axis = _stitch_axis(x[0].shape)
for item in x:
i_stitch_axis = _stitch_axis(item.shape)
if i_stitch_axis is None or i_stitch_axis != stitch_axis:
return False
return True

def _may_stitch(dom, a, r):
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if _reduce_nums(a.ops) < 2:
dom_outs = [op.output for op in dom.ops]
a_ins = [input for op in a.ops for input in op.inputs]
a_outs = [op.output for op in a.ops]
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
if _same_stitch_axis(stitch_tensors, a_final_outs):
for tensor in stitch_tensors:
if _tensor_size(tensor) >= 1024 * 1024 * 12:
return True
return False

def _reduce_stitch(dom):
if dom.pattern != PrimLib.REDUCE:
return None
@@ -434,12 +469,14 @@ class GraphSplitGpu(GraphSplitByPattern):

fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if _reduce_nums(a.ops) < 2:
# softmax
if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4:
if _may_stitch(dom, a, r):
if a.pattern == PrimLib.REDUCE:
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
elif a.pattern == PrimLib.BROADCAST:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
return fused, False

def _transpose(dom):


Loading…
Cancel
Save