diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index 870e55816e..6d84481ee6 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -32,3 +32,5 @@ from .layernorm_grad import expand_layernormgrad from .logsoftmax import expand_logsoftmax from .logsoftmax_grad import expand_logsoftmaxgrad from .gkdropout import expand_gkdropout +from .tile import expand_tile +from .sqrt_grad import expand_sqrtgrad diff --git a/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py b/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py new file mode 100644 index 0000000000..e9dacaaed4 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for sqrtgrad""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_sqrtgrad(expand_info): + """SqrtGrad expander""" + # cal formula are: + # sqrt_grad(x, dout) = dout / (2 * x) + + # get op info. + input_desc_0 = expand_info['input_desc'][0] + input_desc_1 = expand_info['input_desc'][1] + graph_builder = builder.GraphBuilder() + + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) + input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) + graph_scope.set_input(input_x, input_dout) + + # cal result + const_two = graph_builder.value(input_x.dtype, 2, input_x.data_format) + dividend = graph_builder.emit('Mul', [input_x, const_two]) + result = graph_builder.emit('RealDiv', [input_dout, dividend]) + + # set graph output. + graph_scope.set_output(result) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/_extends/graph_kernel/expanders/tile.py b/mindspore/_extends/graph_kernel/expanders/tile.py new file mode 100644 index 0000000000..4a2638dc07 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/tile.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for Tile""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def _get_tile_output_shape(shape, multiples): + """compute output shape of tile""" + + if multiples is None: + return shape + if not isinstance(shape, (list, tuple)): + raise TypeError("Input shape of Tile must be of type list or tuple") + if not isinstance(multiples, (list, tuple)): + raise TypeError("multiples of Tile must be of type list or tuple") + + shape = list(shape) + multiples = list(multiples) + diff_len = len(multiples) - len(shape) + if diff_len < 0: + raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) + if diff_len > 0: + for _ in range(diff_len): + shape.insert(0, 1) + + shape_compatible = True + output_shape = [] + input_reshape = [] + output_reshape = [] + for sh, mul in list(zip(shape, multiples)): + dim = sh * mul + output_shape.append(dim) + if sh == 1 or mul == 1: + input_reshape.append(sh) + output_reshape.append(dim) + else: + shape_compatible = False + input_reshape.append(1) + input_reshape.append(sh) + output_reshape.append(mul) + output_reshape.append(sh) + + return output_shape, input_reshape, output_reshape, shape_compatible + + +def expand_tile(expand_info): + """Tile expander""" + + # get op info. + input_desc = expand_info['input_desc'][0] + attrs = expand_info['attr'] + multiples = None + for item in attrs: + if 'multiples' in item: + multiples = item['multiples'] + output_shape, input_reshape, output_reshape, shape_compatible = _get_tile_output_shape(input_desc['shape'], + multiples) + graph_builder = builder.GraphBuilder() + + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) + # create op. + if shape_compatible: + result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) + else: + input_x_reshape = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_reshape}) + reshape_broadcast = graph_builder.emit('BroadcastTo', [input_x_reshape], attrs={'shape': output_reshape}) + result = graph_builder.emit('Reshape', [reshape_broadcast], attrs={'shape': output_shape}) + # set graph output. + graph_scope.set_output(result) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index b92f5f930f..9120230726 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -31,7 +31,8 @@ class GraphSplitByPattern: self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.mode = self.MODE_BASIC - if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE): + if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \ + (use_poly_reduce and self.pattern == PrimLib.REDUCE): self.mode = self.MODE_COMPOSITE self.is_output = is_output self.output_excluded = set() diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index f3d2e1c27f..81822a2b12 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -170,6 +170,7 @@ class PrimLib: 'FlattenGrad': Prim(RESHAPE), 'Transpose': Prim(TRANSFORM), 'Tile': Prim(BROADCAST), + 'BroadcastTo': Prim(BROADCAST), } default_primtive = Prim(UNKNOWN) diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index 21722e9439..54c190015d 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -73,6 +73,7 @@ class OpInfer: # add special infer func here 'InplaceAssign': lambda inputs, attrs: inputs[2].shape, 'Reshape': lambda inputs, attrs: attrs["shape"], + 'BroadcastTo': lambda inputs, attrs: attrs["shape"], } infer_dtype_func = { # add special infer func here diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc index 10a8179c76..ba7bb02ba3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc @@ -248,6 +248,7 @@ class CNodeDecoder { {kReduceSumOpName, std::vector{kAttrKeepDims}}, {kReduceMaxOpName, std::vector{kAttrKeepDims}}, {kReduceMinOpName, std::vector{kAttrKeepDims}}, + {kBroadcastToOpName, std::vector{kAttrShape}}, }; PrimitivePtr GetPrimitive(const std::string &op_name) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 7bd937ec63..b37dea0e35 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -701,11 +701,14 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector GetExpandOps() { std::unordered_set expand_ops = { prim::kPrimSquare, -#if ENABLE_GPU + prim::kPrimGeluGrad, +#if ENABLE_D + prim::kPrimTile, + prim::kPrimSqrtGrad, +#elif ENABLE_GPU prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, - prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad,