Browse Source

Refactor random uniform ops and complete gamma and poisson

tags/v0.7.0-beta
peixu_ren 5 years ago
parent
commit
374772035a
15 changed files with 180 additions and 134 deletions
  1. +1
    -0
      mindspore/ops/_op_impl/aicpu/gamma.py
  2. +1
    -0
      mindspore/ops/_op_impl/aicpu/poisson.py
  3. +1
    -0
      mindspore/ops/_op_impl/aicpu/uniform_int.py
  4. +3
    -4
      mindspore/ops/_op_impl/aicpu/uniform_real.py
  5. +3
    -1
      mindspore/ops/composite/__init__.py
  6. +57
    -8
      mindspore/ops/composite/random_ops.py
  7. +33
    -44
      mindspore/ops/operations/random_ops.py
  8. +4
    -6
      tests/st/ops/ascend/test_aicpu_ops/test_gamma.py
  9. +1
    -2
      tests/st/ops/ascend/test_aicpu_ops/test_normal.py
  10. +3
    -5
      tests/st/ops/ascend/test_aicpu_ops/test_poisson.py
  11. +1
    -1
      tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py
  12. +57
    -0
      tests/st/ops/ascend/test_aicpu_ops/test_uniform.py
  13. +2
    -16
      tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py
  14. +6
    -25
      tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py
  15. +7
    -22
      tests/ut/python/ops/test_ops.py

+ 1
- 0
mindspore/ops/_op_impl/aicpu/gamma.py View File

@@ -23,6 +23,7 @@ gamma_op_info = AiCPURegOp("Gamma") \
.input(2, "beta", "required") \ .input(2, "beta", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("seed", "int") \ .attr("seed", "int") \
.attr("seed2", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info() .get_op_info()


+ 1
- 0
mindspore/ops/_op_impl/aicpu/poisson.py View File

@@ -22,6 +22,7 @@ poisson_op_info = AiCPURegOp("Poisson") \
.input(1, "mean", "required") \ .input(1, "mean", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("seed", "int") \ .attr("seed", "int") \
.attr("seed2", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW) \ .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW) \
.get_op_info() .get_op_info()


+ 1
- 0
mindspore/ops/_op_impl/aicpu/uniform_int.py View File

@@ -23,6 +23,7 @@ uniform_int_op_info = AiCPURegOp("UniformInt") \
.input(2, "b", "required") \ .input(2, "b", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("seed", "int") \ .attr("seed", "int") \
.attr("seed2", "int") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \ .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \
.get_op_info() .get_op_info()


+ 3
- 4
mindspore/ops/_op_impl/aicpu/uniform_real.py View File

@@ -19,12 +19,11 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
uniform_real_op_info = AiCPURegOp("UniformReal") \ uniform_real_op_info = AiCPURegOp("UniformReal") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "shape", "required") \ .input(0, "shape", "required") \
.input(1, "a", "required") \
.input(2, "b", "required") \
.output(0, "output", "required") \ .output(0, "output", "required") \
.attr("seed", "int") \ .attr("seed", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.attr("seed2", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW) \
.get_op_info() .get_op_info()


@op_info_register(uniform_real_op_info) @op_info_register(uniform_real_op_info)


+ 3
- 1
mindspore/ops/composite/__init__.py View File

@@ -27,7 +27,7 @@ from .clip_ops import clip_by_value
from .multitype_ops.add_impl import hyper_add from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal
from .random_ops import set_seed, normal, uniform




__all__ = [ __all__ = [
@@ -48,5 +48,7 @@ __all__ = [
'zeros_like', 'zeros_like',
'ones_like', 'ones_like',
'zip_operation', 'zip_operation',
'set_seed',
'uniform',
'normal', 'normal',
'clip_by_value',] 'clip_by_value',]

+ 57
- 8
mindspore/ops/composite/random_ops.py View File

@@ -13,10 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


"""Operations for random number generatos."""
"""Operations for random number generators."""


from mindspore.ops.primitive import constexpr
from .. import operations as P from .. import operations as P
from .. import functional as F
from ..primitive import constexpr
from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype


# set graph-level RNG seed # set graph-level RNG seed
_GRAPH_SEED = 0 _GRAPH_SEED = 0
@@ -31,17 +34,17 @@ def get_seed():
return _GRAPH_SEED return _GRAPH_SEED




def normal(shape, mean, stddev, seed):
def normal(shape, mean, stddev, seed=0):
""" """
Generates random numbers according to the Normal (or Gaussian) random number distribution. Generates random numbers according to the Normal (or Gaussian) random number distribution.
It is defined as: It is defined as:


Args: Args:
- **shape** (tuple) - The shape of random tensor to be generated.
- **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak.
shape (tuple): The shape of random tensor to be generated.
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
With float32 data type. With float32 data type.
- **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type.
- **seed** (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
stddev (Tensor): The deviation σ distribution parameter. With float32 data type.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0. Default: 0.


Returns: Returns:
@@ -52,12 +55,58 @@ def normal(shape, mean, stddev, seed):
>>> shape = (4, 16) >>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32) >>> mean = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32) >>> stddev = Tensor(1.0, mstype.float32)
>>> C.set_seed(10)
>>> output = C.normal(shape, mean, stddev, seed=5) >>> output = C.normal(shape, mean, stddev, seed=5)
""" """
set_seed(10)
mean_dtype = F.dtype(mean)
stddev_dtype = F.dtype(stddev)
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
seed1 = get_seed() seed1 = get_seed()
seed2 = seed seed2 = seed
stdnormal = P.StandardNormal(seed1, seed2) stdnormal = P.StandardNormal(seed1, seed2)
rnd = stdnormal(shape) rnd = stdnormal(shape)
value = rnd * stddev + mean value = rnd * stddev + mean
return value return value

def uniform(shape, a, b, seed=0, dtype=mstype.float32):
"""
Generates random numbers according to the Uniform (or Gaussian) random number distribution.
It is defined as:

Args:
shape (tuple): The shape of random tensor to be generated.
a (Tensor): The a distribution parameter.
It defines the minimum possibly generated value. With int32 or float32 data type.
If dtype is int32, only one number is allowed.
b (Tensor): The b distribution parameter.
It defines the maximum possibly generated value. With int32 or float32 data type.
If dtype is int32, only one number is allowed.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.

Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b.
The dtype is float32.

Examples:
>>> shape = (4, 16)
>>> a = Tensor(1.0, mstype.float32)
>>> b = Tensor(1.0, mstype.float32)
>>> C.set_seed(10)
>>> output = C.uniform(shape, a, b, seed=5)
"""
a_dtype = F.dtype(a)
b_dtype = F.dtype(b)
const_utils.check_tensors_dtype_same(a_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(b_dtype, dtype, "uniform")
seed1 = get_seed()
seed2 = seed
if const_utils.is_same_type(dtype, mstype.int32):
rnd = P.UniformInt(seed1, seed2)
value = rnd(shape, a, b)
else:
uniform_real = P.UniformReal(seed1, seed2)
rnd = uniform_real(shape)
value = rnd * (b - a) + a
return value

+ 33
- 44
mindspore/ops/operations/random_ops.py View File

@@ -34,8 +34,7 @@ class StandardNormal(PrimitiveWithInfer):
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.


Outputs: Outputs:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev.
The dtype is float32.
Tensor. The shape that the input 'shape' denotes. The dtype is float32.


Examples: Examples:
>>> shape = (4, 16) >>> shape = (4, 16)
@@ -126,8 +125,8 @@ class Gamma(PrimitiveWithInfer):
\text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}}, \text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}},


Args: Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.


Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
@@ -149,10 +148,11 @@ class Gamma(PrimitiveWithInfer):
""" """


@prim_attr_register @prim_attr_register
def __init__(self, seed=0):
def __init__(self, seed=0, seed2=0):
"""Init Gamma""" """Init Gamma"""
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)


def __infer__(self, shape, alpha, beta): def __infer__(self, shape, alpha, beta):
shape_v = shape["value"] shape_v = shape["value"]
@@ -180,8 +180,8 @@ class Poisson(PrimitiveWithInfer):
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}, \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!},


Args: Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.


Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
@@ -200,10 +200,11 @@ class Poisson(PrimitiveWithInfer):
""" """


@prim_attr_register @prim_attr_register
def __init__(self, seed=0):
def __init__(self, seed=0, seed2=0):
"""Init Poisson""" """Init Poisson"""
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)


def __infer__(self, shape, mean): def __infer__(self, shape, mean):
shape_v = shape["value"] shape_v = shape["value"]
@@ -223,7 +224,7 @@ class Poisson(PrimitiveWithInfer):


class UniformInt(PrimitiveWithInfer): class UniformInt(PrimitiveWithInfer):
r""" r"""
Produces random integer values i, uniformly distributed on the closed interval [a, b], that is,
Produces random integer values i, uniformly distributed on the closed interval [a, b), that is,
distributed according to the discrete probability function: distributed according to the discrete probability function:


.. math:: .. math::
@@ -233,19 +234,18 @@ class UniformInt(PrimitiveWithInfer):
The number in tensor a should be strictly less than b at any position after broadcasting. The number in tensor a should be strictly less than b at any position after broadcasting.


Args: Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.


Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
- **a** (Tensor) - The a distribution parameter. - **a** (Tensor) - The a distribution parameter.
It defines the minimum possibly generated value. With int32 data type.
It defines the minimum possibly generated value. With int32 data type. Only one number is supported.
- **b** (Tensor) - The b distribution parameter. - **b** (Tensor) - The b distribution parameter.
It defines the maximum possibly generated value. With int32 data type.
It defines the maximum possibly generated value. With int32 data type. Only one number is supported.


Outputs: Outputs:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b.
The dtype is int32.
Tensor. The shape that the input 'shape' denotes. The dtype is int32.


Examples: Examples:
>>> shape = (4, 16) >>> shape = (4, 16)
@@ -256,10 +256,11 @@ class UniformInt(PrimitiveWithInfer):
""" """


@prim_attr_register @prim_attr_register
def __init__(self, seed=0):
def __init__(self, seed=0, seed2=0):
"""Init UniformInt""" """Init UniformInt"""
self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)


def __infer__(self, shape, a, b): def __infer__(self, shape, a, b):
shape_v = shape["value"] shape_v = shape["value"]
@@ -270,10 +271,12 @@ class UniformInt(PrimitiveWithInfer):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name) validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name)
validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name) validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name)
broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name)
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
a_shape = a['shape']
b_shape = b['shape']
validator.check("dim of a", len(a_shape), '0(scalar)', 0, Rel.EQ, self.name)
validator.check("dim of b", len(b_shape), '0(scalar)', 0, Rel.EQ, self.name)
out = { out = {
'shape': broadcast_shape,
'shape': shape_v,
'dtype': mstype.int32, 'dtype': mstype.int32,
'value': None} 'value': None}
return out return out
@@ -281,54 +284,40 @@ class UniformInt(PrimitiveWithInfer):


class UniformReal(PrimitiveWithInfer): class UniformReal(PrimitiveWithInfer):
r""" r"""
Produces random floating-point values i, uniformly distributed on the interval [min(a, b), max(a, b)), that is,\
distributed according to the probability density function:

.. math::
\text{P}(i|a,b) = \frac{1}{b-a},
Produces random floating-point values i, uniformly distributed on the interval [0, 1).


Args: Args:
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
Default: 0.
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.


Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
- **a** (Tensor) - The a distribution parameter.
It defines the minimum possibly generated value. With float32 data type.
- **b** (Tensor) - The b distribution parameter.
It defines the maximum possibly generated value. With float32 data type.


Outputs: Outputs:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b.
The dtype is float32.
Tensor. The shape that the input 'shape' denotes. The dtype is float32.


Examples: Examples:
>>> shape = (4, 16) >>> shape = (4, 16)
>>> a = Tensor(1.0, mstype.float32)
>>> b = Tensor(5.0, mstype.float32)
>>> uniform_real = P.UniformReal(seed=10)
>>> output = uniform_real(shape, a, b)
>>> uniformreal = P.UniformReal(seed=2)
>>> output = uniformreal(shape)
""" """


@prim_attr_register @prim_attr_register
def __init__(self, seed=0):
def __init__(self, seed=0, seed2=0):
"""Init UniformReal""" """Init UniformReal"""
self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output'])
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name)


def __infer__(self, shape, a, b):
def __infer__(self, shape):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.float32], self.name)
validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name)
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
out = { out = {
'shape': broadcast_shape,
'shape': shape_v,
'dtype': mstype.float32, 'dtype': mstype.float32,
'value': None} 'value': None}
return out return out


+ 4
- 6
tests/st/ops/ascend/test_aicpu_ops/test_gamma.py View File

@@ -24,9 +24,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")




class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape, seed=0):
def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__() super(Net, self).__init__()
self.gamma = P.Gamma(seed=seed)
self.gamma = P.Gamma(seed=seed, seed2=seed2)
self.shape = shape self.shape = shape


def construct(self, alpha, beta): def construct(self, alpha, beta):
@@ -38,10 +38,9 @@ def test_net_1D():
shape = (3, 2, 4) shape = (3, 2, 4)
alpha = 1.0 alpha = 1.0
beta = 1.0 beta = 1.0
net = Net(shape, seed)
net = Net(shape=shape, seed=seed)
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32) talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
output = net(talpha, tbeta) output = net(talpha, tbeta)
print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)




@@ -50,9 +49,8 @@ def test_net_ND():
shape = (3, 1, 2) shape = (3, 1, 2)
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
beta = np.array([1.0]).astype(np.float32) beta = np.array([1.0]).astype(np.float32)
net = Net(shape, seed)
net = Net(shape=shape, seed=seed)
talpha, tbeta = Tensor(alpha), Tensor(beta) talpha, tbeta = Tensor(alpha), Tensor(beta)
output = net(talpha, tbeta) output = net(talpha, tbeta)
print(output.asnumpy())
assert output.shape == (3, 2, 2) assert output.shape == (3, 2, 2)



+ 1
- 2
tests/st/ops/ascend/test_aicpu_ops/test_normal.py View File

@@ -32,6 +32,7 @@ class Net(nn.Cell):
self.seed = seed self.seed = seed


def construct(self, mean, stddev): def construct(self, mean, stddev):
C.set_seed(20)
return C.normal(self.shape, mean, stddev, self.seed) return C.normal(self.shape, mean, stddev, self.seed)




@@ -43,7 +44,6 @@ def test_net_1D():
net = Net(shape, seed) net = Net(shape, seed)
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
output = net(tmean, tstddev) output = net(tmean, tstddev)
print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)




@@ -55,5 +55,4 @@ def test_net_ND():
net = Net(shape, seed) net = Net(shape, seed)
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
output = net(tmean, tstddev) output = net(tmean, tstddev)
print(output.asnumpy())
assert output.shape == (3, 2, 2) assert output.shape == (3, 2, 2)

+ 3
- 5
tests/st/ops/ascend/test_aicpu_ops/test_poisson.py View File

@@ -24,7 +24,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")




class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape):
def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__() super(Net, self).__init__()
self.poisson = P.Poisson() self.poisson = P.Poisson()
self.shape = shape self.shape = shape
@@ -36,18 +36,16 @@ class Net(nn.Cell):
def test_net_1(): def test_net_1():
shape = (2, 16) shape = (2, 16)
mean = np.array([5.0]).astype(np.float32) mean = np.array([5.0]).astype(np.float32)
net = Net(shape)
net = Net(shape=shape)
tmean = Tensor(mean) tmean = Tensor(mean)
output = net(tmean) output = net(tmean)
print(output.asnumpy())
assert output.shape == (2, 16) assert output.shape == (2, 16)




def test_net_2(): def test_net_2():
shape = (4, 1) shape = (4, 1)
mean = np.array([5.0, 10.0]).astype(np.float32) mean = np.array([5.0, 10.0]).astype(np.float32)
net = Net(shape)
net = Net(shape=shape)
tmean = Tensor(mean) tmean = Tensor(mean)
output = net(tmean) output = net(tmean)
print(output.asnumpy())
assert output.shape == (4, 2) assert output.shape == (4, 2)

+ 1
- 1
tests/st/ops/ascend/test_aicpu_ops/test_standard_normal.py View File

@@ -34,7 +34,7 @@ class Net(nn.Cell):
self.stdnormal = P.StandardNormal(seed, seed2) self.stdnormal = P.StandardNormal(seed, seed2)


def construct(self): def construct(self):
return self.stdnormal(self.shape, self.seed, self.seed2)
return self.stdnormal(self.shape)




def test_net(): def test_net():


+ 57
- 0
tests/st/ops/ascend/test_aicpu_ops/test_uniform.py View File

@@ -0,0 +1,57 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.shape = shape
self.seed = seed

def construct(self, a, b):
C.set_seed(20)
return C.uniform(self.shape, a, b, self.seed)


def test_net_1D():
seed = 10
shape = (3, 2, 4)
a = 1.0
b = 6.0
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
assert output.shape == (3, 2, 4)


def test_net_ND():
seed = 10
shape = (3, 1, 2)
a = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
b = np.array([1.0]).astype(np.float32)
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
assert output.shape == (3, 2, 2)

+ 2
- 16
tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np


import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
@@ -24,7 +23,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")




class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape, seed=0):
def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__() super(Net, self).__init__()
self.uniformint = P.UniformInt(seed=seed) self.uniformint = P.UniformInt(seed=seed)
self.shape = shape self.shape = shape
@@ -38,20 +37,7 @@ def test_net_1D():
shape = (3, 2, 4) shape = (3, 2, 4)
a = 1 a = 1
b = 5 b = 5
net = Net(shape, seed)
net = Net(shape, seed=seed)
ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32) ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32)
output = net(ta, tb) output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)


def test_net_ND():
seed = 10
shape = (3, 2, 1)
a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32)
b = np.array([10]).astype(np.int32)
net = Net(shape, seed)
ta, tb = Tensor(a), Tensor(b)
output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 2)

+ 6
- 25
tests/st/ops/ascend/test_aicpu_ops/test_uniform_real.py View File

@@ -12,46 +12,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np


import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")




class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape, seed=0):
def __init__(self, shape, seed=0, seed2=0):
super(Net, self).__init__() super(Net, self).__init__()
self.uniformreal = P.UniformReal(seed=seed) self.uniformreal = P.UniformReal(seed=seed)
self.shape = shape self.shape = shape


def construct(self, a, b):
return self.uniformreal(self.shape, a, b)
def construct(self):
return self.uniformreal(self.shape)




def test_net_1D():
def test_net():
seed = 10 seed = 10
shape = (3, 2, 4) shape = (3, 2, 4)
a = 1.0
b = 5.0
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
print(output.asnumpy())
net = Net(shape, seed=seed)
output = net()
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)


def test_net_ND():
seed = 10
shape = (3, 2, 1)
a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.float32)
b = np.array([10]).astype(np.float32)
net = Net(shape, seed)
ta, tb = Tensor(a), Tensor(b)
output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 2)

+ 7
- 22
tests/ut/python/ops/test_ops.py View File

@@ -573,25 +573,14 @@ class PoissonNet(nn.Cell):
return out return out




class UniformIntNet(nn.Cell):
class UniformNet(nn.Cell):
def __init__(self, shape=None, seed=0): def __init__(self, shape=None, seed=0):
super(UniformIntNet, self).__init__()
self.uniformint = P.UniformInt(seed=seed)
self.shape = shape

def construct(self, a, b):
out = self.uniformint(self.shape, a, b)
return out


class UniformRealNet(nn.Cell):
def __init__(self, shape=None, seed=0):
super(UniformRealNet, self).__init__()
self.uniformreal = P.UniformReal(seed=seed)
super(UniformNet, self).__init__()
self.shape = shape self.shape = shape
self.seed = seed


def construct(self, a, b): def construct(self, a, b):
out = self.uniformreal(self.shape, a, b)
out = C.uniform(self.shape, a, b, self.seed)
return out return out




@@ -882,13 +871,9 @@ test_case_math_ops = [
'block': PoissonNet((3, 2, 4), 0), 'block': PoissonNet((3, 2, 4), 0),
'desc_inputs': [Tensor(2.0, mstype.float32)], 'desc_inputs': [Tensor(2.0, mstype.float32)],
'skip': ['backward']}), 'skip': ['backward']}),
('UniformInt', {
'block': UniformIntNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1, mstype.int32), Tensor(15, mstype.int32)],
'skip': ['backward']}),
('UniformReal', {
'block': UniformRealNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(5.0, mstype.float32)],
('Uniform', {
'block': UniformNet((3, 2, 4), 0),
'desc_inputs': [Tensor(0.0, mstype.float32), Tensor(1.0, mstype.float32)],
'skip': ['backward']}), 'skip': ['backward']}),
('RandomChoiceWithMask', { ('RandomChoiceWithMask', {
'block': P.RandomChoiceWithMask(256), 'block': P.RandomChoiceWithMask(256),


Loading…
Cancel
Save