|
|
|
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore._checkparam import Validator |
|
|
|
from .distribution import Distribution |
|
|
|
from ._utils.utils import check_prob, check_distribution_name |
|
|
|
from ._utils.utils import check_prob, check_distribution_name, clamp_probs |
|
|
|
from ._utils.custom_ops import exp_generic, log_generic |
|
|
|
|
|
|
|
|
|
|
|
@@ -86,7 +86,6 @@ class Bernoulli(Distribution): |
|
|
|
>>> ans = b2.mean(probs_a) |
|
|
|
>>> print(ans.shape) |
|
|
|
(1,) |
|
|
|
>>> print(ans.shape) |
|
|
|
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: |
|
|
|
>>> # Args: |
|
|
|
>>> # dist (str): the name of the distribution. Only 'Bernoulli' is supported. |
|
|
|
@@ -132,7 +131,8 @@ class Bernoulli(Distribution): |
|
|
|
param = dict(locals()) |
|
|
|
param['param_dict'] = {'probs': probs} |
|
|
|
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type |
|
|
|
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) |
|
|
|
Validator.check_type_name( |
|
|
|
"dtype", dtype, valid_dtype, type(self).__name__) |
|
|
|
super(Bernoulli, self).__init__(seed, dtype, name, param) |
|
|
|
|
|
|
|
self._probs = self._add_parameter(probs, 'probs') |
|
|
|
@@ -241,6 +241,9 @@ class Bernoulli(Distribution): |
|
|
|
value = self._check_value(value, 'value') |
|
|
|
value = self.cast(value, self.parameter_type) |
|
|
|
probs1 = self._check_param_type(probs1) |
|
|
|
|
|
|
|
# clamp value for numerical stability |
|
|
|
probs1 = clamp_probs(probs1) |
|
|
|
probs0 = 1.0 - probs1 |
|
|
|
return self.log(probs1) * value + self.log(probs0) * (1.0 - value) |
|
|
|
|
|
|
|
@@ -266,8 +269,10 @@ class Bernoulli(Distribution): |
|
|
|
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor) |
|
|
|
comp_zero = self.less(value, 0.0) |
|
|
|
comp_one = self.less(value, 1.0) |
|
|
|
zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0) |
|
|
|
ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0) |
|
|
|
zeros = self.fill(self.parameter_type, self.shape( |
|
|
|
broadcast_shape_tensor), 0.0) |
|
|
|
ones = self.fill(self.parameter_type, self.shape( |
|
|
|
broadcast_shape_tensor), 1.0) |
|
|
|
less_than_zero = self.select(comp_zero, zeros, probs0) |
|
|
|
return self.select(comp_one, less_than_zero, ones) |
|
|
|
|
|
|
|
|