Browse Source

!7051 Fixed zero plus neg_inf issue under fp16

Merge pull request !7051 from XunDeng/pp_issue_branch
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a75b3161e1
1 changed files with 6 additions and 1 deletions
  1. +6
    -1
      mindspore/nn/probability/distribution/transformed_distribution.py

+ 6
- 1
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -13,7 +13,9 @@
# limitations under the License.
# ============================================================================
"""Transformed Distribution"""
import numpy as np
from mindspore._checkparam import Validator as validator
from mindspore.ops import operations as P
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import raise_not_impl_error
@@ -80,6 +82,8 @@ class TransformedDistribution(Distribution):
self.parameter_names = distribution.parameter_names
self.exp = exp_generic
self.log = log_generic
self.equal_base = P.Equal()
self.select_base = P.Select()

@property
def bijector(self):
@@ -125,7 +129,8 @@ class TransformedDistribution(Distribution):
inverse_value = self.bijector("inverse", value)
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs)
log_jacobian = self.bijector("inverse_log_jacobian", value)
return unadjust_prob + log_jacobian
isneginf = self.equal_base(unadjust_prob, -np.inf)
return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian)

def _prob(self, value, *args, **kwargs):
return self.exp(self._log_prob(value, *args, **kwargs))


Loading…
Cancel
Save