diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 7bef5ff347..b59e289b9f 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -168,9 +168,12 @@ def check_sum_equal_one(probs): if not isinstance(probs.data, Tensor): return probs = probs.data - prob_sum = np.sum(probs.asnumpy(), axis=-1) - comp = np.equal(np.ones(prob_sum.shape), prob_sum) - if not comp.all(): + if isinstance(probs, Tensor): + probs = probs.asnumpy() + prob_sum = np.sum(probs, axis=-1) + # add a small tolerance here to increase numerical stability + comp = np.allclose(prob_sum, np.ones(prob_sum.shape), rtol=1e-14, atol=1e-14) + if not comp: raise ValueError('Probabilities for each category should sum to one for Categorical distribution.') def check_rank(probs): diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 44b53ff4d0..2af21d9f49 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -40,65 +40,84 @@ class Categorical(Distribution): `probs` must have rank at least 1, values are proper probabilities and sum to 1. Examples: - >>> # To initialize a Categorical distribution of probs [0.5, 0.5] + >>> import mindspore + >>> import mindspore.nn as nn >>> import mindspore.nn.probability.distribution as msd - >>> b = msd.Categorical(probs = [0.5, 0.5], dtype=mstype.int32) - >>> - >>> # To use a Categorical distribution in a network - >>> class net(Cell): - ... def __init__(self, probs): - ... super(net, self).__init__(): - ... self.ca = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32) - ... self.ca1 = msd.Categorical(dtype=mstype.int32) - ... - ... # All the following calls in construct are valid - ... def construct(self, value): - ... - ... # Private interfaces of probability functions corresponding to public interfaces, including - ... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows. - ... # Args: - ... # value (Tensor): the value to be evaluated. - ... # probs (Tensor): event probabilities. Default: self.probs. - ... - ... # Examples of `prob`. - ... # Similar calls can be made to other probability functions - ... # by replacing `prob` by the name of the function. - ... ans = self.ca.prob(value) - ... # Evaluate `prob` with respect to distribution b. - ... ans = self.ca.prob(value, probs_b) - ... # `probs` must be passed in during function calls. - ... ans = self.ca1.prob(value, probs_a) - ... - ... # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments. - ... # Args: - ... # probs (Tensor): event probabilities. Default: self.probs. - ... - ... # Examples of `mean`. `sd`, `var`, and `entropy` are similar. - ... ans = self.ca.mean() # return 0.8 - ... ans = self.ca.mean(probs_b) - ... # `probs` must be passed in during function calls. - ... ans = self.ca1.mean(probs_a) - ... - ... # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: - ... # Args: - ... # dist (str): the name of the distribution. Only 'Categorical' is supported. - ... # probs_b (Tensor): event probabilities of distribution b. - ... # probs (Tensor): event probabilities of distribution a. Default: self.probs. - ... - ... # Examples of kl_loss. `cross_entropy` is similar. - ... ans = self.ca.kl_loss('Categorical', probs_b) - ... ans = self.ca.kl_loss('Categorical', probs_b, probs_a) - ... # An additional `probs` must be passed in. - ... ans = self.ca1.kl_loss('Categorical', probs_b, probs_a) - ... - ... # Examples of `sample`. - ... # Args: - ... # shape (tuple): the shape of the sample. Default: (). - ... # probs (Tensor): event probabilities. Default: self.probs. - ... ans = self.ca.sample() - ... ans = self.ca.sample((2,3)) - ... ans = self.ca.sample((2,3), probs_b) - ... ans = self.ca1.sample((2,3), probs_a) + >>> from mindspore import Tensor + >>> # To initialize a Categorical distribution of probs [0.5, 0.5] + >>> ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mindspore.int32) + >>> # A Categorical distribution can be initialized without arguments. + >>> # In this case, `probs` must be passed in through arguments during function calls. + >>> ca2 = msd.Categorical(dtype=mindspore.int32) + >>> # Here are some tensors used below for testing + >>> value = Tensor([1, 0], dtype=mindspore.int32) + >>> probs_a = Tensor([0.5, 0.5], dtype=mindspore.float32) + >>> probs_b = Tensor([0.35, 0.65], dtype=mindspore.float32) + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> # probs (Tensor): event probabilities. Default: self.probs. + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing `prob` by the name of the function. + >>> ans = ca1.prob(value) + >>> print(ans) + [0.8 0.2] + >>> # Evaluate `prob` with respect to distribution b. + >>> ans = ca1.prob(value, probs_b) + >>> print(ans) + [0.65 0.35] + >>> # `probs` must be passed in during function calls. + >>> ans = ca2.prob(value, probs_a) + >>> print(ans) + [0.5 0.5] + >>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments. + >>> # Args: + >>> # probs (Tensor): event probabilities. Default: self.probs. + >>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar. + >>> ans = ca1.mean() # return 0.8 + >>> print(ans) + [0.8] + >>> ans = ca1.mean(probs_b) + >>> print(ans) + [0.65] + >>> # `probs` must be passed in during function calls. + >>> ans = ca2.mean(probs_a) + >>> print(ans) + [0.5] + >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: + >>> # Args: + >>> # dist (str): the name of the distribution. Only 'Categorical' is supported. + >>> # probs_b (Tensor): event probabilities of distribution b. + >>> # probs (Tensor): event probabilities of distribution a. Default: self.probs. + >>> # Examples of kl_loss. `cross_entropy` is similar. + >>> ans = ca1.kl_loss('Categorical', probs_b) + >>> print(ans) + 0.05418826 + >>> ans = ca1.kl_loss('Categorical', probs_b, probs_a) + >>> print(ans) + 0.04715523 + >>> # An additional `probs` must be passed in. + >>> ans = ca2.kl_loss('Categorical', probs_b, probs_a) + >>> print(ans) + 0.04715523 + >>> # Examples of `sample`. + >>> # Args: + >>> # shape (tuple): the shape of the sample. Default: (). + >>> # probs (Tensor): event probabilities. Default: self.probs. + >>> ans = ca1.sample() + >>> print(ans.shape) + () + >>> ans = ca1.sample((2,3)) + >>> print(ans.shape) + (2, 3) + >>> ans = ca1.sample((2,3), probs_b) + >>> print(ans.shape) + (2, 3) + >>> ans = ca2.sample((2,3), probs_a) + >>> print(ans.shape) + (2, 3) """ def __init__(self, @@ -108,7 +127,7 @@ class Categorical(Distribution): name="Categorical"): param = dict(locals()) param['param_dict'] = {'probs': probs} - valid_dtype = mstype.int_type + mstype.float_type + valid_dtype = mstype.uint_type + mstype.int_type + mstype.float_type Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Categorical, self).__init__(seed, dtype, name, param) @@ -116,7 +135,7 @@ class Categorical(Distribution): if self.probs is not None: check_rank(self.probs) check_prob(self.probs) - check_sum_equal_one(self.probs) + check_sum_equal_one(probs) # update is_scalar_batch and broadcast_shape # drop one dimension @@ -124,7 +143,7 @@ class Categorical(Distribution): self._is_scalar_batch = True self._broadcast_shape = self._broadcast_shape[:-1] - self.argmax = P.Argmax() + self.argmax = P.ArgMaxWithValue(axis=-1) self.broadcast = broadcast_to self.cast = P.Cast() self.clip_by_value = C.clip_by_value @@ -140,6 +159,7 @@ class Categorical(Distribution): self.log = log_generic self.log_softmax = P.LogSoftmax() self.logicor = P.LogicalOr() + self.logicand = P.LogicalAnd() self.multinomial = P.Multinomial(seed=self.seed) self.reshape = P.Reshape() self.reduce_sum = P.ReduceSum(keep_dims=True) @@ -192,8 +212,9 @@ class Categorical(Distribution): def _mode(self, probs=None): probs = self._check_param_type(probs) - mode = self.cast(self.argmax(probs), self.dtype) - return self.squeeze(mode) + index, _ = self.argmax(probs) + mode = self.cast(index, self.dtype) + return mode def _var(self, probs=None): r""" @@ -232,7 +253,7 @@ class Categorical(Distribution): probs_a = self._check_param_type(probs) logits_a = self.log(probs_a) logits_b = self.log(probs_b) - return self.squeeze(-self.reduce_sum( + return self.squeeze(self.reduce_sum( self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1)) def _cross_entropy(self, dist, probs_b, probs=None): @@ -256,10 +277,13 @@ class Categorical(Distribution): probs (Tensor): Event probabilities. Default: self.probs. """ value = self._check_value(value, 'value') + # cast value to int to find the right integer to compute index if self.issubclass(self.dtype, mstype.float_): value = self.cast(value, self.index_type) else: value = self.cast(value, self.dtype) + # cast int to float for the broadcasting below + value = self.cast(value, mstype.float32) probs = self._check_param_type(probs) logits = self.log(probs) @@ -289,6 +313,9 @@ class Categorical(Distribution): value = self.reshape(value, (-1, 1)) out_of_bound = self.squeeze_last_axis(self.logicor(\ self.less(value, 0.0), self.less(num_classes-1, value))) + # deal with the case the there is only one class. + zeros = self.fill(mstype.float32, self.shape(out_of_bound), 0.0) + out_of_bound = self.logicand(out_of_bound, self.less(zeros, num_classes-1)) value_clipped = self.clip_by_value(value, 0.0, num_classes - 1) value_clipped = self.cast(value_clipped, self.index_type) # create index from 0 ... NumOfLabels diff --git a/mindspore/nn/probability/distribution/cauchy.py b/mindspore/nn/probability/distribution/cauchy.py index 1726c515d4..f1ae92b459 100644 --- a/mindspore/nn/probability/distribution/cauchy.py +++ b/mindspore/nn/probability/distribution/cauchy.py @@ -41,77 +41,91 @@ class Cauchy(Distribution): Cauchy distribution is not supported on GPU backend. Examples: - >>> # To initialize a Cauchy distribution of loc 3.0 and scale 4.0. + >>> import mindspore + >>> import mindspore.nn as nn >>> import mindspore.nn.probability.distribution as msd - >>> cauchy = msd.Cauchy(3.0, 4.0, dtype=mstype.float32) - >>> - >>> # The following creates two independent Cauchy distributions. - >>> cauchy = msd.Cauchy([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) - >>> + >>> from mindspore import Tensor + >>> # To initialize a Cauchy distribution of loc 3.0 and scale 4.0. + >>> cauchy1 = msd.Cauchy(3.0, 4.0, dtype=mindspore.float32) >>> # A Cauchy distribution can be initialized without arguments. >>> # In this case, 'loc' and `scale` must be passed in through arguments. - >>> cauchy = msd.Cauchy(dtype=mstype.float32) - >>> - >>> # To use a Cauchy distribution in a network. - >>> class net(Cell): - ... def __init__(self): - ... super(net, self).__init__(): - ... self.cau1 = msd.Cauchy(0.0, 1.0, dtype=mstype.float32) - ... self.cau2 = msd.Cauchy(dtype=mstype.float32) - ... - ... # The following calls are valid in construct. - ... def construct(self, value, loc_b, scale_b, loc_a, scale_a): - ... - ... # Private interfaces of probability functions corresponding to public interfaces, including - ... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows. - ... # Args: - ... # value (Tensor): the value to be evaluated. - ... # loc (Tensor): the location of the distribution. Default: self.loc. - ... # scale (Tensor): the scale of the distribution. Default: self.scale. - ... - ... # Examples of `prob`. - ... # Similar calls can be made to other probability functions - ... # by replacing 'prob' by the name of the function - ... ans = self.cau1.prob(value) - ... # Evaluate with respect to distribution b. - ... ans = self.cau1.prob(value, loc_b, scale_b) - ... # `loc` and `scale` must be passed in during function calls - ... ans = self.cau2.prob(value, loc_a, scale_a) - ... - ... # Functions `mode` and `entropy` have the same arguments. - ... # Args: - ... # loc (Tensor): the location of the distribution. Default: self.loc. - ... # scale (Tensor): the scale of the distribution. Default: self.scale. - ... - ... # Example of `mode`. - ... ans = self.cau1.mode() # return 0.0 - ... ans = self.cau1.mode(loc_b, scale_b) # return loc_b - ... # `loc` and `scale` must be passed in during function calls. - ... ans = self.cau2.mode(loc_a, scale_a) - ... - ... # Interfaces of 'kl_loss' and 'cross_entropy' are the same: - ... # Args: - ... # dist (str): the type of the distributions. Only "Cauchy" is supported. - ... # loc_b (Tensor): the loc of distribution b. - ... # scale_b (Tensor): the scale distribution b. - ... # loc (Tensor): the loc of distribution a. Default: self.loc. - ... # scale (Tensor): the scale distribution a. Default: self.scale. - ... - ... # Examples of `kl_loss`. `cross_entropy` is similar. - ... ans = self.cau1.kl_loss('Cauchy', loc_b, scale_b) - ... ans = self.cau1.kl_loss('Cauchy', loc_b, scale_b, loc_a, scale_a) - ... # Additional `loc` and `scale` must be passed in. - ... ans = self.cau2.kl_loss('Cauchy', loc_b, scale_b, loc_a, scale_a) - ... - ... # Examples of `sample`. - ... # Args: - ... # shape (tuple): the shape of the sample. Default: () - ... # loc (Tensor): the location of the distribution. Default: self.loc. - ... # scale (Tensor): the scale of the distribution. Default: self.scale. - ... ans = self.cau1.sample() - ... ans = self.cau1.sample((2,3)) - ... ans = self.cau1.sample((2,3), loc_b, s_b) - ... ans = self.cau2.sample((2,3), loc_a, s_a) + >>> cauchy2 = msd.Cauchy(dtype=mindspore.float32) + >>> # Here are some tensors used below for testing + >>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32) + >>> loc_a = Tensor([2.0], dtype=mindspore.float32) + >>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32) + >>> loc_b = Tensor([1.0], dtype=mindspore.float32) + >>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32) + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' by the name of the function + >>> ans = cauchy1.prob(value) + >>> print(ans) + [0.06366198 0.07489645 0.07957747] + >>> # Evaluate with respect to distribution b. + >>> ans = cauchy1.prob(value, loc_b, scale_b) + >>> print(ans) + [0.31830987 0.14691226 0.07957747] + >>> # `loc` and `scale` must be passed in during function calls + >>> ans = cauchy2.prob(value, loc_a, scale_a) + >>> print(ans) + [0.12732396 0.15915494 0.12732396] + >>> # Functions `mode` and `entropy` have the same arguments. + >>> # Args: + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> # Example of `mode`. + >>> ans = cauchy1.mode() # return 3.0 + >>> print(ans) + 3.0 + >>> ans = cauchy1.mode(loc_b, scale_b) # return loc_b + >>> print(ans) + [1. 1. 1.] + >>> # `loc` and `scale` must be passed in during function calls. + >>> ans = cauchy2.mode(loc_a, scale_a) + >>> print(ans) + [2. 2. 2.] + >>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same: + >>> # Args: + >>> # dist (str): the type of the distributions. Only "Cauchy" is supported. + >>> # loc_b (Tensor): the loc of distribution b. + >>> # scale_b (Tensor): the scale distribution b. + >>> # loc (Tensor): the loc of distribution a. Default: self.loc. + >>> # scale (Tensor): the scale distribution a. Default: self.scale. + >>> # Examples of `kl_loss`. `cross_entropy` is similar. + >>> ans = cauchy1.kl_loss('Cauchy', loc_b, scale_b) + >>> print(ans) + [0.594707 0.35563278 0.22314358] + >>> ans = cauchy1.kl_loss('Cauchy', loc_b, scale_b, loc_a, scale_a) + >>> print(ans) + [0.22314358 0.09909081 0.0606246 ] + >>> # Additional `loc` and `scale` must be passed in. + >>> ans = cauchy2.kl_loss('Cauchy', loc_b, scale_b, loc_a, scale_a) + >>> print(ans) + [0.22314358 0.09909081 0.0606246 ] + >>> # Examples of `sample`. + >>> # Args: + >>> # shape (tuple): the shape of the sample. Default: () + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> ans = cauchy1.sample() + >>> print(ans.shape) + () + >>> ans = cauchy1.sample((2,3)) + >>> print(ans.shape) + (2, 3) + >>> ans = cauchy1.sample((2,3), loc_b, scale_b) + >>> print(ans.shape) + (2, 3, 3) + >>> ans = cauchy2.sample((2,3), loc_a, scale_a) + >>> print(ans.shape) + (2, 3, 3) """ def __init__(self, @@ -275,7 +289,7 @@ class Cauchy(Distribution): loc, scale = self._check_param_type(loc, scale) return loc + scale * self.tan(np.pi * (p - 0.5)) - def _kl_loss(self, dist, loc_b, scale_b, loc=None, scale=None): + def _kl_loss(self, dist, loc_b, scale_b, loc_a=None, scale_a=None): r""" Evaluate Cauchy-Cauchy kl divergence, i.e. KL(a||b). @@ -291,17 +305,17 @@ class Cauchy(Distribution): {4 * scale_a * scale_b}) """ check_distribution_name(dist, 'Cauchy') - loc, scale = self._check_param_type(loc, scale) + loc_a, scale_a = self._check_param_type(loc_a, scale_a) loc_b = self._check_value(loc_b, 'loc_b') loc_b = self.cast(loc_b, self.parameter_type) scale_b = self._check_value(scale_b, 'scale_b') scale_b = self.cast(scale_b, self.parameter_type) - sum_square = self.sq(scale + scale_b) - square_diff = self.sq(loc - loc_b) + sum_square = self.sq(scale_a + scale_b) + square_diff = self.sq(loc_a - loc_b) return self.log(sum_square + square_diff) - \ - self.log(self.const(4.0)) - self.log(scale) - self.log(scale_b) + self.log(self.const(4.0)) - self.log(scale_a) - self.log(scale_b) - def _cross_entropy(self, dist, loc_b, scale_b, loc=None, scale=None): + def _cross_entropy(self, dist, loc_b, scale_b, loc_a=None, scale_a=None): r""" Evaluate cross entropy between Cauchy distributions. @@ -313,7 +327,7 @@ class Cauchy(Distribution): scale (Tensor): The scale of distribution a. Default: self.scale. """ check_distribution_name(dist, 'Cauchy') - return self._entropy(loc, scale) + self._kl_loss(dist, loc_b, scale_b, loc, scale) + return self._entropy(loc_a, scale_a) + self._kl_loss(dist, loc_b, scale_b, loc_a, scale_a) def _sample(self, shape=(), loc=None, scale=None): """ diff --git a/mindspore/nn/probability/distribution/gumbel.py b/mindspore/nn/probability/distribution/gumbel.py index bf99630e3b..f6e8257e91 100644 --- a/mindspore/nn/probability/distribution/gumbel.py +++ b/mindspore/nn/probability/distribution/gumbel.py @@ -42,55 +42,49 @@ class Gumbel(TransformedDistribution): `kl_loss` and `cross_entropy` are not supported on GPU backend. Examples: + >>> import mindspore + >>> import mindspore.context as context + >>> import mindspore.nn as nn + >>> import mindspore.nn.probability.distribution as msd + >>> from mindspore import Tensor + >>> context.set_context(mode=1, device_target="GPU") >>> # To initialize a Gumbel distribution of `loc` 3.0 and `scale` 4.0. - >>> gum = msd.Gumbel(3.0, 4.0, dtype=mstype.float32) - >>> - >>> # The following creates two independent Gumbel distributions. - >>> gum = msd.Gumbel([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) - >>> - >>> # To use a Gumbel distribution in a network. - >>> class net(Cell): - ... def __init__(self): - ... super(net, self).__init__(): - ... self.g1 = msd.Gumbel(0.0, 1.0, dtype=mstype.float32) - ... - ... # The following calls are valid in construct. - ... def construct(self, value, loc_b, scale_b): - ... - ... # Private interfaces of probability functions corresponding to public interfaces, including - ... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same - ... # arguments as follows. - ... # Args: - ... # value (Tensor): the value to be evaluated. - ... - ... # Examples of `prob`. - ... # Similar calls can be made to other probability functions - ... # by replacing 'prob' by the name of the function. - ... ans = self.g1.prob(value) - ... - ... # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument. - ... ans = self.g1.mean() - ... ans = self.g1.mode() - ... ans = self.g1.sd() - ... ans = self.g1.entropy() - ... ans = self.g1.var() - ... - ... # Interfaces of 'kl_loss' and 'cross_entropy' are the same: - ... # Args: - ... # dist (str): the type of the distributions. Only "Gumbel" is supported. - ... # loc_b (Tensor): the loc of distribution b. - ... # scale_b (Tensor): the scale distribution b. - ... - ... # Examples of `kl_loss`. `cross_entropy` is similar. - ... ans = self.g1.kl_loss('Gumbel', loc_b, scale_b) - ... ans = self.g1.cross_entropy('Gumbel', loc_b, scale_b) - ... - ... # Examples of `sample`. - ... # Args: - ... # shape (tuple): the shape of the sample. Default: () - ... - ... ans = self.g1.sample() - ... ans = self.g1.sample((2,3)) + >>> gumbel = msd.Gumbel(3.0, 4.0, dtype=mindspore.float32) + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same + >>> # arguments as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' by the name of the function. + >>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32) + >>> ans = gumbel.prob(value) + >>> print(ans) + [0.07926048 0.08889321 0.09196986] + >>> # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument. + >>> ans = gumbel.mean() + >>> print(ans) + 5.3088627 + >>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same: + >>> # Args: + >>> # dist (str): the type of the distributions. Only "Gumbel" is supported. + >>> # loc_b (Tensor): the loc of distribution b. + >>> # scale_b (Tensor): the scale distribution b. + >>> # Examples of `kl_loss`. `cross_entropy` is similar. + >>> loc_b = Tensor([1.0], dtype=mindspore.float32) + >>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32) + >>> ans = gumbel.kl_loss('Gumbel', loc_b, scale_b) + >>> print(ans) + [ 2.5934026 0.03880269 -0.38017237] + >>> # Examples of `sample`. + >>> # Args: + >>> # shape (tuple): the shape of the sample. Default: () + >>> ans = gumbel.sample() + >>> print(ans.shape) + () + >>> ans = gumbel.sample((2,3)) + >>> print(ans.shape) """ def __init__(self, @@ -125,6 +119,7 @@ class Gumbel(TransformedDistribution): self.lgamma = nn.LGamma() self.log = log_generic self.shape = P.Shape() + self.squeeze = P.Squeeze(0) self.sqrt = P.Sqrt() @property diff --git a/mindspore/nn/probability/distribution/log_normal.py b/mindspore/nn/probability/distribution/log_normal.py index 4cd8ebb196..28c1467e1d 100644 --- a/mindspore/nn/probability/distribution/log_normal.py +++ b/mindspore/nn/probability/distribution/log_normal.py @@ -41,87 +41,104 @@ class LogNormal(msd.TransformedDistribution): `dtype` must be a float type because LogNormal distributions are continuous. Examples: + >>> import mindspore + >>> import mindspore.context as context + >>> import mindspore.nn as nn + >>> import mindspore.nn.probability.distribution as msd + >>> from mindspore import Tensor + >>> context.set_context(mode=1) >>> # To initialize a LogNormal distribution of `loc` 3.0 and `scale` 4.0. - >>> n = msd.LogNormal(3.0, 4.0, dtype=mstype.float32) - >>> - >>> # The following creates two independent LogNormal distributions. - >>> n = msd.LogNormal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) - >>> + >>> n1 = msd.LogNormal(3.0, 4.0, dtype=mindspore.float32) >>> # A LogNormal distribution can be initialized without arguments. >>> # In this case, `loc` and `scale` must be passed in during function calls. - >>> n = msd.LogNormal(dtype=mstype.float32) + >>> n2 = msd.LogNormal(dtype=mindspore.float32) + >>> + >>> # Here are some tensors used below for testing + >>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32) + >>> loc_a = Tensor([2.0], dtype=mindspore.float32) + >>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32) + >>> loc_b = Tensor([1.0], dtype=mindspore.float32) + >>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32) >>> - >>> # To use a LogNormal distribution in a network. - >>> class net(Cell): - ... def __init__(self): - ... super(net, self).__init__(): - ... self.n1 = msd.LogNormal(0.0, 1.0, dtype=mstype.float32) - ... self.n2 = msd.LogNormal(dtype=mstype.float32) - ... - ... # The following calls are valid in construct. - ... def construct(self, value, loc_b, scale_b, loc_a, scale_a): - ... - ... # Private interfaces of probability functions corresponding to public interfaces, including - ... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same - ... # arguments as follows. - ... # Args: - ... # value (Tensor): the value to be evaluated. - ... # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None, - ... # the mean of the underlying Normal distribution will be used. - ... # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None, - ... # the standard deviation of the underlying Normal distribution will be used. - ... - ... # Examples of `prob`. - ... # Similar calls can be made to other probability functions - ... # by replacing 'prob' by the name of the function. - ... ans = self.n1.prob(value) - ... # Evaluate with respect to distribution b. - ... ans = self.n1.prob(value, loc_b, scale_b) - ... # `loc` and `scale` must be passed in during function calls since they were not passed in construct. - ... ans = self.n2.prob(value, loc_a, scale_a) - ... - ... - ... # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments. - ... # Args: - ... # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None, - ... # the mean of the underlying Normal distribution will be used. - ... # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None, - ... # the standard deviation of the underlying Normal distribution will be used. - ... - ... # Example of `mean`. `sd`, `var`, and `entropy` are similar. - ... ans = self.n1.mean() # return 0.0 - ... ans = self.n1.mean(loc_b, scale_b) # return mean_b - ... # `loc` and `scale` must be passed in during function calls since they were not passed in construct. - ... ans = self.n2.mean(loc_a, scale_a) - ... - ... - ... # Interfaces of 'kl_loss' and 'cross_entropy' are the same: - ... # Args: - ... # dist (str): the type of the distributions. Only "Normal" is supported. - ... # loc_b (Tensor): the loc of distribution b. - ... # scale_b (Tensor): the scale distribution b. - ... # loc_a (Tensor): the loc of distribution a. Default: None. If `loc` is passed in as None, - ... # the mean of the underlying Normal distribution will be used. - ... # scale_a (Tensor): the scale distribution a. Default: None. If `scale` is passed in as None, - ... # the standard deviation of the underlying Normal distribution will be used. - ... - ... # Examples of `kl_loss`. `cross_entropy` is similar. - ... ans = self.n1.kl_loss('Normal', loc_b, scale_b) - ... ans = self.n1.kl_loss('Normal', loc_b, scale_b, loc_a, scale_a) - ... # Additional `loc` and `scale` must be passed in since they were not passed in construct. - ... ans = self.n2.kl_loss('Normal', loc_b, scale_b, loc_a, scale_a) - ... - ... # Examples of `sample`. - ... # Args: - ... # shape (tuple): the shape of the sample. Default: () - ... # loc (Tensor): the loc of the distribution. Default: None. If `loc` is passed in as None, - ... # the mean of the underlying Normal distribution will be used. - ... # scale (Tensor): the scale of the distribution. Default: None. If `scale` is passed in as None, - ... # the standard deviation of the underlying Normal distribution will be used. - ... ans = self.n1.sample() - ... ans = self.n1.sample((2,3)) - ... ans = self.n1.sample((2,3), loc_b, scale_b) - ... ans = self.n2.sample((2,3), loc_a, scale_a) + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same + >>> # arguments as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None, + >>> # the mean of the underlying Normal distribution will be used. + >>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None, + >>> # the standard deviation of the underlying Normal distribution will be used. + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' by the name of the function. + >>> ans = n1.prob(value) + >>> print(ans) + [0.07528435 0.04222769 0.02969363] + >>> # Evaluate with respect to distribution b. + >>> ans = n1.prob(value, loc_b, scale_b) + >>> print(ans) + [0.24197072 0.13022715 0.0664096 ] + >>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct. + >>> ans = n2.prob(value, loc_a, scale_a) + >>> print(ans) + [0.12098535 0.08056299 0.06006904] + >>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments. + >>> # Args: + >>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None, + >>> # the mean of the underlying Normal distribution will be used. + >>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None, + >>> # the standard deviation of the underlying Normal distribution will be used. + >>> # Example of `mean`. `sd`, `var`, and `entropy` are similar. + >>> ans = n1.mean() + >>> print(ans) + 59874.14 + >>> ans = n1.mean(loc_b, scale_b) + >>> print(ans) + [ 4.481689 8.372897 20.085537] + >>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct. + >>> ans = n2.mean(loc_a, scale_a) + >>> print(ans) + [54.59815 54.59815 54.59815] + >>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same: + >>> # Args: + >>> # dist (str): the type of the distributions. Only "Normal" is supported. + >>> # loc_b (Tensor): the loc of distribution b. + >>> # scale_b (Tensor): the scale distribution b. + >>> # loc_a (Tensor): the loc of distribution a. Default: None. If `loc` is passed in as None, + >>> # the mean of the underlying Normal distribution will be used. + >>> # scale_a (Tensor): the scale distribution a. Default: None. If `scale` is passed in as None, + >>> # the standard deviation of the underlying Normal distribution will be used. + >>> # Examples of `kl_loss`. `cross_entropy` is similar. + >>> ans = n1.kl_loss('LogNormal', loc_b, scale_b) + >>> print(ans) + [8.113706 2.963615 1.3068528] + >>> ans = n1.kl_loss('LogNormal', loc_b, scale_b, loc_a, scale_a) + >>> print(ans) + [1.3068528 0.32342905 0.125 ] + >>> # Additional `loc` and `scale` must be passed in since they were not passed in construct. + >>> ans = n2.kl_loss('LogNormal', loc_b, scale_b, loc_a, scale_a) + >>> print(ans) + [1.3068528 0.32342905 0.125 ] + >>> # Examples of `sample`. + >>> # Args: + >>> # shape (tuple): the shape of the sample. Default: () + >>> # loc (Tensor): the loc of the distribution. Default: None. If `loc` is passed in as None, + >>> # the mean of the underlying Normal distribution will be used. + >>> # scale (Tensor): the scale of the distribution. Default: None. If `scale` is passed in as None, + >>> # the standard deviation of the underlying Normal distribution will be used. + >>> ans = n1.sample() + >>> print(ans.shape) + () + >>> ans = n1.sample((2,3)) + >>> print(ans.shape) + (2, 3) + >>> ans = n1.sample((2,3), loc_b, scale_b) + >>> print(ans.shape) + (2, 3, 3) + >>> ans = n2.sample((2,3), loc_a, scale_a) + >>> print(ans.shape) + (2, 3, 3) """ def __init__(self, @@ -154,6 +171,8 @@ class LogNormal(msd.TransformedDistribution): self.shape = P.Shape() self.sq = P.Square() self.sqrt = P.Sqrt() + self.cast = P.Cast() + self.squeeze = P.Squeeze(0) self.zeroslike = P.ZerosLike() @property @@ -221,6 +240,35 @@ class LogNormal(msd.TransformedDistribution): mean, sd = self._check_param_type(loc, scale) return mean + 0.5 + self.log(sd) + 0.5 * self.log_2pi + def _cdf(self, value, loc=None, scale=None): + r""" + Compute the cdf via the below formula, + where g is the exp bijector, + and P is the cdf of the underlying normal dist + .. math:: + Y = g(X) + P(Y <= a) = P(X <= g^{-1}(a)) + """ + mean, sd = self._check_param_type(loc, scale) + inverse_value = self.bijector("inverse", value) + return self.distribution("cdf", inverse_value, mean, sd) + + def _log_prob(self, value, loc=None, scale=None): + r""" + Compute the log prob via the below formula, + where g is the exp bijector, + and P is the pdf of the underlying normal dist + .. math:: + Y = g(X) + Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a) + \log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a)) + """ + mean, sd = self._check_param_type(loc, scale) + inverse_value = self.bijector("inverse", value) + unadjust_prob = self.distribution("log_prob", inverse_value, mean, sd) + log_jacobian = self.bijector("inverse_log_jacobian", value) + return unadjust_prob + log_jacobian + def _cross_entropy(self, dist, loc_b, scale_b, loc_a=None, scale_a=None): r""" Evaluate cross entropy between lognormal distributions. @@ -252,3 +300,20 @@ class LogNormal(msd.TransformedDistribution): """ check_distribution_name(dist, 'LogNormal') return self.distribution("kl_loss", 'Normal', loc_b, scale_b, loc_a, scale_a) + + def _sample(self, shape=(), loc=None, scale=None): + r""" + Generate samples via mapping the samples from the underlying normal dist. + """ + shape = self.checktuple(shape, 'shape') + mean, sd = self._check_param_type(loc, scale) + if shape == (): + sample_shape = (1,) + else: + sample_shape = shape + org_sample = self.distribution("sample", sample_shape, mean, sd) + org_sample = self.cast(org_sample, self.dtype) + value = self.bijector("forward", org_sample) + if shape == (): + value = self.squeeze(value) + return value diff --git a/mindspore/nn/probability/distribution/logistic.py b/mindspore/nn/probability/distribution/logistic.py index 5128850896..d5ad182b2e 100644 --- a/mindspore/nn/probability/distribution/logistic.py +++ b/mindspore/nn/probability/distribution/logistic.py @@ -40,63 +40,75 @@ class Logistic(Distribution): `dtype` must be a float type because Logistic distributions are continuous. Examples: - >>> # To initialize a Logistic distribution of loc 3.0 and scale 4.0. + >>> import mindspore + >>> import mindspore.nn as nn >>> import mindspore.nn.probability.distribution as msd - >>> n = msd.Logistic(3.0, 4.0, dtype=mstype.float32) - >>> - >>> # The following creates two independent Logistic distributions. - >>> n = msd.Logistic([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) - >>> + >>> from mindspore import Tensor + >>> # To initialize a Logistic distribution of loc 3.0 and scale 4.0. + >>> l1 = msd.Logistic(3.0, 4.0, dtype=mindspore.float32) >>> # A Logistic distribution can be initialized without arguments. >>> # In this case, `loc` and `scale` must be passed in through arguments. - >>> n = msd.Logistic(dtype=mstype.float32) + >>> l2 = msd.Logistic(dtype=mindspore.float32) + >>> + >>> # Here are some tensors used below for testing + >>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32) + >>> loc_a = Tensor([2.0], dtype=mindspore.float32) + >>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32) + >>> loc_b = Tensor([1.0], dtype=mindspore.float32) + >>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32) >>> - >>> # To use a Normal distribution in a network. - >>> class net(Cell): - ... def __init__(self): - ... super(net, self).__init__(): - ... self.l1 = msd.Logistic(0.0, 1.0, dtype=mstype.float32) - ... self.l2 = msd.Logistic(dtype=mstype.float32) - ... - ... # The following calls are valid in construct. - ... def construct(self, value, loc_b, scale_b, loc_a, scale_a): - ... - ... # Private interfaces of probability functions corresponding to public interfaces, including - ... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows. - ... # Args: - ... # value (Tensor): the value to be evaluated. - ... # loc (Tensor): the location of the distribution. Default: self.loc. - ... # scale (Tensor): the scale of the distribution. Default: self.scale. - ... - ... # Examples of `prob`. - ... # Similar calls can be made to other probability functions - ... # by replacing 'prob' by the name of the function - ... ans = self.l1.prob(value) - ... # Evaluate with respect to distribution b. - ... ans = self.l1.prob(value, loc_b, scale_b) - ... # `loc` and `scale` must be passed in during function calls - ... ans = self.l2.prob(value, loc_a, scale_a) - ... - ... # Functions `mean`, `mode`, `sd`, `var`, and `entropy` have the same arguments. - ... # Args: - ... # loc (Tensor): the location of the distribution. Default: self.loc. - ... # scale (Tensor): the scale of the distribution. Default: self.scale. - ... - ... # Example of `mean`. `mode`, `sd`, `var`, and `entropy` are similar. - ... ans = self.l1.mean() # return 0.0 - ... ans = self.l1.mean(loc_b, scale_b) # return loc_b - ... # `loc` and `scale` must be passed in during function calls. - ... ans = self.l2.mean(loc_a, scale_a) - ... - ... # Examples of `sample`. - ... # Args: - ... # shape (tuple): the shape of the sample. Default: () - ... # loc (Tensor): the location of the distribution. Default: self.loc. - ... # scale (Tensor): the scale of the distribution. Default: self.scale. - ... ans = self.l1.sample() - ... ans = self.l1.sample((2,3)) - ... ans = self.l1.sample((2,3), loc_b, scale_b) - ... ans = self.l2.sample((2,3), loc_a, scale_a) + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' by the name of the function + >>> ans = l1.prob(value) + >>> print(ans) + [0.05875093 0.06153353 0.0625 ] + >>> # Evaluate with respect to distribution b. + >>> ans = l1.prob(value, loc_b, scale_b) + >>> print(ans) + [0.25 0.14943825 0.09830598] + >>> # `loc` and `scale` must be passed in during function calls + >>> ans = l1.prob(value, loc_a, scale_a) + >>> print(ans) + [0.11750185 0.125 0.11750185] + >>> # Functions `mean`, `mode`, `sd`, `var`, and `entropy` have the same arguments. + >>> # Args: + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> # Example of `mean`. `mode`, `sd`, `var`, and `entropy` are similar. + >>> ans = l1.mean() + >>> print(ans) + 3.0 + >>> ans = l1.mean(loc_b, scale_b) + >>> print(ans) + [1. 1. 1.] + >>> # `loc` and `scale` must be passed in during function calls. + >>> ans = l1.mean(loc_a, scale_a) + >>> print(ans) + [2. 2. 2.] + >>> # Examples of `sample`. + >>> # Args: + >>> # shape (tuple): the shape of the sample. Default: () + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> ans = l1.sample() + >>> print(ans.shape) + () + >>> ans = l1.sample((2,3)) + >>> print(ans.shape) + (2, 3) + >>> ans = l1.sample((2,3), loc_b, scale_b) + >>> print(ans.shape) + (2, 3, 3) + >>> ans = l1.sample((2,3), loc_a, scale_a) + >>> print(ans.shape) + (2, 3, 3) """ def __init__(self, diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 1ea9c7ad41..192110a708 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -200,7 +200,7 @@ class Uniform(Distribution): self.checktensor(high, 'high') else: high = self.high - return high, low + return low, high def _range(self, low=None, high=None): r"""