| @@ -26,7 +26,7 @@ class Bijector(Cell): | |||
| is_constant_jacobian (bool): if the bijector has constant derivative. Default: False. | |||
| is_injective (bool): if the bijector is an one-to-one mapping. Default: True. | |||
| name (str): name of the bijector. Default: None. | |||
| dtype (mstype): type of the distribution the bijector can operate on. Default: None. | |||
| dtype (mindspore.dtype): type of the distribution the bijector can operate on. Default: None. | |||
| param (dict): parameters used to initialize the bijector. Default: None. | |||
| """ | |||
| def __init__(self, | |||
| @@ -110,7 +110,7 @@ class Bijector(Cell): | |||
| *args: args[0] shall be either a distribution or the name of a bijector function. | |||
| """ | |||
| if isinstance(args[0], Distribution): | |||
| return TransformedDistribution(self, args[0]) | |||
| return TransformedDistribution(self, args[0], self.distribution.dtype) | |||
| return super(Bijector, self).__call__(*args, **kwargs) | |||
| def construct(self, name, *args, **kwargs): | |||
| @@ -22,7 +22,10 @@ from .bijector import Bijector | |||
| class Softplus(Bijector): | |||
| r""" | |||
| Softplus Bijector. | |||
| This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor. | |||
| This Bijector performs the operation, where k is the sharpness factor. | |||
| .. math:: | |||
| Y = \frac{\log(1 + e ^ {kX})}{k} | |||
| Args: | |||
| sharpness (float): scale factor. Default: 1.0. | |||
| @@ -184,7 +184,7 @@ def check_greater(a, b, name_a, name_b): | |||
| def check_prob(p): | |||
| """ | |||
| Check if p is a proper probability, i.e. 0 <= p <=1. | |||
| Check if p is a proper probability, i.e. 0 < p <1. | |||
| Args: | |||
| p (Tensor, Parameter): value to be checked. | |||
| @@ -196,12 +196,12 @@ def check_prob(p): | |||
| if not isinstance(p.default_input, Tensor): | |||
| return | |||
| p = p.default_input | |||
| comp = np.less(p.asnumpy(), np.zeros(p.shape)) | |||
| if comp.any(): | |||
| raise ValueError('Probabilities should be greater than or equal to zero') | |||
| comp = np.greater(p.asnumpy(), np.ones(p.shape)) | |||
| if comp.any(): | |||
| raise ValueError('Probabilities should be less than or equal to one') | |||
| comp = np.less(np.zeros(p.shape), p.asnumpy()) | |||
| if not comp.all(): | |||
| raise ValueError('Probabilities should be greater than zero') | |||
| comp = np.greater(np.ones(p.shape), p.asnumpy()) | |||
| if not comp.all(): | |||
| raise ValueError('Probabilities should be less than one') | |||
| def logits_to_probs(logits, is_binary=False): | |||
| @@ -110,6 +110,7 @@ class Bernoulli(Distribution): | |||
| self.const = P.ScalarToArray() | |||
| self.dtypeop = P.DType() | |||
| self.erf = P.Erf() | |||
| self.exp = P.Exp() | |||
| self.fill = P.Fill() | |||
| self.log = P.Log() | |||
| self.less = P.Less() | |||
| @@ -159,7 +160,7 @@ class Bernoulli(Distribution): | |||
| """ | |||
| probs1 = self.probs if probs1 is None else probs1 | |||
| probs0 = 1.0 - probs1 | |||
| return probs0 * probs1 | |||
| return self.exp(self.log(probs0) + self.log(probs1)) | |||
| def _entropy(self, probs=None): | |||
| r""" | |||
| @@ -183,7 +184,7 @@ class Bernoulli(Distribution): | |||
| return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) | |||
| return None | |||
| def _prob(self, value, probs=None): | |||
| def _log_prob(self, value, probs=None): | |||
| r""" | |||
| pmf of Bernoulli distribution. | |||
| @@ -197,7 +198,7 @@ class Bernoulli(Distribution): | |||
| """ | |||
| probs1 = self.probs if probs is None else probs | |||
| probs0 = 1.0 - probs1 | |||
| return (probs1 * value) + (probs0 * (1.0 - value)) | |||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | |||
| def _cdf(self, value, probs=None): | |||
| r""" | |||
| @@ -15,6 +15,7 @@ | |||
| """basic""" | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param | |||
| class Distribution(Cell): | |||
| @@ -28,12 +29,15 @@ class Distribution(Cell): | |||
| Note: | |||
| Derived class should override operations such as ,_mean, _prob, | |||
| and _log_prob. Arguments should be passed in through *args or **kwargs. | |||
| and _log_prob. Required arguments, such as value for _prob, | |||
| should be passed in through args or kwargs. dist_spec_args which specify | |||
| a new distribution are optional. | |||
| Dist_spec_args are unique for each type of distribution. For example, mean and sd | |||
| are the dist_spec_args for a Normal distribution. | |||
| dist_spec_args are unique for each type of distribution. For example, mean and sd | |||
| are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args | |||
| for exponential distribution. | |||
| For all functions, passing in dist_spec_args, are optional. | |||
| For all functions, passing in dist_spec_args, is optional. | |||
| Passing in the additional dist_spec_args will make the result to be evaluated with | |||
| new distribution specified by the dist_spec_args. But it won't change the | |||
| original distribuion. | |||
| @@ -49,7 +53,7 @@ class Distribution(Cell): | |||
| """ | |||
| super(Distribution, self).__init__() | |||
| validator.check_value_type('name', name, [str], 'distribution_name') | |||
| validator.check_value_type('seed', seed, [int], name) | |||
| validator.check_integer('seed', seed, 0, Rel.GE, name) | |||
| self._name = name | |||
| self._seed = seed | |||
| @@ -191,7 +195,7 @@ class Distribution(Cell): | |||
| Note: | |||
| Args must include value. | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._call_log_prob(*args, **kwargs) | |||
| @@ -210,7 +214,7 @@ class Distribution(Cell): | |||
| Note: | |||
| Args must include value. | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._call_prob(*args, **kwargs) | |||
| @@ -229,7 +233,7 @@ class Distribution(Cell): | |||
| Note: | |||
| Args must include value. | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._call_cdf(*args, **kwargs) | |||
| @@ -266,7 +270,7 @@ class Distribution(Cell): | |||
| Note: | |||
| Args must include value. | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._call_log_cdf(*args, **kwargs) | |||
| @@ -285,7 +289,7 @@ class Distribution(Cell): | |||
| Note: | |||
| Args must include value. | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._call_survival(*args, **kwargs) | |||
| @@ -313,7 +317,7 @@ class Distribution(Cell): | |||
| Note: | |||
| Args must include value. | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._call_log_survival(*args, **kwargs) | |||
| @@ -341,7 +345,7 @@ class Distribution(Cell): | |||
| Evaluate the mean. | |||
| Note: | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._mean(*args, **kwargs) | |||
| @@ -350,7 +354,7 @@ class Distribution(Cell): | |||
| Evaluate the mode. | |||
| Note: | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._mode(*args, **kwargs) | |||
| @@ -359,7 +363,7 @@ class Distribution(Cell): | |||
| Evaluate the standard deviation. | |||
| Note: | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._call_sd(*args, **kwargs) | |||
| @@ -368,7 +372,7 @@ class Distribution(Cell): | |||
| Evaluate the variance. | |||
| Note: | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._call_var(*args, **kwargs) | |||
| @@ -395,7 +399,7 @@ class Distribution(Cell): | |||
| Evaluate the entropy. | |||
| Note: | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._entropy(*args, **kwargs) | |||
| @@ -424,7 +428,7 @@ class Distribution(Cell): | |||
| Note: | |||
| Shape of the sample is default to (). | |||
| Dist_spec_args are optional. | |||
| dist_spec_args are optional. | |||
| """ | |||
| return self._sample(*args, **kwargs) | |||
| @@ -199,7 +199,7 @@ class Exponential(Distribution): | |||
| pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 | |||
| """ | |||
| rate = self.rate if rate is None else rate | |||
| prob = rate * self.exp(-1. * rate * value) | |||
| prob = self.exp(self.log(rate) - rate * value) | |||
| zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | |||
| comp = self.less(value, zeros) | |||
| return self.select(comp, zeros, prob) | |||
| @@ -113,6 +113,7 @@ class Geometric(Distribution): | |||
| self.cast = P.Cast() | |||
| self.const = P.ScalarToArray() | |||
| self.dtypeop = P.DType() | |||
| self.exp = P.Exp() | |||
| self.fill = P.Fill() | |||
| self.floor = P.Floor() | |||
| self.issubclass = P.IsSubClass() | |||
| @@ -205,7 +206,7 @@ class Geometric(Distribution): | |||
| value = self.floor(value) | |||
| else: | |||
| return None | |||
| pmf = self.pow((1.0 - probs1), value) * probs1 | |||
| pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | |||
| zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | |||
| comp = self.less(value, zeros) | |||
| return self.select(comp, zeros, pmf) | |||
| @@ -18,7 +18,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import convert_to_batch, check_greater_equal_zero, check_type | |||
| from ._utils.utils import convert_to_batch, check_greater_zero, check_type | |||
| class Normal(Distribution): | |||
| @@ -106,7 +106,7 @@ class Normal(Distribution): | |||
| if mean is not None and sd is not None: | |||
| self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype) | |||
| self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype) | |||
| check_greater_equal_zero(self._sd_value, "Standard deviation") | |||
| check_greater_zero(self._sd_value, "Standard deviation") | |||
| else: | |||
| self._mean_value = mean | |||
| self._sd_value = sd | |||
| @@ -166,7 +166,7 @@ class Normal(Distribution): | |||
| H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | |||
| """ | |||
| sd = self._sd_value if sd is None else sd | |||
| return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) | |||
| return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) | |||
| def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): | |||
| r""" | |||
| @@ -198,7 +198,7 @@ class Normal(Distribution): | |||
| mean = self._mean_value if mean is None else mean | |||
| sd = self._sd_value if sd is None else sd | |||
| unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | |||
| neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) | |||
| neg_normalization = -1. * self.log(self.sqrt(self.const(2. * np.pi))) - self.log(sd) | |||
| return unnormalized_log_prob + neg_normalization | |||
| def _cdf(self, value, mean=None, sd=None): | |||
| @@ -216,8 +216,8 @@ class Uniform(Distribution): | |||
| """ | |||
| low = self.low if low is None else low | |||
| high = self.high if high is None else high | |||
| ones = self.fill(self.dtype, self.shape(value), 1.0) | |||
| prob = ones / (high - low) | |||
| neg_ones = self.fill(self.dtype, self.shape(value), -1.0) | |||
| prob = self.exp(neg_ones * self.log(high - low)) | |||
| broadcast_shape = self.shape(prob) | |||
| zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | |||
| comp_lo = self.less(value, low) | |||
| @@ -28,7 +28,7 @@ def test_arguments(): | |||
| """ | |||
| b = msd.Bernoulli() | |||
| assert isinstance(b, msd.Distribution) | |||
| b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) | |||
| b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | |||
| assert isinstance(b, msd.Distribution) | |||
| def test_type(): | |||
| @@ -51,6 +51,10 @@ def test_prob(): | |||
| msd.Bernoulli([-0.1], dtype=dtype.int32) | |||
| with pytest.raises(ValueError): | |||
| msd.Bernoulli([1.1], dtype=dtype.int32) | |||
| with pytest.raises(ValueError): | |||
| msd.Bernoulli([0.0], dtype=dtype.int32) | |||
| with pytest.raises(ValueError): | |||
| msd.Bernoulli([1.0], dtype=dtype.int32) | |||
| class BernoulliProb(nn.Cell): | |||
| """ | |||
| @@ -29,7 +29,7 @@ def test_arguments(): | |||
| """ | |||
| g = msd.Geometric() | |||
| assert isinstance(g, msd.Distribution) | |||
| g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) | |||
| g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | |||
| assert isinstance(g, msd.Distribution) | |||
| def test_type(): | |||
| @@ -52,6 +52,10 @@ def test_prob(): | |||
| msd.Geometric([-0.1], dtype=dtype.int32) | |||
| with pytest.raises(ValueError): | |||
| msd.Geometric([1.1], dtype=dtype.int32) | |||
| with pytest.raises(ValueError): | |||
| msd.Geometric([0.0], dtype=dtype.int32) | |||
| with pytest.raises(ValueError): | |||
| msd.Geometric([1.0], dtype=dtype.int32) | |||
| class GeometricProb(nn.Cell): | |||
| """ | |||
| @@ -42,6 +42,12 @@ def test_seed(): | |||
| with pytest.raises(TypeError): | |||
| msd.Normal(0., 1., seed='seed') | |||
| def test_sd(): | |||
| with pytest.raises(ValueError): | |||
| msd.Normal(0., 0.) | |||
| with pytest.raises(ValueError): | |||
| msd.Normal(0., -1.) | |||
| def test_arguments(): | |||
| """ | |||
| args passing during initialization. | |||