Browse Source

!5602 Fix log prob and survival function of exponential distribution

Merge pull request !5602 from XunDeng/pp_issue_branch
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
1d0e0ae27c
1 changed files with 27 additions and 5 deletions
  1. +27
    -5
      mindspore/nn/probability/distribution/exponential.py

+ 27
- 5
mindspore/nn/probability/distribution/exponential.py View File

@@ -198,9 +198,9 @@ class Exponential(Distribution):
return self._entropy(rate) + self._kl_loss(dist, rate_b, rate)


def _prob(self, value, rate=None):
def _log_prob(self, value, rate=None):
r"""
pdf of Exponential distribution.
log_pdf of Exponential distribution.

Args:
Args:
@@ -211,15 +211,16 @@ class Exponential(Distribution):
Value should be greater or equal to zero.

.. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
log_pdf(x) = \log(rate) - rate * x if x >= 0 else 0
"""
value = self._check_value(value, "value")
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
prob = self.exp(self.log(rate) - rate * value)
prob = self.log(rate) - rate * value
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
comp = self.less(value, zeros)
return self.select(comp, zeros, prob)
return self.select(comp, neginf, prob)

def _cdf(self, value, rate=None):
r"""
@@ -243,6 +244,27 @@ class Exponential(Distribution):
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)

def _log_survival(self, value, rate=None):
r"""
log survival_function of Exponential distribution.

Args:
value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate.

Note:
Value should be greater or equal to zero.

.. math::
log_survival_function(x) = -1 * \lambda * x if x >= 0 else 0
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
sf = -1. * rate * value
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, sf)

def _kl_loss(self, dist, rate_b, rate=None):
"""


Loading…
Cancel
Save