| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for Tile""" | """generate json desc for Tile""" | ||||
| from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer | |||||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | from mindspore._extends.graph_kernel.model.model import DataFormat as DF | ||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | from ._utils import Expander, ExpanderInfoValidator as VLD | ||||
| @@ -23,14 +23,32 @@ from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| class Tile(Expander): | class Tile(Expander): | ||||
| """Tile expander""" | """Tile expander""" | ||||
| def _get_output_shape(self): | |||||
| """Get output shape""" | |||||
| shape = self.inputs[0].shape | |||||
| multiples = self.attrs["multiples"] | |||||
| shape = list(shape) | |||||
| multiples = list(multiples) | |||||
| diff_len = len(multiples) - len(shape) | |||||
| if diff_len < 0: | |||||
| raise GKException("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) | |||||
| if diff_len > 0: | |||||
| for _ in range(diff_len): | |||||
| shape.insert(0, 1) | |||||
| output_shape = [] | |||||
| for sh, mul in list(zip(shape, multiples)): | |||||
| if sh != 1 and mul != 1: | |||||
| raise GKException("Tile op in expander only Support Automatic Broadcast!") | |||||
| dim = sh * mul | |||||
| output_shape.append(dim) | |||||
| return output_shape | |||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| input_x = self.inputs[0] | input_x = self.inputs[0] | ||||
| multiples = self.attrs['multiples'] | |||||
| tile_infer = TileInfer(self.name, self.inputs, self.attrs) | |||||
| output_shape, _, _ = tile_infer.infer() | |||||
| if tile_infer.broadcast_compatible: | |||||
| result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) | |||||
| else: | |||||
| result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) | |||||
| output_shape = self._get_output_shape() | |||||
| result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) | |||||
| return result | return result | ||||
| @@ -198,51 +198,6 @@ class BroadcastTo(OpInfer): | |||||
| return self.inputs[0].data_format | return self.inputs[0].data_format | ||||
| class Tile(OpInfer): | |||||
| """Op Tile""" | |||||
| def __init__(self, op_name, inputs, attrs): | |||||
| super().__init__(op_name, inputs, attrs) | |||||
| self.input_reshape = None | |||||
| self.output_reshape = None | |||||
| self.broadcast_compatible = True | |||||
| def _infer_shape(self): | |||||
| shape = self.inputs[0].shape | |||||
| multiples = self.attrs["multiples"] | |||||
| shape = list(shape) | |||||
| multiples = list(multiples) | |||||
| diff_len = len(multiples) - len(shape) | |||||
| if diff_len < 0: | |||||
| raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) | |||||
| if diff_len > 0: | |||||
| for _ in range(diff_len): | |||||
| shape.insert(0, 1) | |||||
| self.broadcast_compatible = True | |||||
| output_shape = [] | |||||
| self.input_reshape = [] | |||||
| self.output_reshape = [] | |||||
| for sh, mul in list(zip(shape, multiples)): | |||||
| dim = sh * mul | |||||
| output_shape.append(dim) | |||||
| if sh == 1 or mul == 1: | |||||
| self.input_reshape.append(sh) | |||||
| self.output_reshape.append(dim) | |||||
| else: | |||||
| self.broadcast_compatible = False | |||||
| self.input_reshape.append(1) | |||||
| self.input_reshape.append(sh) | |||||
| self.output_reshape.append(mul) | |||||
| self.output_reshape.append(sh) | |||||
| return output_shape | |||||
| def _infer_format(self): | |||||
| return DF.DEFAULT | |||||
| class _CompareOp(_Elemwise): | class _CompareOp(_Elemwise): | ||||
| """Compare operators""" | """Compare operators""" | ||||
| @@ -61,7 +61,7 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode): | |||||
| f.write("********** main graph: {} **********\n".format(graph_desc.name)) | f.write("********** main graph: {} **********\n".format(graph_desc.name)) | ||||
| f.write("input json:\n{}\n".format(graph_json)) | f.write("input json:\n{}\n".format(graph_json)) | ||||
| f.write("graph desc:\n{}\n".format(str(graph_desc))) | f.write("graph desc:\n{}\n".format(str(graph_desc))) | ||||
| if len(subgraphs) > 1 or subgraphs[0].stitch_info is not None: | |||||
| if len(subgraphs) > 1: | |||||
| for i, g in enumerate(subgraphs): | for i, g in enumerate(subgraphs): | ||||
| f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i])) | f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i])) | ||||
| f.write("{}\n".format(str(g))) | f.write("{}\n".format(str(g))) | ||||
| @@ -50,8 +50,8 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimLayerNorm, | prim::kPrimLayerNorm, | ||||
| prim::kPrimLayerNormGrad, | prim::kPrimLayerNormGrad, | ||||
| prim::kPrimExpandDims, | prim::kPrimExpandDims, | ||||
| #if ENABLE_D | |||||
| prim::kPrimTile, | prim::kPrimTile, | ||||
| #if ENABLE_D | |||||
| prim::kPrimSqrtGrad, | prim::kPrimSqrtGrad, | ||||
| prim::kPrimClipByNormNoDivSum, | prim::kPrimClipByNormNoDivSum, | ||||
| prim::kLambApplyOptimizerAssign, | prim::kLambApplyOptimizerAssign, | ||||