Browse Source

add tile to expand list

add tile expander

add BroadcastTo in model

fix BroadcastTo op calling error and infer shape

rewrite tile expander

not split broadcast_to

add SqrtGrad expander
tags/v1.1.0
looop5 5 years ago
parent
commit
848be9b07c
8 changed files with 144 additions and 3 deletions
  1. +2
    -0
      mindspore/_extends/graph_kernel/expanders/__init__.py
  2. +45
    -0
      mindspore/_extends/graph_kernel/expanders/sqrt_grad.py
  3. +87
    -0
      mindspore/_extends/graph_kernel/expanders/tile.py
  4. +2
    -1
      mindspore/_extends/graph_kernel/model/graph_split.py
  5. +1
    -0
      mindspore/_extends/graph_kernel/model/model.py
  6. +1
    -0
      mindspore/_extends/graph_kernel/model/model_builder.py
  7. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc
  8. +5
    -2
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc

+ 2
- 0
mindspore/_extends/graph_kernel/expanders/__init__.py View File

@@ -32,3 +32,5 @@ from .layernorm_grad import expand_layernormgrad
from .logsoftmax import expand_logsoftmax from .logsoftmax import expand_logsoftmax
from .logsoftmax_grad import expand_logsoftmaxgrad from .logsoftmax_grad import expand_logsoftmaxgrad
from .gkdropout import expand_gkdropout from .gkdropout import expand_gkdropout
from .tile import expand_tile
from .sqrt_grad import expand_sqrtgrad

+ 45
- 0
mindspore/_extends/graph_kernel/expanders/sqrt_grad.py View File

@@ -0,0 +1,45 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for sqrtgrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder


def expand_sqrtgrad(expand_info):
"""SqrtGrad expander"""
# cal formula are:
# sqrt_grad(x, dout) = dout / (2 * x)

# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
graph_builder = builder.GraphBuilder()

# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
graph_scope.set_input(input_x, input_dout)

# cal result
const_two = graph_builder.value(input_x.dtype, 2, input_x.data_format)
dividend = graph_builder.emit('Mul', [input_x, const_two])
result = graph_builder.emit('RealDiv', [input_dout, dividend])

# set graph output.
graph_scope.set_output(result)

graph = graph_builder.get()[0]
return graph

+ 87
- 0
mindspore/_extends/graph_kernel/expanders/tile.py View File

@@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for Tile"""
from mindspore._extends.graph_kernel.model import model_builder as builder


def _get_tile_output_shape(shape, multiples):
"""compute output shape of tile"""

if multiples is None:
return shape
if not isinstance(shape, (list, tuple)):
raise TypeError("Input shape of Tile must be of type list or tuple")
if not isinstance(multiples, (list, tuple)):
raise TypeError("multiples of Tile must be of type list or tuple")

shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)

shape_compatible = True
output_shape = []
input_reshape = []
output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
input_reshape.append(sh)
output_reshape.append(dim)
else:
shape_compatible = False
input_reshape.append(1)
input_reshape.append(sh)
output_reshape.append(mul)
output_reshape.append(sh)

return output_shape, input_reshape, output_reshape, shape_compatible


def expand_tile(expand_info):
"""Tile expander"""

# get op info.
input_desc = expand_info['input_desc'][0]
attrs = expand_info['attr']
multiples = None
for item in attrs:
if 'multiples' in item:
multiples = item['multiples']
output_shape, input_reshape, output_reshape, shape_compatible = _get_tile_output_shape(input_desc['shape'],
multiples)
graph_builder = builder.GraphBuilder()

# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# create op.
if shape_compatible:
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
else:
input_x_reshape = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_reshape})
reshape_broadcast = graph_builder.emit('BroadcastTo', [input_x_reshape], attrs={'shape': output_reshape})
result = graph_builder.emit('Reshape', [reshape_broadcast], attrs={'shape': output_shape})
# set graph output.
graph_scope.set_output(result)

graph = graph_builder.get()[0]
return graph

+ 2
- 1
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -31,7 +31,8 @@ class GraphSplitByPattern:
self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
self.mode = self.MODE_BASIC self.mode = self.MODE_BASIC
if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE):
if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \
(use_poly_reduce and self.pattern == PrimLib.REDUCE):
self.mode = self.MODE_COMPOSITE self.mode = self.MODE_COMPOSITE
self.is_output = is_output self.is_output = is_output
self.output_excluded = set() self.output_excluded = set()


+ 1
- 0
mindspore/_extends/graph_kernel/model/model.py View File

@@ -170,6 +170,7 @@ class PrimLib:
'FlattenGrad': Prim(RESHAPE), 'FlattenGrad': Prim(RESHAPE),
'Transpose': Prim(TRANSFORM), 'Transpose': Prim(TRANSFORM),
'Tile': Prim(BROADCAST), 'Tile': Prim(BROADCAST),
'BroadcastTo': Prim(BROADCAST),
} }


default_primtive = Prim(UNKNOWN) default_primtive = Prim(UNKNOWN)


+ 1
- 0
mindspore/_extends/graph_kernel/model/model_builder.py View File

@@ -73,6 +73,7 @@ class OpInfer:
# add special infer func here # add special infer func here
'InplaceAssign': lambda inputs, attrs: inputs[2].shape, 'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
'Reshape': lambda inputs, attrs: attrs["shape"], 'Reshape': lambda inputs, attrs: attrs["shape"],
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
} }
infer_dtype_func = { infer_dtype_func = {
# add special infer func here # add special infer func here


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc View File

@@ -248,6 +248,7 @@ class CNodeDecoder {
{kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}}, {kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}}, {kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}}, {kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}},
{kBroadcastToOpName, std::vector<std::string>{kAttrShape}},
}; };


PrimitivePtr GetPrimitive(const std::string &op_name) { PrimitivePtr GetPrimitive(const std::string &op_name) {


+ 5
- 2
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -701,11 +701,14 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
std::unordered_set<PrimitivePtr> GetExpandOps() { std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = { std::unordered_set<PrimitivePtr> expand_ops = {
prim::kPrimSquare, prim::kPrimSquare,
#if ENABLE_GPU
prim::kPrimGeluGrad,
#if ENABLE_D
prim::kPrimTile,
prim::kPrimSqrtGrad,
#elif ENABLE_GPU
prim::kPrimBiasAdd, prim::kPrimBiasAdd,
prim::kPrimBiasAddGrad, prim::kPrimBiasAddGrad,
prim::kPrimGelu, prim::kPrimGelu,
prim::kPrimGeluGrad,
prim::kPrimFusedAdam, prim::kPrimFusedAdam,
prim::kPrimFusedAdamWeightDecay, prim::kPrimFusedAdamWeightDecay,
prim::kPrimTanhGrad, prim::kPrimTanhGrad,


Loading…
Cancel
Save