Merge pull request !588 from liuxiao/add-reluv2tags/v0.2.0-alpha
| @@ -33,6 +33,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"re_lu6", "relu6"}, | |||
| {"re_lu6_grad", "relu6_grad"}, | |||
| {"re_lu", "relu"}, | |||
| {"re_luv2", "relu_v2"}, | |||
| {"tensor_add", "add"}, | |||
| {"reduce_mean", "reduce_mean_d"}, | |||
| {"reduce_max", "reduce_max_d"}, | |||
| @@ -227,6 +227,18 @@ def get_bprop_relu6(self): | |||
| return bprop | |||
| @bprop_getters.register(P.ReLUV2) | |||
| def get_bprop_relu_v2(self): | |||
| """Grad definition for `ReLUV2` operation.""" | |||
| input_grad = G.ReluGradV2() | |||
| def bprop(x, out, dout): | |||
| mask = out[1] | |||
| dx = input_grad(dout[0], mask) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.HSwish) | |||
| def get_bprop_hswish(self): | |||
| """Grad definition for `HSwish` operation.""" | |||
| @@ -33,6 +33,7 @@ from .cast import _cast_tbe | |||
| from .conv2d import _conv2d_tbe | |||
| from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe | |||
| from .conv2d_backprop_input import _conv2d_backprop_input_tbe | |||
| from .confusion_mul_grad import _confusion_mul_grad_tbe | |||
| from .dropout_do_mask import _dropout_do_mask_tbe | |||
| from .gelu import _gelu_tbe | |||
| from .gelu_grad import _gelu_grad_tbe | |||
| @@ -46,6 +47,8 @@ from .relu import _relu_tbe | |||
| from .relu_grad import _relu_grad_tbe | |||
| from .relu6 import _relu6_tbe | |||
| from .relu6_grad import _relu6_grad_tbe | |||
| from .relu_v2 import _relu_v2_tbe | |||
| from .relu_grad_v2 import _relu_grad_v2_tbe | |||
| from .softmax_cross_entropy_with_logits import _softmax_cross_entropy_with_logits_tbe | |||
| from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe | |||
| from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ConfusionMulGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| confusion_mul_grad_op_info = TBERegOp("ConfusionMulGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .attr("axis", "required", "listInt", "all") \ | |||
| .attr("keep_dims", "required", "bool", "all") \ | |||
| .input(0, "input0", False, "required", "all") \ | |||
| .input(1, "input1", False, "required", "all") \ | |||
| .input(2, "input2", False, "required", "all") \ | |||
| .output(0, "output0", False, "required", "all") \ | |||
| .output(1, "output1", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(confusion_mul_grad_op_info) | |||
| def _confusion_mul_grad_tbe(): | |||
| """ConfusionMulGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ReluGradV2 op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| relu_grad_v2_op_info = TBERegOp("ReluGradV2") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("relu_grad_v2.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("relu_grad_v2") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "gradients", False, "required", "all") \ | |||
| .input(1, "mask", False, "rerequired", "all") \ | |||
| .output(0, "backprops", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.U8_Default, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.U8_Default, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_Default, DataType.U8_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(relu_grad_v2_op_info) | |||
| def _relu_grad_v2_tbe(): | |||
| """ReluGradV2 TBE register""" | |||
| return | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ReluV2 op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| relu_v2_op_info = TBERegOp("ReLUV2") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("relu_v2.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("relu_v2") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "mask", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(relu_v2_op_info) | |||
| def _relu_v2_tbe(): | |||
| """ReluV2 TBE register""" | |||
| return | |||
| @@ -58,8 +58,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| GetNext, L2Normalize, LayerNorm, L2Loss, | |||
| LogSoftmax, | |||
| MaxPool, ExtractImagePatches, | |||
| AvgPool, Conv2DBackpropInput, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, | |||
| AvgPool, Conv2DBackpropInput, ConfusionMulGrad, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, | |||
| SigmoidCrossEntropyWithLogits, | |||
| SmoothL1Loss, Softmax, | |||
| @@ -101,6 +101,7 @@ __all__ = [ | |||
| 'LogSoftmax', | |||
| 'SoftmaxCrossEntropyWithLogits', | |||
| 'ROIAlign', | |||
| 'ConfusionMulGrad', | |||
| 'SparseSoftmaxCrossEntropyWithLogits', | |||
| 'SGD', | |||
| 'ApplyMomentum', | |||
| @@ -138,6 +139,7 @@ __all__ = [ | |||
| 'Split', | |||
| 'ReLU', | |||
| 'ReLU6', | |||
| 'ReLUV2', | |||
| 'Elu', | |||
| 'Erf', | |||
| 'Sigmoid', | |||
| @@ -730,6 +730,27 @@ class ReLU6Grad(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class ReluGradV2(PrimitiveWithInfer): | |||
| """Performs grad of ReLUV2 operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output']) | |||
| def __call__(self, gradients, mask): | |||
| raise NotImplementedError | |||
| def infer_shape(self, gradients_shape, mask_shape): | |||
| return gradients_shape | |||
| def infer_dtype(self, gradients_dtype, mask_dtype): | |||
| args_type = {'gradients': gradients_dtype, 'mask': mask_dtype} | |||
| validator.check_args_tensor(args_type) | |||
| validator.check_typename("gradients_dtype", gradients_dtype, mstype.number_type) | |||
| validator.check_typename("mask_dtype", mask_dtype, (mstype.uint8,)) | |||
| return gradients_dtype | |||
| class EluGrad(PrimitiveWithInfer): | |||
| """Performs grad of Elu operation.""" | |||
| @@ -1329,7 +1329,7 @@ class Concat(PrimitiveWithInfer): | |||
| def _get_pack_shape(x_shape, x_type, axis): | |||
| """for pack output shape""" | |||
| validator.check_type("shape", x_shape, [tuple]) | |||
| validator.check_type("shape", x_shape, [tuple, list]) | |||
| validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT) | |||
| validator.check_subclass("shape0", x_type[0], mstype.tensor) | |||
| validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT) | |||
| @@ -28,6 +28,7 @@ from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): | |||
| @@ -233,6 +234,62 @@ class ReLU6(PrimitiveWithInfer): | |||
| return input_x | |||
| class ReLUV2(PrimitiveWithInfer): | |||
| r""" | |||
| Computes ReLU(Rectified Linear Unit) of input tensor element-wise. | |||
| It returns :math:`\max(x,\ 0)` element-wise. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor should be a 4-D tensor. | |||
| Outputs: | |||
| - **output** (Tensor) - Has the same type and shape as the `input_x`. | |||
| - **mask** (Tensor) - A tensor whose data type must be uint8. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[[[1, -2], [-3, 4]], [[-5, 6], [7, -8]]]]), mindspore.float32) | |||
| >>> relu_v2 = P.ReLUV2() | |||
| >>> output = relu_v2(input_x) | |||
| ([[[[1., 0.], [0., 4.]], [[0., 6.], [7., 0.]]]], | |||
| [[[[1, 0], [2, 0]], [[2, 0], [1, 0]]]]) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init ReLUV2""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask']) | |||
| def __infer__(self, input_x): | |||
| input_shape = list(input_x['shape']) | |||
| input_dtype = input_x['dtype'] | |||
| mask_shape = [] | |||
| if len(input_shape) != 4: | |||
| raise ValueError("The `input_x` should be a 4-D tensor, " | |||
| f"but got a {len(input_shape)}-D tensor whose shape is {input_shape}") | |||
| for i in enumerate(input_shape): | |||
| if i[0] == 1: | |||
| if input_dtype == mstype.uint8 and input_dtype == mstype.int8: | |||
| mask_shape.append((input_shape[1] + 31) // 32) | |||
| else: | |||
| mask_shape.append((input_shape[1] + 15) // 16) | |||
| else: | |||
| mask_shape.append(i[1]) | |||
| if input_dtype == mstype.uint8 and input_dtype == mstype.int8: | |||
| mask_shape.append(4) | |||
| else: | |||
| mask_shape.append(2) | |||
| output_shape = (input_x['shape'], mask_shape) | |||
| validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({'input_x': input_dtype}, mstype.number_type, self.name) | |||
| mask_dtype = mstype.uint8 | |||
| output_dtype = (input_dtype, mask_dtype) | |||
| return {'shape': output_shape, | |||
| 'dtype': output_dtype, | |||
| 'value': None} | |||
| class Elu(PrimitiveWithInfer): | |||
| r""" | |||
| Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise. | |||
| @@ -2580,3 +2637,51 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| def infer_dtype(self, input_x): | |||
| validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) | |||
| return input_x | |||
| class ConfusionMulGrad(PrimitiveWithInfer): | |||
| """ | |||
| `output0` is the result of which input0 dot multily input1. | |||
| `output1` is the result of which input0 dot multily input1, then reducesum it. | |||
| Args: | |||
| axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. | |||
| Default:(), reduce all dimensions. Only constant value is allowed. | |||
| keep_dims (bool): | |||
| - If true, keep these reduced dimensions and the length is 1. | |||
| - If false, don't keep these dimensions. Default:False. | |||
| Inputs: | |||
| - **input_0** (Tensor) - The input Tensor. | |||
| - **input_1** (Tensor) - The input Tensor. | |||
| - **input_2** (Tensor) - The input Tensor. | |||
| outputs: | |||
| - **output_0** (Tensor) - The same shape with `input0`. | |||
| - **output_1** (Tensor) | |||
| - If axis is (), and keep_dims is false, the output is a 0-D array representing | |||
| the sum of all elements in the input array. | |||
| - If axis is int, set as 2, and keep_dims is false, | |||
| the shape of output is :math:`(x_1,x_3,...,x_R)`. | |||
| - If axis is tuple(int), set as (2,3), and keep_dims is false, | |||
| the shape of output is :math:`(x_1,x_4,...x_R)`. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, axis = (), keep_dims = False): | |||
| self.init_prim_io_names(inputs = ["input0", "input1", "input2"], outputs = ["output0", "output1"]) | |||
| self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) | |||
| self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) | |||
| def infer_shape(self, input0_shape, input1_shape, input2_shape): | |||
| outshape0 = input0_shape | |||
| outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name) | |||
| return outshape0, outshape1 | |||
| def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype): | |||
| validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) | |||
| return input0_dtype, input1_dtype | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import ms_function | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops.composite import GradOperation | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = GradOperation(name="get_all", get_all=True) | |||
| self.network = network | |||
| @ms_function | |||
| def construct(self, input): | |||
| return self.grad(self.network)(input) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.relu_v2 = P.ReLUV2() | |||
| def construct(self, x): | |||
| return self.relu_v2(x) | |||
| def test_net(): | |||
| x = Tensor(np.ones((2,3,3,4)).astype(np.float32)) | |||
| relu_net = Net() | |||
| relu_output = relu_net(x) | |||
| net = Grad(Net()) | |||
| output_grad = net(x) | |||
| print(relu_output[0].asnumpy()) | |||
| print(relu_output[1].asnumpy()) | |||
| print(len(output_grad)) | |||
| print(output_grad[0].asnumpy()) | |||
| @@ -582,6 +582,10 @@ test_case_nn_ops = [ | |||
| 'block': P.ReLU6(), | |||
| 'desc_inputs': [[1, 3, 4, 4]], | |||
| 'desc_bprop': [[1, 3, 4, 4]]}), | |||
| ('ReLUV2', { | |||
| 'block': P.ReLUV2(), | |||
| 'desc_inputs': [[1, 3, 4, 4]], | |||
| 'desc_bprop': [[1, 3, 4, 4], [1, 3, 4, 4]]}), | |||
| ('ReLUGrad', { | |||
| 'block': G.ReluGrad(), | |||
| 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | |||
| @@ -1134,6 +1138,21 @@ test_case_other_ops = [ | |||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | |||
| Tensor(np.array([1.2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('ConfusionMulGrad_1', { | |||
| 'block': P.ConfusionMulGrad(axis = [0], keep_dims = False), | |||
| 'desc_inputs': [[3, 2], [3, 2], [3, 2]], | |||
| 'desc_bprop': [[3, 2], [2]], | |||
| 'skip': ['backward']}), | |||
| ('ConfusionMulGrad_2', { | |||
| 'block': P.ConfusionMulGrad(axis = [0], keep_dims = True), | |||
| 'desc_inputs': [[3, 2], [3, 2], [3, 2]], | |||
| 'desc_bprop': [[3, 2], [1, 2]], | |||
| 'skip': ['backward']}), | |||
| ('ConfusionMulGrad_3', { | |||
| 'block': P.ConfusionMulGrad(axis = (), keep_dims = True), | |||
| 'desc_inputs': [[2, 3, 4], [2, 3, 4], [2, 3, 4]], | |||
| 'desc_bprop': [[2, 3, 4], [1, 1, 1]], | |||
| 'skip': ['backward']}), | |||
| ('HistogramSummary', { | |||
| 'block': HistogramSummaryNet(), | |||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | |||