| @@ -26,7 +26,7 @@ class Bijector(Cell): | |||||
| is_constant_jacobian (bool): if the bijector has constant derivative. Default: False. | is_constant_jacobian (bool): if the bijector has constant derivative. Default: False. | ||||
| is_injective (bool): if the bijector is an one-to-one mapping. Default: True. | is_injective (bool): if the bijector is an one-to-one mapping. Default: True. | ||||
| name (str): name of the bijector. Default: None. | name (str): name of the bijector. Default: None. | ||||
| dtype (mstype): type of the distribution the bijector can operate on. Default: None. | |||||
| dtype (mindspore.dtype): type of the distribution the bijector can operate on. Default: None. | |||||
| param (dict): parameters used to initialize the bijector. Default: None. | param (dict): parameters used to initialize the bijector. Default: None. | ||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -110,7 +110,7 @@ class Bijector(Cell): | |||||
| *args: args[0] shall be either a distribution or the name of a bijector function. | *args: args[0] shall be either a distribution or the name of a bijector function. | ||||
| """ | """ | ||||
| if isinstance(args[0], Distribution): | if isinstance(args[0], Distribution): | ||||
| return TransformedDistribution(self, args[0]) | |||||
| return TransformedDistribution(self, args[0], self.distribution.dtype) | |||||
| return super(Bijector, self).__call__(*args, **kwargs) | return super(Bijector, self).__call__(*args, **kwargs) | ||||
| def construct(self, name, *args, **kwargs): | def construct(self, name, *args, **kwargs): | ||||
| @@ -22,7 +22,10 @@ from .bijector import Bijector | |||||
| class Softplus(Bijector): | class Softplus(Bijector): | ||||
| r""" | r""" | ||||
| Softplus Bijector. | Softplus Bijector. | ||||
| This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor. | |||||
| This Bijector performs the operation, where k is the sharpness factor. | |||||
| .. math:: | |||||
| Y = \frac{\log(1 + e ^ {kX})}{k} | |||||
| Args: | Args: | ||||
| sharpness (float): scale factor. Default: 1.0. | sharpness (float): scale factor. Default: 1.0. | ||||
| @@ -184,7 +184,7 @@ def check_greater(a, b, name_a, name_b): | |||||
| def check_prob(p): | def check_prob(p): | ||||
| """ | """ | ||||
| Check if p is a proper probability, i.e. 0 <= p <=1. | |||||
| Check if p is a proper probability, i.e. 0 < p <1. | |||||
| Args: | Args: | ||||
| p (Tensor, Parameter): value to be checked. | p (Tensor, Parameter): value to be checked. | ||||
| @@ -196,12 +196,12 @@ def check_prob(p): | |||||
| if not isinstance(p.default_input, Tensor): | if not isinstance(p.default_input, Tensor): | ||||
| return | return | ||||
| p = p.default_input | p = p.default_input | ||||
| comp = np.less(p.asnumpy(), np.zeros(p.shape)) | |||||
| if comp.any(): | |||||
| raise ValueError('Probabilities should be greater than or equal to zero') | |||||
| comp = np.greater(p.asnumpy(), np.ones(p.shape)) | |||||
| if comp.any(): | |||||
| raise ValueError('Probabilities should be less than or equal to one') | |||||
| comp = np.less(np.zeros(p.shape), p.asnumpy()) | |||||
| if not comp.all(): | |||||
| raise ValueError('Probabilities should be greater than zero') | |||||
| comp = np.greater(np.ones(p.shape), p.asnumpy()) | |||||
| if not comp.all(): | |||||
| raise ValueError('Probabilities should be less than one') | |||||
| def logits_to_probs(logits, is_binary=False): | def logits_to_probs(logits, is_binary=False): | ||||
| @@ -110,6 +110,7 @@ class Bernoulli(Distribution): | |||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.dtypeop = P.DType() | self.dtypeop = P.DType() | ||||
| self.erf = P.Erf() | self.erf = P.Erf() | ||||
| self.exp = P.Exp() | |||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| self.log = P.Log() | self.log = P.Log() | ||||
| self.less = P.Less() | self.less = P.Less() | ||||
| @@ -159,7 +160,7 @@ class Bernoulli(Distribution): | |||||
| """ | """ | ||||
| probs1 = self.probs if probs1 is None else probs1 | probs1 = self.probs if probs1 is None else probs1 | ||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return probs0 * probs1 | |||||
| return self.exp(self.log(probs0) + self.log(probs1)) | |||||
| def _entropy(self, probs=None): | def _entropy(self, probs=None): | ||||
| r""" | r""" | ||||
| @@ -183,7 +184,7 @@ class Bernoulli(Distribution): | |||||
| return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) | return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) | ||||
| return None | return None | ||||
| def _prob(self, value, probs=None): | |||||
| def _log_prob(self, value, probs=None): | |||||
| r""" | r""" | ||||
| pmf of Bernoulli distribution. | pmf of Bernoulli distribution. | ||||
| @@ -197,7 +198,7 @@ class Bernoulli(Distribution): | |||||
| """ | """ | ||||
| probs1 = self.probs if probs is None else probs | probs1 = self.probs if probs is None else probs | ||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return (probs1 * value) + (probs0 * (1.0 - value)) | |||||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | |||||
| def _cdf(self, value, probs=None): | def _cdf(self, value, probs=None): | ||||
| r""" | r""" | ||||
| @@ -15,6 +15,7 @@ | |||||
| """basic""" | """basic""" | ||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | |||||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param | from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param | ||||
| class Distribution(Cell): | class Distribution(Cell): | ||||
| @@ -28,12 +29,15 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Derived class should override operations such as ,_mean, _prob, | Derived class should override operations such as ,_mean, _prob, | ||||
| and _log_prob. Arguments should be passed in through *args or **kwargs. | |||||
| and _log_prob. Required arguments, such as value for _prob, | |||||
| should be passed in through args or kwargs. dist_spec_args which specify | |||||
| a new distribution are optional. | |||||
| Dist_spec_args are unique for each type of distribution. For example, mean and sd | |||||
| are the dist_spec_args for a Normal distribution. | |||||
| dist_spec_args are unique for each type of distribution. For example, mean and sd | |||||
| are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args | |||||
| for exponential distribution. | |||||
| For all functions, passing in dist_spec_args, are optional. | |||||
| For all functions, passing in dist_spec_args, is optional. | |||||
| Passing in the additional dist_spec_args will make the result to be evaluated with | Passing in the additional dist_spec_args will make the result to be evaluated with | ||||
| new distribution specified by the dist_spec_args. But it won't change the | new distribution specified by the dist_spec_args. But it won't change the | ||||
| original distribuion. | original distribuion. | ||||
| @@ -49,7 +53,7 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| super(Distribution, self).__init__() | super(Distribution, self).__init__() | ||||
| validator.check_value_type('name', name, [str], 'distribution_name') | validator.check_value_type('name', name, [str], 'distribution_name') | ||||
| validator.check_value_type('seed', seed, [int], name) | |||||
| validator.check_integer('seed', seed, 0, Rel.GE, name) | |||||
| self._name = name | self._name = name | ||||
| self._seed = seed | self._seed = seed | ||||
| @@ -191,7 +195,7 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Args must include value. | Args must include value. | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_log_prob(*args, **kwargs) | return self._call_log_prob(*args, **kwargs) | ||||
| @@ -210,7 +214,7 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Args must include value. | Args must include value. | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_prob(*args, **kwargs) | return self._call_prob(*args, **kwargs) | ||||
| @@ -229,7 +233,7 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Args must include value. | Args must include value. | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_cdf(*args, **kwargs) | return self._call_cdf(*args, **kwargs) | ||||
| @@ -266,7 +270,7 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Args must include value. | Args must include value. | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_log_cdf(*args, **kwargs) | return self._call_log_cdf(*args, **kwargs) | ||||
| @@ -285,7 +289,7 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Args must include value. | Args must include value. | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_survival(*args, **kwargs) | return self._call_survival(*args, **kwargs) | ||||
| @@ -313,7 +317,7 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Args must include value. | Args must include value. | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_log_survival(*args, **kwargs) | return self._call_log_survival(*args, **kwargs) | ||||
| @@ -341,7 +345,7 @@ class Distribution(Cell): | |||||
| Evaluate the mean. | Evaluate the mean. | ||||
| Note: | Note: | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._mean(*args, **kwargs) | return self._mean(*args, **kwargs) | ||||
| @@ -350,7 +354,7 @@ class Distribution(Cell): | |||||
| Evaluate the mode. | Evaluate the mode. | ||||
| Note: | Note: | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._mode(*args, **kwargs) | return self._mode(*args, **kwargs) | ||||
| @@ -359,7 +363,7 @@ class Distribution(Cell): | |||||
| Evaluate the standard deviation. | Evaluate the standard deviation. | ||||
| Note: | Note: | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_sd(*args, **kwargs) | return self._call_sd(*args, **kwargs) | ||||
| @@ -368,7 +372,7 @@ class Distribution(Cell): | |||||
| Evaluate the variance. | Evaluate the variance. | ||||
| Note: | Note: | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_var(*args, **kwargs) | return self._call_var(*args, **kwargs) | ||||
| @@ -395,7 +399,7 @@ class Distribution(Cell): | |||||
| Evaluate the entropy. | Evaluate the entropy. | ||||
| Note: | Note: | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._entropy(*args, **kwargs) | return self._entropy(*args, **kwargs) | ||||
| @@ -424,7 +428,7 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Shape of the sample is default to (). | Shape of the sample is default to (). | ||||
| Dist_spec_args are optional. | |||||
| dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._sample(*args, **kwargs) | return self._sample(*args, **kwargs) | ||||
| @@ -199,7 +199,7 @@ class Exponential(Distribution): | |||||
| pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 | pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 | ||||
| """ | """ | ||||
| rate = self.rate if rate is None else rate | rate = self.rate if rate is None else rate | ||||
| prob = rate * self.exp(-1. * rate * value) | |||||
| prob = self.exp(self.log(rate) - rate * value) | |||||
| zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | ||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| return self.select(comp, zeros, prob) | return self.select(comp, zeros, prob) | ||||
| @@ -113,6 +113,7 @@ class Geometric(Distribution): | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.dtypeop = P.DType() | self.dtypeop = P.DType() | ||||
| self.exp = P.Exp() | |||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| self.floor = P.Floor() | self.floor = P.Floor() | ||||
| self.issubclass = P.IsSubClass() | self.issubclass = P.IsSubClass() | ||||
| @@ -205,7 +206,7 @@ class Geometric(Distribution): | |||||
| value = self.floor(value) | value = self.floor(value) | ||||
| else: | else: | ||||
| return None | return None | ||||
| pmf = self.pow((1.0 - probs1), value) * probs1 | |||||
| pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | |||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | ||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| return self.select(comp, zeros, pmf) | return self.select(comp, zeros, pmf) | ||||
| @@ -18,7 +18,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import convert_to_batch, check_greater_equal_zero, check_type | |||||
| from ._utils.utils import convert_to_batch, check_greater_zero, check_type | |||||
| class Normal(Distribution): | class Normal(Distribution): | ||||
| @@ -106,7 +106,7 @@ class Normal(Distribution): | |||||
| if mean is not None and sd is not None: | if mean is not None and sd is not None: | ||||
| self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype) | self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype) | ||||
| self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype) | self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype) | ||||
| check_greater_equal_zero(self._sd_value, "Standard deviation") | |||||
| check_greater_zero(self._sd_value, "Standard deviation") | |||||
| else: | else: | ||||
| self._mean_value = mean | self._mean_value = mean | ||||
| self._sd_value = sd | self._sd_value = sd | ||||
| @@ -166,7 +166,7 @@ class Normal(Distribution): | |||||
| H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | ||||
| """ | """ | ||||
| sd = self._sd_value if sd is None else sd | sd = self._sd_value if sd is None else sd | ||||
| return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) | |||||
| return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) | |||||
| def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): | def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): | ||||
| r""" | r""" | ||||
| @@ -198,7 +198,7 @@ class Normal(Distribution): | |||||
| mean = self._mean_value if mean is None else mean | mean = self._mean_value if mean is None else mean | ||||
| sd = self._sd_value if sd is None else sd | sd = self._sd_value if sd is None else sd | ||||
| unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | ||||
| neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) | |||||
| neg_normalization = -1. * self.log(self.sqrt(self.const(2. * np.pi))) - self.log(sd) | |||||
| return unnormalized_log_prob + neg_normalization | return unnormalized_log_prob + neg_normalization | ||||
| def _cdf(self, value, mean=None, sd=None): | def _cdf(self, value, mean=None, sd=None): | ||||
| @@ -216,8 +216,8 @@ class Uniform(Distribution): | |||||
| """ | """ | ||||
| low = self.low if low is None else low | low = self.low if low is None else low | ||||
| high = self.high if high is None else high | high = self.high if high is None else high | ||||
| ones = self.fill(self.dtype, self.shape(value), 1.0) | |||||
| prob = ones / (high - low) | |||||
| neg_ones = self.fill(self.dtype, self.shape(value), -1.0) | |||||
| prob = self.exp(neg_ones * self.log(high - low)) | |||||
| broadcast_shape = self.shape(prob) | broadcast_shape = self.shape(prob) | ||||
| zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | ||||
| comp_lo = self.less(value, low) | comp_lo = self.less(value, low) | ||||
| @@ -28,7 +28,7 @@ def test_arguments(): | |||||
| """ | """ | ||||
| b = msd.Bernoulli() | b = msd.Bernoulli() | ||||
| assert isinstance(b, msd.Distribution) | assert isinstance(b, msd.Distribution) | ||||
| b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) | |||||
| b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | |||||
| assert isinstance(b, msd.Distribution) | assert isinstance(b, msd.Distribution) | ||||
| def test_type(): | def test_type(): | ||||
| @@ -51,6 +51,10 @@ def test_prob(): | |||||
| msd.Bernoulli([-0.1], dtype=dtype.int32) | msd.Bernoulli([-0.1], dtype=dtype.int32) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| msd.Bernoulli([1.1], dtype=dtype.int32) | msd.Bernoulli([1.1], dtype=dtype.int32) | ||||
| with pytest.raises(ValueError): | |||||
| msd.Bernoulli([0.0], dtype=dtype.int32) | |||||
| with pytest.raises(ValueError): | |||||
| msd.Bernoulli([1.0], dtype=dtype.int32) | |||||
| class BernoulliProb(nn.Cell): | class BernoulliProb(nn.Cell): | ||||
| """ | """ | ||||
| @@ -29,7 +29,7 @@ def test_arguments(): | |||||
| """ | """ | ||||
| g = msd.Geometric() | g = msd.Geometric() | ||||
| assert isinstance(g, msd.Distribution) | assert isinstance(g, msd.Distribution) | ||||
| g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) | |||||
| g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | |||||
| assert isinstance(g, msd.Distribution) | assert isinstance(g, msd.Distribution) | ||||
| def test_type(): | def test_type(): | ||||
| @@ -52,6 +52,10 @@ def test_prob(): | |||||
| msd.Geometric([-0.1], dtype=dtype.int32) | msd.Geometric([-0.1], dtype=dtype.int32) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| msd.Geometric([1.1], dtype=dtype.int32) | msd.Geometric([1.1], dtype=dtype.int32) | ||||
| with pytest.raises(ValueError): | |||||
| msd.Geometric([0.0], dtype=dtype.int32) | |||||
| with pytest.raises(ValueError): | |||||
| msd.Geometric([1.0], dtype=dtype.int32) | |||||
| class GeometricProb(nn.Cell): | class GeometricProb(nn.Cell): | ||||
| """ | """ | ||||
| @@ -42,6 +42,12 @@ def test_seed(): | |||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| msd.Normal(0., 1., seed='seed') | msd.Normal(0., 1., seed='seed') | ||||
| def test_sd(): | |||||
| with pytest.raises(ValueError): | |||||
| msd.Normal(0., 0.) | |||||
| with pytest.raises(ValueError): | |||||
| msd.Normal(0., -1.) | |||||
| def test_arguments(): | def test_arguments(): | ||||
| """ | """ | ||||
| args passing during initialization. | args passing during initialization. | ||||