| @@ -280,7 +280,7 @@ def _get_mean_matrix(x_shape, ksize, stride, pad_mode, x_dtype): | |||
| the value of element which is padded is 0, else are 1. | |||
| For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, | |||
| w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the | |||
| number of input that assosiate with output element. | |||
| number of input that associate with output element. | |||
| """ | |||
| n_input, c_input, h_input, w_input = x_shape | |||
| @@ -416,6 +416,58 @@ def get_bprop_dropout_do_mask(self): | |||
| return bprop | |||
| @bprop_getters.register(P.Mish) | |||
| def get_bprop_mish(self): | |||
| """Grad definition for `Mish` operation.""" | |||
| tanh = P.Tanh() | |||
| tanh_grad = SG.TanhGrad() | |||
| softplus = P.Softplus() | |||
| softplus_grad = G.SoftplusGrad() | |||
| def bprop(x, out, dout): | |||
| dx1 = tanh(softplus(x)) | |||
| dx2 = softplus_grad(tanh_grad(dx1, x * dout), x) | |||
| dx = (dx1 * dout + dx2) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.SeLU) | |||
| def get_bprop_selu(self): | |||
| """Grad definition for `SeLU` operation.""" | |||
| scale = 1.0507009873554804934193349852946 | |||
| elu_grad = G.EluGrad() | |||
| def bprop(x, out, dout): | |||
| dx = elu_grad(dout, out) * scale | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.MulNoNan) | |||
| def get_bprop_mul_no_nan(self): | |||
| """Grad definition for `MulNoNan` operation.""" | |||
| mul_no_nan = P.MulNoNan() | |||
| reduce_sum = P.ReduceSum() | |||
| reshape = P.Reshape() | |||
| def bprop(x, y, out, dout): | |||
| x_shape = F.shape(x) | |||
| y_shape = F.shape(y) | |||
| dx = mul_no_nan(dout, y) | |||
| dy = mul_no_nan(x, dout) | |||
| broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape) | |||
| if broadcast_x != (): | |||
| dx = reshape(reduce_sum(dx, broadcast_x), x_shape) | |||
| if broadcast_y != (): | |||
| dy = reshape(reduce_sum(dy, broadcast_y), y_shape) | |||
| return dx, dy | |||
| return bprop | |||
| @bprop_getters.register(P.ReLU) | |||
| def get_bprop_relu(self): | |||
| """Grad definition for `ReLU` operation.""" | |||
| @@ -355,3 +355,6 @@ from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe | |||
| from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe | |||
| from .nll_loss import _nll_loss_tbe | |||
| from .nll_loss_grad import _nll_loss_grad_tbe | |||
| from .mish import _mish_tbe | |||
| from .mul_no_nan import _mul_no_nan_tbe | |||
| from .selu import _selu_tbe | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mish op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| mish_op_info = TBERegOp("Mish") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("mish.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("mish") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(mish_op_info) | |||
| def _mish_tbe(): | |||
| """Mish TBE register""" | |||
| return | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MulNoNan op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| mul_no_nan_op_info = TBERegOp("MulNoNan") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("mul_no_nan.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("mul_no_nan") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(mul_no_nan_op_info) | |||
| def _mul_no_nan_tbe(): | |||
| """MulNoNan TBE register""" | |||
| return | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Selu op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| selu_op_info = TBERegOp("Selu") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("selu.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("selu") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .dtype_format(DataType.I8_None, DataType.I8_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(selu_op_info) | |||
| def _selu_tbe(): | |||
| """Selu TBE register""" | |||
| return | |||
| @@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny, | |||
| Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil, | |||
| Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod, | |||
| LogicalNot, LogicalOr, MatMul, Maximum, | |||
| LogicalNot, LogicalOr, MatMul, Maximum, MulNoNan, | |||
| Minimum, Mul, Neg, NMSWithMask, NotEqual, | |||
| NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace, | |||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | |||
| @@ -70,8 +70,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam | |||
| LogSoftmax, | |||
| MaxPool, DataFormatDimMap, | |||
| AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, SeLU, | |||
| SigmoidCrossEntropyWithLogits, NLLLoss, | |||
| SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2, | |||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | |||
| @@ -194,6 +194,9 @@ __all__ = [ | |||
| 'ZerosLike', | |||
| 'Select', | |||
| 'Split', | |||
| 'Mish', | |||
| 'SeLU', | |||
| 'MulNoNan', | |||
| 'ReLU', | |||
| 'ReLU6', | |||
| 'Elu', | |||
| @@ -2035,6 +2035,58 @@ class DivNoNan(_MathBinaryOp): | |||
| return None | |||
| class MulNoNan(_MathBinaryOp): | |||
| r""" | |||
| Computes x * y element-wise. if y is zero, No matter what x is, it will return 0. | |||
| Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent. | |||
| The inputs must be two tensors or one tensor and one scalar. | |||
| When the inputs are two tensors, the shapes of them could be broadcast. | |||
| When the inputs are one tensor and one scalar, the scalar could only be a constant. | |||
| Note: | |||
| The shapes of X and y should be same or can be broadcasting. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number. | |||
| - **input_y** (Union[Tensor]) - The second input is a tensor whose data type is number. | |||
| Outputs: | |||
| Tensor, the shape is the same as the one after broadcasting, | |||
| and the data type is the one with higher precision or higher digits among the two inputs. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Raise: | |||
| TypeError: If x or y is a bool tensor. | |||
| Examples: | |||
| >>> x = Tensor(np.array([[-1.0, 6.0, np.inf], [np.nan, -7.0, 4.0]]), ms.float32) | |||
| >>> y = Tensor(np.array([[-1.0, 4.0, 0], [0, -3.0, 1.0]]), ms.float32) | |||
| >>> mul_no_nan = ops.MulNoNan() | |||
| >>> output = mul_no_nan(x, y) | |||
| >>> print(output) | |||
| [[ 1. 24. 0.] | |||
| [ 0. 21. 4.]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize _BinaryOp""" | |||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | |||
| def infer_value(self, x, y): | |||
| if x is not None and y is not None: | |||
| x = x.asnumpy() | |||
| y = y.asnumpy() | |||
| with np.errstate(divide='ignore', invalid='ignore'): | |||
| out = np.multiply(x, y) | |||
| out[y == 0] = 0 | |||
| return out | |||
| return None | |||
| class FloorDiv(_MathBinaryOp): | |||
| """ | |||
| Divides the first input tensor by the second input tensor element-wise and round down to the closest integer. | |||
| @@ -4041,6 +4093,7 @@ class LinSpace(PrimitiveWithInfer): | |||
| 'value': None} | |||
| return out | |||
| class MatrixInverse(PrimitiveWithInfer): | |||
| """ | |||
| Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown | |||
| @@ -329,6 +329,99 @@ class ReLU(PrimitiveWithCheck): | |||
| validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name) | |||
| class Mish(PrimitiveWithInfer): | |||
| r""" | |||
| Computes MISH of input tensors element-wise. | |||
| The function is shown as follows: | |||
| .. math:: | |||
| \text{output} = x * \tan(\log(1 + \exp(\text{x}))) | |||
| Inputs: | |||
| - **x** (Tensor) - The input tensor. Only support float16 and float32. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `x`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Raise: | |||
| TypeError: If num_features data type not float16 and float32 Tensor. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) | |||
| >>> mish = ops.Mish() | |||
| >>> output = mish(input_x) | |||
| >>> print(output) | |||
| [[-3.034014e-01 3.997413e+00 -2.682209e-03] | |||
| [ 1.943959e+00 -3.357619e-02 8.999999e+00]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize Mish""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| class SeLU(PrimitiveWithInfer): | |||
| r""" | |||
| Computes SeLU (scaled exponential Linear Unit) of input tensors element-wise. | |||
| The activation function is defined as: | |||
| .. math:: | |||
| E_{i} = | |||
| scale * | |||
| \begin{cases} | |||
| x, &\text{if } x \geq 0; \cr | |||
| \text{alpha} * (\exp(x_i) - 1), &\text{otherwise.} | |||
| \end{cases} | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Raise: | |||
| TypeError: If num_features data type not int8, int32, float16 and float32 Tensor. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) | |||
| >>> selu = ops.SeLU() | |||
| >>> output = selu(input_x) | |||
| >>> print(output) | |||
| [[-1.1113307 4.202804 -1.7575096] | |||
| [ 2.101402 -1.7462534 9.456309 ]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize SeLU""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name) | |||
| return x_dtype | |||
| class ReLU6(PrimitiveWithInfer): | |||
| r""" | |||
| Computes ReLU (Rectified Linear Unit) upper bounded by 6 of input tensors element-wise. | |||
| @@ -1338,7 +1431,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||
| :math:`\text{in_channels} * \text{channel_multiplier}` channels. | |||
| Args: | |||
| channel_multiplier (int): The multipiler for the original output convolution. Its value must be greater than 0. | |||
| channel_multiplier (int): The multiplier for the original output convolution. Its value must be greater than 0. | |||
| kernel_size (Union[int, tuple[int]]): The size of the convolution kernel. | |||
| mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution , | |||
| 2 deconvolution, 3 depthwise convolution. Default: 3. | |||
| @@ -320,6 +320,42 @@ class CountNonZero(nn.Cell): | |||
| return nonzero_num | |||
| class Mish(nn.Cell): | |||
| """Mish net definition""" | |||
| def __init__(self): | |||
| super(Mish, self).__init__() | |||
| self.mish = P.Mish() | |||
| def construct(self, input_x): | |||
| out = self.mish(input_x) | |||
| return out | |||
| class SeLU(nn.Cell): | |||
| """Selu net definition""" | |||
| def __init__(self): | |||
| super(SeLU, self).__init__() | |||
| self.selu = P.SeLU() | |||
| def construct(self, input_x): | |||
| out = self.selu(input_x) | |||
| return out | |||
| class MulNoNan(nn.Cell): | |||
| """MulNoNan net definition""" | |||
| def __init__(self): | |||
| super(MulNoNan, self).__init__() | |||
| self.mul_no_nan = P.MulNoNan() | |||
| def construct(self, input_x, input_y): | |||
| out = self.mul_no_nan(input_x, input_y) | |||
| return out | |||
| class ScatterUpdate(nn.Cell): | |||
| """ScatterUpdate net definition""" | |||
| @@ -1315,6 +1351,19 @@ test_case_math_ops = [ | |||
| Tensor(np.array([-6, -1, -2, -3]), mstype.float32), | |||
| Tensor(np.array([6, 1, 2, 3]), mstype.float32)], | |||
| 'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32)]}), | |||
| ('Mish', { | |||
| 'block': Mish(), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)], | |||
| 'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}), | |||
| ('SeLU', { | |||
| 'block': SeLU(), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)], | |||
| 'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}), | |||
| ('MulNoNan', { | |||
| 'block': MulNoNan(), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32), | |||
| Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)], | |||
| 'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}), | |||
| ('Rank', { | |||
| 'block': P.Rank(), | |||
| 'desc_inputs': [[2, 3]], | |||