|
|
|
@@ -14,10 +14,9 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""Transformed Distribution""" |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
import mindspore.nn as nn |
|
|
|
from .distribution import Distribution |
|
|
|
from ._utils.utils import check_type, raise_not_impl_error |
|
|
|
from ._utils.utils import raise_not_impl_error |
|
|
|
from ._utils.custom_ops import exp_generic, log_generic |
|
|
|
|
|
|
|
|
|
|
|
@@ -30,7 +29,6 @@ class TransformedDistribution(Distribution): |
|
|
|
Args: |
|
|
|
bijector (Bijector): The transformation to perform. |
|
|
|
distribution (Distribution): The original distribution. |
|
|
|
dtype (mindspore.dtype): The type of the event samples. |
|
|
|
seed (int): The seed is used in sampling. The global seed is used if it is None. |
|
|
|
name (str): The name of the transformed distribution. Default: 'transformed_distribution'. |
|
|
|
|
|
|
|
@@ -45,16 +43,14 @@ class TransformedDistribution(Distribution): |
|
|
|
>>> import mindspore.nn.probability.distribution as msd |
|
|
|
>>> import mindspore.nn.probability.bijector as msb |
|
|
|
>>> ln = msd.TransformedDistribution(msb.Exp(), |
|
|
|
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32), |
|
|
|
>>> dtype=mstype.float32) |
|
|
|
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32)) |
|
|
|
>>> |
|
|
|
>>> # To use a transformed distribution in a network. |
|
|
|
>>> class net(Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(net, self).__init__(): |
|
|
|
>>> self.ln = msd.TransformedDistribution(msb.Exp(), |
|
|
|
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32), |
|
|
|
>>> dtype=mstype.float32) |
|
|
|
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32)) |
|
|
|
>>> |
|
|
|
>>> def construct(self, value): |
|
|
|
>>> # Similar calls can be made to other functions |
|
|
|
@@ -65,7 +61,6 @@ class TransformedDistribution(Distribution): |
|
|
|
def __init__(self, |
|
|
|
bijector, |
|
|
|
distribution, |
|
|
|
dtype, |
|
|
|
seed=None, |
|
|
|
name="transformed_distribution"): |
|
|
|
""" |
|
|
|
@@ -76,9 +71,7 @@ class TransformedDistribution(Distribution): |
|
|
|
[nn.probability.bijector.Bijector], type(self).__name__) |
|
|
|
validator.check_value_type('distribution', distribution, |
|
|
|
[Distribution], type(self).__name__) |
|
|
|
valid_dtype = mstype.number_type |
|
|
|
check_type(dtype, valid_dtype, type(self).__name__) |
|
|
|
super(TransformedDistribution, self).__init__(seed, dtype, name, param) |
|
|
|
super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) |
|
|
|
|
|
|
|
self._bijector = bijector |
|
|
|
self._distribution = distribution |
|
|
|
@@ -96,6 +89,10 @@ class TransformedDistribution(Distribution): |
|
|
|
def distribution(self): |
|
|
|
return self._distribution |
|
|
|
|
|
|
|
@property |
|
|
|
def dtype(self): |
|
|
|
return self.distribution.dtype |
|
|
|
|
|
|
|
@property |
|
|
|
def is_linear_transformation(self): |
|
|
|
return self._is_linear_transformation |
|
|
|
|