Browse Source

[GraphKernel]Support reshape/elewise/broadcast+transdata fusion

pull/16071/head
hanhuifeng2020 4 years ago
parent
commit
bc46d644fe
2 changed files with 34 additions and 15 deletions
  1. +17
    -0
      mindspore/_extends/graph_kernel/model/graph_split.py
  2. +17
    -15
      mindspore/_extends/graph_kernel/model/model_builder.py

+ 17
- 0
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -15,7 +15,9 @@
"""Cost model splitter""" """Cost model splitter"""
import os import os
from functools import reduce from functools import reduce
from mindspore import log as logger
from .model import PrimLib, Graph, Tensor from .model import PrimLib, Graph, Tensor
from .model import DataFormat as DF


class GraphSplitByPattern: class GraphSplitByPattern:
"""Graph splitter""" """Graph splitter"""
@@ -674,6 +676,21 @@ class GraphSplitAscend(GraphSplitByPattern):


if a.dom_op().prim == "MatMul" and len(dom.ops) == 1: if a.dom_op().prim == "MatMul" and len(dom.ops) == 1:
return True return True

# reshape/elewise/broadcast + transdata
if a.pattern <= PrimLib.BROADCAST and len(dom.ops) == 1:
op_attrs = dom.dom_op().attrs
if 'src_format' not in op_attrs.keys() \
or 'dst_format' not in op_attrs.keys():
logger.error("src_format or dst_format not be found in the attrs of Transdata op")
return False
src_format, dst_format = op_attrs['src_format'], op_attrs['dst_format']
if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW):
return True
# For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported
if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ\
and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output:
return True
return False return False


def _transdata(dom): def _transdata(dom):


+ 17
- 15
mindspore/_extends/graph_kernel/model/model_builder.py View File

@@ -134,21 +134,23 @@ class CompositeGraph:
return red_axis return red_axis


attr = {} attr = {}
if op['name'] not in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
return attr
for a in op['attr']:
if a['name'] == 'axis':
red_axis, dim_size = [], len(inputs[0].shape)
if not a['value']:
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
else:
if isinstance(a['value'], int):
a['value'] = [a['value']]
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
if a['name'] == "reduce_output_fuse":
attr['reduce_output_fuse'] = a['value']
if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
for a in op['attr']:
if a['name'] == 'axis':
red_axis, dim_size = [], len(inputs[0].shape)
if not a['value']:
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
else:
if isinstance(a['value'], int):
a['value'] = [a['value']]
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
if a['name'] == "reduce_output_fuse":
attr['reduce_output_fuse'] = a['value']
elif op['attr']:
for a in op['attr']:
attr[a['name']] = a['value']
return attr return attr


builder = GraphBuilder() builder = GraphBuilder()


Loading…
Cancel
Save