浏览代码

!6862 Remove the parameter dtype from transformed distribution

Merge pull request !6862 from peixu_ren/custom_bijector
tags/v1.1.0
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
7814565d26
共有 4 个文件被更改,包括 12 次插入20 次删除
  1. +1
    -2
      mindspore/nn/probability/bijector/exp.py
  2. +2
    -6
      mindspore/nn/probability/bijector/power_transform.py
  3. +1
    -1
      mindspore/nn/probability/distribution/log_normal.py
  4. +8
    -11
      mindspore/nn/probability/distribution/transformed_distribution.py

+ 1
- 2
mindspore/nn/probability/bijector/exp.py 查看文件

@@ -49,5 +49,4 @@ class Exp(PowerTransform):

def __init__(self,
name='Exp'):
param = dict(locals())
super(Exp, self).__init__(name=name, param=param)
super(Exp, self).__init__(name=name)

+ 2
- 6
mindspore/nn/probability/bijector/power_transform.py 查看文件

@@ -39,9 +39,6 @@ class PowerTransform(Bijector):
Args:
power (int or float): The scale factor. Default: 0.
name (str): The name of the bijector. Default: 'PowerTransform'.
param (dict): The parameters used to initialize the bijector. These parameters are only used when other
Bijectors inherit from powertransform to pass in parameters. In this case the derived Bijector may overwrite
the argument `param`. Default: None.

Examples:
>>> # To initialize a PowerTransform bijector of power 0.5.
@@ -65,9 +62,8 @@ class PowerTransform(Bijector):

def __init__(self,
power=0,
name='PowerTransform',
param=None):
param = dict(locals()) if param is None else param
name='PowerTransform'):
param = dict(locals())
super(PowerTransform, self).__init__(name=name, param=param)
validator.check_value_type('power', power, [int, float], self.name)
validator.check_number("power", power, 0, Rel.GE, self.name)


+ 1
- 1
mindspore/nn/probability/distribution/log_normal.py 查看文件

@@ -135,7 +135,7 @@ class LogNormal(msd.TransformedDistribution):
"""
super(LogNormal, self).__init__(distribution=msd.Normal(loc, scale, dtype=dtype),
bijector=msb.Exp(),
dtype=dtype, seed=seed, name=name)
seed=seed, name=name)

self.log_2pi = np.log(2 * np.pi)



+ 8
- 11
mindspore/nn/probability/distribution/transformed_distribution.py 查看文件

@@ -14,10 +14,9 @@
# ============================================================================
"""Transformed Distribution"""
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import check_type, raise_not_impl_error
from ._utils.utils import raise_not_impl_error
from ._utils.custom_ops import exp_generic, log_generic


@@ -30,7 +29,6 @@ class TransformedDistribution(Distribution):
Args:
bijector (Bijector): The transformation to perform.
distribution (Distribution): The original distribution.
dtype (mindspore.dtype): The type of the event samples.
seed (int): The seed is used in sampling. The global seed is used if it is None.
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.

@@ -45,16 +43,14 @@ class TransformedDistribution(Distribution):
>>> import mindspore.nn.probability.distribution as msd
>>> import mindspore.nn.probability.bijector as msb
>>> ln = msd.TransformedDistribution(msb.Exp(),
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32),
>>> dtype=mstype.float32)
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32))
>>>
>>> # To use a transformed distribution in a network.
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.ln = msd.TransformedDistribution(msb.Exp(),
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32),
>>> dtype=mstype.float32)
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32))
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other functions
@@ -65,7 +61,6 @@ class TransformedDistribution(Distribution):
def __init__(self,
bijector,
distribution,
dtype,
seed=None,
name="transformed_distribution"):
"""
@@ -76,9 +71,7 @@ class TransformedDistribution(Distribution):
[nn.probability.bijector.Bijector], type(self).__name__)
validator.check_value_type('distribution', distribution,
[Distribution], type(self).__name__)
valid_dtype = mstype.number_type
check_type(dtype, valid_dtype, type(self).__name__)
super(TransformedDistribution, self).__init__(seed, dtype, name, param)
super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param)

self._bijector = bijector
self._distribution = distribution
@@ -96,6 +89,10 @@ class TransformedDistribution(Distribution):
def distribution(self):
return self._distribution

@property
def dtype(self):
return self.distribution.dtype

@property
def is_linear_transformation(self):
return self._is_linear_transformation


正在加载...
取消
保存