diff --git a/mindspore/nn/probability/bijector/exp.py b/mindspore/nn/probability/bijector/exp.py index 13114cd6ac..96a349e897 100644 --- a/mindspore/nn/probability/bijector/exp.py +++ b/mindspore/nn/probability/bijector/exp.py @@ -49,5 +49,4 @@ class Exp(PowerTransform): def __init__(self, name='Exp'): - param = dict(locals()) - super(Exp, self).__init__(name=name, param=param) + super(Exp, self).__init__(name=name) diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index 3242ce6e9b..dd1b017c70 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -39,9 +39,6 @@ class PowerTransform(Bijector): Args: power (int or float): The scale factor. Default: 0. name (str): The name of the bijector. Default: 'PowerTransform'. - param (dict): The parameters used to initialize the bijector. These parameters are only used when other - Bijectors inherit from powertransform to pass in parameters. In this case the derived Bijector may overwrite - the argument `param`. Default: None. Examples: >>> # To initialize a PowerTransform bijector of power 0.5. @@ -65,9 +62,8 @@ class PowerTransform(Bijector): def __init__(self, power=0, - name='PowerTransform', - param=None): - param = dict(locals()) if param is None else param + name='PowerTransform'): + param = dict(locals()) super(PowerTransform, self).__init__(name=name, param=param) validator.check_value_type('power', power, [int, float], self.name) validator.check_number("power", power, 0, Rel.GE, self.name) diff --git a/mindspore/nn/probability/distribution/log_normal.py b/mindspore/nn/probability/distribution/log_normal.py index f3d1eebc7b..69d4059a34 100644 --- a/mindspore/nn/probability/distribution/log_normal.py +++ b/mindspore/nn/probability/distribution/log_normal.py @@ -135,7 +135,7 @@ class LogNormal(msd.TransformedDistribution): """ super(LogNormal, self).__init__(distribution=msd.Normal(loc, scale, dtype=dtype), bijector=msb.Exp(), - dtype=dtype, seed=seed, name=name) + seed=seed, name=name) self.log_2pi = np.log(2 * np.pi) diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index a4a91d0832..cab8f2662f 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -14,10 +14,9 @@ # ============================================================================ """Transformed Distribution""" from mindspore._checkparam import Validator as validator -from mindspore.common import dtype as mstype import mindspore.nn as nn from .distribution import Distribution -from ._utils.utils import check_type, raise_not_impl_error +from ._utils.utils import raise_not_impl_error from ._utils.custom_ops import exp_generic, log_generic @@ -30,7 +29,6 @@ class TransformedDistribution(Distribution): Args: bijector (Bijector): The transformation to perform. distribution (Distribution): The original distribution. - dtype (mindspore.dtype): The type of the event samples. seed (int): The seed is used in sampling. The global seed is used if it is None. name (str): The name of the transformed distribution. Default: 'transformed_distribution'. @@ -45,16 +43,14 @@ class TransformedDistribution(Distribution): >>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.bijector as msb >>> ln = msd.TransformedDistribution(msb.Exp(), - >>> msd.Normal(0.0, 1.0, dtype=mstype.float32), - >>> dtype=mstype.float32) + >>> msd.Normal(0.0, 1.0, dtype=mstype.float32)) >>> >>> # To use a transformed distribution in a network. >>> class net(Cell): >>> def __init__(self): >>> super(net, self).__init__(): >>> self.ln = msd.TransformedDistribution(msb.Exp(), - >>> msd.Normal(0.0, 1.0, dtype=mstype.float32), - >>> dtype=mstype.float32) + >>> msd.Normal(0.0, 1.0, dtype=mstype.float32)) >>> >>> def construct(self, value): >>> # Similar calls can be made to other functions @@ -65,7 +61,6 @@ class TransformedDistribution(Distribution): def __init__(self, bijector, distribution, - dtype, seed=None, name="transformed_distribution"): """ @@ -76,9 +71,7 @@ class TransformedDistribution(Distribution): [nn.probability.bijector.Bijector], type(self).__name__) validator.check_value_type('distribution', distribution, [Distribution], type(self).__name__) - valid_dtype = mstype.number_type - check_type(dtype, valid_dtype, type(self).__name__) - super(TransformedDistribution, self).__init__(seed, dtype, name, param) + super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) self._bijector = bijector self._distribution = distribution @@ -96,6 +89,10 @@ class TransformedDistribution(Distribution): def distribution(self): return self._distribution + @property + def dtype(self): + return self.distribution.dtype + @property def is_linear_transformation(self): return self._is_linear_transformation