|
|
|
@@ -104,7 +104,7 @@ class TransformedDistribution(Distribution): |
|
|
|
return self.exp(self._log_prob(*args, **kwargs)) |
|
|
|
|
|
|
|
def _sample(self, *args, **kwargs): |
|
|
|
org_sample = self.distribution("sample", shape) |
|
|
|
org_sample = self.distribution("sample", *args, **kwargs) |
|
|
|
return self.bijector("forward", org_sample) |
|
|
|
|
|
|
|
def _mean(self, *args, **kwargs): |
|
|
|
@@ -113,6 +113,6 @@ class TransformedDistribution(Distribution): |
|
|
|
This function maybe overridden by derived class. |
|
|
|
""" |
|
|
|
if not self.is_linear_transformation: |
|
|
|
raise_not_impl_error(mean) |
|
|
|
raise_not_impl_error("mean") |
|
|
|
|
|
|
|
return self.bijector("forward", self.distribution("mean")) |