Browse Source

!7413 Solve the problem with input of Nan in logrithm calculation

Merge pull request !7413 from peixu_ren/custom_bijector
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c1285335e2
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      mindspore/nn/probability/distribution/transformed_distribution.py

+ 5
- 1
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -84,6 +84,7 @@ class TransformedDistribution(Distribution):
self.parameter_names = distribution.parameter_names self.parameter_names = distribution.parameter_names
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
self.isnan = P.IsNan()
self.equal_base = P.Equal() self.equal_base = P.Equal()
self.select_base = P.Select() self.select_base = P.Select()


@@ -132,7 +133,10 @@ class TransformedDistribution(Distribution):
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs) unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs)
log_jacobian = self.bijector("inverse_log_jacobian", value) log_jacobian = self.bijector("inverse_log_jacobian", value)
isneginf = self.equal_base(unadjust_prob, -np.inf) isneginf = self.equal_base(unadjust_prob, -np.inf)
return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian)
isnan = self.equal_base(unadjust_prob + log_jacobian, np.nan)
return self.select_base(isneginf,
self.select_base(isnan, unadjust_prob + log_jacobian, unadjust_prob),
unadjust_prob + log_jacobian)


def _prob(self, value, *args, **kwargs): def _prob(self, value, *args, **kwargs):
return self.exp(self._log_prob(value, *args, **kwargs)) return self.exp(self._log_prob(value, *args, **kwargs))


Loading…
Cancel
Save