| @@ -18,7 +18,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore._checkparam import Validator | from mindspore._checkparam import Validator | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_prob, check_distribution_name | |||||
| from ._utils.utils import check_prob, check_distribution_name, clamp_probs | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -86,7 +86,6 @@ class Bernoulli(Distribution): | |||||
| >>> ans = b2.mean(probs_a) | >>> ans = b2.mean(probs_a) | ||||
| >>> print(ans.shape) | >>> print(ans.shape) | ||||
| (1,) | (1,) | ||||
| >>> print(ans.shape) | |||||
| >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: | >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: | ||||
| >>> # Args: | >>> # Args: | ||||
| >>> # dist (str): the name of the distribution. Only 'Bernoulli' is supported. | >>> # dist (str): the name of the distribution. Only 'Bernoulli' is supported. | ||||
| @@ -132,7 +131,8 @@ class Bernoulli(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | param['param_dict'] = {'probs': probs} | ||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| Validator.check_type_name( | |||||
| "dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | super(Bernoulli, self).__init__(seed, dtype, name, param) | ||||
| self._probs = self._add_parameter(probs, 'probs') | self._probs = self._add_parameter(probs, 'probs') | ||||
| @@ -241,6 +241,9 @@ class Bernoulli(Distribution): | |||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, self.parameter_type) | value = self.cast(value, self.parameter_type) | ||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| # clamp value for numerical stability | |||||
| probs1 = clamp_probs(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | ||||
| @@ -266,8 +269,10 @@ class Bernoulli(Distribution): | |||||
| probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor) | probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor) | ||||
| comp_zero = self.less(value, 0.0) | comp_zero = self.less(value, 0.0) | ||||
| comp_one = self.less(value, 1.0) | comp_one = self.less(value, 1.0) | ||||
| zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0) | |||||
| ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0) | |||||
| zeros = self.fill(self.parameter_type, self.shape( | |||||
| broadcast_shape_tensor), 0.0) | |||||
| ones = self.fill(self.parameter_type, self.shape( | |||||
| broadcast_shape_tensor), 1.0) | |||||
| less_than_zero = self.select(comp_zero, zeros, probs0) | less_than_zero = self.select(comp_zero, zeros, probs0) | ||||
| return self.select(comp_one, less_than_zero, ones) | return self.select(comp_one, less_than_zero, ones) | ||||