diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index c83a6ec46e..dc5fba988c 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -26,7 +26,6 @@ from .squeeze import _squeeze_aicpu from .expand_dims import _expand_dims_aicpu from .random_choice_with_mask import _random_choice_with_mask_aicpu from .pack import _pack_aicpu -from .normal import _normal_aicpu from .ctcloss import _ctcloss_aicpu from .reverse_sequence import _reverse_sequence_aicpu from .crop_and_resize import _crop_and_resize_aicpu @@ -34,3 +33,8 @@ from .rnnt_loss import _rnnt_loss_aicpu from .random_categorical import _random_categorical_aicpu from .cast import _cast_aicpu from .mirror_pad import _mirror_pad_aicpu +from .normal import _normal_aicpu +from .gamma import _gamma_aicpu +from .poisson import _poisson_aicpu +from .uniform_int import _uniform_int_aicpu +from .uniform_real import _uniform_real_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/gamma.py b/mindspore/ops/_op_impl/aicpu/gamma.py new file mode 100644 index 0000000000..b6b92a9da4 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/gamma.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomGamma op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +gamma_op_info = AiCPURegOp("Gamma") \ + .fusion_type("OPAQUE") \ + .input(0, "shape", "required") \ + .input(1, "alpha", "required") \ + .input(2, "beta", "required") \ + .output(0, "output", "required") \ + .attr("seed", "int") \ + .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ + .get_op_info() + +@op_info_register(gamma_op_info) +def _gamma_aicpu(): + """RandomGamma AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/normal.py b/mindspore/ops/_op_impl/aicpu/normal.py index fdb96e362f..5597dc7200 100644 --- a/mindspore/ops/_op_impl/aicpu/normal.py +++ b/mindspore/ops/_op_impl/aicpu/normal.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""Normal op""" +"""RandomNormal op""" from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType normal_op_info = AiCPURegOp("Normal") \ @@ -21,7 +21,7 @@ normal_op_info = AiCPURegOp("Normal") \ .input(0, "shape", "required") \ .input(1, "mean", "required") \ .input(2, "stddev", "required") \ - .output(0, "y", "required") \ + .output(0, "output", "required") \ .attr("seed", "int") \ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ @@ -29,5 +29,5 @@ normal_op_info = AiCPURegOp("Normal") \ @op_info_register(normal_op_info) def _normal_aicpu(): - """Normal AiCPU register""" + """RandomNormal AiCPU register""" return diff --git a/mindspore/ops/_op_impl/aicpu/poisson.py b/mindspore/ops/_op_impl/aicpu/poisson.py new file mode 100644 index 0000000000..3d3e5e4c35 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/poisson.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomPoisson op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +poisson_op_info = AiCPURegOp("Poisson") \ + .fusion_type("OPAQUE") \ + .input(0, "shape", "required") \ + .input(1, "mean", "required") \ + .output(0, "output", "required") \ + .attr("seed", "int") \ + .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW) \ + .get_op_info() + +@op_info_register(poisson_op_info) +def _poisson_aicpu(): + """RandomPoisson AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/uniform_int.py b/mindspore/ops/_op_impl/aicpu/uniform_int.py new file mode 100644 index 0000000000..3a55a399a9 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/uniform_int.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomUniformInt op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +uniform_int_op_info = AiCPURegOp("UniformInt") \ + .fusion_type("OPAQUE") \ + .input(0, "shape", "required") \ + .input(1, "a", "required") \ + .input(2, "b", "required") \ + .output(0, "output", "required") \ + .attr("seed", "int") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \ + .get_op_info() + +@op_info_register(uniform_int_op_info) +def _uniform_int_aicpu(): + """RandomUniformInt AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/uniform_real.py b/mindspore/ops/_op_impl/aicpu/uniform_real.py new file mode 100644 index 0000000000..51824fbb2c --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/uniform_real.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomUniformReal op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +uniform_real_op_info = AiCPURegOp("UniformReal") \ + .fusion_type("OPAQUE") \ + .input(0, "shape", "required") \ + .input(1, "a", "required") \ + .input(2, "b", "required") \ + .output(0, "output", "required") \ + .attr("seed", "int") \ + .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ + .get_op_info() + +@op_info_register(uniform_real_op_info) +def _uniform_real_aicpu(): + """RandomUniformReal AiCPU register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index dec223193a..6ba138bc22 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -54,7 +54,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) -from .random_ops import (RandomChoiceWithMask, Normal, RandomCategorical) +from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal, + RandomCategorical) from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, @@ -172,6 +173,10 @@ __all__ = [ 'Tanh', 'RandomChoiceWithMask', 'Normal', + 'Gamma', + 'Poisson', + 'UniformInt', + 'UniformReal', 'RandomCategorical', 'ResizeBilinear', 'ScalarSummary', diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index cde7dd41e3..251acf33f0 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -19,6 +19,269 @@ from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register +from .._utils import get_broadcast_shape + + +class Normal(PrimitiveWithInfer): + r""" + Generates random numbers according to the Normal (or Gaussian) random number distribution. + It is defined as: + + .. math:: + \text{f}(x;μ,σ) = \frac{1}{σ\sqrt{2π}}\exp(-\frac{1}{2}(\frac{x-μ}{σ})^2), + + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. + - **mean** (Tensor) - The mean μ distribution parameter, The mean specifies the location of the peak. + With float32 data type. + - **stddev** (Tensor) - the deviation σ distribution parameter. With float32 data type. + + Outputs: + Tensor, has the shape 'shape' input and dtype as float32. + + Examples: + >>> shape = (4, 16) + >>> mean = Tensor(1.0, mstype.float32) + >>> stddev = Tensor(1.0, mstype.float32) + >>> normal = P.Normal(seed=2) + >>> output = normal(shape, mean, stddev) + """ + + @prim_attr_register + def __init__(self, seed=0): + """Init Normal""" + self.init_prim_io_names(inputs=['shape', 'mean', 'stddev'], outputs=['output']) + validator.check_value_type('seed', seed, [int], self.name) + + def __infer__(self, shape, mean, stddev): + shape_v = shape["value"] + if shape_v is None: + raise ValueError(f"For {self.name}, shape must be const.") + validator.check_value_type("shape", shape_v, [tuple], self.name) + for i, shape_i in enumerate(shape_v): + validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) + validator.check_tensor_type_same({"stddev": stddev["dtype"]}, [mstype.float32], self.name) + broadcast_shape = get_broadcast_shape(mean['shape'], stddev['shape'], self.name) + broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) + out = { + 'shape': broadcast_shape, + 'dtype': mstype.float32, + 'value': None} + return out + + +class Gamma(PrimitiveWithInfer): + r""" + Produces random positive floating-point values x, distributed according to probability density function: + + .. math:: + \text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}}, + + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. + - **alpha** (Tensor) - The α distribution parameter. + It is also known as the shape parameter. With float32 data type. + - **beta** (Tensor) - The β distribution parameter. + It is also known as the scale parameter. With float32 data type. + + Outputs: + Tensor, has the shape 'shape' input and dtype as float32. + + Examples: + >>> shape = (4, 16) + >>> alpha = Tensor(1.0, mstype.float32) + >>> beta = Tensor(1.0, mstype.float32) + >>> gamma = P.Gamma(seed=3) + >>> output = normal(shape, alpha, beta) + """ + + @prim_attr_register + def __init__(self, seed=0): + """Init Gamma""" + self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) + validator.check_value_type('seed', seed, [int], self.name) + + def __infer__(self, shape, alpha, beta): + shape_v = shape["value"] + if shape_v is None: + raise ValueError(f"For {self.name}, shape must be const.") + validator.check_value_type("shape", shape_v, [tuple], self.name) + for i, shape_i in enumerate(shape_v): + validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name) + validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name) + broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) + broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) + out = { + 'shape': broadcast_shape, + 'dtype': mstype.float32, + 'value': None} + return out + + +class Poisson(PrimitiveWithInfer): + r""" + Produces random non-negative integer values i, distributed according to discrete probability function: + + .. math:: + \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}, + + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. + - **mean** (Tensor) - μ parameter the distribution was constructed with. + The parameter defines mean number of occurrences of the event. With float32 data type. + + Outputs: + Tensor, has the shape 'shape' input and dtype as int32. + + Examples: + >>> shape = (4, 16) + >>> mean = Tensor(5.0, mstype.float32) + >>> poisson = P.Poisson(seed=5) + >>> output = poisson(shape, mean) + """ + + @prim_attr_register + def __init__(self, seed=0): + """Init Poisson""" + self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) + validator.check_value_type('seed', seed, [int], self.name) + + def __infer__(self, shape, mean): + shape_v = shape["value"] + if shape_v is None: + raise ValueError(f"For {self.name}, shape must be const.") + validator.check_value_type("shape", shape_v, [tuple], self.name) + for i, shape_i in enumerate(shape_v): + validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) + broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) + out = { + 'shape': broadcast_shape, + 'dtype': mstype.int32, + 'value': None} + return out + + +class UniformInt(PrimitiveWithInfer): + r""" + Produces random integer values i, uniformly distributed on the closed interval [a, b], that is, + distributed according to the discrete probability function: + + .. math:: + \text{P}(i|a,b) = \frac{1}{b-a+1}, + + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. + - **a** (Tensor) - The a distribution parameter. + It defines the minimum possibly generated value. With int32 data type. + - **b** (Tensor) - The b distribution parameter. + It defines the maximum possibly generated value. With int32 data type. + + Outputs: + Tensor, has the shape 'shape' input and dtype as int32. + + Examples: + >>> shape = (4, 16) + >>> a = Tensor(1, mstype.int32) + >>> b = Tensor(5, mstype.int32) + >>> uniform_int = P.UniformInt(seed=10) + >>> output = uniform_int(shape, a, b) + """ + + @prim_attr_register + def __init__(self, seed=0): + """Init UniformInt""" + self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output']) + validator.check_value_type('seed', seed, [int], self.name) + + def __infer__(self, shape, a, b): + shape_v = shape["value"] + if shape_v is None: + raise ValueError(f"For {self.name}, shape must be const.") + validator.check_value_type("shape", shape_v, [tuple], self.name) + for i, shape_i in enumerate(shape_v): + validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name) + validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name) + broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name) + broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) + out = { + 'shape': broadcast_shape, + 'dtype': mstype.int32, + 'value': None} + return out + + +class UniformReal(PrimitiveWithInfer): + r""" + Produces random floating-point values i, uniformly distributed on the interval [a, b), that is,\ + distributed according to the probability density function: + + .. math:: + \text{P}(i|a,b) = \frac{1}{b-a}, + + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. + - **a** (Tensor) - The a distribution parameter. + It defines the minimum possibly generated value. With float32 data type. + - **b** (Tensor) - The b distribution parameter. + It defines the maximum possibly generated value. With float32 data type. + + Outputs: + Tensor, has the shape 'shape' input and dtype as int32. + + Examples: + >>> shape = (4, 16) + >>> a = Tensor(1.0, mstype.float32) + >>> b = Tensor(5.0, mstype.float32) + >>> uniform_real = P.UniformReal(seed=10) + >>> output = uniform_real(shape, a, b) + """ + + @prim_attr_register + def __init__(self, seed=0): + """Init UniformReal""" + self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output']) + validator.check_value_type('seed', seed, [int], self.name) + + def __infer__(self, shape, a, b): + shape_v = shape["value"] + if shape_v is None: + raise ValueError(f"For {self.name}, shape must be const.") + validator.check_value_type("shape", shape_v, [tuple], self.name) + for i, shape_i in enumerate(shape_v): + validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.float32], self.name) + validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.float32], self.name) + broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name) + broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) + out = { + 'shape': broadcast_shape, + 'dtype': mstype.float32, + 'value': None} + return out class RandomChoiceWithMask(PrimitiveWithInfer): @@ -66,49 +329,6 @@ class RandomChoiceWithMask(PrimitiveWithInfer): return (mstype.int32, mstype.bool_) -class Normal(PrimitiveWithInfer): - """ - Generates random samples from a normal(Gaussian) distribution. - - Args: - seed (int): Random seed. Default: 0. - - Inputs: - - **shape** (tuple[int]) - The shape of output tensor. Only constant value is allowed. - - **mean** (Tensor) - The mean of the distribution, with float32 data type. - - **stddev** (Tensor) - The standard deviation of the distribution, with float32 data type. - - Outputs: - Tensor, with the given shape from the specific distribution and float32 data type. - - Examples: - >>> normal = P.Normal() - >>> mean = Tensor(0., mstype.float32) - >>> stddev = Tensor(1., mstype.float32) - >>> out = normal((32, 3, 3), mean, stddev) - """ - - @prim_attr_register - def __init__(self, seed=0): - """Init Normal""" - validator.check_value_type("seed", seed, [int], self.name) - - def __infer__(self, shape, mean, stddev): - shape_value = shape["value"] - if shape_value is None: - raise ValueError(f"For {self.name}, shape must be const.") - validator.check_value_type("shape", shape_value, [tuple], self.name) - for i, shape_i in enumerate(shape_value): - validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GE, self.name) - - validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) - validator.check_tensor_type_same({"stddev": stddev["dtype"]}, [mstype.float32], self.name) - - out = {"shape": shape_value, - "dtype": mstype.float32, - "value": None} - return out - class RandomCategorical(PrimitiveWithInfer): """ Generates random samples from a given categorical distribution tensor. diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py b/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py new file mode 100644 index 0000000000..4b685df16b --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.gamma = P.Gamma(seed=seed) + self.shape = shape + + def construct(self, alpha, beta): + return self.gamma(self.shape, alpha, beta) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + alpha = 1.0 + beta = 1.0 + net = Net(shape, seed) + talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32) + output = net(talpha, tbeta) + print(output.asnumpy()) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 1, 2) + alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) + beta = np.array([1.0]).astype(np.float32) + net = Net(shape, seed) + talpha, tbeta = Tensor(alpha), Tensor(beta) + output = net(talpha, tbeta) + print(output.asnumpy()) + assert output.shape == (3, 2, 2) + diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_normal.py b/tests/st/ops/ascend/test_aicpu_ops/test_normal.py index 66254caf21..1a94d61463 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_normal.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_normal.py @@ -12,32 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import numpy as np + import mindspore.context as context import mindspore.nn as nn +from mindspore import Tensor from mindspore.ops import operations as P -from mindspore.common import Tensor from mindspore.common import dtype as mstype - -context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): - def __init__(self, shape=None, mean=0.0, stddev=1.0, seed=0): + def __init__(self, shape, seed=0): super(Net, self).__init__() - self._mean = Tensor(mean, mstype.float32) - self._stddev = Tensor(stddev, mstype.float32) - self._normal = P.Normal(seed=seed) - self._shape = shape + self.normal = P.Normal(seed=seed) + self.shape = shape - def construct(self): - return self._normal(self._shape, self._mean, self._stddev) + def construct(self, mean, stddev): + return self.normal(self.shape, mean, stddev) -def test_net_3x2x4(): - mean = 0.0 +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + mean = 1.0 stddev = 1.0 - seed = 0 - net = Net((3, 2, 4), mean, stddev, seed) - out = net() - assert out.shape == (3, 2, 4) + net = Net(shape, seed) + tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) + output = net(tmean, tstddev) + print(output.asnumpy()) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 1, 2) + mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) + stddev = np.array([1.0]).astype(np.float32) + net = Net(shape, seed) + tmean, tstddev = Tensor(mean), Tensor(stddev) + output = net(tmean, tstddev) + print(output.asnumpy()) + assert output.shape == (3, 2, 2) \ No newline at end of file diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py b/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py new file mode 100644 index 0000000000..29af6cbb09 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, shape): + super(Net, self).__init__() + self.poisson = P.Poisson() + self.shape = shape + + def construct(self, mean): + return self.poisson(self.shape, mean) + + +def test_net_1(): + shape = (2, 16) + mean = np.array([5.0]).astype(np.float32) + net = Net(shape) + tmean = Tensor(mean) + output = net(tmean) + print(output.asnumpy()) + assert output.shape == (2, 16) + + +def test_net_2(): + shape = (4, 1) + mean = np.array([5.0, 10.0]).astype(np.float32) + net = Net(shape) + tmean = Tensor(mean) + output = net(tmean) + print(output.asnumpy()) + assert output.shape == (4, 2) \ No newline at end of file diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py new file mode 100644 index 0000000000..cbd39f4706 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.uniformint = P.UniformInt(seed=seed) + self.shape = shape + + def construct(self, a, b): + return self.uniformint(self.shape, a, b) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + a = 1 + b = 5 + net = Net(shape, seed) + ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32) + output = net(ta, tb) + print(output.asnumpy()) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 2, 1) + a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32) + b = np.array([10]).astype(np.int32) + net = Net(shape, seed) + ta, tb = Tensor(a), Tensor(b) + output = net(ta, tb) + print(output.asnumpy()) + assert output.shape == (3, 2, 2) \ No newline at end of file diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py new file mode 100644 index 0000000000..635eb3fa28 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.uniformreal = P.UniformReal(seed=seed) + self.shape = shape + + def construct(self, a, b): + return self.uniformreal(self.shape, a, b) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + a = 1.0 + b = 5.0 + net = Net(shape, seed) + ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) + output = net(ta, tb) + print(output.asnumpy()) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 2, 1) + a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.float32) + b = np.array([10]).astype(np.float32) + net = Net(shape, seed) + ta, tb = Tensor(a), Tensor(b) + output = net(ta, tb) + print(output.asnumpy()) + assert output.shape == (3, 2, 2) \ No newline at end of file diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index cf6a6705ab..c9dab27b66 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -400,15 +400,57 @@ class InplaceSubNet(nn.Cell): class NormalNet(nn.Cell): - def __init__(self, shape=None, mean=0.0, stddev=1.0, seed=0): + def __init__(self, shape=None, seed=0): super(NormalNet, self).__init__() self.normal = P.Normal(seed=seed) self.shape = shape - self.mean = Tensor(mean, mstype.float32) - self.stddev = Tensor(stddev, mstype.float32) - def construct(self): - out = self.normal(self.shape, self.mean, self.stddev) + def construct(self, mean, stddev): + out = self.normal(self.shape, mean, stddev) + return out + + +class GammaNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(GammaNet, self).__init__() + self.gamma = P.Gamma(seed=seed) + self.shape = shape + + def construct(self, alpha, beta): + out = self.gamma(self.shape, alpha, beta) + return out + + +class PoissonNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(PoissonNet, self).__init__() + self.poisson = P.Poisson(seed=seed) + self.shape = shape + + def construct(self, mean): + out = self.poisson(self.shape, mean) + return out + + +class UniformIntNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(UniformIntNet, self).__init__() + self.uniformint = P.UniformInt(seed=seed) + self.shape = shape + + def construct(self, a, b): + out = self.uniformint(self.shape, a, b) + return out + + +class UniformRealNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(UniformRealNet, self).__init__() + self.uniformreal = P.UniformReal(seed=seed) + self.shape = shape + + def construct(self, a, b): + out = self.uniformreal(self.shape, a, b) return out @@ -620,6 +662,26 @@ test_case_math_ops = [ (1, 1, 1)], 'desc_inputs': [[64, 128, 1024]], 'skip': ['backward']}), + ('Normal', { + 'block': NormalNet((3, 2, 4), 0), + 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)], + 'skip': ['backward']}), + ('Gamma', { + 'block': GammaNet((3, 2, 4), 0), + 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)], + 'skip': ['backward']}), + ('Poisson', { + 'block': PoissonNet((3, 2, 4), 0), + 'desc_inputs': [Tensor(2.0, mstype.float32)], + 'skip': ['backward']}), + ('UniformInt', { + 'block': UniformIntNet((3, 2, 4), 0), + 'desc_inputs': [Tensor(1, mstype.int32), Tensor(15, mstype.int32)], + 'skip': ['backward']}), + ('UniformReal', { + 'block': UniformRealNet((3, 2, 4), 0), + 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(5.0, mstype.float32)], + 'skip': ['backward']}), ('RandomChoiceWithMask', { 'block': P.RandomChoiceWithMask(256), 'desc_inputs': [Tensor(np.random.rand(24000, 4).astype(np.bool_))], @@ -908,10 +970,6 @@ test_case_math_ops = [ 'desc_inputs': [Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mstype.float16), Tensor([0.0, 5.0], mstype.float16)], 'desc_bprop': [], 'skip': ['backward']}), - ('Normal', { - 'block': NormalNet((3, 2, 4), 0.0, 1.0, 0), - 'desc_inputs': [], - 'skip': ['backward']}), ] test_case_nn_ops = [