| @@ -16,6 +16,7 @@ | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from ..distribution._utils.utils import CheckTensor | |||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| class PowerTransform(Bijector): | class PowerTransform(Bijector): | ||||
| @@ -62,6 +63,8 @@ class PowerTransform(Bijector): | |||||
| self.log1p = self._log1p_by_step | self.log1p = self._log1p_by_step | ||||
| self.expm1 = self._expm1_by_step | self.expm1 = self._expm1_by_step | ||||
| self.checktensor = CheckTensor() | |||||
| def _log1p_by_step(self, x): | def _log1p_by_step(self, x): | ||||
| """ | """ | ||||
| Log1p ops on GPU device or when device_target == GPU. | Log1p ops on GPU device or when device_target == GPU. | ||||
| @@ -86,11 +89,13 @@ class PowerTransform(Bijector): | |||||
| return shape | return shape | ||||
| def _forward(self, x): | def _forward(self, x): | ||||
| self.checktensor(x, 'x') | |||||
| if self.power == 0: | if self.power == 0: | ||||
| return self.exp(x) | return self.exp(x) | ||||
| return self.exp(self.log1p(x * self.power) / self.power) | return self.exp(self.log1p(x * self.power) / self.power) | ||||
| def _inverse(self, y): | def _inverse(self, y): | ||||
| self.checktensor(y, 'y') | |||||
| if self.power == 0: | if self.power == 0: | ||||
| return self.log(y) | return self.log(y) | ||||
| return self.expm1(self.log(y) * self.power) / self.power | return self.expm1(self.log(y) * self.power) / self.power | ||||
| @@ -107,6 +112,7 @@ class PowerTransform(Bijector): | |||||
| f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | ||||
| \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | ||||
| """ | """ | ||||
| self.checktensor(x, 'x') | |||||
| if self.power == 0: | if self.power == 0: | ||||
| return x | return x | ||||
| return (1. / self.power - 1) * self.log1p(x * self.power) | return (1. / self.power - 1) * self.log1p(x * self.power) | ||||
| @@ -123,4 +129,5 @@ class PowerTransform(Bijector): | |||||
| f'(x) = \frac{e^c\log(y)}{y} | f'(x) = \frac{e^c\log(y)}{y} | ||||
| \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| return (self.power - 1) * self.log(y) | return (self.power - 1) * self.log(y) | ||||
| @@ -15,7 +15,7 @@ | |||||
| """Scalar Affine Bijector""" | """Scalar Affine Bijector""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from ..distribution._utils.utils import cast_to_tensor | |||||
| from ..distribution._utils.utils import cast_to_tensor, CheckTensor | |||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| class ScalarAffine(Bijector): | class ScalarAffine(Bijector): | ||||
| @@ -54,8 +54,8 @@ class ScalarAffine(Bijector): | |||||
| Constructor of scalar affine bijector. | Constructor of scalar affine bijector. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| validator.check_value_type('scale', scale, [float], name) | |||||
| validator.check_value_type('shift', shift, [float], name) | |||||
| validator.check_value_type('scale', scale, [int, float], name) | |||||
| validator.check_value_type('shift', shift, [int, float], name) | |||||
| self._scale = cast_to_tensor(scale) | self._scale = cast_to_tensor(scale) | ||||
| self._shift = cast_to_tensor(shift) | self._shift = cast_to_tensor(shift) | ||||
| super(ScalarAffine, self).__init__( | super(ScalarAffine, self).__init__( | ||||
| @@ -65,8 +65,10 @@ class ScalarAffine(Bijector): | |||||
| dtype=None, | dtype=None, | ||||
| param=param) | param=param) | ||||
| self.abs = P.Abs() | |||||
| self.log = P.Log() | self.log = P.Log() | ||||
| self.oneslike = P.OnesLike() | |||||
| self.checktensor = CheckTensor() | |||||
| @property | @property | ||||
| def scale(self): | def scale(self): | ||||
| @@ -88,6 +90,7 @@ class ScalarAffine(Bijector): | |||||
| .. math:: | .. math:: | ||||
| f(x) = a * x + b | f(x) = a * x + b | ||||
| """ | """ | ||||
| self.checktensor(x, 'x') | |||||
| return self.scale * x + self.shift | return self.scale * x + self.shift | ||||
| def _inverse(self, y): | def _inverse(self, y): | ||||
| @@ -95,22 +98,25 @@ class ScalarAffine(Bijector): | |||||
| .. math:: | .. math:: | ||||
| f(y) = \frac{y - b}{a} | f(y) = \frac{y - b}{a} | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| return (y - self.shift) / self.scale | return (y - self.shift) / self.scale | ||||
| def _forward_log_jacobian(self, value): | |||||
| def _forward_log_jacobian(self, x): | |||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| f(x) = a * x + b | f(x) = a * x + b | ||||
| f'(x) = a | f'(x) = a | ||||
| \log(f'(x)) = \log(a) | \log(f'(x)) = \log(a) | ||||
| """ | """ | ||||
| return self.log(self.scale) * self.oneslike(value) | |||||
| self.checktensor(x, 'x') | |||||
| return self.log(self.abs(self.scale)) | |||||
| def _inverse_log_jacobian(self, value): | |||||
| def _inverse_log_jacobian(self, y): | |||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| f(y) = \frac{(y - b)}{a} | f(y) = \frac{(y - b)}{a} | ||||
| f'(x) = \frac{1.0}{a} | f'(x) = \frac{1.0}{a} | ||||
| \log(f'(x)) = - \log(a) | \log(f'(x)) = - \log(a) | ||||
| """ | """ | ||||
| return -1. * self.log(self.scale) * self.oneslike(value) | |||||
| self.checktensor(y, 'y') | |||||
| return -1. * self.log(self.abs(self.scale)) | |||||
| @@ -13,10 +13,12 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Softplus Bijector""" | """Softplus Bijector""" | ||||
| import numpy as np | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.layer.activation import LogSigmoid | from mindspore.nn.layer.activation import LogSigmoid | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from ..distribution._utils.utils import cast_to_tensor | |||||
| from ..distribution._utils.utils import cast_to_tensor, CheckTensor | |||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| class Softplus(Bijector): | class Softplus(Bijector): | ||||
| @@ -52,19 +54,28 @@ class Softplus(Bijector): | |||||
| sharpness=1.0, | sharpness=1.0, | ||||
| name='Softplus'): | name='Softplus'): | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| validator.check_value_type('sharpness', sharpness, [float], name) | |||||
| validator.check_value_type('sharpness', sharpness, [int, float], name) | |||||
| super(Softplus, self).__init__(name=name, param=param) | super(Softplus, self).__init__(name=name, param=param) | ||||
| self._sharpness = cast_to_tensor(sharpness) | self._sharpness = cast_to_tensor(sharpness) | ||||
| self.abs = P.Abs() | |||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| self.expm1 = self._expm1_by_step | self.expm1 = self._expm1_by_step | ||||
| self.fill = P.Fill() | |||||
| self.greater = P.Greater() | |||||
| self.less = P.Less() | |||||
| self.log_sigmoid = LogSigmoid() | self.log_sigmoid = LogSigmoid() | ||||
| self.log = P.Log() | self.log = P.Log() | ||||
| self.logicalor = P.LogicalOr() | |||||
| self.select = P.Select() | |||||
| self.shape = P.Shape() | |||||
| self.sigmoid = P.Sigmoid() | self.sigmoid = P.Sigmoid() | ||||
| self.softplus = self._softplus | self.softplus = self._softplus | ||||
| self.inverse_softplus = self._inverse_softplus | self.inverse_softplus = self._inverse_softplus | ||||
| self.checktensor = CheckTensor() | |||||
| self.threshold = np.log(np.finfo(np.float32).eps) + 1 | |||||
| def _expm1_by_step(self, x): | def _expm1_by_step(self, x): | ||||
| """ | """ | ||||
| Expm1 ops under GPU context. | Expm1 ops under GPU context. | ||||
| @@ -72,7 +83,15 @@ class Softplus(Bijector): | |||||
| return self.exp(x) - 1.0 | return self.exp(x) - 1.0 | ||||
| def _softplus(self, x): | def _softplus(self, x): | ||||
| return self.log(self.exp(x) + 1.0) | |||||
| too_small = self.less(x, self.threshold) | |||||
| too_large = self.greater(x, -self.threshold) | |||||
| too_small_value = self.exp(x) | |||||
| too_large_value = x | |||||
| ones = self.fill(mstype.float32, self.shape(x), 1.0) | |||||
| too_small_or_too_large = self.logicalor(too_small, too_large) | |||||
| x = self.select(too_small_or_too_large, ones, x) | |||||
| y = self.log(self.exp(x) + 1.0) | |||||
| return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y)) | |||||
| def _inverse_softplus(self, x): | def _inverse_softplus(self, x): | ||||
| r""" | r""" | ||||
| @@ -80,7 +99,15 @@ class Softplus(Bijector): | |||||
| f(x) = \frac{\log(1 + e^{x}))} | f(x) = \frac{\log(1 + e^{x}))} | ||||
| f^{-1}(y) = \frac{\log(e^{y} - 1)} | f^{-1}(y) = \frac{\log(e^{y} - 1)} | ||||
| """ | """ | ||||
| return self.log(self.expm1(x)) | |||||
| too_small = self.less(x, self.threshold) | |||||
| too_large = self.greater(x, -self.threshold) | |||||
| too_small_value = self.log(x) | |||||
| too_large_value = x | |||||
| ones = self.fill(mstype.float32, self.shape(x), 1.0) | |||||
| too_small_or_too_large = self.logicalor(too_small, too_large) | |||||
| x = self.select(too_small_or_too_large, ones, x) | |||||
| y = x + self.log(self.abs(self.expm1(-x))) | |||||
| return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y)) | |||||
| @property | @property | ||||
| def sharpness(self): | def sharpness(self): | ||||
| @@ -94,6 +121,7 @@ class Softplus(Bijector): | |||||
| return shape | return shape | ||||
| def _forward(self, x): | def _forward(self, x): | ||||
| self.checktensor(x, 'x') | |||||
| scaled_value = self.sharpness * x | scaled_value = self.sharpness * x | ||||
| return self.softplus(scaled_value) / self.sharpness | return self.softplus(scaled_value) / self.sharpness | ||||
| @@ -103,6 +131,7 @@ class Softplus(Bijector): | |||||
| f(x) = \frac{\log(1 + e^{kx}))}{k} | f(x) = \frac{\log(1 + e^{kx}))}{k} | ||||
| f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| scaled_value = self.sharpness * y | scaled_value = self.sharpness * y | ||||
| return self.inverse_softplus(scaled_value) / self.sharpness | return self.inverse_softplus(scaled_value) / self.sharpness | ||||
| @@ -113,6 +142,7 @@ class Softplus(Bijector): | |||||
| f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | ||||
| \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | ||||
| """ | """ | ||||
| self.checktensor(x, 'x') | |||||
| scaled_value = self.sharpness * x | scaled_value = self.sharpness * x | ||||
| return self.log_sigmoid(scaled_value) | return self.log_sigmoid(scaled_value) | ||||
| @@ -123,5 +153,6 @@ class Softplus(Bijector): | |||||
| f'(y) = \frac{e^{ky}}{e^{ky} - 1} | f'(y) = \frac{e^{ky}}{e^{ky} - 1} | ||||
| \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | ||||
| """ | """ | ||||
| self.checktensor(y, 'y') | |||||
| scaled_value = self.sharpness * y | scaled_value = self.sharpness * y | ||||
| return scaled_value - self.inverse_softplus(scaled_value) | return scaled_value - self.inverse_softplus(scaled_value) | ||||
| @@ -15,7 +15,8 @@ | |||||
| """Utitly functions to help distribution class.""" | """Utitly functions to help distribution class.""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import _utils as utils | from mindspore.ops import _utils as utils | ||||
| from mindspore.ops.primitive import constexpr | |||||
| from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| @@ -53,7 +54,9 @@ def cast_to_tensor(t, hint_type=mstype.float32): | |||||
| raise TypeError(f'Input cannot be Type Bool') | raise TypeError(f'Input cannot be Type Bool') | ||||
| if isinstance(t, (int, float)): | if isinstance(t, (int, float)): | ||||
| return Tensor(t, dtype=t_type) | return Tensor(t, dtype=t_type) | ||||
| raise TypeError("Input type is not supported.") | |||||
| invalid_type = type(t) | |||||
| raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") | |||||
| def convert_to_batch(t, batch_shape, required_type): | def convert_to_batch(t, batch_shape, required_type): | ||||
| """ | """ | ||||
| @@ -274,5 +277,51 @@ def raise_none_error(name): | |||||
| @constexpr | @constexpr | ||||
| def check_distribution_name(name, expected_name): | def check_distribution_name(name, expected_name): | ||||
| if name is None: | |||||
| raise ValueError(f"Distribution should be a constant which is not None.") | |||||
| if name != expected_name: | if name != expected_name: | ||||
| raise ValueError(f"Distribution should be {expected_name}.") | |||||
| raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.") | |||||
| class CheckTuple(PrimitiveWithInfer): | |||||
| """ | |||||
| Check if input is a tuple. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init Cast""" | |||||
| super(CheckTuple, self).__init__("CheckTuple") | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) | |||||
| def __infer__(self, x, name): | |||||
| if not isinstance(x['dtype'], tuple): | |||||
| raise TypeError("Input type should be a tuple: " + name["value"]) | |||||
| out = {'shape': None, | |||||
| 'dtype': None, | |||||
| 'value': None} | |||||
| return out | |||||
| def __call__(self, *args): | |||||
| return | |||||
| class CheckTensor(PrimitiveWithInfer): | |||||
| """ | |||||
| Check if input is a Tensor. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init Cast""" | |||||
| super(CheckTensor, self).__init__("CheckTensor") | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) | |||||
| def __infer__(self, x, name): | |||||
| src_type = x['dtype'] | |||||
| validator.check_subclass("input", src_type, [mstype.tensor], name["value"]) | |||||
| out = {'shape': None, | |||||
| 'dtype': None, | |||||
| 'value': None} | |||||
| return out | |||||
| def __call__(self, *args): | |||||
| return | |||||
| @@ -18,6 +18,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | ||||
| from ._utils.utils import CheckTensor, CheckTuple | |||||
| class Bernoulli(Distribution): | class Bernoulli(Distribution): | ||||
| """ | """ | ||||
| @@ -123,6 +124,9 @@ class Bernoulli(Distribution): | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| self.checktensor = CheckTensor() | |||||
| self.checktuple = CheckTuple() | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| str_info = f'probs = {self.probs}' | str_info = f'probs = {self.probs}' | ||||
| @@ -137,14 +141,21 @@ class Bernoulli(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _check_param(self, probs1): | |||||
| """ | |||||
| Check availablity of distribution specific args probs1. | |||||
| """ | |||||
| if probs1 is not None: | |||||
| self.checktensor(probs1, 'probs1') | |||||
| return self.cast(probs1, self.parameter_type) | |||||
| return self.probs if self.probs is not None else raise_none_error('probs1') | |||||
| def _mean(self, probs1=None): | def _mean(self, probs1=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| MEAN(B) = probs1 | MEAN(B) = probs1 | ||||
| """ | """ | ||||
| probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs1") | |||||
| probs1 = self._check_param(probs1) | |||||
| return probs1 | return probs1 | ||||
| def _mode(self, probs1=None): | def _mode(self, probs1=None): | ||||
| @@ -152,9 +163,7 @@ class Bernoulli(Distribution): | |||||
| .. math:: | .. math:: | ||||
| MODE(B) = 1 if probs1 > 0.5 else = 0 | MODE(B) = 1 if probs1 > 0.5 else = 0 | ||||
| """ | """ | ||||
| probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs1") | |||||
| probs1 = self._check_param(probs1) | |||||
| prob_type = self.dtypeop(probs1) | prob_type = self.dtypeop(probs1) | ||||
| zeros = self.fill(prob_type, self.shape(probs1), 0.0) | zeros = self.fill(prob_type, self.shape(probs1), 0.0) | ||||
| ones = self.fill(prob_type, self.shape(probs1), 1.0) | ones = self.fill(prob_type, self.shape(probs1), 1.0) | ||||
| @@ -166,24 +175,20 @@ class Bernoulli(Distribution): | |||||
| .. math:: | .. math:: | ||||
| VAR(B) = probs1 * probs0 | VAR(B) = probs1 * probs0 | ||||
| """ | """ | ||||
| probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs1") | |||||
| probs1 = self._check_param(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return self.exp(self.log(probs0) + self.log(probs1)) | return self.exp(self.log(probs0) + self.log(probs1)) | ||||
| def _entropy(self, probs=None): | |||||
| def _entropy(self, probs1=None): | |||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) | H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) | ||||
| """ | """ | ||||
| probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs") | |||||
| probs1 = self._check_param(probs1) | |||||
| probs0 = 1 - probs1 | probs0 = 1 - probs1 | ||||
| return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) | return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) | ||||
| def _cross_entropy(self, dist, probs1_b, probs1_a=None): | |||||
| def _cross_entropy(self, dist, probs1_b, probs1=None): | |||||
| """ | """ | ||||
| Evaluate cross_entropy between Bernoulli distributions. | Evaluate cross_entropy between Bernoulli distributions. | ||||
| @@ -193,9 +198,9 @@ class Bernoulli(Distribution): | |||||
| probs1_a (Tensor): probs1 of distribution a. Default: self.probs. | probs1_a (Tensor): probs1 of distribution a. Default: self.probs. | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Bernoulli') | check_distribution_name(dist, 'Bernoulli') | ||||
| return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) | |||||
| return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) | |||||
| def _log_prob(self, value, probs=None): | |||||
| def _log_prob(self, value, probs1=None): | |||||
| r""" | r""" | ||||
| pmf of Bernoulli distribution. | pmf of Bernoulli distribution. | ||||
| @@ -207,17 +212,14 @@ class Bernoulli(Distribution): | |||||
| pmf(k) = probs1 if k = 1; | pmf(k) = probs1 if k = 1; | ||||
| pmf(k) = probs0 if k = 0; | pmf(k) = probs0 if k = 0; | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs") | |||||
| probs1 = self._check_param(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | ||||
| def _cdf(self, value, probs=None): | |||||
| def _cdf(self, value, probs1=None): | |||||
| r""" | r""" | ||||
| cdf of Bernoulli distribution. | cdf of Bernoulli distribution. | ||||
| @@ -230,13 +232,10 @@ class Bernoulli(Distribution): | |||||
| cdf(k) = probs0 if 0 <= k <1; | cdf(k) = probs0 if 0 <= k <1; | ||||
| cdf(k) = 1 if k >=1; | cdf(k) = 1 if k >=1; | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs") | |||||
| probs1 = self._check_param(probs1) | |||||
| prob_type = self.dtypeop(probs1) | prob_type = self.dtypeop(probs1) | ||||
| value = value * self.fill(prob_type, self.shape(probs1), 1.0) | value = value * self.fill(prob_type, self.shape(probs1), 1.0) | ||||
| probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) | probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) | ||||
| @@ -247,7 +246,7 @@ class Bernoulli(Distribution): | |||||
| less_than_zero = self.select(comp_zero, zeros, probs0) | less_than_zero = self.select(comp_zero, zeros, probs0) | ||||
| return self.select(comp_one, less_than_zero, ones) | return self.select(comp_one, less_than_zero, ones) | ||||
| def _kl_loss(self, dist, probs1_b, probs1_a=None): | |||||
| def _kl_loss(self, dist, probs1_b, probs1=None): | |||||
| r""" | r""" | ||||
| Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | ||||
| @@ -261,17 +260,14 @@ class Bernoulli(Distribution): | |||||
| probs0_a * \log(\frac{probs0_a}{probs0_b}) | probs0_a * \log(\frac{probs0_a}{probs0_b}) | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Bernoulli') | check_distribution_name(dist, 'Bernoulli') | ||||
| if probs1_b is None: | |||||
| raise_none_error("probs1_b") | |||||
| self.checktensor(probs1_b, 'probs1_b') | |||||
| probs1_b = self.cast(probs1_b, self.parameter_type) | probs1_b = self.cast(probs1_b, self.parameter_type) | ||||
| probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs | |||||
| if probs1_a is None: | |||||
| raise_none_error("probs1_a") | |||||
| probs1_a = self._check_param(probs1) | |||||
| probs0_a = 1.0 - probs1_a | probs0_a = 1.0 - probs1_a | ||||
| probs0_b = 1.0 - probs1_b | probs0_b = 1.0 - probs1_b | ||||
| return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) | return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) | ||||
| def _sample(self, shape=(), probs=None): | |||||
| def _sample(self, shape=(), probs1=None): | |||||
| """ | """ | ||||
| Sampling. | Sampling. | ||||
| @@ -282,9 +278,8 @@ class Bernoulli(Distribution): | |||||
| Returns: | Returns: | ||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs") | |||||
| self.checktuple(shape, 'shape') | |||||
| probs1 = self._check_param(probs1) | |||||
| origin_shape = shape + self.shape(probs1) | origin_shape = shape + self.shape(probs1) | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| sample_shape = (1,) | sample_shape = (1,) | ||||
| @@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | ||||
| raise_none_error | raise_none_error | ||||
| from ._utils.utils import CheckTensor, CheckTuple | |||||
| class Exponential(Distribution): | class Exponential(Distribution): | ||||
| """ | """ | ||||
| @@ -125,6 +126,9 @@ class Exponential(Distribution): | |||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| self.checktensor = CheckTensor() | |||||
| self.checktuple = CheckTuple() | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| str_info = f'rate = {self.rate}' | str_info = f'rate = {self.rate}' | ||||
| @@ -139,14 +143,21 @@ class Exponential(Distribution): | |||||
| """ | """ | ||||
| return self._rate | return self._rate | ||||
| def _check_param(self, rate): | |||||
| """ | |||||
| Check availablity of distribution specific args rate. | |||||
| """ | |||||
| if rate is not None: | |||||
| self.checktensor(rate, 'rate') | |||||
| return self.cast(rate, self.parameter_type) | |||||
| return self.rate if self.rate is not None else raise_none_error('rate') | |||||
| def _mean(self, rate=None): | def _mean(self, rate=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| MEAN(EXP) = \frac{1.0}{\lambda}. | MEAN(EXP) = \frac{1.0}{\lambda}. | ||||
| """ | """ | ||||
| rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate | |||||
| if rate is None: | |||||
| raise_none_error("rate") | |||||
| rate = self._check_param(rate) | |||||
| return 1.0 / rate | return 1.0 / rate | ||||
| def _mode(self, rate=None): | def _mode(self, rate=None): | ||||
| @@ -154,9 +165,7 @@ class Exponential(Distribution): | |||||
| .. math:: | .. math:: | ||||
| MODE(EXP) = 0. | MODE(EXP) = 0. | ||||
| """ | """ | ||||
| rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate | |||||
| if rate is None: | |||||
| raise_none_error("rate") | |||||
| rate = self._check_param(rate) | |||||
| return self.fill(self.dtype, self.shape(rate), 0.) | return self.fill(self.dtype, self.shape(rate), 0.) | ||||
| def _sd(self, rate=None): | def _sd(self, rate=None): | ||||
| @@ -164,9 +173,7 @@ class Exponential(Distribution): | |||||
| .. math:: | .. math:: | ||||
| sd(EXP) = \frac{1.0}{\lambda}. | sd(EXP) = \frac{1.0}{\lambda}. | ||||
| """ | """ | ||||
| rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate | |||||
| if rate is None: | |||||
| raise_none_error("rate") | |||||
| rate = self._check_param(rate) | |||||
| return 1.0 / rate | return 1.0 / rate | ||||
| def _entropy(self, rate=None): | def _entropy(self, rate=None): | ||||
| @@ -174,13 +181,10 @@ class Exponential(Distribution): | |||||
| .. math:: | .. math:: | ||||
| H(Exp) = 1 - \log(\lambda). | H(Exp) = 1 - \log(\lambda). | ||||
| """ | """ | ||||
| rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate | |||||
| if rate is None: | |||||
| raise_none_error("rate") | |||||
| rate = self._check_param(rate) | |||||
| return 1.0 - self.log(rate) | return 1.0 - self.log(rate) | ||||
| def _cross_entropy(self, dist, rate_b, rate_a=None): | |||||
| def _cross_entropy(self, dist, rate_b, rate=None): | |||||
| """ | """ | ||||
| Evaluate cross_entropy between Exponential distributions. | Evaluate cross_entropy between Exponential distributions. | ||||
| @@ -190,7 +194,7 @@ class Exponential(Distribution): | |||||
| rate_a (Tensor): rate of distribution a. Default: self.rate. | rate_a (Tensor): rate of distribution a. Default: self.rate. | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Exponential') | check_distribution_name(dist, 'Exponential') | ||||
| return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a) | |||||
| return self._entropy(rate) + self._kl_loss(dist, rate_b, rate) | |||||
| def _prob(self, value, rate=None): | def _prob(self, value, rate=None): | ||||
| @@ -208,12 +212,9 @@ class Exponential(Distribution): | |||||
| .. math:: | .. math:: | ||||
| pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 | pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, "value") | |||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate | |||||
| if rate is None: | |||||
| raise_none_error("rate") | |||||
| rate = self._check_param(rate) | |||||
| prob = self.exp(self.log(rate) - rate * value) | prob = self.exp(self.log(rate) - rate * value) | ||||
| zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | ||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| @@ -233,19 +234,16 @@ class Exponential(Distribution): | |||||
| .. math:: | .. math:: | ||||
| cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 | cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate | |||||
| if rate is None: | |||||
| raise_none_error("rate") | |||||
| rate = self._check_param(rate) | |||||
| cdf = 1.0 - self.exp(-1. * rate * value) | cdf = 1.0 - self.exp(-1. * rate * value) | ||||
| zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) | zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) | ||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| return self.select(comp, zeros, cdf) | return self.select(comp, zeros, cdf) | ||||
| def _kl_loss(self, dist, rate_b, rate_a=None): | |||||
| def _kl_loss(self, dist, rate_b, rate=None): | |||||
| """ | """ | ||||
| Evaluate exp-exp kl divergence, i.e. KL(a||b). | Evaluate exp-exp kl divergence, i.e. KL(a||b). | ||||
| @@ -255,12 +253,9 @@ class Exponential(Distribution): | |||||
| rate_a (Tensor): rate of distribution a. Default: self.rate. | rate_a (Tensor): rate of distribution a. Default: self.rate. | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Exponential') | check_distribution_name(dist, 'Exponential') | ||||
| if rate_b is None: | |||||
| raise_none_error("rate_b") | |||||
| self.checktensor(rate_b, 'rate_b') | |||||
| rate_b = self.cast(rate_b, self.parameter_type) | rate_b = self.cast(rate_b, self.parameter_type) | ||||
| rate_a = self.cast(rate_a, self.parameter_type) if rate_a is not None else self.rate | |||||
| if rate_a is None: | |||||
| raise_none_error("rate_a") | |||||
| rate_a = self._check_param(rate) | |||||
| return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 | return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 | ||||
| def _sample(self, shape=(), rate=None): | def _sample(self, shape=(), rate=None): | ||||
| @@ -274,9 +269,8 @@ class Exponential(Distribution): | |||||
| Returns: | Returns: | ||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate | |||||
| if rate is None: | |||||
| raise_none_error("rate") | |||||
| self.checktuple(shape, 'shape') | |||||
| rate = self._check_param(rate) | |||||
| origin_shape = shape + self.shape(rate) | origin_shape = shape + self.shape(rate) | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| sample_shape = (1,) | sample_shape = (1,) | ||||
| @@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | ||||
| raise_none_error | raise_none_error | ||||
| from ._utils.utils import CheckTensor, CheckTuple | |||||
| class Geometric(Distribution): | class Geometric(Distribution): | ||||
| """ | """ | ||||
| @@ -129,6 +130,9 @@ class Geometric(Distribution): | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| self.checktensor = CheckTensor() | |||||
| self.checktuple = CheckTuple() | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| str_info = f'probs = {self.probs}' | str_info = f'probs = {self.probs}' | ||||
| @@ -143,14 +147,21 @@ class Geometric(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _check_param(self, probs1): | |||||
| """ | |||||
| Check availablity of distribution specific args probs1. | |||||
| """ | |||||
| if probs1 is not None: | |||||
| self.checktensor(probs1, 'probs1') | |||||
| return self.cast(probs1, self.parameter_type) | |||||
| return self.probs if self.probs is not None else raise_none_error('probs1') | |||||
| def _mean(self, probs1=None): | def _mean(self, probs1=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| MEAN(Geo) = \fratc{1 - probs1}{probs1} | MEAN(Geo) = \fratc{1 - probs1}{probs1} | ||||
| """ | """ | ||||
| probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs1") | |||||
| probs1 = self._check_param(probs1) | |||||
| return (1. - probs1) / probs1 | return (1. - probs1) / probs1 | ||||
| def _mode(self, probs1=None): | def _mode(self, probs1=None): | ||||
| @@ -158,9 +169,7 @@ class Geometric(Distribution): | |||||
| .. math:: | .. math:: | ||||
| MODE(Geo) = 0 | MODE(Geo) = 0 | ||||
| """ | """ | ||||
| probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs1") | |||||
| probs1 = self._check_param(probs1) | |||||
| return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) | return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) | ||||
| def _var(self, probs1=None): | def _var(self, probs1=None): | ||||
| @@ -168,23 +177,19 @@ class Geometric(Distribution): | |||||
| .. math:: | .. math:: | ||||
| VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} | VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} | ||||
| """ | """ | ||||
| probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs1") | |||||
| probs1 = self._check_param(probs1) | |||||
| return (1.0 - probs1) / self.sq(probs1) | return (1.0 - probs1) / self.sq(probs1) | ||||
| def _entropy(self, probs=None): | |||||
| def _entropy(self, probs1=None): | |||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} | H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} | ||||
| """ | """ | ||||
| probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs") | |||||
| probs1 = self._check_param(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 | return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 | ||||
| def _cross_entropy(self, dist, probs1_b, probs1_a=None): | |||||
| def _cross_entropy(self, dist, probs1_b, probs1=None): | |||||
| r""" | r""" | ||||
| Evaluate cross_entropy between Geometric distributions. | Evaluate cross_entropy between Geometric distributions. | ||||
| @@ -194,9 +199,9 @@ class Geometric(Distribution): | |||||
| probs1_a (Tensor): probability of success of distribution a. Default: self.probs. | probs1_a (Tensor): probability of success of distribution a. Default: self.probs. | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Geometric') | check_distribution_name(dist, 'Geometric') | ||||
| return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) | |||||
| return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) | |||||
| def _prob(self, value, probs=None): | |||||
| def _prob(self, value, probs1=None): | |||||
| r""" | r""" | ||||
| pmf of Geometric distribution. | pmf of Geometric distribution. | ||||
| @@ -208,19 +213,16 @@ class Geometric(Distribution): | |||||
| pmf(k) = probs0 ^k * probs1 if k >= 0; | pmf(k) = probs0 ^k * probs1 if k >= 0; | ||||
| pmf(k) = 0 if k < 0. | pmf(k) = 0 if k < 0. | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs") | |||||
| probs1 = self._check_param(probs1) | |||||
| pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | ||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | ||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| return self.select(comp, zeros, pmf) | return self.select(comp, zeros, pmf) | ||||
| def _cdf(self, value, probs=None): | |||||
| def _cdf(self, value, probs1=None): | |||||
| r""" | r""" | ||||
| cdf of Geometric distribution. | cdf of Geometric distribution. | ||||
| @@ -233,13 +235,10 @@ class Geometric(Distribution): | |||||
| cdf(k) = 0 if k < 0. | cdf(k) = 0 if k < 0. | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs") | |||||
| probs1 = self._check_param(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| cdf = 1.0 - self.pow(probs0, value + 1.0) | cdf = 1.0 - self.pow(probs0, value + 1.0) | ||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) | zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) | ||||
| @@ -247,7 +246,7 @@ class Geometric(Distribution): | |||||
| return self.select(comp, zeros, cdf) | return self.select(comp, zeros, cdf) | ||||
| def _kl_loss(self, dist, probs1_b, probs1_a=None): | |||||
| def _kl_loss(self, dist, probs1_b, probs1=None): | |||||
| r""" | r""" | ||||
| Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). | Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). | ||||
| @@ -260,17 +259,14 @@ class Geometric(Distribution): | |||||
| KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) | KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Geometric') | check_distribution_name(dist, 'Geometric') | ||||
| if probs1_b is None: | |||||
| raise_none_error("probs1_b") | |||||
| self.checktensor(probs1_b, 'probs1_b') | |||||
| probs1_b = self.cast(probs1_b, self.parameter_type) | probs1_b = self.cast(probs1_b, self.parameter_type) | ||||
| probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs | |||||
| if probs1_a is None: | |||||
| raise_none_error("probs1_a") | |||||
| probs1_a = self._check_param(probs1) | |||||
| probs0_a = 1.0 - probs1_a | probs0_a = 1.0 - probs1_a | ||||
| probs0_b = 1.0 - probs1_b | probs0_b = 1.0 - probs1_b | ||||
| return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) | return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) | ||||
| def _sample(self, shape=(), probs=None): | |||||
| def _sample(self, shape=(), probs1=None): | |||||
| """ | """ | ||||
| Sampling. | Sampling. | ||||
| @@ -281,9 +277,8 @@ class Geometric(Distribution): | |||||
| Returns: | Returns: | ||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs | |||||
| if probs1 is None: | |||||
| raise_none_error("probs") | |||||
| self.checktuple(shape, 'shape') | |||||
| probs1 = self._check_param(probs1) | |||||
| origin_shape = shape + self.shape(probs1) | origin_shape = shape + self.shape(probs1) | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| sample_shape = (1,) | sample_shape = (1,) | ||||
| @@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ | from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ | ||||
| raise_none_error | raise_none_error | ||||
| from ._utils.utils import CheckTensor, CheckTuple | |||||
| class Normal(Distribution): | class Normal(Distribution): | ||||
| """ | """ | ||||
| @@ -112,7 +113,6 @@ class Normal(Distribution): | |||||
| self._mean_value = mean | self._mean_value = mean | ||||
| self._sd_value = sd | self._sd_value = sd | ||||
| #ops needed for the class | #ops needed for the class | ||||
| self.squeeze = P.Squeeze(0) | self.squeeze = P.Squeeze(0) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| @@ -127,6 +127,9 @@ class Normal(Distribution): | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.zeroslike = P.ZerosLike() | self.zeroslike = P.ZerosLike() | ||||
| self.checktensor = CheckTensor() | |||||
| self.checktuple = CheckTuple() | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' | str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' | ||||
| @@ -140,40 +143,44 @@ class Normal(Distribution): | |||||
| """ | """ | ||||
| return self.exp(x) - 1.0 | return self.exp(x) - 1.0 | ||||
| def _check_param(self, mean, sd): | |||||
| """ | |||||
| Check availablity of distribution specific args mean and sd. | |||||
| """ | |||||
| if mean is not None: | |||||
| self.checktensor(mean, 'mean') | |||||
| mean = self.cast(mean, self.parameter_type) | |||||
| else: | |||||
| mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') | |||||
| if sd is not None: | |||||
| self.checktensor(sd, 'sd') | |||||
| sd = self.cast(sd, self.parameter_type) | |||||
| else: | |||||
| sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') | |||||
| batch_shape = self.shape(mean + sd) | |||||
| mean = mean * self.fill(self.dtype, batch_shape, 1.0) | |||||
| sd = sd * self.fill(self.dtype, batch_shape, 1.0) | |||||
| return mean, sd | |||||
| def _mean(self, mean=None, sd=None): | def _mean(self, mean=None, sd=None): | ||||
| """ | """ | ||||
| Mean of the distribution. | Mean of the distribution. | ||||
| """ | """ | ||||
| mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value | |||||
| if mean is None: | |||||
| raise_none_error("mean") | |||||
| sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value | |||||
| if sd is None: | |||||
| raise_none_error("sd") | |||||
| mean, sd = self._check_param(mean, sd) | |||||
| return mean | return mean | ||||
| def _mode(self, mean=None, sd=None): | def _mode(self, mean=None, sd=None): | ||||
| """ | """ | ||||
| Mode of the distribution. | Mode of the distribution. | ||||
| """ | """ | ||||
| mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value | |||||
| if mean is None: | |||||
| raise_none_error("mean") | |||||
| sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value | |||||
| if sd is None: | |||||
| raise_none_error("sd") | |||||
| mean, sd = self._check_param(mean, sd) | |||||
| return mean | return mean | ||||
| def _sd(self, mean=None, sd=None): | def _sd(self, mean=None, sd=None): | ||||
| """ | """ | ||||
| Standard deviation of the distribution. | Standard deviation of the distribution. | ||||
| """ | """ | ||||
| mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value | |||||
| if mean is None: | |||||
| raise_none_error("mean") | |||||
| sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value | |||||
| if sd is None: | |||||
| raise_none_error("sd") | |||||
| mean, sd = self._check_param(mean, sd) | |||||
| return sd | return sd | ||||
| def _entropy(self, mean=None, sd=None): | def _entropy(self, mean=None, sd=None): | ||||
| @@ -183,15 +190,10 @@ class Normal(Distribution): | |||||
| .. math:: | .. math:: | ||||
| H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | ||||
| """ | """ | ||||
| mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value | |||||
| if mean is None: | |||||
| raise_none_error("mean") | |||||
| sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value | |||||
| if sd is None: | |||||
| raise_none_error("sd") | |||||
| mean, sd = self._check_param(mean, sd) | |||||
| return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) | return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) | ||||
| def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): | |||||
| def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None): | |||||
| r""" | r""" | ||||
| Evaluate cross_entropy between normal distributions. | Evaluate cross_entropy between normal distributions. | ||||
| @@ -203,7 +205,7 @@ class Normal(Distribution): | |||||
| sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. | sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Normal') | check_distribution_name(dist, 'Normal') | ||||
| return self._entropy(mean=mean_a, sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a) | |||||
| return self._entropy(mean, sd) + self._kl_loss(dist, mean_b, sd_b, mean, sd) | |||||
| def _log_prob(self, value, mean=None, sd=None): | def _log_prob(self, value, mean=None, sd=None): | ||||
| r""" | r""" | ||||
| @@ -217,15 +219,9 @@ class Normal(Distribution): | |||||
| .. math:: | .. math:: | ||||
| L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) | L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value | |||||
| if mean is None: | |||||
| raise_none_error("mean") | |||||
| sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value | |||||
| if sd is None: | |||||
| raise_none_error("sd") | |||||
| mean, sd = self._check_param(mean, sd) | |||||
| unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | ||||
| neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) | neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) | ||||
| return unnormalized_log_prob + neg_normalization | return unnormalized_log_prob + neg_normalization | ||||
| @@ -242,20 +238,14 @@ class Normal(Distribution): | |||||
| .. math:: | .. math:: | ||||
| cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) | cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value | |||||
| if mean is None: | |||||
| raise_none_error("mean") | |||||
| sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value | |||||
| if sd is None: | |||||
| raise_none_error("sd") | |||||
| mean, sd = self._check_param(mean, sd) | |||||
| sqrt2 = self.sqrt(self.const(2.0)) | sqrt2 = self.sqrt(self.const(2.0)) | ||||
| adjusted = (value - mean) / (sd * sqrt2) | adjusted = (value - mean) / (sd * sqrt2) | ||||
| return 0.5 * (1.0 + self.erf(adjusted)) | return 0.5 * (1.0 + self.erf(adjusted)) | ||||
| def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): | |||||
| def _kl_loss(self, dist, mean_b, sd_b, mean=None, sd=None): | |||||
| r""" | r""" | ||||
| Evaluate Normal-Normal kl divergence, i.e. KL(a||b). | Evaluate Normal-Normal kl divergence, i.e. KL(a||b). | ||||
| @@ -271,23 +261,15 @@ class Normal(Distribution): | |||||
| 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) | 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Normal') | check_distribution_name(dist, 'Normal') | ||||
| if mean_b is None: | |||||
| raise_none_error("mean_b") | |||||
| if sd_b is None: | |||||
| raise_none_error("sd_b") | |||||
| self.checktensor(mean_b, 'mean_b') | |||||
| self.checktensor(sd_b, 'sd_b') | |||||
| mean_b = self.cast(mean_b, self.parameter_type) | mean_b = self.cast(mean_b, self.parameter_type) | ||||
| sd_b = self.cast(sd_b, self.parameter_type) | sd_b = self.cast(sd_b, self.parameter_type) | ||||
| mean_a = self.cast(mean_a, self.parameter_type) if mean_a is not None else self._mean_value | |||||
| sd_a = self.cast(sd_a, self.parameter_type) if sd_a is not None else self._sd_value | |||||
| if mean_a is None: | |||||
| raise_none_error("mean_a") | |||||
| if sd_a is None: | |||||
| raise_none_error("sd_a") | |||||
| mean_a, sd_a = self._check_param(mean, sd) | |||||
| diff_log_scale = self.log(sd_a) - self.log(sd_b) | diff_log_scale = self.log(sd_a) - self.log(sd_b) | ||||
| squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) | squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) | ||||
| return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale | return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale | ||||
| def _sample(self, shape=(), mean=None, sd=None): | def _sample(self, shape=(), mean=None, sd=None): | ||||
| """ | """ | ||||
| Sampling. | Sampling. | ||||
| @@ -300,12 +282,8 @@ class Normal(Distribution): | |||||
| Returns: | Returns: | ||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value | |||||
| if mean is None: | |||||
| raise_none_error("mean") | |||||
| sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value | |||||
| if sd is None: | |||||
| raise_none_error("sd") | |||||
| self.checktuple(shape, 'shape') | |||||
| mean, sd = self._check_param(mean, sd) | |||||
| batch_shape = self.shape(mean + sd) | batch_shape = self.shape(mean + sd) | ||||
| origin_shape = shape + batch_shape | origin_shape = shape + batch_shape | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| @@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ | from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ | ||||
| raise_none_error | raise_none_error | ||||
| from ._utils.utils import CheckTensor, CheckTuple | |||||
| class Uniform(Distribution): | class Uniform(Distribution): | ||||
| """ | """ | ||||
| @@ -129,6 +130,9 @@ class Uniform(Distribution): | |||||
| self.zeroslike = P.ZerosLike() | self.zeroslike = P.ZerosLike() | ||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| self.checktensor = CheckTensor() | |||||
| self.checktuple = CheckTuple() | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| str_info = f'low = {self.low}, high = {self.high}' | str_info = f'low = {self.low}, high = {self.high}' | ||||
| @@ -136,6 +140,25 @@ class Uniform(Distribution): | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | str_info = f'batch_shape = {self._broadcast_shape}' | ||||
| return str_info | return str_info | ||||
| def _check_param(self, low, high): | |||||
| """ | |||||
| Check availablity of distribution specific args low and high. | |||||
| """ | |||||
| if low is not None: | |||||
| self.checktensor(low, 'low') | |||||
| low = self.cast(low, self.parameter_type) | |||||
| else: | |||||
| low = self.low if self.low is not None else raise_none_error('low') | |||||
| if high is not None: | |||||
| self.checktensor(high, 'high') | |||||
| high = self.cast(high, self.parameter_type) | |||||
| else: | |||||
| high = self.high if self.high is not None else raise_none_error('high') | |||||
| batch_shape = self.shape(high - low) | |||||
| high = high * self.fill(self.dtype, batch_shape, 1.0) | |||||
| low = low * self.fill(self.dtype, batch_shape, 1.0) | |||||
| return low, high | |||||
| @property | @property | ||||
| def low(self): | def low(self): | ||||
| """ | """ | ||||
| @@ -156,12 +179,7 @@ class Uniform(Distribution): | |||||
| .. math:: | .. math:: | ||||
| range(U) = high -low | range(U) = high -low | ||||
| """ | """ | ||||
| low = self.cast(low, self.parameter_type) if low is not None else self.low | |||||
| if low is None: | |||||
| raise_none_error("low") | |||||
| high = self.cast(high, self.parameter_type) if high is not None else self.high | |||||
| if high is None: | |||||
| raise_none_error("high") | |||||
| low, high = self._check_param(low, high) | |||||
| return high - low | return high - low | ||||
| def _mean(self, low=None, high=None): | def _mean(self, low=None, high=None): | ||||
| @@ -169,12 +187,7 @@ class Uniform(Distribution): | |||||
| .. math:: | .. math:: | ||||
| MEAN(U) = \frac{low + high}{2}. | MEAN(U) = \frac{low + high}{2}. | ||||
| """ | """ | ||||
| low = self.cast(low, self.parameter_type) if low is not None else self.low | |||||
| if low is None: | |||||
| raise_none_error("low") | |||||
| high = self.cast(high, self.parameter_type) if high is not None else self.high | |||||
| if high is None: | |||||
| raise_none_error("high") | |||||
| low, high = self._check_param(low, high) | |||||
| return (low + high) / 2. | return (low + high) / 2. | ||||
| def _var(self, low=None, high=None): | def _var(self, low=None, high=None): | ||||
| @@ -182,12 +195,7 @@ class Uniform(Distribution): | |||||
| .. math:: | .. math:: | ||||
| VAR(U) = \frac{(high -low) ^ 2}{12}. | VAR(U) = \frac{(high -low) ^ 2}{12}. | ||||
| """ | """ | ||||
| low = self.cast(low, self.parameter_type) if low is not None else self.low | |||||
| if low is None: | |||||
| raise_none_error("low") | |||||
| high = self.cast(high, self.parameter_type) if high is not None else self.high | |||||
| if high is None: | |||||
| raise_none_error("high") | |||||
| low, high = self._check_param(low, high) | |||||
| return self.sq(high - low) / 12.0 | return self.sq(high - low) / 12.0 | ||||
| def _entropy(self, low=None, high=None): | def _entropy(self, low=None, high=None): | ||||
| @@ -195,15 +203,10 @@ class Uniform(Distribution): | |||||
| .. math:: | .. math:: | ||||
| H(U) = \log(high - low). | H(U) = \log(high - low). | ||||
| """ | """ | ||||
| low = self.cast(low, self.parameter_type) if low is not None else self.low | |||||
| if low is None: | |||||
| raise_none_error("low") | |||||
| high = self.cast(high, self.parameter_type) if high is not None else self.high | |||||
| if high is None: | |||||
| raise_none_error("high") | |||||
| low, high = self._check_param(low, high) | |||||
| return self.log(high - low) | return self.log(high - low) | ||||
| def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None): | |||||
| def _cross_entropy(self, dist, low_b, high_b, low=None, high=None): | |||||
| """ | """ | ||||
| Evaluate cross_entropy between Uniform distributoins. | Evaluate cross_entropy between Uniform distributoins. | ||||
| @@ -215,7 +218,7 @@ class Uniform(Distribution): | |||||
| high_a (Tensor): upper bound of distribution a. Default: self.high. | high_a (Tensor): upper bound of distribution a. Default: self.high. | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Uniform') | check_distribution_name(dist, 'Uniform') | ||||
| return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a) | |||||
| return self._entropy(low, high) + self._kl_loss(dist, low_b, high_b, low, high) | |||||
| def _prob(self, value, low=None, high=None): | def _prob(self, value, low=None, high=None): | ||||
| r""" | r""" | ||||
| @@ -231,15 +234,9 @@ class Uniform(Distribution): | |||||
| pdf(x) = \frac{1.0}{high -low} if low <= x <= high; | pdf(x) = \frac{1.0}{high -low} if low <= x <= high; | ||||
| pdf(x) = 0 if x > high; | pdf(x) = 0 if x > high; | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| low = self.cast(low, self.parameter_type) if low is not None else self.low | |||||
| if low is None: | |||||
| raise_none_error("low") | |||||
| high = self.cast(high, self.parameter_type) if high is not None else self.high | |||||
| if high is None: | |||||
| raise_none_error("high") | |||||
| low, high = self._check_param(low, high) | |||||
| neg_ones = self.fill(self.dtype, self.shape(value), -1.0) | neg_ones = self.fill(self.dtype, self.shape(value), -1.0) | ||||
| prob = self.exp(neg_ones * self.log(high - low)) | prob = self.exp(neg_ones * self.log(high - low)) | ||||
| broadcast_shape = self.shape(prob) | broadcast_shape = self.shape(prob) | ||||
| @@ -249,7 +246,7 @@ class Uniform(Distribution): | |||||
| less_than_low = self.select(comp_lo, zeros, prob) | less_than_low = self.select(comp_lo, zeros, prob) | ||||
| return self.select(comp_hi, less_than_low, zeros) | return self.select(comp_hi, less_than_low, zeros) | ||||
| def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None): | |||||
| def _kl_loss(self, dist, low_b, high_b, low=None, high=None): | |||||
| """ | """ | ||||
| Evaluate uniform-uniform kl divergence, i.e. KL(a||b). | Evaluate uniform-uniform kl divergence, i.e. KL(a||b). | ||||
| @@ -261,19 +258,12 @@ class Uniform(Distribution): | |||||
| high_a (Tensor): upper bound of distribution a. Default: self.high. | high_a (Tensor): upper bound of distribution a. Default: self.high. | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Uniform') | check_distribution_name(dist, 'Uniform') | ||||
| if low_b is None: | |||||
| raise_none_error("low_b") | |||||
| if high_b is None: | |||||
| raise_none_error("high_b") | |||||
| self.checktensor(low_b, 'low_b') | |||||
| low_b = self.cast(low_b, self.parameter_type) | low_b = self.cast(low_b, self.parameter_type) | ||||
| self.checktensor(high_b, 'high_b') | |||||
| high_b = self.cast(high_b, self.parameter_type) | high_b = self.cast(high_b, self.parameter_type) | ||||
| low_a = self.cast(low_a, self.parameter_type) if low_a is not None else self.low | |||||
| if low_a is None: | |||||
| raise_none_error("low_a") | |||||
| high_a = self.cast(high_a, self.parameter_type) if high_a is not None else self.high | |||||
| if high_a is None: | |||||
| raise_none_error("high_a") | |||||
| kl = self.log(high_b - low_b) / self.log(high_a - low_a) | |||||
| low_a, high_a = self._check_param(low, high) | |||||
| kl = self.log(high_b - low_b) - self.log(high_a - low_a) | |||||
| comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) | comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) | ||||
| return self.select(comp, kl, self.log(self.zeroslike(kl))) | return self.select(comp, kl, self.log(self.zeroslike(kl))) | ||||
| @@ -291,15 +281,9 @@ class Uniform(Distribution): | |||||
| cdf(x) = \frac{x - low}{high -low} if low <= x <= high; | cdf(x) = \frac{x - low}{high -low} if low <= x <= high; | ||||
| cdf(x) = 1 if x > high; | cdf(x) = 1 if x > high; | ||||
| """ | """ | ||||
| if value is None: | |||||
| raise_none_error("value") | |||||
| self.checktensor(value, 'value') | |||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| low = self.cast(low, self.parameter_type) if low is not None else self.low | |||||
| if low is None: | |||||
| raise_none_error("low") | |||||
| high = self.cast(high, self.parameter_type) if high is not None else self.high | |||||
| if high is None: | |||||
| raise_none_error("high") | |||||
| low, high = self._check_param(low, high) | |||||
| prob = (value - low) / (high - low) | prob = (value - low) / (high - low) | ||||
| broadcast_shape = self.shape(prob) | broadcast_shape = self.shape(prob) | ||||
| zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | ||||
| @@ -321,12 +305,8 @@ class Uniform(Distribution): | |||||
| Returns: | Returns: | ||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| low = self.cast(low, self.parameter_type) if low is not None else self.low | |||||
| if low is None: | |||||
| raise_none_error("low") | |||||
| high = self.cast(high, self.parameter_type) if high is not None else self.high | |||||
| if high is None: | |||||
| raise_none_error("high") | |||||
| self.checktuple(shape, 'shape') | |||||
| low, high = self._check_param(low, high) | |||||
| broadcast_shape = self.shape(low + high) | broadcast_shape = self.shape(low + high) | ||||
| origin_shape = shape + broadcast_shape | origin_shape = shape + broadcast_shape | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| @@ -75,7 +75,7 @@ def test_forward_jacobian(): | |||||
| forward_jacobian = Net2() | forward_jacobian = Net2() | ||||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | ||||
| ans = forward_jacobian(x) | ans = forward_jacobian(x) | ||||
| expected = np.log([2.0, 2.0, 2.0, 2.0]) | |||||
| expected = np.log([2.0]) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | assert (np.abs(ans.asnumpy() - expected) < tol).all() | ||||
| @@ -94,6 +94,6 @@ def test_backward_jacobian(): | |||||
| backward_jacobian = Net3() | backward_jacobian = Net3() | ||||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | ||||
| ans = backward_jacobian(x) | ans = backward_jacobian(x) | ||||
| expected = np.log([0.5, 0.5, 0.5, 0.5]) | |||||
| expected = np.log([0.5]) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | assert (np.abs(ans.asnumpy() - expected) < tol).all() | ||||
| @@ -20,7 +20,7 @@ import mindspore.nn.probability.bijector as msb | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import dtype | from mindspore import dtype | ||||
| context.set_context(device_target="Ascend") | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| """ | """ | ||||
| @@ -88,7 +88,7 @@ def test_kl_loss(): | |||||
| high_a = 1.5 | high_a = 1.5 | ||||
| low_b = -1.0 | low_b = -1.0 | ||||
| high_b = 2.0 | high_b = 2.0 | ||||
| expect_kl_loss = np.log(high_b - low_b) / np.log(high_a - low_a) | |||||
| expect_kl_loss = np.log(high_b - low_b) - np.log(high_a - low_a) | |||||
| kl = KL() | kl = KL() | ||||
| output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) | output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) | ||||
| tol = 1e-6 | tol = 1e-6 | ||||