| @@ -19,6 +19,7 @@ Distributions are the high-level components used to construct the probabilistic | |||
| from .distribution import Distribution | |||
| from .transformed_distribution import TransformedDistribution | |||
| from .bernoulli import Bernoulli | |||
| from .beta import Beta | |||
| from .categorical import Categorical | |||
| from .cauchy import Cauchy | |||
| from .exponential import Exponential | |||
| @@ -34,6 +35,7 @@ from .uniform import Uniform | |||
| __all__ = ['Distribution', | |||
| 'TransformedDistribution', | |||
| 'Bernoulli', | |||
| 'Beta', | |||
| 'Categorical', | |||
| 'Cauchy', | |||
| 'Exponential', | |||
| @@ -0,0 +1,333 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Beta Distribution""" | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| import mindspore.nn as nn | |||
| from mindspore._checkparam import Validator | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import check_greater_zero, check_distribution_name | |||
| from ._utils.custom_ops import log_generic | |||
| class Beta(Distribution): | |||
| """ | |||
| Beta distribution. | |||
| Args: | |||
| concentration1 (int, float, list, numpy.ndarray, Tensor, Parameter): The concentration1, | |||
| also know as alpha of the Beta distribution. | |||
| concentration0 (int, float, list, numpy.ndarray, Tensor, Parameter): The concentration0, also know as | |||
| beta of the Beta distribution. | |||
| seed (int): The seed used in sampling. The global seed is used if it is None. Default: None. | |||
| dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32. | |||
| name (str): The name of the distribution. Default: 'Beta'. | |||
| Note: | |||
| `concentration1` and `concentration0` must be greater than zero. | |||
| `dist_spec_args` are `concentration1` and `concentration0`. | |||
| `dtype` must be a float type because Beta distributions are continuous. | |||
| Examples: | |||
| >>> # To initialize a Beta distribution of the concentration1 3.0 and the concentration0 4.0. | |||
| >>> import mindspore.nn.probability.distribution as msd | |||
| >>> b = msd.Beta(3.0, 4.0, dtype=mstype.float32) | |||
| >>> | |||
| >>> # The following creates two independent Beta distributions. | |||
| >>> b = msd.Beta([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) | |||
| >>> | |||
| >>> # A Beta distribution can be initilized without arguments. | |||
| >>> # In this case, `concentration1` and `concentration0` must be passed in through arguments. | |||
| >>> b = msd.Beta(dtype=mstype.float32) | |||
| >>> | |||
| >>> # To use a Beta distribution in a network. | |||
| >>> class net(Cell): | |||
| ... def __init__(self): | |||
| ... super(net, self).__init__(): | |||
| ... self.b1 = msd.Beta(1.0, 1.0, dtype=mstype.float32) | |||
| ... self.b2 = msd.Beta(dtype=mstype.float32) | |||
| ... | |||
| ... # The following calls are valid in construct. | |||
| ... def construct(self, value, concentration1_b, concentration0_b, concentration1_a, concentration0_a): | |||
| ... | |||
| ... # Private interfaces of probability functions corresponding to public interfaces, including | |||
| ... # `prob` and `log_prob`, have the same arguments as follows. | |||
| ... # Args: | |||
| ... # value (Tensor): the value to be evaluated. | |||
| ... # concentration1 (Tensor): the concentration1 of the distribution. Default: self._concentration1. | |||
| ... # concentration0 (Tensor): the concentration0 of the distribution. Default: self._concentration0. | |||
| ... | |||
| ... # Examples of `prob`. | |||
| ... # Similar calls can be made to other probability functions | |||
| ... # by replacing 'prob' by the name of the function | |||
| ... ans = self.b1.prob(value) | |||
| ... # Evaluate with respect to the distribution b. | |||
| ... ans = self.b1.prob(value, concentration1_b, concentration0_b) | |||
| ... # `concentration1` and `concentration0` must be passed in during function calls | |||
| ... ans = self.b2.prob(value, concentration1_a, concentration0_a) | |||
| ... | |||
| ... | |||
| ... # Functions `mean`, `sd`, `mode`, `var`, and `entropy` have the same arguments. | |||
| ... # Args: | |||
| ... # concentration1 (Tensor): the concentration1 of the distribution. Default: self._concentration1. | |||
| ... # concentration0 (Tensor): the concentration0 of the distribution. Default: self._concentration0. | |||
| ... | |||
| ... # Example of `mean`, `sd`, `mode`, `var`, and `entropy` are similar. | |||
| ... ans = self.b1.concentration1() # return 1.0 | |||
| ... ans = self.b1.concentration1(concentration1_b, concentration0_b) # return concentration1_b | |||
| ... # `concentration1` and `concentration0` must be passed in during function calls. | |||
| ... ans = self.b2.concentration1(concentration1_a, concentration0_a) | |||
| ... | |||
| ... | |||
| ... # Interfaces of 'kl_loss' and 'cross_entropy' are the same: | |||
| ... # Args: | |||
| ... # dist (str): the type of the distributions. Only "Beta" is supported. | |||
| ... # concentration1_b (Tensor): the concentration1 of distribution b. | |||
| ... # concentration0_b (Tensor): the concentration0 of distribution b. | |||
| ... # concentration1_a (Tensor): the concentration1 of distribution a. | |||
| ... # Default: self._concentration1. | |||
| ... # concentration0_a (Tensor): the concentration0 of distribution a. | |||
| ... # Default: self._concentration0. | |||
| ... | |||
| ... # Examples of `kl_loss`. `cross_entropy` is similar. | |||
| ... ans = self.b1.kl_loss('Beta', concentration1_b, concentration0_b) | |||
| ... ans = self.b1.kl_loss('Beta', concentration1_b, concentration0_b, | |||
| ... concentration1_a, concentration0_a) | |||
| ... # Additional `concentration1` and `concentration0` must be passed in. | |||
| ... ans = self.b2.kl_loss('Beta', concentration1_b, concentration0_b, | |||
| ... concentration1_a, concentration0_a) | |||
| ... | |||
| ... | |||
| ... # Examples of `sample`. | |||
| ... # Args: | |||
| ... # shape (tuple): the shape of the sample. Default: () | |||
| ... # concentration1 (Tensor): the concentration1 of the distribution. Default: self._concentration1. | |||
| ... # concentration0 (Tensor): the concentration0 of the distribution. Default: self._concentration0. | |||
| ... ans = self.b1.sample() | |||
| ... ans = self.b1.sample((2,3)) | |||
| ... ans = self.b1.sample((2,3), concentration1_b, concentration0_b) | |||
| ... ans = self.b2.sample((2,3), concentration1_a, concentration0_a) | |||
| """ | |||
| def __init__(self, | |||
| concentration1=None, | |||
| concentration0=None, | |||
| seed=None, | |||
| dtype=mstype.float32, | |||
| name="Beta"): | |||
| """ | |||
| Constructor of Beta. | |||
| """ | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'concentration1': concentration1, 'concentration0': concentration0} | |||
| valid_dtype = mstype.float_type | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| super(Beta, self).__init__(seed, dtype, name, param) | |||
| self._concentration1 = self._add_parameter(concentration1, 'concentration1') | |||
| self._concentration0 = self._add_parameter(concentration0, 'concentration0') | |||
| if self._concentration1 is not None: | |||
| check_greater_zero(self._concentration1, "concentration1") | |||
| if self._concentration0 is not None: | |||
| check_greater_zero(self._concentration0, "concentration0") | |||
| # ops needed for the class | |||
| self.log = log_generic | |||
| self.log1p = P.Log1p() | |||
| self.neg = P.Neg() | |||
| self.pow = P.Pow() | |||
| self.squeeze = P.Squeeze(0) | |||
| self.cast = P.Cast() | |||
| self.fill = P.Fill() | |||
| self.shape = P.Shape() | |||
| self.select = P.Select() | |||
| self.logicaland = P.LogicalAnd() | |||
| self.greater = P.Greater() | |||
| self.digamma = nn.DiGamma() | |||
| self.lbeta = nn.LBeta() | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| s = f'concentration1 = {self._concentration1}, concentration0 = {self._concentration0}' | |||
| else: | |||
| s = f'batch_shape = {self._broadcast_shape}' | |||
| return s | |||
| @property | |||
| def concentration1(self): | |||
| """ | |||
| Return the concentration1, also know as the alpha of the Beta distribution. | |||
| """ | |||
| return self._concentration1 | |||
| @property | |||
| def concentration0(self): | |||
| """ | |||
| Return the concentration0, also know as the beta of the Beta distribution. | |||
| """ | |||
| return self._concentration0 | |||
| def _get_dist_type(self): | |||
| return "Beta" | |||
| def _get_dist_args(self, concentration1=None, concentration0=None): | |||
| if concentration1 is not None: | |||
| self.checktensor(concentration1, 'concentration1') | |||
| else: | |||
| concentration1 = self._concentration1 | |||
| if concentration0 is not None: | |||
| self.checktensor(concentration0, 'concentration0') | |||
| else: | |||
| concentration0 = self._concentration0 | |||
| return concentration1, concentration0 | |||
| def _mean(self, concentration1=None, concentration0=None): | |||
| """ | |||
| The mean of the distribution. | |||
| """ | |||
| concentration1, concentration0 = self._check_param_type(concentration1, concentration0) | |||
| return concentration1 / (concentration1 + concentration0) | |||
| def _var(self, concentration1=None, concentration0=None): | |||
| """ | |||
| The variance of the distribution. | |||
| """ | |||
| concentration1, concentration0 = self._check_param_type(concentration1, concentration0) | |||
| total_concentration = concentration1 + concentration0 | |||
| return concentration1 * concentration0 / (self.pow(total_concentration, 2) * (total_concentration + 1.)) | |||
| def _mode(self, concentration1=None, concentration0=None): | |||
| """ | |||
| The mode of the distribution. | |||
| """ | |||
| concentration1, concentration0 = self._check_param_type(concentration1, concentration0) | |||
| comp1 = self.greater(concentration1, 1.) | |||
| comp2 = self.greater(concentration0, 1.) | |||
| cond = self.logicaland(comp1, comp2) | |||
| nan = self.fill(self.dtype, self.broadcast_shape, np.nan) | |||
| mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.) | |||
| return self.select(cond, mode, nan) | |||
| def _entropy(self, concentration1=None, concentration0=None): | |||
| r""" | |||
| Evaluate entropy. | |||
| .. math:: | |||
| H(X) = \log(\Beta(\alpha, \beta)) - (\alpha - 1) * \digamma(\alpha) | |||
| - (\beta - 1) * \digamma(\beta) + (\alpha + \beta - 2) * \digamma(\alpha + \beta) | |||
| """ | |||
| concentration1, concentration0 = self._check_param_type(concentration1, concentration0) | |||
| total_concentration = concentration1 + concentration0 | |||
| return self.lbeta(concentration1, concentration0) \ | |||
| - (concentration1 - 1.) * self.digamma(concentration1) \ | |||
| - (concentration0 - 1.) * self.digamma(concentration0) \ | |||
| + (total_concentration - 2.) * self.digamma(total_concentration) | |||
| def _cross_entropy(self, dist, concentration1_b, concentration0_b, concentration1=None, concentration0=None): | |||
| r""" | |||
| Evaluate cross entropy between Beta distributions. | |||
| Args: | |||
| dist (str): Type of the distributions. Should be "Beta" in this case. | |||
| concentration1_b (Tensor): concentration1 of distribution b. | |||
| concentration0_b (Tensor): concentration0 of distribution b. | |||
| concentration1_a (Tensor): concentration1 of distribution a. Default: self._concentration1. | |||
| concentration0_a (Tensor): concentration0 of distribution a. Default: self._concentration0. | |||
| """ | |||
| check_distribution_name(dist, 'Beta') | |||
| return self._entropy(concentration1, concentration0) \ | |||
| + self._kl_loss(dist, concentration1_b, concentration0_b, concentration1, concentration0) | |||
| def _log_prob(self, value, concentration1=None, concentration0=None): | |||
| r""" | |||
| Evaluate log probability. | |||
| Args: | |||
| value (Tensor): The value to be evaluated. | |||
| concentration1 (Tensor): The concentration1 of the distribution. Default: self._concentration1. | |||
| concentration0 (Tensor): The concentration0 the distribution. Default: self._concentration0. | |||
| .. math:: | |||
| L(x) = (\alpha - 1) * \log(x) + (\beta - 1) * \log(1 - x) - \log(\Beta(\alpha, \beta)) | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| concentration1, concentration0 = self._check_param_type(concentration1, concentration0) | |||
| log_unnormalized_prob = (concentration1 - 1.) * self.log(value) \ | |||
| + (concentration0 - 1.) * self.log1p(self.neg(value)) | |||
| return log_unnormalized_prob - self.lbeta(concentration1, concentration0) | |||
| def _kl_loss(self, dist, concentration1_b, concentration0_b, concentration1=None, concentration0=None): | |||
| r""" | |||
| Evaluate Beta-Beta KL divergence, i.e. KL(a||b). | |||
| Args: | |||
| dist (str): The type of the distributions. Should be "Beta" in this case. | |||
| concentration1_b (Tensor): The concentration1 of distribution b. | |||
| concentration0_b (Tensor): The concentration0 distribution b. | |||
| concentration1_a (Tensor): The concentration1 of distribution a. Default: self._concentration1. | |||
| concentration0_a (Tensor): The concentration0 distribution a. Default: self._concentration0. | |||
| .. math:: | |||
| KL(a||b) = \log(\Beta(\alpha_{b}, \beta_{b})) - \log(\Beta(\alpha_{a}, \beta_{a})) | |||
| - \digamma(\alpha_{a}) * (\alpha_{b} - \alpha_{a}) | |||
| - \digamma(\beta_{a}) * (\beta_{b} - \beta_{a}) | |||
| + \digamma(\alpha_{a} + \beta_{a}) * (\alpha_{b} + \beta_{b} - \alpha_{a} - \beta_{a}) | |||
| """ | |||
| check_distribution_name(dist, 'Beta') | |||
| concentration1_b = self._check_value(concentration1_b, 'concentration1_b') | |||
| concentration0_b = self._check_value(concentration0_b, 'concentration0_b') | |||
| concentration1_b = self.cast(concentration1_b, self.parameter_type) | |||
| concentration0_b = self.cast(concentration0_b, self.parameter_type) | |||
| concentration1_a, concentration0_a = self._check_param_type(concentration1, concentration0) | |||
| total_concentration_a = concentration1_a + concentration0_a | |||
| total_concentration_b = concentration1_b + concentration0_b | |||
| log_normalization_a = self.lbeta(concentration1_a, concentration0_a) | |||
| log_normalization_b = self.lbeta(concentration1_b, concentration0_b) | |||
| return (log_normalization_b - log_normalization_a) \ | |||
| - (self.digamma(concentration1_a) * (concentration1_b - concentration1_a)) \ | |||
| - (self.digamma(concentration0_a) * (concentration0_b - concentration0_a)) \ | |||
| + (self.digamma(total_concentration_a) * (total_concentration_b - total_concentration_a)) | |||
| def _sample(self, shape=(), concentration1=None, concentration0=None): | |||
| """ | |||
| Sampling. | |||
| Args: | |||
| shape (tuple): The shape of the sample. Default: (). | |||
| concentration1 (Tensor): The concentration1 of the samples. Default: self._concentration1. | |||
| concentration0 (Tensor): The concentration0 of the samples. Default: self._concentration0. | |||
| Returns: | |||
| Tensor, with the shape being shape + batch_shape. | |||
| """ | |||
| shape = self.checktuple(shape, 'shape') | |||
| concentration1, concentration0 = self._check_param_type(concentration1, concentration0) | |||
| batch_shape = self.shape(concentration1 + concentration0) | |||
| origin_shape = shape + batch_shape | |||
| if origin_shape == (): | |||
| sample_shape = (1,) | |||
| else: | |||
| sample_shape = origin_shape | |||
| ones = self.fill(self.dtype, sample_shape, 1.0) | |||
| sample_gamma1 = C.gamma(sample_shape, alpha=concentration1, beta=ones, seed=self.seed) | |||
| sample_gamma2 = C.gamma(sample_shape, alpha=concentration0, beta=ones, seed=self.seed) | |||
| sample_beta = sample_gamma1 / (sample_gamma1 + sample_gamma2) | |||
| value = self.cast(sample_beta, self.dtype) | |||
| if origin_shape == (): | |||
| value = self.squeeze(value) | |||
| return value | |||
| @@ -81,12 +81,12 @@ class Gamma(Distribution): | |||
| ... ans = self.g2.prob(value, concentration_a, rate_a) | |||
| ... | |||
| ... | |||
| ... # Functions `concentration`, `rate`, `mean`, `sd`, `var`, and `entropy` have the same arguments. | |||
| ... # Functions `mean`, `sd`, `mode`, `var`, and `entropy` have the same arguments. | |||
| ... # Args: | |||
| ... # concentration (Tensor): the concentration of the distribution. Default: self._concentration. | |||
| ... # rate (Tensor): the rate of the distribution. Default: self._rate. | |||
| ... | |||
| ... # Example of `concentration`, `rate`, `mean`. `sd`, `var`, and `entropy` are similar. | |||
| ... # Example of `mean`, `sd`, `mode`, `var`, and `entropy` are similar. | |||
| ... ans = self.g1.concentration() # return 1.0 | |||
| ... ans = self.g1.concentration(concentration_b, rate_b) # return concentration_b | |||
| ... # `concentration` and `rate` must be passed in during function calls. | |||
| @@ -76,11 +76,11 @@ class Poisson(Distribution): | |||
| ... ans = self.p2.prob(value, rate_a) | |||
| ... | |||
| ... | |||
| ... # Functions `mean`, `sd`, and 'var' have the same arguments as follows. | |||
| ... # Functions `mean`, `mode`, `sd`, and 'var' have the same arguments as follows. | |||
| ... # Args: | |||
| ... # rate (Tensor): the rate of the distribution. Default: self.rate. | |||
| ... | |||
| ... # Examples of `mean`. `sd`, `var`, and `entropy` are similar. | |||
| ... # Examples of `mean`, `sd`, `mode`, `var`, and `entropy` are similar. | |||
| ... ans = self.p1.mean() # return 2 | |||
| ... ans = self.p1.mean(rate_b) # return 1 / rate_b | |||
| ... # `rate` must be passed in during function calls. | |||
| @@ -0,0 +1,245 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for Beta distribution""" | |||
| import numpy as np | |||
| from scipy import stats | |||
| from scipy import special | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.distribution as msd | |||
| from mindspore import Tensor | |||
| from mindspore import dtype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Prob(nn.Cell): | |||
| """ | |||
| Test class: probability of Beta distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Prob, self).__init__() | |||
| self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) | |||
| def construct(self, x_): | |||
| return self.b.prob(x_) | |||
| def test_pdf(): | |||
| """ | |||
| Test pdf. | |||
| """ | |||
| beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) | |||
| expect_pdf = beta_benchmark.pdf([0.25, 0.75]).astype(np.float32) | |||
| pdf = Prob() | |||
| output = pdf(Tensor([0.25, 0.75], dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | |||
| class LogProb(nn.Cell): | |||
| """ | |||
| Test class: log probability of Beta distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(LogProb, self).__init__() | |||
| self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) | |||
| def construct(self, x_): | |||
| return self.b.log_prob(x_) | |||
| def test_log_likelihood(): | |||
| """ | |||
| Test log_pdf. | |||
| """ | |||
| beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) | |||
| expect_logpdf = beta_benchmark.logpdf([0.25, 0.75]).astype(np.float32) | |||
| logprob = LogProb() | |||
| output = logprob(Tensor([0.25, 0.75], dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | |||
| class KL(nn.Cell): | |||
| """ | |||
| Test class: kl_loss of Beta distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(KL, self).__init__() | |||
| self.b = msd.Beta(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||
| def construct(self, x_, y_): | |||
| return self.b.kl_loss('Beta', x_, y_) | |||
| def test_kl_loss(): | |||
| """ | |||
| Test kl_loss. | |||
| """ | |||
| concentration1_a = np.array([3.0]).astype(np.float32) | |||
| concentration0_a = np.array([4.0]).astype(np.float32) | |||
| concentration1_b = np.array([1.0]).astype(np.float32) | |||
| concentration0_b = np.array([1.0]).astype(np.float32) | |||
| total_concentration_a = concentration1_a + concentration0_a | |||
| total_concentration_b = concentration1_b + concentration0_b | |||
| log_normalization_a = np.log(special.beta(concentration1_a, concentration0_a)) | |||
| log_normalization_b = np.log(special.beta(concentration1_b, concentration0_b)) | |||
| expect_kl_loss = (log_normalization_b - log_normalization_a) \ | |||
| - (special.digamma(concentration1_a) * (concentration1_b - concentration1_a)) \ | |||
| - (special.digamma(concentration0_a) * (concentration0_b - concentration0_a)) \ | |||
| + (special.digamma(total_concentration_a) * (total_concentration_b - total_concentration_a)) | |||
| kl_loss = KL() | |||
| concentration1 = Tensor(concentration1_b, dtype=dtype.float32) | |||
| concentration0 = Tensor(concentration0_b, dtype=dtype.float32) | |||
| output = kl_loss(concentration1, concentration0) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||
| class Basics(nn.Cell): | |||
| """ | |||
| Test class: mean/sd/mode of Beta distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Basics, self).__init__() | |||
| self.b = msd.Beta(np.array([3.0]), np.array([3.0]), dtype=dtype.float32) | |||
| def construct(self): | |||
| return self.b.mean(), self.b.sd(), self.b.mode() | |||
| def test_basics(): | |||
| """ | |||
| Test mean/standard deviation/mode. | |||
| """ | |||
| basics = Basics() | |||
| mean, sd, mode = basics() | |||
| beta_benchmark = stats.beta(np.array([3.0]), np.array([3.0])) | |||
| expect_mean = beta_benchmark.mean().astype(np.float32) | |||
| expect_sd = beta_benchmark.std().astype(np.float32) | |||
| expect_mode = [0.5] | |||
| tol = 1e-6 | |||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | |||
| assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() | |||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||
| class Sampling(nn.Cell): | |||
| """ | |||
| Test class: sample of Beta distribution. | |||
| """ | |||
| def __init__(self, shape, seed=0): | |||
| super(Sampling, self).__init__() | |||
| self.b = msd.Beta(np.array([3.0]), np.array([1.0]), seed=seed, dtype=dtype.float32) | |||
| self.shape = shape | |||
| def construct(self, concentration1=None, concentration0=None): | |||
| return self.b.sample(self.shape, concentration1, concentration0) | |||
| def test_sample(): | |||
| """ | |||
| Test sample. | |||
| """ | |||
| shape = (2, 3) | |||
| seed = 10 | |||
| concentration1 = Tensor([2.0], dtype=dtype.float32) | |||
| concentration0 = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) | |||
| sample = Sampling(shape, seed=seed) | |||
| output = sample(concentration1, concentration0) | |||
| assert output.shape == (2, 3, 3) | |||
| class EntropyH(nn.Cell): | |||
| """ | |||
| Test class: entropy of Beta distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(EntropyH, self).__init__() | |||
| self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) | |||
| def construct(self): | |||
| return self.b.entropy() | |||
| def test_entropy(): | |||
| """ | |||
| Test entropy. | |||
| """ | |||
| beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) | |||
| expect_entropy = beta_benchmark.entropy().astype(np.float32) | |||
| entropy = EntropyH() | |||
| output = entropy() | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() | |||
| class CrossEntropy(nn.Cell): | |||
| """ | |||
| Test class: cross entropy between Beta distributions. | |||
| """ | |||
| def __init__(self): | |||
| super(CrossEntropy, self).__init__() | |||
| self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) | |||
| def construct(self, x_, y_): | |||
| entropy = self.b.entropy() | |||
| kl_loss = self.b.kl_loss('Beta', x_, y_) | |||
| h_sum_kl = entropy + kl_loss | |||
| cross_entropy = self.b.cross_entropy('Beta', x_, y_) | |||
| return h_sum_kl - cross_entropy | |||
| def test_cross_entropy(): | |||
| """ | |||
| Test cross_entropy. | |||
| """ | |||
| cross_entropy = CrossEntropy() | |||
| concentration1 = Tensor([3.0], dtype=dtype.float32) | |||
| concentration0 = Tensor([2.0], dtype=dtype.float32) | |||
| diff = cross_entropy(concentration1, concentration0) | |||
| tol = 1e-6 | |||
| assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: expand single distribution instance to multiple graphs | |||
| by specifying the attributes. | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.beta = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) | |||
| def construct(self, x_, y_): | |||
| kl = self.beta.kl_loss('Beta', x_, y_) | |||
| prob = self.beta.prob(kl) | |||
| return prob | |||
| def test_multiple_graphs(): | |||
| """ | |||
| Test multiple graphs case. | |||
| """ | |||
| prob = Net() | |||
| concentration1_a = np.array([3.0]).astype(np.float32) | |||
| concentration0_a = np.array([1.0]).astype(np.float32) | |||
| concentration1_b = np.array([2.0]).astype(np.float32) | |||
| concentration0_b = np.array([1.0]).astype(np.float32) | |||
| ans = prob(Tensor(concentration1_b), Tensor(concentration0_b)) | |||
| total_concentration_a = concentration1_a + concentration0_a | |||
| total_concentration_b = concentration1_b + concentration0_b | |||
| log_normalization_a = np.log(special.beta(concentration1_a, concentration0_a)) | |||
| log_normalization_b = np.log(special.beta(concentration1_b, concentration0_b)) | |||
| expect_kl_loss = (log_normalization_b - log_normalization_a) \ | |||
| - (special.digamma(concentration1_a) * (concentration1_b - concentration1_a)) \ | |||
| - (special.digamma(concentration0_a) * (concentration0_b - concentration0_a)) \ | |||
| + (special.digamma(total_concentration_a) * (total_concentration_b - total_concentration_a)) | |||
| beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) | |||
| expect_prob = beta_benchmark.pdf(expect_kl_loss).astype(np.float32) | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() | |||
| @@ -298,11 +298,11 @@ class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.Gamma = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) | |||
| self.get_flags = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) | |||
| def construct(self, x_, y_): | |||
| kl = self.Gamma.kl_loss('Gamma', x_, y_) | |||
| prob = self.Gamma.prob(kl) | |||
| kl = self.g.kl_loss('Gamma', x_, y_) | |||
| prob = self.g.prob(kl) | |||
| return prob | |||
| def test_multiple_graphs(): | |||
| @@ -0,0 +1,212 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Test nn.probability.distribution.Gamma. | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.distribution as msd | |||
| from mindspore import dtype | |||
| from mindspore import Tensor | |||
| def test_gamma_shape_errpr(): | |||
| """ | |||
| Invalid shapes. | |||
| """ | |||
| with pytest.raises(ValueError): | |||
| msd.Gamma([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) | |||
| def test_type(): | |||
| with pytest.raises(TypeError): | |||
| msd.Gamma(0., 1., dtype=dtype.int32) | |||
| def test_name(): | |||
| with pytest.raises(TypeError): | |||
| msd.Gamma(0., 1., name=1.0) | |||
| def test_seed(): | |||
| with pytest.raises(TypeError): | |||
| msd.Gamma(0., 1., seed='seed') | |||
| def test_concentration1(): | |||
| with pytest.raises(ValueError): | |||
| msd.Gamma(0., 1.) | |||
| with pytest.raises(ValueError): | |||
| msd.Gamma(-1., 1.) | |||
| def test_concentration0(): | |||
| with pytest.raises(ValueError): | |||
| msd.Gamma(1., 0.) | |||
| with pytest.raises(ValueError): | |||
| msd.Gamma(1., -1.) | |||
| def test_arguments(): | |||
| """ | |||
| args passing during initialization. | |||
| """ | |||
| g = msd.Gamma() | |||
| assert isinstance(g, msd.Distribution) | |||
| g = msd.Gamma([3.0], [4.0], dtype=dtype.float32) | |||
| assert isinstance(g, msd.Distribution) | |||
| class GammaProb(nn.Cell): | |||
| """ | |||
| Gamma distribution: initialize with concentration1/concentration0. | |||
| """ | |||
| def __init__(self): | |||
| super(GammaProb, self).__init__() | |||
| self.gamma = msd.Gamma([3.0, 4.0], [1.0, 1.0], dtype=dtype.float32) | |||
| def construct(self, value): | |||
| prob = self.gamma.prob(value) | |||
| log_prob = self.gamma.log_prob(value) | |||
| return prob + log_prob | |||
| def test_gamma_prob(): | |||
| """ | |||
| Test probability functions: passing value through construct. | |||
| """ | |||
| net = GammaProb() | |||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||
| ans = net(value) | |||
| assert isinstance(ans, Tensor) | |||
| class GammaProb1(nn.Cell): | |||
| """ | |||
| Gamma distribution: initialize without concentration1/concentration0. | |||
| """ | |||
| def __init__(self): | |||
| super(GammaProb1, self).__init__() | |||
| self.gamma = msd.Gamma() | |||
| def construct(self, value, concentration1, concentration0): | |||
| prob = self.gamma.prob(value, concentration1, concentration0) | |||
| log_prob = self.gamma.log_prob(value, concentration1, concentration0) | |||
| return prob + log_prob | |||
| def test_gamma_prob1(): | |||
| """ | |||
| Test probability functions: passing concentration1/concentration0, value through construct. | |||
| """ | |||
| net = GammaProb1() | |||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||
| concentration1 = Tensor([2.0, 3.0], dtype=dtype.float32) | |||
| concentration0 = Tensor([1.0], dtype=dtype.float32) | |||
| ans = net(value, concentration1, concentration0) | |||
| assert isinstance(ans, Tensor) | |||
| class GammaKl(nn.Cell): | |||
| """ | |||
| Test class: kl_loss of Gamma distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(GammaKl, self).__init__() | |||
| self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||
| self.g2 = msd.Gamma(dtype=dtype.float32) | |||
| def construct(self, concentration1_b, concentration0_b, concentration1_a, concentration0_a): | |||
| kl1 = self.g1.kl_loss('Gamma', concentration1_b, concentration0_b) | |||
| kl2 = self.g2.kl_loss('Gamma', concentration1_b, concentration0_b, concentration1_a, concentration0_a) | |||
| return kl1 + kl2 | |||
| def test_kl(): | |||
| """ | |||
| Test kl_loss. | |||
| """ | |||
| net = GammaKl() | |||
| concentration1_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||
| concentration0_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||
| concentration1_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) | |||
| concentration0_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) | |||
| ans = net(concentration1_b, concentration0_b, concentration1_a, concentration0_a) | |||
| assert isinstance(ans, Tensor) | |||
| class GammaCrossEntropy(nn.Cell): | |||
| """ | |||
| Test class: cross_entropy of Gamma distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(GammaCrossEntropy, self).__init__() | |||
| self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||
| self.g2 = msd.Gamma(dtype=dtype.float32) | |||
| def construct(self, concentration1_b, concentration0_b, concentration1_a, concentration0_a): | |||
| h1 = self.g1.cross_entropy('Gamma', concentration1_b, concentration0_b) | |||
| h2 = self.g2.cross_entropy('Gamma', concentration1_b, concentration0_b, concentration1_a, concentration0_a) | |||
| return h1 + h2 | |||
| def test_cross_entropy(): | |||
| """ | |||
| Test cross entropy between Gamma distributions. | |||
| """ | |||
| net = GammaCrossEntropy() | |||
| concentration1_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||
| concentration0_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||
| concentration1_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) | |||
| concentration0_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) | |||
| ans = net(concentration1_b, concentration0_b, concentration1_a, concentration0_a) | |||
| assert isinstance(ans, Tensor) | |||
| class GammaBasics(nn.Cell): | |||
| """ | |||
| Test class: basic mean/sd function. | |||
| """ | |||
| def __init__(self): | |||
| super(GammaBasics, self).__init__() | |||
| self.g = msd.Gamma(np.array([3.0, 4.0]), np.array([4.0, 6.0]), dtype=dtype.float32) | |||
| def construct(self): | |||
| mean = self.g.mean() | |||
| sd = self.g.sd() | |||
| mode = self.g.mode() | |||
| return mean + sd + mode | |||
| def test_bascis(): | |||
| """ | |||
| Test mean/sd/mode/entropy functionality of Gamma. | |||
| """ | |||
| net = GammaBasics() | |||
| ans = net() | |||
| assert isinstance(ans, Tensor) | |||
| class GammaConstruct(nn.Cell): | |||
| """ | |||
| Gamma distribution: going through construct. | |||
| """ | |||
| def __init__(self): | |||
| super(GammaConstruct, self).__init__() | |||
| self.gamma = msd.Gamma([3.0], [4.0]) | |||
| self.gamma1 = msd.Gamma() | |||
| def construct(self, value, concentration1, concentration0): | |||
| prob = self.gamma('prob', value) | |||
| prob1 = self.gamma('prob', value, concentration1, concentration0) | |||
| prob2 = self.gamma1('prob', value, concentration1, concentration0) | |||
| return prob + prob1 + prob2 | |||
| def test_gamma_construct(): | |||
| """ | |||
| Test probability function going through construct. | |||
| """ | |||
| net = GammaConstruct() | |||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||
| concentration1 = Tensor([0.0], dtype=dtype.float32) | |||
| concentration0 = Tensor([1.0], dtype=dtype.float32) | |||
| ans = net(value, concentration1, concentration0) | |||
| assert isinstance(ans, Tensor) | |||