Merge pull request !251 from jxlang910/mastertags/v0.6.0-beta
| @@ -38,3 +38,4 @@ from .gamma import _gamma_aicpu | |||
| from .poisson import _poisson_aicpu | |||
| from .uniform_int import _uniform_int_aicpu | |||
| from .uniform_real import _uniform_real_aicpu | |||
| from .laplace import _laplace_aicpu | |||
| @@ -0,0 +1,33 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """RandomLaplace op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| laplace_op_info = AiCPURegOp("Laplace") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "shape", "required") \ | |||
| .input(1, "mean", "required") \ | |||
| .input(2, "lambda_param", "required") \ | |||
| .output(0, "output", "required") \ | |||
| .attr("seed", "int") \ | |||
| .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||
| .get_op_info() | |||
| @op_info_register(laplace_op_info) | |||
| def _laplace_aicpu(): | |||
| """RandomLaplace AiCPU register""" | |||
| return | |||
| @@ -55,7 +55,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) | |||
| from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal, | |||
| RandomCategorical) | |||
| RandomCategorical, Laplace) | |||
| from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, | |||
| BiasAdd, Conv2D, | |||
| DepthwiseConv2dNative, | |||
| @@ -177,6 +177,7 @@ __all__ = [ | |||
| 'Poisson', | |||
| 'UniformInt', | |||
| 'UniformReal', | |||
| 'Laplace', | |||
| 'RandomCategorical', | |||
| 'ResizeBilinear', | |||
| 'ScalarSummary', | |||
| @@ -36,9 +36,9 @@ class Normal(PrimitiveWithInfer): | |||
| Inputs: | |||
| - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. | |||
| - **mean** (Tensor) - The mean μ distribution parameter, The mean specifies the location of the peak. | |||
| With float32 data type. | |||
| - **stddev** (Tensor) - the deviation σ distribution parameter. With float32 data type. | |||
| - **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak. | |||
| With float32 data type. | |||
| - **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type. | |||
| Outputs: | |||
| Tensor, has the shape 'shape' input and dtype as float32. | |||
| @@ -75,6 +75,60 @@ class Normal(PrimitiveWithInfer): | |||
| return out | |||
| class Laplace(PrimitiveWithInfer): | |||
| r""" | |||
| Generates random numbers according to the Laplace random number distribution. | |||
| It is defined as: | |||
| .. math:: | |||
| \text{f}(x;μ,λ) = \frac{1}{2λ}\exp(-\frac{|x-μ|}{λ}), | |||
| Args: | |||
| seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. | |||
| Default: 0. | |||
| Inputs: | |||
| - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. | |||
| - **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak. | |||
| With float32 data type. | |||
| - **lambda_param** (Tensor) - The parameter used for controling the variance of this random distribution. The | |||
| variance of Laplace distribution is equal to twice the square of lambda_param. With float32 data type. | |||
| Outputs: | |||
| Tensor, has the shape 'shape' input and dtype as float32. | |||
| Examples: | |||
| >>> shape = (4, 16) | |||
| >>> mean = Tensor(1.0, mstype.float32) | |||
| >>> lambda_param = Tensor(1.0, mstype.float32) | |||
| >>> laplace = P.Laplace(seed=2) | |||
| >>> output = laplace(shape, mean, lambda_param) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, seed=0): | |||
| """Init Laplace""" | |||
| self.init_prim_io_names(inputs=['shape', 'mean', 'lambda_param'], outputs=['output']) | |||
| validator.check_value_type('seed', seed, [int], self.name) | |||
| def __infer__(self, shape, mean, lambda_param): | |||
| shape_v = shape["value"] | |||
| if shape_v is None: | |||
| raise ValueError(f"For {self.name}, shape must be const.") | |||
| validator.check_value_type("shape", shape_v, [tuple], self.name) | |||
| for i, shape_i in enumerate(shape_v): | |||
| validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) | |||
| validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) | |||
| validator.check_tensor_type_same({"lambda_param": lambda_param["dtype"]}, [mstype.float32], self.name) | |||
| broadcast_shape = get_broadcast_shape(mean['shape'], lambda_param['shape'], self.name) | |||
| broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) | |||
| out = { | |||
| 'shape': broadcast_shape, | |||
| 'dtype': mstype.float32, | |||
| 'value': None} | |||
| return out | |||
| class Gamma(PrimitiveWithInfer): | |||
| r""" | |||
| Produces random positive floating-point values x, distributed according to probability density function: | |||
| @@ -101,7 +155,7 @@ class Gamma(PrimitiveWithInfer): | |||
| >>> alpha = Tensor(1.0, mstype.float32) | |||
| >>> beta = Tensor(1.0, mstype.float32) | |||
| >>> gamma = P.Gamma(seed=3) | |||
| >>> output = normal(shape, alpha, beta) | |||
| >>> output = Gamma(shape, alpha, beta) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common import dtype as mstype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self, shape, seed=0): | |||
| super(Net, self).__init__() | |||
| self.laplace = P.Laplace(seed=seed) | |||
| self.shape = shape | |||
| def construct(self, mean, lambda_param): | |||
| return self.laplace(self.shape, mean, lambda_param) | |||
| def test_net_1D(): | |||
| seed = 10 | |||
| shape = (3, 2, 4) | |||
| mean = 1.0 | |||
| lambda_param = 1.0 | |||
| net = Net(shape, seed) | |||
| tmean, tlambda_param = Tensor(mean, mstype.float32), Tensor(lambda_param, mstype.float32) | |||
| output = net(tmean, tlambda_param) | |||
| print(output.asnumpy()) | |||
| assert output.shape == (3, 2, 4) | |||
| def test_net_ND(): | |||
| seed = 10 | |||
| shape = (3, 1, 2) | |||
| mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) | |||
| lambda_param = np.array([1.0]).astype(np.float32) | |||
| net = Net(shape, seed) | |||
| tmean, tlambda_param = Tensor(mean), Tensor(lambda_param) | |||
| output = net(tmean, tlambda_param) | |||
| print(output.asnumpy()) | |||
| assert output.shape == (3, 2, 2) | |||
| @@ -410,6 +410,17 @@ class NormalNet(nn.Cell): | |||
| return out | |||
| class LaplaceNet(nn.Cell): | |||
| def __init__(self, shape=None, seed=0): | |||
| super(LaplaceNet, self).__init__() | |||
| self.laplace = P.Laplace(seed=seed) | |||
| self.shape = shape | |||
| def construct(self, mean, lambda_param): | |||
| out = self.laplace(self.shape, mean, lambda_param) | |||
| return out | |||
| class GammaNet(nn.Cell): | |||
| def __init__(self, shape=None, seed=0): | |||
| super(GammaNet, self).__init__() | |||
| @@ -666,6 +677,10 @@ test_case_math_ops = [ | |||
| 'block': NormalNet((3, 2, 4), 0), | |||
| 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)], | |||
| 'skip': ['backward']}), | |||
| ('Laplace', { | |||
| 'block': LaplaceNet((3, 2, 4), 0), | |||
| 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)], | |||
| 'skip': ['backward']}), | |||
| ('Gamma', { | |||
| 'block': GammaNet((3, 2, 4), 0), | |||
| 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)], | |||