Browse Source

Fix errors in transformed_distribution

tags/v0.7.0-beta
peixu_ren 5 years ago
parent
commit
98e2a48e4c
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      mindspore/nn/probability/distribution/transformed_distribution.py

+ 2
- 2
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -104,7 +104,7 @@ class TransformedDistribution(Distribution):
return self.exp(self._log_prob(*args, **kwargs))

def _sample(self, *args, **kwargs):
org_sample = self.distribution("sample", shape)
org_sample = self.distribution("sample", *args, **kwargs)
return self.bijector("forward", org_sample)

def _mean(self, *args, **kwargs):
@@ -113,6 +113,6 @@ class TransformedDistribution(Distribution):
This function maybe overridden by derived class.
"""
if not self.is_linear_transformation:
raise_not_impl_error(mean)
raise_not_impl_error("mean")

return self.bijector("forward", self.distribution("mean"))

Loading…
Cancel
Save