Merge pull request !7416 from XunDeng/gumbeltags/v1.1.0
| @@ -17,7 +17,7 @@ from mindspore import context | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import CheckTensor | |||
| from ..distribution._utils.utils import CheckTensor, cast_to_tensor | |||
| from ..distribution import Distribution | |||
| from ..distribution import TransformedDistribution | |||
| @@ -66,6 +66,8 @@ class Bijector(Cell): | |||
| # ops needed for the base class | |||
| self.cast_base = P.Cast() | |||
| self.dtype_base = P.DType() | |||
| self.shape_base = P.Shape() | |||
| self.fill_base = P.Fill() | |||
| @property | |||
| def name(self): | |||
| @@ -87,6 +89,36 @@ class Bijector(Cell): | |||
| def is_injective(self): | |||
| return self._is_injective | |||
| def _add_parameter(self, value, name): | |||
| """ | |||
| Cast `value` to a tensor and add it to `self.default_parameters`. | |||
| Add `name` into and `self.parameter_names`. | |||
| """ | |||
| # initialize the attributes if they do not exist yet | |||
| if not hasattr(self, 'default_parameters'): | |||
| self.default_parameters = [] | |||
| self.parameter_names = [] | |||
| # cast value to a tensor if it is not None | |||
| value_t = None if value is None else cast_to_tensor(value, self.parameter_type) | |||
| self.default_parameters += [value_t,] | |||
| self.parameter_names += [name,] | |||
| return value_t | |||
| def _calc_event_shape(self): | |||
| """ | |||
| Calculate event_shape based on parameters. | |||
| """ | |||
| broadcast_shape = None | |||
| for param in self.default_parameters: | |||
| if broadcast_shape is None: | |||
| broadcast_shape = self.shape_base(param) | |||
| broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) | |||
| else: | |||
| broadcast_shape = self.shape_base(param + broadcast_shape_tensor) | |||
| broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) | |||
| return broadcast_shape | |||
| def _check_value(self, value, name): | |||
| """ | |||
| Check availability of `value` as a Tensor. | |||
| @@ -14,7 +14,9 @@ | |||
| # ============================================================================ | |||
| """GumbelCDF Bijector""" | |||
| from mindspore.common import dtype as mstype | |||
| from ..distribution._utils.utils import cast_to_tensor, check_greater_zero, set_param_type | |||
| from mindspore._checkparam import Validator | |||
| from mindspore.ops import operations as P | |||
| from ..distribution._utils.utils import check_greater_zero, set_param_type | |||
| from ..distribution._utils.custom_ops import exp_generic, log_generic | |||
| from .bijector import Bijector | |||
| @@ -33,6 +35,7 @@ class GumbelCDF(Bijector): | |||
| Args: | |||
| loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0.. | |||
| scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. | |||
| dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32. | |||
| name (str): The name of the Bijector. Default: 'Gumbel_CDF'. | |||
| Examples: | |||
| @@ -58,17 +61,24 @@ class GumbelCDF(Bijector): | |||
| def __init__(self, | |||
| loc=0.0, | |||
| scale=1.0, | |||
| dtype=mstype.float32, | |||
| name='GumbelCDF'): | |||
| """ | |||
| Constructor of GumbelCDF Bijector. | |||
| """ | |||
| param = dict(locals()) | |||
| parameter_type = set_param_type({'loc': loc, "scale": scale}, mstype.float32) | |||
| super(GumbelCDF, self).__init__(name=name, dtype=parameter_type, param=param) | |||
| self._loc = cast_to_tensor(loc, parameter_type) | |||
| self._scale = cast_to_tensor(scale, parameter_type) | |||
| valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) | |||
| super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) | |||
| self._parameter_type = parameter_type | |||
| self._loc = self._add_parameter(loc, 'loc') | |||
| self._scale = self._add_parameter(scale, 'scale') | |||
| check_greater_zero(self._scale, "scale") | |||
| self._event_shape = self._calc_event_shape() | |||
| self.cast = P.Cast() | |||
| self.exp = exp_generic | |||
| self.log = log_generic | |||
| @@ -81,6 +91,14 @@ class GumbelCDF(Bijector): | |||
| def scale(self): | |||
| return self._scale | |||
| @property | |||
| def event_shape(self): | |||
| return self._event_shape | |||
| @property | |||
| def parameter_type(self): | |||
| return self._parameter_type | |||
| def extend_repr(self): | |||
| str_info = f'loc = {self.loc}, scale = {self.scale}' | |||
| return str_info | |||
| @@ -90,18 +108,22 @@ class GumbelCDF(Bijector): | |||
| def _forward(self, x): | |||
| x = self._check_value(x, 'value') | |||
| x = self.cast(x, self.parameter_type) | |||
| z = (x - self.loc) / self.scale | |||
| return self.exp(-self.exp(-z)) | |||
| def _inverse(self, y): | |||
| y = self._check_value(y, 'value') | |||
| y = self.cast(y, self.parameter_type) | |||
| return self.loc - self.scale * self.log(-self.log(y)) | |||
| def _forward_log_jacobian(self, x): | |||
| x = self._check_value(x, 'value') | |||
| x = self.cast(x, self.parameter_type) | |||
| z = (x - self.loc) / self.scale | |||
| return -z - self.exp(-z) - self.log(self.scale) | |||
| def _inverse_log_jacobian(self, y): | |||
| y = self._check_value(y, 'value') | |||
| return self.log(self.scale / (-y * self.log(y))) | |||
| y = self.cast(y, self.parameter_type) | |||
| return self.log(self.scale / (-1. * y * self.log(y))) | |||
| @@ -57,11 +57,19 @@ class Invert(Bijector): | |||
| name=name, | |||
| param=param) | |||
| self._bijector = bijector | |||
| if hasattr(self._bijector, 'event_shape'): | |||
| self._event_shape = self.bijector.event_shape | |||
| else: | |||
| self._event_shape = () | |||
| @property | |||
| def bijector(self): | |||
| return self._bijector | |||
| @property | |||
| def event_shape(self): | |||
| return self._event_shape | |||
| def inverse(self, y): | |||
| return self.bijector("forward", y) | |||
| @@ -26,6 +26,7 @@ from .geometric import Geometric | |||
| from .categorical import Categorical | |||
| from .log_normal import LogNormal | |||
| from .logistic import Logistic | |||
| from .gumbel import Gumbel | |||
| __all__ = ['Distribution', | |||
| 'TransformedDistribution', | |||
| @@ -37,4 +38,5 @@ __all__ = ['Distribution', | |||
| 'Geometric', | |||
| 'LogNormal', | |||
| 'Logistic', | |||
| 'Gumbel', | |||
| ] | |||
| @@ -132,6 +132,10 @@ class Distribution(Cell): | |||
| def broadcast_shape(self): | |||
| return self._broadcast_shape | |||
| def _reset_parameters(self): | |||
| self.default_parameters = [] | |||
| self.parameter_names = [] | |||
| def _add_parameter(self, value, name): | |||
| """ | |||
| Cast `value` to a tensor and add it to `self.default_parameters`. | |||
| @@ -0,0 +1,249 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Gumbel Distribution""" | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import Validator | |||
| from mindspore.common import dtype as mstype | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.bijector as msb | |||
| import mindspore.nn.probability.distribution as msd | |||
| from .transformed_distribution import TransformedDistribution | |||
| from ._utils.utils import check_distribution_name, raise_not_implemented_util | |||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic | |||
| class Gumbel(TransformedDistribution): | |||
| """ | |||
| Gumbel distribution. | |||
| Args: | |||
| loc (int, float, list, numpy.ndarray, Tensor, Parameter): The location of Gumbel distribution. | |||
| scale (int, float, list, numpy.ndarray, Tensor, Parameter): The scale of Gumbel distribution. | |||
| seed (int): the seed used in sampling. The global seed is used if it is None. Default: None. | |||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. | |||
| name (str): the name of the distribution. Default: 'Gumbel'. | |||
| Note: | |||
| `scale` must be greater than zero. | |||
| `dist_spec_args` are `loc` and `scale`. | |||
| `dtype` must be a float type because Gumbel distributions are continuous. | |||
| Examples: | |||
| >>> # To initialize a Gumbel distribution of `loc` 3.0 and `scale` 4.0. | |||
| >>> gum = msd.Gumbel(3.0, 4.0, dtype=mstype.float32) | |||
| >>> | |||
| >>> # The following creates two independent Gumbel distributions. | |||
| >>> gum = msd.Gumbel([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) | |||
| >>> | |||
| >>> # To use a Gumbel distribution in a network. | |||
| >>> class net(Cell): | |||
| >>> def __init__(self): | |||
| >>> super(net, self).__init__(): | |||
| >>> self.g1 = msd.Gumbel(0.0, 1.0, dtype=mstype.float32) | |||
| >>> | |||
| >>> # The following calls are valid in construct. | |||
| >>> def construct(self, value, loc_b, scale_b): | |||
| >>> | |||
| >>> # Private interfaces of probability functions corresponding to public interfaces, including | |||
| >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same | |||
| >>> # arguments as follows. | |||
| >>> # Args: | |||
| >>> # value (Tensor): the value to be evaluated. | |||
| >>> | |||
| >>> # Examples of `prob`. | |||
| >>> # Similar calls can be made to other probability functions | |||
| >>> # by replacing 'prob' by the name of the function. | |||
| >>> ans = self.g1.prob(value) | |||
| >>> | |||
| >>> # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument. | |||
| >>> ans = self.g1.mean() | |||
| >>> ans = self.g1.mode() | |||
| >>> ans = self.g1.sd() | |||
| >>> ans = self.g1.entropy() | |||
| >>> ans = self.g1.var() | |||
| >>> | |||
| >>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same: | |||
| >>> # Args: | |||
| >>> # dist (str): the type of the distributions. Only "Gumbel" is supported. | |||
| >>> # loc_b (Tensor): the loc of distribution b. | |||
| >>> # scale_b (Tensor): the scale distribution b. | |||
| >>> | |||
| >>> # Examples of `kl_loss`. `cross_entropy` is similar. | |||
| >>> ans = self.g1.kl_loss('Gumbel', loc_b, scale_b) | |||
| >>> ans = self.g1.cross_entropy('Gumbel', loc_b, scale_b) | |||
| >>> | |||
| >>> # Examples of `sample`. | |||
| >>> # Args: | |||
| >>> # shape (tuple): the shape of the sample. Default: () | |||
| >>> | |||
| >>> ans = self.g1.sample() | |||
| >>> ans = self.g1.sample((2,3)) | |||
| """ | |||
| def __init__(self, | |||
| loc, | |||
| scale, | |||
| seed=0, | |||
| dtype=mstype.float32, | |||
| name="Gumbel"): | |||
| """ | |||
| Constructor of Gumbel distribution. | |||
| """ | |||
| valid_dtype = mstype.float_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) | |||
| super(Gumbel, self).__init__( | |||
| distribution=msd.Uniform(0.0, 1.0, dtype=dtype), | |||
| bijector=msb.Invert(gumbel_cdf), | |||
| seed=seed, name=name) | |||
| self._parameter_type = gumbel_cdf.parameter_type | |||
| self._broadcast_shape = gumbel_cdf.event_shape | |||
| if self._broadcast_shape != (): | |||
| self._is_scalar_batch = False | |||
| # overwrite default_parameters and parameter_names | |||
| self._reset_parameters() | |||
| self._loc = self._add_parameter(loc, 'loc') | |||
| self._scale = self._add_parameter(scale, 'scale') | |||
| self._gumbel_bijector = gumbel_cdf | |||
| # ops needed for the class | |||
| self.cast = P.Cast() | |||
| self.const = P.ScalarToArray() | |||
| self.exp = exp_generic | |||
| self.expm1 = expm1_generic | |||
| self.fill = P.Fill() | |||
| self.lgamma = nn.LGamma() | |||
| self.log = log_generic | |||
| self.shape = P.Shape() | |||
| self.sqrt = P.Sqrt() | |||
| @property | |||
| def loc(self): | |||
| return self._loc | |||
| @property | |||
| def scale(self): | |||
| return self._scale | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| str_info = f'loc = {self._loc}, scale = {self._scale}' | |||
| else: | |||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||
| return str_info | |||
| def _mean(self): | |||
| r""" | |||
| The mean of the distribution. | |||
| .. math:: | |||
| MEAN(X) = loc + scale * Euler-Mascheroni_constant | |||
| """ | |||
| return self.loc + self.scale * np.euler_gamma | |||
| def _mode(self): | |||
| """ | |||
| The mode of the distribution. | |||
| """ | |||
| return self.loc * self.fill(self.parameter_type, self.shape(self.scale), 1.0) | |||
| def _sd(self): | |||
| r""" | |||
| The standard deviation of the distribution. | |||
| .. math:: | |||
| STD(X) = \frac{\pi}{\sqrt(6)} * scale | |||
| """ | |||
| scale = self.scale * self.fill(self.parameter_type, self.broadcast_shape, 1.0) | |||
| return scale * np.pi / self.sqrt(self.const(6.)) | |||
| def _entropy(self): | |||
| r""" | |||
| Evaluate entropy. | |||
| .. math:: | |||
| H(X) = 1. + \log(scale) + Euler-Mascheroni_constant | |||
| """ | |||
| scale = self.scale * self.fill(self.parameter_type, self.broadcast_shape, 1.0) | |||
| return 1. + self.log(scale) + np.euler_gamma | |||
| def _log_prob(self, value): | |||
| r""" | |||
| .. math:: | |||
| log_pdf(X) = -(z + \exp(-z)) - \log(scale) | |||
| where z = \frac{x - loc}{scale} | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| z = (value - self.loc) / self.scale | |||
| return -(z + self.exp(-z)) - self.log(self.scale) | |||
| def _cdf(self, value): | |||
| r""" | |||
| .. math:: | |||
| cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale}) | |||
| """ | |||
| return self._gumbel_bijector("forward", value) | |||
| def _cross_entropy(self, dist, loc_b, scale_b): | |||
| r""" | |||
| Evaluate cross entropy between Gumbel distributions. | |||
| Args: | |||
| dist (str): The type of the distributions. Should be "Gumbel" in this case. | |||
| loc_b (Tensor): The loc of distribution b. | |||
| scale_b (Tensor): The scale of distribution b. | |||
| """ | |||
| if self.device_target == 'GPU': | |||
| raise_not_implemented_util('On GPU backend, cross_entropy', self.name) | |||
| check_distribution_name(dist, 'Gumbel') | |||
| return self._entropy() + self._kl_loss(dist, loc_b, scale_b) | |||
| def _kl_loss(self, dist, loc_b, scale_b): | |||
| r""" | |||
| Evaluate Gumbel-Gumbel kl divergence, i.e. KL(a||b). | |||
| Args: | |||
| dist (str): The type of the distributions. Should be "Gumbel" in this case. | |||
| loc_b (Tensor): The loc of distribution b. | |||
| scale_b (Tensor): The scale of distribution b. | |||
| .. math:: | |||
| KL(a||b) = \log(scale_b / scale_a) + Euler-Mascheroni_constant * (scale_a / scale_b - 1.) + | |||
| \exp(\frac{(loc_b - loc_a)}{scale_b}) * \Gamma(scale_a / scale_b + 1.) - 1. | |||
| """ | |||
| if self.device_target == 'GPU': | |||
| raise_not_implemented_util('On GPU backend, kl_loss', self.name) | |||
| check_distribution_name(dist, 'Gumbel') | |||
| loc_b = self._check_value(loc_b, 'loc_b') | |||
| scale_b = self._check_value(scale_b, 'scale_b') | |||
| loc_b = self.cast(loc_b, self.parameter_type) | |||
| scale_b = self.cast(scale_b, self.parameter_type) | |||
| return self.log(scale_b) - self.log(self.scale) +\ | |||
| np.euler_gamma * (self.scale / scale_b - 1.) +\ | |||
| self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.)) | |||
| def _sample(self, shape=()): | |||
| origin_shape = shape + self._broadcast_shape | |||
| if origin_shape == (): | |||
| sample_shape = (1,) | |||
| else: | |||
| sample_shape = origin_shape | |||
| org_sample = self.distribution("sample", sample_shape) | |||
| value = self.bijector("forward", org_sample) | |||
| if origin_shape == (): | |||
| value = self.squeeze(value) | |||
| return value | |||
| @@ -82,11 +82,21 @@ class TransformedDistribution(Distribution): | |||
| self._is_linear_transformation = bijector.is_constant_jacobian | |||
| self.default_parameters = distribution.default_parameters | |||
| self.parameter_names = distribution.parameter_names | |||
| self.exp = exp_generic | |||
| self.log = log_generic | |||
| self.isnan = P.IsNan() | |||
| self.equal_base = P.Equal() | |||
| self.select_base = P.Select() | |||
| self.fill = P.Fill() | |||
| # check if batch shape of the distribution and event shape is broadcastable | |||
| if hasattr(self.bijector, 'event_shape'): | |||
| event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0) | |||
| broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0) | |||
| self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape | |||
| else: | |||
| self._batch_event = self.broadcast_shape | |||
| @property | |||
| def bijector(self): | |||
| @@ -0,0 +1,303 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for Gumbel distribution""" | |||
| import numpy as np | |||
| from scipy import stats | |||
| from scipy import special | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.distribution as msd | |||
| from mindspore import Tensor | |||
| from mindspore import dtype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Prob(nn.Cell): | |||
| """ | |||
| Test class: probability of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Prob, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self, x_): | |||
| return self.gum.prob(x_) | |||
| def test_pdf(): | |||
| """ | |||
| Test pdf. | |||
| """ | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([[1.0], [2.0]]).astype(np.float32) | |||
| gumbel_benchmark = stats.gumbel_r(loc, scale) | |||
| value = np.array([1.0, 2.0]).astype(np.float32) | |||
| expect_pdf = gumbel_benchmark.pdf(value).astype(np.float32) | |||
| pdf = Prob() | |||
| output = pdf(Tensor(value, dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | |||
| class LogProb(nn.Cell): | |||
| """ | |||
| Test class: log probability of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(LogProb, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self, x_): | |||
| return self.gum.log_prob(x_) | |||
| def test_log_likelihood(): | |||
| """ | |||
| Test log_pdf. | |||
| """ | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([[1.0], [2.0]]).astype(np.float32) | |||
| gumbel_benchmark = stats.gumbel_r(loc, scale) | |||
| expect_logpdf = gumbel_benchmark.logpdf([1.0, 2.0]).astype(np.float32) | |||
| logprob = LogProb() | |||
| output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | |||
| class KL(nn.Cell): | |||
| """ | |||
| Test class: kl_loss of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(KL, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0]), dtype=dtype.float32) | |||
| def construct(self, loc_b, scale_b): | |||
| return self.gum.kl_loss('Gumbel', loc_b, scale_b) | |||
| def test_kl_loss(): | |||
| """ | |||
| Test kl_loss. | |||
| """ | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([1.0, 2.0]).astype(np.float32) | |||
| loc_b = np.array([1.0]).astype(np.float32) | |||
| scale_b = np.array([1.0, 2.0]).astype(np.float32) | |||
| expect_kl_loss = np.log(scale_b) - np.log(scale) +\ | |||
| np.euler_gamma * (scale / scale_b - 1.) +\ | |||
| np.expm1((loc_b - loc) / scale_b + special.loggamma(scale / scale_b + 1.)) | |||
| kl_loss = KL() | |||
| loc_b = Tensor(loc_b, dtype=dtype.float32) | |||
| scale_b = Tensor(scale_b, dtype=dtype.float32) | |||
| output = kl_loss(loc_b, scale_b) | |||
| tol = 1e-5 | |||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||
| class Basics(nn.Cell): | |||
| """ | |||
| Test class: mean/sd/mode of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Basics, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self): | |||
| return self.gum.mean(), self.gum.sd(), self.gum.mode() | |||
| def test_basics(): | |||
| """ | |||
| Test mean/standard deviation/mode. | |||
| """ | |||
| basics = Basics() | |||
| mean, sd, mode = basics() | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([[1.0], [2.0]]).astype(np.float32) | |||
| gumbel_benchmark = stats.gumbel_r(loc, scale) | |||
| expect_mean = gumbel_benchmark.mean().astype(np.float32) | |||
| expect_sd = gumbel_benchmark.std().astype(np.float32) | |||
| expect_mode = np.array([[0.0], [0.0]]).astype(np.float32) | |||
| tol = 1e-6 | |||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | |||
| assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() | |||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||
| class Sampling(nn.Cell): | |||
| """ | |||
| Test class: sample of Gumbel distribution. | |||
| """ | |||
| def __init__(self, shape, seed=0): | |||
| super(Sampling, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0, 3.0]), dtype=dtype.float32, seed=seed) | |||
| self.shape = shape | |||
| def construct(self): | |||
| return self.gum.sample(self.shape) | |||
| def test_sample(): | |||
| """ | |||
| Test sample. | |||
| """ | |||
| shape = (2, 3) | |||
| seed = 10 | |||
| sample = Sampling(shape, seed=seed) | |||
| output = sample() | |||
| assert output.shape == (2, 3, 3) | |||
| class CDF(nn.Cell): | |||
| """ | |||
| Test class: cdf of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(CDF, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self, x_): | |||
| return self.gum.cdf(x_) | |||
| def test_cdf(): | |||
| """ | |||
| Test cdf. | |||
| """ | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([[1.0], [2.0]]).astype(np.float32) | |||
| gumbel_benchmark = stats.gumbel_r(loc, scale) | |||
| expect_cdf = gumbel_benchmark.cdf([1.0, 2.0]).astype(np.float32) | |||
| cdf = CDF() | |||
| output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||
| tol = 2e-5 | |||
| assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() | |||
| class LogCDF(nn.Cell): | |||
| """ | |||
| Test class: log_cdf of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(LogCDF, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self, x_): | |||
| return self.gum.log_cdf(x_) | |||
| def test_log_cdf(): | |||
| """ | |||
| Test log cdf. | |||
| """ | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([[1.0], [2.0]]).astype(np.float32) | |||
| gumbel_benchmark = stats.gumbel_r(loc, scale) | |||
| expect_logcdf = gumbel_benchmark.logcdf([1.0, 2.0]).astype(np.float32) | |||
| logcdf = LogCDF() | |||
| output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||
| tol = 1e-4 | |||
| assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() | |||
| class SF(nn.Cell): | |||
| """ | |||
| Test class: survival function of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(SF, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self, x_): | |||
| return self.gum.survival_function(x_) | |||
| def test_survival(): | |||
| """ | |||
| Test log_survival. | |||
| """ | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([[1.0], [2.0]]).astype(np.float32) | |||
| gumbel_benchmark = stats.gumbel_r(loc, scale) | |||
| expect_survival = gumbel_benchmark.sf([1.0, 2.0]).astype(np.float32) | |||
| survival_function = SF() | |||
| output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||
| tol = 2e-5 | |||
| assert (np.abs(output.asnumpy() - expect_survival) < tol).all() | |||
| class LogSF(nn.Cell): | |||
| """ | |||
| Test class: log survival function of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(LogSF, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self, x_): | |||
| return self.gum.log_survival(x_) | |||
| def test_log_survival(): | |||
| """ | |||
| Test log_survival. | |||
| """ | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([[1.0], [2.0]]).astype(np.float32) | |||
| gumbel_benchmark = stats.gumbel_r(loc, scale) | |||
| expect_log_survival = gumbel_benchmark.logsf([1.0, 2.0]).astype(np.float32) | |||
| log_survival = LogSF() | |||
| output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||
| tol = 5e-4 | |||
| assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() | |||
| class EntropyH(nn.Cell): | |||
| """ | |||
| Test class: entropy of Gumbel distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(EntropyH, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self): | |||
| return self.gum.entropy() | |||
| def test_entropy(): | |||
| """ | |||
| Test entropy. | |||
| """ | |||
| loc = np.array([0.0]).astype(np.float32) | |||
| scale = np.array([[1.0], [2.0]]).astype(np.float32) | |||
| gumbel_benchmark = stats.gumbel_r(loc, scale) | |||
| expect_entropy = gumbel_benchmark.entropy().astype(np.float32) | |||
| entropy = EntropyH() | |||
| output = entropy() | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() | |||
| class CrossEntropy(nn.Cell): | |||
| """ | |||
| Test class: cross entropy between Gumbel distributions. | |||
| """ | |||
| def __init__(self): | |||
| super(CrossEntropy, self).__init__() | |||
| self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) | |||
| def construct(self, x_, y_): | |||
| entropy = self.gum.entropy() | |||
| kl_loss = self.gum.kl_loss('Gumbel', x_, y_) | |||
| h_sum_kl = entropy + kl_loss | |||
| cross_entropy = self.gum.cross_entropy('Gumbel', x_, y_) | |||
| return h_sum_kl - cross_entropy | |||
| def test_cross_entropy(): | |||
| """ | |||
| Test cross_entropy. | |||
| """ | |||
| cross_entropy = CrossEntropy() | |||
| loc = Tensor([1.0], dtype=dtype.float32) | |||
| scale = Tensor([1.0], dtype=dtype.float32) | |||
| diff = cross_entropy(loc, scale) | |||
| tol = 1e-6 | |||
| assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() | |||
| @@ -0,0 +1,153 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Test nn.probability.distribution.gumbel. | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.distribution as msd | |||
| from mindspore import dtype | |||
| from mindspore import Tensor | |||
| def test_gumbel_shape_errpr(): | |||
| """ | |||
| Invalid shapes. | |||
| """ | |||
| with pytest.raises(ValueError): | |||
| msd.Gumbel([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) | |||
| def test_type(): | |||
| with pytest.raises(TypeError): | |||
| msd.Gumbel(0., 1., dtype=dtype.int32) | |||
| def test_name(): | |||
| with pytest.raises(TypeError): | |||
| msd.Gumbel(0., 1., name=1.0) | |||
| def test_seed(): | |||
| with pytest.raises(TypeError): | |||
| msd.Gumbel(0., 1., seed='seed') | |||
| def test_scale(): | |||
| with pytest.raises(ValueError): | |||
| msd.Gumbel(0., 0.) | |||
| with pytest.raises(ValueError): | |||
| msd.Gumbel(0., -1.) | |||
| def test_arguments(): | |||
| """ | |||
| args passing during initialization. | |||
| """ | |||
| l = msd.Gumbel([3.0], [4.0], dtype=dtype.float32) | |||
| assert isinstance(l, msd.Distribution) | |||
| class GumbelProb(nn.Cell): | |||
| """ | |||
| Gumbel distribution: initialize with loc/scale. | |||
| """ | |||
| def __init__(self): | |||
| super(GumbelProb, self).__init__() | |||
| self.gumbel = msd.Gumbel(3.0, 4.0, dtype=dtype.float32) | |||
| def construct(self, value): | |||
| prob = self.gumbel.prob(value) | |||
| log_prob = self.gumbel.log_prob(value) | |||
| cdf = self.gumbel.cdf(value) | |||
| log_cdf = self.gumbel.log_cdf(value) | |||
| sf = self.gumbel.survival_function(value) | |||
| log_sf = self.gumbel.log_survival(value) | |||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||
| def test_gumbel_prob(): | |||
| """ | |||
| Test probability functions: passing value through construct. | |||
| """ | |||
| net = GumbelProb() | |||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||
| ans = net(value) | |||
| assert isinstance(ans, Tensor) | |||
| class KL(nn.Cell): | |||
| """ | |||
| Test kl_loss. | |||
| """ | |||
| def __init__(self): | |||
| super(KL, self).__init__() | |||
| self.gumbel = msd.Gumbel(3.0, 4.0) | |||
| def construct(self, mu, s): | |||
| kl = self.gumbel.kl_loss('Gumbel', mu, s) | |||
| cross_entropy = self.gumbel.cross_entropy('Gumbel', mu, s) | |||
| return kl + cross_entropy | |||
| def test_kl_cross_entropy(): | |||
| """ | |||
| Test kl_loss and cross_entropy. | |||
| """ | |||
| net = KL() | |||
| loc_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||
| scale_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||
| ans = net(loc_b, scale_b) | |||
| assert isinstance(ans, Tensor) | |||
| class GumbelBasics(nn.Cell): | |||
| """ | |||
| Test class: basic loc/scale function. | |||
| """ | |||
| def __init__(self): | |||
| super(GumbelBasics, self).__init__() | |||
| self.gumbel = msd.Gumbel(3.0, 4.0, dtype=dtype.float32) | |||
| def construct(self): | |||
| mean = self.gumbel.mean() | |||
| sd = self.gumbel.sd() | |||
| mode = self.gumbel.mode() | |||
| entropy = self.gumbel.entropy() | |||
| return mean + sd + mode + entropy | |||
| def test_bascis(): | |||
| """ | |||
| Test mean/sd/mode/entropy functionality of Gumbel. | |||
| """ | |||
| net = GumbelBasics() | |||
| ans = net() | |||
| assert isinstance(ans, Tensor) | |||
| class GumbelConstruct(nn.Cell): | |||
| """ | |||
| Gumbel distribution: going through construct. | |||
| """ | |||
| def __init__(self): | |||
| super(GumbelConstruct, self).__init__() | |||
| self.gumbel = msd.Gumbel(3.0, 4.0) | |||
| def construct(self, value): | |||
| prob = self.gumbel('prob', value) | |||
| prob1 = self.gumbel.prob(value) | |||
| return prob + prob1 | |||
| def test_gumbel_construct(): | |||
| """ | |||
| Test probability function going through construct. | |||
| """ | |||
| net = GumbelConstruct() | |||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||
| ans = net(value) | |||
| assert isinstance(ans, Tensor) | |||