Browse Source

redesigned bijector classes, changed bacth_shape calculation in transformed distribution and upgraded dtype logic of bijector class

tags/v1.1.0
Xun Deng 5 years ago
parent
commit
fb0263f869
13 changed files with 451 additions and 158 deletions
  1. +105
    -22
      mindspore/nn/probability/bijector/bijector.py
  2. +28
    -39
      mindspore/nn/probability/bijector/gumbel_cdf.py
  3. +3
    -9
      mindspore/nn/probability/bijector/invert.py
  4. +15
    -13
      mindspore/nn/probability/bijector/power_transform.py
  5. +18
    -18
      mindspore/nn/probability/bijector/scalar_affine.py
  6. +13
    -15
      mindspore/nn/probability/bijector/softplus.py
  7. +6
    -1
      mindspore/nn/probability/distribution/_utils/utils.py
  8. +15
    -16
      mindspore/nn/probability/distribution/distribution.py
  9. +6
    -6
      mindspore/nn/probability/distribution/gumbel.py
  10. +11
    -6
      mindspore/nn/probability/distribution/log_normal.py
  11. +36
    -13
      mindspore/nn/probability/distribution/transformed_distribution.py
  12. +191
    -0
      tests/ut/python/nn/probability/bijector/test_bijector.py
  13. +4
    -0
      tests/ut/python/nn/probability/distribution/test_distribution.py

+ 105
- 22
mindspore/nn/probability/bijector/bijector.py View File

@@ -16,8 +16,10 @@
from mindspore import context from mindspore import context
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import CheckTensor, cast_to_tensor
from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error
from ..distribution import Distribution from ..distribution import Distribution
from ..distribution import TransformedDistribution from ..distribution import TransformedDistribution


@@ -32,6 +34,17 @@ class Bijector(Cell):
name (str): The name of the Bijector. Default: None. name (str): The name of the Bijector. Default: None.
dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None. dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None.
param (dict): The parameters used to initialize the Bijector. Default: None. param (dict): The parameters used to initialize the Bijector. Default: None.

Note:
`dtype` of bijector represents the type of the distributions that the bijector could operate on.
When `dtype` is None, there is no enforcement on the type of input value except that the input value
has to be float type. During initilization, when `dtype` is None, there is no enforcement on the dtype
of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
will be casted into the same type as input value when `dtype`is None.
When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`.
When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be
raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`.
""" """


def __init__(self, def __init__(self,
@@ -48,6 +61,8 @@ class Bijector(Cell):
validator.check_value_type( validator.check_value_type(
'is_constant_jacobian', is_constant_jacobian, [bool], name) 'is_constant_jacobian', is_constant_jacobian, [bool], name)
validator.check_value_type('is_injective', is_injective, [bool], name) validator.check_value_type('is_injective', is_injective, [bool], name)
if dtype is not None:
validator.check_type_name("dtype", dtype, mstype.float_type, type(self).__name__)
self._name = name self._name = name
self._dtype = dtype self._dtype = dtype
self._parameters = {} self._parameters = {}
@@ -57,6 +72,12 @@ class Bijector(Cell):
continue continue
if not(k == 'self' or k.startswith('_')): if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k] self._parameters[k] = param[k]

# if no bijector is used as an argument during initilization
if 'bijector' not in param.keys():
self._batch_shape = self._calc_batch_shape()
self._is_scalar_batch = self._check_is_scalar_batch()

self._is_constant_jacobian = is_constant_jacobian self._is_constant_jacobian = is_constant_jacobian
self._is_injective = is_injective self._is_injective = is_injective


@@ -68,6 +89,8 @@ class Bijector(Cell):
self.dtype_base = P.DType() self.dtype_base = P.DType()
self.shape_base = P.Shape() self.shape_base = P.Shape()
self.fill_base = P.Fill() self.fill_base = P.Fill()
self.sametypeshape_base = P.SameTypeShape()
self.issubclass_base = P.IsSubClass()


@property @property
def name(self): def name(self):
@@ -89,6 +112,38 @@ class Bijector(Cell):
def is_injective(self): def is_injective(self):
return self._is_injective return self._is_injective


@property
def batch_shape(self):
return self._batch_shape

@property
def is_scalar_batch(self):
return self._is_scalar_batch

def _check_value_dtype(self, value):
"""
Firstly check if the input value is Tensor. Then, if `self.dtype` is None, check
if the input tensor is or can be directly cast into a float tensor.
If `self.dtype` is not None, check if the input tensor's dtype is `self.dtype`.
"""
self.checktensor(value, 'input value of bijector')
value_type = self.dtype_base(value)
if self.dtype is None:
if self.issubclass_base(value_type, mstype.float_):
return value
return raise_type_error('input value of bijector', value_type, mstype.float_)
dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0)
self.sametypeshape_base(value, dtype_tensor)
return value

def _shape_mapping(self, shape):
shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
dist_shape_tensor = self.fill_base(self.parameter_type, self.batch_shape, 0.0)
return (shape_tensor + dist_shape_tensor).shape

def shape_mapping(self, shape):
return self._shape_mapping(shape)

def _add_parameter(self, value, name): def _add_parameter(self, value, name):
""" """
Cast `value` to a tensor and add it to `self.default_parameters`. Cast `value` to a tensor and add it to `self.default_parameters`.
@@ -98,26 +153,51 @@ class Bijector(Cell):
if not hasattr(self, 'default_parameters'): if not hasattr(self, 'default_parameters'):
self.default_parameters = [] self.default_parameters = []
self.parameter_names = [] self.parameter_names = []
self.common_dtype = None
# cast value to a tensor if it is not None # cast value to a tensor if it is not None
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
self.default_parameters += [value_t,]
if isinstance(value, bool) or value is None:
raise TypeError(f"{name} cannot be type {type(value)}")
value_t = Tensor(value)
# if the bijector's dtype is not specified
if self.dtype is None:
if self.common_dtype is None:
self.common_dtype = value_t.dtype
elif value_t.dtype != self.common_dtype:
raise TypeError(f"{name} should have the same dtype as other arguments.")
# check if the dtype of the input_parameter agrees with the bijector's dtype
elif value_t.dtype != self.dtype:
raise TypeError(f"{name} should have the same dtype as the bijector's dtype.")
self.default_parameters += [value,]
self.parameter_names += [name,] self.parameter_names += [name,]
return value_t return value_t


def _calc_event_shape(self):
def _calc_batch_shape(self):
""" """
Calculate event_shape based on parameters.
Calculate batch_shape based on parameters.
""" """
broadcast_shape = None
for param in self.default_parameters:
if broadcast_shape is None:
broadcast_shape = self.shape_base(param)
broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0)
param_dict = self.parameters['param_dict']
broadcast_shape_tensor = None
for value in param_dict.values():
if value is None:
return None
if broadcast_shape_tensor is None:
broadcast_shape_tensor = cast_to_tensor(value)
else: else:
broadcast_shape = self.shape_base(param + broadcast_shape_tensor)
broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0)
return broadcast_shape
value = cast_to_tensor(value)
broadcast_shape_tensor = (value + broadcast_shape_tensor)
return broadcast_shape_tensor.shape


def _check_is_scalar_batch(self):
"""
Check if the parameters used during initialization are scalars.
"""
param_dict = self.parameters['param_dict']
for value in param_dict.values():
if value is None:
continue
if not isinstance(value, (int, float)):
return False
return True


def _check_value(self, value, name): def _check_value(self, value, name):
""" """
@@ -127,32 +207,35 @@ class Bijector(Cell):
return value return value


def cast_param_by_value(self, value, para): def cast_param_by_value(self, value, para):
"""
Cast the parameter(s) of the bijector to be the same type of input_value.
"""
local = self.cast_base(para, self.dtype_base(value)) local = self.cast_base(para, self.dtype_base(value))
return local return local


def forward(self, *args, **kwargs):
def forward(self, value, *args, **kwargs):
""" """
Forward transformation: transform the input value to another distribution. Forward transformation: transform the input value to another distribution.
""" """
return self._forward(*args, **kwargs)
return self._forward(value, *args, **kwargs)


def inverse(self, *args, **kwargs):
def inverse(self, value, *args, **kwargs):
""" """
Inverse transformation: transform the input value back to the original distribution. Inverse transformation: transform the input value back to the original distribution.
""" """
return self._inverse(*args, **kwargs)
return self._inverse(value, *args, **kwargs)


def forward_log_jacobian(self, *args, **kwargs):
def forward_log_jacobian(self, value, *args, **kwargs):
""" """
Logarithm of the derivative of the forward transformation. Logarithm of the derivative of the forward transformation.
""" """
return self._forward_log_jacobian(*args, **kwargs)
return self._forward_log_jacobian(value, *args, **kwargs)


def inverse_log_jacobian(self, *args, **kwargs):
def inverse_log_jacobian(self, value, *args, **kwargs):
""" """
Logarithm of the derivative of the inverse transformation. Logarithm of the derivative of the inverse transformation.
""" """
return self._inverse_log_jacobian(*args, **kwargs)
return self._inverse_log_jacobian(value, *args, **kwargs)


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
@@ -167,7 +250,7 @@ class Bijector(Cell):
*args: args[0] shall be either a distribution or the name of a Bijector function. *args: args[0] shall be either a distribution or the name of a Bijector function.
""" """
if isinstance(args[0], Distribution): if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0], self.distribution.dtype)
return TransformedDistribution(self, args[0])
return super(Bijector, self).__call__(*args, **kwargs) return super(Bijector, self).__call__(*args, **kwargs)


def construct(self, name, *args, **kwargs): def construct(self, name, *args, **kwargs):


+ 28
- 39
mindspore/nn/probability/bijector/gumbel_cdf.py View File

@@ -13,10 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""GumbelCDF Bijector""" """GumbelCDF Bijector"""
from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator
from mindspore.ops import operations as P from mindspore.ops import operations as P
from ..distribution._utils.utils import check_greater_zero, set_param_type
from ..distribution._utils.utils import check_greater_zero
from ..distribution._utils.custom_ops import exp_generic, log_generic from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector from .bijector import Bijector


@@ -30,12 +28,11 @@ class GumbelCDF(Bijector):
Y = \exp(-\exp(\frac{-(X - loc)}{scale})) Y = \exp(-\exp(\frac{-(X - loc)}{scale}))


Note: Note:
For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1).
For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1).


Args: Args:
loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0..
scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32.
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0..
scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
name (str): The name of the Bijector. Default: 'Gumbel_CDF'. name (str): The name of the Bijector. Default: 'Gumbel_CDF'.


Examples: Examples:
@@ -61,22 +58,18 @@ class GumbelCDF(Bijector):
def __init__(self, def __init__(self,
loc=0.0, loc=0.0,
scale=1.0, scale=1.0,
dtype=mstype.float32,
name='GumbelCDF'): name='GumbelCDF'):
""" """
Constructor of GumbelCDF Bijector. Constructor of GumbelCDF Bijector.
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype)
super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param)
param['param_dict'] = {'loc': loc, 'scale': scale}
super(GumbelCDF, self).__init__(name=name, param=param)


self._parameter_type = parameter_type
self._loc = self._add_parameter(loc, 'loc') self._loc = self._add_parameter(loc, 'loc')
self._scale = self._add_parameter(scale, 'scale') self._scale = self._add_parameter(scale, 'scale')
check_greater_zero(self._scale, "scale") check_greater_zero(self._scale, "scale")
self._event_shape = self._calc_event_shape()


self.cast = P.Cast() self.cast = P.Cast()
self.exp = exp_generic self.exp = exp_generic
@@ -91,38 +84,34 @@ class GumbelCDF(Bijector):
def scale(self): def scale(self):
return self._scale return self._scale


@property
def event_shape(self):
return self._event_shape

@property
def parameter_type(self):
return self._parameter_type

def extend_repr(self): def extend_repr(self):
return f'loc = {self.loc}, scale = {self.scale}'

def shape_mapping(self, shape):
return shape
if self.is_scalar_batch:
str_info = f'loc = {self.loc}, scale = {self.scale}'
else:
str_info = f'batch_shape = {self.batch_shape}'
return str_info


def _forward(self, x): def _forward(self, x):
x = self._check_value(x, 'value')
x = self.cast(x, self.parameter_type)
z = (x - self.loc) / self.scale
x = self._check_value_dtype(x)
loc_local = self.cast_param_by_value(x, self.loc)
scale_local = self.cast_param_by_value(x, self.scale)
z = (x - loc_local) / scale_local
return self.exp(-self.exp(-z)) return self.exp(-self.exp(-z))


def _inverse(self, y): def _inverse(self, y):
y = self._check_value(y, 'value')
y = self.cast(y, self.parameter_type)
return self.loc - self.scale * self.log(-self.log(y))
y = self._check_value_dtype(y)
loc_local = self.cast_param_by_value(y, self.loc)
scale_local = self.cast_param_by_value(y, self.scale)
return loc_local - scale_local * self.log(-self.log(y))


def _forward_log_jacobian(self, x): def _forward_log_jacobian(self, x):
x = self._check_value(x, 'value')
x = self.cast(x, self.parameter_type)
z = (x - self.loc) / self.scale
return -z - self.exp(-z) - self.log(self.scale)
x = self._check_value_dtype(x)
loc_local = self.cast_param_by_value(x, self.loc)
scale_local = self.cast_param_by_value(x, self.scale)
z = (x - loc_local) / scale_local
return -z - self.exp(-z) - self.log(scale_local)


def _inverse_log_jacobian(self, y): def _inverse_log_jacobian(self, y):
y = self._check_value(y, 'value')
y = self.cast(y, self.parameter_type)
return self.log(self.scale / (-1. * y * self.log(y)))
y = self._check_value_dtype(y)
scale_local = self.cast_param_by_value(y, self.scale)
return self.log(scale_local / (-1. * y * self.log(y)))

+ 3
- 9
mindspore/nn/probability/bijector/invert.py View File

@@ -53,23 +53,17 @@ class Invert(Bijector):
name = (name + bijector.name) if name == 'Invert' else name name = (name + bijector.name) if name == 'Invert' else name
super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian, super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian,
is_injective=bijector.is_injective, is_injective=bijector.is_injective,
dtype=bijector.dtype,
name=name, name=name,
dtype=bijector.dtype,
param=param) param=param)
self._bijector = bijector self._bijector = bijector
if hasattr(self._bijector, 'event_shape'):
self._event_shape = self.bijector.event_shape
else:
self._event_shape = ()
self._batch_shape = self.bijector.batch_shape
self._is_scalar_batch = self.bijector.is_scalar_batch


@property @property
def bijector(self): def bijector(self):
return self._bijector return self._bijector


@property
def event_shape(self):
return self._event_shape

def inverse(self, y): def inverse(self, y):
return self.bijector("forward", y) return self.bijector("forward", y)




+ 15
- 13
mindspore/nn/probability/bijector/power_transform.py View File

@@ -14,8 +14,7 @@
# ============================================================================ # ============================================================================
"""Power Bijector""" """Power Bijector"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ..distribution._utils.utils import check_greater_equal_zero
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from .bijector import Bijector from .bijector import Bijector


@@ -37,7 +36,7 @@ class PowerTransform(Bijector):
ValueError: When the power is less than 0 or is not known statically. ValueError: When the power is less than 0 or is not known statically.


Args: Args:
power (int or float): The scale factor. Default: 0.
power (float, list, numpy.ndarray, Tensor): The scale factor. Default: 0.
name (str): The name of the bijector. Default: 'PowerTransform'. name (str): The name of the bijector. Default: 'PowerTransform'.


Examples: Examples:
@@ -64,10 +63,11 @@ class PowerTransform(Bijector):
power=0, power=0,
name='PowerTransform'): name='PowerTransform'):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'power': power}
super(PowerTransform, self).__init__(name=name, param=param) super(PowerTransform, self).__init__(name=name, param=param)
validator.check_value_type('power', power, [int, float], self.name)
validator.check_number("power", power, 0, Rel.GE, self.name)
self._power = power
self._power = self._add_parameter(power, 'power')
check_greater_equal_zero(self._power, 'Power')
self.pow = P.Pow() self.pow = P.Pow()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.cast = P.Cast() self.cast = P.Cast()
@@ -81,13 +81,15 @@ class PowerTransform(Bijector):
return self._power return self._power


def extend_repr(self): def extend_repr(self):
return f'power = {self.power}'
if self.is_scalar_batch:
str_info = f'power = {self.power}'
else:
str_info = f'batch_shape = {self.batch_shape}'
return str_info


def shape_mapping(self, shape):
return shape


def _forward(self, x): def _forward(self, x):
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
power_local = self.cast_param_by_value(x, self.power) power_local = self.cast_param_by_value(x, self.power)
if power_local == 0: if power_local == 0:
forward_v = self.exp(x) forward_v = self.exp(x)
@@ -96,7 +98,7 @@ class PowerTransform(Bijector):
return forward_v return forward_v


def _inverse(self, y): def _inverse(self, y):
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
power_local = self.cast_param_by_value(y, self.power) power_local = self.cast_param_by_value(y, self.power)
if power_local == 0: if power_local == 0:
inverse_v = self.log(y) inverse_v = self.log(y)
@@ -116,7 +118,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
""" """
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
power_local = self.cast_param_by_value(x, self.power) power_local = self.cast_param_by_value(x, self.power)
if power_local == 0: if power_local == 0:
forward_log_j = x forward_log_j = x
@@ -136,7 +138,7 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y} f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
""" """
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
power_local = self.cast_param_by_value(y, self.power) power_local = self.cast_param_by_value(y, self.power)
inverse_log_j = (power_local - 1) * self.log(y) inverse_log_j = (power_local - 1) * self.log(y)
return inverse_log_j return inverse_log_j

+ 18
- 18
mindspore/nn/probability/bijector/scalar_affine.py View File

@@ -14,8 +14,6 @@
# ============================================================================ # ============================================================================
"""Scalar Affine Bijector""" """Scalar Affine Bijector"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor
from ..distribution._utils.custom_ops import log_generic from ..distribution._utils.custom_ops import log_generic
from .bijector import Bijector from .bijector import Bijector


@@ -30,10 +28,14 @@ class ScalarAffine(Bijector):
where a is the scale factor and b is the shift factor. where a is the scale factor and b is the shift factor.


Args: Args:
scale (float): The scale factor. Default: 1.0.
shift (float): The shift factor. Default: 0.0.
scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0.
name (str): The name of the bijector. Default: 'ScalarAffine'. name (str): The name of the bijector. Default: 'ScalarAffine'.


Note:
If `shift`, `scale` are passed in as numpy.ndarray or tensor, they have to have
the same dtype otherwise an error will be raised.

Examples: Examples:
>>> # To initialize a ScalarAffine bijector of scale 1 and shift 2. >>> # To initialize a ScalarAffine bijector of scale 1 and shift 2.
>>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2) >>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2)
@@ -61,10 +63,7 @@ class ScalarAffine(Bijector):
Constructor of ScalarAffine Bijector. Constructor of ScalarAffine Bijector.
""" """
param = dict(locals()) param = dict(locals())
validator.check_value_type(
'scale', scale, [int, float], type(self).__name__)
validator.check_value_type(
'shift', shift, [int, float], type(self).__name__)
param['param_dict'] = {'scale': scale, 'shift': shift}
super(ScalarAffine, self).__init__( super(ScalarAffine, self).__init__(
is_constant_jacobian=True, is_constant_jacobian=True,
is_injective=True, is_injective=True,
@@ -72,8 +71,8 @@ class ScalarAffine(Bijector):
dtype=None, dtype=None,
param=param) param=param)


self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift)
self._scale = self._add_parameter(scale, 'scale')
self._shift = self._add_parameter(shift, 'shift')


self.abs = P.Abs() self.abs = P.Abs()
self.oneslike = P.OnesLike() self.oneslike = P.OnesLike()
@@ -90,17 +89,18 @@ class ScalarAffine(Bijector):
return self._shift return self._shift


def extend_repr(self): def extend_repr(self):
return f'scale = {self.scale}, shift = {self.shift}'

def shape_mapping(self, shape):
return shape
if self.is_scalar_batch:
str_info = f'scale = {self.scale}, shift = {self.shift}'
else:
str_info = f'batch_shape = {self.batch_shape}'
return str_info


def _forward(self, x): def _forward(self, x):
r""" r"""
.. math:: .. math::
f(x) = a * x + b f(x) = a * x + b
""" """
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
scale_local = self.cast_param_by_value(x, self.scale) scale_local = self.cast_param_by_value(x, self.scale)
shift_local = self.cast_param_by_value(x, self.shift) shift_local = self.cast_param_by_value(x, self.shift)
forward_v = scale_local * x + shift_local * self.oneslike(x) forward_v = scale_local * x + shift_local * self.oneslike(x)
@@ -111,7 +111,7 @@ class ScalarAffine(Bijector):
.. math:: .. math::
f(y) = \frac{y - b}{a} f(y) = \frac{y - b}{a}
""" """
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
scale_local = self.cast_param_by_value(y, self.scale) scale_local = self.cast_param_by_value(y, self.scale)
shift_local = self.cast_param_by_value(y, self.shift) shift_local = self.cast_param_by_value(y, self.shift)
inverse_v = (y - shift_local) / scale_local inverse_v = (y - shift_local) / scale_local
@@ -124,7 +124,7 @@ class ScalarAffine(Bijector):
f'(x) = a f'(x) = a
\log(f'(x)) = \log(a) \log(f'(x)) = \log(a)
""" """
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
scale_local = self.cast_param_by_value(x, self.scale) scale_local = self.cast_param_by_value(x, self.scale)
forward_log_j = self.log(self.abs(scale_local)) forward_log_j = self.log(self.abs(scale_local))
return forward_log_j return forward_log_j
@@ -136,7 +136,7 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a} f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a) \log(f'(x)) = - \log(a)
""" """
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
scale_local = self.cast_param_by_value(y, self.scale) scale_local = self.cast_param_by_value(y, self.scale)
inverse_log_j = -1. * self.log(self.abs(scale_local)) inverse_log_j = -1. * self.log(self.abs(scale_local))
return inverse_log_j return inverse_log_j

+ 13
- 15
mindspore/nn/probability/bijector/softplus.py View File

@@ -16,8 +16,6 @@
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn.layer.activation import LogSigmoid from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
from .bijector import Bijector from .bijector import Bijector


@@ -32,7 +30,7 @@ class Softplus(Bijector):
where k is the sharpness factor. where k is the sharpness factor.


Args: Args:
sharpness (float): The scale factor. Default: 1.0.
sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
name (str): The name of the Bijector. Default: 'Softplus'. name (str): The name of the Bijector. Default: 'Softplus'.


Examples: Examples:
@@ -61,10 +59,9 @@ class Softplus(Bijector):
Constructor of Softplus Bijector. Constructor of Softplus Bijector.
""" """
param = dict(locals()) param = dict(locals())
validator.check_value_type('sharpness', sharpness,
[int, float], type(self).__name__)
super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness)
param['param_dict'] = {'sharpness': sharpness}
super(Softplus, self).__init__(name=name, dtype=None, param=param)
self._sharpness = self._add_parameter(sharpness, 'sharpness')


self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
@@ -118,13 +115,14 @@ class Softplus(Bijector):
return self._sharpness return self._sharpness


def extend_repr(self): def extend_repr(self):
return f'sharpness = {self.sharpness}'

def shape_mapping(self, shape):
return shape
if self.is_scalar_batch:
str_info = f'sharpness = {self.sharpness}'
else:
str_info = f'batch_shape = {self.batch_shape}'
return str_info


def _forward(self, x): def _forward(self, x):
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
sharpness_local = self.cast_param_by_value(x, self.sharpness) sharpness_local = self.cast_param_by_value(x, self.sharpness)
scaled_value = sharpness_local * x scaled_value = sharpness_local * x
forward_v = self.softplus(scaled_value) / sharpness_local forward_v = self.softplus(scaled_value) / sharpness_local
@@ -136,7 +134,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k} f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
""" """
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
sharpness_local = self.cast_param_by_value(y, self.sharpness) sharpness_local = self.cast_param_by_value(y, self.sharpness)
scaled_value = sharpness_local * y scaled_value = sharpness_local * y
inverse_v = self.inverse_softplus(scaled_value) / sharpness_local inverse_v = self.inverse_softplus(scaled_value) / sharpness_local
@@ -149,7 +147,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
""" """
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
sharpness_local = self.cast_param_by_value(x, self.sharpness) sharpness_local = self.cast_param_by_value(x, self.sharpness)
scaled_value = sharpness_local * x scaled_value = sharpness_local * x
forward_log_j = self.log_sigmoid(scaled_value) forward_log_j = self.log_sigmoid(scaled_value)
@@ -162,7 +160,7 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1} f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
""" """
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
sharpness_local = self.cast_param_by_value(y, self.sharpness) sharpness_local = self.cast_param_by_value(y, self.sharpness)
scaled_value = sharpness_local * y scaled_value = sharpness_local * y
inverse_log_j = scaled_value - self.inverse_softplus(scaled_value) inverse_log_j = scaled_value - self.inverse_softplus(scaled_value)


+ 6
- 1
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -229,6 +229,11 @@ def raise_not_implemented_util(func_name, obj, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
f"{func_name} is not implemented for {obj} distribution.") f"{func_name} is not implemented for {obj} distribution.")


@constexpr
def raise_type_error(name, cur_type, required_type):
raise TypeError(
f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}")



@constexpr @constexpr
def check_distribution_name(name, expected_name): def check_distribution_name(name, expected_name):
@@ -304,7 +309,7 @@ def set_param_type(args, hint_type):
TypeError: if tensors in args are not the same dtype. TypeError: if tensors in args are not the same dtype.
""" """
int_type = mstype.int_type + mstype.uint_type int_type = mstype.int_type + mstype.uint_type
if hint_type in int_type:
if hint_type in int_type or hint_type is None:
hint_type = mstype.float32 hint_type = mstype.float32
common_dtype = None common_dtype = None
for name, arg in args.items(): for name, arg in args.items():


+ 15
- 16
mindspore/nn/probability/distribution/distribution.py View File

@@ -72,13 +72,12 @@ class Distribution(Cell):
if not(k == 'self' or k.startswith('_')): if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k] self._parameters[k] = param[k]


# some attributes
if 'distribution' in self.parameters.keys():
self.parameter_type = self.parameters['distribution'].parameter_type
else:
# if not a transformed distribution, set the following attribute
if 'distribution' not in self.parameters.keys():
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
self._broadcast_shape = self._calc_broadcast_shape()
self._is_scalar_batch = self._check_is_scalar_batch()
self._batch_shape = self._calc_batch_shape()
self._is_scalar_batch = self._check_is_scalar_batch()
self._broadcast_shape = self._batch_shape


# set the function to call according to the derived class's attributes # set the function to call according to the derived class's attributes
self._set_prob() self._set_prob()
@@ -128,6 +127,10 @@ class Distribution(Cell):
def is_scalar_batch(self): def is_scalar_batch(self):
return self._is_scalar_batch return self._is_scalar_batch


@property
def batch_shape(self):
return self._batch_shape

@property @property
def broadcast_shape(self): def broadcast_shape(self):
return self._broadcast_shape return self._broadcast_shape
@@ -208,8 +211,6 @@ class Distribution(Cell):
""" """
Check if the parameters used during initialization are scalars. Check if the parameters used during initialization are scalars.
""" """
if 'distribution' in self.parameters.keys():
return self.parameters['distribution'].is_scalar_batch
param_dict = self.parameters['param_dict'] param_dict = self.parameters['param_dict']
for value in param_dict.values(): for value in param_dict.values():
if value is None: if value is None:
@@ -218,12 +219,10 @@ class Distribution(Cell):
return False return False
return True return True


def _calc_broadcast_shape(self):
def _calc_batch_shape(self):
""" """
Calculate the broadcast shape of the parameters used during initialization. Calculate the broadcast shape of the parameters used during initialization.
""" """
if 'distribution' in self.parameters.keys():
return self.parameters['distribution'].broadcast_shape
param_dict = self.parameters['param_dict'] param_dict = self.parameters['param_dict']
broadcast_shape_tensor = None broadcast_shape_tensor = None
for value in param_dict.values(): for value in param_dict.values():
@@ -362,14 +361,14 @@ class Distribution(Cell):
""" """
return self._get_dist_args(*args, **kwargs) return self._get_dist_args(*args, **kwargs)


def _get_dist_type(self, *args, **kwargs):
return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs)
def _get_dist_type(self):
return raise_not_implemented_util('get_dist_type', self.name)


def get_dist_type(self, *args, **kwargs):
def get_dist_type(self):
""" """
Return the type of the distribution. Return the type of the distribution.
""" """
return self._get_dist_type(*args, **kwargs)
return self._get_dist_type()


def _raise_not_implemented_error(self, func_name): def _raise_not_implemented_error(self, func_name):
name = self.name name = self.name
@@ -751,5 +750,5 @@ class Distribution(Cell):
if name == 'get_dist_args': if name == 'get_dist_args':
return self._get_dist_args(*args, **kwargs) return self._get_dist_args(*args, **kwargs)
if name == 'get_dist_type': if name == 'get_dist_type':
return self._get_dist_type(*args, **kwargs)
return self._get_dist_type()
return raise_not_implemented_util(name, self.name, *args, **kwargs) return raise_not_implemented_util(name, self.name, *args, **kwargs)

+ 6
- 6
mindspore/nn/probability/distribution/gumbel.py View File

@@ -103,17 +103,12 @@ class Gumbel(TransformedDistribution):
""" """
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
gumbel_cdf = msb.GumbelCDF(loc, scale, dtype)
gumbel_cdf = msb.GumbelCDF(loc, scale)
super(Gumbel, self).__init__( super(Gumbel, self).__init__(
distribution=msd.Uniform(0.0, 1.0, dtype=dtype), distribution=msd.Uniform(0.0, 1.0, dtype=dtype),
bijector=msb.Invert(gumbel_cdf), bijector=msb.Invert(gumbel_cdf),
seed=seed, name=name) seed=seed, name=name)


self.parameter_type = gumbel_cdf.parameter_type
self._broadcast_shape = gumbel_cdf.event_shape
if self._broadcast_shape != ():
self._is_scalar_batch = False

# overwrite default_parameters and parameter_names # overwrite default_parameters and parameter_names
self._reset_parameters() self._reset_parameters()
self._loc = self._add_parameter(loc, 'loc') self._loc = self._add_parameter(loc, 'loc')
@@ -202,6 +197,7 @@ class Gumbel(TransformedDistribution):
where z = \frac{x - loc}{scale} where z = \frac{x - loc}{scale}
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
z = (value - self.loc) / self.scale z = (value - self.loc) / self.scale
return -(z + self.exp(-z)) - self.log(self.scale) return -(z + self.exp(-z)) - self.log(self.scale)


@@ -210,6 +206,8 @@ class Gumbel(TransformedDistribution):
.. math:: .. math::
cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale}) cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale})
""" """
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
return self._gumbel_bijector("forward", value) return self._gumbel_bijector("forward", value)


def _cross_entropy(self, dist, loc_b, scale_b): def _cross_entropy(self, dist, loc_b, scale_b):
@@ -251,12 +249,14 @@ class Gumbel(TransformedDistribution):
self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.)) self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.))


def _sample(self, shape=()): def _sample(self, shape=()):
shape = self.checktuple(shape, 'shape')
origin_shape = shape + self._broadcast_shape origin_shape = shape + self._broadcast_shape
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)
else: else:
sample_shape = origin_shape sample_shape = origin_shape
org_sample = self.distribution("sample", sample_shape) org_sample = self.distribution("sample", sample_shape)
org_sample = self.cast(org_sample, self.dtype)
value = self.bijector("forward", org_sample) value = self.bijector("forward", org_sample)
if origin_shape == (): if origin_shape == ():
value = self.squeeze(value) value = self.squeeze(value)


+ 11
- 6
mindspore/nn/probability/distribution/log_normal.py View File

@@ -137,6 +137,11 @@ class LogNormal(msd.TransformedDistribution):
bijector=msb.Exp(), bijector=msb.Exp(),
seed=seed, name=name) seed=seed, name=name)


# overwrite default_parameters and parameter_names
self._reset_parameters()
self._loc = self._add_parameter(loc, 'loc')
self._scale = self._add_parameter(scale, 'scale')

self.log_2pi = np.log(2 * np.pi) self.log_2pi = np.log(2 * np.pi)


#ops needed for the class #ops needed for the class
@@ -154,12 +159,12 @@ class LogNormal(msd.TransformedDistribution):
@property @property
def loc(self): def loc(self):
"""Distribution parameter for the pre-transformed mean.""" """Distribution parameter for the pre-transformed mean."""
return self.distribution("mean")
return self._loc


@property @property
def scale(self): def scale(self):
"""Distribution parameter for the pre-transformed standard deviation.""" """Distribution parameter for the pre-transformed standard deviation."""
return self.distribution("sd")
return self._scale


def _get_dist_type(self): def _get_dist_type(self):
return "LogNormal" return "LogNormal"
@@ -168,18 +173,18 @@ class LogNormal(msd.TransformedDistribution):
if loc is not None: if loc is not None:
self.checktensor(loc, 'loc') self.checktensor(loc, 'loc')
else: else:
loc = self.distribution("mean")
loc = self.loc
if scale is not None: if scale is not None:
self.checktensor(scale, 'scale') self.checktensor(scale, 'scale')
else: else:
scale = self.distribution("sd")
scale = self.scale
return loc, scale return loc, scale


def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'loc = {self._mean_value}, scale = {self._sd_value}'
s = f'loc = {self.loc}, scale = {self.scale}'
else: else:
s = f'batch_shape = {self._broadcast_shape}'
s = f'batch_shape = {self.broadcast_shape}'
return s return s


def _mean(self, loc=None, scale=None): def _mean(self, loc=None, scale=None):


+ 36
- 13
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -16,6 +16,7 @@
import numpy as np import numpy as np
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import raise_not_impl_error from ._utils.utils import raise_not_impl_error
@@ -30,7 +31,7 @@ class TransformedDistribution(Distribution):


Args: Args:
bijector (Bijector): The transformation to perform. bijector (Bijector): The transformation to perform.
distribution (Distribution): The original distribution.
distribution (Distribution): The original distribution. Must has dtype of mindspore.float_type.
seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None. seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None.
If this seed is given when a TransformedDistribution object is initialised, the object's sampling function If this seed is given when a TransformedDistribution object is initialised, the object's sampling function
will use this seed; elsewise, the underlying distribution's seed will be used. will use this seed; elsewise, the underlying distribution's seed will be used.
@@ -40,6 +41,12 @@ class TransformedDistribution(Distribution):
The arguments used to initialize the original distribution cannot be None. The arguments used to initialize the original distribution cannot be None.
For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a
TransformedDistribution since `mean` and `sd` are not specified. TransformedDistribution since `mean` and `sd` are not specified.
`batch_shape` is the batch_shape of the original distribution.
`broadcast_shape` is the broadcast shape between the original distribution and bijector.
`is_scalar_batch` is only true if both the original distribution and the bijector are scalar batches.
`default_parameters`, `parameter_names` and `parameter_type` are set to be consistent with the original
distribution. Derived class can overwrite `default_parameters` and `parameter_names` by calling
`reset_parameters` followed by `add_parameter`.


Examples: Examples:
>>> # To initialize a transformed distribution, e.g. a lognormal distribution, >>> # To initialize a transformed distribution, e.g. a lognormal distribution,
@@ -75,28 +82,34 @@ class TransformedDistribution(Distribution):
[nn.probability.bijector.Bijector], type(self).__name__) [nn.probability.bijector.Bijector], type(self).__name__)
validator.check_value_type('distribution', distribution, validator.check_value_type('distribution', distribution,
[Distribution], type(self).__name__) [Distribution], type(self).__name__)
validator.check_type_name("dtype", distribution.dtype, mstype.float_type, type(self).__name__)
super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param)


self._bijector = bijector self._bijector = bijector
self._distribution = distribution self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian
self.default_parameters = distribution.default_parameters
self.parameter_names = distribution.parameter_names

# set attributes
self._is_linear_transformation = self.bijector.is_constant_jacobian
self._dtype = self.distribution.dtype
self._is_scalar_batch = self.distribution.is_scalar_batch and self.bijector.is_scalar_batch
self._batch_shape = self.distribution.batch_shape

self.default_parameters = self.distribution.default_parameters
self.parameter_names = self.distribution.parameter_names
# by default, set the parameter_type to be the distribution's parameter_type
self.parameter_type = self.distribution.parameter_type


self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
self.isnan = P.IsNan() self.isnan = P.IsNan()
self.cast_base = P.Cast()
self.equal_base = P.Equal() self.equal_base = P.Equal()
self.select_base = P.Select() self.select_base = P.Select()
self.fill = P.Fill()
self.fill_base = P.Fill()

# broadcast bijector batch_shape and distribution batch_shape
self._broadcast_shape = self._broadcast_bijector_dist()


# check if batch shape of the distribution and event shape is broadcastable
if hasattr(self.bijector, 'event_shape'):
event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0)
broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0)
self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape
else:
self._batch_event = self.broadcast_shape


@property @property
def bijector(self): def bijector(self):
@@ -108,12 +121,22 @@ class TransformedDistribution(Distribution):


@property @property
def dtype(self): def dtype(self):
return self.distribution.dtype
return self._dtype


@property @property
def is_linear_transformation(self): def is_linear_transformation(self):
return self._is_linear_transformation return self._is_linear_transformation


def _broadcast_bijector_dist(self):
"""
check if the batch shape of base distribution and the bijector is broadcastable.
"""
if self.batch_shape is None or self.bijector.batch_shape is None:
return None
bijector_shape_tensor = self.fill_base(self.dtype, self.bijector.batch_shape, 0.0)
dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0)
return (bijector_shape_tensor + dist_shape_tensor).shape

def _cdf(self, value, *args, **kwargs): def _cdf(self, value, *args, **kwargs):
r""" r"""
.. math:: .. math::


+ 191
- 0
tests/ut/python/nn/probability/bijector/test_bijector.py View File

@@ -0,0 +1,191 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for exp"""
import numpy as np
import pytest

import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype

class MyBijector(msb.Bijector):
"""
Customized bijector class with dtype not specified.
"""
def __init__(self, param1, param2):
param = dict(locals())
param['param_dict'] = {'param1': param1, 'param2': param2}
super(MyBijector, self).__init__(name='MyBijector', dtype=None, param=param)

self._param1 = self._add_parameter(param1, 'param1')
self._param2 = self._add_parameter(param2, 'param2')

@property
def param1(self):
return self._param1

@property
def param2(self):
return self._param2

def _forward(self, value):
value = self._check_value_dtype(value)
param1_local = self.cast_param_by_value(value, self.param1)
param2_local = self.cast_param_by_value(value, self.param2)
return value * param1_local + param2_local

class MySecondBijector(msb.Bijector):
"""
Customized bijector class with dtype specified.
"""
def __init__(self, param1, param2):
param = dict(locals())
param['param_dict'] = {'param1': param1, 'param2': param2}
super(MySecondBijector, self).__init__(name='MySecondBijector', dtype=dtype.float32, param=param)

self._param1 = self._add_parameter(param1, 'param1')
self._param2 = self._add_parameter(param2, 'param2')

@property
def param1(self):
return self._param1

@property
def param2(self):
return self._param2

def _forward(self, value):
value = self._check_value_dtype(value)
param1_local = self.cast_param_by_value(value, self.param1)
param2_local = self.cast_param_by_value(value, self.param2)
return value * param1_local + param2_local

def test_arguments_same_type():
"""
Test bijector initializations.
"""
param1_1 = np.array(1.0).astype(np.float16)
param2_1 = np.array(2.0).astype(np.float32)
with pytest.raises(TypeError):
MyBijector(param1_1, param2_1)
param1_2 = Tensor(1.0, dtype=dtype.float16)
param2_2 = Tensor(2.0, dtype=dtype.float32)
with pytest.raises(TypeError):
MyBijector(param1_2, param2_2)
with pytest.raises(TypeError):
MyBijector(True, param2_2)
with pytest.raises(TypeError):
MyBijector(None, param2_2)
param1_3 = Tensor(1.0, dtype=dtype.float32)
param2_3 = Tensor(2.0, dtype=dtype.float32)
bijector = MyBijector(param1_3, param2_3)
assert isinstance(bijector, msb.Bijector)
param1_4 = np.array([1.0, 2.0]).astype(np.float16)
param2_4 = np.array([1.0, 2.0]).astype(np.float16)
bijector = MyBijector(param1_4, param2_4)
assert isinstance(bijector, msb.Bijector)
bijector = MyBijector(1.0, 2.0)
assert isinstance(bijector, msb.Bijector)

def test_arguments_with_dtype_specified():
"""
Customized bijector class with dtype not specified.
"""
param1_1 = np.array(1.0).astype(np.float16)
param2_1 = np.array(2.0).astype(np.float16)
with pytest.raises(TypeError):
MySecondBijector(param1_1, param2_1)
param1_2 = Tensor(1.0, dtype=dtype.float16)
param2_2 = Tensor(2.0, dtype=dtype.float32)
with pytest.raises(TypeError):
MySecondBijector(param1_2, param2_2)
with pytest.raises(TypeError):
MySecondBijector(True, param2_2)
with pytest.raises(TypeError):
MySecondBijector(None, param2_2)
param1_3 = Tensor(1.0, dtype=dtype.float32)
param2_3 = Tensor(2.0, dtype=dtype.float32)
bijector = MyBijector(param1_3, param2_3)
assert isinstance(bijector, msb.Bijector)
param1_4 = np.array(2.0).astype(np.float32)
param2_4 = np.array(1.0).astype(np.float32)
bijector = MyBijector(param1_4, param2_4)
assert isinstance(bijector, msb.Bijector)

class Net1(nn.Cell):
"""
Test input value when bijector's dtype is not specified.
"""
def __init__(self):
super(Net1, self).__init__()
self.bijector = MyBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32))

def construct(self, value):
return self.bijector.forward(value)

class Net2(nn.Cell):
"""
Test input value when bijector's dtype is specified.
"""
def __init__(self):
super(Net2, self).__init__()
self.bijector = MySecondBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32))

def construct(self, value):
return self.bijector.forward(value)

def test_input_value():
"""
Test validity of input value.
"""
net = Net1()
value = None
with pytest.raises(TypeError):
ans = net(value)
value = 1.0
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.int32)
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.float32)
ans = net(value)
assert ans.dtype == dtype.float32
value = Tensor(1.0, dtype=dtype.float16)
ans = net(value)
assert ans.dtype == dtype.float16

def test_input_value2():
"""
Test validity of input value.
"""
net = Net2()
value = None
with pytest.raises(TypeError):
ans = net(value)
value = 1.0
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.int32)
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.float16)
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.float32)
ans = net(value)
assert ans.dtype == dtype.float32

+ 4
- 0
tests/ut/python/nn/probability/distribution/test_distribution.py View File

@@ -1,3 +1,7 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0


Loading…
Cancel
Save