|
|
|
@@ -424,6 +424,41 @@ class GraphSplitGpu(GraphSplitByPattern): |
|
|
|
fused.append(a) |
|
|
|
return fused, False |
|
|
|
|
|
|
|
def _stitch_axis(shape): |
|
|
|
stitch_axis = [] |
|
|
|
size = 1 |
|
|
|
for i in shape: |
|
|
|
size = size * i |
|
|
|
stitch_axis.append(i) |
|
|
|
if size >= 1024 * 8: |
|
|
|
return stitch_axis |
|
|
|
return None |
|
|
|
|
|
|
|
def _same_stitch_axis(a, b): |
|
|
|
x = [] |
|
|
|
x.extend(a) |
|
|
|
x.extend(b) |
|
|
|
stitch_axis = _stitch_axis(x[0].shape) |
|
|
|
for item in x: |
|
|
|
i_stitch_axis = _stitch_axis(item.shape) |
|
|
|
if i_stitch_axis is None or i_stitch_axis != stitch_axis: |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def _may_stitch(dom, a, r): |
|
|
|
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): |
|
|
|
if _reduce_nums(a.ops) < 2: |
|
|
|
dom_outs = [op.output for op in dom.ops] |
|
|
|
a_ins = [input for op in a.ops for input in op.inputs] |
|
|
|
a_outs = [op.output for op in a.ops] |
|
|
|
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins] |
|
|
|
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins] |
|
|
|
if _same_stitch_axis(stitch_tensors, a_final_outs): |
|
|
|
for tensor in stitch_tensors: |
|
|
|
if _tensor_size(tensor) >= 1024 * 1024 * 12: |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def _reduce_stitch(dom): |
|
|
|
if dom.pattern != PrimLib.REDUCE: |
|
|
|
return None |
|
|
|
@@ -434,12 +469,14 @@ class GraphSplitGpu(GraphSplitByPattern): |
|
|
|
|
|
|
|
fused = [] |
|
|
|
for a, r in dom.out_relations.items(): |
|
|
|
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): |
|
|
|
if _reduce_nums(a.ops) < 2: |
|
|
|
# softmax |
|
|
|
if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: |
|
|
|
if _may_stitch(dom, a, r): |
|
|
|
if a.pattern == PrimLib.REDUCE: |
|
|
|
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']: |
|
|
|
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) |
|
|
|
fused.append(a) |
|
|
|
elif a.pattern == PrimLib.BROADCAST: |
|
|
|
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) |
|
|
|
fused.append(a) |
|
|
|
return fused, False |
|
|
|
|
|
|
|
def _transpose(dom): |
|
|
|
|