| @@ -16,8 +16,10 @@ | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from ..distribution._utils.utils import CheckTensor, cast_to_tensor | |||||
| from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error | |||||
| from ..distribution import Distribution | from ..distribution import Distribution | ||||
| from ..distribution import TransformedDistribution | from ..distribution import TransformedDistribution | ||||
| @@ -32,6 +34,17 @@ class Bijector(Cell): | |||||
| name (str): The name of the Bijector. Default: None. | name (str): The name of the Bijector. Default: None. | ||||
| dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None. | dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None. | ||||
| param (dict): The parameters used to initialize the Bijector. Default: None. | param (dict): The parameters used to initialize the Bijector. Default: None. | ||||
| Note: | |||||
| `dtype` of bijector represents the type of the distributions that the bijector could operate on. | |||||
| When `dtype` is None, there is no enforcement on the type of input value except that the input value | |||||
| has to be float type. During initilization, when `dtype` is None, there is no enforcement on the dtype | |||||
| of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised. | |||||
| Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector | |||||
| will be casted into the same type as input value when `dtype`is None. | |||||
| When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`. | |||||
| When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be | |||||
| raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`. | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -48,6 +61,8 @@ class Bijector(Cell): | |||||
| validator.check_value_type( | validator.check_value_type( | ||||
| 'is_constant_jacobian', is_constant_jacobian, [bool], name) | 'is_constant_jacobian', is_constant_jacobian, [bool], name) | ||||
| validator.check_value_type('is_injective', is_injective, [bool], name) | validator.check_value_type('is_injective', is_injective, [bool], name) | ||||
| if dtype is not None: | |||||
| validator.check_type_name("dtype", dtype, mstype.float_type, type(self).__name__) | |||||
| self._name = name | self._name = name | ||||
| self._dtype = dtype | self._dtype = dtype | ||||
| self._parameters = {} | self._parameters = {} | ||||
| @@ -57,6 +72,12 @@ class Bijector(Cell): | |||||
| continue | continue | ||||
| if not(k == 'self' or k.startswith('_')): | if not(k == 'self' or k.startswith('_')): | ||||
| self._parameters[k] = param[k] | self._parameters[k] = param[k] | ||||
| # if no bijector is used as an argument during initilization | |||||
| if 'bijector' not in param.keys(): | |||||
| self._batch_shape = self._calc_batch_shape() | |||||
| self._is_scalar_batch = self._check_is_scalar_batch() | |||||
| self._is_constant_jacobian = is_constant_jacobian | self._is_constant_jacobian = is_constant_jacobian | ||||
| self._is_injective = is_injective | self._is_injective = is_injective | ||||
| @@ -68,6 +89,8 @@ class Bijector(Cell): | |||||
| self.dtype_base = P.DType() | self.dtype_base = P.DType() | ||||
| self.shape_base = P.Shape() | self.shape_base = P.Shape() | ||||
| self.fill_base = P.Fill() | self.fill_base = P.Fill() | ||||
| self.sametypeshape_base = P.SameTypeShape() | |||||
| self.issubclass_base = P.IsSubClass() | |||||
| @property | @property | ||||
| def name(self): | def name(self): | ||||
| @@ -89,6 +112,38 @@ class Bijector(Cell): | |||||
| def is_injective(self): | def is_injective(self): | ||||
| return self._is_injective | return self._is_injective | ||||
| @property | |||||
| def batch_shape(self): | |||||
| return self._batch_shape | |||||
| @property | |||||
| def is_scalar_batch(self): | |||||
| return self._is_scalar_batch | |||||
| def _check_value_dtype(self, value): | |||||
| """ | |||||
| Firstly check if the input value is Tensor. Then, if `self.dtype` is None, check | |||||
| if the input tensor is or can be directly cast into a float tensor. | |||||
| If `self.dtype` is not None, check if the input tensor's dtype is `self.dtype`. | |||||
| """ | |||||
| self.checktensor(value, 'input value of bijector') | |||||
| value_type = self.dtype_base(value) | |||||
| if self.dtype is None: | |||||
| if self.issubclass_base(value_type, mstype.float_): | |||||
| return value | |||||
| return raise_type_error('input value of bijector', value_type, mstype.float_) | |||||
| dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0) | |||||
| self.sametypeshape_base(value, dtype_tensor) | |||||
| return value | |||||
| def _shape_mapping(self, shape): | |||||
| shape_tensor = self.fill_base(self.parameter_type, shape, 0.0) | |||||
| dist_shape_tensor = self.fill_base(self.parameter_type, self.batch_shape, 0.0) | |||||
| return (shape_tensor + dist_shape_tensor).shape | |||||
| def shape_mapping(self, shape): | |||||
| return self._shape_mapping(shape) | |||||
| def _add_parameter(self, value, name): | def _add_parameter(self, value, name): | ||||
| """ | """ | ||||
| Cast `value` to a tensor and add it to `self.default_parameters`. | Cast `value` to a tensor and add it to `self.default_parameters`. | ||||
| @@ -98,26 +153,51 @@ class Bijector(Cell): | |||||
| if not hasattr(self, 'default_parameters'): | if not hasattr(self, 'default_parameters'): | ||||
| self.default_parameters = [] | self.default_parameters = [] | ||||
| self.parameter_names = [] | self.parameter_names = [] | ||||
| self.common_dtype = None | |||||
| # cast value to a tensor if it is not None | # cast value to a tensor if it is not None | ||||
| value_t = None if value is None else cast_to_tensor(value, self.parameter_type) | |||||
| self.default_parameters += [value_t,] | |||||
| if isinstance(value, bool) or value is None: | |||||
| raise TypeError(f"{name} cannot be type {type(value)}") | |||||
| value_t = Tensor(value) | |||||
| # if the bijector's dtype is not specified | |||||
| if self.dtype is None: | |||||
| if self.common_dtype is None: | |||||
| self.common_dtype = value_t.dtype | |||||
| elif value_t.dtype != self.common_dtype: | |||||
| raise TypeError(f"{name} should have the same dtype as other arguments.") | |||||
| # check if the dtype of the input_parameter agrees with the bijector's dtype | |||||
| elif value_t.dtype != self.dtype: | |||||
| raise TypeError(f"{name} should have the same dtype as the bijector's dtype.") | |||||
| self.default_parameters += [value,] | |||||
| self.parameter_names += [name,] | self.parameter_names += [name,] | ||||
| return value_t | return value_t | ||||
| def _calc_event_shape(self): | |||||
| def _calc_batch_shape(self): | |||||
| """ | """ | ||||
| Calculate event_shape based on parameters. | |||||
| Calculate batch_shape based on parameters. | |||||
| """ | """ | ||||
| broadcast_shape = None | |||||
| for param in self.default_parameters: | |||||
| if broadcast_shape is None: | |||||
| broadcast_shape = self.shape_base(param) | |||||
| broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) | |||||
| param_dict = self.parameters['param_dict'] | |||||
| broadcast_shape_tensor = None | |||||
| for value in param_dict.values(): | |||||
| if value is None: | |||||
| return None | |||||
| if broadcast_shape_tensor is None: | |||||
| broadcast_shape_tensor = cast_to_tensor(value) | |||||
| else: | else: | ||||
| broadcast_shape = self.shape_base(param + broadcast_shape_tensor) | |||||
| broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) | |||||
| return broadcast_shape | |||||
| value = cast_to_tensor(value) | |||||
| broadcast_shape_tensor = (value + broadcast_shape_tensor) | |||||
| return broadcast_shape_tensor.shape | |||||
| def _check_is_scalar_batch(self): | |||||
| """ | |||||
| Check if the parameters used during initialization are scalars. | |||||
| """ | |||||
| param_dict = self.parameters['param_dict'] | |||||
| for value in param_dict.values(): | |||||
| if value is None: | |||||
| continue | |||||
| if not isinstance(value, (int, float)): | |||||
| return False | |||||
| return True | |||||
| def _check_value(self, value, name): | def _check_value(self, value, name): | ||||
| """ | """ | ||||
| @@ -127,32 +207,35 @@ class Bijector(Cell): | |||||
| return value | return value | ||||
| def cast_param_by_value(self, value, para): | def cast_param_by_value(self, value, para): | ||||
| """ | |||||
| Cast the parameter(s) of the bijector to be the same type of input_value. | |||||
| """ | |||||
| local = self.cast_base(para, self.dtype_base(value)) | local = self.cast_base(para, self.dtype_base(value)) | ||||
| return local | return local | ||||
| def forward(self, *args, **kwargs): | |||||
| def forward(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Forward transformation: transform the input value to another distribution. | Forward transformation: transform the input value to another distribution. | ||||
| """ | """ | ||||
| return self._forward(*args, **kwargs) | |||||
| return self._forward(value, *args, **kwargs) | |||||
| def inverse(self, *args, **kwargs): | |||||
| def inverse(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Inverse transformation: transform the input value back to the original distribution. | Inverse transformation: transform the input value back to the original distribution. | ||||
| """ | """ | ||||
| return self._inverse(*args, **kwargs) | |||||
| return self._inverse(value, *args, **kwargs) | |||||
| def forward_log_jacobian(self, *args, **kwargs): | |||||
| def forward_log_jacobian(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Logarithm of the derivative of the forward transformation. | Logarithm of the derivative of the forward transformation. | ||||
| """ | """ | ||||
| return self._forward_log_jacobian(*args, **kwargs) | |||||
| return self._forward_log_jacobian(value, *args, **kwargs) | |||||
| def inverse_log_jacobian(self, *args, **kwargs): | |||||
| def inverse_log_jacobian(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Logarithm of the derivative of the inverse transformation. | Logarithm of the derivative of the inverse transformation. | ||||
| """ | """ | ||||
| return self._inverse_log_jacobian(*args, **kwargs) | |||||
| return self._inverse_log_jacobian(value, *args, **kwargs) | |||||
| def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
| """ | """ | ||||
| @@ -167,7 +250,7 @@ class Bijector(Cell): | |||||
| *args: args[0] shall be either a distribution or the name of a Bijector function. | *args: args[0] shall be either a distribution or the name of a Bijector function. | ||||
| """ | """ | ||||
| if isinstance(args[0], Distribution): | if isinstance(args[0], Distribution): | ||||
| return TransformedDistribution(self, args[0], self.distribution.dtype) | |||||
| return TransformedDistribution(self, args[0]) | |||||
| return super(Bijector, self).__call__(*args, **kwargs) | return super(Bijector, self).__call__(*args, **kwargs) | ||||
| def construct(self, name, *args, **kwargs): | def construct(self, name, *args, **kwargs): | ||||
| @@ -13,10 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """GumbelCDF Bijector""" | """GumbelCDF Bijector""" | ||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore._checkparam import Validator | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from ..distribution._utils.utils import check_greater_zero, set_param_type | |||||
| from ..distribution._utils.utils import check_greater_zero | |||||
| from ..distribution._utils.custom_ops import exp_generic, log_generic | from ..distribution._utils.custom_ops import exp_generic, log_generic | ||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| @@ -30,12 +28,11 @@ class GumbelCDF(Bijector): | |||||
| Y = \exp(-\exp(\frac{-(X - loc)}{scale})) | Y = \exp(-\exp(\frac{-(X - loc)}{scale})) | ||||
| Note: | Note: | ||||
| For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1). | |||||
| For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1). | |||||
| Args: | Args: | ||||
| loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0.. | |||||
| scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. | |||||
| dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32. | |||||
| loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.. | |||||
| scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. | |||||
| name (str): The name of the Bijector. Default: 'Gumbel_CDF'. | name (str): The name of the Bijector. Default: 'Gumbel_CDF'. | ||||
| Examples: | Examples: | ||||
| @@ -61,22 +58,18 @@ class GumbelCDF(Bijector): | |||||
| def __init__(self, | def __init__(self, | ||||
| loc=0.0, | loc=0.0, | ||||
| scale=1.0, | scale=1.0, | ||||
| dtype=mstype.float32, | |||||
| name='GumbelCDF'): | name='GumbelCDF'): | ||||
| """ | """ | ||||
| Constructor of GumbelCDF Bijector. | Constructor of GumbelCDF Bijector. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) | |||||
| super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) | |||||
| param['param_dict'] = {'loc': loc, 'scale': scale} | |||||
| super(GumbelCDF, self).__init__(name=name, param=param) | |||||
| self._parameter_type = parameter_type | |||||
| self._loc = self._add_parameter(loc, 'loc') | self._loc = self._add_parameter(loc, 'loc') | ||||
| self._scale = self._add_parameter(scale, 'scale') | self._scale = self._add_parameter(scale, 'scale') | ||||
| check_greater_zero(self._scale, "scale") | check_greater_zero(self._scale, "scale") | ||||
| self._event_shape = self._calc_event_shape() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| @@ -91,38 +84,34 @@ class GumbelCDF(Bijector): | |||||
| def scale(self): | def scale(self): | ||||
| return self._scale | return self._scale | ||||
| @property | |||||
| def event_shape(self): | |||||
| return self._event_shape | |||||
| @property | |||||
| def parameter_type(self): | |||||
| return self._parameter_type | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| return f'loc = {self.loc}, scale = {self.scale}' | |||||
| def shape_mapping(self, shape): | |||||
| return shape | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'loc = {self.loc}, scale = {self.scale}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self.batch_shape}' | |||||
| return str_info | |||||
| def _forward(self, x): | def _forward(self, x): | ||||
| x = self._check_value(x, 'value') | |||||
| x = self.cast(x, self.parameter_type) | |||||
| z = (x - self.loc) / self.scale | |||||
| x = self._check_value_dtype(x) | |||||
| loc_local = self.cast_param_by_value(x, self.loc) | |||||
| scale_local = self.cast_param_by_value(x, self.scale) | |||||
| z = (x - loc_local) / scale_local | |||||
| return self.exp(-self.exp(-z)) | return self.exp(-self.exp(-z)) | ||||
| def _inverse(self, y): | def _inverse(self, y): | ||||
| y = self._check_value(y, 'value') | |||||
| y = self.cast(y, self.parameter_type) | |||||
| return self.loc - self.scale * self.log(-self.log(y)) | |||||
| y = self._check_value_dtype(y) | |||||
| loc_local = self.cast_param_by_value(y, self.loc) | |||||
| scale_local = self.cast_param_by_value(y, self.scale) | |||||
| return loc_local - scale_local * self.log(-self.log(y)) | |||||
| def _forward_log_jacobian(self, x): | def _forward_log_jacobian(self, x): | ||||
| x = self._check_value(x, 'value') | |||||
| x = self.cast(x, self.parameter_type) | |||||
| z = (x - self.loc) / self.scale | |||||
| return -z - self.exp(-z) - self.log(self.scale) | |||||
| x = self._check_value_dtype(x) | |||||
| loc_local = self.cast_param_by_value(x, self.loc) | |||||
| scale_local = self.cast_param_by_value(x, self.scale) | |||||
| z = (x - loc_local) / scale_local | |||||
| return -z - self.exp(-z) - self.log(scale_local) | |||||
| def _inverse_log_jacobian(self, y): | def _inverse_log_jacobian(self, y): | ||||
| y = self._check_value(y, 'value') | |||||
| y = self.cast(y, self.parameter_type) | |||||
| return self.log(self.scale / (-1. * y * self.log(y))) | |||||
| y = self._check_value_dtype(y) | |||||
| scale_local = self.cast_param_by_value(y, self.scale) | |||||
| return self.log(scale_local / (-1. * y * self.log(y))) | |||||
| @@ -53,23 +53,17 @@ class Invert(Bijector): | |||||
| name = (name + bijector.name) if name == 'Invert' else name | name = (name + bijector.name) if name == 'Invert' else name | ||||
| super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian, | super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian, | ||||
| is_injective=bijector.is_injective, | is_injective=bijector.is_injective, | ||||
| dtype=bijector.dtype, | |||||
| name=name, | name=name, | ||||
| dtype=bijector.dtype, | |||||
| param=param) | param=param) | ||||
| self._bijector = bijector | self._bijector = bijector | ||||
| if hasattr(self._bijector, 'event_shape'): | |||||
| self._event_shape = self.bijector.event_shape | |||||
| else: | |||||
| self._event_shape = () | |||||
| self._batch_shape = self.bijector.batch_shape | |||||
| self._is_scalar_batch = self.bijector.is_scalar_batch | |||||
| @property | @property | ||||
| def bijector(self): | def bijector(self): | ||||
| return self._bijector | return self._bijector | ||||
| @property | |||||
| def event_shape(self): | |||||
| return self._event_shape | |||||
| def inverse(self, y): | def inverse(self, y): | ||||
| return self.bijector("forward", y) | return self.bijector("forward", y) | ||||
| @@ -14,8 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Power Bijector""" | """Power Bijector""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| from ..distribution._utils.utils import check_greater_equal_zero | |||||
| from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic | from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic | ||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| @@ -37,7 +36,7 @@ class PowerTransform(Bijector): | |||||
| ValueError: When the power is less than 0 or is not known statically. | ValueError: When the power is less than 0 or is not known statically. | ||||
| Args: | Args: | ||||
| power (int or float): The scale factor. Default: 0. | |||||
| power (float, list, numpy.ndarray, Tensor): The scale factor. Default: 0. | |||||
| name (str): The name of the bijector. Default: 'PowerTransform'. | name (str): The name of the bijector. Default: 'PowerTransform'. | ||||
| Examples: | Examples: | ||||
| @@ -64,10 +63,11 @@ class PowerTransform(Bijector): | |||||
| power=0, | power=0, | ||||
| name='PowerTransform'): | name='PowerTransform'): | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'power': power} | |||||
| super(PowerTransform, self).__init__(name=name, param=param) | super(PowerTransform, self).__init__(name=name, param=param) | ||||
| validator.check_value_type('power', power, [int, float], self.name) | |||||
| validator.check_number("power", power, 0, Rel.GE, self.name) | |||||
| self._power = power | |||||
| self._power = self._add_parameter(power, 'power') | |||||
| check_greater_equal_zero(self._power, 'Power') | |||||
| self.pow = P.Pow() | self.pow = P.Pow() | ||||
| self.dtypeop = P.DType() | self.dtypeop = P.DType() | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| @@ -81,13 +81,15 @@ class PowerTransform(Bijector): | |||||
| return self._power | return self._power | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| return f'power = {self.power}' | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'power = {self.power}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self.batch_shape}' | |||||
| return str_info | |||||
| def shape_mapping(self, shape): | |||||
| return shape | |||||
| def _forward(self, x): | def _forward(self, x): | ||||
| x = self._check_value(x, 'value') | |||||
| x = self._check_value_dtype(x) | |||||
| power_local = self.cast_param_by_value(x, self.power) | power_local = self.cast_param_by_value(x, self.power) | ||||
| if power_local == 0: | if power_local == 0: | ||||
| forward_v = self.exp(x) | forward_v = self.exp(x) | ||||
| @@ -96,7 +98,7 @@ class PowerTransform(Bijector): | |||||
| return forward_v | return forward_v | ||||
| def _inverse(self, y): | def _inverse(self, y): | ||||
| y = self._check_value(y, 'value') | |||||
| y = self._check_value_dtype(y) | |||||
| power_local = self.cast_param_by_value(y, self.power) | power_local = self.cast_param_by_value(y, self.power) | ||||
| if power_local == 0: | if power_local == 0: | ||||
| inverse_v = self.log(y) | inverse_v = self.log(y) | ||||
| @@ -116,7 +118,7 @@ class PowerTransform(Bijector): | |||||
| f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | ||||
| \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | ||||
| """ | """ | ||||
| x = self._check_value(x, 'value') | |||||
| x = self._check_value_dtype(x) | |||||
| power_local = self.cast_param_by_value(x, self.power) | power_local = self.cast_param_by_value(x, self.power) | ||||
| if power_local == 0: | if power_local == 0: | ||||
| forward_log_j = x | forward_log_j = x | ||||
| @@ -136,7 +138,7 @@ class PowerTransform(Bijector): | |||||
| f'(x) = \frac{e^c\log(y)}{y} | f'(x) = \frac{e^c\log(y)}{y} | ||||
| \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | ||||
| """ | """ | ||||
| y = self._check_value(y, 'value') | |||||
| y = self._check_value_dtype(y) | |||||
| power_local = self.cast_param_by_value(y, self.power) | power_local = self.cast_param_by_value(y, self.power) | ||||
| inverse_log_j = (power_local - 1) * self.log(y) | inverse_log_j = (power_local - 1) * self.log(y) | ||||
| return inverse_log_j | return inverse_log_j | ||||
| @@ -14,8 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Scalar Affine Bijector""" | """Scalar Affine Bijector""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore._checkparam import Validator as validator | |||||
| from ..distribution._utils.utils import cast_to_tensor | |||||
| from ..distribution._utils.custom_ops import log_generic | from ..distribution._utils.custom_ops import log_generic | ||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| @@ -30,10 +28,14 @@ class ScalarAffine(Bijector): | |||||
| where a is the scale factor and b is the shift factor. | where a is the scale factor and b is the shift factor. | ||||
| Args: | Args: | ||||
| scale (float): The scale factor. Default: 1.0. | |||||
| shift (float): The shift factor. Default: 0.0. | |||||
| scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0. | |||||
| shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0. | |||||
| name (str): The name of the bijector. Default: 'ScalarAffine'. | name (str): The name of the bijector. Default: 'ScalarAffine'. | ||||
| Note: | |||||
| If `shift`, `scale` are passed in as numpy.ndarray or tensor, they have to have | |||||
| the same dtype otherwise an error will be raised. | |||||
| Examples: | Examples: | ||||
| >>> # To initialize a ScalarAffine bijector of scale 1 and shift 2. | >>> # To initialize a ScalarAffine bijector of scale 1 and shift 2. | ||||
| >>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2) | >>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2) | ||||
| @@ -61,10 +63,7 @@ class ScalarAffine(Bijector): | |||||
| Constructor of ScalarAffine Bijector. | Constructor of ScalarAffine Bijector. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| validator.check_value_type( | |||||
| 'scale', scale, [int, float], type(self).__name__) | |||||
| validator.check_value_type( | |||||
| 'shift', shift, [int, float], type(self).__name__) | |||||
| param['param_dict'] = {'scale': scale, 'shift': shift} | |||||
| super(ScalarAffine, self).__init__( | super(ScalarAffine, self).__init__( | ||||
| is_constant_jacobian=True, | is_constant_jacobian=True, | ||||
| is_injective=True, | is_injective=True, | ||||
| @@ -72,8 +71,8 @@ class ScalarAffine(Bijector): | |||||
| dtype=None, | dtype=None, | ||||
| param=param) | param=param) | ||||
| self._scale = cast_to_tensor(scale) | |||||
| self._shift = cast_to_tensor(shift) | |||||
| self._scale = self._add_parameter(scale, 'scale') | |||||
| self._shift = self._add_parameter(shift, 'shift') | |||||
| self.abs = P.Abs() | self.abs = P.Abs() | ||||
| self.oneslike = P.OnesLike() | self.oneslike = P.OnesLike() | ||||
| @@ -90,17 +89,18 @@ class ScalarAffine(Bijector): | |||||
| return self._shift | return self._shift | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| return f'scale = {self.scale}, shift = {self.shift}' | |||||
| def shape_mapping(self, shape): | |||||
| return shape | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'scale = {self.scale}, shift = {self.shift}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self.batch_shape}' | |||||
| return str_info | |||||
| def _forward(self, x): | def _forward(self, x): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| f(x) = a * x + b | f(x) = a * x + b | ||||
| """ | """ | ||||
| x = self._check_value(x, 'value') | |||||
| x = self._check_value_dtype(x) | |||||
| scale_local = self.cast_param_by_value(x, self.scale) | scale_local = self.cast_param_by_value(x, self.scale) | ||||
| shift_local = self.cast_param_by_value(x, self.shift) | shift_local = self.cast_param_by_value(x, self.shift) | ||||
| forward_v = scale_local * x + shift_local * self.oneslike(x) | forward_v = scale_local * x + shift_local * self.oneslike(x) | ||||
| @@ -111,7 +111,7 @@ class ScalarAffine(Bijector): | |||||
| .. math:: | .. math:: | ||||
| f(y) = \frac{y - b}{a} | f(y) = \frac{y - b}{a} | ||||
| """ | """ | ||||
| y = self._check_value(y, 'value') | |||||
| y = self._check_value_dtype(y) | |||||
| scale_local = self.cast_param_by_value(y, self.scale) | scale_local = self.cast_param_by_value(y, self.scale) | ||||
| shift_local = self.cast_param_by_value(y, self.shift) | shift_local = self.cast_param_by_value(y, self.shift) | ||||
| inverse_v = (y - shift_local) / scale_local | inverse_v = (y - shift_local) / scale_local | ||||
| @@ -124,7 +124,7 @@ class ScalarAffine(Bijector): | |||||
| f'(x) = a | f'(x) = a | ||||
| \log(f'(x)) = \log(a) | \log(f'(x)) = \log(a) | ||||
| """ | """ | ||||
| x = self._check_value(x, 'value') | |||||
| x = self._check_value_dtype(x) | |||||
| scale_local = self.cast_param_by_value(x, self.scale) | scale_local = self.cast_param_by_value(x, self.scale) | ||||
| forward_log_j = self.log(self.abs(scale_local)) | forward_log_j = self.log(self.abs(scale_local)) | ||||
| return forward_log_j | return forward_log_j | ||||
| @@ -136,7 +136,7 @@ class ScalarAffine(Bijector): | |||||
| f'(x) = \frac{1.0}{a} | f'(x) = \frac{1.0}{a} | ||||
| \log(f'(x)) = - \log(a) | \log(f'(x)) = - \log(a) | ||||
| """ | """ | ||||
| y = self._check_value(y, 'value') | |||||
| y = self._check_value_dtype(y) | |||||
| scale_local = self.cast_param_by_value(y, self.scale) | scale_local = self.cast_param_by_value(y, self.scale) | ||||
| inverse_log_j = -1. * self.log(self.abs(scale_local)) | inverse_log_j = -1. * self.log(self.abs(scale_local)) | ||||
| return inverse_log_j | return inverse_log_j | ||||
| @@ -16,8 +16,6 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.nn.layer.activation import LogSigmoid | from mindspore.nn.layer.activation import LogSigmoid | ||||
| from mindspore._checkparam import Validator as validator | |||||
| from ..distribution._utils.utils import cast_to_tensor | |||||
| from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic | from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic | ||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| @@ -32,7 +30,7 @@ class Softplus(Bijector): | |||||
| where k is the sharpness factor. | where k is the sharpness factor. | ||||
| Args: | Args: | ||||
| sharpness (float): The scale factor. Default: 1.0. | |||||
| sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0. | |||||
| name (str): The name of the Bijector. Default: 'Softplus'. | name (str): The name of the Bijector. Default: 'Softplus'. | ||||
| Examples: | Examples: | ||||
| @@ -61,10 +59,9 @@ class Softplus(Bijector): | |||||
| Constructor of Softplus Bijector. | Constructor of Softplus Bijector. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| validator.check_value_type('sharpness', sharpness, | |||||
| [int, float], type(self).__name__) | |||||
| super(Softplus, self).__init__(name=name, param=param) | |||||
| self._sharpness = cast_to_tensor(sharpness) | |||||
| param['param_dict'] = {'sharpness': sharpness} | |||||
| super(Softplus, self).__init__(name=name, dtype=None, param=param) | |||||
| self._sharpness = self._add_parameter(sharpness, 'sharpness') | |||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| self.log = log_generic | self.log = log_generic | ||||
| @@ -118,13 +115,14 @@ class Softplus(Bijector): | |||||
| return self._sharpness | return self._sharpness | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| return f'sharpness = {self.sharpness}' | |||||
| def shape_mapping(self, shape): | |||||
| return shape | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'sharpness = {self.sharpness}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self.batch_shape}' | |||||
| return str_info | |||||
| def _forward(self, x): | def _forward(self, x): | ||||
| x = self._check_value(x, 'value') | |||||
| x = self._check_value_dtype(x) | |||||
| sharpness_local = self.cast_param_by_value(x, self.sharpness) | sharpness_local = self.cast_param_by_value(x, self.sharpness) | ||||
| scaled_value = sharpness_local * x | scaled_value = sharpness_local * x | ||||
| forward_v = self.softplus(scaled_value) / sharpness_local | forward_v = self.softplus(scaled_value) / sharpness_local | ||||
| @@ -136,7 +134,7 @@ class Softplus(Bijector): | |||||
| f(x) = \frac{\log(1 + e^{kx}))}{k} | f(x) = \frac{\log(1 + e^{kx}))}{k} | ||||
| f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | ||||
| """ | """ | ||||
| y = self._check_value(y, 'value') | |||||
| y = self._check_value_dtype(y) | |||||
| sharpness_local = self.cast_param_by_value(y, self.sharpness) | sharpness_local = self.cast_param_by_value(y, self.sharpness) | ||||
| scaled_value = sharpness_local * y | scaled_value = sharpness_local * y | ||||
| inverse_v = self.inverse_softplus(scaled_value) / sharpness_local | inverse_v = self.inverse_softplus(scaled_value) / sharpness_local | ||||
| @@ -149,7 +147,7 @@ class Softplus(Bijector): | |||||
| f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | ||||
| \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | ||||
| """ | """ | ||||
| x = self._check_value(x, 'value') | |||||
| x = self._check_value_dtype(x) | |||||
| sharpness_local = self.cast_param_by_value(x, self.sharpness) | sharpness_local = self.cast_param_by_value(x, self.sharpness) | ||||
| scaled_value = sharpness_local * x | scaled_value = sharpness_local * x | ||||
| forward_log_j = self.log_sigmoid(scaled_value) | forward_log_j = self.log_sigmoid(scaled_value) | ||||
| @@ -162,7 +160,7 @@ class Softplus(Bijector): | |||||
| f'(y) = \frac{e^{ky}}{e^{ky} - 1} | f'(y) = \frac{e^{ky}}{e^{ky} - 1} | ||||
| \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | ||||
| """ | """ | ||||
| y = self._check_value(y, 'value') | |||||
| y = self._check_value_dtype(y) | |||||
| sharpness_local = self.cast_param_by_value(y, self.sharpness) | sharpness_local = self.cast_param_by_value(y, self.sharpness) | ||||
| scaled_value = sharpness_local * y | scaled_value = sharpness_local * y | ||||
| inverse_log_j = scaled_value - self.inverse_softplus(scaled_value) | inverse_log_j = scaled_value - self.inverse_softplus(scaled_value) | ||||
| @@ -229,6 +229,11 @@ def raise_not_implemented_util(func_name, obj, *args, **kwargs): | |||||
| raise NotImplementedError( | raise NotImplementedError( | ||||
| f"{func_name} is not implemented for {obj} distribution.") | f"{func_name} is not implemented for {obj} distribution.") | ||||
| @constexpr | |||||
| def raise_type_error(name, cur_type, required_type): | |||||
| raise TypeError( | |||||
| f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}") | |||||
| @constexpr | @constexpr | ||||
| def check_distribution_name(name, expected_name): | def check_distribution_name(name, expected_name): | ||||
| @@ -304,7 +309,7 @@ def set_param_type(args, hint_type): | |||||
| TypeError: if tensors in args are not the same dtype. | TypeError: if tensors in args are not the same dtype. | ||||
| """ | """ | ||||
| int_type = mstype.int_type + mstype.uint_type | int_type = mstype.int_type + mstype.uint_type | ||||
| if hint_type in int_type: | |||||
| if hint_type in int_type or hint_type is None: | |||||
| hint_type = mstype.float32 | hint_type = mstype.float32 | ||||
| common_dtype = None | common_dtype = None | ||||
| for name, arg in args.items(): | for name, arg in args.items(): | ||||
| @@ -72,13 +72,12 @@ class Distribution(Cell): | |||||
| if not(k == 'self' or k.startswith('_')): | if not(k == 'self' or k.startswith('_')): | ||||
| self._parameters[k] = param[k] | self._parameters[k] = param[k] | ||||
| # some attributes | |||||
| if 'distribution' in self.parameters.keys(): | |||||
| self.parameter_type = self.parameters['distribution'].parameter_type | |||||
| else: | |||||
| # if not a transformed distribution, set the following attribute | |||||
| if 'distribution' not in self.parameters.keys(): | |||||
| self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) | self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) | ||||
| self._broadcast_shape = self._calc_broadcast_shape() | |||||
| self._is_scalar_batch = self._check_is_scalar_batch() | |||||
| self._batch_shape = self._calc_batch_shape() | |||||
| self._is_scalar_batch = self._check_is_scalar_batch() | |||||
| self._broadcast_shape = self._batch_shape | |||||
| # set the function to call according to the derived class's attributes | # set the function to call according to the derived class's attributes | ||||
| self._set_prob() | self._set_prob() | ||||
| @@ -128,6 +127,10 @@ class Distribution(Cell): | |||||
| def is_scalar_batch(self): | def is_scalar_batch(self): | ||||
| return self._is_scalar_batch | return self._is_scalar_batch | ||||
| @property | |||||
| def batch_shape(self): | |||||
| return self._batch_shape | |||||
| @property | @property | ||||
| def broadcast_shape(self): | def broadcast_shape(self): | ||||
| return self._broadcast_shape | return self._broadcast_shape | ||||
| @@ -208,8 +211,6 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| Check if the parameters used during initialization are scalars. | Check if the parameters used during initialization are scalars. | ||||
| """ | """ | ||||
| if 'distribution' in self.parameters.keys(): | |||||
| return self.parameters['distribution'].is_scalar_batch | |||||
| param_dict = self.parameters['param_dict'] | param_dict = self.parameters['param_dict'] | ||||
| for value in param_dict.values(): | for value in param_dict.values(): | ||||
| if value is None: | if value is None: | ||||
| @@ -218,12 +219,10 @@ class Distribution(Cell): | |||||
| return False | return False | ||||
| return True | return True | ||||
| def _calc_broadcast_shape(self): | |||||
| def _calc_batch_shape(self): | |||||
| """ | """ | ||||
| Calculate the broadcast shape of the parameters used during initialization. | Calculate the broadcast shape of the parameters used during initialization. | ||||
| """ | """ | ||||
| if 'distribution' in self.parameters.keys(): | |||||
| return self.parameters['distribution'].broadcast_shape | |||||
| param_dict = self.parameters['param_dict'] | param_dict = self.parameters['param_dict'] | ||||
| broadcast_shape_tensor = None | broadcast_shape_tensor = None | ||||
| for value in param_dict.values(): | for value in param_dict.values(): | ||||
| @@ -362,14 +361,14 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| return self._get_dist_args(*args, **kwargs) | return self._get_dist_args(*args, **kwargs) | ||||
| def _get_dist_type(self, *args, **kwargs): | |||||
| return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs) | |||||
| def _get_dist_type(self): | |||||
| return raise_not_implemented_util('get_dist_type', self.name) | |||||
| def get_dist_type(self, *args, **kwargs): | |||||
| def get_dist_type(self): | |||||
| """ | """ | ||||
| Return the type of the distribution. | Return the type of the distribution. | ||||
| """ | """ | ||||
| return self._get_dist_type(*args, **kwargs) | |||||
| return self._get_dist_type() | |||||
| def _raise_not_implemented_error(self, func_name): | def _raise_not_implemented_error(self, func_name): | ||||
| name = self.name | name = self.name | ||||
| @@ -751,5 +750,5 @@ class Distribution(Cell): | |||||
| if name == 'get_dist_args': | if name == 'get_dist_args': | ||||
| return self._get_dist_args(*args, **kwargs) | return self._get_dist_args(*args, **kwargs) | ||||
| if name == 'get_dist_type': | if name == 'get_dist_type': | ||||
| return self._get_dist_type(*args, **kwargs) | |||||
| return self._get_dist_type() | |||||
| return raise_not_implemented_util(name, self.name, *args, **kwargs) | return raise_not_implemented_util(name, self.name, *args, **kwargs) | ||||
| @@ -103,17 +103,12 @@ class Gumbel(TransformedDistribution): | |||||
| """ | """ | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | ||||
| gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) | |||||
| gumbel_cdf = msb.GumbelCDF(loc, scale) | |||||
| super(Gumbel, self).__init__( | super(Gumbel, self).__init__( | ||||
| distribution=msd.Uniform(0.0, 1.0, dtype=dtype), | distribution=msd.Uniform(0.0, 1.0, dtype=dtype), | ||||
| bijector=msb.Invert(gumbel_cdf), | bijector=msb.Invert(gumbel_cdf), | ||||
| seed=seed, name=name) | seed=seed, name=name) | ||||
| self.parameter_type = gumbel_cdf.parameter_type | |||||
| self._broadcast_shape = gumbel_cdf.event_shape | |||||
| if self._broadcast_shape != (): | |||||
| self._is_scalar_batch = False | |||||
| # overwrite default_parameters and parameter_names | # overwrite default_parameters and parameter_names | ||||
| self._reset_parameters() | self._reset_parameters() | ||||
| self._loc = self._add_parameter(loc, 'loc') | self._loc = self._add_parameter(loc, 'loc') | ||||
| @@ -202,6 +197,7 @@ class Gumbel(TransformedDistribution): | |||||
| where z = \frac{x - loc}{scale} | where z = \frac{x - loc}{scale} | ||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, self.dtype) | |||||
| z = (value - self.loc) / self.scale | z = (value - self.loc) / self.scale | ||||
| return -(z + self.exp(-z)) - self.log(self.scale) | return -(z + self.exp(-z)) - self.log(self.scale) | ||||
| @@ -210,6 +206,8 @@ class Gumbel(TransformedDistribution): | |||||
| .. math:: | .. math:: | ||||
| cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale}) | cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale}) | ||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | |||||
| value = self.cast(value, self.dtype) | |||||
| return self._gumbel_bijector("forward", value) | return self._gumbel_bijector("forward", value) | ||||
| def _cross_entropy(self, dist, loc_b, scale_b): | def _cross_entropy(self, dist, loc_b, scale_b): | ||||
| @@ -251,12 +249,14 @@ class Gumbel(TransformedDistribution): | |||||
| self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.)) | self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.)) | ||||
| def _sample(self, shape=()): | def _sample(self, shape=()): | ||||
| shape = self.checktuple(shape, 'shape') | |||||
| origin_shape = shape + self._broadcast_shape | origin_shape = shape + self._broadcast_shape | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| sample_shape = (1,) | sample_shape = (1,) | ||||
| else: | else: | ||||
| sample_shape = origin_shape | sample_shape = origin_shape | ||||
| org_sample = self.distribution("sample", sample_shape) | org_sample = self.distribution("sample", sample_shape) | ||||
| org_sample = self.cast(org_sample, self.dtype) | |||||
| value = self.bijector("forward", org_sample) | value = self.bijector("forward", org_sample) | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| value = self.squeeze(value) | value = self.squeeze(value) | ||||
| @@ -137,6 +137,11 @@ class LogNormal(msd.TransformedDistribution): | |||||
| bijector=msb.Exp(), | bijector=msb.Exp(), | ||||
| seed=seed, name=name) | seed=seed, name=name) | ||||
| # overwrite default_parameters and parameter_names | |||||
| self._reset_parameters() | |||||
| self._loc = self._add_parameter(loc, 'loc') | |||||
| self._scale = self._add_parameter(scale, 'scale') | |||||
| self.log_2pi = np.log(2 * np.pi) | self.log_2pi = np.log(2 * np.pi) | ||||
| #ops needed for the class | #ops needed for the class | ||||
| @@ -154,12 +159,12 @@ class LogNormal(msd.TransformedDistribution): | |||||
| @property | @property | ||||
| def loc(self): | def loc(self): | ||||
| """Distribution parameter for the pre-transformed mean.""" | """Distribution parameter for the pre-transformed mean.""" | ||||
| return self.distribution("mean") | |||||
| return self._loc | |||||
| @property | @property | ||||
| def scale(self): | def scale(self): | ||||
| """Distribution parameter for the pre-transformed standard deviation.""" | """Distribution parameter for the pre-transformed standard deviation.""" | ||||
| return self.distribution("sd") | |||||
| return self._scale | |||||
| def _get_dist_type(self): | def _get_dist_type(self): | ||||
| return "LogNormal" | return "LogNormal" | ||||
| @@ -168,18 +173,18 @@ class LogNormal(msd.TransformedDistribution): | |||||
| if loc is not None: | if loc is not None: | ||||
| self.checktensor(loc, 'loc') | self.checktensor(loc, 'loc') | ||||
| else: | else: | ||||
| loc = self.distribution("mean") | |||||
| loc = self.loc | |||||
| if scale is not None: | if scale is not None: | ||||
| self.checktensor(scale, 'scale') | self.checktensor(scale, 'scale') | ||||
| else: | else: | ||||
| scale = self.distribution("sd") | |||||
| scale = self.scale | |||||
| return loc, scale | return loc, scale | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| s = f'loc = {self._mean_value}, scale = {self._sd_value}' | |||||
| s = f'loc = {self.loc}, scale = {self.scale}' | |||||
| else: | else: | ||||
| s = f'batch_shape = {self._broadcast_shape}' | |||||
| s = f'batch_shape = {self.broadcast_shape}' | |||||
| return s | return s | ||||
| def _mean(self, loc=None, scale=None): | def _mean(self, loc=None, scale=None): | ||||
| @@ -16,6 +16,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common import dtype as mstype | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import raise_not_impl_error | from ._utils.utils import raise_not_impl_error | ||||
| @@ -30,7 +31,7 @@ class TransformedDistribution(Distribution): | |||||
| Args: | Args: | ||||
| bijector (Bijector): The transformation to perform. | bijector (Bijector): The transformation to perform. | ||||
| distribution (Distribution): The original distribution. | |||||
| distribution (Distribution): The original distribution. Must has dtype of mindspore.float_type. | |||||
| seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None. | seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None. | ||||
| If this seed is given when a TransformedDistribution object is initialised, the object's sampling function | If this seed is given when a TransformedDistribution object is initialised, the object's sampling function | ||||
| will use this seed; elsewise, the underlying distribution's seed will be used. | will use this seed; elsewise, the underlying distribution's seed will be used. | ||||
| @@ -40,6 +41,12 @@ class TransformedDistribution(Distribution): | |||||
| The arguments used to initialize the original distribution cannot be None. | The arguments used to initialize the original distribution cannot be None. | ||||
| For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a | For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a | ||||
| TransformedDistribution since `mean` and `sd` are not specified. | TransformedDistribution since `mean` and `sd` are not specified. | ||||
| `batch_shape` is the batch_shape of the original distribution. | |||||
| `broadcast_shape` is the broadcast shape between the original distribution and bijector. | |||||
| `is_scalar_batch` is only true if both the original distribution and the bijector are scalar batches. | |||||
| `default_parameters`, `parameter_names` and `parameter_type` are set to be consistent with the original | |||||
| distribution. Derived class can overwrite `default_parameters` and `parameter_names` by calling | |||||
| `reset_parameters` followed by `add_parameter`. | |||||
| Examples: | Examples: | ||||
| >>> # To initialize a transformed distribution, e.g. a lognormal distribution, | >>> # To initialize a transformed distribution, e.g. a lognormal distribution, | ||||
| @@ -75,28 +82,34 @@ class TransformedDistribution(Distribution): | |||||
| [nn.probability.bijector.Bijector], type(self).__name__) | [nn.probability.bijector.Bijector], type(self).__name__) | ||||
| validator.check_value_type('distribution', distribution, | validator.check_value_type('distribution', distribution, | ||||
| [Distribution], type(self).__name__) | [Distribution], type(self).__name__) | ||||
| validator.check_type_name("dtype", distribution.dtype, mstype.float_type, type(self).__name__) | |||||
| super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) | super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) | ||||
| self._bijector = bijector | self._bijector = bijector | ||||
| self._distribution = distribution | self._distribution = distribution | ||||
| self._is_linear_transformation = bijector.is_constant_jacobian | |||||
| self.default_parameters = distribution.default_parameters | |||||
| self.parameter_names = distribution.parameter_names | |||||
| # set attributes | |||||
| self._is_linear_transformation = self.bijector.is_constant_jacobian | |||||
| self._dtype = self.distribution.dtype | |||||
| self._is_scalar_batch = self.distribution.is_scalar_batch and self.bijector.is_scalar_batch | |||||
| self._batch_shape = self.distribution.batch_shape | |||||
| self.default_parameters = self.distribution.default_parameters | |||||
| self.parameter_names = self.distribution.parameter_names | |||||
| # by default, set the parameter_type to be the distribution's parameter_type | |||||
| self.parameter_type = self.distribution.parameter_type | |||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| self.log = log_generic | self.log = log_generic | ||||
| self.isnan = P.IsNan() | self.isnan = P.IsNan() | ||||
| self.cast_base = P.Cast() | |||||
| self.equal_base = P.Equal() | self.equal_base = P.Equal() | ||||
| self.select_base = P.Select() | self.select_base = P.Select() | ||||
| self.fill = P.Fill() | |||||
| self.fill_base = P.Fill() | |||||
| # broadcast bijector batch_shape and distribution batch_shape | |||||
| self._broadcast_shape = self._broadcast_bijector_dist() | |||||
| # check if batch shape of the distribution and event shape is broadcastable | |||||
| if hasattr(self.bijector, 'event_shape'): | |||||
| event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0) | |||||
| broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0) | |||||
| self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape | |||||
| else: | |||||
| self._batch_event = self.broadcast_shape | |||||
| @property | @property | ||||
| def bijector(self): | def bijector(self): | ||||
| @@ -108,12 +121,22 @@ class TransformedDistribution(Distribution): | |||||
| @property | @property | ||||
| def dtype(self): | def dtype(self): | ||||
| return self.distribution.dtype | |||||
| return self._dtype | |||||
| @property | @property | ||||
| def is_linear_transformation(self): | def is_linear_transformation(self): | ||||
| return self._is_linear_transformation | return self._is_linear_transformation | ||||
| def _broadcast_bijector_dist(self): | |||||
| """ | |||||
| check if the batch shape of base distribution and the bijector is broadcastable. | |||||
| """ | |||||
| if self.batch_shape is None or self.bijector.batch_shape is None: | |||||
| return None | |||||
| bijector_shape_tensor = self.fill_base(self.dtype, self.bijector.batch_shape, 0.0) | |||||
| dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0) | |||||
| return (bijector_shape_tensor + dist_shape_tensor).shape | |||||
| def _cdf(self, value, *args, **kwargs): | def _cdf(self, value, *args, **kwargs): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| @@ -0,0 +1,191 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for exp""" | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.bijector as msb | |||||
| from mindspore import Tensor | |||||
| from mindspore import dtype | |||||
| class MyBijector(msb.Bijector): | |||||
| """ | |||||
| Customized bijector class with dtype not specified. | |||||
| """ | |||||
| def __init__(self, param1, param2): | |||||
| param = dict(locals()) | |||||
| param['param_dict'] = {'param1': param1, 'param2': param2} | |||||
| super(MyBijector, self).__init__(name='MyBijector', dtype=None, param=param) | |||||
| self._param1 = self._add_parameter(param1, 'param1') | |||||
| self._param2 = self._add_parameter(param2, 'param2') | |||||
| @property | |||||
| def param1(self): | |||||
| return self._param1 | |||||
| @property | |||||
| def param2(self): | |||||
| return self._param2 | |||||
| def _forward(self, value): | |||||
| value = self._check_value_dtype(value) | |||||
| param1_local = self.cast_param_by_value(value, self.param1) | |||||
| param2_local = self.cast_param_by_value(value, self.param2) | |||||
| return value * param1_local + param2_local | |||||
| class MySecondBijector(msb.Bijector): | |||||
| """ | |||||
| Customized bijector class with dtype specified. | |||||
| """ | |||||
| def __init__(self, param1, param2): | |||||
| param = dict(locals()) | |||||
| param['param_dict'] = {'param1': param1, 'param2': param2} | |||||
| super(MySecondBijector, self).__init__(name='MySecondBijector', dtype=dtype.float32, param=param) | |||||
| self._param1 = self._add_parameter(param1, 'param1') | |||||
| self._param2 = self._add_parameter(param2, 'param2') | |||||
| @property | |||||
| def param1(self): | |||||
| return self._param1 | |||||
| @property | |||||
| def param2(self): | |||||
| return self._param2 | |||||
| def _forward(self, value): | |||||
| value = self._check_value_dtype(value) | |||||
| param1_local = self.cast_param_by_value(value, self.param1) | |||||
| param2_local = self.cast_param_by_value(value, self.param2) | |||||
| return value * param1_local + param2_local | |||||
| def test_arguments_same_type(): | |||||
| """ | |||||
| Test bijector initializations. | |||||
| """ | |||||
| param1_1 = np.array(1.0).astype(np.float16) | |||||
| param2_1 = np.array(2.0).astype(np.float32) | |||||
| with pytest.raises(TypeError): | |||||
| MyBijector(param1_1, param2_1) | |||||
| param1_2 = Tensor(1.0, dtype=dtype.float16) | |||||
| param2_2 = Tensor(2.0, dtype=dtype.float32) | |||||
| with pytest.raises(TypeError): | |||||
| MyBijector(param1_2, param2_2) | |||||
| with pytest.raises(TypeError): | |||||
| MyBijector(True, param2_2) | |||||
| with pytest.raises(TypeError): | |||||
| MyBijector(None, param2_2) | |||||
| param1_3 = Tensor(1.0, dtype=dtype.float32) | |||||
| param2_3 = Tensor(2.0, dtype=dtype.float32) | |||||
| bijector = MyBijector(param1_3, param2_3) | |||||
| assert isinstance(bijector, msb.Bijector) | |||||
| param1_4 = np.array([1.0, 2.0]).astype(np.float16) | |||||
| param2_4 = np.array([1.0, 2.0]).astype(np.float16) | |||||
| bijector = MyBijector(param1_4, param2_4) | |||||
| assert isinstance(bijector, msb.Bijector) | |||||
| bijector = MyBijector(1.0, 2.0) | |||||
| assert isinstance(bijector, msb.Bijector) | |||||
| def test_arguments_with_dtype_specified(): | |||||
| """ | |||||
| Customized bijector class with dtype not specified. | |||||
| """ | |||||
| param1_1 = np.array(1.0).astype(np.float16) | |||||
| param2_1 = np.array(2.0).astype(np.float16) | |||||
| with pytest.raises(TypeError): | |||||
| MySecondBijector(param1_1, param2_1) | |||||
| param1_2 = Tensor(1.0, dtype=dtype.float16) | |||||
| param2_2 = Tensor(2.0, dtype=dtype.float32) | |||||
| with pytest.raises(TypeError): | |||||
| MySecondBijector(param1_2, param2_2) | |||||
| with pytest.raises(TypeError): | |||||
| MySecondBijector(True, param2_2) | |||||
| with pytest.raises(TypeError): | |||||
| MySecondBijector(None, param2_2) | |||||
| param1_3 = Tensor(1.0, dtype=dtype.float32) | |||||
| param2_3 = Tensor(2.0, dtype=dtype.float32) | |||||
| bijector = MyBijector(param1_3, param2_3) | |||||
| assert isinstance(bijector, msb.Bijector) | |||||
| param1_4 = np.array(2.0).astype(np.float32) | |||||
| param2_4 = np.array(1.0).astype(np.float32) | |||||
| bijector = MyBijector(param1_4, param2_4) | |||||
| assert isinstance(bijector, msb.Bijector) | |||||
| class Net1(nn.Cell): | |||||
| """ | |||||
| Test input value when bijector's dtype is not specified. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net1, self).__init__() | |||||
| self.bijector = MyBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32)) | |||||
| def construct(self, value): | |||||
| return self.bijector.forward(value) | |||||
| class Net2(nn.Cell): | |||||
| """ | |||||
| Test input value when bijector's dtype is specified. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net2, self).__init__() | |||||
| self.bijector = MySecondBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32)) | |||||
| def construct(self, value): | |||||
| return self.bijector.forward(value) | |||||
| def test_input_value(): | |||||
| """ | |||||
| Test validity of input value. | |||||
| """ | |||||
| net = Net1() | |||||
| value = None | |||||
| with pytest.raises(TypeError): | |||||
| ans = net(value) | |||||
| value = 1.0 | |||||
| with pytest.raises(TypeError): | |||||
| ans = net(value) | |||||
| value = Tensor(1.0, dtype=dtype.int32) | |||||
| with pytest.raises(TypeError): | |||||
| ans = net(value) | |||||
| value = Tensor(1.0, dtype=dtype.float32) | |||||
| ans = net(value) | |||||
| assert ans.dtype == dtype.float32 | |||||
| value = Tensor(1.0, dtype=dtype.float16) | |||||
| ans = net(value) | |||||
| assert ans.dtype == dtype.float16 | |||||
| def test_input_value2(): | |||||
| """ | |||||
| Test validity of input value. | |||||
| """ | |||||
| net = Net2() | |||||
| value = None | |||||
| with pytest.raises(TypeError): | |||||
| ans = net(value) | |||||
| value = 1.0 | |||||
| with pytest.raises(TypeError): | |||||
| ans = net(value) | |||||
| value = Tensor(1.0, dtype=dtype.int32) | |||||
| with pytest.raises(TypeError): | |||||
| ans = net(value) | |||||
| value = Tensor(1.0, dtype=dtype.float16) | |||||
| with pytest.raises(TypeError): | |||||
| ans = net(value) | |||||
| value = Tensor(1.0, dtype=dtype.float32) | |||||
| ans = net(value) | |||||
| assert ans.dtype == dtype.float32 | |||||
| @@ -1,3 +1,7 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | # You may obtain a copy of the License at | ||||
| # | # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | # http://www.apache.org/licenses/LICENSE-2.0 | ||||