From 8dcff8d83c971f94865105c4d0d8899f298a8c33 Mon Sep 17 00:00:00 2001 From: zengzitao Date: Sun, 25 Apr 2021 16:08:35 +0800 Subject: [PATCH] refactor tile op and in expander open on gpu --- .../_extends/graph_kernel/expanders/tile.py | 36 +++++++++++---- .../_extends/graph_kernel/model/op_infer.py | 45 ------------------- mindspore/_extends/graph_kernel/splitter.py | 2 +- .../graph_kernel/graph_kernel_expander.cc | 2 +- 4 files changed, 29 insertions(+), 56 deletions(-) diff --git a/mindspore/_extends/graph_kernel/expanders/tile.py b/mindspore/_extends/graph_kernel/expanders/tile.py index 23c2584105..83b3f2f002 100644 --- a/mindspore/_extends/graph_kernel/expanders/tile.py +++ b/mindspore/_extends/graph_kernel/expanders/tile.py @@ -13,7 +13,7 @@ # limitations under the License. # =========================================================================== """generate json desc for Tile""" -from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer +from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException from mindspore._extends.graph_kernel.model.model import DataFormat as DF from ._utils import Expander, ExpanderInfoValidator as VLD @@ -23,14 +23,32 @@ from ._utils import Expander, ExpanderInfoValidator as VLD class Tile(Expander): """Tile expander""" + def _get_output_shape(self): + """Get output shape""" + shape = self.inputs[0].shape + multiples = self.attrs["multiples"] + + shape = list(shape) + multiples = list(multiples) + diff_len = len(multiples) - len(shape) + if diff_len < 0: + raise GKException("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) + if diff_len > 0: + for _ in range(diff_len): + shape.insert(0, 1) + + output_shape = [] + + for sh, mul in list(zip(shape, multiples)): + if sh != 1 and mul != 1: + raise GKException("Tile op in expander only Support Automatic Broadcast!") + dim = sh * mul + output_shape.append(dim) + return output_shape + def _expand(self, graph_builder): input_x = self.inputs[0] - multiples = self.attrs['multiples'] - - tile_infer = TileInfer(self.name, self.inputs, self.attrs) - output_shape, _, _ = tile_infer.infer() - if tile_infer.broadcast_compatible: - result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) - else: - result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) + output_shape = self._get_output_shape() + + result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) return result diff --git a/mindspore/_extends/graph_kernel/model/op_infer.py b/mindspore/_extends/graph_kernel/model/op_infer.py index 0009bd6d12..ae48de3ce8 100644 --- a/mindspore/_extends/graph_kernel/model/op_infer.py +++ b/mindspore/_extends/graph_kernel/model/op_infer.py @@ -198,51 +198,6 @@ class BroadcastTo(OpInfer): return self.inputs[0].data_format -class Tile(OpInfer): - """Op Tile""" - - def __init__(self, op_name, inputs, attrs): - super().__init__(op_name, inputs, attrs) - self.input_reshape = None - self.output_reshape = None - self.broadcast_compatible = True - - def _infer_shape(self): - shape = self.inputs[0].shape - multiples = self.attrs["multiples"] - - shape = list(shape) - multiples = list(multiples) - diff_len = len(multiples) - len(shape) - if diff_len < 0: - raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) - if diff_len > 0: - for _ in range(diff_len): - shape.insert(0, 1) - - self.broadcast_compatible = True - output_shape = [] - self.input_reshape = [] - self.output_reshape = [] - for sh, mul in list(zip(shape, multiples)): - dim = sh * mul - output_shape.append(dim) - if sh == 1 or mul == 1: - self.input_reshape.append(sh) - self.output_reshape.append(dim) - else: - self.broadcast_compatible = False - self.input_reshape.append(1) - self.input_reshape.append(sh) - self.output_reshape.append(mul) - self.output_reshape.append(sh) - - return output_shape - - def _infer_format(self): - return DF.DEFAULT - - class _CompareOp(_Elemwise): """Compare operators""" diff --git a/mindspore/_extends/graph_kernel/splitter.py b/mindspore/_extends/graph_kernel/splitter.py index b2d2253cc7..c622159ac1 100644 --- a/mindspore/_extends/graph_kernel/splitter.py +++ b/mindspore/_extends/graph_kernel/splitter.py @@ -61,7 +61,7 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode): f.write("********** main graph: {} **********\n".format(graph_desc.name)) f.write("input json:\n{}\n".format(graph_json)) f.write("graph desc:\n{}\n".format(str(graph_desc))) - if len(subgraphs) > 1 or subgraphs[0].stitch_info is not None: + if len(subgraphs) > 1: for i, g in enumerate(subgraphs): f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i])) f.write("{}\n".format(str(g))) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index eeb530a72f..45029f3112 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -50,8 +50,8 @@ std::vector GetExpandOps() { prim::kPrimLayerNorm, prim::kPrimLayerNormGrad, prim::kPrimExpandDims, -#if ENABLE_D prim::kPrimTile, +#if ENABLE_D prim::kPrimSqrtGrad, prim::kPrimClipByNormNoDivSum, prim::kLambApplyOptimizerAssign,