Browse Source

!9136 Fix issues in distribution class

From: @shallydeng
Reviewed-by: @sunnybeike
Signed-off-by: @sunnybeike
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
dd64eadd75
7 changed files with 437 additions and 321 deletions
  1. +6
    -3
      mindspore/nn/probability/distribution/_utils/utils.py
  2. +91
    -64
      mindspore/nn/probability/distribution/categorical.py
  3. +89
    -75
      mindspore/nn/probability/distribution/cauchy.py
  4. +43
    -48
      mindspore/nn/probability/distribution/gumbel.py
  5. +142
    -77
      mindspore/nn/probability/distribution/log_normal.py
  6. +65
    -53
      mindspore/nn/probability/distribution/logistic.py
  7. +1
    -1
      mindspore/nn/probability/distribution/uniform.py

+ 6
- 3
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -168,9 +168,12 @@ def check_sum_equal_one(probs):
if not isinstance(probs.data, Tensor):
return
probs = probs.data
prob_sum = np.sum(probs.asnumpy(), axis=-1)
comp = np.equal(np.ones(prob_sum.shape), prob_sum)
if not comp.all():
if isinstance(probs, Tensor):
probs = probs.asnumpy()
prob_sum = np.sum(probs, axis=-1)
# add a small tolerance here to increase numerical stability
comp = np.allclose(prob_sum, np.ones(prob_sum.shape), rtol=1e-14, atol=1e-14)
if not comp:
raise ValueError('Probabilities for each category should sum to one for Categorical distribution.')

def check_rank(probs):


+ 91
- 64
mindspore/nn/probability/distribution/categorical.py View File

@@ -40,65 +40,84 @@ class Categorical(Distribution):
`probs` must have rank at least 1, values are proper probabilities and sum to 1.
Examples:
>>> # To initialize a Categorical distribution of probs [0.5, 0.5]
>>> import mindspore
>>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd
>>> b = msd.Categorical(probs = [0.5, 0.5], dtype=mstype.int32)
>>>
>>> # To use a Categorical distribution in a network
>>> class net(Cell):
... def __init__(self, probs):
... super(net, self).__init__():
... self.ca = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
... self.ca1 = msd.Categorical(dtype=mstype.int32)
...
... # All the following calls in construct are valid
... def construct(self, value):
...
... # Private interfaces of probability functions corresponding to public interfaces, including
... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
... # Args:
... # value (Tensor): the value to be evaluated.
... # probs (Tensor): event probabilities. Default: self.probs.
...
... # Examples of `prob`.
... # Similar calls can be made to other probability functions
... # by replacing `prob` by the name of the function.
... ans = self.ca.prob(value)
... # Evaluate `prob` with respect to distribution b.
... ans = self.ca.prob(value, probs_b)
... # `probs` must be passed in during function calls.
... ans = self.ca1.prob(value, probs_a)
...
... # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
... # Args:
... # probs (Tensor): event probabilities. Default: self.probs.
...
... # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
... ans = self.ca.mean() # return 0.8
... ans = self.ca.mean(probs_b)
... # `probs` must be passed in during function calls.
... ans = self.ca1.mean(probs_a)
...
... # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
... # Args:
... # dist (str): the name of the distribution. Only 'Categorical' is supported.
... # probs_b (Tensor): event probabilities of distribution b.
... # probs (Tensor): event probabilities of distribution a. Default: self.probs.
...
... # Examples of kl_loss. `cross_entropy` is similar.
... ans = self.ca.kl_loss('Categorical', probs_b)
... ans = self.ca.kl_loss('Categorical', probs_b, probs_a)
... # An additional `probs` must be passed in.
... ans = self.ca1.kl_loss('Categorical', probs_b, probs_a)
...
... # Examples of `sample`.
... # Args:
... # shape (tuple): the shape of the sample. Default: ().
... # probs (Tensor): event probabilities. Default: self.probs.
... ans = self.ca.sample()
... ans = self.ca.sample((2,3))
... ans = self.ca.sample((2,3), probs_b)
... ans = self.ca1.sample((2,3), probs_a)
>>> from mindspore import Tensor
>>> # To initialize a Categorical distribution of probs [0.5, 0.5]
>>> ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mindspore.int32)
>>> # A Categorical distribution can be initialized without arguments.
>>> # In this case, `probs` must be passed in through arguments during function calls.
>>> ca2 = msd.Categorical(dtype=mindspore.int32)
>>> # Here are some tensors used below for testing
>>> value = Tensor([1, 0], dtype=mindspore.int32)
>>> probs_a = Tensor([0.5, 0.5], dtype=mindspore.float32)
>>> probs_b = Tensor([0.35, 0.65], dtype=mindspore.float32)
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>> # probs (Tensor): event probabilities. Default: self.probs.
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing `prob` by the name of the function.
>>> ans = ca1.prob(value)
>>> print(ans)
[0.8 0.2]
>>> # Evaluate `prob` with respect to distribution b.
>>> ans = ca1.prob(value, probs_b)
>>> print(ans)
[0.65 0.35]
>>> # `probs` must be passed in during function calls.
>>> ans = ca2.prob(value, probs_a)
>>> print(ans)
[0.5 0.5]
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
>>> # Args:
>>> # probs (Tensor): event probabilities. Default: self.probs.
>>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
>>> ans = ca1.mean() # return 0.8
>>> print(ans)
[0.8]
>>> ans = ca1.mean(probs_b)
>>> print(ans)
[0.65]
>>> # `probs` must be passed in during function calls.
>>> ans = ca2.mean(probs_a)
>>> print(ans)
[0.5]
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
>>> # Args:
>>> # dist (str): the name of the distribution. Only 'Categorical' is supported.
>>> # probs_b (Tensor): event probabilities of distribution b.
>>> # probs (Tensor): event probabilities of distribution a. Default: self.probs.
>>> # Examples of kl_loss. `cross_entropy` is similar.
>>> ans = ca1.kl_loss('Categorical', probs_b)
>>> print(ans)
0.05418826
>>> ans = ca1.kl_loss('Categorical', probs_b, probs_a)
>>> print(ans)
0.04715523
>>> # An additional `probs` must be passed in.
>>> ans = ca2.kl_loss('Categorical', probs_b, probs_a)
>>> print(ans)
0.04715523
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ().
>>> # probs (Tensor): event probabilities. Default: self.probs.
>>> ans = ca1.sample()
>>> print(ans.shape)
()
>>> ans = ca1.sample((2,3))
>>> print(ans.shape)
(2, 3)
>>> ans = ca1.sample((2,3), probs_b)
>>> print(ans.shape)
(2, 3)
>>> ans = ca2.sample((2,3), probs_a)
>>> print(ans.shape)
(2, 3)
"""
def __init__(self,
@@ -108,7 +127,7 @@ class Categorical(Distribution):
name="Categorical"):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.float_type
valid_dtype = mstype.uint_type + mstype.int_type + mstype.float_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Categorical, self).__init__(seed, dtype, name, param)
@@ -116,7 +135,7 @@ class Categorical(Distribution):
if self.probs is not None:
check_rank(self.probs)
check_prob(self.probs)
check_sum_equal_one(self.probs)
check_sum_equal_one(probs)
# update is_scalar_batch and broadcast_shape
# drop one dimension
@@ -124,7 +143,7 @@ class Categorical(Distribution):
self._is_scalar_batch = True
self._broadcast_shape = self._broadcast_shape[:-1]
self.argmax = P.Argmax()
self.argmax = P.ArgMaxWithValue(axis=-1)
self.broadcast = broadcast_to
self.cast = P.Cast()
self.clip_by_value = C.clip_by_value
@@ -140,6 +159,7 @@ class Categorical(Distribution):
self.log = log_generic
self.log_softmax = P.LogSoftmax()
self.logicor = P.LogicalOr()
self.logicand = P.LogicalAnd()
self.multinomial = P.Multinomial(seed=self.seed)
self.reshape = P.Reshape()
self.reduce_sum = P.ReduceSum(keep_dims=True)
@@ -192,8 +212,9 @@ class Categorical(Distribution):
def _mode(self, probs=None):
probs = self._check_param_type(probs)
mode = self.cast(self.argmax(probs), self.dtype)
return self.squeeze(mode)
index, _ = self.argmax(probs)
mode = self.cast(index, self.dtype)
return mode
def _var(self, probs=None):
r"""
@@ -232,7 +253,7 @@ class Categorical(Distribution):
probs_a = self._check_param_type(probs)
logits_a = self.log(probs_a)
logits_b = self.log(probs_b)
return self.squeeze(-self.reduce_sum(
return self.squeeze(self.reduce_sum(
self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1))
def _cross_entropy(self, dist, probs_b, probs=None):
@@ -256,10 +277,13 @@ class Categorical(Distribution):
probs (Tensor): Event probabilities. Default: self.probs.
"""
value = self._check_value(value, 'value')
# cast value to int to find the right integer to compute index
if self.issubclass(self.dtype, mstype.float_):
value = self.cast(value, self.index_type)
else:
value = self.cast(value, self.dtype)
# cast int to float for the broadcasting below
value = self.cast(value, mstype.float32)
probs = self._check_param_type(probs)
logits = self.log(probs)
@@ -289,6 +313,9 @@ class Categorical(Distribution):
value = self.reshape(value, (-1, 1))
out_of_bound = self.squeeze_last_axis(self.logicor(\
self.less(value, 0.0), self.less(num_classes-1, value)))
# deal with the case the there is only one class.
zeros = self.fill(mstype.float32, self.shape(out_of_bound), 0.0)
out_of_bound = self.logicand(out_of_bound, self.less(zeros, num_classes-1))
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
value_clipped = self.cast(value_clipped, self.index_type)
# create index from 0 ... NumOfLabels


+ 89
- 75
mindspore/nn/probability/distribution/cauchy.py View File

@@ -41,77 +41,91 @@ class Cauchy(Distribution):
Cauchy distribution is not supported on GPU backend.

Examples:
>>> # To initialize a Cauchy distribution of loc 3.0 and scale 4.0.
>>> import mindspore
>>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd
>>> cauchy = msd.Cauchy(3.0, 4.0, dtype=mstype.float32)
>>>
>>> # The following creates two independent Cauchy distributions.
>>> cauchy = msd.Cauchy([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>> from mindspore import Tensor
>>> # To initialize a Cauchy distribution of loc 3.0 and scale 4.0.
>>> cauchy1 = msd.Cauchy(3.0, 4.0, dtype=mindspore.float32)
>>> # A Cauchy distribution can be initialized without arguments.
>>> # In this case, 'loc' and `scale` must be passed in through arguments.
>>> cauchy = msd.Cauchy(dtype=mstype.float32)
>>>
>>> # To use a Cauchy distribution in a network.
>>> class net(Cell):
... def __init__(self):
... super(net, self).__init__():
... self.cau1 = msd.Cauchy(0.0, 1.0, dtype=mstype.float32)
... self.cau2 = msd.Cauchy(dtype=mstype.float32)
...
... # The following calls are valid in construct.
... def construct(self, value, loc_b, scale_b, loc_a, scale_a):
...
... # Private interfaces of probability functions corresponding to public interfaces, including
... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows.
... # Args:
... # value (Tensor): the value to be evaluated.
... # loc (Tensor): the location of the distribution. Default: self.loc.
... # scale (Tensor): the scale of the distribution. Default: self.scale.
...
... # Examples of `prob`.
... # Similar calls can be made to other probability functions
... # by replacing 'prob' by the name of the function
... ans = self.cau1.prob(value)
... # Evaluate with respect to distribution b.
... ans = self.cau1.prob(value, loc_b, scale_b)
... # `loc` and `scale` must be passed in during function calls
... ans = self.cau2.prob(value, loc_a, scale_a)
...
... # Functions `mode` and `entropy` have the same arguments.
... # Args:
... # loc (Tensor): the location of the distribution. Default: self.loc.
... # scale (Tensor): the scale of the distribution. Default: self.scale.
...
... # Example of `mode`.
... ans = self.cau1.mode() # return 0.0
... ans = self.cau1.mode(loc_b, scale_b) # return loc_b
... # `loc` and `scale` must be passed in during function calls.
... ans = self.cau2.mode(loc_a, scale_a)
...
... # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
... # Args:
... # dist (str): the type of the distributions. Only "Cauchy" is supported.
... # loc_b (Tensor): the loc of distribution b.
... # scale_b (Tensor): the scale distribution b.
... # loc (Tensor): the loc of distribution a. Default: self.loc.
... # scale (Tensor): the scale distribution a. Default: self.scale.
...
... # Examples of `kl_loss`. `cross_entropy` is similar.
... ans = self.cau1.kl_loss('Cauchy', loc_b, scale_b)
... ans = self.cau1.kl_loss('Cauchy', loc_b, scale_b, loc_a, scale_a)
... # Additional `loc` and `scale` must be passed in.
... ans = self.cau2.kl_loss('Cauchy', loc_b, scale_b, loc_a, scale_a)
...
... # Examples of `sample`.
... # Args:
... # shape (tuple): the shape of the sample. Default: ()
... # loc (Tensor): the location of the distribution. Default: self.loc.
... # scale (Tensor): the scale of the distribution. Default: self.scale.
... ans = self.cau1.sample()
... ans = self.cau1.sample((2,3))
... ans = self.cau1.sample((2,3), loc_b, s_b)
... ans = self.cau2.sample((2,3), loc_a, s_a)
>>> cauchy2 = msd.Cauchy(dtype=mindspore.float32)
>>> # Here are some tensors used below for testing
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
>>> loc_a = Tensor([2.0], dtype=mindspore.float32)
>>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
>>> loc_b = Tensor([1.0], dtype=mindspore.float32)
>>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>> # loc (Tensor): the location of the distribution. Default: self.loc.
>>> # scale (Tensor): the scale of the distribution. Default: self.scale.
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' by the name of the function
>>> ans = cauchy1.prob(value)
>>> print(ans)
[0.06366198 0.07489645 0.07957747]
>>> # Evaluate with respect to distribution b.
>>> ans = cauchy1.prob(value, loc_b, scale_b)
>>> print(ans)
[0.31830987 0.14691226 0.07957747]
>>> # `loc` and `scale` must be passed in during function calls
>>> ans = cauchy2.prob(value, loc_a, scale_a)
>>> print(ans)
[0.12732396 0.15915494 0.12732396]
>>> # Functions `mode` and `entropy` have the same arguments.
>>> # Args:
>>> # loc (Tensor): the location of the distribution. Default: self.loc.
>>> # scale (Tensor): the scale of the distribution. Default: self.scale.
>>> # Example of `mode`.
>>> ans = cauchy1.mode() # return 3.0
>>> print(ans)
3.0
>>> ans = cauchy1.mode(loc_b, scale_b) # return loc_b
>>> print(ans)
[1. 1. 1.]
>>> # `loc` and `scale` must be passed in during function calls.
>>> ans = cauchy2.mode(loc_a, scale_a)
>>> print(ans)
[2. 2. 2.]
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
>>> # Args:
>>> # dist (str): the type of the distributions. Only "Cauchy" is supported.
>>> # loc_b (Tensor): the loc of distribution b.
>>> # scale_b (Tensor): the scale distribution b.
>>> # loc (Tensor): the loc of distribution a. Default: self.loc.
>>> # scale (Tensor): the scale distribution a. Default: self.scale.
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
>>> ans = cauchy1.kl_loss('Cauchy', loc_b, scale_b)
>>> print(ans)
[0.594707 0.35563278 0.22314358]
>>> ans = cauchy1.kl_loss('Cauchy', loc_b, scale_b, loc_a, scale_a)
>>> print(ans)
[0.22314358 0.09909081 0.0606246 ]
>>> # Additional `loc` and `scale` must be passed in.
>>> ans = cauchy2.kl_loss('Cauchy', loc_b, scale_b, loc_a, scale_a)
>>> print(ans)
[0.22314358 0.09909081 0.0606246 ]
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ()
>>> # loc (Tensor): the location of the distribution. Default: self.loc.
>>> # scale (Tensor): the scale of the distribution. Default: self.scale.
>>> ans = cauchy1.sample()
>>> print(ans.shape)
()
>>> ans = cauchy1.sample((2,3))
>>> print(ans.shape)
(2, 3)
>>> ans = cauchy1.sample((2,3), loc_b, scale_b)
>>> print(ans.shape)
(2, 3, 3)
>>> ans = cauchy2.sample((2,3), loc_a, scale_a)
>>> print(ans.shape)
(2, 3, 3)
"""

def __init__(self,
@@ -275,7 +289,7 @@ class Cauchy(Distribution):
loc, scale = self._check_param_type(loc, scale)
return loc + scale * self.tan(np.pi * (p - 0.5))

def _kl_loss(self, dist, loc_b, scale_b, loc=None, scale=None):
def _kl_loss(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
r"""
Evaluate Cauchy-Cauchy kl divergence, i.e. KL(a||b).

@@ -291,17 +305,17 @@ class Cauchy(Distribution):
{4 * scale_a * scale_b})
"""
check_distribution_name(dist, 'Cauchy')
loc, scale = self._check_param_type(loc, scale)
loc_a, scale_a = self._check_param_type(loc_a, scale_a)
loc_b = self._check_value(loc_b, 'loc_b')
loc_b = self.cast(loc_b, self.parameter_type)
scale_b = self._check_value(scale_b, 'scale_b')
scale_b = self.cast(scale_b, self.parameter_type)
sum_square = self.sq(scale + scale_b)
square_diff = self.sq(loc - loc_b)
sum_square = self.sq(scale_a + scale_b)
square_diff = self.sq(loc_a - loc_b)
return self.log(sum_square + square_diff) - \
self.log(self.const(4.0)) - self.log(scale) - self.log(scale_b)
self.log(self.const(4.0)) - self.log(scale_a) - self.log(scale_b)

def _cross_entropy(self, dist, loc_b, scale_b, loc=None, scale=None):
def _cross_entropy(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
r"""
Evaluate cross entropy between Cauchy distributions.

@@ -313,7 +327,7 @@ class Cauchy(Distribution):
scale (Tensor): The scale of distribution a. Default: self.scale.
"""
check_distribution_name(dist, 'Cauchy')
return self._entropy(loc, scale) + self._kl_loss(dist, loc_b, scale_b, loc, scale)
return self._entropy(loc_a, scale_a) + self._kl_loss(dist, loc_b, scale_b, loc_a, scale_a)

def _sample(self, shape=(), loc=None, scale=None):
"""


+ 43
- 48
mindspore/nn/probability/distribution/gumbel.py View File

@@ -42,55 +42,49 @@ class Gumbel(TransformedDistribution):
`kl_loss` and `cross_entropy` are not supported on GPU backend.

Examples:
>>> import mindspore
>>> import mindspore.context as context
>>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd
>>> from mindspore import Tensor
>>> context.set_context(mode=1, device_target="GPU")
>>> # To initialize a Gumbel distribution of `loc` 3.0 and `scale` 4.0.
>>> gum = msd.Gumbel(3.0, 4.0, dtype=mstype.float32)
>>>
>>> # The following creates two independent Gumbel distributions.
>>> gum = msd.Gumbel([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>> # To use a Gumbel distribution in a network.
>>> class net(Cell):
... def __init__(self):
... super(net, self).__init__():
... self.g1 = msd.Gumbel(0.0, 1.0, dtype=mstype.float32)
...
... # The following calls are valid in construct.
... def construct(self, value, loc_b, scale_b):
...
... # Private interfaces of probability functions corresponding to public interfaces, including
... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
... # arguments as follows.
... # Args:
... # value (Tensor): the value to be evaluated.
...
... # Examples of `prob`.
... # Similar calls can be made to other probability functions
... # by replacing 'prob' by the name of the function.
... ans = self.g1.prob(value)
...
... # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument.
... ans = self.g1.mean()
... ans = self.g1.mode()
... ans = self.g1.sd()
... ans = self.g1.entropy()
... ans = self.g1.var()
...
... # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
... # Args:
... # dist (str): the type of the distributions. Only "Gumbel" is supported.
... # loc_b (Tensor): the loc of distribution b.
... # scale_b (Tensor): the scale distribution b.
...
... # Examples of `kl_loss`. `cross_entropy` is similar.
... ans = self.g1.kl_loss('Gumbel', loc_b, scale_b)
... ans = self.g1.cross_entropy('Gumbel', loc_b, scale_b)
...
... # Examples of `sample`.
... # Args:
... # shape (tuple): the shape of the sample. Default: ()
...
... ans = self.g1.sample()
... ans = self.g1.sample((2,3))
>>> gumbel = msd.Gumbel(3.0, 4.0, dtype=mindspore.float32)
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
>>> # arguments as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' by the name of the function.
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
>>> ans = gumbel.prob(value)
>>> print(ans)
[0.07926048 0.08889321 0.09196986]
>>> # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument.
>>> ans = gumbel.mean()
>>> print(ans)
5.3088627
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
>>> # Args:
>>> # dist (str): the type of the distributions. Only "Gumbel" is supported.
>>> # loc_b (Tensor): the loc of distribution b.
>>> # scale_b (Tensor): the scale distribution b.
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
>>> loc_b = Tensor([1.0], dtype=mindspore.float32)
>>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
>>> ans = gumbel.kl_loss('Gumbel', loc_b, scale_b)
>>> print(ans)
[ 2.5934026 0.03880269 -0.38017237]
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ()
>>> ans = gumbel.sample()
>>> print(ans.shape)
()
>>> ans = gumbel.sample((2,3))
>>> print(ans.shape)
"""

def __init__(self,
@@ -125,6 +119,7 @@ class Gumbel(TransformedDistribution):
self.lgamma = nn.LGamma()
self.log = log_generic
self.shape = P.Shape()
self.squeeze = P.Squeeze(0)
self.sqrt = P.Sqrt()

@property


+ 142
- 77
mindspore/nn/probability/distribution/log_normal.py View File

@@ -41,87 +41,104 @@ class LogNormal(msd.TransformedDistribution):
`dtype` must be a float type because LogNormal distributions are continuous.

Examples:
>>> import mindspore
>>> import mindspore.context as context
>>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd
>>> from mindspore import Tensor
>>> context.set_context(mode=1)
>>> # To initialize a LogNormal distribution of `loc` 3.0 and `scale` 4.0.
>>> n = msd.LogNormal(3.0, 4.0, dtype=mstype.float32)
>>>
>>> # The following creates two independent LogNormal distributions.
>>> n = msd.LogNormal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>> n1 = msd.LogNormal(3.0, 4.0, dtype=mindspore.float32)
>>> # A LogNormal distribution can be initialized without arguments.
>>> # In this case, `loc` and `scale` must be passed in during function calls.
>>> n = msd.LogNormal(dtype=mstype.float32)
>>> n2 = msd.LogNormal(dtype=mindspore.float32)
>>>
>>> # Here are some tensors used below for testing
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
>>> loc_a = Tensor([2.0], dtype=mindspore.float32)
>>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
>>> loc_b = Tensor([1.0], dtype=mindspore.float32)
>>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
>>>
>>> # To use a LogNormal distribution in a network.
>>> class net(Cell):
... def __init__(self):
... super(net, self).__init__():
... self.n1 = msd.LogNormal(0.0, 1.0, dtype=mstype.float32)
... self.n2 = msd.LogNormal(dtype=mstype.float32)
...
... # The following calls are valid in construct.
... def construct(self, value, loc_b, scale_b, loc_a, scale_a):
...
... # Private interfaces of probability functions corresponding to public interfaces, including
... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
... # arguments as follows.
... # Args:
... # value (Tensor): the value to be evaluated.
... # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
... # the mean of the underlying Normal distribution will be used.
... # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
... # the standard deviation of the underlying Normal distribution will be used.
...
... # Examples of `prob`.
... # Similar calls can be made to other probability functions
... # by replacing 'prob' by the name of the function.
... ans = self.n1.prob(value)
... # Evaluate with respect to distribution b.
... ans = self.n1.prob(value, loc_b, scale_b)
... # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
... ans = self.n2.prob(value, loc_a, scale_a)
...
...
... # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
... # Args:
... # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
... # the mean of the underlying Normal distribution will be used.
... # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
... # the standard deviation of the underlying Normal distribution will be used.
...
... # Example of `mean`. `sd`, `var`, and `entropy` are similar.
... ans = self.n1.mean() # return 0.0
... ans = self.n1.mean(loc_b, scale_b) # return mean_b
... # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
... ans = self.n2.mean(loc_a, scale_a)
...
...
... # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
... # Args:
... # dist (str): the type of the distributions. Only "Normal" is supported.
... # loc_b (Tensor): the loc of distribution b.
... # scale_b (Tensor): the scale distribution b.
... # loc_a (Tensor): the loc of distribution a. Default: None. If `loc` is passed in as None,
... # the mean of the underlying Normal distribution will be used.
... # scale_a (Tensor): the scale distribution a. Default: None. If `scale` is passed in as None,
... # the standard deviation of the underlying Normal distribution will be used.
...
... # Examples of `kl_loss`. `cross_entropy` is similar.
... ans = self.n1.kl_loss('Normal', loc_b, scale_b)
... ans = self.n1.kl_loss('Normal', loc_b, scale_b, loc_a, scale_a)
... # Additional `loc` and `scale` must be passed in since they were not passed in construct.
... ans = self.n2.kl_loss('Normal', loc_b, scale_b, loc_a, scale_a)
...
... # Examples of `sample`.
... # Args:
... # shape (tuple): the shape of the sample. Default: ()
... # loc (Tensor): the loc of the distribution. Default: None. If `loc` is passed in as None,
... # the mean of the underlying Normal distribution will be used.
... # scale (Tensor): the scale of the distribution. Default: None. If `scale` is passed in as None,
... # the standard deviation of the underlying Normal distribution will be used.
... ans = self.n1.sample()
... ans = self.n1.sample((2,3))
... ans = self.n1.sample((2,3), loc_b, scale_b)
... ans = self.n2.sample((2,3), loc_a, scale_a)
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
>>> # arguments as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' by the name of the function.
>>> ans = n1.prob(value)
>>> print(ans)
[0.07528435 0.04222769 0.02969363]
>>> # Evaluate with respect to distribution b.
>>> ans = n1.prob(value, loc_b, scale_b)
>>> print(ans)
[0.24197072 0.13022715 0.0664096 ]
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
>>> ans = n2.prob(value, loc_a, scale_a)
>>> print(ans)
[0.12098535 0.08056299 0.06006904]
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
>>> # Args:
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> # Example of `mean`. `sd`, `var`, and `entropy` are similar.
>>> ans = n1.mean()
>>> print(ans)
59874.14
>>> ans = n1.mean(loc_b, scale_b)
>>> print(ans)
[ 4.481689 8.372897 20.085537]
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
>>> ans = n2.mean(loc_a, scale_a)
>>> print(ans)
[54.59815 54.59815 54.59815]
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
>>> # Args:
>>> # dist (str): the type of the distributions. Only "Normal" is supported.
>>> # loc_b (Tensor): the loc of distribution b.
>>> # scale_b (Tensor): the scale distribution b.
>>> # loc_a (Tensor): the loc of distribution a. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale_a (Tensor): the scale distribution a. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
>>> ans = n1.kl_loss('LogNormal', loc_b, scale_b)
>>> print(ans)
[8.113706 2.963615 1.3068528]
>>> ans = n1.kl_loss('LogNormal', loc_b, scale_b, loc_a, scale_a)
>>> print(ans)
[1.3068528 0.32342905 0.125 ]
>>> # Additional `loc` and `scale` must be passed in since they were not passed in construct.
>>> ans = n2.kl_loss('LogNormal', loc_b, scale_b, loc_a, scale_a)
>>> print(ans)
[1.3068528 0.32342905 0.125 ]
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ()
>>> # loc (Tensor): the loc of the distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of the distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> ans = n1.sample()
>>> print(ans.shape)
()
>>> ans = n1.sample((2,3))
>>> print(ans.shape)
(2, 3)
>>> ans = n1.sample((2,3), loc_b, scale_b)
>>> print(ans.shape)
(2, 3, 3)
>>> ans = n2.sample((2,3), loc_a, scale_a)
>>> print(ans.shape)
(2, 3, 3)
"""

def __init__(self,
@@ -154,6 +171,8 @@ class LogNormal(msd.TransformedDistribution):
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.cast = P.Cast()
self.squeeze = P.Squeeze(0)
self.zeroslike = P.ZerosLike()

@property
@@ -221,6 +240,35 @@ class LogNormal(msd.TransformedDistribution):
mean, sd = self._check_param_type(loc, scale)
return mean + 0.5 + self.log(sd) + 0.5 * self.log_2pi

def _cdf(self, value, loc=None, scale=None):
r"""
Compute the cdf via the below formula,
where g is the exp bijector,
and P is the cdf of the underlying normal dist
.. math::
Y = g(X)
P(Y <= a) = P(X <= g^{-1}(a))
"""
mean, sd = self._check_param_type(loc, scale)
inverse_value = self.bijector("inverse", value)
return self.distribution("cdf", inverse_value, mean, sd)

def _log_prob(self, value, loc=None, scale=None):
r"""
Compute the log prob via the below formula,
where g is the exp bijector,
and P is the pdf of the underlying normal dist
.. math::
Y = g(X)
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
"""
mean, sd = self._check_param_type(loc, scale)
inverse_value = self.bijector("inverse", value)
unadjust_prob = self.distribution("log_prob", inverse_value, mean, sd)
log_jacobian = self.bijector("inverse_log_jacobian", value)
return unadjust_prob + log_jacobian

def _cross_entropy(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
r"""
Evaluate cross entropy between lognormal distributions.
@@ -252,3 +300,20 @@ class LogNormal(msd.TransformedDistribution):
"""
check_distribution_name(dist, 'LogNormal')
return self.distribution("kl_loss", 'Normal', loc_b, scale_b, loc_a, scale_a)

def _sample(self, shape=(), loc=None, scale=None):
r"""
Generate samples via mapping the samples from the underlying normal dist.
"""
shape = self.checktuple(shape, 'shape')
mean, sd = self._check_param_type(loc, scale)
if shape == ():
sample_shape = (1,)
else:
sample_shape = shape
org_sample = self.distribution("sample", sample_shape, mean, sd)
org_sample = self.cast(org_sample, self.dtype)
value = self.bijector("forward", org_sample)
if shape == ():
value = self.squeeze(value)
return value

+ 65
- 53
mindspore/nn/probability/distribution/logistic.py View File

@@ -40,63 +40,75 @@ class Logistic(Distribution):
`dtype` must be a float type because Logistic distributions are continuous.

Examples:
>>> # To initialize a Logistic distribution of loc 3.0 and scale 4.0.
>>> import mindspore
>>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Logistic(3.0, 4.0, dtype=mstype.float32)
>>>
>>> # The following creates two independent Logistic distributions.
>>> n = msd.Logistic([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>> from mindspore import Tensor
>>> # To initialize a Logistic distribution of loc 3.0 and scale 4.0.
>>> l1 = msd.Logistic(3.0, 4.0, dtype=mindspore.float32)
>>> # A Logistic distribution can be initialized without arguments.
>>> # In this case, `loc` and `scale` must be passed in through arguments.
>>> n = msd.Logistic(dtype=mstype.float32)
>>> l2 = msd.Logistic(dtype=mindspore.float32)
>>>
>>> # Here are some tensors used below for testing
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
>>> loc_a = Tensor([2.0], dtype=mindspore.float32)
>>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
>>> loc_b = Tensor([1.0], dtype=mindspore.float32)
>>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
>>>
>>> # To use a Normal distribution in a network.
>>> class net(Cell):
... def __init__(self):
... super(net, self).__init__():
... self.l1 = msd.Logistic(0.0, 1.0, dtype=mstype.float32)
... self.l2 = msd.Logistic(dtype=mstype.float32)
...
... # The following calls are valid in construct.
... def construct(self, value, loc_b, scale_b, loc_a, scale_a):
...
... # Private interfaces of probability functions corresponding to public interfaces, including
... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows.
... # Args:
... # value (Tensor): the value to be evaluated.
... # loc (Tensor): the location of the distribution. Default: self.loc.
... # scale (Tensor): the scale of the distribution. Default: self.scale.
...
... # Examples of `prob`.
... # Similar calls can be made to other probability functions
... # by replacing 'prob' by the name of the function
... ans = self.l1.prob(value)
... # Evaluate with respect to distribution b.
... ans = self.l1.prob(value, loc_b, scale_b)
... # `loc` and `scale` must be passed in during function calls
... ans = self.l2.prob(value, loc_a, scale_a)
...
... # Functions `mean`, `mode`, `sd`, `var`, and `entropy` have the same arguments.
... # Args:
... # loc (Tensor): the location of the distribution. Default: self.loc.
... # scale (Tensor): the scale of the distribution. Default: self.scale.
...
... # Example of `mean`. `mode`, `sd`, `var`, and `entropy` are similar.
... ans = self.l1.mean() # return 0.0
... ans = self.l1.mean(loc_b, scale_b) # return loc_b
... # `loc` and `scale` must be passed in during function calls.
... ans = self.l2.mean(loc_a, scale_a)
...
... # Examples of `sample`.
... # Args:
... # shape (tuple): the shape of the sample. Default: ()
... # loc (Tensor): the location of the distribution. Default: self.loc.
... # scale (Tensor): the scale of the distribution. Default: self.scale.
... ans = self.l1.sample()
... ans = self.l1.sample((2,3))
... ans = self.l1.sample((2,3), loc_b, scale_b)
... ans = self.l2.sample((2,3), loc_a, scale_a)
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>> # loc (Tensor): the location of the distribution. Default: self.loc.
>>> # scale (Tensor): the scale of the distribution. Default: self.scale.
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' by the name of the function
>>> ans = l1.prob(value)
>>> print(ans)
[0.05875093 0.06153353 0.0625 ]
>>> # Evaluate with respect to distribution b.
>>> ans = l1.prob(value, loc_b, scale_b)
>>> print(ans)
[0.25 0.14943825 0.09830598]
>>> # `loc` and `scale` must be passed in during function calls
>>> ans = l1.prob(value, loc_a, scale_a)
>>> print(ans)
[0.11750185 0.125 0.11750185]
>>> # Functions `mean`, `mode`, `sd`, `var`, and `entropy` have the same arguments.
>>> # Args:
>>> # loc (Tensor): the location of the distribution. Default: self.loc.
>>> # scale (Tensor): the scale of the distribution. Default: self.scale.
>>> # Example of `mean`. `mode`, `sd`, `var`, and `entropy` are similar.
>>> ans = l1.mean()
>>> print(ans)
3.0
>>> ans = l1.mean(loc_b, scale_b)
>>> print(ans)
[1. 1. 1.]
>>> # `loc` and `scale` must be passed in during function calls.
>>> ans = l1.mean(loc_a, scale_a)
>>> print(ans)
[2. 2. 2.]
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ()
>>> # loc (Tensor): the location of the distribution. Default: self.loc.
>>> # scale (Tensor): the scale of the distribution. Default: self.scale.
>>> ans = l1.sample()
>>> print(ans.shape)
()
>>> ans = l1.sample((2,3))
>>> print(ans.shape)
(2, 3)
>>> ans = l1.sample((2,3), loc_b, scale_b)
>>> print(ans.shape)
(2, 3, 3)
>>> ans = l1.sample((2,3), loc_a, scale_a)
>>> print(ans.shape)
(2, 3, 3)
"""

def __init__(self,


+ 1
- 1
mindspore/nn/probability/distribution/uniform.py View File

@@ -200,7 +200,7 @@ class Uniform(Distribution):
self.checktensor(high, 'high')
else:
high = self.high
return high, low
return low, high

def _range(self, low=None, high=None):
r"""


Loading…
Cancel
Save