Browse Source

Add sampling functions in exponential, geometric and uniform distributions

tags/v0.7.0-beta
peixu_ren Xun Deng 5 years ago
parent
commit
292f2de0cf
14 changed files with 284 additions and 197 deletions
  1. +29
    -40
      mindspore/nn/distribution/bernoulli.py
  2. +14
    -6
      mindspore/nn/distribution/distribution.py
  3. +37
    -23
      mindspore/nn/distribution/exponential.py
  4. +37
    -34
      mindspore/nn/distribution/geometric.py
  5. +15
    -20
      mindspore/nn/distribution/normal.py
  6. +52
    -39
      mindspore/nn/distribution/uniform.py
  7. +9
    -13
      tests/st/ops/ascend/test_distribution/test_bernoulli.py
  8. +25
    -1
      tests/st/ops/ascend/test_distribution/test_exponential.py
  9. +24
    -6
      tests/st/ops/ascend/test_distribution/test_geometric.py
  10. +11
    -11
      tests/st/ops/ascend/test_distribution/test_normal.py
  11. +26
    -1
      tests/st/ops/ascend/test_distribution/test_uniform.py
  12. +0
    -1
      tests/ut/python/nn/distribution/test_bernoulli.py
  13. +4
    -1
      tests/ut/python/nn/distribution/test_normal.py
  14. +1
    -1
      tests/ut/python/nn/distribution/test_uniform.py

+ 29
- 40
mindspore/nn/distribution/bernoulli.py View File

@@ -14,7 +14,6 @@
# ============================================================================
"""Bernoulli Distribution"""
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype
@@ -37,10 +36,10 @@ class Bernoulli(Distribution):
>>> # To initialize a Bernoulli distribution of prob 0.5
>>> n = nn.Bernoulli(0.5, dtype=mstype.int32)
>>>
>>> # The following create two independent Bernoulli distributions
>>> # The following creates two independent Bernoulli distributions
>>> n = nn.Bernoulli([0.5, 0.5], dtype=mstype.int32)
>>>
>>> # A Bernoulli distribution can be initilize without arguments
>>> # A Bernoulli distribution can be initilized without arguments
>>> # In this case, probs must be passed in through construct.
>>> n = nn.Bernoulli(dtype=mstype.int32)
>>>
@@ -54,29 +53,29 @@ class Bernoulli(Distribution):
>>> # All the following calls in construct are valid
>>> def construct(self, value, probs_b, probs_a):
>>>
>>> # Similar to calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.b1('prob', value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.b1('prob', value, probs_b)
>>>
>>> # Additional probs must be passed in through construct
>>> # probs must be passed in through construct
>>> ans = self.b2('prob', value, probs_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
>>> # Functions 'sd', 'var', 'entropy' have the same usage like 'mean'
>>> # Will return [0.0]
>>> ans = self.b1('mean')
>>> # Will return mean_b
>>> ans = self.b1('mean', probs_b)
>>>
>>> # Additional probs must be passed in through construct
>>> # probs must be passed in through construct
>>> ans = self.b2('mean', probs_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.b1('kl_loss', 'Bernoulli', probs_b)
>>> ans = self.b1('kl_loss', 'Bernoulli', probs_b, probs_a)
>>>
>>> # Additional probs must be passed in through construct
>>> # Additional probs_a must be passed in through construct
>>> ans = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a)
>>>
>>> # Sample Usage
@@ -110,18 +109,12 @@ class Bernoulli(Distribution):
self.erf = P.Erf()
self.fill = P.Fill()
self.log = P.Log()
self.add = P.TensorAdd()
self.sq = P.Square()
self.mul = P.Mul()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.shape = P.Shape()
self.const = P.ScalarToArray()
self.less = P.Less()
self.cast = P.Cast()
self.erf = P.Erf()
self.shape = P.Shape()
self.select = P.Select()
self.fill = P.Fill()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)

def extend_repr(self):
if self.is_scalar_batch:
@@ -143,7 +136,7 @@ class Bernoulli(Distribution):
MEAN(B) = probs1
"""
if name == 'mean':
return self._probs if probs1 is None else probs1
return self.probs if probs1 is None else probs1
return None

def _mode(self, name='mode', probs1=None):
@@ -166,9 +159,9 @@ class Bernoulli(Distribution):
VAR(B) = probs1 * probs0
"""
if name in self._variance_functions:
probs1 = self._probs if probs1 is None else probs1
probs1 = self.probs if probs1 is None else probs1
probs0 = 1.0 - probs1
return self.mul(probs0, probs1)
return probs0 * probs1
return None

def _entropy(self, name='entropy', probs=None):
@@ -177,9 +170,9 @@ class Bernoulli(Distribution):
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
"""
if name == 'entropy':
probs1 = self._probs if probs is None else probs
probs1 = self.probs if probs is None else probs
probs0 = 1 - probs1
return -self.mul(probs0, self.log(probs0)) - self.mul(probs1, self.log(probs1))
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
return None

def _cross_entropy(self, name, dist, probs1_b, probs1_a=None):
@@ -190,7 +183,7 @@ class Bernoulli(Distribution):
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self._probs.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
"""
if name == 'cross_entropy' and dist == 'Bernoulli':
return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a)
@@ -203,14 +196,14 @@ class Bernoulli(Distribution):
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default: self._probs.
probs (Tensor): probability of outcome is 1. Default: self.probs.

.. math::
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
if name in self._prob_functions:
probs1 = self._probs if probs is None else probs
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
return (probs1 * value) + (probs0 * (1.0 - value))
return None
@@ -222,7 +215,7 @@ class Bernoulli(Distribution):
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
probs (Tensor): probability of outcome is 1. Default: self._probs.
probs (Tensor): probability of outcome is 1. Default: self.probs.

.. math::
cdf(k) = 0 if k < 0;
@@ -250,17 +243,17 @@ class Bernoulli(Distribution):
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self._probs.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.

.. math::
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b})
"""
if name in self._divergence_functions and dist == 'Bernoulli':
probs1_a = self._probs if probs1_a is None else probs1_a
probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.mul(probs1_a, self.log(probs1_a / probs1_b)) + self.mul(probs0_a, self.log(probs0_a / probs0_b))
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
return None

def _sample(self, name, shape=(), probs=None):
@@ -270,21 +263,17 @@ class Bernoulli(Distribution):
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self._probs.
probs (Tensor): probs1 of the samples. Default: self.probs.

Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
probs1 = self._probs if probs is None else probs
batch_shape = self.shape(probs1)
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sqrt_two = self.sqrt(self.const(2.0))
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two)))
probs1 = self.probs if probs is None else probs
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one)
sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self._dtype)
sample = self.cast(sample, self.dtype)
return sample
return None

+ 14
- 6
mindspore/nn/distribution/distribution.py View File

@@ -30,12 +30,14 @@ class Distribution(Cell):
and _log_prob. Functions should be called through construct when
used inside a network. Arguments should be passed in through *args
in the form of function name followed by additional arguments.
Functions such as cdf and prob, requires a value to be passed in while
functions such as mean, and sd does not require arguments other than name.
Functions such as cdf and prob, require a value to be passed in while
functions such as mean and sd do not require arguments other than name.

Dist_spec_args are unique for each distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution. For all functions, dist_spec_args, are optional. Passing in
the additional dist_spec_args will make the result to be evaluated with
Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution.

For all functions, passing in dist_spec_args, are optional.
Passing in the additional dist_spec_args will make the result to be evaluated with
new distribution specified by the dist_spec_args. But it won't change the
original distribuion.
"""
@@ -258,7 +260,8 @@ class Distribution(Cell):
Evaluate the log cdf at given value.

Note:
Args must include value, and dist_spec_args are optional.
Args must include name of the function and value.
Dist_spec_args are optional.
"""
return self._call_log_cdf(*args)

@@ -428,6 +431,11 @@ class Distribution(Cell):
"""
Override construct in Cell.

Note:
Names of supported functions:
'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival'
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'.

Args:
*inputs (list): inputs[0] is always the name of the function.
"""


+ 37
- 23
mindspore/nn/distribution/exponential.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Exponential Distribution"""
import numpy as np
from mindspore.ops import operations as P
from .distribution import Distribution
from ...common import dtype as mstype
@@ -36,10 +37,10 @@ class Exponential(Distribution):
>>> # To initialize an Exponential distribution of rate 0.5
>>> n = nn.Exponential(0.5, dtype=mstype.float32)
>>>
>>> # The following create two independent Exponential distributions
>>> # The following creates two independent Exponential distributions
>>> n = nn.Exponential([0.5, 0.5], dtype=mstype.float32)
>>>
>>> # A Exponential distribution can be initilize without arguments
>>> # A Exponential distribution can be initilized without arguments
>>> # In this case, rate must be passed in through construct.
>>> n = nn.Exponential(dtype=mstype.float32)
>>>
@@ -53,13 +54,13 @@ class Exponential(Distribution):
>>> # All the following calls in construct are valid
>>> def construct(self, value, rate_b, rate_a):
>>>
>>> # Similar to calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.e1('prob', value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.e1('prob', value, rate_b)
>>>
>>> # Additional rate must be passed in through construct
>>> # Rate must be passed in through construct
>>> ans = self.e2('prob', value, rate_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
@@ -68,7 +69,7 @@ class Exponential(Distribution):
>>> # Will return mean_b
>>> ans = self.e1('mean', rate_b)
>>>
>>> # Additional rate must be passed in through construct
>>> # Rate must be passed in through construct
>>> ans = self.e2('mean', rate_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
@@ -101,21 +102,20 @@ class Exponential(Distribution):
else:
self._rate = rate

self.minval = np.finfo(np.float).tiny

# ops needed for the class
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
self.log = P.Log()
self.add = P.TensorAdd()
self.mul = P.Mul()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.shape = P.Shape()
self.normal = P.Normal(seed=seed)
self.sq = P.Square()
self.fill = P.Fill()
self.less = P.Less()
self.log = P.Log()
self.select = P.Select()
self.shape = P.Shape()
self.sqrt = P.Sqrt()
self.sq = P.Square()
self.uniform = P.UniformReal(seed=seed)

def extend_repr(self):
if self.is_scalar_batch:
@@ -137,7 +137,7 @@ class Exponential(Distribution):
MEAN(EXP) = \fract{1.0}{\lambda}.
"""
if name == 'mean':
rate = self._rate if rate is None else rate
rate = self.rate if rate is None else rate
return 1.0 / rate
return None

@@ -157,7 +157,7 @@ class Exponential(Distribution):
sd(EXP) = \fract{1.0}{\lambda}.
"""
if name in self._variance_functions:
rate = self._rate if rate is None else rate
rate = self.rate if rate is None else rate
return 1.0 / rate
return None

@@ -166,7 +166,7 @@ class Exponential(Distribution):
.. math::
H(Exp) = 1 - \log(\lambda).
"""
rate = self._rate if rate is None else rate
rate = self.rate if rate is None else rate
if name == 'entropy':
return 1.0 - self.log(rate)
return None
@@ -179,7 +179,7 @@ class Exponential(Distribution):
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
dist (str): type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self._rate.
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
if name == 'cross_entropy' and dist == 'Exponential':
return self._entropy(rate=rate_a) + self._kl_loss(name, dist, rate_b, rate_a)
@@ -193,7 +193,7 @@ class Exponential(Distribution):
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self._rate.
rate (Tensor): rate of the distribution. Default: self.rate.

Note:
Value should be greater or equal to zero.
@@ -216,7 +216,7 @@ class Exponential(Distribution):
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self._rate.
rate (Tensor): rate of the distribution. Default: self.rate.

Note:
Value should be greater or equal to zero.
@@ -240,15 +240,29 @@ class Exponential(Distribution):
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self._rate.
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
if name in self._divergence_functions and dist == 'Exponential':
rate_a = self._rate if rate_a is None else rate_a
rate_a = self.rate if rate_a is None else rate_a
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
return None

def _sample(self, name, shape=(), rate=None):
"""
Sampling.

Args:
name (str): name of the function.
shape (tuple): shape of the sample. Default: ().
rate (Tensor): rate of the distribution. Default: self.rate.

Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
rate = self._rate if rate is None else rate
return self.fill(mstype.float32, shape + self.shape(rate), 1.0)
rate = self.rate if rate is None else rate
minval = self.const(self.minval)
maxval = self.const(1.0)
sample = self.uniform(shape + self.shape(rate), minval, maxval)
return -self.log(sample) / rate
return None

+ 37
- 34
mindspore/nn/distribution/geometric.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Geometric Distribution"""
import numpy as np
from mindspore.ops import operations as P
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
@@ -37,10 +38,10 @@ class Geometric(Distribution):
>>> # To initialize a Geometric distribution of prob 0.5
>>> n = nn.Geometric(0.5, dtype=mstype.int32)
>>>
>>> # The following create two independent Geometric distributions
>>> # The following creates two independent Geometric distributions
>>> n = nn.Geometric([0.5, 0.5], dtype=mstype.int32)
>>>
>>> # A Geometric distribution can be initilize without arguments
>>> # A Geometric distribution can be initilized without arguments
>>> # In this case, probs must be passed in through construct.
>>> n = nn.Geometric(dtype=mstype.int32)
>>>
@@ -51,16 +52,16 @@ class Geometric(Distribution):
>>> self.g1 = nn.Geometric(0.5, dtype=mstype.int32)
>>> self.g2 = nn.Geometric(dtype=mstype.int32)
>>>
>>> # All the following calls in construct are valid
>>> # Tthe following calls are valid in construct
>>> def construct(self, value, probs_b, probs_a):
>>>
>>> # Similar to calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.g1('prob', value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.g1('prob', value, probs_b)
>>>
>>> # Additional probs must be passed in through construct
>>> # Probs must be passed in through construct
>>> ans = self.g2('prob', value, probs_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
@@ -69,7 +70,7 @@ class Geometric(Distribution):
>>> # Will return mean_b
>>> ans = self.g1('mean', probs_b)
>>>
>>> # Additional probs must be passed in through construct
>>> # Probs must be passed in through construct
>>> ans = self.g2('mean', probs_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
@@ -102,23 +103,22 @@ class Geometric(Distribution):
else:
self._probs = probs

self.minval = np.finfo(np.float).tiny

# ops needed for the class
self.log = P.Log()
self.add = P.TensorAdd()
self.mul = P.Mul()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.shape = P.Shape()
self.dType = P.DType()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.fill = P.Fill()
self.floor = P.Floor()
self.issubclass = P.IsSubClass()
self.const = P.ScalarToArray()
self.less = P.Less()
self.normal = P.Normal(seed=seed)
self.sq = P.Square()
self.select = P.Select()
self.fill = P.Fill()
self.log = P.Log()
self.pow = P.Pow()
self.select = P.Select()
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)

def extend_repr(self):
if self.is_scalar_batch:
@@ -140,7 +140,7 @@ class Geometric(Distribution):
MEAN(Geo) = \fratc{1 - probs1}{probs1}
"""
if name == 'mean':
probs1 = self._probs if probs1 is None else probs1
probs1 = self.probs if probs1 is None else probs1
return (1. - probs1) / probs1
return None

@@ -160,7 +160,7 @@ class Geometric(Distribution):
VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}}
"""
if name in self._variance_functions:
probs1 = self._probs if probs1 is None else probs1
probs1 = self.probs if probs1 is None else probs1
return (1.0 - probs1) / self.sq(probs1)
return None

@@ -170,7 +170,7 @@ class Geometric(Distribution):
H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
"""
if name == 'entropy':
probs1 = self._probs if probs is None else probs
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
return None
@@ -183,7 +183,7 @@ class Geometric(Distribution):
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
dist (str): type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self._probs.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
"""
if name == 'cross_entropy' and dist == 'Geometric':
return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a)
@@ -196,15 +196,15 @@ class Geometric(Distribution):
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self._probs.
probs (Tensor): probability of success. Default: self.probs.

.. math::
pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0.
"""
if name in self._prob_functions:
probs1 = self._probs if probs is None else probs
dtype = self.dType(value)
probs1 = self.probs if probs is None else probs
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
@@ -224,7 +224,7 @@ class Geometric(Distribution):
Args:
name (str): name of the function.
value (Tensor): a Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self._probs.
probs (Tensor): probability of success. Default: self.probs.

.. math::
cdf(k) = 1 - probs0 ^ (k+1) if k >= 0;
@@ -232,9 +232,9 @@ class Geometric(Distribution):

"""
if name in self._cdf_survival_functions:
probs1 = self._probs if probs is None else probs
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
dtype = self.dType(value)
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
@@ -255,16 +255,16 @@ class Geometric(Distribution):
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self._probs.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.

.. math::
KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b})
"""
if name in self._divergence_functions and dist == 'Geometric':
probs1_a = self._probs if probs1_a is None else probs1_a
probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + self.mul(probs0_a / probs1_a, self.log(probs0_a / probs0_b))
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
return None

def _sample(self, name, shape=(), probs=None):
@@ -274,12 +274,15 @@ class Geometric(Distribution):
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self._probs.
probs (Tensor): probability of success. Default: self.probs.

Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
probs = self._probs if probs is None else probs
return self.fill(mstype.float32, shape + self.shape(probs), 1.0)
probs = self.probs if probs is None else probs
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval)
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
return None

+ 15
- 20
mindspore/nn/distribution/normal.py View File

@@ -26,22 +26,21 @@ class Normal(Distribution):
Normal distribution.

Args:
mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution.
sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution.
mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Normal distribution.
sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Normal distribution.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
name (str): name of the distribution. Default: Normal.


Note:
Standard deviation should be greater than zero.
Dist_spec_args are mean and sd.

Examples:
>>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0
>>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32)
>>>
>>> # The following create two independent normal distributions
>>> # The following creates two independent Normal distributions
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>> # A normal distribution can be initilize without arguments
@@ -55,16 +54,16 @@ class Normal(Distribution):
>>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32)
>>> self.n2 = nn.Normal(dtype=mstype.float32)
>>>
>>> # All the following calls in construct are valid
>>> # The following calls are valid in construct
>>> def construct(self, value, mean_b, sd_b, mean_a, sd_a):
>>>
>>> # Similar to calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.n1('prob', value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.n1('prob', value, mean_b, sd_b)
>>>
>>> # Additional mean and sd must be passed in through construct
>>> # mean and sd must be passed in through construct
>>> ans = self.n2('prob', value, mean_a, sd_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
@@ -73,7 +72,7 @@ class Normal(Distribution):
>>> # Will return mean_b
>>> ans = self.n1('mean', mean_b, sd_b)
>>>
>>> # Additional mean and sd must be passed in through construct
>>> # mean and sd must be passed in through construct
>>> ans = self.n2('mean', mean_a, sd_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
@@ -111,20 +110,16 @@ class Normal(Distribution):
self.seed = seed

#ops needed for the class
self.const = P.ScalarToArray()
self.erf = P.Erf()
self.exp = P.Exp()
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step
self.fill = P.Fill()
self.log = P.Log()
self.shape = P.Shape()
self.sq = P.Square()
self.log = P.Log()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step
self.shape = P.Shape()
self.zeroslike = P.ZerosLike()
self.const = P.ScalarToArray()
self.erf = P.Erf()
self.fill = P.Fill()

def extend_repr(self):
if self.is_scalar_batch:
@@ -231,8 +226,8 @@ class Normal(Distribution):
if name in self._cdf_survival_functions:
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
sqrt2 = self.sqrt(self.fill(mstype.float32, self.shape(sd), 2.0))
adjusted = (value - mean) / self.mul(sd, sqrt2)
sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted))
return None

@@ -276,11 +271,11 @@ class Normal(Distribution):
if name == 'sample':
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd)))
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample = self.add(mean, self.mul(sample_norm, sd))
sample = mean + sample_norm * sd
return sample
return None

+ 52
- 39
mindspore/nn/distribution/uniform.py View File

@@ -37,10 +37,10 @@ class Uniform(Distribution):
>>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32)
>>>
>>> # The following create two independent Uniform distributions
>>> # The following creates two independent Uniform distributions
>>> n = nn.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
>>>
>>> # A Uniform distribution can be initilize without arguments
>>> # A Uniform distribution can be initilized without arguments
>>> # In this case, high and low must be passed in through construct.
>>> n = nn.Uniform(dtype=mstype.float32)
>>>
@@ -54,13 +54,13 @@ class Uniform(Distribution):
>>> # All the following calls in construct are valid
>>> def construct(self, value, low_b, high_b, low_a, high_a):
>>>
>>> # Similar to calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> ans = self.u1('prob', value)
>>> # Evaluate with the respect to distribution b
>>> ans = self.u1('prob', value, low_b, high_b)
>>>
>>> # Additional high and low must be passed in through construct
>>> # High and low must be passed in through construct
>>> ans = self.u2('prob', value, low_a, high_a)
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
@@ -69,7 +69,7 @@ class Uniform(Distribution):
>>> # Will return low_b
>>> ans = self.u1('mean', low_b, high_b)
>>>
>>> # Additional high and low must be passed in through construct
>>> # High and low must be passed in through construct
>>> ans = self.u2('mean', low_a, high_a)
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
@@ -100,7 +100,7 @@ class Uniform(Distribution):
if low is not None and high is not None:
self._low = convert_to_batch(low, self._broadcast_shape, dtype)
self._high = convert_to_batch(high, self._broadcast_shape, dtype)
check_greater(self._low, self._high, "low value", "high value")
check_greater(self.low, self.high, "low value", "high value")
else:
self._low = low
self._high = high
@@ -109,20 +109,17 @@ class Uniform(Distribution):
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
self.log = P.Log()
self.add = P.TensorAdd()
self.mul = P.Mul()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.fill = P.Fill()
self.less = P.Less()
self.lessequal = P.LessEqual()
self.sq = P.Square()
self.select = P.Select()
self.zeroslike = P.ZerosLike()
self.log = P.Log()
self.logicaland = P.LogicalAnd()
self.fill = P.Fill()
self.select = P.Select()
self.shape = P.Shape()
self.normal = P.Normal(seed=seed)
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = P.UniformReal(seed=seed)
self.zeroslike = P.ZerosLike()

def extend_repr(self):
if self.is_scalar_batch:
@@ -152,8 +149,8 @@ class Uniform(Distribution):
range(U) = high -low
"""
if name == 'range':
low = self._low if low is None else low
high = self._high if high is None else high
low = self.low if low is None else low
high = self.high if high is None else high
return high - low
return None

@@ -163,8 +160,8 @@ class Uniform(Distribution):
MEAN(U) = \fract{low + high}{2}.
"""
if name == 'mean':
low = self._low if low is None else low
high = self._high if high is None else high
low = self.low if low is None else low
high = self.high if high is None else high
return (low + high) / 2.
return None

@@ -174,8 +171,8 @@ class Uniform(Distribution):
VAR(U) = \fract{(high -low) ^ 2}{12}.
"""
if name in self._variance_functions:
low = self._low if low is None else low
high = self._high if high is None else high
low = self.low if low is None else low
high = self.high if high is None else high
return self.sq(high - low) / 12.0
return None

@@ -185,8 +182,8 @@ class Uniform(Distribution):
H(U) = \log(high - low).
"""
if name == 'entropy':
low = self._low if low is None else low
high = self._high if high is None else high
low = self.low if low is None else low
high = self.high if high is None else high
return self.log(high - low)
return None

@@ -199,8 +196,8 @@ class Uniform(Distribution):
dist (str): type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b.
high_b (Tensor): upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self._low.
high_a (Tensor): upper bound of distribution a. Default: self._high.
low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
if name == 'cross_entropy' and dist == 'Uniform':
return self._entropy(low=low_a, high=high_a) + self._kl_loss(name, dist, low_b, high_b, low_a, high_a)
@@ -213,8 +210,8 @@ class Uniform(Distribution):
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self._low.
high (Tensor): upper bound of the distribution. Default: self._high.
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.

.. math::
pdf(x) = 0 if x < low;
@@ -243,12 +240,12 @@ class Uniform(Distribution):
dist (str): type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b.
high_b (Tensor): upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self._low.
high_a (Tensor): upper bound of distribution a. Default: self._high.
low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
if name in self._divergence_functions and dist == 'Uniform':
low_a = self._low if low_a is None else low_a
high_a = self._high if high_a is None else high_a
low_a = self.low if low_a is None else low_a
high_a = self.high if high_a is None else high_a
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl)))
@@ -261,8 +258,8 @@ class Uniform(Distribution):
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self._low.
high (Tensor): upper bound of the distribution. Default: self._high.
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.

.. math::
cdf(x) = 0 if x < low;
@@ -270,8 +267,8 @@ class Uniform(Distribution):
cdf(x) = 1 if x > high;
"""
if name in self._cdf_survival_functions:
low = self._low if low is None else low
high = self._high if high is None else high
low = self.low if low is None else low
high = self.high if high is None else high
prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
@@ -283,9 +280,25 @@ class Uniform(Distribution):
return None

def _sample(self, name, shape=(), low=None, high=None):
"""
Sampling.

Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.

Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
low = self._low if low is None else low
high = self._high if high is None else high
low = self.low if low is None else low
high = self.high if high is None else high
broadcast_shape = self.shape(low + high)
return self.fill(mstype.float32, shape + broadcast_shape, 1.0)
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one)
sample = (high - low) * sample_uniform + low
return sample
return None

+ 9
- 13
tests/st/ops/ascend/test_distribution/test_bernoulli.py View File

@@ -25,7 +25,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

class Prob(nn.Cell):
"""
Test class: probability of bernoulli distribution.
Test class: probability of Bernoulli distribution.
"""
def __init__(self):
super(Prob, self).__init__()
@@ -50,7 +50,7 @@ def test_pmf():

class LogProb(nn.Cell):
"""
Test class: log probability of bernoulli distribution.
Test class: log probability of Bernoulli distribution.
"""
def __init__(self):
super(LogProb, self).__init__()
@@ -74,7 +74,7 @@ def test_log_likelihood():

class KL(nn.Cell):
"""
Test class: kl_loss between bernoulli distributions.
Test class: kl_loss between Bernoulli distributions.
"""
def __init__(self):
super(KL, self).__init__()
@@ -100,7 +100,7 @@ def test_kl_loss():

class Basics(nn.Cell):
"""
Test class: mean/sd/mode of bernoulli distribution.
Test class: mean/sd/mode of Bernoulli distribution.
"""
def __init__(self):
super(Basics, self).__init__()
@@ -112,7 +112,7 @@ class Basics(nn.Cell):

def test_basics():
"""
Test mean/standard deviation/mode and probs.
Test mean/standard deviation/mode.
"""
basics = Basics()
mean, sd, mode = basics()
@@ -123,14 +123,10 @@ def test_basics():
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32)
probs = b.probs()
expect_probs = [0.7, 0.5]
assert (np.abs(probs.asnumpy() - expect_probs) < tol).all()

class Sampling(nn.Cell):
"""
Test class: log probability of bernoulli distribution.
Test class: log probability of Bernoulli distribution.
"""
def __init__(self, shape, seed=0):
super(Sampling, self).__init__()
@@ -202,7 +198,7 @@ def test_logcdf():

class SF(nn.Cell):
"""
Test class: survival function of bernoulli distributions.
Test class: survival function of Bernoulli distributions.
"""
def __init__(self):
super(SF, self).__init__()
@@ -227,7 +223,7 @@ def test_survival():

class LogSF(nn.Cell):
"""
Test class: log survival function of bernoulli distributions.
Test class: log survival function of Bernoulli distributions.
"""
def __init__(self):
super(LogSF, self).__init__()
@@ -251,7 +247,7 @@ def test_log_survival():

class EntropyH(nn.Cell):
"""
Test class: entropy of bernoulli distributions.
Test class: entropy of Bernoulli distributions.
"""
def __init__(self):
super(EntropyH, self).__init__()


+ 25
- 1
tests/st/ops/ascend/test_distribution/test_exponential.py View File

@@ -109,7 +109,7 @@ class Basics(nn.Cell):

def test_basics():
"""
Test mean/standard deviation and range.
Test mean/standard/mode deviation.
"""
basics = Basics()
mean, sd, mode = basics()
@@ -121,6 +121,30 @@ def test_basics():
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()

class Sampling(nn.Cell):
"""
Test class: sample of Exponential distribution.
"""
def __init__(self, shape, seed=0):
super(Sampling, self).__init__()
self.e = nn.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32)
self.shape = shape

@ms_function
def construct(self, rate=None):
return self.e('sample', self.shape, rate)

def test_sample():
"""
Test sample.
"""
shape = (2, 3)
seed = 10
rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32)
sample = Sampling(shape, seed=seed)
output = sample(rate)
assert output.shape == (2, 3, 3)

class CDF(nn.Cell):
"""
Test class: cdf of Exponential distribution.


+ 24
- 6
tests/st/ops/ascend/test_distribution/test_geometric.py View File

@@ -99,7 +99,7 @@ def test_kl_loss():

class Basics(nn.Cell):
"""
Test class: mean/sd of Geometric distribution.
Test class: mean/sd/mode of Geometric distribution.
"""
def __init__(self):
super(Basics, self).__init__()
@@ -111,7 +111,7 @@ class Basics(nn.Cell):

def test_basics():
"""
Test mean/standard deviation/mode and probs.
Test mean/standard deviation/mode.
"""
basics = Basics()
mean, sd, mode = basics()
@@ -122,10 +122,28 @@ def test_basics():
assert (np.abs(mean.asnumpy()- expect_mean) < tol).all()
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
b = nn.Geometric([0.7, 0.5], dtype=dtype.int32)
probs = b.probs()
expect_probs = [0.7, 0.5]
assert (np.abs(probs.asnumpy() - expect_probs) < tol).all()

class Sampling(nn.Cell):
"""
Test class: log probability of bernoulli distribution.
"""
def __init__(self, shape, seed=0):
super(Sampling, self).__init__()
self.g = nn.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape

@ms_function
def construct(self, probs=None):
return self.g('sample', self.shape, probs)

def test_sample():
"""
Test sample.
"""
shape = (2, 3)
sample = Sampling(shape)
output = sample()
assert output.shape == (2, 3, 2)

class CDF(nn.Cell):
"""


+ 11
- 11
tests/st/ops/ascend/test_distribution/test_normal.py View File

@@ -25,7 +25,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

class Prob(nn.Cell):
"""
Test class: probability of normal distribution.
Test class: probability of Normal distribution.
"""
def __init__(self):
super(Prob, self).__init__()
@@ -48,7 +48,7 @@ def test_pdf():

class LogProb(nn.Cell):
"""
Test class: log probability of normal distribution.
Test class: log probability of Normal distribution.
"""
def __init__(self):
super(LogProb, self).__init__()
@@ -72,7 +72,7 @@ def test_log_likelihood():

class KL(nn.Cell):
"""
Test class: kl_loss of normal distribution.
Test class: kl_loss of Normal distribution.
"""
def __init__(self):
super(KL, self).__init__()
@@ -106,7 +106,7 @@ def test_kl_loss():

class Basics(nn.Cell):
"""
Test class: mean/sd of normal distribution.
Test class: mean/sd/mode of Normal distribution.
"""
def __init__(self):
super(Basics, self).__init__()
@@ -131,7 +131,7 @@ def test_basics():

class Sampling(nn.Cell):
"""
Test class: sample of normal distribution.
Test class: sample of Normal distribution.
"""
def __init__(self, shape, seed=0):
super(Sampling, self).__init__()
@@ -156,7 +156,7 @@ def test_sample():

class CDF(nn.Cell):
"""
Test class: cdf of normal distribution.
Test class: cdf of Normal distribution.
"""
def __init__(self):
super(CDF, self).__init__()
@@ -180,7 +180,7 @@ def test_cdf():

class LogCDF(nn.Cell):
"""
Test class: log_cdf of normal distribution.
Test class: log_cdf of Mormal distribution.
"""
def __init__(self):
super(LogCDF, self).__init__()
@@ -203,7 +203,7 @@ def test_log_cdf():

class SF(nn.Cell):
"""
Test class: survival function of normal distribution.
Test class: survival function of Normal distribution.
"""
def __init__(self):
super(SF, self).__init__()
@@ -226,7 +226,7 @@ def test_survival():

class LogSF(nn.Cell):
"""
Test class: log survival function of normal distribution.
Test class: log survival function of Normal distribution.
"""
def __init__(self):
super(LogSF, self).__init__()
@@ -249,7 +249,7 @@ def test_log_survival():

class EntropyH(nn.Cell):
"""
Test class: entropy of normal distribution.
Test class: entropy of Normal distribution.
"""
def __init__(self):
super(EntropyH, self).__init__()
@@ -272,7 +272,7 @@ def test_entropy():

class CrossEntropy(nn.Cell):
"""
Test class: cross entropy between normal distribution.
Test class: cross entropy between Normal distributions.
"""
def __init__(self):
super(CrossEntropy, self).__init__()


+ 26
- 1
tests/st/ops/ascend/test_distribution/test_uniform.py View File

@@ -111,7 +111,7 @@ class Basics(nn.Cell):

def test_basics():
"""
Test mean/standard deviation/mode.
Test mean/standard deviation.
"""
basics = Basics()
mean, sd = basics()
@@ -121,6 +121,31 @@ def test_basics():
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()

class Sampling(nn.Cell):
"""
Test class: sample of Uniform distribution.
"""
def __init__(self, shape, seed=0):
super(Sampling, self).__init__()
self.u = nn.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32)
self.shape = shape

@ms_function
def construct(self, low=None, high=None):
return self.u('sample', self.shape, low, high)

def test_sample():
"""
Test sample.
"""
shape = (2, 3)
seed = 10
low = Tensor([1.0], dtype=dtype.float32)
high = Tensor([2.0, 3.0, 4.0], dtype=dtype.float32)
sample = Sampling(shape, seed=seed)
output = sample(low, high)
assert output.shape == (2, 3, 3)

class CDF(nn.Cell):
"""
Test class: cdf of Uniform distribution.


+ 0
- 1
tests/ut/python/nn/distribution/test_bernoulli.py View File

@@ -21,7 +21,6 @@ import mindspore.nn as nn
from mindspore import dtype
from mindspore import Tensor


def test_arguments():
"""
Args passing during initialization.


+ 4
- 1
tests/ut/python/nn/distribution/test_normal.py View File

@@ -111,7 +111,7 @@ class NormalKl(nn.Cell):

def test_kl():
"""
Test kl_loss
Test kl_loss.
"""
net = NormalKl()
mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
@@ -136,6 +136,9 @@ class NormalCrossEntropy(nn.Cell):
return h1 + h2

def test_cross_entropy():
"""
Test cross entropy between Normal distributions.
"""
net = NormalCrossEntropy()
mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)


+ 1
- 1
tests/ut/python/nn/distribution/test_uniform.py View File

@@ -120,7 +120,7 @@ class UniformKl(nn.Cell):

def test_kl():
"""
Test kl_loss
Test kl_loss.
"""
net = UniformKl()
low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32)


Loading…
Cancel
Save