add tile expander add BroadcastTo in model fix BroadcastTo op calling error and infer shape rewrite tile expander not split broadcast_to add SqrtGrad expandertags/v1.1.0
| @@ -32,3 +32,5 @@ from .layernorm_grad import expand_layernormgrad | |||||
| from .logsoftmax import expand_logsoftmax | from .logsoftmax import expand_logsoftmax | ||||
| from .logsoftmax_grad import expand_logsoftmaxgrad | from .logsoftmax_grad import expand_logsoftmaxgrad | ||||
| from .gkdropout import expand_gkdropout | from .gkdropout import expand_gkdropout | ||||
| from .tile import expand_tile | |||||
| from .sqrt_grad import expand_sqrtgrad | |||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for sqrtgrad""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_sqrtgrad(expand_info): | |||||
| """SqrtGrad expander""" | |||||
| # cal formula are: | |||||
| # sqrt_grad(x, dout) = dout / (2 * x) | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| graph_scope.set_input(input_x, input_dout) | |||||
| # cal result | |||||
| const_two = graph_builder.value(input_x.dtype, 2, input_x.data_format) | |||||
| dividend = graph_builder.emit('Mul', [input_x, const_two]) | |||||
| result = graph_builder.emit('RealDiv', [input_dout, dividend]) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for Tile""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def _get_tile_output_shape(shape, multiples): | |||||
| """compute output shape of tile""" | |||||
| if multiples is None: | |||||
| return shape | |||||
| if not isinstance(shape, (list, tuple)): | |||||
| raise TypeError("Input shape of Tile must be of type list or tuple") | |||||
| if not isinstance(multiples, (list, tuple)): | |||||
| raise TypeError("multiples of Tile must be of type list or tuple") | |||||
| shape = list(shape) | |||||
| multiples = list(multiples) | |||||
| diff_len = len(multiples) - len(shape) | |||||
| if diff_len < 0: | |||||
| raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) | |||||
| if diff_len > 0: | |||||
| for _ in range(diff_len): | |||||
| shape.insert(0, 1) | |||||
| shape_compatible = True | |||||
| output_shape = [] | |||||
| input_reshape = [] | |||||
| output_reshape = [] | |||||
| for sh, mul in list(zip(shape, multiples)): | |||||
| dim = sh * mul | |||||
| output_shape.append(dim) | |||||
| if sh == 1 or mul == 1: | |||||
| input_reshape.append(sh) | |||||
| output_reshape.append(dim) | |||||
| else: | |||||
| shape_compatible = False | |||||
| input_reshape.append(1) | |||||
| input_reshape.append(sh) | |||||
| output_reshape.append(mul) | |||||
| output_reshape.append(sh) | |||||
| return output_shape, input_reshape, output_reshape, shape_compatible | |||||
| def expand_tile(expand_info): | |||||
| """Tile expander""" | |||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | |||||
| attrs = expand_info['attr'] | |||||
| multiples = None | |||||
| for item in attrs: | |||||
| if 'multiples' in item: | |||||
| multiples = item['multiples'] | |||||
| output_shape, input_reshape, output_reshape, shape_compatible = _get_tile_output_shape(input_desc['shape'], | |||||
| multiples) | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| # create op. | |||||
| if shape_compatible: | |||||
| result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) | |||||
| else: | |||||
| input_x_reshape = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_reshape}) | |||||
| reshape_broadcast = graph_builder.emit('BroadcastTo', [input_x_reshape], attrs={'shape': output_reshape}) | |||||
| result = graph_builder.emit('Reshape', [reshape_broadcast], attrs={'shape': output_shape}) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| @@ -31,7 +31,8 @@ class GraphSplitByPattern: | |||||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | ||||
| self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | ||||
| self.mode = self.MODE_BASIC | self.mode = self.MODE_BASIC | ||||
| if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE): | |||||
| if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \ | |||||
| (use_poly_reduce and self.pattern == PrimLib.REDUCE): | |||||
| self.mode = self.MODE_COMPOSITE | self.mode = self.MODE_COMPOSITE | ||||
| self.is_output = is_output | self.is_output = is_output | ||||
| self.output_excluded = set() | self.output_excluded = set() | ||||
| @@ -170,6 +170,7 @@ class PrimLib: | |||||
| 'FlattenGrad': Prim(RESHAPE), | 'FlattenGrad': Prim(RESHAPE), | ||||
| 'Transpose': Prim(TRANSFORM), | 'Transpose': Prim(TRANSFORM), | ||||
| 'Tile': Prim(BROADCAST), | 'Tile': Prim(BROADCAST), | ||||
| 'BroadcastTo': Prim(BROADCAST), | |||||
| } | } | ||||
| default_primtive = Prim(UNKNOWN) | default_primtive = Prim(UNKNOWN) | ||||
| @@ -73,6 +73,7 @@ class OpInfer: | |||||
| # add special infer func here | # add special infer func here | ||||
| 'InplaceAssign': lambda inputs, attrs: inputs[2].shape, | 'InplaceAssign': lambda inputs, attrs: inputs[2].shape, | ||||
| 'Reshape': lambda inputs, attrs: attrs["shape"], | 'Reshape': lambda inputs, attrs: attrs["shape"], | ||||
| 'BroadcastTo': lambda inputs, attrs: attrs["shape"], | |||||
| } | } | ||||
| infer_dtype_func = { | infer_dtype_func = { | ||||
| # add special infer func here | # add special infer func here | ||||
| @@ -248,6 +248,7 @@ class CNodeDecoder { | |||||
| {kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}}, | {kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}}, | ||||
| {kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}}, | {kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}}, | ||||
| {kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}}, | {kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}}, | ||||
| {kBroadcastToOpName, std::vector<std::string>{kAttrShape}}, | |||||
| }; | }; | ||||
| PrimitivePtr GetPrimitive(const std::string &op_name) { | PrimitivePtr GetPrimitive(const std::string &op_name) { | ||||
| @@ -701,11 +701,14 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | std::unordered_set<PrimitivePtr> GetExpandOps() { | ||||
| std::unordered_set<PrimitivePtr> expand_ops = { | std::unordered_set<PrimitivePtr> expand_ops = { | ||||
| prim::kPrimSquare, | prim::kPrimSquare, | ||||
| #if ENABLE_GPU | |||||
| prim::kPrimGeluGrad, | |||||
| #if ENABLE_D | |||||
| prim::kPrimTile, | |||||
| prim::kPrimSqrtGrad, | |||||
| #elif ENABLE_GPU | |||||
| prim::kPrimBiasAdd, | prim::kPrimBiasAdd, | ||||
| prim::kPrimBiasAddGrad, | prim::kPrimBiasAddGrad, | ||||
| prim::kPrimGelu, | prim::kPrimGelu, | ||||
| prim::kPrimGeluGrad, | |||||
| prim::kPrimFusedAdam, | prim::kPrimFusedAdam, | ||||
| prim::kPrimFusedAdamWeightDecay, | prim::kPrimFusedAdamWeightDecay, | ||||
| prim::kPrimTanhGrad, | prim::kPrimTanhGrad, | ||||