Merge pull request !8023 from ZengZitao/expand_bias_addtags/v1.1.0
| @@ -18,3 +18,5 @@ from .gelu import expand_gelu | |||||
| from .layernorm import expand_layernorm | from .layernorm import expand_layernorm | ||||
| from .softmax import expand_softmax | from .softmax import expand_softmax | ||||
| from .square import expand_square | from .square import expand_square | ||||
| from .bias_add import expand_biasadd | |||||
| from .bias_add_grad import expand_biasaddgrad | |||||
| @@ -0,0 +1,62 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for bias_add""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_biasadd(expand_info): | |||||
| """BiasAdd expander""" | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor( | |||||
| input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_y = graph_builder.tensor( | |||||
| input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| graph_scope.set_input(input_x, input_y) | |||||
| if input_x.data_format == "NCHW": | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||||
| result = graph_builder.emit('TensorAdd', [input_x, input_y_expand]) | |||||
| elif input_x.data_format == "DefaultFormat": | |||||
| if len(input_x.shape) == 2: | |||||
| result = graph_builder.emit('TensorAdd', [input_x, input_y]) | |||||
| elif len(input_x.shape) == 3: | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| result = graph_builder.emit( | |||||
| 'TensorAdd', [input_x, input_y_expand]) | |||||
| else: | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y], attrs={'axis': 1}) | |||||
| input_y_expand = graph_builder.emit( | |||||
| 'ExpandDims', [input_y_expand], attrs={'axis': 2}) | |||||
| result = graph_builder.emit( | |||||
| 'TensorAdd', [input_x, input_y_expand]) | |||||
| else: | |||||
| result = graph_builder.emit('TensorAdd', [input_x, input_y]) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for bias_add""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_biasaddgrad(expand_info): | |||||
| """BiasAddGrad expander""" | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor( | |||||
| input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| graph_scope.set_input(input_x) | |||||
| reduce_axis = () | |||||
| if input_x.data_format == 'NHWC': | |||||
| reduce_axis = (0, 1, 2) | |||||
| elif input_x.data_format == 'NCHW': | |||||
| reduce_axis = (0, 2, 3) | |||||
| # Default format shape's length maybe equal 2 to 4, so different shape's length reduce axis are differnet | |||||
| else: | |||||
| if len(input_x.shape) == 2: | |||||
| reduce_axis = (0,) | |||||
| elif len(input_x.shape) == 3: | |||||
| reduce_axis = (0, 1) | |||||
| else: | |||||
| reduce_axis = (0, 2, 3) | |||||
| result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| @@ -296,6 +296,7 @@ class Graph: | |||||
| def __init__(self, name, ops): | def __init__(self, name, ops): | ||||
| self.name = name | self.name = name | ||||
| self.ops = ops # in topo order, can not use set | self.ops = ops # in topo order, can not use set | ||||
| self.inputs = [] | |||||
| self.outputs = [] | self.outputs = [] | ||||
| def set_processor(self, processor): | def set_processor(self, processor): | ||||
| @@ -341,6 +342,9 @@ class Graph: | |||||
| if d not in self.ops: | if d not in self.ops: | ||||
| outputs.append(op.output) | outputs.append(op.output) | ||||
| break | break | ||||
| if self.inputs: | |||||
| inputs = self.inputs | |||||
| if self.outputs: | if self.outputs: | ||||
| outputs = self.outputs | outputs = self.outputs | ||||
| return inputs, outputs | return inputs, outputs | ||||
| @@ -22,14 +22,36 @@ class OpInfer: | |||||
| """Op infer""" | """Op infer""" | ||||
| @staticmethod | @staticmethod | ||||
| def default_reduce_infer(inputs, attrs): | def default_reduce_infer(inputs, attrs): | ||||
| """Default reduce infer""" | |||||
| shape = copy.deepcopy(inputs[0].shape) | shape = copy.deepcopy(inputs[0].shape) | ||||
| for i in attrs['reduce_axis']: | |||||
| shape[i] = 1 | |||||
| if attrs['keep_dims']: | |||||
| for i in attrs['reduce_axis']: | |||||
| shape[i] = 1 | |||||
| return shape | |||||
| real_shape = [] | |||||
| for i, _ in enumerate(shape): | |||||
| if i not in attrs['reduce_axis']: | |||||
| real_shape.append(shape[i]) | |||||
| return real_shape | |||||
| @staticmethod | |||||
| def default_elementwise_infer(inputs, attrs): | |||||
| """Default elementwise infer""" | |||||
| shape = (1,) | |||||
| max_flatten_shape = 1 | |||||
| for t in inputs: | |||||
| flatten_shape = 1 | |||||
| for s in t.shape: | |||||
| flatten_shape *= s | |||||
| if flatten_shape >= max_flatten_shape: | |||||
| max_flatten_shape = flatten_shape | |||||
| shape = t.shape | |||||
| return shape | return shape | ||||
| default_infer_shape_func = [ | default_infer_shape_func = [ | ||||
| None, | None, | ||||
| lambda inputs, attrs: max([t.shape for t in inputs]), | |||||
| default_elementwise_infer.__func__, | |||||
| lambda inputs, attrs: max([t.shape for t in inputs]), | lambda inputs, attrs: max([t.shape for t in inputs]), | ||||
| default_reduce_infer.__func__, | default_reduce_infer.__func__, | ||||
| None, | None, | ||||
| @@ -72,9 +94,16 @@ class OpInfer: | |||||
| class GraphBuilder: | class GraphBuilder: | ||||
| """Graph builder""" | """Graph builder""" | ||||
| class GraphWrapper: | class GraphWrapper: | ||||
| """Graph wrapper""" | |||||
| def __init__(self, name): | def __init__(self, name): | ||||
| self.graph = Graph(name, []) | self.graph = Graph(name, []) | ||||
| def set_input(self, *para): | |||||
| for t in para: | |||||
| t.para_type = Tensor.PARA_INPUT | |||||
| self.graph.inputs.append(t) | |||||
| def set_output(self, *para): | def set_output(self, *para): | ||||
| for t in para: | for t in para: | ||||
| t.para_type = Tensor.PARA_OUTPUT | t.para_type = Tensor.PARA_OUTPUT | ||||
| @@ -702,6 +702,8 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo | |||||
| std::unordered_set<PrimitivePtr> GetExpandOps() { | std::unordered_set<PrimitivePtr> GetExpandOps() { | ||||
| std::unordered_set<PrimitivePtr> expand_ops = { | std::unordered_set<PrimitivePtr> expand_ops = { | ||||
| prim::kPrimSquare, | prim::kPrimSquare, | ||||
| prim::kPrimBiasAdd, | |||||
| prim::kPrimBiasAddGrad, | |||||
| }; | }; | ||||
| return expand_ops; | return expand_ops; | ||||
| } | } | ||||
| @@ -143,6 +143,7 @@ inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | |||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | ||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | ||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | ||||
| inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd"); | |||||
| inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | ||||
| inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = | inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = | ||||
| std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | ||||