From 292f2de0cf9d756ba5bd46d7699bcac4e9f78dae Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Mon, 20 Jul 2020 13:25:55 -0700 Subject: [PATCH] Add sampling functions in exponential, geometric and uniform distributions --- mindspore/nn/distribution/bernoulli.py | 69 ++++++-------- mindspore/nn/distribution/distribution.py | 20 ++-- mindspore/nn/distribution/exponential.py | 60 +++++++----- mindspore/nn/distribution/geometric.py | 71 ++++++++------- mindspore/nn/distribution/normal.py | 35 +++---- mindspore/nn/distribution/uniform.py | 91 +++++++++++-------- .../test_distribution/test_bernoulli.py | 22 ++--- .../test_distribution/test_exponential.py | 26 +++++- .../test_distribution/test_geometric.py | 30 ++++-- .../ascend/test_distribution/test_normal.py | 22 ++--- .../ascend/test_distribution/test_uniform.py | 27 +++++- .../python/nn/distribution/test_bernoulli.py | 1 - .../ut/python/nn/distribution/test_normal.py | 5 +- .../ut/python/nn/distribution/test_uniform.py | 2 +- 14 files changed, 284 insertions(+), 197 deletions(-) diff --git a/mindspore/nn/distribution/bernoulli.py b/mindspore/nn/distribution/bernoulli.py index 90e1a77930..f047326798 100644 --- a/mindspore/nn/distribution/bernoulli.py +++ b/mindspore/nn/distribution/bernoulli.py @@ -14,7 +14,6 @@ # ============================================================================ """Bernoulli Distribution""" from mindspore.ops import operations as P -from mindspore.ops import composite as C from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob from ...common import dtype as mstype @@ -37,10 +36,10 @@ class Bernoulli(Distribution): >>> # To initialize a Bernoulli distribution of prob 0.5 >>> n = nn.Bernoulli(0.5, dtype=mstype.int32) >>> - >>> # The following create two independent Bernoulli distributions + >>> # The following creates two independent Bernoulli distributions >>> n = nn.Bernoulli([0.5, 0.5], dtype=mstype.int32) >>> - >>> # A Bernoulli distribution can be initilize without arguments + >>> # A Bernoulli distribution can be initilized without arguments >>> # In this case, probs must be passed in through construct. >>> n = nn.Bernoulli(dtype=mstype.int32) >>> @@ -54,29 +53,29 @@ class Bernoulli(Distribution): >>> # All the following calls in construct are valid >>> def construct(self, value, probs_b, probs_a): >>> - >>> # Similar to calls can be made to other probability functions + >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function >>> ans = self.b1('prob', value) >>> # Evaluate with the respect to distribution b >>> ans = self.b1('prob', value, probs_b) >>> - >>> # Additional probs must be passed in through construct + >>> # probs must be passed in through construct >>> ans = self.b2('prob', value, probs_a) >>> - >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' + >>> # Functions 'sd', 'var', 'entropy' have the same usage like 'mean' >>> # Will return [0.0] >>> ans = self.b1('mean') >>> # Will return mean_b >>> ans = self.b1('mean', probs_b) >>> - >>> # Additional probs must be passed in through construct + >>> # probs must be passed in through construct >>> ans = self.b2('mean', probs_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar >>> ans = self.b1('kl_loss', 'Bernoulli', probs_b) >>> ans = self.b1('kl_loss', 'Bernoulli', probs_b, probs_a) >>> - >>> # Additional probs must be passed in through construct + >>> # Additional probs_a must be passed in through construct >>> ans = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) >>> >>> # Sample Usage @@ -110,18 +109,12 @@ class Bernoulli(Distribution): self.erf = P.Erf() self.fill = P.Fill() self.log = P.Log() - self.add = P.TensorAdd() - self.sq = P.Square() - self.mul = P.Mul() - self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() - self.shape = P.Shape() - self.const = P.ScalarToArray() self.less = P.Less() - self.cast = P.Cast() - self.erf = P.Erf() + self.shape = P.Shape() self.select = P.Select() - self.fill = P.Fill() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) def extend_repr(self): if self.is_scalar_batch: @@ -143,7 +136,7 @@ class Bernoulli(Distribution): MEAN(B) = probs1 """ if name == 'mean': - return self._probs if probs1 is None else probs1 + return self.probs if probs1 is None else probs1 return None def _mode(self, name='mode', probs1=None): @@ -166,9 +159,9 @@ class Bernoulli(Distribution): VAR(B) = probs1 * probs0 """ if name in self._variance_functions: - probs1 = self._probs if probs1 is None else probs1 + probs1 = self.probs if probs1 is None else probs1 probs0 = 1.0 - probs1 - return self.mul(probs0, probs1) + return probs0 * probs1 return None def _entropy(self, name='entropy', probs=None): @@ -177,9 +170,9 @@ class Bernoulli(Distribution): H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) """ if name == 'entropy': - probs1 = self._probs if probs is None else probs + probs1 = self.probs if probs is None else probs probs0 = 1 - probs1 - return -self.mul(probs0, self.log(probs0)) - self.mul(probs1, self.log(probs1)) + return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) return None def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): @@ -190,7 +183,7 @@ class Bernoulli(Distribution): name (str): name of the funtion. dist (str): type of the distributions. Should be "Bernoulli" in this case. probs1_b (Tensor): probs1 of distribution b. - probs1_a (Tensor): probs1 of distribution a. Default: self._probs. + probs1_a (Tensor): probs1 of distribution a. Default: self.probs. """ if name == 'cross_entropy' and dist == 'Bernoulli': return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) @@ -203,14 +196,14 @@ class Bernoulli(Distribution): Args: name (str): name of the function. Should be "prob" when passed in from construct. value (Tensor): a Tensor composed of only zeros and ones. - probs (Tensor): probability of outcome is 1. Default: self._probs. + probs (Tensor): probability of outcome is 1. Default: self.probs. .. math:: pmf(k) = probs1 if k = 1; pmf(k) = probs0 if k = 0; """ if name in self._prob_functions: - probs1 = self._probs if probs is None else probs + probs1 = self.probs if probs is None else probs probs0 = 1.0 - probs1 return (probs1 * value) + (probs0 * (1.0 - value)) return None @@ -222,7 +215,7 @@ class Bernoulli(Distribution): Args: name (str): name of the function. value (Tensor): value to be evaluated. - probs (Tensor): probability of outcome is 1. Default: self._probs. + probs (Tensor): probability of outcome is 1. Default: self.probs. .. math:: cdf(k) = 0 if k < 0; @@ -250,17 +243,17 @@ class Bernoulli(Distribution): name (str): name of the funtion. dist (str): type of the distributions. Should be "Bernoulli" in this case. probs1_b (Tensor): probs1 of distribution b. - probs1_a (Tensor): probs1 of distribution a. Default: self._probs. + probs1_a (Tensor): probs1 of distribution a. Default: self.probs. .. math:: KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + probs0_a * \log(\fract{probs0_a}{probs0_b}) """ if name in self._divergence_functions and dist == 'Bernoulli': - probs1_a = self._probs if probs1_a is None else probs1_a + probs1_a = self.probs if probs1_a is None else probs1_a probs0_a = 1.0 - probs1_a probs0_b = 1.0 - probs1_b - return self.mul(probs1_a, self.log(probs1_a / probs1_b)) + self.mul(probs0_a, self.log(probs0_a / probs0_b)) + return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) return None def _sample(self, name, shape=(), probs=None): @@ -270,21 +263,17 @@ class Bernoulli(Distribution): Args: name (str): name of the function. Should always be 'sample' when passed in from construct. shape (tuple): shape of the sample. Default: (). - probs (Tensor): probs1 of the samples. Default: self._probs. + probs (Tensor): probs1 of the samples. Default: self.probs. Returns: Tensor, shape is shape + batch_shape. """ if name == 'sample': - probs1 = self._probs if probs is None else probs - batch_shape = self.shape(probs1) - sample_shape = shape + batch_shape - mean_zero = self.const(0.0) - sd_one = self.const(1.0) - sqrt_two = self.sqrt(self.const(2.0)) - sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) - sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) + probs1 = self.probs if probs is None else probs + l_zero = self.const(0.0) + h_one = self.const(1.0) + sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) sample = self.less(sample_uniform, probs1) - sample = self.cast(sample, self._dtype) + sample = self.cast(sample, self.dtype) return sample return None diff --git a/mindspore/nn/distribution/distribution.py b/mindspore/nn/distribution/distribution.py index bdc6d44b17..52e23f0e9a 100644 --- a/mindspore/nn/distribution/distribution.py +++ b/mindspore/nn/distribution/distribution.py @@ -30,12 +30,14 @@ class Distribution(Cell): and _log_prob. Functions should be called through construct when used inside a network. Arguments should be passed in through *args in the form of function name followed by additional arguments. - Functions such as cdf and prob, requires a value to be passed in while - functions such as mean, and sd does not require arguments other than name. + Functions such as cdf and prob, require a value to be passed in while + functions such as mean and sd do not require arguments other than name. - Dist_spec_args are unique for each distribution. For example, mean and sd - are the dist_spec_args for a Normal distribution. For all functions, dist_spec_args, are optional. Passing in - the additional dist_spec_args will make the result to be evaluated with + Dist_spec_args are unique for each type of distribution. For example, mean and sd + are the dist_spec_args for a Normal distribution. + + For all functions, passing in dist_spec_args, are optional. + Passing in the additional dist_spec_args will make the result to be evaluated with new distribution specified by the dist_spec_args. But it won't change the original distribuion. """ @@ -258,7 +260,8 @@ class Distribution(Cell): Evaluate the log cdf at given value. Note: - Args must include value, and dist_spec_args are optional. + Args must include name of the function and value. + Dist_spec_args are optional. """ return self._call_log_cdf(*args) @@ -428,6 +431,11 @@ class Distribution(Cell): """ Override construct in Cell. + Note: + Names of supported functions: + 'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival' + 'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'. + Args: *inputs (list): inputs[0] is always the name of the function. """ diff --git a/mindspore/nn/distribution/exponential.py b/mindspore/nn/distribution/exponential.py index 7796e3fb1d..9816369e0b 100644 --- a/mindspore/nn/distribution/exponential.py +++ b/mindspore/nn/distribution/exponential.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Exponential Distribution""" +import numpy as np from mindspore.ops import operations as P from .distribution import Distribution from ...common import dtype as mstype @@ -36,10 +37,10 @@ class Exponential(Distribution): >>> # To initialize an Exponential distribution of rate 0.5 >>> n = nn.Exponential(0.5, dtype=mstype.float32) >>> - >>> # The following create two independent Exponential distributions + >>> # The following creates two independent Exponential distributions >>> n = nn.Exponential([0.5, 0.5], dtype=mstype.float32) >>> - >>> # A Exponential distribution can be initilize without arguments + >>> # A Exponential distribution can be initilized without arguments >>> # In this case, rate must be passed in through construct. >>> n = nn.Exponential(dtype=mstype.float32) >>> @@ -53,13 +54,13 @@ class Exponential(Distribution): >>> # All the following calls in construct are valid >>> def construct(self, value, rate_b, rate_a): >>> - >>> # Similar to calls can be made to other probability functions + >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function >>> ans = self.e1('prob', value) >>> # Evaluate with the respect to distribution b >>> ans = self.e1('prob', value, rate_b) >>> - >>> # Additional rate must be passed in through construct + >>> # Rate must be passed in through construct >>> ans = self.e2('prob', value, rate_a) >>> >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' @@ -68,7 +69,7 @@ class Exponential(Distribution): >>> # Will return mean_b >>> ans = self.e1('mean', rate_b) >>> - >>> # Additional rate must be passed in through construct + >>> # Rate must be passed in through construct >>> ans = self.e2('mean', rate_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar @@ -101,21 +102,20 @@ class Exponential(Distribution): else: self._rate = rate + self.minval = np.finfo(np.float).tiny + # ops needed for the class self.const = P.ScalarToArray() self.dtypeop = P.DType() self.exp = P.Exp() - self.log = P.Log() - self.add = P.TensorAdd() - self.mul = P.Mul() - self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() - self.shape = P.Shape() - self.normal = P.Normal(seed=seed) - self.sq = P.Square() self.fill = P.Fill() self.less = P.Less() + self.log = P.Log() self.select = P.Select() + self.shape = P.Shape() + self.sqrt = P.Sqrt() + self.sq = P.Square() + self.uniform = P.UniformReal(seed=seed) def extend_repr(self): if self.is_scalar_batch: @@ -137,7 +137,7 @@ class Exponential(Distribution): MEAN(EXP) = \fract{1.0}{\lambda}. """ if name == 'mean': - rate = self._rate if rate is None else rate + rate = self.rate if rate is None else rate return 1.0 / rate return None @@ -157,7 +157,7 @@ class Exponential(Distribution): sd(EXP) = \fract{1.0}{\lambda}. """ if name in self._variance_functions: - rate = self._rate if rate is None else rate + rate = self.rate if rate is None else rate return 1.0 / rate return None @@ -166,7 +166,7 @@ class Exponential(Distribution): .. math:: H(Exp) = 1 - \log(\lambda). """ - rate = self._rate if rate is None else rate + rate = self.rate if rate is None else rate if name == 'entropy': return 1.0 - self.log(rate) return None @@ -179,7 +179,7 @@ class Exponential(Distribution): name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct. dist (str): type of the distributions. Should be "Exponential" in this case. rate_b (Tensor): rate of distribution b. - rate_a (Tensor): rate of distribution a. Default: self._rate. + rate_a (Tensor): rate of distribution a. Default: self.rate. """ if name == 'cross_entropy' and dist == 'Exponential': return self._entropy(rate=rate_a) + self._kl_loss(name, dist, rate_b, rate_a) @@ -193,7 +193,7 @@ class Exponential(Distribution): Args: name (str): name of the function. value (Tensor): value to be evaluated. - rate (Tensor): rate of the distribution. Default: self._rate. + rate (Tensor): rate of the distribution. Default: self.rate. Note: Value should be greater or equal to zero. @@ -216,7 +216,7 @@ class Exponential(Distribution): Args: name (str): name of the function. value (Tensor): value to be evaluated. - rate (Tensor): rate of the distribution. Default: self._rate. + rate (Tensor): rate of the distribution. Default: self.rate. Note: Value should be greater or equal to zero. @@ -240,15 +240,29 @@ class Exponential(Distribution): name (str): name of the funtion. dist (str): type of the distributions. Should be "Exponential" in this case. rate_b (Tensor): rate of distribution b. - rate_a (Tensor): rate of distribution a. Default: self._rate. + rate_a (Tensor): rate of distribution a. Default: self.rate. """ if name in self._divergence_functions and dist == 'Exponential': - rate_a = self._rate if rate_a is None else rate_a + rate_a = self.rate if rate_a is None else rate_a return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 return None def _sample(self, name, shape=(), rate=None): + """ + Sampling. + + Args: + name (str): name of the function. + shape (tuple): shape of the sample. Default: (). + rate (Tensor): rate of the distribution. Default: self.rate. + + Returns: + Tensor, shape is shape + batch_shape. + """ if name == 'sample': - rate = self._rate if rate is None else rate - return self.fill(mstype.float32, shape + self.shape(rate), 1.0) + rate = self.rate if rate is None else rate + minval = self.const(self.minval) + maxval = self.const(1.0) + sample = self.uniform(shape + self.shape(rate), minval, maxval) + return -self.log(sample) / rate return None diff --git a/mindspore/nn/distribution/geometric.py b/mindspore/nn/distribution/geometric.py index ca869863c0..0a9da3b244 100644 --- a/mindspore/nn/distribution/geometric.py +++ b/mindspore/nn/distribution/geometric.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Geometric Distribution""" +import numpy as np from mindspore.ops import operations as P from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob @@ -37,10 +38,10 @@ class Geometric(Distribution): >>> # To initialize a Geometric distribution of prob 0.5 >>> n = nn.Geometric(0.5, dtype=mstype.int32) >>> - >>> # The following create two independent Geometric distributions + >>> # The following creates two independent Geometric distributions >>> n = nn.Geometric([0.5, 0.5], dtype=mstype.int32) >>> - >>> # A Geometric distribution can be initilize without arguments + >>> # A Geometric distribution can be initilized without arguments >>> # In this case, probs must be passed in through construct. >>> n = nn.Geometric(dtype=mstype.int32) >>> @@ -51,16 +52,16 @@ class Geometric(Distribution): >>> self.g1 = nn.Geometric(0.5, dtype=mstype.int32) >>> self.g2 = nn.Geometric(dtype=mstype.int32) >>> - >>> # All the following calls in construct are valid + >>> # Tthe following calls are valid in construct >>> def construct(self, value, probs_b, probs_a): >>> - >>> # Similar to calls can be made to other probability functions + >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function >>> ans = self.g1('prob', value) >>> # Evaluate with the respect to distribution b >>> ans = self.g1('prob', value, probs_b) >>> - >>> # Additional probs must be passed in through construct + >>> # Probs must be passed in through construct >>> ans = self.g2('prob', value, probs_a) >>> >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' @@ -69,7 +70,7 @@ class Geometric(Distribution): >>> # Will return mean_b >>> ans = self.g1('mean', probs_b) >>> - >>> # Additional probs must be passed in through construct + >>> # Probs must be passed in through construct >>> ans = self.g2('mean', probs_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar @@ -102,23 +103,22 @@ class Geometric(Distribution): else: self._probs = probs + self.minval = np.finfo(np.float).tiny + # ops needed for the class - self.log = P.Log() - self.add = P.TensorAdd() - self.mul = P.Mul() - self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() - self.shape = P.Shape() - self.dType = P.DType() + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.fill = P.Fill() self.floor = P.Floor() self.issubclass = P.IsSubClass() - self.const = P.ScalarToArray() self.less = P.Less() - self.normal = P.Normal(seed=seed) - self.sq = P.Square() - self.select = P.Select() - self.fill = P.Fill() + self.log = P.Log() self.pow = P.Pow() + self.select = P.Select() + self.shape = P.Shape() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) def extend_repr(self): if self.is_scalar_batch: @@ -140,7 +140,7 @@ class Geometric(Distribution): MEAN(Geo) = \fratc{1 - probs1}{probs1} """ if name == 'mean': - probs1 = self._probs if probs1 is None else probs1 + probs1 = self.probs if probs1 is None else probs1 return (1. - probs1) / probs1 return None @@ -160,7 +160,7 @@ class Geometric(Distribution): VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}} """ if name in self._variance_functions: - probs1 = self._probs if probs1 is None else probs1 + probs1 = self.probs if probs1 is None else probs1 return (1.0 - probs1) / self.sq(probs1) return None @@ -170,7 +170,7 @@ class Geometric(Distribution): H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} """ if name == 'entropy': - probs1 = self._probs if probs is None else probs + probs1 = self.probs if probs is None else probs probs0 = 1.0 - probs1 return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 return None @@ -183,7 +183,7 @@ class Geometric(Distribution): name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct. dist (str): type of the distributions. Should be "Geometric" in this case. probs1_b (Tensor): probability of success of distribution b. - probs1_a (Tensor): probability of success of distribution a. Default: self._probs. + probs1_a (Tensor): probability of success of distribution a. Default: self.probs. """ if name == 'cross_entropy' and dist == 'Geometric': return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) @@ -196,15 +196,15 @@ class Geometric(Distribution): Args: name (str): name of the function. Should be "prob" when passed in from construct. value (Tensor): a Tensor composed of only natural numbers. - probs (Tensor): probability of success. Default: self._probs. + probs (Tensor): probability of success. Default: self.probs. .. math:: pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = 0 if k < 0. """ if name in self._prob_functions: - probs1 = self._probs if probs is None else probs - dtype = self.dType(value) + probs1 = self.probs if probs is None else probs + dtype = self.dtypeop(value) if self.issubclass(dtype, mstype.int_): pass elif self.issubclass(dtype, mstype.float_): @@ -224,7 +224,7 @@ class Geometric(Distribution): Args: name (str): name of the function. value (Tensor): a Tensor composed of only natural numbers. - probs (Tensor): probability of success. Default: self._probs. + probs (Tensor): probability of success. Default: self.probs. .. math:: cdf(k) = 1 - probs0 ^ (k+1) if k >= 0; @@ -232,9 +232,9 @@ class Geometric(Distribution): """ if name in self._cdf_survival_functions: - probs1 = self._probs if probs is None else probs + probs1 = self.probs if probs is None else probs probs0 = 1.0 - probs1 - dtype = self.dType(value) + dtype = self.dtypeop(value) if self.issubclass(dtype, mstype.int_): pass elif self.issubclass(dtype, mstype.float_): @@ -255,16 +255,16 @@ class Geometric(Distribution): name (str): name of the funtion. dist (str): type of the distributions. Should be "Geometric" in this case. probs1_b (Tensor): probability of success of distribution b. - probs1_a (Tensor): probability of success of distribution a. Default: self._probs. + probs1_a (Tensor): probability of success of distribution a. Default: self.probs. .. math:: KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b}) """ if name in self._divergence_functions and dist == 'Geometric': - probs1_a = self._probs if probs1_a is None else probs1_a + probs1_a = self.probs if probs1_a is None else probs1_a probs0_a = 1.0 - probs1_a probs0_b = 1.0 - probs1_b - return self.log(probs1_a / probs1_b) + self.mul(probs0_a / probs1_a, self.log(probs0_a / probs0_b)) + return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) return None def _sample(self, name, shape=(), probs=None): @@ -274,12 +274,15 @@ class Geometric(Distribution): Args: name (str): name of the function. Should always be 'sample' when passed in from construct. shape (tuple): shape of the sample. Default: (). - probs (Tensor): probs1 of the samples. Default: self._probs. + probs (Tensor): probability of success. Default: self.probs. Returns: Tensor, shape is shape + batch_shape. """ if name == 'sample': - probs = self._probs if probs is None else probs - return self.fill(mstype.float32, shape + self.shape(probs), 1.0) + probs = self.probs if probs is None else probs + minval = self.const(self.minval) + maxval = self.const(1.0) + sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) + return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) return None diff --git a/mindspore/nn/distribution/normal.py b/mindspore/nn/distribution/normal.py index 7139794a26..7bfea6c7e9 100644 --- a/mindspore/nn/distribution/normal.py +++ b/mindspore/nn/distribution/normal.py @@ -26,22 +26,21 @@ class Normal(Distribution): Normal distribution. Args: - mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution. - sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution. + mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Normal distribution. + sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Normal distribution. seed (int): seed to use in sampling. Default: 0. dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. name (str): name of the distribution. Default: Normal. - Note: Standard deviation should be greater than zero. Dist_spec_args are mean and sd. Examples: - >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 + >>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0 >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) >>> - >>> # The following create two independent normal distributions + >>> # The following creates two independent Normal distributions >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) >>> >>> # A normal distribution can be initilize without arguments @@ -55,16 +54,16 @@ class Normal(Distribution): >>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32) >>> self.n2 = nn.Normal(dtype=mstype.float32) >>> - >>> # All the following calls in construct are valid + >>> # The following calls are valid in construct >>> def construct(self, value, mean_b, sd_b, mean_a, sd_a): >>> - >>> # Similar to calls can be made to other probability functions + >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function >>> ans = self.n1('prob', value) >>> # Evaluate with the respect to distribution b >>> ans = self.n1('prob', value, mean_b, sd_b) >>> - >>> # Additional mean and sd must be passed in through construct + >>> # mean and sd must be passed in through construct >>> ans = self.n2('prob', value, mean_a, sd_a) >>> >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' @@ -73,7 +72,7 @@ class Normal(Distribution): >>> # Will return mean_b >>> ans = self.n1('mean', mean_b, sd_b) >>> - >>> # Additional mean and sd must be passed in through construct + >>> # mean and sd must be passed in through construct >>> ans = self.n2('mean', mean_a, sd_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar @@ -111,20 +110,16 @@ class Normal(Distribution): self.seed = seed #ops needed for the class + self.const = P.ScalarToArray() + self.erf = P.Erf() self.exp = P.Exp() self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step + self.fill = P.Fill() self.log = P.Log() self.shape = P.Shape() self.sq = P.Square() - self.log = P.Log() self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() - self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step - self.shape = P.Shape() self.zeroslike = P.ZerosLike() - self.const = P.ScalarToArray() - self.erf = P.Erf() - self.fill = P.Fill() def extend_repr(self): if self.is_scalar_batch: @@ -231,8 +226,8 @@ class Normal(Distribution): if name in self._cdf_survival_functions: mean = self._mean_value if mean is None else mean sd = self._sd_value if sd is None else sd - sqrt2 = self.sqrt(self.fill(mstype.float32, self.shape(sd), 2.0)) - adjusted = (value - mean) / self.mul(sd, sqrt2) + sqrt2 = self.sqrt(self.const(2.0)) + adjusted = (value - mean) / (sd * sqrt2) return 0.5 * (1.0 + self.erf(adjusted)) return None @@ -276,11 +271,11 @@ class Normal(Distribution): if name == 'sample': mean = self._mean_value if mean is None else mean sd = self._sd_value if sd is None else sd - batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd))) + batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) sample_shape = shape + batch_shape mean_zero = self.const(0.0) sd_one = self.const(1.0) sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) - sample = self.add(mean, self.mul(sample_norm, sd)) + sample = mean + sample_norm * sd return sample return None diff --git a/mindspore/nn/distribution/uniform.py b/mindspore/nn/distribution/uniform.py index 890b2283bc..3b90bbe736 100644 --- a/mindspore/nn/distribution/uniform.py +++ b/mindspore/nn/distribution/uniform.py @@ -37,10 +37,10 @@ class Uniform(Distribution): >>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0 >>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32) >>> - >>> # The following create two independent Uniform distributions + >>> # The following creates two independent Uniform distributions >>> n = nn.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32) >>> - >>> # A Uniform distribution can be initilize without arguments + >>> # A Uniform distribution can be initilized without arguments >>> # In this case, high and low must be passed in through construct. >>> n = nn.Uniform(dtype=mstype.float32) >>> @@ -54,13 +54,13 @@ class Uniform(Distribution): >>> # All the following calls in construct are valid >>> def construct(self, value, low_b, high_b, low_a, high_a): >>> - >>> # Similar to calls can be made to other probability functions + >>> # Similar calls can be made to other probability functions >>> # by replacing 'prob' with the name of the function >>> ans = self.u1('prob', value) >>> # Evaluate with the respect to distribution b >>> ans = self.u1('prob', value, low_b, high_b) >>> - >>> # Additional high and low must be passed in through construct + >>> # High and low must be passed in through construct >>> ans = self.u2('prob', value, low_a, high_a) >>> >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' @@ -69,7 +69,7 @@ class Uniform(Distribution): >>> # Will return low_b >>> ans = self.u1('mean', low_b, high_b) >>> - >>> # Additional high and low must be passed in through construct + >>> # High and low must be passed in through construct >>> ans = self.u2('mean', low_a, high_a) >>> >>> # Usage of 'kl_loss' and 'cross_entropy' are similar @@ -100,7 +100,7 @@ class Uniform(Distribution): if low is not None and high is not None: self._low = convert_to_batch(low, self._broadcast_shape, dtype) self._high = convert_to_batch(high, self._broadcast_shape, dtype) - check_greater(self._low, self._high, "low value", "high value") + check_greater(self.low, self.high, "low value", "high value") else: self._low = low self._high = high @@ -109,20 +109,17 @@ class Uniform(Distribution): self.const = P.ScalarToArray() self.dtypeop = P.DType() self.exp = P.Exp() - self.log = P.Log() - self.add = P.TensorAdd() - self.mul = P.Mul() - self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() + self.fill = P.Fill() self.less = P.Less() self.lessequal = P.LessEqual() - self.sq = P.Square() - self.select = P.Select() - self.zeroslike = P.ZerosLike() + self.log = P.Log() self.logicaland = P.LogicalAnd() - self.fill = P.Fill() + self.select = P.Select() self.shape = P.Shape() - self.normal = P.Normal(seed=seed) + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) + self.zeroslike = P.ZerosLike() def extend_repr(self): if self.is_scalar_batch: @@ -152,8 +149,8 @@ class Uniform(Distribution): range(U) = high -low """ if name == 'range': - low = self._low if low is None else low - high = self._high if high is None else high + low = self.low if low is None else low + high = self.high if high is None else high return high - low return None @@ -163,8 +160,8 @@ class Uniform(Distribution): MEAN(U) = \fract{low + high}{2}. """ if name == 'mean': - low = self._low if low is None else low - high = self._high if high is None else high + low = self.low if low is None else low + high = self.high if high is None else high return (low + high) / 2. return None @@ -174,8 +171,8 @@ class Uniform(Distribution): VAR(U) = \fract{(high -low) ^ 2}{12}. """ if name in self._variance_functions: - low = self._low if low is None else low - high = self._high if high is None else high + low = self.low if low is None else low + high = self.high if high is None else high return self.sq(high - low) / 12.0 return None @@ -185,8 +182,8 @@ class Uniform(Distribution): H(U) = \log(high - low). """ if name == 'entropy': - low = self._low if low is None else low - high = self._high if high is None else high + low = self.low if low is None else low + high = self.high if high is None else high return self.log(high - low) return None @@ -199,8 +196,8 @@ class Uniform(Distribution): dist (str): type of the distributions. Should be "Uniform" in this case. low_b (Tensor): lower bound of distribution b. high_b (Tensor): upper bound of distribution b. - low_a (Tensor): lower bound of distribution a. Default: self._low. - high_a (Tensor): upper bound of distribution a. Default: self._high. + low_a (Tensor): lower bound of distribution a. Default: self.low. + high_a (Tensor): upper bound of distribution a. Default: self.high. """ if name == 'cross_entropy' and dist == 'Uniform': return self._entropy(low=low_a, high=high_a) + self._kl_loss(name, dist, low_b, high_b, low_a, high_a) @@ -213,8 +210,8 @@ class Uniform(Distribution): Args: name (str): name of the function. value (Tensor): value to be evaluated. - low (Tensor): lower bound of the distribution. Default: self._low. - high (Tensor): upper bound of the distribution. Default: self._high. + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. .. math:: pdf(x) = 0 if x < low; @@ -243,12 +240,12 @@ class Uniform(Distribution): dist (str): type of the distributions. Should be "Uniform" in this case. low_b (Tensor): lower bound of distribution b. high_b (Tensor): upper bound of distribution b. - low_a (Tensor): lower bound of distribution a. Default: self._low. - high_a (Tensor): upper bound of distribution a. Default: self._high. + low_a (Tensor): lower bound of distribution a. Default: self.low. + high_a (Tensor): upper bound of distribution a. Default: self.high. """ if name in self._divergence_functions and dist == 'Uniform': - low_a = self._low if low_a is None else low_a - high_a = self._high if high_a is None else high_a + low_a = self.low if low_a is None else low_a + high_a = self.high if high_a is None else high_a kl = self.log(high_b - low_b) / self.log(high_a - low_a) comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) return self.select(comp, kl, self.log(self.zeroslike(kl))) @@ -261,8 +258,8 @@ class Uniform(Distribution): Args: name (str): name of the function. value (Tensor): value to be evaluated. - low (Tensor): lower bound of the distribution. Default: self._low. - high (Tensor): upper bound of the distribution. Default: self._high. + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. .. math:: cdf(x) = 0 if x < low; @@ -270,8 +267,8 @@ class Uniform(Distribution): cdf(x) = 1 if x > high; """ if name in self._cdf_survival_functions: - low = self._low if low is None else low - high = self._high if high is None else high + low = self.low if low is None else low + high = self.high if high is None else high prob = (value - low) / (high - low) broadcast_shape = self.shape(prob) zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) @@ -283,9 +280,25 @@ class Uniform(Distribution): return None def _sample(self, name, shape=(), low=None, high=None): + """ + Sampling. + + Args: + name (str): name of the function. Should always be 'sample' when passed in from construct. + shape (tuple): shape of the sample. Default: (). + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. + + Returns: + Tensor, shape is shape + batch_shape. + """ if name == 'sample': - low = self._low if low is None else low - high = self._high if high is None else high + low = self.low if low is None else low + high = self.high if high is None else high broadcast_shape = self.shape(low + high) - return self.fill(mstype.float32, shape + broadcast_shape, 1.0) + l_zero = self.const(0.0) + h_one = self.const(1.0) + sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) + sample = (high - low) * sample_uniform + low + return sample return None diff --git a/tests/st/ops/ascend/test_distribution/test_bernoulli.py b/tests/st/ops/ascend/test_distribution/test_bernoulli.py index 035e315964..451530116b 100644 --- a/tests/st/ops/ascend/test_distribution/test_bernoulli.py +++ b/tests/st/ops/ascend/test_distribution/test_bernoulli.py @@ -25,7 +25,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Prob(nn.Cell): """ - Test class: probability of bernoulli distribution. + Test class: probability of Bernoulli distribution. """ def __init__(self): super(Prob, self).__init__() @@ -50,7 +50,7 @@ def test_pmf(): class LogProb(nn.Cell): """ - Test class: log probability of bernoulli distribution. + Test class: log probability of Bernoulli distribution. """ def __init__(self): super(LogProb, self).__init__() @@ -74,7 +74,7 @@ def test_log_likelihood(): class KL(nn.Cell): """ - Test class: kl_loss between bernoulli distributions. + Test class: kl_loss between Bernoulli distributions. """ def __init__(self): super(KL, self).__init__() @@ -100,7 +100,7 @@ def test_kl_loss(): class Basics(nn.Cell): """ - Test class: mean/sd/mode of bernoulli distribution. + Test class: mean/sd/mode of Bernoulli distribution. """ def __init__(self): super(Basics, self).__init__() @@ -112,7 +112,7 @@ class Basics(nn.Cell): def test_basics(): """ - Test mean/standard deviation/mode and probs. + Test mean/standard deviation/mode. """ basics = Basics() mean, sd, mode = basics() @@ -123,14 +123,10 @@ def test_basics(): assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() - b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) - probs = b.probs() - expect_probs = [0.7, 0.5] - assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() class Sampling(nn.Cell): """ - Test class: log probability of bernoulli distribution. + Test class: log probability of Bernoulli distribution. """ def __init__(self, shape, seed=0): super(Sampling, self).__init__() @@ -202,7 +198,7 @@ def test_logcdf(): class SF(nn.Cell): """ - Test class: survival function of bernoulli distributions. + Test class: survival function of Bernoulli distributions. """ def __init__(self): super(SF, self).__init__() @@ -227,7 +223,7 @@ def test_survival(): class LogSF(nn.Cell): """ - Test class: log survival function of bernoulli distributions. + Test class: log survival function of Bernoulli distributions. """ def __init__(self): super(LogSF, self).__init__() @@ -251,7 +247,7 @@ def test_log_survival(): class EntropyH(nn.Cell): """ - Test class: entropy of bernoulli distributions. + Test class: entropy of Bernoulli distributions. """ def __init__(self): super(EntropyH, self).__init__() diff --git a/tests/st/ops/ascend/test_distribution/test_exponential.py b/tests/st/ops/ascend/test_distribution/test_exponential.py index acf05f17ff..823f9b0e1a 100644 --- a/tests/st/ops/ascend/test_distribution/test_exponential.py +++ b/tests/st/ops/ascend/test_distribution/test_exponential.py @@ -109,7 +109,7 @@ class Basics(nn.Cell): def test_basics(): """ - Test mean/standard deviation and range. + Test mean/standard/mode deviation. """ basics = Basics() mean, sd, mode = basics() @@ -121,6 +121,30 @@ def test_basics(): assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() +class Sampling(nn.Cell): + """ + Test class: sample of Exponential distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32) + self.shape = shape + + @ms_function + def construct(self, rate=None): + return self.e('sample', self.shape, rate) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(rate) + assert output.shape == (2, 3, 3) + class CDF(nn.Cell): """ Test class: cdf of Exponential distribution. diff --git a/tests/st/ops/ascend/test_distribution/test_geometric.py b/tests/st/ops/ascend/test_distribution/test_geometric.py index b056d4acab..b3b9995bcb 100644 --- a/tests/st/ops/ascend/test_distribution/test_geometric.py +++ b/tests/st/ops/ascend/test_distribution/test_geometric.py @@ -99,7 +99,7 @@ def test_kl_loss(): class Basics(nn.Cell): """ - Test class: mean/sd of Geometric distribution. + Test class: mean/sd/mode of Geometric distribution. """ def __init__(self): super(Basics, self).__init__() @@ -111,7 +111,7 @@ class Basics(nn.Cell): def test_basics(): """ - Test mean/standard deviation/mode and probs. + Test mean/standard deviation/mode. """ basics = Basics() mean, sd, mode = basics() @@ -122,10 +122,28 @@ def test_basics(): assert (np.abs(mean.asnumpy()- expect_mean) < tol).all() assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() - b = nn.Geometric([0.7, 0.5], dtype=dtype.int32) - probs = b.probs() - expect_probs = [0.7, 0.5] - assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: log probability of bernoulli distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.g = nn.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32) + self.shape = shape + + @ms_function + def construct(self, probs=None): + return self.g('sample', self.shape, probs) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + sample = Sampling(shape) + output = sample() + assert output.shape == (2, 3, 2) class CDF(nn.Cell): """ diff --git a/tests/st/ops/ascend/test_distribution/test_normal.py b/tests/st/ops/ascend/test_distribution/test_normal.py index 5eab57973c..d3a93c244c 100644 --- a/tests/st/ops/ascend/test_distribution/test_normal.py +++ b/tests/st/ops/ascend/test_distribution/test_normal.py @@ -25,7 +25,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Prob(nn.Cell): """ - Test class: probability of normal distribution. + Test class: probability of Normal distribution. """ def __init__(self): super(Prob, self).__init__() @@ -48,7 +48,7 @@ def test_pdf(): class LogProb(nn.Cell): """ - Test class: log probability of normal distribution. + Test class: log probability of Normal distribution. """ def __init__(self): super(LogProb, self).__init__() @@ -72,7 +72,7 @@ def test_log_likelihood(): class KL(nn.Cell): """ - Test class: kl_loss of normal distribution. + Test class: kl_loss of Normal distribution. """ def __init__(self): super(KL, self).__init__() @@ -106,7 +106,7 @@ def test_kl_loss(): class Basics(nn.Cell): """ - Test class: mean/sd of normal distribution. + Test class: mean/sd/mode of Normal distribution. """ def __init__(self): super(Basics, self).__init__() @@ -131,7 +131,7 @@ def test_basics(): class Sampling(nn.Cell): """ - Test class: sample of normal distribution. + Test class: sample of Normal distribution. """ def __init__(self, shape, seed=0): super(Sampling, self).__init__() @@ -156,7 +156,7 @@ def test_sample(): class CDF(nn.Cell): """ - Test class: cdf of normal distribution. + Test class: cdf of Normal distribution. """ def __init__(self): super(CDF, self).__init__() @@ -180,7 +180,7 @@ def test_cdf(): class LogCDF(nn.Cell): """ - Test class: log_cdf of normal distribution. + Test class: log_cdf of Mormal distribution. """ def __init__(self): super(LogCDF, self).__init__() @@ -203,7 +203,7 @@ def test_log_cdf(): class SF(nn.Cell): """ - Test class: survival function of normal distribution. + Test class: survival function of Normal distribution. """ def __init__(self): super(SF, self).__init__() @@ -226,7 +226,7 @@ def test_survival(): class LogSF(nn.Cell): """ - Test class: log survival function of normal distribution. + Test class: log survival function of Normal distribution. """ def __init__(self): super(LogSF, self).__init__() @@ -249,7 +249,7 @@ def test_log_survival(): class EntropyH(nn.Cell): """ - Test class: entropy of normal distribution. + Test class: entropy of Normal distribution. """ def __init__(self): super(EntropyH, self).__init__() @@ -272,7 +272,7 @@ def test_entropy(): class CrossEntropy(nn.Cell): """ - Test class: cross entropy between normal distribution. + Test class: cross entropy between Normal distributions. """ def __init__(self): super(CrossEntropy, self).__init__() diff --git a/tests/st/ops/ascend/test_distribution/test_uniform.py b/tests/st/ops/ascend/test_distribution/test_uniform.py index b237390e1c..bfcf9b7235 100644 --- a/tests/st/ops/ascend/test_distribution/test_uniform.py +++ b/tests/st/ops/ascend/test_distribution/test_uniform.py @@ -111,7 +111,7 @@ class Basics(nn.Cell): def test_basics(): """ - Test mean/standard deviation/mode. + Test mean/standard deviation. """ basics = Basics() mean, sd = basics() @@ -121,6 +121,31 @@ def test_basics(): assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() +class Sampling(nn.Cell): + """ + Test class: sample of Uniform distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.u = nn.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32) + self.shape = shape + + @ms_function + def construct(self, low=None, high=None): + return self.u('sample', self.shape, low, high) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + low = Tensor([1.0], dtype=dtype.float32) + high = Tensor([2.0, 3.0, 4.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(low, high) + assert output.shape == (2, 3, 3) + class CDF(nn.Cell): """ Test class: cdf of Uniform distribution. diff --git a/tests/ut/python/nn/distribution/test_bernoulli.py b/tests/ut/python/nn/distribution/test_bernoulli.py index 838cda6004..9233e2d395 100644 --- a/tests/ut/python/nn/distribution/test_bernoulli.py +++ b/tests/ut/python/nn/distribution/test_bernoulli.py @@ -21,7 +21,6 @@ import mindspore.nn as nn from mindspore import dtype from mindspore import Tensor - def test_arguments(): """ Args passing during initialization. diff --git a/tests/ut/python/nn/distribution/test_normal.py b/tests/ut/python/nn/distribution/test_normal.py index 1773337dec..87a92ad8da 100644 --- a/tests/ut/python/nn/distribution/test_normal.py +++ b/tests/ut/python/nn/distribution/test_normal.py @@ -111,7 +111,7 @@ class NormalKl(nn.Cell): def test_kl(): """ - Test kl_loss + Test kl_loss. """ net = NormalKl() mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) @@ -136,6 +136,9 @@ class NormalCrossEntropy(nn.Cell): return h1 + h2 def test_cross_entropy(): + """ + Test cross entropy between Normal distributions. + """ net = NormalCrossEntropy() mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) diff --git a/tests/ut/python/nn/distribution/test_uniform.py b/tests/ut/python/nn/distribution/test_uniform.py index b568b4dab4..7f9b442816 100644 --- a/tests/ut/python/nn/distribution/test_uniform.py +++ b/tests/ut/python/nn/distribution/test_uniform.py @@ -120,7 +120,7 @@ class UniformKl(nn.Cell): def test_kl(): """ - Test kl_loss + Test kl_loss. """ net = UniformKl() low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32)