|
|
|
@@ -13,7 +13,9 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""Transformed Distribution""" |
|
|
|
import numpy as np |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore.ops import operations as P |
|
|
|
import mindspore.nn as nn |
|
|
|
from .distribution import Distribution |
|
|
|
from ._utils.utils import raise_not_impl_error |
|
|
|
@@ -80,6 +82,8 @@ class TransformedDistribution(Distribution): |
|
|
|
self.parameter_names = distribution.parameter_names |
|
|
|
self.exp = exp_generic |
|
|
|
self.log = log_generic |
|
|
|
self.equal_base = P.Equal() |
|
|
|
self.select_base = P.Select() |
|
|
|
|
|
|
|
@property |
|
|
|
def bijector(self): |
|
|
|
@@ -125,7 +129,8 @@ class TransformedDistribution(Distribution): |
|
|
|
inverse_value = self.bijector("inverse", value) |
|
|
|
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs) |
|
|
|
log_jacobian = self.bijector("inverse_log_jacobian", value) |
|
|
|
return unadjust_prob + log_jacobian |
|
|
|
isneginf = self.equal_base(unadjust_prob, -np.inf) |
|
|
|
return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian) |
|
|
|
|
|
|
|
def _prob(self, value, *args, **kwargs): |
|
|
|
return self.exp(self._log_prob(value, *args, **kwargs)) |
|
|
|
|