Browse Source

[GraphKernel]Support reshape/elewise/broadcast+transdata fusion

pull/16071/head
hanhuifeng2020 4 years ago
parent
commit
bc46d644fe
2 changed files with 34 additions and 15 deletions
  1. +17
    -0
      mindspore/_extends/graph_kernel/model/graph_split.py
  2. +17
    -15
      mindspore/_extends/graph_kernel/model/model_builder.py

+ 17
- 0
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -15,7 +15,9 @@
"""Cost model splitter"""
import os
from functools import reduce
from mindspore import log as logger
from .model import PrimLib, Graph, Tensor
from .model import DataFormat as DF

class GraphSplitByPattern:
"""Graph splitter"""
@@ -674,6 +676,21 @@ class GraphSplitAscend(GraphSplitByPattern):

if a.dom_op().prim == "MatMul" and len(dom.ops) == 1:
return True

# reshape/elewise/broadcast + transdata
if a.pattern <= PrimLib.BROADCAST and len(dom.ops) == 1:
op_attrs = dom.dom_op().attrs
if 'src_format' not in op_attrs.keys() \
or 'dst_format' not in op_attrs.keys():
logger.error("src_format or dst_format not be found in the attrs of Transdata op")
return False
src_format, dst_format = op_attrs['src_format'], op_attrs['dst_format']
if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW):
return True
# For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported
if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ\
and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output:
return True
return False

def _transdata(dom):


+ 17
- 15
mindspore/_extends/graph_kernel/model/model_builder.py View File

@@ -134,21 +134,23 @@ class CompositeGraph:
return red_axis

attr = {}
if op['name'] not in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
return attr
for a in op['attr']:
if a['name'] == 'axis':
red_axis, dim_size = [], len(inputs[0].shape)
if not a['value']:
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
else:
if isinstance(a['value'], int):
a['value'] = [a['value']]
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
if a['name'] == "reduce_output_fuse":
attr['reduce_output_fuse'] = a['value']
if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
for a in op['attr']:
if a['name'] == 'axis':
red_axis, dim_size = [], len(inputs[0].shape)
if not a['value']:
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
else:
if isinstance(a['value'], int):
a['value'] = [a['value']]
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
if a['name'] == "reduce_output_fuse":
attr['reduce_output_fuse'] = a['value']
elif op['attr']:
for a in op['attr']:
attr[a['name']] = a['value']
return attr

builder = GraphBuilder()


Loading…
Cancel
Save