| @@ -370,7 +370,7 @@ class CheckTensor(PrimitiveWithInfer): | |||||
| def __call__(self, x, name): | def __call__(self, x, name): | ||||
| if isinstance(x, Tensor): | if isinstance(x, Tensor): | ||||
| return x | return x | ||||
| raise TypeError(f"For {name}, input type should be a Tensor.") | |||||
| raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") | |||||
| def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): | def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): | ||||
| """ | """ | ||||
| @@ -186,8 +186,8 @@ class Bernoulli(Distribution): | |||||
| H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) | H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | probs1 = self._check_param(probs1) | ||||
| probs0 = 1 - probs1 | |||||
| return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) | |||||
| probs0 = 1.0 - probs1 | |||||
| return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) | |||||
| def _cross_entropy(self, dist, probs1_b, probs1=None): | def _cross_entropy(self, dist, probs1_b, probs1=None): | ||||
| """ | """ | ||||
| @@ -27,7 +27,7 @@ class Distribution(Cell): | |||||
| Args: | Args: | ||||
| seed (int): random seed used in sampling. | seed (int): random seed used in sampling. | ||||
| dtype (mindspore.dtype): type of the distribution. | |||||
| dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype. | |||||
| name (str): Python str name prefixed to Ops created by this class. Default: subclass name. | name (str): Python str name prefixed to Ops created by this class. Default: subclass name. | ||||
| param (dict): parameters used to initialize the distribution. | param (dict): parameters used to initialize the distribution. | ||||