Browse Source

Solve the problem with input of Nan in logrithm calculation

tags/v1.1.0
peixu_ren 5 years ago
parent
commit
3df62c759c
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      mindspore/nn/probability/distribution/transformed_distribution.py

+ 5
- 1
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -84,6 +84,7 @@ class TransformedDistribution(Distribution):
self.parameter_names = distribution.parameter_names
self.exp = exp_generic
self.log = log_generic
self.isnan = P.IsNan()
self.equal_base = P.Equal()
self.select_base = P.Select()

@@ -132,7 +133,10 @@ class TransformedDistribution(Distribution):
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs)
log_jacobian = self.bijector("inverse_log_jacobian", value)
isneginf = self.equal_base(unadjust_prob, -np.inf)
return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian)
isnan = self.equal_base(unadjust_prob + log_jacobian, np.nan)
return self.select_base(isneginf,
self.select_base(isnan, unadjust_prob + log_jacobian, unadjust_prob),
unadjust_prob + log_jacobian)

def _prob(self, value, *args, **kwargs):
return self.exp(self._log_prob(value, *args, **kwargs))


Loading…
Cancel
Save