Browse Source

Remove the redundant parameters from bijectors and transformed distribution

tags/v1.1.0
peixu_ren 5 years ago
parent
commit
c7563d53bf
4 changed files with 12 additions and 20 deletions
  1. +1
    -2
      mindspore/nn/probability/bijector/exp.py
  2. +2
    -6
      mindspore/nn/probability/bijector/power_transform.py
  3. +1
    -1
      mindspore/nn/probability/distribution/log_normal.py
  4. +8
    -11
      mindspore/nn/probability/distribution/transformed_distribution.py

+ 1
- 2
mindspore/nn/probability/bijector/exp.py View File

@@ -49,5 +49,4 @@ class Exp(PowerTransform):

def __init__(self,
name='Exp'):
param = dict(locals())
super(Exp, self).__init__(name=name, param=param)
super(Exp, self).__init__(name=name)

+ 2
- 6
mindspore/nn/probability/bijector/power_transform.py View File

@@ -39,9 +39,6 @@ class PowerTransform(Bijector):
Args:
power (int or float): The scale factor. Default: 0.
name (str): The name of the bijector. Default: 'PowerTransform'.
param (dict): The parameters used to initialize the bijector. These parameters are only used when other
Bijectors inherit from powertransform to pass in parameters. In this case the derived Bijector may overwrite
the argument `param`. Default: None.

Examples:
>>> # To initialize a PowerTransform bijector of power 0.5.
@@ -65,9 +62,8 @@ class PowerTransform(Bijector):

def __init__(self,
power=0,
name='PowerTransform',
param=None):
param = dict(locals()) if param is None else param
name='PowerTransform'):
param = dict(locals())
super(PowerTransform, self).__init__(name=name, param=param)
validator.check_value_type('power', power, [int, float], self.name)
validator.check_number("power", power, 0, Rel.GE, self.name)


+ 1
- 1
mindspore/nn/probability/distribution/log_normal.py View File

@@ -135,7 +135,7 @@ class LogNormal(msd.TransformedDistribution):
"""
super(LogNormal, self).__init__(distribution=msd.Normal(loc, scale, dtype=dtype),
bijector=msb.Exp(),
dtype=dtype, seed=seed, name=name)
seed=seed, name=name)

self.log_2pi = np.log(2 * np.pi)



+ 8
- 11
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -14,10 +14,9 @@
# ============================================================================
"""Transformed Distribution"""
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import check_type, raise_not_impl_error
from ._utils.utils import raise_not_impl_error
from ._utils.custom_ops import exp_generic, log_generic


@@ -30,7 +29,6 @@ class TransformedDistribution(Distribution):
Args:
bijector (Bijector): The transformation to perform.
distribution (Distribution): The original distribution.
dtype (mindspore.dtype): The type of the event samples.
seed (int): The seed is used in sampling. The global seed is used if it is None.
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.

@@ -45,16 +43,14 @@ class TransformedDistribution(Distribution):
>>> import mindspore.nn.probability.distribution as msd
>>> import mindspore.nn.probability.bijector as msb
>>> ln = msd.TransformedDistribution(msb.Exp(),
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32),
>>> dtype=mstype.float32)
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32))
>>>
>>> # To use a transformed distribution in a network.
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.ln = msd.TransformedDistribution(msb.Exp(),
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32),
>>> dtype=mstype.float32)
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32))
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other functions
@@ -65,7 +61,6 @@ class TransformedDistribution(Distribution):
def __init__(self,
bijector,
distribution,
dtype,
seed=None,
name="transformed_distribution"):
"""
@@ -76,9 +71,7 @@ class TransformedDistribution(Distribution):
[nn.probability.bijector.Bijector], type(self).__name__)
validator.check_value_type('distribution', distribution,
[Distribution], type(self).__name__)
valid_dtype = mstype.number_type
check_type(dtype, valid_dtype, type(self).__name__)
super(TransformedDistribution, self).__init__(seed, dtype, name, param)
super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param)

self._bijector = bijector
self._distribution = distribution
@@ -96,6 +89,10 @@ class TransformedDistribution(Distribution):
def distribution(self):
return self._distribution

@property
def dtype(self):
return self.distribution.dtype

@property
def is_linear_transformation(self):
return self._is_linear_transformation


Loading…
Cancel
Save