|
|
|
@@ -84,6 +84,7 @@ class TransformedDistribution(Distribution): |
|
|
|
self.parameter_names = distribution.parameter_names |
|
|
|
self.exp = exp_generic |
|
|
|
self.log = log_generic |
|
|
|
self.isnan = P.IsNan() |
|
|
|
self.equal_base = P.Equal() |
|
|
|
self.select_base = P.Select() |
|
|
|
|
|
|
|
@@ -132,7 +133,10 @@ class TransformedDistribution(Distribution): |
|
|
|
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs) |
|
|
|
log_jacobian = self.bijector("inverse_log_jacobian", value) |
|
|
|
isneginf = self.equal_base(unadjust_prob, -np.inf) |
|
|
|
return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian) |
|
|
|
isnan = self.equal_base(unadjust_prob + log_jacobian, np.nan) |
|
|
|
return self.select_base(isneginf, |
|
|
|
self.select_base(isnan, unadjust_prob + log_jacobian, unadjust_prob), |
|
|
|
unadjust_prob + log_jacobian) |
|
|
|
|
|
|
|
def _prob(self, value, *args, **kwargs): |
|
|
|
return self.exp(self._log_prob(value, *args, **kwargs)) |
|
|
|
|