update threshold of softplus computation support fp for bernoulli and geometric distributiontags/v1.0.0
| @@ -20,6 +20,7 @@ from ..distribution._utils.utils import CheckTensor | |||
| from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step | |||
| from .bijector import Bijector | |||
| class PowerTransform(Bijector): | |||
| r""" | |||
| Power Bijector. | |||
| @@ -49,6 +50,7 @@ class PowerTransform(Bijector): | |||
| >>> # by replacing 'forward' with the name of the function | |||
| >>> ans = self.p1.forward(, value) | |||
| """ | |||
| def __init__(self, | |||
| power=0, | |||
| name='PowerTransform', | |||
| @@ -78,13 +80,13 @@ class PowerTransform(Bijector): | |||
| return shape | |||
| def _forward(self, x): | |||
| self.checktensor(x, 'x') | |||
| self.checktensor(x, 'value') | |||
| if self.power == 0: | |||
| return self.exp(x) | |||
| return self.exp(self.log1p(x * self.power) / self.power) | |||
| def _inverse(self, y): | |||
| self.checktensor(y, 'y') | |||
| self.checktensor(y, 'value') | |||
| if self.power == 0: | |||
| return self.log(y) | |||
| return self.expm1(self.log(y) * self.power) / self.power | |||
| @@ -101,7 +103,7 @@ class PowerTransform(Bijector): | |||
| f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | |||
| \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | |||
| """ | |||
| self.checktensor(x, 'x') | |||
| self.checktensor(x, 'value') | |||
| if self.power == 0: | |||
| return x | |||
| return (1. / self.power - 1) * self.log1p(x * self.power) | |||
| @@ -118,5 +120,5 @@ class PowerTransform(Bijector): | |||
| f'(x) = \frac{e^c\log(y)}{y} | |||
| \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | |||
| """ | |||
| self.checktensor(y, 'y') | |||
| self.checktensor(y, 'value') | |||
| return (self.power - 1) * self.log(y) | |||
| @@ -19,6 +19,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor | |||
| from ..distribution._utils.custom_ops import log_by_step | |||
| from .bijector import Bijector | |||
| class ScalarAffine(Bijector): | |||
| """ | |||
| Scalar Affine Bijector. | |||
| @@ -47,6 +48,7 @@ class ScalarAffine(Bijector): | |||
| >>> ans = self.s1.forward_log_jacobian(value) | |||
| >>> ans = self.s1.inverse_log_jacobian(value) | |||
| """ | |||
| def __init__(self, | |||
| scale=1.0, | |||
| shift=0.0, | |||
| @@ -91,7 +93,7 @@ class ScalarAffine(Bijector): | |||
| .. math:: | |||
| f(x) = a * x + b | |||
| """ | |||
| self.checktensor(x, 'x') | |||
| self.checktensor(x, 'value') | |||
| return self.scale * x + self.shift | |||
| def _inverse(self, y): | |||
| @@ -99,7 +101,7 @@ class ScalarAffine(Bijector): | |||
| .. math:: | |||
| f(y) = \frac{y - b}{a} | |||
| """ | |||
| self.checktensor(y, 'y') | |||
| self.checktensor(y, 'value') | |||
| return (y - self.shift) / self.scale | |||
| def _forward_log_jacobian(self, x): | |||
| @@ -109,7 +111,7 @@ class ScalarAffine(Bijector): | |||
| f'(x) = a | |||
| \log(f'(x)) = \log(a) | |||
| """ | |||
| self.checktensor(x, 'x') | |||
| self.checktensor(x, 'value') | |||
| return self.log(self.abs(self.scale)) | |||
| def _inverse_log_jacobian(self, y): | |||
| @@ -119,5 +121,5 @@ class ScalarAffine(Bijector): | |||
| f'(x) = \frac{1.0}{a} | |||
| \log(f'(x)) = - \log(a) | |||
| """ | |||
| self.checktensor(y, 'y') | |||
| self.checktensor(y, 'value') | |||
| return -1. * self.log(self.abs(self.scale)) | |||
| @@ -22,6 +22,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor | |||
| from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step | |||
| from .bijector import Bijector | |||
| class Softplus(Bijector): | |||
| r""" | |||
| Softplus Bijector. | |||
| @@ -51,6 +52,7 @@ class Softplus(Bijector): | |||
| >>> ans = self.sp1.forward_log_jacobian(value) | |||
| >>> ans = self.sp1.inverse_log_jacobian(value) | |||
| """ | |||
| def __init__(self, | |||
| sharpness=1.0, | |||
| name='Softplus'): | |||
| @@ -76,6 +78,7 @@ class Softplus(Bijector): | |||
| self.checktensor = CheckTensor() | |||
| self.threshold = np.log(np.finfo(np.float32).eps) + 1 | |||
| self.tiny = np.exp(self.threshold) | |||
| def _softplus(self, x): | |||
| too_small = self.less(x, self.threshold) | |||
| @@ -94,7 +97,7 @@ class Softplus(Bijector): | |||
| f(x) = \frac{\log(1 + e^{x}))} | |||
| f^{-1}(y) = \frac{\log(e^{y} - 1)} | |||
| """ | |||
| too_small = self.less(x, self.threshold) | |||
| too_small = self.less(x, self.tiny) | |||
| too_large = self.greater(x, -self.threshold) | |||
| too_small_value = self.log(x) | |||
| too_large_value = x | |||
| @@ -116,7 +119,7 @@ class Softplus(Bijector): | |||
| return shape | |||
| def _forward(self, x): | |||
| self.checktensor(x, 'x') | |||
| self.checktensor(x, 'value') | |||
| scaled_value = self.sharpness * x | |||
| return self.softplus(scaled_value) / self.sharpness | |||
| @@ -126,7 +129,7 @@ class Softplus(Bijector): | |||
| f(x) = \frac{\log(1 + e^{kx}))}{k} | |||
| f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | |||
| """ | |||
| self.checktensor(y, 'y') | |||
| self.checktensor(y, 'value') | |||
| scaled_value = self.sharpness * y | |||
| return self.inverse_softplus(scaled_value) / self.sharpness | |||
| @@ -137,7 +140,7 @@ class Softplus(Bijector): | |||
| f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | |||
| \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | |||
| """ | |||
| self.checktensor(x, 'x') | |||
| self.checktensor(x, 'value') | |||
| scaled_value = self.sharpness * x | |||
| return self.log_sigmoid(scaled_value) | |||
| @@ -148,6 +151,6 @@ class Softplus(Bijector): | |||
| f'(y) = \frac{e^{ky}}{e^{ky} - 1} | |||
| \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | |||
| """ | |||
| self.checktensor(y, 'y') | |||
| self.checktensor(y, 'value') | |||
| scaled_value = self.sharpness * y | |||
| return scaled_value - self.inverse_softplus(scaled_value) | |||
| @@ -26,6 +26,7 @@ from mindspore import context | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability as msp | |||
| def cast_to_tensor(t, hint_type=mstype.float32): | |||
| """ | |||
| Cast an user input value into a Tensor of dtype. | |||
| @@ -47,7 +48,7 @@ def cast_to_tensor(t, hint_type=mstype.float32): | |||
| return t | |||
| t_type = hint_type | |||
| if isinstance(t, Tensor): | |||
| #convert the type of tensor to dtype | |||
| # convert the type of tensor to dtype | |||
| return Tensor(t.asnumpy(), dtype=t_type) | |||
| if isinstance(t, (list, np.ndarray)): | |||
| return Tensor(t, dtype=t_type) | |||
| @@ -56,7 +57,8 @@ def cast_to_tensor(t, hint_type=mstype.float32): | |||
| if isinstance(t, (int, float)): | |||
| return Tensor(t, dtype=t_type) | |||
| invalid_type = type(t) | |||
| raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") | |||
| raise TypeError( | |||
| f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") | |||
| def convert_to_batch(t, batch_shape, required_type): | |||
| @@ -79,6 +81,7 @@ def convert_to_batch(t, batch_shape, required_type): | |||
| t = cast_to_tensor(t, required_type) | |||
| return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) | |||
| def check_scalar_from_param(params): | |||
| """ | |||
| Check if params are all scalars. | |||
| @@ -93,11 +96,7 @@ def check_scalar_from_param(params): | |||
| return params['distribution'].is_scalar_batch | |||
| if isinstance(value, Parameter): | |||
| return False | |||
| if isinstance(value, (str, type(params['dtype']))): | |||
| continue | |||
| elif isinstance(value, (int, float)): | |||
| continue | |||
| else: | |||
| if not isinstance(value, (int, float, str, type(params['dtype']))): | |||
| return False | |||
| return True | |||
| @@ -124,7 +123,8 @@ def calc_broadcast_shape_from_param(params): | |||
| value_t = value.default_input | |||
| else: | |||
| value_t = cast_to_tensor(value, mstype.float32) | |||
| broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) | |||
| broadcast_shape = utils.get_broadcast_shape( | |||
| broadcast_shape, list(value_t.shape), params['name']) | |||
| return tuple(broadcast_shape) | |||
| @@ -148,6 +148,7 @@ def check_greater_equal_zero(value, name): | |||
| if comp.any(): | |||
| raise ValueError(f'{name} should be greater than ot equal to zero.') | |||
| def check_greater_zero(value, name): | |||
| """ | |||
| Check if the given Tensor is strictly greater than zero. | |||
| @@ -251,6 +252,7 @@ def probs_to_logits(probs, is_binary=False): | |||
| return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) | |||
| return P.Log()(ps_clamped) | |||
| def check_tensor_type(name, inputs, valid_type): | |||
| """ | |||
| Check if inputs is proper. | |||
| @@ -268,25 +270,34 @@ def check_tensor_type(name, inputs, valid_type): | |||
| if input_type not in valid_type: | |||
| raise TypeError(f"{name} dtype is invalid") | |||
| def check_type(data_type, value_type, name): | |||
| if not data_type in value_type: | |||
| raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid") | |||
| raise TypeError( | |||
| f"For {name}, valid type include {value_type}, {data_type} is invalid") | |||
| @constexpr | |||
| def raise_none_error(name): | |||
| raise TypeError(f"the type {name} should be subclass of Tensor." | |||
| f" It should not be None since it is not specified during initialization.") | |||
| @constexpr | |||
| def raise_not_impl_error(name): | |||
| raise ValueError(f"{name} function should be implemented for non-linear transformation") | |||
| raise ValueError( | |||
| f"{name} function should be implemented for non-linear transformation") | |||
| @constexpr | |||
| def check_distribution_name(name, expected_name): | |||
| if name is None: | |||
| raise ValueError(f"Distribution should be a constant which is not None.") | |||
| raise ValueError( | |||
| f"Distribution should be a constant which is not None.") | |||
| if name != expected_name: | |||
| raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.") | |||
| raise ValueError( | |||
| f"Expected distribution name is {expected_name}, but got {name}.") | |||
| class CheckTuple(PrimitiveWithInfer): | |||
| """ | |||
| @@ -294,13 +305,13 @@ class CheckTuple(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init Cast""" | |||
| super(CheckTuple, self).__init__("CheckTuple") | |||
| self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) | |||
| self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) | |||
| def __infer__(self, x, name): | |||
| if not isinstance(x['dtype'], tuple): | |||
| raise TypeError(f"For {name['value']}, Input type should b a tuple.") | |||
| raise TypeError( | |||
| f"For {name['value']}, Input type should b a tuple.") | |||
| out = {'shape': None, | |||
| 'dtype': None, | |||
| @@ -310,24 +321,25 @@ class CheckTuple(PrimitiveWithInfer): | |||
| def __call__(self, x, name): | |||
| if context.get_context("mode") == 0: | |||
| return x["value"] | |||
| #Pynative mode | |||
| # Pynative mode | |||
| if isinstance(x, tuple): | |||
| return x | |||
| raise TypeError(f"For {name['value']}, Input type should b a tuple.") | |||
| class CheckTensor(PrimitiveWithInfer): | |||
| """ | |||
| Check if input is a Tensor. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init Cast""" | |||
| super(CheckTensor, self).__init__("CheckTensor") | |||
| self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) | |||
| self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) | |||
| def __infer__(self, x, name): | |||
| src_type = x['dtype'] | |||
| validator.check_subclass("input", src_type, [mstype.tensor], name["value"]) | |||
| validator.check_subclass( | |||
| "input", src_type, [mstype.tensor], name["value"]) | |||
| out = {'shape': None, | |||
| 'dtype': None, | |||
| @@ -20,6 +20,7 @@ from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | |||
| from ._utils.custom_ops import exp_by_step, log_by_step | |||
| class Bernoulli(Distribution): | |||
| """ | |||
| Bernoulli Distribution. | |||
| @@ -97,7 +98,7 @@ class Bernoulli(Distribution): | |||
| Constructor of Bernoulli distribution. | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.int_type + mstype.uint_type | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| check_type(dtype, valid_dtype, "Bernoulli") | |||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = mstype.float32 | |||
| @@ -211,7 +212,6 @@ class Bernoulli(Distribution): | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| value = self.floor(value) | |||
| probs1 = self._check_param(probs1) | |||
| probs0 = 1.0 - probs1 | |||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | |||
| @@ -19,9 +19,10 @@ from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| raise_none_error | |||
| from ._utils.custom_ops import exp_by_step, log_by_step | |||
| class Geometric(Distribution): | |||
| """ | |||
| Geometric Distribution. | |||
| @@ -100,7 +101,7 @@ class Geometric(Distribution): | |||
| Constructor of Geometric distribution. | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.int_type + mstype.uint_type | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| check_type(dtype, valid_dtype, "Geometric") | |||
| super(Geometric, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = mstype.float32 | |||
| @@ -130,7 +131,6 @@ class Geometric(Distribution): | |||
| self.sqrt = P.Sqrt() | |||
| self.uniform = C.uniform | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| str_info = f'probs = {self.probs}' | |||
| @@ -243,7 +243,6 @@ class Geometric(Distribution): | |||
| comp = self.less(value, zeros) | |||
| return self.select(comp, zeros, cdf) | |||
| def _kl_loss(self, dist, probs1_b, probs1=None): | |||
| r""" | |||
| Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). | |||
| @@ -22,6 +22,7 @@ import mindspore.nn.probability.distribution as msd | |||
| from mindspore import dtype | |||
| from mindspore import Tensor | |||
| def test_arguments(): | |||
| """ | |||
| Args passing during initialization. | |||
| @@ -31,18 +32,22 @@ def test_arguments(): | |||
| b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | |||
| assert isinstance(b, msd.Distribution) | |||
| def test_type(): | |||
| with pytest.raises(TypeError): | |||
| msd.Bernoulli([0.1], dtype=dtype.float32) | |||
| msd.Bernoulli([0.1], dtype=dtype.bool_) | |||
| def test_name(): | |||
| with pytest.raises(TypeError): | |||
| msd.Bernoulli([0.1], name=1.0) | |||
| def test_seed(): | |||
| with pytest.raises(TypeError): | |||
| msd.Bernoulli([0.1], seed='seed') | |||
| def test_prob(): | |||
| """ | |||
| Invalid probability. | |||
| @@ -56,10 +61,12 @@ def test_prob(): | |||
| with pytest.raises(ValueError): | |||
| msd.Bernoulli([1.0], dtype=dtype.int32) | |||
| class BernoulliProb(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: initialize with probs. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliProb, self).__init__() | |||
| self.b = msd.Bernoulli(0.5, dtype=dtype.int32) | |||
| @@ -73,6 +80,7 @@ class BernoulliProb(nn.Cell): | |||
| log_sf = self.b.log_survival(value) | |||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||
| def test_bernoulli_prob(): | |||
| """ | |||
| Test probability functions: passing value through construct. | |||
| @@ -82,10 +90,12 @@ def test_bernoulli_prob(): | |||
| ans = net(value) | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliProb1(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: initialize without probs. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliProb1, self).__init__() | |||
| self.b = msd.Bernoulli(dtype=dtype.int32) | |||
| @@ -99,6 +109,7 @@ class BernoulliProb1(nn.Cell): | |||
| log_sf = self.b.log_survival(value, probs) | |||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||
| def test_bernoulli_prob1(): | |||
| """ | |||
| Test probability functions: passing value/probs through construct. | |||
| @@ -109,10 +120,12 @@ def test_bernoulli_prob1(): | |||
| ans = net(value, probs) | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliKl(nn.Cell): | |||
| """ | |||
| Test class: kl_loss between Bernoulli distributions. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliKl, self).__init__() | |||
| self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) | |||
| @@ -123,6 +136,7 @@ class BernoulliKl(nn.Cell): | |||
| kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a) | |||
| return kl1 + kl2 | |||
| def test_kl(): | |||
| """ | |||
| Test kl_loss function. | |||
| @@ -133,10 +147,12 @@ def test_kl(): | |||
| ans = ber_net(probs_b, probs_a) | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliCrossEntropy(nn.Cell): | |||
| """ | |||
| Test class: cross_entropy of Bernoulli distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliCrossEntropy, self).__init__() | |||
| self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) | |||
| @@ -147,6 +163,7 @@ class BernoulliCrossEntropy(nn.Cell): | |||
| h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a) | |||
| return h1 + h2 | |||
| def test_cross_entropy(): | |||
| """ | |||
| Test cross_entropy between Bernoulli distributions. | |||
| @@ -157,10 +174,12 @@ def test_cross_entropy(): | |||
| ans = net(probs_b, probs_a) | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliConstruct(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: going through construct. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliConstruct, self).__init__() | |||
| self.b = msd.Bernoulli(0.5, dtype=dtype.int32) | |||
| @@ -172,6 +191,7 @@ class BernoulliConstruct(nn.Cell): | |||
| prob2 = self.b1('prob', value, probs) | |||
| return prob + prob1 + prob2 | |||
| def test_bernoulli_construct(): | |||
| """ | |||
| Test probability function going through construct. | |||
| @@ -182,10 +202,12 @@ def test_bernoulli_construct(): | |||
| ans = net(value, probs) | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliMean(nn.Cell): | |||
| """ | |||
| Test class: basic mean/sd/var/mode/entropy function. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliMean, self).__init__() | |||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | |||
| @@ -194,6 +216,7 @@ class BernoulliMean(nn.Cell): | |||
| mean = self.b.mean() | |||
| return mean | |||
| def test_mean(): | |||
| """ | |||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | |||
| @@ -202,10 +225,12 @@ def test_mean(): | |||
| ans = net() | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliSd(nn.Cell): | |||
| """ | |||
| Test class: basic mean/sd/var/mode/entropy function. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliSd, self).__init__() | |||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | |||
| @@ -214,6 +239,7 @@ class BernoulliSd(nn.Cell): | |||
| sd = self.b.sd() | |||
| return sd | |||
| def test_sd(): | |||
| """ | |||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | |||
| @@ -222,10 +248,12 @@ def test_sd(): | |||
| ans = net() | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliVar(nn.Cell): | |||
| """ | |||
| Test class: basic mean/sd/var/mode/entropy function. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliVar, self).__init__() | |||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | |||
| @@ -234,6 +262,7 @@ class BernoulliVar(nn.Cell): | |||
| var = self.b.var() | |||
| return var | |||
| def test_var(): | |||
| """ | |||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | |||
| @@ -242,10 +271,12 @@ def test_var(): | |||
| ans = net() | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliMode(nn.Cell): | |||
| """ | |||
| Test class: basic mean/sd/var/mode/entropy function. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliMode, self).__init__() | |||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | |||
| @@ -254,6 +285,7 @@ class BernoulliMode(nn.Cell): | |||
| mode = self.b.mode() | |||
| return mode | |||
| def test_mode(): | |||
| """ | |||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | |||
| @@ -262,10 +294,12 @@ def test_mode(): | |||
| ans = net() | |||
| assert isinstance(ans, Tensor) | |||
| class BernoulliEntropy(nn.Cell): | |||
| """ | |||
| Test class: basic mean/sd/var/mode/entropy function. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliEntropy, self).__init__() | |||
| self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) | |||
| @@ -274,6 +308,7 @@ class BernoulliEntropy(nn.Cell): | |||
| entropy = self.b.entropy() | |||
| return entropy | |||
| def test_entropy(): | |||
| """ | |||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | |||
| @@ -32,18 +32,22 @@ def test_arguments(): | |||
| g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) | |||
| assert isinstance(g, msd.Distribution) | |||
| def test_type(): | |||
| with pytest.raises(TypeError): | |||
| msd.Geometric([0.1], dtype=dtype.float32) | |||
| msd.Geometric([0.1], dtype=dtype.bool_) | |||
| def test_name(): | |||
| with pytest.raises(TypeError): | |||
| msd.Geometric([0.1], name=1.0) | |||
| def test_seed(): | |||
| with pytest.raises(TypeError): | |||
| msd.Geometric([0.1], seed='seed') | |||
| def test_prob(): | |||
| """ | |||
| Invalid probability. | |||
| @@ -57,10 +61,12 @@ def test_prob(): | |||
| with pytest.raises(ValueError): | |||
| msd.Geometric([1.0], dtype=dtype.int32) | |||
| class GeometricProb(nn.Cell): | |||
| """ | |||
| Geometric distribution: initialize with probs. | |||
| """ | |||
| def __init__(self): | |||
| super(GeometricProb, self).__init__() | |||
| self.g = msd.Geometric(0.5, dtype=dtype.int32) | |||
| @@ -74,6 +80,7 @@ class GeometricProb(nn.Cell): | |||
| log_sf = self.g.log_survival(value) | |||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||
| def test_geometric_prob(): | |||
| """ | |||
| Test probability functions: passing value through construct. | |||
| @@ -83,10 +90,12 @@ def test_geometric_prob(): | |||
| ans = net(value) | |||
| assert isinstance(ans, Tensor) | |||
| class GeometricProb1(nn.Cell): | |||
| """ | |||
| Geometric distribution: initialize without probs. | |||
| """ | |||
| def __init__(self): | |||
| super(GeometricProb1, self).__init__() | |||
| self.g = msd.Geometric(dtype=dtype.int32) | |||
| @@ -100,6 +109,7 @@ class GeometricProb1(nn.Cell): | |||
| log_sf = self.g.log_survival(value, probs) | |||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||
| def test_geometric_prob1(): | |||
| """ | |||
| Test probability functions: passing value/probs through construct. | |||
| @@ -115,6 +125,7 @@ class GeometricKl(nn.Cell): | |||
| """ | |||
| Test class: kl_loss between Geometric distributions. | |||
| """ | |||
| def __init__(self): | |||
| super(GeometricKl, self).__init__() | |||
| self.g1 = msd.Geometric(0.7, dtype=dtype.int32) | |||
| @@ -125,6 +136,7 @@ class GeometricKl(nn.Cell): | |||
| kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a) | |||
| return kl1 + kl2 | |||
| def test_kl(): | |||
| """ | |||
| Test kl_loss function. | |||
| @@ -135,10 +147,12 @@ def test_kl(): | |||
| ans = ber_net(probs_b, probs_a) | |||
| assert isinstance(ans, Tensor) | |||
| class GeometricCrossEntropy(nn.Cell): | |||
| """ | |||
| Test class: cross_entropy of Geometric distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(GeometricCrossEntropy, self).__init__() | |||
| self.g1 = msd.Geometric(0.3, dtype=dtype.int32) | |||
| @@ -149,6 +163,7 @@ class GeometricCrossEntropy(nn.Cell): | |||
| h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a) | |||
| return h1 + h2 | |||
| def test_cross_entropy(): | |||
| """ | |||
| Test cross_entropy between Geometric distributions. | |||
| @@ -159,10 +174,12 @@ def test_cross_entropy(): | |||
| ans = net(probs_b, probs_a) | |||
| assert isinstance(ans, Tensor) | |||
| class GeometricBasics(nn.Cell): | |||
| """ | |||
| Test class: basic mean/sd/mode/entropy function. | |||
| """ | |||
| def __init__(self): | |||
| super(GeometricBasics, self).__init__() | |||
| self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32) | |||
| @@ -175,6 +192,7 @@ class GeometricBasics(nn.Cell): | |||
| entropy = self.g.entropy() | |||
| return mean + sd + var + mode + entropy | |||
| def test_bascis(): | |||
| """ | |||
| Test mean/sd/mode/entropy functionality of Geometric distribution. | |||
| @@ -188,6 +206,7 @@ class GeoConstruct(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: going through construct. | |||
| """ | |||
| def __init__(self): | |||
| super(GeoConstruct, self).__init__() | |||
| self.g = msd.Geometric(0.5, dtype=dtype.int32) | |||
| @@ -199,6 +218,7 @@ class GeoConstruct(nn.Cell): | |||
| prob2 = self.g1('prob', value, probs) | |||
| return prob + prob1 + prob2 | |||
| def test_geo_construct(): | |||
| """ | |||
| Test probability function going through construct. | |||