diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index 5d06614017..b1fa0fc9c2 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -84,6 +84,7 @@ class TransformedDistribution(Distribution): self.parameter_names = distribution.parameter_names self.exp = exp_generic self.log = log_generic + self.isnan = P.IsNan() self.equal_base = P.Equal() self.select_base = P.Select() @@ -132,7 +133,10 @@ class TransformedDistribution(Distribution): unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs) log_jacobian = self.bijector("inverse_log_jacobian", value) isneginf = self.equal_base(unadjust_prob, -np.inf) - return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian) + isnan = self.equal_base(unadjust_prob + log_jacobian, np.nan) + return self.select_base(isneginf, + self.select_base(isnan, unadjust_prob + log_jacobian, unadjust_prob), + unadjust_prob + log_jacobian) def _prob(self, value, *args, **kwargs): return self.exp(self._log_prob(value, *args, **kwargs))