From 3df62c759c67f53a2d8d24e2d58f0282f8e96ecc Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 15 Oct 2020 16:15:14 -0400 Subject: [PATCH] Solve the problem with input of Nan in logrithm calculation --- .../nn/probability/distribution/transformed_distribution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index 5d06614017..b1fa0fc9c2 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -84,6 +84,7 @@ class TransformedDistribution(Distribution): self.parameter_names = distribution.parameter_names self.exp = exp_generic self.log = log_generic + self.isnan = P.IsNan() self.equal_base = P.Equal() self.select_base = P.Select() @@ -132,7 +133,10 @@ class TransformedDistribution(Distribution): unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs) log_jacobian = self.bijector("inverse_log_jacobian", value) isneginf = self.equal_base(unadjust_prob, -np.inf) - return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian) + isnan = self.equal_base(unadjust_prob + log_jacobian, np.nan) + return self.select_base(isneginf, + self.select_base(isnan, unadjust_prob + log_jacobian, unadjust_prob), + unadjust_prob + log_jacobian) def _prob(self, value, *args, **kwargs): return self.exp(self._log_prob(value, *args, **kwargs))