diff --git a/mindspore/ops/_op_impl/aicpu/gamma.py b/mindspore/ops/_op_impl/aicpu/gamma.py index b6b92a9da4..801ae41d6a 100644 --- a/mindspore/ops/_op_impl/aicpu/gamma.py +++ b/mindspore/ops/_op_impl/aicpu/gamma.py @@ -23,6 +23,7 @@ gamma_op_info = AiCPURegOp("Gamma") \ .input(2, "beta", "required") \ .output(0, "output", "required") \ .attr("seed", "int") \ + .attr("seed2", "int") \ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/aicpu/poisson.py b/mindspore/ops/_op_impl/aicpu/poisson.py index 59d2c1b957..4569efe40e 100644 --- a/mindspore/ops/_op_impl/aicpu/poisson.py +++ b/mindspore/ops/_op_impl/aicpu/poisson.py @@ -22,6 +22,7 @@ poisson_op_info = AiCPURegOp("Poisson") \ .input(1, "mean", "required") \ .output(0, "output", "required") \ .attr("seed", "int") \ + .attr("seed2", "int") \ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/aicpu/uniform_int.py b/mindspore/ops/_op_impl/aicpu/uniform_int.py index 35cfbec11c..3e76dc794a 100644 --- a/mindspore/ops/_op_impl/aicpu/uniform_int.py +++ b/mindspore/ops/_op_impl/aicpu/uniform_int.py @@ -23,6 +23,7 @@ uniform_int_op_info = AiCPURegOp("UniformInt") \ .input(2, "b", "required") \ .output(0, "output", "required") \ .attr("seed", "int") \ + .attr("seed2", "int") \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/aicpu/uniform_real.py b/mindspore/ops/_op_impl/aicpu/uniform_real.py index 51824fbb2c..9e0876d317 100644 --- a/mindspore/ops/_op_impl/aicpu/uniform_real.py +++ b/mindspore/ops/_op_impl/aicpu/uniform_real.py @@ -19,12 +19,11 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp uniform_real_op_info = AiCPURegOp("UniformReal") \ .fusion_type("OPAQUE") \ .input(0, "shape", "required") \ - .input(1, "a", "required") \ - .input(2, "b", "required") \ .output(0, "output", "required") \ .attr("seed", "int") \ - .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ + .attr("seed2", "int") \ + .dtype_format(DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW) \ .get_op_info() @op_info_register(uniform_real_op_info) diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index bb5e2960ff..ab35dd65fb 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,7 +27,7 @@ from .clip_ops import clip_by_value from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like -from .random_ops import normal +from .random_ops import set_seed, normal, uniform __all__ = [ @@ -48,5 +48,7 @@ __all__ = [ 'zeros_like', 'ones_like', 'zip_operation', + 'set_seed', + 'uniform', 'normal', 'clip_by_value',] diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index db338f5672..88037aefb7 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -13,10 +13,13 @@ # limitations under the License. # ============================================================================ -"""Operations for random number generatos.""" +"""Operations for random number generators.""" -from mindspore.ops.primitive import constexpr from .. import operations as P +from .. import functional as F +from ..primitive import constexpr +from .multitype_ops import _constexpr_utils as const_utils +from ...common import dtype as mstype # set graph-level RNG seed _GRAPH_SEED = 0 @@ -31,17 +34,17 @@ def get_seed(): return _GRAPH_SEED -def normal(shape, mean, stddev, seed): +def normal(shape, mean, stddev, seed=0): """ Generates random numbers according to the Normal (or Gaussian) random number distribution. It is defined as: Args: - - **shape** (tuple) - The shape of random tensor to be generated. - - **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak. + shape (tuple): The shape of random tensor to be generated. + mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak. With float32 data type. - - **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type. - - **seed** (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. + stddev (Tensor): The deviation σ distribution parameter. With float32 data type. + seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. Default: 0. Returns: @@ -52,12 +55,58 @@ def normal(shape, mean, stddev, seed): >>> shape = (4, 16) >>> mean = Tensor(1.0, mstype.float32) >>> stddev = Tensor(1.0, mstype.float32) + >>> C.set_seed(10) >>> output = C.normal(shape, mean, stddev, seed=5) """ - set_seed(10) + mean_dtype = F.dtype(mean) + stddev_dtype = F.dtype(stddev) + const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal") + const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal") seed1 = get_seed() seed2 = seed stdnormal = P.StandardNormal(seed1, seed2) rnd = stdnormal(shape) value = rnd * stddev + mean return value + +def uniform(shape, a, b, seed=0, dtype=mstype.float32): + """ + Generates random numbers according to the Uniform (or Gaussian) random number distribution. + It is defined as: + + Args: + shape (tuple): The shape of random tensor to be generated. + a (Tensor): The a distribution parameter. + It defines the minimum possibly generated value. With int32 or float32 data type. + If dtype is int32, only one number is allowed. + b (Tensor): The b distribution parameter. + It defines the maximum possibly generated value. With int32 or float32 data type. + If dtype is int32, only one number is allowed. + seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Returns: + Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b. + The dtype is float32. + + Examples: + >>> shape = (4, 16) + >>> a = Tensor(1.0, mstype.float32) + >>> b = Tensor(1.0, mstype.float32) + >>> C.set_seed(10) + >>> output = C.uniform(shape, a, b, seed=5) + """ + a_dtype = F.dtype(a) + b_dtype = F.dtype(b) + const_utils.check_tensors_dtype_same(a_dtype, dtype, "uniform") + const_utils.check_tensors_dtype_same(b_dtype, dtype, "uniform") + seed1 = get_seed() + seed2 = seed + if const_utils.is_same_type(dtype, mstype.int32): + rnd = P.UniformInt(seed1, seed2) + value = rnd(shape, a, b) + else: + uniform_real = P.UniformReal(seed1, seed2) + rnd = uniform_real(shape) + value = rnd * (b - a) + a + return value diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index d2c67b8f1b..59b28cf09d 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -34,8 +34,7 @@ class StandardNormal(PrimitiveWithInfer): - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. Outputs: - Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev. - The dtype is float32. + Tensor. The shape that the input 'shape' denotes. The dtype is float32. Examples: >>> shape = (4, 16) @@ -126,8 +125,8 @@ class Gamma(PrimitiveWithInfer): \text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}}, Args: - seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. - Default: 0. + seed (int): Random seed. Default: 0. + seed2 (int): Random seed2. Default: 0. Inputs: - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. @@ -149,10 +148,11 @@ class Gamma(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, seed=0): + def __init__(self, seed=0, seed2=0): """Init Gamma""" self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) validator.check_value_type('seed', seed, [int], self.name) + validator.check_value_type('seed2', seed2, [int], self.name) def __infer__(self, shape, alpha, beta): shape_v = shape["value"] @@ -180,8 +180,8 @@ class Poisson(PrimitiveWithInfer): \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}, Args: - seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. - Default: 0. + seed (int): Random seed. Default: 0. + seed2 (int): Random seed2. Default: 0. Inputs: - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. @@ -200,10 +200,11 @@ class Poisson(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, seed=0): + def __init__(self, seed=0, seed2=0): """Init Poisson""" self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) validator.check_value_type('seed', seed, [int], self.name) + validator.check_value_type('seed2', seed2, [int], self.name) def __infer__(self, shape, mean): shape_v = shape["value"] @@ -223,7 +224,7 @@ class Poisson(PrimitiveWithInfer): class UniformInt(PrimitiveWithInfer): r""" - Produces random integer values i, uniformly distributed on the closed interval [a, b], that is, + Produces random integer values i, uniformly distributed on the closed interval [a, b), that is, distributed according to the discrete probability function: .. math:: @@ -233,19 +234,18 @@ class UniformInt(PrimitiveWithInfer): The number in tensor a should be strictly less than b at any position after broadcasting. Args: - seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. - Default: 0. + seed (int): Random seed. Default: 0. + seed2 (int): Random seed2. Default: 0. Inputs: - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **a** (Tensor) - The a distribution parameter. - It defines the minimum possibly generated value. With int32 data type. + It defines the minimum possibly generated value. With int32 data type. Only one number is supported. - **b** (Tensor) - The b distribution parameter. - It defines the maximum possibly generated value. With int32 data type. + It defines the maximum possibly generated value. With int32 data type. Only one number is supported. Outputs: - Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b. - The dtype is int32. + Tensor. The shape that the input 'shape' denotes. The dtype is int32. Examples: >>> shape = (4, 16) @@ -256,10 +256,11 @@ class UniformInt(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, seed=0): + def __init__(self, seed=0, seed2=0): """Init UniformInt""" self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output']) validator.check_value_type('seed', seed, [int], self.name) + validator.check_value_type('seed2', seed2, [int], self.name) def __infer__(self, shape, a, b): shape_v = shape["value"] @@ -270,10 +271,12 @@ class UniformInt(PrimitiveWithInfer): validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name) validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name) - broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name) - broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) + a_shape = a['shape'] + b_shape = b['shape'] + validator.check("dim of a", len(a_shape), '0(scalar)', 0, Rel.EQ, self.name) + validator.check("dim of b", len(b_shape), '0(scalar)', 0, Rel.EQ, self.name) out = { - 'shape': broadcast_shape, + 'shape': shape_v, 'dtype': mstype.int32, 'value': None} return out @@ -281,54 +284,40 @@ class UniformInt(PrimitiveWithInfer): class UniformReal(PrimitiveWithInfer): r""" - Produces random floating-point values i, uniformly distributed on the interval [min(a, b), max(a, b)), that is,\ - distributed according to the probability density function: - - .. math:: - \text{P}(i|a,b) = \frac{1}{b-a}, + Produces random floating-point values i, uniformly distributed on the interval [0, 1). Args: - seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. - Default: 0. + seed (int): Random seed. Default: 0. + seed2 (int): Random seed2. Default: 0. Inputs: - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - - **a** (Tensor) - The a distribution parameter. - It defines the minimum possibly generated value. With float32 data type. - - **b** (Tensor) - The b distribution parameter. - It defines the maximum possibly generated value. With float32 data type. Outputs: - Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b. - The dtype is float32. + Tensor. The shape that the input 'shape' denotes. The dtype is float32. Examples: >>> shape = (4, 16) - >>> a = Tensor(1.0, mstype.float32) - >>> b = Tensor(5.0, mstype.float32) - >>> uniform_real = P.UniformReal(seed=10) - >>> output = uniform_real(shape, a, b) + >>> uniformreal = P.UniformReal(seed=2) + >>> output = uniformreal(shape) """ @prim_attr_register - def __init__(self, seed=0): + def __init__(self, seed=0, seed2=0): """Init UniformReal""" - self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output']) + self.init_prim_io_names(inputs=['shape'], outputs=['output']) validator.check_value_type('seed', seed, [int], self.name) + validator.check_value_type('seed2', seed2, [int], self.name) - def __infer__(self, shape, a, b): + def __infer__(self, shape): shape_v = shape["value"] if shape_v is None: raise ValueError(f"For {self.name}, shape must be const.") validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) - validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.float32], self.name) - validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.float32], self.name) - broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name) - broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) out = { - 'shape': broadcast_shape, + 'shape': shape_v, 'dtype': mstype.float32, 'value': None} return out diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py b/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py index 4b685df16b..61bb3f8476 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_gamma.py @@ -24,9 +24,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): - def __init__(self, shape, seed=0): + def __init__(self, shape, seed=0, seed2=0): super(Net, self).__init__() - self.gamma = P.Gamma(seed=seed) + self.gamma = P.Gamma(seed=seed, seed2=seed2) self.shape = shape def construct(self, alpha, beta): @@ -38,10 +38,9 @@ def test_net_1D(): shape = (3, 2, 4) alpha = 1.0 beta = 1.0 - net = Net(shape, seed) + net = Net(shape=shape, seed=seed) talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32) output = net(talpha, tbeta) - print(output.asnumpy()) assert output.shape == (3, 2, 4) @@ -50,9 +49,8 @@ def test_net_ND(): shape = (3, 1, 2) alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) beta = np.array([1.0]).astype(np.float32) - net = Net(shape, seed) + net = Net(shape=shape, seed=seed) talpha, tbeta = Tensor(alpha), Tensor(beta) output = net(talpha, tbeta) - print(output.asnumpy()) assert output.shape == (3, 2, 2) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_normal.py b/tests/st/ops/ascend/test_aicpu_ops/test_normal.py index 77f0b7e8e8..70c9c2c68f 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_normal.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_normal.py @@ -32,6 +32,7 @@ class Net(nn.Cell): self.seed = seed def construct(self, mean, stddev): + C.set_seed(20) return C.normal(self.shape, mean, stddev, self.seed) @@ -43,7 +44,6 @@ def test_net_1D(): net = Net(shape, seed) tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) output = net(tmean, tstddev) - print(output.asnumpy()) assert output.shape == (3, 2, 4) @@ -55,5 +55,4 @@ def test_net_ND(): net = Net(shape, seed) tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) output = net(tmean, tstddev) - print(output.asnumpy()) assert output.shape == (3, 2, 2) \ No newline at end of file diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py b/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py index 29af6cbb09..dd5ada2712 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_poisson.py @@ -24,7 +24,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): - def __init__(self, shape): + def __init__(self, shape, seed=0, seed2=0): super(Net, self).__init__() self.poisson = P.Poisson() self.shape = shape @@ -36,18 +36,16 @@ class Net(nn.Cell): def test_net_1(): shape = (2, 16) mean = np.array([5.0]).astype(np.float32) - net = Net(shape) + net = Net(shape=shape) tmean = Tensor(mean) output = net(tmean) - print(output.asnumpy()) assert output.shape == (2, 16) def test_net_2(): shape = (4, 1) mean = np.array([5.0, 10.0]).astype(np.float32) - net = Net(shape) + net = Net(shape=shape) tmean = Tensor(mean) output = net(tmean) - print(output.asnumpy()) assert output.shape == (4, 2) \ No newline at end of file diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py b/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py index 5cc21fac80..f45cb462cc 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py @@ -34,7 +34,7 @@ class Net(nn.Cell): self.stdnormal = P.StandardNormal(seed, seed2) def construct(self): - return self.stdnormal(self.shape, self.seed, self.seed2) + return self.stdnormal(self.shape) def test_net(): diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_uniform.py b/tests/st/ops/ascend/test_aicpu_ops/test_uniform.py new file mode 100644 index 0000000000..cef50bdbc7 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_uniform.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, a, b): + C.set_seed(20) + return C.uniform(self.shape, a, b, self.seed) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + a = 1.0 + b = 6.0 + net = Net(shape, seed) + ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) + output = net(ta, tb) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 1, 2) + a = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) + b = np.array([1.0]).astype(np.float32) + net = Net(shape, seed) + ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) + output = net(ta, tb) + assert output.shape == (3, 2, 2) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py index cbd39f4706..fef34fad07 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import numpy as np import mindspore.context as context import mindspore.nn as nn @@ -24,7 +23,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): - def __init__(self, shape, seed=0): + def __init__(self, shape, seed=0, seed2=0): super(Net, self).__init__() self.uniformint = P.UniformInt(seed=seed) self.shape = shape @@ -38,20 +37,7 @@ def test_net_1D(): shape = (3, 2, 4) a = 1 b = 5 - net = Net(shape, seed) + net = Net(shape, seed=seed) ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32) output = net(ta, tb) - print(output.asnumpy()) assert output.shape == (3, 2, 4) - - -def test_net_ND(): - seed = 10 - shape = (3, 2, 1) - a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32) - b = np.array([10]).astype(np.int32) - net = Net(shape, seed) - ta, tb = Tensor(a), Tensor(b) - output = net(ta, tb) - print(output.asnumpy()) - assert output.shape == (3, 2, 2) \ No newline at end of file diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py index 635eb3fa28..7ab0b42e11 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py @@ -12,46 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import numpy as np import mindspore.context as context import mindspore.nn as nn -from mindspore import Tensor from mindspore.ops import operations as P -from mindspore.common import dtype as mstype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): - def __init__(self, shape, seed=0): + def __init__(self, shape, seed=0, seed2=0): super(Net, self).__init__() self.uniformreal = P.UniformReal(seed=seed) self.shape = shape - def construct(self, a, b): - return self.uniformreal(self.shape, a, b) + def construct(self): + return self.uniformreal(self.shape) -def test_net_1D(): +def test_net(): seed = 10 shape = (3, 2, 4) - a = 1.0 - b = 5.0 - net = Net(shape, seed) - ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) - output = net(ta, tb) - print(output.asnumpy()) + net = Net(shape, seed=seed) + output = net() assert output.shape == (3, 2, 4) - - -def test_net_ND(): - seed = 10 - shape = (3, 2, 1) - a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.float32) - b = np.array([10]).astype(np.float32) - net = Net(shape, seed) - ta, tb = Tensor(a), Tensor(b) - output = net(ta, tb) - print(output.asnumpy()) - assert output.shape == (3, 2, 2) \ No newline at end of file diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 31ca540f74..7f53d40469 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -573,25 +573,14 @@ class PoissonNet(nn.Cell): return out -class UniformIntNet(nn.Cell): +class UniformNet(nn.Cell): def __init__(self, shape=None, seed=0): - super(UniformIntNet, self).__init__() - self.uniformint = P.UniformInt(seed=seed) - self.shape = shape - - def construct(self, a, b): - out = self.uniformint(self.shape, a, b) - return out - - -class UniformRealNet(nn.Cell): - def __init__(self, shape=None, seed=0): - super(UniformRealNet, self).__init__() - self.uniformreal = P.UniformReal(seed=seed) + super(UniformNet, self).__init__() self.shape = shape + self.seed = seed def construct(self, a, b): - out = self.uniformreal(self.shape, a, b) + out = C.uniform(self.shape, a, b, self.seed) return out @@ -882,13 +871,9 @@ test_case_math_ops = [ 'block': PoissonNet((3, 2, 4), 0), 'desc_inputs': [Tensor(2.0, mstype.float32)], 'skip': ['backward']}), - ('UniformInt', { - 'block': UniformIntNet((3, 2, 4), 0), - 'desc_inputs': [Tensor(1, mstype.int32), Tensor(15, mstype.int32)], - 'skip': ['backward']}), - ('UniformReal', { - 'block': UniformRealNet((3, 2, 4), 0), - 'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(5.0, mstype.float32)], + ('Uniform', { + 'block': UniformNet((3, 2, 4), 0), + 'desc_inputs': [Tensor(0.0, mstype.float32), Tensor(1.0, mstype.float32)], 'skip': ['backward']}), ('RandomChoiceWithMask', { 'block': P.RandomChoiceWithMask(256),