diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 729c42634d..2908f33b6b 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -370,7 +370,7 @@ class CheckTensor(PrimitiveWithInfer): def __call__(self, x, name): if isinstance(x, Tensor): return x - raise TypeError(f"For {name}, input type should be a Tensor.") + raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): """ diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 0dcbc59689..7fce4b7802 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -186,8 +186,8 @@ class Bernoulli(Distribution): H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) """ probs1 = self._check_param(probs1) - probs0 = 1 - probs1 - return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) + probs0 = 1.0 - probs1 + return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) def _cross_entropy(self, dist, probs1_b, probs1=None): """ diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index dcb904aeac..943b022057 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -27,7 +27,7 @@ class Distribution(Cell): Args: seed (int): random seed used in sampling. - dtype (mindspore.dtype): type of the distribution. + dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype. name (str): Python str name prefixed to Ops created by this class. Default: subclass name. param (dict): parameters used to initialize the distribution.