Browse Source

Modify the name of parameters in uniform

tags/v1.0.0
peixu_ren 5 years ago
parent
commit
10f381d662
6 changed files with 64 additions and 64 deletions
  1. +4
    -4
      mindspore/nn/probability/bnn_layers/conv_variational.py
  2. +19
    -19
      mindspore/ops/composite/random_ops.py
  3. +15
    -15
      mindspore/ops/operations/random_ops.py
  4. +10
    -10
      tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py
  5. +10
    -10
      tests/st/ops/ascend/test_compoite_random_ops/test_uniform.py
  6. +6
    -6
      tests/st/ops/gpu/test_uniform_real.py

+ 4
- 4
mindspore/nn/probability/bnn_layers/conv_variational.py View File

@@ -250,10 +250,10 @@ class ConvReparam(_ConvVariational):
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples: Examples:
>>> net = ConvReparam(120, 240, 4, has_bias=False)
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
>>> net = ConvReparam(120, 240, 4, has_bias=False)
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
""" """
def __init__( def __init__(


+ 19
- 19
mindspore/ops/composite/random_ops.py View File

@@ -92,55 +92,55 @@ def normal(shape, mean, stddev, seed=0):
value = random_normal * stddev + mean value = random_normal * stddev + mean
return value return value


def uniform(shape, a, b, seed=0, dtype=mstype.float32):
def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
""" """
Generates random numbers according to the Uniform random number distribution. Generates random numbers according to the Uniform random number distribution.


Note: Note:
The number in tensor a should be strictly less than b at any position after broadcasting.
The number in tensor minval should be strictly less than maxval at any position after broadcasting.


Args: Args:
shape (tuple): The shape of random tensor to be generated. shape (tuple): The shape of random tensor to be generated.
a (Tensor): The a distribution parameter.
minval (Tensor): The a distribution parameter.
It defines the minimum possibly generated value. With int32 or float32 data type. It defines the minimum possibly generated value. With int32 or float32 data type.
If dtype is int32, only one number is allowed. If dtype is int32, only one number is allowed.
b (Tensor): The b distribution parameter.
maxval (Tensor): The b distribution parameter.
It defines the maximum possibly generated value. With int32 or float32 data type. It defines the maximum possibly generated value. With int32 or float32 data type.
If dtype is int32, only one number is allowed. If dtype is int32, only one number is allowed.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Must be non-negative. Default: 0. Must be non-negative. Default: 0.


Returns: Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b.
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of minval and maxval.
The dtype is designated as the input `dtype`. The dtype is designated as the input `dtype`.


Examples: Examples:
>>> For discrete uniform distribution, only one number is allowed for both a and b:
>>> For discrete uniform distribution, only one number is allowed for both minval and maxval:
>>> shape = (4, 2) >>> shape = (4, 2)
>>> a = Tensor(1, mstype.int32)
>>> b = Tensor(2, mstype.int32)
>>> output = C.uniform(shape, a, b, seed=5)
>>> minval = Tensor(1, mstype.int32)
>>> maxval = Tensor(2, mstype.int32)
>>> output = C.uniform(shape, minval, maxval, seed=5)
>>> >>>
>>> For continuous uniform distribution, a and b can be multi-dimentional:
>>> For continuous uniform distribution, minval and maxval can be multi-dimentional:
>>> shape = (4, 2) >>> shape = (4, 2)
>>> a = Tensor([1.0, 2.0], mstype.float32)
>>> b = Tensor([4.0, 5.0], mstype.float32)
>>> output = C.uniform(shape, a, b, seed=5)
>>> minval = Tensor([1.0, 2.0], mstype.float32)
>>> maxval = Tensor([4.0, 5.0], mstype.float32)
>>> output = C.uniform(shape, minval, maxval, seed=5)
""" """
a_dtype = F.dtype(a)
b_dtype = F.dtype(b)
const_utils.check_tensors_dtype_same(a_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(b_dtype, dtype, "uniform")
minval_dtype = F.dtype(minval)
maxval_dtype = F.dtype(maxval)
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
const_utils.check_non_negative("seed", seed, "uniform") const_utils.check_non_negative("seed", seed, "uniform")
seed1 = get_seed() seed1 = get_seed()
seed2 = seed seed2 = seed
if const_utils.is_same_type(dtype, mstype.int32): if const_utils.is_same_type(dtype, mstype.int32):
random_uniform = P.UniformInt(seed1, seed2) random_uniform = P.UniformInt(seed1, seed2)
value = random_uniform(shape, a, b)
value = random_uniform(shape, minval, maxval)
else: else:
uniform_real = P.UniformReal(seed1, seed2) uniform_real = P.UniformReal(seed1, seed2)
random_uniform = uniform_real(shape) random_uniform = uniform_real(shape)
value = random_uniform * (b - a) + a
value = random_uniform * (maxval - minval) + minval
return value return value


def gamma(shape, alpha, beta, seed=0): def gamma(shape, alpha, beta, seed=0):


+ 15
- 15
mindspore/ops/operations/random_ops.py View File

@@ -224,14 +224,14 @@ class Poisson(PrimitiveWithInfer):


class UniformInt(PrimitiveWithInfer): class UniformInt(PrimitiveWithInfer):
r""" r"""
Produces random integer values i, uniformly distributed on the closed interval [a, b), that is,
Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is,
distributed according to the discrete probability function: distributed according to the discrete probability function:


.. math:: .. math::
\text{P}(i|a,b) = \frac{1}{b-a+1}, \text{P}(i|a,b) = \frac{1}{b-a+1},


Note: Note:
The number in tensor a should be strictly less than b at any position after broadcasting.
The number in tensor minval should be strictly less than maxval at any position after broadcasting.


Args: Args:
seed (int): Random seed. Must be non-negative. Default: 0. seed (int): Random seed. Must be non-negative. Default: 0.
@@ -239,9 +239,9 @@ class UniformInt(PrimitiveWithInfer):


Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
- **a** (Tensor) - The a distribution parameter.
- **minval** (Tensor) - The a distribution parameter.
It defines the minimum possibly generated value. With int32 data type. Only one number is supported. It defines the minimum possibly generated value. With int32 data type. Only one number is supported.
- **b** (Tensor) - The b distribution parameter.
- **maxval** (Tensor) - The b distribution parameter.
It defines the maximum possibly generated value. With int32 data type. Only one number is supported. It defines the maximum possibly generated value. With int32 data type. Only one number is supported.


Outputs: Outputs:
@@ -249,32 +249,32 @@ class UniformInt(PrimitiveWithInfer):


Examples: Examples:
>>> shape = (4, 16) >>> shape = (4, 16)
>>> a = Tensor(1, mstype.int32)
>>> b = Tensor(5, mstype.int32)
>>> minval = Tensor(1, mstype.int32)
>>> maxval = Tensor(5, mstype.int32)
>>> uniform_int = P.UniformInt(seed=10) >>> uniform_int = P.UniformInt(seed=10)
>>> output = uniform_int(shape, a, b)
>>> output = uniform_int(shape, minval, maxval)
""" """


@prim_attr_register @prim_attr_register
def __init__(self, seed=0, seed2=0): def __init__(self, seed=0, seed2=0):
"""Init UniformInt""" """Init UniformInt"""
self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output'])
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
validator.check_integer("seed", seed, 0, Rel.GE, self.name) validator.check_integer("seed", seed, 0, Rel.GE, self.name)
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)


def __infer__(self, shape, a, b):
def __infer__(self, shape, minval, maxval):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name)
validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name)
a_shape = a['shape']
b_shape = b['shape']
validator.check("dim of a", len(a_shape), '0(scalar)', 0, Rel.EQ, self.name)
validator.check("dim of b", len(b_shape), '0(scalar)', 0, Rel.EQ, self.name)
validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
minval_shape = minval['shape']
maxval_shape = maxval['shape']
validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
out = { out = {
'shape': shape_v, 'shape': shape_v,
'dtype': mstype.int32, 'dtype': mstype.int32,


+ 10
- 10
tests/st/ops/ascend/test_aicpu_ops/test_uniform_int.py View File

@@ -28,28 +28,28 @@ class Net(nn.Cell):
self.uniformint = P.UniformInt(seed=seed) self.uniformint = P.UniformInt(seed=seed)
self.shape = shape self.shape = shape


def construct(self, a, b):
return self.uniformint(self.shape, a, b)
def construct(self, minval, maxval):
return self.uniformint(self.shape, minval, maxval)




def test_net_1D(): def test_net_1D():
seed = 10 seed = 10
shape = (3, 2, 4) shape = (3, 2, 4)
a = 1
b = 5
minval = 1
maxval = 5
net = Net(shape, seed=seed) net = Net(shape, seed=seed)
ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32)
output = net(ta, tb)
tminval, tmaxval = Tensor(minval, mstype.int32), Tensor(maxval, mstype.int32)
output = net(tminval, tmaxval)
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)




def test_net_ND(): def test_net_ND():
seed = 10 seed = 10
shape = (3, 2, 1) shape = (3, 2, 1)
a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32)
b = np.array([10]).astype(np.int32)
minval = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32)
maxval = np.array([10]).astype(np.int32)
net = Net(shape, seed) net = Net(shape, seed)
ta, tb = Tensor(a), Tensor(b)
output = net(ta, tb)
tminval, tmaxval = Tensor(minval), Tensor(maxval)
output = net(tminval, tmaxval)
print(output.asnumpy()) print(output.asnumpy())
assert output.shape == (3, 2, 2) assert output.shape == (3, 2, 2)

+ 10
- 10
tests/st/ops/ascend/test_compoite_random_ops/test_uniform.py View File

@@ -29,28 +29,28 @@ class Net(nn.Cell):
self.shape = shape self.shape = shape
self.seed = seed self.seed = seed


def construct(self, a, b):
def construct(self, minval, maxval):
C.set_seed(20) C.set_seed(20)
return C.uniform(self.shape, a, b, self.seed)
return C.uniform(self.shape, minval, maxval, self.seed)




def test_net_1D(): def test_net_1D():
seed = 10 seed = 10
shape = (3, 2, 4) shape = (3, 2, 4)
a = 1.0
b = 6.0
minval = 1.0
maxval = 6.0
net = Net(shape, seed) net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
tminval, tmaxval = Tensor(minval, mstype.float32), Tensor(maxval, mstype.float32)
output = net(tminval, tmaxval)
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)




def test_net_ND(): def test_net_ND():
seed = 10 seed = 10
shape = (3, 1, 2) shape = (3, 1, 2)
a = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
b = np.array([1.0]).astype(np.float32)
minval = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
maxval = np.array([1.0]).astype(np.float32)
net = Net(shape, seed) net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
tminval, tmaxval = Tensor(minval, mstype.float32), Tensor(maxval, mstype.float32)
output = net(tminval, tmaxval)
assert output.shape == (3, 2, 2) assert output.shape == (3, 2, 2)

+ 6
- 6
tests/st/ops/gpu/test_uniform_real.py View File

@@ -27,17 +27,17 @@ class Net(nn.Cell):
self.uniformreal = P.UniformReal(seed=seed) self.uniformreal = P.UniformReal(seed=seed)
self.shape = shape self.shape = shape


def construct(self, a, b):
return self.uniformreal(self.shape, a, b)
def construct(self, minval, maxval):
return self.uniformreal(self.shape, minval, maxval)




def test_net_1D(): def test_net_1D():
seed = 10 seed = 10
shape = (3, 2, 4) shape = (3, 2, 4)
a = 0.0
b = 1.0
minval = 0.0
maxval = 1.0
net = Net(shape, seed) net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
tminval, tmaxval = Tensor(minval, mstype.float32), Tensor(maxval, mstype.float32)
output = net(tminval, tmaxval)
print(output.asnumpy()) print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)

Loading…
Cancel
Save