Merge pull request !1026 from peixu_ren/custom_pp_opstags/v0.3.0-alpha
| @@ -17,7 +17,7 @@ Layer. | |||||
| The high-level components(Cells) used to construct the neural network. | The high-level components(Cells) used to construct the neural network. | ||||
| """ | """ | ||||
| from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant | |||||
| from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math | |||||
| from .activation import * | from .activation import * | ||||
| from .normalization import * | from .normalization import * | ||||
| from .container import * | from .container import * | ||||
| @@ -28,6 +28,7 @@ from .embedding import * | |||||
| from .pooling import * | from .pooling import * | ||||
| from .image import * | from .image import * | ||||
| from .quant import * | from .quant import * | ||||
| from .math import * | |||||
| __all__ = [] | __all__ = [] | ||||
| __all__.extend(activation.__all__) | __all__.extend(activation.__all__) | ||||
| @@ -40,3 +41,4 @@ __all__.extend(embedding.__all__) | |||||
| __all__.extend(pooling.__all__) | __all__.extend(pooling.__all__) | ||||
| __all__.extend(image.__all__) | __all__.extend(image.__all__) | ||||
| __all__.extend(quant.__all__) | __all__.extend(quant.__all__) | ||||
| __all__.extend(math.__all__) | |||||
| @@ -35,6 +35,7 @@ __all__ = ['Softmax', | |||||
| 'HSigmoid', | 'HSigmoid', | ||||
| 'HSwish', | 'HSwish', | ||||
| 'ELU', | 'ELU', | ||||
| 'LogSigmoid', | |||||
| ] | ] | ||||
| @@ -476,6 +477,49 @@ class HSigmoid(Cell): | |||||
| return self.hsigmoid(x) | return self.hsigmoid(x) | ||||
| class LogSigmoid(Cell): | |||||
| r""" | |||||
| Logsigmoid activation function. | |||||
| Applies logsigmoid activation element-wise. The input is a Tensor with any valid shape. | |||||
| Logsigmoid is defined as: | |||||
| .. math:: | |||||
| \text{logsigmoid}(x_{i}) = log(\frac{1}{1 + \exp(-x_i)}), | |||||
| where :math:`x_{i}` is the element of the input. | |||||
| Inputs: | |||||
| - **input_data** (Tensor) - The input of LogSigmoid. | |||||
| Outputs: | |||||
| Tensor, with the same type and shape as the `input_data`. | |||||
| Examples: | |||||
| >>> net = nn.LogSigmoid() | |||||
| >>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) | |||||
| >>> logsigmoid = net(input_x) | |||||
| [-3.1326166e-01, -1.2692806e-01, -4.8587345e-02] | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogSigmoid, self).__init__() | |||||
| self.mul = P.Mul() | |||||
| self.exp = P.Exp() | |||||
| self.add = P.TensorAdd() | |||||
| self.rec = P.Reciprocal() | |||||
| self.log = P.Log() | |||||
| def construct(self, input_x): | |||||
| neg_input = self.mul(input_x, -1) | |||||
| exp_neg_input = self.exp(neg_input) | |||||
| exp_neg_input_1 = self.add(exp_neg_input, 1) | |||||
| rec_exp_neg_input_1 = self.rec(exp_neg_input_1) | |||||
| ret = self.log(rec_exp_neg_input_1) | |||||
| return ret | |||||
| _activation = { | _activation = { | ||||
| 'softmax': Softmax, | 'softmax': Softmax, | ||||
| 'logsoftmax': LogSoftmax, | 'logsoftmax': LogSoftmax, | ||||
| @@ -488,6 +532,7 @@ _activation = { | |||||
| 'leakyrelu': LeakyReLU, | 'leakyrelu': LeakyReLU, | ||||
| 'hswish': HSwish, | 'hswish': HSwish, | ||||
| 'hsigmoid': HSigmoid, | 'hsigmoid': HSigmoid, | ||||
| 'logsigmoid': LogSigmoid, | |||||
| } | } | ||||
| @@ -0,0 +1,68 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """math""" | |||||
| from mindspore.ops import operations as P | |||||
| from ..cell import Cell | |||||
| from ..._checkparam import Validator as validator | |||||
| __all__ = ['ReduceLogSumExp'] | |||||
| class ReduceLogSumExp(Cell): | |||||
| r""" | |||||
| Reduce a dimension of a tensor by calculating exponential for all elements in the dimension, | |||||
| then calculate logarithm of the sum. | |||||
| The dtype of the tensor to be reduced is number. | |||||
| Args: | |||||
| keep_dims (bool): If True, keep these reduced dimensions and the length is 1. | |||||
| If False, don't keep these dimensions. | |||||
| Default : False. | |||||
| Inputs: | |||||
| - **input_x** (Tensor[Number]) - The input tensor. | |||||
| - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||||
| Only constant value is allowed. | |||||
| Outputs: | |||||
| Tensor, has the same dtype as the 'input_x'. | |||||
| - If axis is (), and keep_dims is false, | |||||
| the output is a 0-D tensor representing the sum of all elements in the input tensor. | |||||
| - If axis is int, set as 2, and keep_dims is false, | |||||
| the shape of output is :math:`(x_1, x_3, ..., x_R)`. | |||||
| - If axis is tuple(int), set as (2, 3), and keep_dims is false, | |||||
| the shape of output is :math:`(x_1, x_4, ..., x_R)`. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32)) | |||||
| >>> op = P.ReduceLogSumExp(keep_dims=True) | |||||
| >>> output = op(input_x, 1) | |||||
| """ | |||||
| def __init__(self, axis, keep_dims=False): | |||||
| super(ReduceLogSumExp, self).__init__() | |||||
| validator.check_value_type('axis', axis, [int, list, tuple], self.cls_name) | |||||
| validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name) | |||||
| self.axis = axis | |||||
| self.exp = P.Exp() | |||||
| self.sum = P.ReduceSum(keep_dims) | |||||
| self.log = P.Log() | |||||
| def construct(self, input_x): | |||||
| exp = self.exp(input_x) | |||||
| sumexp = self.sum(exp, self.axis) | |||||
| logsumexp = self.log(sumexp) | |||||
| return logsumexp | |||||
| @@ -522,6 +522,16 @@ test_cases = [ | |||||
| 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))], | 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))], | ||||
| 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))], | 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('LogSigmoid', { | |||||
| 'block': nn.LogSigmoid(), | |||||
| 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))], | |||||
| 'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| ('ReduceLogSumExp', { | |||||
| 'block': nn.ReduceLogSumExp((0, ), False), | |||||
| 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], | |||||
| 'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| ] | ] | ||||
| test_cases_for_verify_exception = [ | test_cases_for_verify_exception = [ | ||||
| @@ -621,6 +631,20 @@ test_cases_for_verify_exception = [ | |||||
| ), | ), | ||||
| 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], | 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], | ||||
| }), | }), | ||||
| ('ReduceLogsumexp_TypeError_1', { | |||||
| 'block': ( | |||||
| lambda _: nn.ReduceLogSumExp(axis=(0,), keep_dims=2), | |||||
| {'exception': TypeError}, | |||||
| ), | |||||
| 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], | |||||
| }), | |||||
| ('ReduceLogsumexp_TypeError_2', { | |||||
| 'block': ( | |||||
| lambda _: nn.ReduceLogSumExp(axis=1.2, keep_dims=True), | |||||
| {'exception': TypeError}, | |||||
| ), | |||||
| 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], | |||||
| }), | |||||
| ] | ] | ||||