update threshold of softplus computation support fp for bernoulli and geometric distributiontags/v1.0.0
| @@ -20,6 +20,7 @@ from ..distribution._utils.utils import CheckTensor | |||||
| from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step | from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step | ||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| class PowerTransform(Bijector): | class PowerTransform(Bijector): | ||||
| r""" | r""" | ||||
| Power Bijector. | Power Bijector. | ||||
| @@ -49,6 +50,7 @@ class PowerTransform(Bijector): | |||||
| >>> # by replacing 'forward' with the name of the function | >>> # by replacing 'forward' with the name of the function | ||||
| >>> ans = self.p1.forward(, value) | >>> ans = self.p1.forward(, value) | ||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| power=0, | power=0, | ||||
| name='PowerTransform', | name='PowerTransform', | ||||
| @@ -78,13 +80,13 @@ class PowerTransform(Bijector): | |||||
| return shape | return shape | ||||
| def _forward(self, x): | def _forward(self, x): | ||||
| self.checktensor(x, 'x') | |||||
| self.checktensor(x, 'value') | |||||
| if self.power == 0: | if self.power == 0: | ||||
| return self.exp(x) | return self.exp(x) | ||||
| return self.exp(self.log1p(x * self.power) / self.power) | return self.exp(self.log1p(x * self.power) / self.power) | ||||
| def _inverse(self, y): | def _inverse(self, y): | ||||
| self.checktensor(y, 'y') | |||||
| self.checktensor(y, 'value') | |||||
| if self.power == 0: | if self.power == 0: | ||||
| return self.log(y) | return self.log(y) | ||||
| return self.expm1(self.log(y) * self.power) / self.power | return self.expm1(self.log(y) * self.power) / self.power | ||||
| @@ -101,7 +103,7 @@ class PowerTransform(Bijector): | |||||
| f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | ||||
| \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | ||||
| """ | """ | ||||
| self.checktensor(x, 'x') | |||||
| self.checktensor(x, 'value') | |||||
| if self.power == 0: | if self.power == 0: | ||||
| return x | return x | ||||
| return (1. / self.power - 1) * self.log1p(x * self.power) | return (1. / self.power - 1) * self.log1p(x * self.power) | ||||
| @@ -118,5 +120,5 @@ class PowerTransform(Bijector): | |||||
| f'(x) = \frac{e^c\log(y)}{y} | f'(x) = \frac{e^c\log(y)}{y} | ||||
| \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| self.checktensor(y, 'value') | |||||
| return (self.power - 1) * self.log(y) | return (self.power - 1) * self.log(y) | ||||
| @@ -19,6 +19,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor | |||||
| from ..distribution._utils.custom_ops import log_by_step | from ..distribution._utils.custom_ops import log_by_step | ||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| class ScalarAffine(Bijector): | class ScalarAffine(Bijector): | ||||
| """ | """ | ||||
| Scalar Affine Bijector. | Scalar Affine Bijector. | ||||
| @@ -47,6 +48,7 @@ class ScalarAffine(Bijector): | |||||
| >>> ans = self.s1.forward_log_jacobian(value) | >>> ans = self.s1.forward_log_jacobian(value) | ||||
| >>> ans = self.s1.inverse_log_jacobian(value) | >>> ans = self.s1.inverse_log_jacobian(value) | ||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| scale=1.0, | scale=1.0, | ||||
| shift=0.0, | shift=0.0, | ||||
| @@ -91,7 +93,7 @@ class ScalarAffine(Bijector): | |||||
| .. math:: | .. math:: | ||||
| f(x) = a * x + b | f(x) = a * x + b | ||||
| """ | """ | ||||
| self.checktensor(x, 'x') | |||||
| self.checktensor(x, 'value') | |||||
| return self.scale * x + self.shift | return self.scale * x + self.shift | ||||
| def _inverse(self, y): | def _inverse(self, y): | ||||
| @@ -99,7 +101,7 @@ class ScalarAffine(Bijector): | |||||
| .. math:: | .. math:: | ||||
| f(y) = \frac{y - b}{a} | f(y) = \frac{y - b}{a} | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| self.checktensor(y, 'value') | |||||
| return (y - self.shift) / self.scale | return (y - self.shift) / self.scale | ||||
| def _forward_log_jacobian(self, x): | def _forward_log_jacobian(self, x): | ||||
| @@ -109,7 +111,7 @@ class ScalarAffine(Bijector): | |||||
| f'(x) = a | f'(x) = a | ||||
| \log(f'(x)) = \log(a) | \log(f'(x)) = \log(a) | ||||
| """ | """ | ||||
| self.checktensor(x, 'x') | |||||
| self.checktensor(x, 'value') | |||||
| return self.log(self.abs(self.scale)) | return self.log(self.abs(self.scale)) | ||||
| def _inverse_log_jacobian(self, y): | def _inverse_log_jacobian(self, y): | ||||
| @@ -119,5 +121,5 @@ class ScalarAffine(Bijector): | |||||
| f'(x) = \frac{1.0}{a} | f'(x) = \frac{1.0}{a} | ||||
| \log(f'(x)) = - \log(a) | \log(f'(x)) = - \log(a) | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| self.checktensor(y, 'value') | |||||
| return -1. * self.log(self.abs(self.scale)) | return -1. * self.log(self.abs(self.scale)) | ||||
| @@ -22,6 +22,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor | |||||
| from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step | from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step | ||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| class Softplus(Bijector): | class Softplus(Bijector): | ||||
| r""" | r""" | ||||
| Softplus Bijector. | Softplus Bijector. | ||||
| @@ -51,6 +52,7 @@ class Softplus(Bijector): | |||||
| >>> ans = self.sp1.forward_log_jacobian(value) | >>> ans = self.sp1.forward_log_jacobian(value) | ||||
| >>> ans = self.sp1.inverse_log_jacobian(value) | >>> ans = self.sp1.inverse_log_jacobian(value) | ||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| sharpness=1.0, | sharpness=1.0, | ||||
| name='Softplus'): | name='Softplus'): | ||||
| @@ -76,6 +78,7 @@ class Softplus(Bijector): | |||||
| self.checktensor = CheckTensor() | self.checktensor = CheckTensor() | ||||
| self.threshold = np.log(np.finfo(np.float32).eps) + 1 | self.threshold = np.log(np.finfo(np.float32).eps) + 1 | ||||
| self.tiny = np.exp(self.threshold) | |||||
| def _softplus(self, x): | def _softplus(self, x): | ||||
| too_small = self.less(x, self.threshold) | too_small = self.less(x, self.threshold) | ||||
| @@ -94,7 +97,7 @@ class Softplus(Bijector): | |||||
| f(x) = \frac{\log(1 + e^{x}))} | f(x) = \frac{\log(1 + e^{x}))} | ||||
| f^{-1}(y) = \frac{\log(e^{y} - 1)} | f^{-1}(y) = \frac{\log(e^{y} - 1)} | ||||
| """ | """ | ||||
| too_small = self.less(x, self.threshold) | |||||
| too_small = self.less(x, self.tiny) | |||||
| too_large = self.greater(x, -self.threshold) | too_large = self.greater(x, -self.threshold) | ||||
| too_small_value = self.log(x) | too_small_value = self.log(x) | ||||
| too_large_value = x | too_large_value = x | ||||
| @@ -116,7 +119,7 @@ class Softplus(Bijector): | |||||
| return shape | return shape | ||||
| def _forward(self, x): | def _forward(self, x): | ||||
| self.checktensor(x, 'x') | |||||
| self.checktensor(x, 'value') | |||||
| scaled_value = self.sharpness * x | scaled_value = self.sharpness * x | ||||
| return self.softplus(scaled_value) / self.sharpness | return self.softplus(scaled_value) / self.sharpness | ||||
| @@ -126,7 +129,7 @@ class Softplus(Bijector): | |||||
| f(x) = \frac{\log(1 + e^{kx}))}{k} | f(x) = \frac{\log(1 + e^{kx}))}{k} | ||||
| f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| self.checktensor(y, 'value') | |||||
| scaled_value = self.sharpness * y | scaled_value = self.sharpness * y | ||||
| return self.inverse_softplus(scaled_value) / self.sharpness | return self.inverse_softplus(scaled_value) / self.sharpness | ||||
| @@ -137,7 +140,7 @@ class Softplus(Bijector): | |||||
| f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | ||||
| \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | ||||
| """ | """ | ||||
| self.checktensor(x, 'x') | |||||
| self.checktensor(x, 'value') | |||||
| scaled_value = self.sharpness * x | scaled_value = self.sharpness * x | ||||
| return self.log_sigmoid(scaled_value) | return self.log_sigmoid(scaled_value) | ||||
| @@ -148,6 +151,6 @@ class Softplus(Bijector): | |||||
| f'(y) = \frac{e^{ky}}{e^{ky} - 1} | f'(y) = \frac{e^{ky}}{e^{ky} - 1} | ||||
| \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| self.checktensor(y, 'value') | |||||
| scaled_value = self.sharpness * y | scaled_value = self.sharpness * y | ||||
| return scaled_value - self.inverse_softplus(scaled_value) | return scaled_value - self.inverse_softplus(scaled_value) | ||||
| @@ -26,6 +26,7 @@ from mindspore import context | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.nn.probability as msp | import mindspore.nn.probability as msp | ||||
| def cast_to_tensor(t, hint_type=mstype.float32): | def cast_to_tensor(t, hint_type=mstype.float32): | ||||
| """ | """ | ||||
| Cast an user input value into a Tensor of dtype. | Cast an user input value into a Tensor of dtype. | ||||
| @@ -47,7 +48,7 @@ def cast_to_tensor(t, hint_type=mstype.float32): | |||||
| return t | return t | ||||
| t_type = hint_type | t_type = hint_type | ||||
| if isinstance(t, Tensor): | if isinstance(t, Tensor): | ||||
| #convert the type of tensor to dtype | |||||
| # convert the type of tensor to dtype | |||||
| return Tensor(t.asnumpy(), dtype=t_type) | return Tensor(t.asnumpy(), dtype=t_type) | ||||
| if isinstance(t, (list, np.ndarray)): | if isinstance(t, (list, np.ndarray)): | ||||
| return Tensor(t, dtype=t_type) | return Tensor(t, dtype=t_type) | ||||
| @@ -56,7 +57,8 @@ def cast_to_tensor(t, hint_type=mstype.float32): | |||||
| if isinstance(t, (int, float)): | if isinstance(t, (int, float)): | ||||
| return Tensor(t, dtype=t_type) | return Tensor(t, dtype=t_type) | ||||
| invalid_type = type(t) | invalid_type = type(t) | ||||
| raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") | |||||
| raise TypeError( | |||||
| f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") | |||||
| def convert_to_batch(t, batch_shape, required_type): | def convert_to_batch(t, batch_shape, required_type): | ||||
| @@ -79,6 +81,7 @@ def convert_to_batch(t, batch_shape, required_type): | |||||
| t = cast_to_tensor(t, required_type) | t = cast_to_tensor(t, required_type) | ||||
| return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) | return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) | ||||
| def check_scalar_from_param(params): | def check_scalar_from_param(params): | ||||
| """ | """ | ||||
| Check if params are all scalars. | Check if params are all scalars. | ||||
| @@ -93,11 +96,7 @@ def check_scalar_from_param(params): | |||||
| return params['distribution'].is_scalar_batch | return params['distribution'].is_scalar_batch | ||||
| if isinstance(value, Parameter): | if isinstance(value, Parameter): | ||||
| return False | return False | ||||
| if isinstance(value, (str, type(params['dtype']))): | |||||
| continue | |||||
| elif isinstance(value, (int, float)): | |||||
| continue | |||||
| else: | |||||
| if not isinstance(value, (int, float, str, type(params['dtype']))): | |||||
| return False | return False | ||||
| return True | return True | ||||
| @@ -124,7 +123,8 @@ def calc_broadcast_shape_from_param(params): | |||||
| value_t = value.default_input | value_t = value.default_input | ||||
| else: | else: | ||||
| value_t = cast_to_tensor(value, mstype.float32) | value_t = cast_to_tensor(value, mstype.float32) | ||||
| broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) | |||||
| broadcast_shape = utils.get_broadcast_shape( | |||||
| broadcast_shape, list(value_t.shape), params['name']) | |||||
| return tuple(broadcast_shape) | return tuple(broadcast_shape) | ||||
| @@ -148,6 +148,7 @@ def check_greater_equal_zero(value, name): | |||||
| if comp.any(): | if comp.any(): | ||||
| raise ValueError(f'{name} should be greater than ot equal to zero.') | raise ValueError(f'{name} should be greater than ot equal to zero.') | ||||
| def check_greater_zero(value, name): | def check_greater_zero(value, name): | ||||
| """ | """ | ||||
| Check if the given Tensor is strictly greater than zero. | Check if the given Tensor is strictly greater than zero. | ||||
| @@ -251,6 +252,7 @@ def probs_to_logits(probs, is_binary=False): | |||||
| return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) | return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) | ||||
| return P.Log()(ps_clamped) | return P.Log()(ps_clamped) | ||||
| def check_tensor_type(name, inputs, valid_type): | def check_tensor_type(name, inputs, valid_type): | ||||
| """ | """ | ||||
| Check if inputs is proper. | Check if inputs is proper. | ||||
| @@ -268,25 +270,34 @@ def check_tensor_type(name, inputs, valid_type): | |||||
| if input_type not in valid_type: | if input_type not in valid_type: | ||||
| raise TypeError(f"{name} dtype is invalid") | raise TypeError(f"{name} dtype is invalid") | ||||
| def check_type(data_type, value_type, name): | def check_type(data_type, value_type, name): | ||||
| if not data_type in value_type: | if not data_type in value_type: | ||||
| raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid") | |||||
| raise TypeError( | |||||
| f"For {name}, valid type include {value_type}, {data_type} is invalid") | |||||
| @constexpr | @constexpr | ||||
| def raise_none_error(name): | def raise_none_error(name): | ||||
| raise TypeError(f"the type {name} should be subclass of Tensor." | raise TypeError(f"the type {name} should be subclass of Tensor." | ||||
| f" It should not be None since it is not specified during initialization.") | f" It should not be None since it is not specified during initialization.") | ||||
| @constexpr | @constexpr | ||||
| def raise_not_impl_error(name): | def raise_not_impl_error(name): | ||||
| raise ValueError(f"{name} function should be implemented for non-linear transformation") | |||||
| raise ValueError( | |||||
| f"{name} function should be implemented for non-linear transformation") | |||||
| @constexpr | @constexpr | ||||
| def check_distribution_name(name, expected_name): | def check_distribution_name(name, expected_name): | ||||
| if name is None: | if name is None: | ||||
| raise ValueError(f"Distribution should be a constant which is not None.") | |||||
| raise ValueError( | |||||
| f"Distribution should be a constant which is not None.") | |||||
| if name != expected_name: | if name != expected_name: | ||||
| raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.") | |||||
| raise ValueError( | |||||
| f"Expected distribution name is {expected_name}, but got {name}.") | |||||
| class CheckTuple(PrimitiveWithInfer): | class CheckTuple(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -294,13 +305,13 @@ class CheckTuple(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init Cast""" | |||||
| super(CheckTuple, self).__init__("CheckTuple") | super(CheckTuple, self).__init__("CheckTuple") | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) | |||||
| self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) | |||||
| def __infer__(self, x, name): | def __infer__(self, x, name): | ||||
| if not isinstance(x['dtype'], tuple): | if not isinstance(x['dtype'], tuple): | ||||
| raise TypeError(f"For {name['value']}, Input type should b a tuple.") | |||||
| raise TypeError( | |||||
| f"For {name['value']}, Input type should b a tuple.") | |||||
| out = {'shape': None, | out = {'shape': None, | ||||
| 'dtype': None, | 'dtype': None, | ||||
| @@ -310,24 +321,25 @@ class CheckTuple(PrimitiveWithInfer): | |||||
| def __call__(self, x, name): | def __call__(self, x, name): | ||||
| if context.get_context("mode") == 0: | if context.get_context("mode") == 0: | ||||
| return x["value"] | return x["value"] | ||||
| #Pynative mode | |||||
| # Pynative mode | |||||
| if isinstance(x, tuple): | if isinstance(x, tuple): | ||||
| return x | return x | ||||
| raise TypeError(f"For {name['value']}, Input type should b a tuple.") | raise TypeError(f"For {name['value']}, Input type should b a tuple.") | ||||
| class CheckTensor(PrimitiveWithInfer): | class CheckTensor(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Check if input is a Tensor. | Check if input is a Tensor. | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init Cast""" | |||||
| super(CheckTensor, self).__init__("CheckTensor") | super(CheckTensor, self).__init__("CheckTensor") | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) | |||||
| self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) | |||||
| def __infer__(self, x, name): | def __infer__(self, x, name): | ||||
| src_type = x['dtype'] | src_type = x['dtype'] | ||||
| validator.check_subclass("input", src_type, [mstype.tensor], name["value"]) | |||||
| validator.check_subclass( | |||||
| "input", src_type, [mstype.tensor], name["value"]) | |||||
| out = {'shape': None, | out = {'shape': None, | ||||
| 'dtype': None, | 'dtype': None, | ||||
| @@ -20,6 +20,7 @@ from .distribution import Distribution | |||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | ||||
| from ._utils.custom_ops import exp_by_step, log_by_step | from ._utils.custom_ops import exp_by_step, log_by_step | ||||
| class Bernoulli(Distribution): | class Bernoulli(Distribution): | ||||
| """ | """ | ||||
| Bernoulli Distribution. | Bernoulli Distribution. | ||||
| @@ -97,7 +98,7 @@ class Bernoulli(Distribution): | |||||
| Constructor of Bernoulli distribution. | Constructor of Bernoulli distribution. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| valid_dtype = mstype.int_type + mstype.uint_type | |||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||||
| check_type(dtype, valid_dtype, "Bernoulli") | check_type(dtype, valid_dtype, "Bernoulli") | ||||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | super(Bernoulli, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = mstype.float32 | self.parameter_type = mstype.float32 | ||||
| @@ -211,7 +212,6 @@ class Bernoulli(Distribution): | |||||
| """ | """ | ||||
| self.checktensor(value, 'value') | self.checktensor(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| value = self.floor(value) | |||||
| probs1 = self._check_param(probs1) | probs1 = self._check_param(probs1) | ||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | ||||
| @@ -19,9 +19,10 @@ from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | ||||
| raise_none_error | |||||
| raise_none_error | |||||
| from ._utils.custom_ops import exp_by_step, log_by_step | from ._utils.custom_ops import exp_by_step, log_by_step | ||||
| class Geometric(Distribution): | class Geometric(Distribution): | ||||
| """ | """ | ||||
| Geometric Distribution. | Geometric Distribution. | ||||
| @@ -100,7 +101,7 @@ class Geometric(Distribution): | |||||
| Constructor of Geometric distribution. | Constructor of Geometric distribution. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| valid_dtype = mstype.int_type + mstype.uint_type | |||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||||
| check_type(dtype, valid_dtype, "Geometric") | check_type(dtype, valid_dtype, "Geometric") | ||||
| super(Geometric, self).__init__(seed, dtype, name, param) | super(Geometric, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = mstype.float32 | self.parameter_type = mstype.float32 | ||||
| @@ -130,7 +131,6 @@ class Geometric(Distribution): | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| str_info = f'probs = {self.probs}' | str_info = f'probs = {self.probs}' | ||||
| @@ -243,7 +243,6 @@ class Geometric(Distribution): | |||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| return self.select(comp, zeros, cdf) | return self.select(comp, zeros, cdf) | ||||
| def _kl_loss(self, dist, probs1_b, probs1=None): | def _kl_loss(self, dist, probs1_b, probs1=None): | ||||
| r""" | r""" | ||||
| Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). | Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). | ||||
| @@ -22,6 +22,7 @@ import mindspore.nn.probability.distribution as msd | |||||
| from mindspore import dtype | from mindspore import dtype | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| def test_arguments(): | def test_arguments(): | ||||
| """ | """ | ||||
| Args passing during initialization. | Args passing during initialization. | ||||
| @@ -31,18 +32,22 @@ def test_arguments(): | |||||
| b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | ||||
| assert isinstance(b, msd.Distribution) | assert isinstance(b, msd.Distribution) | ||||
| def test_type(): | def test_type(): | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| msd.Bernoulli([0.1], dtype=dtype.float32) | |||||
| msd.Bernoulli([0.1], dtype=dtype.bool_) | |||||
| def test_name(): | def test_name(): | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| msd.Bernoulli([0.1], name=1.0) | msd.Bernoulli([0.1], name=1.0) | ||||
| def test_seed(): | def test_seed(): | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| msd.Bernoulli([0.1], seed='seed') | msd.Bernoulli([0.1], seed='seed') | ||||
| def test_prob(): | def test_prob(): | ||||
| """ | """ | ||||
| Invalid probability. | Invalid probability. | ||||
| @@ -56,10 +61,12 @@ def test_prob(): | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| msd.Bernoulli([1.0], dtype=dtype.int32) | msd.Bernoulli([1.0], dtype=dtype.int32) | ||||
| class BernoulliProb(nn.Cell): | class BernoulliProb(nn.Cell): | ||||
| """ | """ | ||||
| Bernoulli distribution: initialize with probs. | Bernoulli distribution: initialize with probs. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliProb, self).__init__() | super(BernoulliProb, self).__init__() | ||||
| self.b = msd.Bernoulli(0.5, dtype=dtype.int32) | self.b = msd.Bernoulli(0.5, dtype=dtype.int32) | ||||
| @@ -73,6 +80,7 @@ class BernoulliProb(nn.Cell): | |||||
| log_sf = self.b.log_survival(value) | log_sf = self.b.log_survival(value) | ||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | return prob + log_prob + cdf + log_cdf + sf + log_sf | ||||
| def test_bernoulli_prob(): | def test_bernoulli_prob(): | ||||
| """ | """ | ||||
| Test probability functions: passing value through construct. | Test probability functions: passing value through construct. | ||||
| @@ -82,10 +90,12 @@ def test_bernoulli_prob(): | |||||
| ans = net(value) | ans = net(value) | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliProb1(nn.Cell): | class BernoulliProb1(nn.Cell): | ||||
| """ | """ | ||||
| Bernoulli distribution: initialize without probs. | Bernoulli distribution: initialize without probs. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliProb1, self).__init__() | super(BernoulliProb1, self).__init__() | ||||
| self.b = msd.Bernoulli(dtype=dtype.int32) | self.b = msd.Bernoulli(dtype=dtype.int32) | ||||
| @@ -99,6 +109,7 @@ class BernoulliProb1(nn.Cell): | |||||
| log_sf = self.b.log_survival(value, probs) | log_sf = self.b.log_survival(value, probs) | ||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | return prob + log_prob + cdf + log_cdf + sf + log_sf | ||||
| def test_bernoulli_prob1(): | def test_bernoulli_prob1(): | ||||
| """ | """ | ||||
| Test probability functions: passing value/probs through construct. | Test probability functions: passing value/probs through construct. | ||||
| @@ -109,10 +120,12 @@ def test_bernoulli_prob1(): | |||||
| ans = net(value, probs) | ans = net(value, probs) | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliKl(nn.Cell): | class BernoulliKl(nn.Cell): | ||||
| """ | """ | ||||
| Test class: kl_loss between Bernoulli distributions. | Test class: kl_loss between Bernoulli distributions. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliKl, self).__init__() | super(BernoulliKl, self).__init__() | ||||
| self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) | self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) | ||||
| @@ -123,6 +136,7 @@ class BernoulliKl(nn.Cell): | |||||
| kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a) | kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a) | ||||
| return kl1 + kl2 | return kl1 + kl2 | ||||
| def test_kl(): | def test_kl(): | ||||
| """ | """ | ||||
| Test kl_loss function. | Test kl_loss function. | ||||
| @@ -133,10 +147,12 @@ def test_kl(): | |||||
| ans = ber_net(probs_b, probs_a) | ans = ber_net(probs_b, probs_a) | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliCrossEntropy(nn.Cell): | class BernoulliCrossEntropy(nn.Cell): | ||||
| """ | """ | ||||
| Test class: cross_entropy of Bernoulli distribution. | Test class: cross_entropy of Bernoulli distribution. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliCrossEntropy, self).__init__() | super(BernoulliCrossEntropy, self).__init__() | ||||
| self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) | self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) | ||||
| @@ -147,6 +163,7 @@ class BernoulliCrossEntropy(nn.Cell): | |||||
| h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a) | h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a) | ||||
| return h1 + h2 | return h1 + h2 | ||||
| def test_cross_entropy(): | def test_cross_entropy(): | ||||
| """ | """ | ||||
| Test cross_entropy between Bernoulli distributions. | Test cross_entropy between Bernoulli distributions. | ||||
| @@ -157,10 +174,12 @@ def test_cross_entropy(): | |||||
| ans = net(probs_b, probs_a) | ans = net(probs_b, probs_a) | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliConstruct(nn.Cell): | class BernoulliConstruct(nn.Cell): | ||||
| """ | """ | ||||
| Bernoulli distribution: going through construct. | Bernoulli distribution: going through construct. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliConstruct, self).__init__() | super(BernoulliConstruct, self).__init__() | ||||
| self.b = msd.Bernoulli(0.5, dtype=dtype.int32) | self.b = msd.Bernoulli(0.5, dtype=dtype.int32) | ||||
| @@ -172,6 +191,7 @@ class BernoulliConstruct(nn.Cell): | |||||
| prob2 = self.b1('prob', value, probs) | prob2 = self.b1('prob', value, probs) | ||||
| return prob + prob1 + prob2 | return prob + prob1 + prob2 | ||||
| def test_bernoulli_construct(): | def test_bernoulli_construct(): | ||||
| """ | """ | ||||
| Test probability function going through construct. | Test probability function going through construct. | ||||
| @@ -182,10 +202,12 @@ def test_bernoulli_construct(): | |||||
| ans = net(value, probs) | ans = net(value, probs) | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliMean(nn.Cell): | class BernoulliMean(nn.Cell): | ||||
| """ | """ | ||||
| Test class: basic mean/sd/var/mode/entropy function. | Test class: basic mean/sd/var/mode/entropy function. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliMean, self).__init__() | super(BernoulliMean, self).__init__() | ||||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | ||||
| @@ -194,6 +216,7 @@ class BernoulliMean(nn.Cell): | |||||
| mean = self.b.mean() | mean = self.b.mean() | ||||
| return mean | return mean | ||||
| def test_mean(): | def test_mean(): | ||||
| """ | """ | ||||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | ||||
| @@ -202,10 +225,12 @@ def test_mean(): | |||||
| ans = net() | ans = net() | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliSd(nn.Cell): | class BernoulliSd(nn.Cell): | ||||
| """ | """ | ||||
| Test class: basic mean/sd/var/mode/entropy function. | Test class: basic mean/sd/var/mode/entropy function. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliSd, self).__init__() | super(BernoulliSd, self).__init__() | ||||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | ||||
| @@ -214,6 +239,7 @@ class BernoulliSd(nn.Cell): | |||||
| sd = self.b.sd() | sd = self.b.sd() | ||||
| return sd | return sd | ||||
| def test_sd(): | def test_sd(): | ||||
| """ | """ | ||||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | ||||
| @@ -222,10 +248,12 @@ def test_sd(): | |||||
| ans = net() | ans = net() | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliVar(nn.Cell): | class BernoulliVar(nn.Cell): | ||||
| """ | """ | ||||
| Test class: basic mean/sd/var/mode/entropy function. | Test class: basic mean/sd/var/mode/entropy function. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliVar, self).__init__() | super(BernoulliVar, self).__init__() | ||||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | ||||
| @@ -234,6 +262,7 @@ class BernoulliVar(nn.Cell): | |||||
| var = self.b.var() | var = self.b.var() | ||||
| return var | return var | ||||
| def test_var(): | def test_var(): | ||||
| """ | """ | ||||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | ||||
| @@ -242,10 +271,12 @@ def test_var(): | |||||
| ans = net() | ans = net() | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliMode(nn.Cell): | class BernoulliMode(nn.Cell): | ||||
| """ | """ | ||||
| Test class: basic mean/sd/var/mode/entropy function. | Test class: basic mean/sd/var/mode/entropy function. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliMode, self).__init__() | super(BernoulliMode, self).__init__() | ||||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | ||||
| @@ -254,6 +285,7 @@ class BernoulliMode(nn.Cell): | |||||
| mode = self.b.mode() | mode = self.b.mode() | ||||
| return mode | return mode | ||||
| def test_mode(): | def test_mode(): | ||||
| """ | """ | ||||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | ||||
| @@ -262,10 +294,12 @@ def test_mode(): | |||||
| ans = net() | ans = net() | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class BernoulliEntropy(nn.Cell): | class BernoulliEntropy(nn.Cell): | ||||
| """ | """ | ||||
| Test class: basic mean/sd/var/mode/entropy function. | Test class: basic mean/sd/var/mode/entropy function. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(BernoulliEntropy, self).__init__() | super(BernoulliEntropy, self).__init__() | ||||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | ||||
| @@ -274,6 +308,7 @@ class BernoulliEntropy(nn.Cell): | |||||
| entropy = self.b.entropy() | entropy = self.b.entropy() | ||||
| return entropy | return entropy | ||||
| def test_entropy(): | def test_entropy(): | ||||
| """ | """ | ||||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | ||||
| @@ -32,18 +32,22 @@ def test_arguments(): | |||||
| g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | ||||
| assert isinstance(g, msd.Distribution) | assert isinstance(g, msd.Distribution) | ||||
| def test_type(): | def test_type(): | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| msd.Geometric([0.1], dtype=dtype.float32) | |||||
| msd.Geometric([0.1], dtype=dtype.bool_) | |||||
| def test_name(): | def test_name(): | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| msd.Geometric([0.1], name=1.0) | msd.Geometric([0.1], name=1.0) | ||||
| def test_seed(): | def test_seed(): | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| msd.Geometric([0.1], seed='seed') | msd.Geometric([0.1], seed='seed') | ||||
| def test_prob(): | def test_prob(): | ||||
| """ | """ | ||||
| Invalid probability. | Invalid probability. | ||||
| @@ -57,10 +61,12 @@ def test_prob(): | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| msd.Geometric([1.0], dtype=dtype.int32) | msd.Geometric([1.0], dtype=dtype.int32) | ||||
| class GeometricProb(nn.Cell): | class GeometricProb(nn.Cell): | ||||
| """ | """ | ||||
| Geometric distribution: initialize with probs. | Geometric distribution: initialize with probs. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(GeometricProb, self).__init__() | super(GeometricProb, self).__init__() | ||||
| self.g = msd.Geometric(0.5, dtype=dtype.int32) | self.g = msd.Geometric(0.5, dtype=dtype.int32) | ||||
| @@ -74,6 +80,7 @@ class GeometricProb(nn.Cell): | |||||
| log_sf = self.g.log_survival(value) | log_sf = self.g.log_survival(value) | ||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | return prob + log_prob + cdf + log_cdf + sf + log_sf | ||||
| def test_geometric_prob(): | def test_geometric_prob(): | ||||
| """ | """ | ||||
| Test probability functions: passing value through construct. | Test probability functions: passing value through construct. | ||||
| @@ -83,10 +90,12 @@ def test_geometric_prob(): | |||||
| ans = net(value) | ans = net(value) | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class GeometricProb1(nn.Cell): | class GeometricProb1(nn.Cell): | ||||
| """ | """ | ||||
| Geometric distribution: initialize without probs. | Geometric distribution: initialize without probs. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(GeometricProb1, self).__init__() | super(GeometricProb1, self).__init__() | ||||
| self.g = msd.Geometric(dtype=dtype.int32) | self.g = msd.Geometric(dtype=dtype.int32) | ||||
| @@ -100,6 +109,7 @@ class GeometricProb1(nn.Cell): | |||||
| log_sf = self.g.log_survival(value, probs) | log_sf = self.g.log_survival(value, probs) | ||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | return prob + log_prob + cdf + log_cdf + sf + log_sf | ||||
| def test_geometric_prob1(): | def test_geometric_prob1(): | ||||
| """ | """ | ||||
| Test probability functions: passing value/probs through construct. | Test probability functions: passing value/probs through construct. | ||||
| @@ -115,6 +125,7 @@ class GeometricKl(nn.Cell): | |||||
| """ | """ | ||||
| Test class: kl_loss between Geometric distributions. | Test class: kl_loss between Geometric distributions. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(GeometricKl, self).__init__() | super(GeometricKl, self).__init__() | ||||
| self.g1 = msd.Geometric(0.7, dtype=dtype.int32) | self.g1 = msd.Geometric(0.7, dtype=dtype.int32) | ||||
| @@ -125,6 +136,7 @@ class GeometricKl(nn.Cell): | |||||
| kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a) | kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a) | ||||
| return kl1 + kl2 | return kl1 + kl2 | ||||
| def test_kl(): | def test_kl(): | ||||
| """ | """ | ||||
| Test kl_loss function. | Test kl_loss function. | ||||
| @@ -135,10 +147,12 @@ def test_kl(): | |||||
| ans = ber_net(probs_b, probs_a) | ans = ber_net(probs_b, probs_a) | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class GeometricCrossEntropy(nn.Cell): | class GeometricCrossEntropy(nn.Cell): | ||||
| """ | """ | ||||
| Test class: cross_entropy of Geometric distribution. | Test class: cross_entropy of Geometric distribution. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(GeometricCrossEntropy, self).__init__() | super(GeometricCrossEntropy, self).__init__() | ||||
| self.g1 = msd.Geometric(0.3, dtype=dtype.int32) | self.g1 = msd.Geometric(0.3, dtype=dtype.int32) | ||||
| @@ -149,6 +163,7 @@ class GeometricCrossEntropy(nn.Cell): | |||||
| h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a) | h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a) | ||||
| return h1 + h2 | return h1 + h2 | ||||
| def test_cross_entropy(): | def test_cross_entropy(): | ||||
| """ | """ | ||||
| Test cross_entropy between Geometric distributions. | Test cross_entropy between Geometric distributions. | ||||
| @@ -159,10 +174,12 @@ def test_cross_entropy(): | |||||
| ans = net(probs_b, probs_a) | ans = net(probs_b, probs_a) | ||||
| assert isinstance(ans, Tensor) | assert isinstance(ans, Tensor) | ||||
| class GeometricBasics(nn.Cell): | class GeometricBasics(nn.Cell): | ||||
| """ | """ | ||||
| Test class: basic mean/sd/mode/entropy function. | Test class: basic mean/sd/mode/entropy function. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(GeometricBasics, self).__init__() | super(GeometricBasics, self).__init__() | ||||
| self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32) | self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32) | ||||
| @@ -175,6 +192,7 @@ class GeometricBasics(nn.Cell): | |||||
| entropy = self.g.entropy() | entropy = self.g.entropy() | ||||
| return mean + sd + var + mode + entropy | return mean + sd + var + mode + entropy | ||||
| def test_bascis(): | def test_bascis(): | ||||
| """ | """ | ||||
| Test mean/sd/mode/entropy functionality of Geometric distribution. | Test mean/sd/mode/entropy functionality of Geometric distribution. | ||||
| @@ -188,6 +206,7 @@ class GeoConstruct(nn.Cell): | |||||
| """ | """ | ||||
| Bernoulli distribution: going through construct. | Bernoulli distribution: going through construct. | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(GeoConstruct, self).__init__() | super(GeoConstruct, self).__init__() | ||||
| self.g = msd.Geometric(0.5, dtype=dtype.int32) | self.g = msd.Geometric(0.5, dtype=dtype.int32) | ||||
| @@ -199,6 +218,7 @@ class GeoConstruct(nn.Cell): | |||||
| prob2 = self.g1('prob', value, probs) | prob2 = self.g1('prob', value, probs) | ||||
| return prob + prob1 + prob2 | return prob + prob1 + prob2 | ||||
| def test_geo_construct(): | def test_geo_construct(): | ||||
| """ | """ | ||||
| Test probability function going through construct. | Test probability function going through construct. | ||||