diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index a5ad57fdd9..665c0b98ff 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -15,7 +15,9 @@ """Cost model splitter""" import os from functools import reduce +from mindspore import log as logger from .model import PrimLib, Graph, Tensor +from .model import DataFormat as DF class GraphSplitByPattern: """Graph splitter""" @@ -674,6 +676,21 @@ class GraphSplitAscend(GraphSplitByPattern): if a.dom_op().prim == "MatMul" and len(dom.ops) == 1: return True + + # reshape/elewise/broadcast + transdata + if a.pattern <= PrimLib.BROADCAST and len(dom.ops) == 1: + op_attrs = dom.dom_op().attrs + if 'src_format' not in op_attrs.keys() \ + or 'dst_format' not in op_attrs.keys(): + logger.error("src_format or dst_format not be found in the attrs of Transdata op") + return False + src_format, dst_format = op_attrs['src_format'], op_attrs['dst_format'] + if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW): + return True + # For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported + if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ\ + and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output: + return True return False def _transdata(dom): diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index 5ff88f4e5e..0913329bed 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -134,21 +134,23 @@ class CompositeGraph: return red_axis attr = {} - if op['name'] not in ('ReduceSum', 'ReduceMax', 'ReduceMin'): - return attr - for a in op['attr']: - if a['name'] == 'axis': - red_axis, dim_size = [], len(inputs[0].shape) - if not a['value']: - red_axis = _get_axis_while_none(inputs[0].shape, output.shape) - else: - if isinstance(a['value'], int): - a['value'] = [a['value']] - for i in a['value']: - red_axis.append(i if i >= 0 else dim_size + i) - attr['reduce_axis'] = red_axis - if a['name'] == "reduce_output_fuse": - attr['reduce_output_fuse'] = a['value'] + if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'): + for a in op['attr']: + if a['name'] == 'axis': + red_axis, dim_size = [], len(inputs[0].shape) + if not a['value']: + red_axis = _get_axis_while_none(inputs[0].shape, output.shape) + else: + if isinstance(a['value'], int): + a['value'] = [a['value']] + for i in a['value']: + red_axis.append(i if i >= 0 else dim_size + i) + attr['reduce_axis'] = red_axis + if a['name'] == "reduce_output_fuse": + attr['reduce_output_fuse'] = a['value'] + elif op['attr']: + for a in op['attr']: + attr[a['name']] = a['value'] return attr builder = GraphBuilder()